Repository: gorse-io/gorse Branch: master Commit: 64847a36c966 Files: 274 Total size: 2.4 MB Directory structure: gitextract_2md3h71u/ ├── .devcontainer/ │ ├── Dockerfile │ └── devcontainer.json ├── .dockerignore ├── .github/ │ ├── ISSUE_TEMPLATE/ │ │ ├── bug_report.md │ │ ├── custom.md │ │ └── feature_request.md │ └── workflows/ │ ├── assign_issue.yml │ ├── backport.yml │ ├── build_docker.yml │ ├── build_release.yml │ ├── build_test.yml │ ├── dockerhub-description.yml │ └── translate_issues.yml ├── .gitignore ├── .golangci.yml ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── client/ │ ├── README.md │ ├── client_test.go │ ├── config.go │ ├── config.toml │ ├── docker-compose.yml.j2 │ └── setup-test.sh ├── cmd/ │ ├── goat/ │ │ └── README.md │ ├── gorse-cli/ │ │ └── main.go │ ├── gorse-in-one/ │ │ ├── Dockerfile │ │ ├── Dockerfile.cuda │ │ ├── Dockerfile.mkl │ │ ├── Dockerfile.openblas │ │ ├── Dockerfile.windows │ │ └── main.go │ ├── gorse-master/ │ │ ├── Dockerfile │ │ ├── Dockerfile.cuda │ │ ├── Dockerfile.mkl │ │ ├── Dockerfile.openblas │ │ ├── Dockerfile.windows │ │ └── main.go │ ├── gorse-server/ │ │ ├── Dockerfile │ │ ├── Dockerfile.cuda │ │ ├── Dockerfile.mkl │ │ ├── Dockerfile.openblas │ │ ├── Dockerfile.windows │ │ └── main.go │ ├── gorse-worker/ │ │ ├── Dockerfile │ │ ├── Dockerfile.cuda │ │ ├── Dockerfile.mkl │ │ ├── Dockerfile.openblas │ │ ├── Dockerfile.windows │ │ └── main.go │ └── version/ │ └── version.go ├── codecov.yml ├── common/ │ ├── ann/ │ │ ├── ann.go │ │ ├── ann_test.go │ │ ├── bruteforce.go │ │ └── hnsw.go │ ├── blas/ │ │ ├── blas.go │ │ ├── blas_darwin_arm64.go │ │ ├── blas_mkl.go │ │ └── blas_openblas.go │ ├── copier/ │ │ ├── copier.go │ │ └── copier_test.go │ ├── datautil/ │ │ ├── datautil.go │ │ └── datautil_test.go │ ├── encoding/ │ │ ├── encoding.go │ │ └── encoding_test.go │ ├── expression/ │ │ ├── expression.go │ │ └── expression_test.go │ ├── floats/ │ │ ├── floats.go │ │ ├── floats_amd64.go │ │ ├── floats_amd64_test.go │ │ ├── floats_arm64.go │ │ ├── floats_arm64_test.go │ │ ├── floats_avx.go │ │ ├── floats_avx.s │ │ ├── floats_avx512.go │ │ ├── floats_avx512.s │ │ ├── floats_neon.go │ │ ├── floats_neon.s │ │ ├── floats_noasm.go │ │ ├── floats_riscv64.go │ │ ├── floats_riscv64_test.go │ │ ├── floats_rvv.go │ │ ├── floats_rvv.s │ │ ├── floats_test.go │ │ ├── mm.go │ │ ├── mm_darwin_arm64.go │ │ ├── mm_mkl.go │ │ ├── mm_openblas.go │ │ └── src/ │ │ ├── .gitignore │ │ ├── Makefile │ │ ├── floats_avx.c │ │ ├── floats_avx512.c │ │ ├── floats_neon.c │ │ ├── floats_rvv.c │ │ ├── floats_sve2.c │ │ ├── floats_test.c │ │ ├── munit.c │ │ └── munit.h │ ├── heap/ │ │ ├── filter.go │ │ ├── filter_test.go │ │ ├── pq.go │ │ └── pq_test.go │ ├── jsonutil/ │ │ ├── json.go │ │ └── json_test.go │ ├── log/ │ │ ├── log.go │ │ └── log_test.go │ ├── mock/ │ │ ├── openai.go │ │ └── openai_test.go │ ├── monitor/ │ │ ├── progress.go │ │ └── progress_test.go │ ├── nn/ │ │ ├── functions.go │ │ ├── layers.go │ │ ├── nn_test.go │ │ ├── op.go │ │ ├── op_test.go │ │ ├── optimizers.go │ │ ├── tensor.go │ │ └── tensor_test.go │ ├── parallel/ │ │ ├── parallel.go │ │ ├── parallel_test.go │ │ ├── ratelimit.go │ │ └── ratelimit_test.go │ ├── rc/ │ │ ├── rc.go │ │ └── rc_test.go │ ├── reranker/ │ │ ├── client.go │ │ └── client_test.go │ ├── sizeof/ │ │ ├── size.go │ │ └── size_test.go │ └── util/ │ ├── random.go │ ├── random_test.go │ ├── strconv.go │ ├── tls.go │ ├── util.go │ └── util_test.go ├── config/ │ ├── config.go │ ├── config.toml │ └── config_test.go ├── dataset/ │ ├── dataset.go │ ├── dataset_test.go │ ├── dict.go │ ├── dict_test.go │ ├── index.go │ ├── index_test.go │ ├── unified_index.go │ └── unified_index_test.go ├── docker-bake.hcl ├── docker-compose.yml ├── go.mod ├── go.sum ├── logics/ │ ├── cf.go │ ├── cf_test.go │ ├── chat.go │ ├── chat_test.go │ ├── external.go │ ├── external_test.go │ ├── item_to_item.go │ ├── item_to_item_test.go │ ├── non_personalized.go │ ├── non_personalized_test.go │ ├── recommend.go │ ├── recommend_test.go │ ├── user_to_user.go │ └── user_to_user_test.go ├── master/ │ ├── master.go │ ├── master_test.go │ ├── metrics.go │ ├── metrics_test.go │ ├── rest.go │ ├── rest_test.go │ ├── rpc.go │ ├── rpc_test.go │ ├── tasks.go │ └── tasks_test.go ├── model/ │ ├── built_in.go │ ├── built_in_test.go │ ├── cf/ │ │ ├── evaluator.go │ │ ├── evaluator_test.go │ │ ├── model.go │ │ ├── model_test.go │ │ ├── optimize.go │ │ └── optimize_test.go │ ├── ctr/ │ │ ├── data.go │ │ ├── data_test.go │ │ ├── evaluator.go │ │ ├── evaluator_test.go │ │ ├── fm.go │ │ ├── fm_xla.go │ │ ├── model.go │ │ ├── model.py │ │ ├── model_test.go │ │ ├── optimize.go │ │ └── optimize_test.go │ ├── model.go │ ├── params.go │ └── params_test.go ├── protocol/ │ ├── cache_store.pb.go │ ├── cache_store.proto │ ├── cache_store_grpc.pb.go │ ├── data_store.pb.go │ ├── data_store.proto │ ├── data_store_grpc.pb.go │ ├── encoding.pb.go │ ├── encoding.proto │ ├── generate.go │ ├── protocol.pb.go │ ├── protocol.proto │ ├── protocol_grpc.pb.go │ ├── vector_store.pb.go │ ├── vector_store.proto │ └── vector_store_grpc.pb.go ├── server/ │ ├── metrics.go │ ├── rest.go │ ├── rest_test.go │ ├── server.go │ └── server_test.go ├── storage/ │ ├── blob/ │ │ ├── azure.go │ │ ├── azure_test.go │ │ ├── blob.go │ │ ├── blob_test.go │ │ ├── gcs.go │ │ ├── gcs_test.go │ │ ├── posix.go │ │ ├── posix_test.go │ │ ├── s3.go │ │ └── s3_test.go │ ├── cache/ │ │ ├── database.go │ │ ├── database_test.go │ │ ├── mongodb.go │ │ ├── mongodb_test.go │ │ ├── no_database.go │ │ ├── no_database_test.go │ │ ├── proxy.go │ │ ├── proxy_test.go │ │ ├── redis.go │ │ ├── redis_test.go │ │ ├── sql.go │ │ └── sql_test.go │ ├── data/ │ │ ├── database.go │ │ ├── database_test.go │ │ ├── mongodb.go │ │ ├── mongodb_test.go │ │ ├── no_database.go │ │ ├── no_database_test.go │ │ ├── proxy.go │ │ ├── proxy_test.go │ │ ├── sql.go │ │ └── sql_test.go │ ├── docker-compose.yml │ ├── meta/ │ │ ├── database.go │ │ ├── database_test.go │ │ ├── sqlite.go │ │ └── sqlite_test.go │ ├── options.go │ ├── schema_test.go │ ├── scheme.go │ └── vectors/ │ ├── database.go │ ├── database_test.go │ ├── milvus.go │ ├── milvus_test.go │ ├── proxy.go │ ├── proxy_test.go │ ├── qdrant.go │ ├── qdrant_test.go │ ├── sqlite.go │ ├── sqlite_test.go │ ├── weaviate.go │ └── weaviate_test.go └── worker/ ├── metrics.go ├── pipeline.go ├── pipeline_test.go ├── worker.go └── worker_test.go ================================================ FILE CONTENTS ================================================ ================================================ FILE: .devcontainer/Dockerfile ================================================ FROM mcr.microsoft.com/devcontainers/base:ubuntu-24.04 # use this Dockerfile to install additional tools you might need, e.g. # RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \ # && apt-get -y install --no-install-recommends ================================================ FILE: .devcontainer/devcontainer.json ================================================ // The Dev Container format allows you to configure your environment. At the heart of it // is a Docker image or Dockerfile which controls the tools available in your environment. // // See https://aka.ms/devcontainer.json for more information. { "name": "Gorse", // Use "image": "mcr.microsoft.com/devcontainers/base:ubuntu-24.04", // instead of the build to use a pre-built image. "build": { "context": ".", "dockerfile": "Dockerfile" }, // Features add additional features to your environment. See https://containers.dev/features // Beware: features are not supported on all platforms and may have unintended side-effects. "features": { "ghcr.io/devcontainers/features/docker-in-docker": { "moby": false }, "ghcr.io/devcontainers/features/go": {}, "ghcr.io/devcontainers/features/python": {}, "ghcr.io/devcontainers-extra/features/protoc": {} }, "postCreateCommand": [ "go install google.golang.org/protobuf/cmd/protoc-gen-go@latest", "go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@latest" ] } ================================================ FILE: .dockerignore ================================================ .github assets LICENSE *.yml *.md ================================================ FILE: .github/ISSUE_TEMPLATE/bug_report.md ================================================ --- name: Bug report about: Create a report to help us improve title: '' labels: bug assignees: '' --- Please answer these questions before submitting your issue. Thanks! **Gorse version** Print build info by the `--version` option. **Describe the bug** A clear and concise description of what the bug is. **To Reproduce** Steps to reproduce the behavior. **Expected behavior** A clear and concise description of what you expected to happen. **Additional context** Add any other context about the problem here. ================================================ FILE: .github/ISSUE_TEMPLATE/custom.md ================================================ --- name: Custom issue template about: Describe this issue template's purpose here. title: '' labels: '' assignees: '' --- ================================================ FILE: .github/ISSUE_TEMPLATE/feature_request.md ================================================ --- name: Feature request about: Suggest an idea for this project title: '' labels: '' assignees: '' --- **Is your feature request related to a problem? Please describe.** A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] **Describe the solution you'd like** A clear and concise description of what you want to happen. **Describe alternatives you've considered** A clear and concise description of any alternative solutions or features you've considered. **Additional context** Add any other context or screenshots about the feature request here. ================================================ FILE: .github/workflows/assign_issue.yml ================================================ name: 'assign issues' on: issue_comment: types: [created, edited] jobs: assign_issues: name: assign issues if: ${{ !github.event.issue.pull_request && github.event.comment.body == '/assign' }} runs-on: ubuntu-latest steps: - name: 'Assign issue' uses: pozil/auto-assign-issue@v1.4.0 with: assignees: ${{ github.event.comment.user.login }} ================================================ FILE: .github/workflows/backport.yml ================================================ name: backport merged pull request on: pull_request_target: types: [closed] issue_comment: types: [created] permissions: contents: write # so it can comment pull-requests: write # so it can create pull requests jobs: backport: name: backport pull request runs-on: ubuntu-latest # Only run when pull request is merged # or when a comment containing `/backport` is created by someone other than the # https://github.com/backport-action bot user (user id: 97796249). Note that if you use your # own PAT as `github_token`, that you should replace this id with yours. if: > ( github.event_name == 'pull_request' && github.event.pull_request.merged ) || ( github.event_name == 'issue_comment' && github.event.issue.pull_request && github.event.comment.user.id != 97796249 && contains(github.event.comment.body, '/backport') ) steps: - uses: actions/checkout@v3 - name: Create backport pull requests uses: korthout/backport-action@v1 ================================================ FILE: .github/workflows/build_docker.yml ================================================ name: build on: push: branches: - master jobs: windows_images: name: docker images (windows) runs-on: windows-latest steps: - name: Pull source uses: actions/checkout@v5 with: fetch-depth: 0 - name: Login to DockerHub uses: docker/login-action@v3 with: username: ${{ secrets.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_TOKEN }} - name: Build docker image run: | foreach ($image in "gorse-master", "gorse-server", "gorse-worker", "gorse-in-one") { docker build -f cmd/$image/Dockerfile.windows ` -t zhenghaoz/${image}:nightly-windowsservercore . docker image push --all-tags zhenghaoz/$image } docker_images: name: docker images runs-on: ubuntu-latest strategy: matrix: targets: [default] steps: - name: Pull source uses: actions/checkout@v5 with: fetch-depth: 0 - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 with: buildkitd-config-inline: | [worker.oci] max-parallelism = 1 - name: Login to DockerHub uses: docker/login-action@v3 with: username: ${{ secrets.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_TOKEN }} - name: Build docker image uses: docker/bake-action@v6 with: source: . targets: ${{ matrix.targets }} push: true env: AWS_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} ================================================ FILE: .github/workflows/build_release.yml ================================================ name: release on: release: types: [published] jobs: binaries: name: binaries runs-on: macos-latest steps: - name: Pull source uses: actions/checkout@v5 - name: Set up Go uses: actions/setup-go@v5 with: go-version-file: ./go.mod - name: Install gox run: go install github.com/mitchellh/gox@master - name: Build release (windows and linux) run: > gox -output="{{.OS}}/{{.Arch}}/{{.Dir}}" \ -osarch='windows/arm64 windows/amd64 linux/arm64 linux/amd64 linux/riscv64' -ldflags=" -X 'github.com/gorse-io/gorse/cmd/version.Version=$(git describe --tags $(git rev-parse HEAD))' -X 'github.com/gorse-io/gorse/cmd/version.GitCommit=$(git rev-parse HEAD)' -X 'github.com/gorse-io/gorse/cmd/version.BuildTime=$(date)'" ./... env: CGO_ENABLED: 0 - name: Build release (darwin) run: > gox -output="{{.OS}}/{{.Arch}}/{{.Dir}}" \ -osarch='darwin/arm64' -ldflags=" -X 'github.com/gorse-io/gorse/cmd/version.Version=$(git describe --tags $(git rev-parse HEAD))' -X 'github.com/gorse-io/gorse/cmd/version.GitCommit=$(git rev-parse HEAD)' -X 'github.com/gorse-io/gorse/cmd/version.BuildTime=$(date)'" ./... - name: Install zip run: brew install zip - name: Zip binaries run: | zip -j gorse_linux_amd64.zip linux/amd64/gorse-* zip -j gorse_linux_arm64.zip linux/arm64/gorse-* zip -j gorse_linux_riscv64.zip linux/riscv64/gorse-* zip -j gorse_windows_amd64.zip windows/amd64/gorse-* zip -j gorse_windows_arm64.zip windows/arm64/gorse-* zip -j gorse_darwin_arm64.zip darwin/arm64/gorse-* - name: Upload release uses: svenstaro/upload-release-action@v2 with: repo_token: ${{ secrets.GITHUB_TOKEN }} file: gorse_*_*.zip tag: ${{ github.ref }} overwrite: true file_glob: true docker_images: name: docker images runs-on: ubuntu-latest steps: - name: Pull source uses: actions/checkout@v5 with: fetch-depth: 0 - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 with: buildkitd-config-inline: | [worker.oci] max-parallelism = 1 - name: Login to DockerHub uses: docker/login-action@v3 with: username: ${{ secrets.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_TOKEN }} - id: get_version uses: battila7/get-version-action@v2 - if: github.event.release.prerelease == false name: Build Docker image uses: docker/bake-action@v6 with: source: . targets: default push: true env: VERSIONS: latest,${{ steps.get_version.outputs.major }}.${{ steps.get_version.outputs.minor }},${{ steps.get_version.outputs.version-without-v }} - if: github.event.release.prerelease == true name: Build prerelease Docker image uses: docker/bake-action@v6 with: source: . targets: default push: true env: VERSIONS: ${{ steps.get_version.outputs.version-without-v }} windows_images: name: docker images (windows) runs-on: windows-latest steps: - name: Pull source uses: actions/checkout@v5 with: fetch-depth: 0 - id: get_version uses: battila7/get-version-action@v2 - name: Login to DockerHub uses: docker/login-action@v3 with: username: ${{ secrets.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_TOKEN }} - name: Build docker image if: github.event.release.prerelease == false run: | foreach ($image in "gorse-master", "gorse-server", "gorse-worker", "gorse-in-one") { docker build -f cmd/${image}/Dockerfile.windows ` -t zhenghaoz/${image}:windowsservercore ` -t zhenghaoz/${image}:${{ steps.get_version.outputs.major }}.${{ steps.get_version.outputs.minor }}-windowsservercore ` -t zhenghaoz/${image}:${{ steps.get_version.outputs.version-without-v }}-windowsservercore . docker image push --all-tags zhenghaoz/$image } - name: Build prerelease docker image if: github.event.release.prerelease == true run: | foreach ($image in "gorse-master", "gorse-server", "gorse-worker", "gorse-in-one") { docker build -f cmd/${image}/Dockerfile.windows ` -t zhenghaoz/${image}:${{ steps.get_version.outputs.version-without-v }}-windowsservercore . docker image push --all-tags zhenghaoz/$image } ================================================ FILE: .github/workflows/build_test.yml ================================================ name: test on: push: branches: - master - 'release-**' pull_request: branches: - master - 'release-**' jobs: unit_test: strategy: matrix: os: [ubuntu-latest, ubuntu-24.04-arm] name: unit tests runs-on: ${{ matrix.os }} services: mysql: image: mysql:8.0 ports: - 3306 env: MYSQL_ROOT_PASSWORD: password options: --health-cmd="mysqladmin ping" --health-interval=10s --health-timeout=5s --health-retries=3 postgres: image: postgres:10.0 ports: - 5432 env: POSTGRES_USER: gorse POSTGRES_PASSWORD: gorse_pass options: >- --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 mongo: image: mongo:4.0 ports: - 27017 env: MONGO_INITDB_ROOT_USERNAME: root MONGO_INITDB_ROOT_PASSWORD: password options: >- --health-cmd mongo --health-interval 10s --health-timeout 5s --health-retries 5 clickhouse: image: clickhouse/clickhouse-server:23 ports: - 8123 options: >- --health-cmd="clickhouse-client --query 'SELECT 1'" --health-interval=10s --health-timeout=5s --health-retries=5 redis: image: redis/redis-stack:6.2.6-v9 ports: - 6379 rustfs: image: rustfs/rustfs:alpha ports: - 9000 env: RUSTFS_ACCESS_KEY: rustfsadmin RUSTFS_SECRET_KEY: rustfsadmin qdrant: image: qdrant/qdrant:latest ports: - 6334 weaviate: image: cr.weaviate.io/semitechnologies/weaviate:1.35.7 ports: - 8080 steps: - name: Set up dataset run: | mkdir -p ~/.gorse/dataset mkdir -p ~/.gorse/download wget https://cdn.gorse.io/datasets/ml-100k.zip -P ~/.gorse/download wget https://cdn.gorse.io/datasets/ml-1m.zip -P ~/.gorse/download wget https://cdn.gorse.io/datasets/pinterest-20.zip -P ~/.gorse/download wget https://cdn.gorse.io/datasets/frappe.zip -P ~/.gorse/download wget https://cdn.gorse.io/datasets/ml-tag.zip -P ~/.gorse/download wget https://cdn.gorse.io/datasets/criteo.zip -P ~/.gorse/download wget https://cdn.gorse.io/datasets/mnist.zip -P ~/.gorse/download unzip ~/.gorse/download/ml-100k.zip -d ~/.gorse/dataset unzip ~/.gorse/download/ml-1m.zip -d ~/.gorse/dataset unzip ~/.gorse/download/pinterest-20.zip -d ~/.gorse/dataset unzip ~/.gorse/download/frappe.zip -d ~/.gorse/dataset unzip ~/.gorse/download/ml-tag.zip -d ~/.gorse/dataset unzip ~/.gorse/download/criteo.zip -d ~/.gorse/dataset unzip ~/.gorse/download/mnist.zip -d ~/.gorse/dataset - uses: actions/checkout@v2 - uses: actions/setup-go@v5 with: go-version-file: ./go.mod - name: Install and start Azurite uses: potatoqualitee/azuright@v1 - name: Setup Milvus run: | curl -sfL https://raw.githubusercontent.com/milvus-io/milvus/master/scripts/standalone_embed.sh -o standalone_embed.sh bash standalone_embed.sh start working-directory: ${{ runner.temp }} - name: Test run: go test -timeout 30m -v ./... -coverprofile=coverage.txt -covermode=atomic -coverpkg=./... env: # MySQL MYSQL_URI: mysql://root:password@tcp(localhost:${{ job.services.mysql.ports[3306] }})/ # Postgres POSTGRES_URI: postgres://gorse:gorse_pass@localhost:${{ job.services.postgres.ports[5432] }}/ # MongoDB MONGO_URI: mongodb://root:password@localhost:${{ job.services.mongo.ports[27017] }}/ # ClickHouse CLICKHOUSE_URI: clickhouse://localhost:${{ job.services.clickhouse.ports[8123] }}/ # Redis REDIS_URI: redis://localhost:${{ job.services.redis.ports[6379] }}/ # S3 S3_ENDPOINT: localhost:${{ job.services.rustfs.ports[9000] }} S3_ACCESS_KEY_ID: rustfsadmin S3_SECRET_ACCESS_KEY: rustfsadmin # Azure Blob (Azurite) AZURE_STORAGE_CONNECTION_STRING: DefaultEndpointsProtocol=http;AccountName=devstoreaccount1;AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;BlobEndpoint=http://127.0.0.1:10000/devstoreaccount1; # Qdrant QDRANT_URI: qdrant://localhost:${{ job.services.qdrant.ports[6334] }}/ # Weaviate WEAVIATE_URI: weaviate://localhost:${{ job.services.weaviate.ports[8080] }}/ - name: Upload coverage to Codecov if: matrix.os == 'ubuntu-latest' uses: codecov/codecov-action@v5 with: files: ./coverage.txt fail_ci_if_error: false token: ${{ secrets.CODECOV_TOKEN }} unit_test_macos: name: unit tests (macos-latest) runs-on: macos-latest steps: - name: Set up dataset run: | mkdir -p ~/.gorse/dataset mkdir -p ~/.gorse/download wget https://cdn.gorse.io/datasets/ml-100k.zip -P ~/.gorse/download wget https://cdn.gorse.io/datasets/ml-1m.zip -P ~/.gorse/download wget https://cdn.gorse.io/datasets/pinterest-20.zip -P ~/.gorse/download wget https://cdn.gorse.io/datasets/frappe.zip -P ~/.gorse/download wget https://cdn.gorse.io/datasets/ml-tag.zip -P ~/.gorse/download wget https://cdn.gorse.io/datasets/criteo.zip -P ~/.gorse/download wget https://cdn.gorse.io/datasets/mnist.zip -P ~/.gorse/download unzip ~/.gorse/download/ml-100k.zip -d ~/.gorse/dataset unzip ~/.gorse/download/ml-1m.zip -d ~/.gorse/dataset unzip ~/.gorse/download/pinterest-20.zip -d ~/.gorse/dataset unzip ~/.gorse/download/frappe.zip -d ~/.gorse/dataset unzip ~/.gorse/download/ml-tag.zip -d ~/.gorse/dataset unzip ~/.gorse/download/criteo.zip -d ~/.gorse/dataset unzip ~/.gorse/download/mnist.zip -d ~/.gorse/dataset - uses: actions/checkout@v2 - uses: actions/setup-go@v5 with: go-version-file: ./go.mod - name: Test run: go test -timeout 20m -v ./... -skip "TestPostgres|TestMySQL|TestMongo|TestRedis|TestClickHouse|TestMilvus|TestQdrant|TestWeaviate" unit_test_windows: strategy: matrix: os: [windows-latest, windows-11-arm] name: unit tests runs-on: ${{ matrix.os }} steps: - name: Set up dataset run: | New-Item -Type Directory -Path ~/.gorse/dataset New-Item -Type Directory -Path ~/.gorse/download Invoke-WebRequest https://cdn.gorse.io/datasets/ml-100k.zip -OutFile ~/.gorse/download/ml-100k.zip Invoke-WebRequest https://cdn.gorse.io/datasets/ml-1m.zip -OutFile ~/.gorse/download/ml-1m.zip Invoke-WebRequest https://cdn.gorse.io/datasets/pinterest-20.zip -OutFile ~/.gorse/download/pinterest-20.zip Invoke-WebRequest https://cdn.gorse.io/datasets/frappe.zip -OutFile ~/.gorse/download/frappe.zip Invoke-WebRequest https://cdn.gorse.io/datasets/ml-tag.zip -OutFile ~/.gorse/download/ml-tag.zip Invoke-WebRequest https://cdn.gorse.io/datasets/criteo.zip -OutFile ~/.gorse/download/criteo.zip Invoke-WebRequest https://cdn.gorse.io/datasets/mnist.zip -OutFile ~/.gorse/download/mnist.zip Expand-Archive ~/.gorse/download/ml-100k.zip -DestinationPath ~/.gorse/dataset Expand-Archive ~/.gorse/download/ml-1m.zip -DestinationPath ~/.gorse/dataset Expand-Archive ~/.gorse/download/pinterest-20.zip -DestinationPath ~/.gorse/dataset Expand-Archive ~/.gorse/download/frappe.zip -DestinationPath ~/.gorse/dataset Expand-Archive ~/.gorse/download/ml-tag.zip -DestinationPath ~/.gorse/dataset Expand-Archive ~/.gorse/download/criteo.zip -DestinationPath ~/.gorse/dataset Expand-Archive ~/.gorse/download/mnist.zip -DestinationPath ~/.gorse/dataset - uses: actions/checkout@v2 - uses: actions/setup-go@v5 with: go-version-file: ./go.mod - name: Test run: go test -timeout 20m -v ./... -skip "TestPostgres|TestMySQL|TestMongo|TestRedis|TestClickHouse|TestMilvus|TestQdrant|TestWeaviate" integrate_test: name: integrate tests runs-on: ubuntu-latest strategy: matrix: database: [mysql, postgres, mongo, sqlite] steps: - uses: actions/checkout@v5 with: fetch-depth: 0 - uses: actions/setup-go@v5 with: go-version-file: ./go.mod - uses: cuchi/jinja2-action@v1.2.0 with: template: client/docker-compose.yml.j2 output_file: docker-compose.yml strict: true variables: | database=${{ matrix.database }} - name: Setup test run: ./client/setup-test.sh env: DOCKER_BUILDKIT: 1 - name: Test run: go test ./client/ env: GORSE_SERVER_ENDPOINT: http://localhost:8087 GORSE_DASHBOARD_ENDPOINT: http://localhost:8088 compat_test: name: compatibility tests runs-on: ubuntu-latest services: mariadb: image: mariadb:10.2 ports: - 3306 env: MYSQL_ROOT_PASSWORD: password kvrocks: image: apache/kvrocks:nightly ports: - 6666 steps: - name: Install pre-requisites uses: awalsh128/cache-apt-pkgs-action@latest with: packages: redis-tools - uses: actions/checkout@v5 with: repository: gorse-cloud/redis-stack path: redis-stack - name: Setup Redis cluster run: | docker compose -f redis-stack/docker-compose.yml --project-directory redis-stack up -d for i in {1..5}; do redis-cli -p 7005 ping | grep PONG && break sleep 10 done - uses: actions/checkout@v2 - uses: actions/setup-go@v5 with: go-version-file: ./go.mod - name: Test MariaDB run: go test ./storage/data -run TestMySQL env: MYSQL_URI: mysql://root:password@tcp(localhost:${{ job.services.mariadb.ports[3306] }})/ - name: Test Kvrocks run: go test ./storage/cache -run ^TestRedis env: REDIS_URI: redis://localhost:${{ job.services.kvrocks.ports[6666] }}/ - name: Test Redis cluster (1 address) run: go test ./storage/cache -run ^TestRedis env: REDIS_URI: redis+cluster://localhost:7000 - name: Test Redis cluster (6 addresses) run: go test ./storage/cache -run ^TestRedis env: REDIS_URI: redis+cluster://localhost:7000?addr=localhost:7001&addr=localhost:7002&addr=localhost:7003&addr=localhost:7004&addr=localhost:7005 playground: name: playground runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - name: Build Docker image run: docker build -t zhenghaoz/gorse-in-one -f ./cmd/gorse-in-one/Dockerfile . - name: Run playground run: docker run -d -p 8088:8088 --name playground zhenghaoz/gorse-in-one --playground - name: Check dashboard URL run: | for i in {1..10}; do curl -sSf http://localhost:8088 && break docker logs playground sleep 10 done golangci: name: lint runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - uses: actions/setup-go@v5 with: go-version-file: ./go.mod - name: golangci-lint uses: golangci/golangci-lint-action@v9 with: args: --timeout 20m ================================================ FILE: .github/workflows/dockerhub-description.yml ================================================ name: update on: push: branches: - master paths: - 'README.md' jobs: dockerhub_description: name: dockerHub description runs-on: ubuntu-latest strategy: matrix: image: [gorse-master, gorse-server, gorse-worker, gorse-in-one] steps: - uses: actions/checkout@v5 - name: Resolve Images on Description run: | sed -i -E "s/src=\"assets\//src=\"https:\/\/github.com\/gorse-io\/gorse\/raw\/master\/assets\//" README.md - name: Update DockerHub Description uses: peter-evans/dockerhub-description@v3 with: username: ${{ secrets.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_TOKEN }} repository: zhenghaoz/${{ matrix.image }} ================================================ FILE: .github/workflows/translate_issues.yml ================================================ name: 'translate issues' on: issue_comment: types: [created] issues: types: [opened] jobs: translate-isssues: name: translate issues runs-on: ubuntu-latest steps: - name: Issues Translator uses: tomsun28/issues-translate-action@v2.5 ================================================ FILE: .gitignore ================================================ # Created by https://www.gitignore.io/api/go,windows,jetbrains ### Go ### # Binaries for programs and plugins *.exe *.exe~ *.dll *.so *.dylib # Test binary, build with `go test -c` *.test # Output of the go coverage tool, specifically when used with LiteIDE *.out ### Go Patch ### /vendor/ /Godeps/ ### JetBrains ### # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 # User-specific stuff .idea/**/workspace.xml .idea/**/tasks.xml .idea/**/usage.statistics.xml .idea/**/dictionaries .idea/**/shelf # Sensitive or high-churn files .idea/**/dataSources/ .idea/**/dataSources.ids .idea/**/dataSources.local.xml .idea/**/sqlDataSources.xml .idea/**/dynamic.xml .idea/**/uiDesigner.xml .idea/**/dbnavigator.xml # Gradle .idea/**/gradle.xml .idea/**/libraries # Gradle and Maven with auto-import # When using Gradle or Maven with auto-import, you should exclude module files, # since they will be recreated, and may cause churn. Uncomment if using # auto-import. # .idea/modules.xml # .idea/*.iml # .idea/modules # CMake cmake-build-*/ # Mongo Explorer plugin .idea/**/mongoSettings.xml # File-based project format *.iws # IntelliJ out/ # mpeltonen/sbt-idea plugin .idea_modules/ # JIRA plugin atlassian-ide-plugin.xml # Cursive Clojure plugin .idea/replstate.xml # Crashlytics plugin (for Android Studio and IntelliJ) com_crashlytics_export_strings.xml crashlytics.properties crashlytics-build.properties fabric.properties # Editor-based Rest Client .idea/httpRequests ### JetBrains Patch ### # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 # *.iml # modules.xml # .idea/misc.xml # *.ipr # Sonarlint plugin .idea/sonarlint ### Windows ### # Windows thumbnail cache files Thumbs.db ehthumbs.db ehthumbs_vista.db # Dump file *.stackdump # Folder config file [Dd]esktop.ini # Recycle Bin used on file shares $RECYCLE.BIN/ # Windows Installer files *.cab *.msi *.msix *.msm *.msp # Windows shortcuts *.lnk .vscode # End of https://www.gitignore.io/api/go,windows,jetbrains ================================================ FILE: .golangci.yml ================================================ version: "2" linters: settings: govet: disable: - composites staticcheck: checks: - all - -QF1004 - -QF1008 - -SA1019 - -ST1003 exclusions: presets: - comments - common-false-positives - legacy - std-error-handling ================================================ FILE: CODE_OF_CONDUCT.md ================================================ # Contributor Covenant Code of Conduct ## Our Pledge We as members, contributors, and leaders pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community. ## Our Standards Examples of behavior that contributes to a positive environment for our community include: * Demonstrating empathy and kindness toward other people * Being respectful of differing opinions, viewpoints, and experiences * Giving and gracefully accepting constructive feedback * Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience * Focusing on what is best not just for us as individuals, but for the overall community Examples of unacceptable behavior include: * The use of sexualized language or imagery, and sexual attention or advances of any kind * Trolling, insulting or derogatory comments, and personal or political attacks * Public or private harassment * Publishing others' private information, such as a physical or email address, without their explicit permission * Other conduct which could reasonably be considered inappropriate in a professional setting ## Enforcement Responsibilities Community leaders are responsible for clarifying and enforcing our standards of acceptable behavior and will take appropriate and fair corrective action in response to any behavior that they deem inappropriate, threatening, offensive, or harmful. Community leaders have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, and will communicate reasons for moderation decisions when appropriate. ## Scope This Code of Conduct applies within all community spaces, and also applies when an individual is officially representing the community in public spaces. Examples of representing our community include using an official e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. ## Enforcement Instances of abusive, harassing, or otherwise unacceptable behavior may be reported to the community leaders responsible for enforcement at coc@gorse.io. All complaints will be reviewed and investigated promptly and fairly. All community leaders are obligated to respect the privacy and security of the reporter of any incident. ## Enforcement Guidelines Community leaders will follow these Community Impact Guidelines in determining the consequences for any action they deem in violation of this Code of Conduct: ### 1. Correction **Community Impact**: Use of inappropriate language or other behavior deemed unprofessional or unwelcome in the community. **Consequence**: A private, written warning from community leaders, providing clarity around the nature of the violation and an explanation of why the behavior was inappropriate. A public apology may be requested. ### 2. Warning **Community Impact**: A violation through a single incident or series of actions. **Consequence**: A warning with consequences for continued behavior. No interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, for a specified period of time. This includes avoiding interactions in community spaces as well as external channels like social media. Violating these terms may lead to a temporary or permanent ban. ### 3. Temporary Ban **Community Impact**: A serious violation of community standards, including sustained inappropriate behavior. **Consequence**: A temporary ban from any sort of interaction or public communication with the community for a specified period of time. No public or private interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, is allowed during this period. Violating these terms may lead to a permanent ban. ### 4. Permanent Ban **Community Impact**: Demonstrating a pattern of violation of community standards, including sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals. **Consequence**: A permanent ban from any sort of public interaction within the community. ## Attribution This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.0, available at https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. Community Impact Guidelines were inspired by [Mozilla's code of conduct enforcement ladder](https://github.com/mozilla/diversity). [homepage]: https://www.contributor-covenant.org For answers to common questions about this code of conduct, see the FAQ at https://www.contributor-covenant.org/faq. Translations are available at https://www.contributor-covenant.org/translations. ================================================ FILE: CONTRIBUTING.md ================================================ # Contribution Guide Welcome and thank you for considering contributing to Gorse! Reading and following these guidelines will help us make the contribution process easy and effective for everyone involved. It also communicates that you agree to respect the time of the developers managing and developing these open source projects. In return, we will reciprocate that respect by addressing your issue, assessing changes, and helping you finalize your pull requests. * [Getting Started](#getting-started) * [Setup Develop Environment](#setup-develop-environment) * [Option 1: Run an All-in-one Node](#option-1-run-an-all-in-one-node) * [Option 2: Run Nodes](#option-2-run-nodes) * [Run Unit Tests](#run-unit-tests) * [Your First Contribution](#your-first-contribution) * [Contribution Workflow](#contribution-workflow) * [Getting Help](#getting-help) ## Getting Started ### Setup Develop Environment These following installations are required: - **Go** (>= 1.18): Since Go features from 1.18 are used in Gorse, the version of the compiler must be greater than 1.18. GoLand or Visual Studio Code is highly recommended as the IDE to develop Gorse. - **Docker Compose**: Multiple databases are required for unit tests. It's convenient to manage databases on Docker Compose. ```bash cd storage docker compose up -d ``` If you need import sample data, download the SQL file github.sql and import to the MySQL instance. ```bash # Download sample data. wget https://cdn.gorse.io/example/github.sql # Import sample data. mysql -h 127.0.0.1 -u gorse -pgorse_pass gorse < github.sql ``` ### Option 1: Run an All-in-one Node ```bash go run cmd/gorse-in-one/main.go --config config/config.toml ``` ### Option 2: Run Nodes - Start the master node with the configuration file. ```bash go run cmd/gorse-master/main.go --config config/config.toml ``` - Start the worker node. ```bash go run cmd/gorse-worker/main.go ``` - Start the server node. ```bash go run cmd/gorse-server/main.go ``` ### Run Unit Tests Most logics in Gorse are covered by unit tests. Run unit tests by the following command: ```bash go test -v ./... ``` The default database URLs are directed to these databases in `storage/docker-compose.yml`. Test databases could be overrode by setting following environment variables: | Environment Value | Default Value | |-------------------|----------------------------------------------| | `MYSQL_URI` | `mysql://root:password@tcp(127.0.0.1:3306)/` | | `POSTGRES_URI` | `postgres://gorse:gorse_pass@127.0.0.1/` | | `MONGO_URI` | `mongodb://root:password@127.0.0.1:27017/` | | `CLICKHOUSE_URI` | `clickhouse://127.0.0.1:8123/` | | `REDIS_URI` | `redis://127.0.0.1:6379/` | For example, use TiDB as a test database by: ```bash MYSQL_URI=mysql://root:password@tcp(127.0.0.1:4000)/ go test -v ./... ``` ## Your First Contribution You can start by finding an existing issue with the [help wanted](https://github.com/gorse-io/gorse/issues?q=is%3Aopen+is%3Aissue+label%3A%22help+wanted%22) label in the Gorse repository. These issues are well suited for new contributors. Issues can be claimed by publishing an `/assign` comment. ### Contribution Workflow To contribute to the Gorse code base, please follow the workflow as defined in this section. - Fork the repository to your own Github account - Make commits and add test case if the change fixes a bug or adds new functionality. - Run tests and make sure all the tests are passed. - Push your changes to a topic branch in your fork of the repository. - Submit a pull request. This is a rough outline of what a contributor's workflow looks like. Thanks for your contributions! ## Getting Help Join us in the [Discord](https://discord.gg/x6gAtNNkAE) and post your question in the `#developers` channel. ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================ # Gorse Open-source Recommender System Engine ![](https://img.shields.io/github/go-mod/go-version/zhenghaoz/gorse) [![test](https://github.com/gorse-io/gorse/actions/workflows/build_test.yml/badge.svg)](https://github.com/gorse-io/gorse/actions/workflows/build_test.yml) [![codecov](https://codecov.io/gh/gorse-io/gorse/branch/master/graph/badge.svg)](https://codecov.io/gh/gorse-io/gorse) [![Discord](https://img.shields.io/discord/830635934210588743)](https://discord.gg/x6gAtNNkAE) [![Twitter Follow](https://img.shields.io/twitter/follow/gorse_io?label=Follow&style=social)](https://twitter.com/gorse_io) [![Gurubase](https://img.shields.io/badge/Gurubase-Ask%20Gorse%20Guru-006BFF)](https://gurubase.io/g/gorse) Gorse is an AI powered open-source recommender system written in Go. Gorse aims to be a universal open-source recommender system that can be quickly integrated into a wide variety of online services. By importing items, users, and interaction data into Gorse, the system will automatically train models to generate recommendations for each user. Project features are as follows. ![](https://github.com/gorse-io/docs/blob/main/src/img/dashboard/recflow.png?raw=true) - **Multi-source:** Recommend items from latest, user-to-user, item-to-item, collaborative filtering and etc. - **Multimodal:** Support multimodal content (text, image, videos, etc.) via embedding. - **AI-powered:** Support both classical recommenders and LLM-based recommenders. - **GUI Dashboard:** Provide GUI dashboard for recommendation pipeline editing, system monitoring, and data management. - **RESTful APIs:** Expose RESTful APIs for data CRUD and recommendation requests. ## Quick Start The playground mode has been prepared for beginners. Just set up a recommender system for GitHub repositories by the following commands. ```bash docker run -p 8088:8088 zhenghaoz/gorse-in-one --playground ``` The playground mode will download data from [GitRec](https://gitrec.gorse.io/) and import it into Gorse. The dashboard is available at `http://localhost:8088`. ![](https://github.com/gorse-io/docs/blob/main/src/img/dashboard/overview.png?raw=true) After the "Generate item-to-item recommendation" task is completed on the "Tasks" page, try to insert several feedbacks into Gorse. Suppose Bob is a developer who interested in LLM related repositories. We insert his star feedback to Gorse. ```bash read -d '' JSON << EOF [ { \"FeedbackType\": \"star\", \"UserId\": \"bob\", \"ItemId\": \"ollama:ollama\", \"Value\": 1.0, \"Timestamp\": \"2022-02-24\" }, { \"FeedbackType\": \"star\", \"UserId\": \"bob\", \"ItemId\": \"huggingface:transformers\", \"Value\": 1.0, \"Timestamp\": \"2022-02-25\" }, { \"FeedbackType\": \"star\", \"UserId\": \"bob\", \"ItemId\": \"rasbt:llms-from-scratch\", \"Value\": 1.0, \"Timestamp\": \"2022-02-26\" }, { \"FeedbackType\": \"star\", \"UserId\": \"bob\", \"ItemId\": \"vllm-project:vllm\", \"Value\": 1.0, \"Timestamp\": \"2022-02-27\" }, { \"FeedbackType\": \"star\", \"UserId\": \"bob\", \"ItemId\": \"hiyouga:llama-factory\", \"Value\": 1.0, \"Timestamp\": \"2022-02-28\" } ] EOF curl -X POST http://127.0.0.1:8088/api/feedback \ -H 'Content-Type: application/json' \ -d "$JSON" ``` Then, fetch 10 recommended items from Gorse. We can find that LLM-related repositories are recommended for Bob. ```bash curl http://127.0.0.1:8088/api/recommend/bob?n=10 ``` For more information: - Read [official documents](https://gorse.io/docs/) - Visit [playground](https://play.gorse.io/) of Gorse dashboard - Explore [live demo](https://gitrec.gorse.io/), a recommender system for GitHub repositories - Discuss on [Discord](https://discord.gg/x6gAtNNkAE) or [GitHub Discussion](https://github.com/gorse-io/gorse/discussions) ## Architecture Gorse is a single-node training and distributed prediction recommender system. Gorse stores data in MySQL, MongoDB, Postgres, or ClickHouse, with intermediate results cached in Redis, MySQL, MongoDB and Postgres. 1. The cluster consists of a master node, multiple worker nodes, and server nodes. 1. The master node is responsible for model training, non-personalized recommendation, configuration management, and membership management. 1. The server node is responsible for exposing the RESTful APIs and online real-time recommendations. 1. Worker nodes are responsible for offline recommendations for each user. In addition, the administrator can perform system monitoring, data import and export, and system status checking via the dashboard on the master node. ## Contributors Any contribution is appreciated: report a bug, give advice or create a pull request. Read [CONTRIBUTING.md](CONTRIBUTING.md) for more information. ## Acknowledgments `gorse` is inspired by the following projects: - [Guibing Guo's librec](https://github.com/guoguibing/librec) - [Nicolas Hug's Surprise](https://github.com/NicolasHug/Surprise) - [Golang Samples's gopher-vector](https://github.com/golang-samples/gopher-vector) ================================================ FILE: client/README.md ================================================ Go SDK has been moved to https://github.com/gorse-io/gorse-go ================================================ FILE: client/client_test.go ================================================ // Copyright 2022 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package client import ( "os" "testing" "time" client "github.com/gorse-io/gorse-go" "github.com/stretchr/testify/suite" ) var ( serverEndpoint string dashboardEndpoint string ) func init() { serverEndpoint = os.Getenv("GORSE_SERVER_ENDPOINT") dashboardEndpoint = os.Getenv("GORSE_DASHBOARD_ENDPOINT") } type GorseClientTestSuite struct { suite.Suite client *client.GorseClient } func (suite *GorseClientTestSuite) SetupSuite() { if serverEndpoint == "" || dashboardEndpoint == "" { suite.T().Skip("GORSE_SERVER_ENDPOINT or GORSE_DASHBOARD_ENDPOINT is not set") } suite.client = client.NewGorseClient(serverEndpoint, "") } func (suite *GorseClientTestSuite) TestUsers() { ctx := suite.T().Context() cursor, err := suite.client.GetUsers(ctx, 3, "") suite.NoError(err) suite.NotEmpty(cursor.Cursor) if suite.Len(cursor.Users, 3) { suite.Equal(client.User{ UserId: "1", Labels: map[string]any{ "age": float64(24), "gender": "M", "occupation": "technician", "zip_code": "85711", }, }, cursor.Users[0]) suite.Equal(client.User{ UserId: "10", Labels: map[string]any{ "age": float64(53), "gender": "M", "occupation": "lawyer", "zip_code": "90703", }, }, cursor.Users[1]) suite.Equal(client.User{ UserId: "100", Labels: map[string]any{ "age": float64(36), "gender": "M", "occupation": "executive", "zip_code": "90254", }, }, cursor.Users[2]) } user := client.User{ UserId: "1000", Labels: map[string]any{"gender": "M", "occupation": "engineer"}, Comment: "zhenghaoz", } rowAffected, err := suite.client.InsertUser(ctx, user) suite.NoError(err) suite.Equal(1, rowAffected.RowAffected) resp, err := suite.client.GetUser(ctx, "1000") suite.NoError(err) suite.Equal(user, resp) patch := client.UserPatch{ Comment: new("hongmi"), } rowAffected, err = suite.client.UpdateUser(ctx, user.UserId, patch) suite.NoError(err) suite.Equal(1, rowAffected.RowAffected) resp, err = suite.client.GetUser(ctx, "1000") suite.NoError(err) suite.Equal("hongmi", resp.Comment) deleteAffect, err := suite.client.DeleteUser(ctx, "1000") suite.NoError(err) suite.Equal(1, deleteAffect.RowAffected) _, err = suite.client.GetUser(ctx, "1000") suite.Equal("1000: user not found", err.Error()) } func (suite *GorseClientTestSuite) TestItems() { ctx := suite.T().Context() items, err := suite.client.GetItems(ctx, 3, "") suite.NoError(err) suite.NotEmpty(items.Cursor) if suite.Len(items.Items, 3) { suite.Equal("1", items.Items[0].ItemId) suite.Equal([]string{"Animation", "Children's", "Comedy"}, items.Items[0].Categories) suite.Equal(time.Date(1995, 1, 1, 0, 0, 0, 0, time.UTC), items.Items[0].Timestamp) suite.Equal("Toy Story (1995)", items.Items[0].Comment) suite.Equal("10", items.Items[1].ItemId) suite.Equal([]string{"Drama", "War"}, items.Items[1].Categories) suite.Equal(time.Date(1996, 1, 22, 0, 0, 0, 0, time.UTC), items.Items[1].Timestamp) suite.Equal("Richard III (1995)", items.Items[1].Comment) suite.Equal("100", items.Items[2].ItemId) suite.Equal([]string{"Crime", "Drama", "Thriller"}, items.Items[2].Categories) suite.Equal(time.Date(1997, 2, 14, 0, 0, 0, 0, time.UTC), items.Items[2].Timestamp) suite.Equal("Fargo (1996)", items.Items[2].Comment) } item := client.Item{ ItemId: "2000", IsHidden: true, Labels: map[string]any{ "embedding": []any{0.1, 0.2, 0.3}, }, Categories: []string{"Comedy", "Animation"}, Timestamp: time.Now().UTC().Truncate(time.Second), Comment: "Minions (2015)", } rowAffected, err := suite.client.InsertItem(ctx, item) suite.NoError(err) suite.Equal(1, rowAffected.RowAffected) resp, err := suite.client.GetItem(ctx, "2000") suite.NoError(err) suite.Equal(item, resp) patch := client.ItemPatch{ Comment: new("小黄人 (2015)"), } rowAffected, err = suite.client.UpdateItem(ctx, item.ItemId, patch) suite.NoError(err) suite.Equal(1, rowAffected.RowAffected) resp, err = suite.client.GetItem(ctx, "2000") suite.NoError(err) suite.Equal("小黄人 (2015)", resp.Comment) deleteAffect, err := suite.client.DeleteItem(ctx, "2000") suite.NoError(err) suite.Equal(1, deleteAffect.RowAffected) _, err = suite.client.GetItem(ctx, "2000") suite.Equal("2000: item not found", err.Error()) } func (suite *GorseClientTestSuite) TestFeedback() { ctx := suite.T().Context() _, err := suite.client.InsertUser(ctx, client.User{UserId: "2000"}) suite.NoError(err) feedback := []client.Feedback{ { FeedbackType: "watch", UserId: "2000", ItemId: "1", Value: 1.0, Timestamp: time.Now().UTC().Truncate(time.Second), }, { FeedbackType: "watch", UserId: "2000", ItemId: "1060", Value: 2.0, Timestamp: time.Now().UTC().Truncate(time.Second), }, { FeedbackType: "watch", UserId: "2000", ItemId: "11", Value: 3.0, Timestamp: time.Now().UTC().Truncate(time.Second), }, } for _, fb := range feedback { _, err := suite.client.DeleteFeedbacks(ctx, fb.UserId, fb.ItemId) suite.NoError(err) } _, err = suite.client.InsertFeedback(ctx, feedback) suite.NoError(err) userFeedback, err := suite.client.ListFeedbacks(ctx, "watch", "2000") suite.NoError(err) suite.Equal(feedback, userFeedback) _, err = suite.client.DeleteFeedback(ctx, "watch", "2000", "1") suite.NoError(err) userFeedback, err = suite.client.ListFeedbacks(ctx, "watch", "2000") suite.NoError(err) suite.Equal([]client.Feedback{feedback[1], feedback[2]}, userFeedback) } func (suite *GorseClientTestSuite) TestLatest() { ctx := suite.T().Context() items, err := suite.client.GetLatestItems(ctx, "", "", 3, 0) suite.NoError(err) if suite.Len(items, 3) { suite.Equal("315", items[0].Id) suite.Equal("1432", items[1].Id) suite.Equal("918", items[2].Id) } } func (suite *GorseClientTestSuite) TestItemToItem() { ctx := suite.T().Context() neighbors, err := suite.client.GetNeighbors(ctx, "1", 3) suite.NoError(err) if suite.Len(neighbors, 3) { suite.Equal("1060", neighbors[0].Id) suite.Equal("404", neighbors[1].Id) suite.Equal("1219", neighbors[2].Id) } } func (suite *GorseClientTestSuite) TestRecommend() { ctx := suite.T().Context() _, err := suite.client.InsertUser(ctx, client.User{UserId: "3000"}) suite.NoError(err) recommendations, err := suite.client.GetRecommend(ctx, "3000", "", 3, 0) suite.NoError(err) suite.Len(recommendations, 3) if suite.Len(recommendations, 3) { suite.Equal("315", recommendations[0]) suite.Equal("1432", recommendations[1]) suite.Equal("918", recommendations[2]) } } func TestGorseClientTestSuite(t *testing.T) { suite.Run(t, new(GorseClientTestSuite)) } ================================================ FILE: client/config.go ================================================ // Copyright 2025 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package client import _ "embed" //go:embed config.toml var ConfigTOML string ================================================ FILE: client/config.toml ================================================ [master] # GRPC port of the master node. The default value is 8086. port = 8086 # gRPC host of the master node. The default values is "0.0.0.0". host = "0.0.0.0" # Enable SSL for the gRPC communication. The default value is false. ssl_mode = false # SSL certification authority for the gRPC communication. ssl_ca = "" # SSL certification for the gRPC communication. ssl_cert = "" # SSL certification key for the gRPC communication. ssl_key = "" # HTTP port of the master node. The default values is 8088. http_port = 8088 # HTTP host of the master node. The default values is "0.0.0.0". http_host = "0.0.0.0" # AllowedDomains is a list of allowed values for Http Origin. # The list may contain the special wildcard string ".*" ; all is allowed # If empty all are allowed. http_cors_domains = [] # AllowedMethods is either empty or has a list of http methods names. Checking is case-insensitive. http_cors_methods = [] # Number of working jobs in the master node. The default value is 1. n_jobs = 1 # Meta information timeout. The default value is 10s. meta_timeout = "10s" # Username for the master node dashboard. dashboard_user_name = "" # Password for the master node dashboard. dashboard_password = "" # Secret key for admin APIs (SSL required). admin_api_key = "" [server] # Default number of returned items. The default value is 10. default_n = 10 # Secret key for RESTful APIs (SSL required). api_key = "" # Clock error in the cluster. The default value is 5s. clock_error = "5s" # Insert new users while inserting feedback. The default value is true. auto_insert_user = true # Insert new items while inserting feedback. The default value is true. auto_insert_item = true # Server-side cache expire time. The default value is 10s. cache_expire = "10s" [recommend] # The cache size for recommended/popular/latest items. The default value is 10. cache_size = 100 # Recommended cache expire time. The default value is 72h. cache_expire = "72h" # The context size for online recommendations. Online recommendations can't use all user feedbacks to generate # recommendations for efficiency consideration. Instead, only the latest `context_size` feedbacks are used. # The default value is 100. context_size = 100 # The time-to-live (days) of active users, 0 means disabled. Recommendation won't be cached for inactive users. The default value is 0. active_user_ttl = 0 [recommend.data_source] # The feedback types for positive events. positive_feedback_types = ["rating>=4"] # The feedback types for read events. read_feedback_types = ["rating"] # The time-to-live (days) of positive feedback, 0 means disabled. The default value is 0. positive_feedback_ttl = 0 # The time-to-live (days) of items, 0 means disabled. The default value is 0. item_ttl = 0 [[recommend.non-personalized]] # The name of the leaderboard. name = "popular" # The score function for items in the leaderboard. score = "len(feedback)" [[recommend.item-to-item]] # The name of the item-to-item recommender. name = "neighbors" # The type of the item-to-item recommender. There are three types: # embedding: recommend by Euclidean distance of embeddings. # tags: recommend by number of common tags. # users: recommend by number of common users. # chat: recommend by chat completion model. type = "embedding" # The column of the item embeddings. Leave blank if type is "users". column = "item.Labels.embedding" [[recommend.user-to-user]] # The name of the user-to-user recommender. name = "neighbors" # The type of the user-to-user recommender. There are three types: # embedding: recommend by Euclidean distance of embeddings. # tags: recommend by number of common tags. # items: recommend by number of common items. type = "items" [recommend.collaborative] # The type of collaborative filtering. Supported values: # none: disable collaborative filtering. # mf: matrix factorization. type = "mf" # The time period for model fitting. The default value is "60m". fit_period = "60m" # The number of epochs for model fitting. The default value is 100. fit_epoch = 100 [recommend.collaborative.early_stopping] # Number of epochs to wait if no improvement and then stop the training. The default value is 10. patience = 10 [recommend.ranker] # The type of the ranker. There are two types: # none: no ranking. # fm: factorization machines. type = "none" # The recommenders used to fetch candidate items before ranking. The default values is all recommenders. recommenders = ["latest"] ================================================ FILE: client/docker-compose.yml.j2 ================================================ version: "3" services: {% if database == 'mysql' %} mysql: image: mysql/mysql-server restart: unless-stopped ports: - 3306:3306 environment: MYSQL_ROOT_PASSWORD: root_pass MYSQL_DATABASE: gorse MYSQL_USER: gorse MYSQL_PASSWORD: gorse_pass healthcheck: test: mysqladmin ping interval: 10s timeout: 5s retries: 5 {% elif database == 'postgres' %} postgres: image: postgres:10.0 ports: - 5432:5432 environment: POSTGRES_DB: gorse POSTGRES_USER: gorse POSTGRES_PASSWORD: gorse_pass healthcheck: test: pg_isready interval: 10s timeout: 5s retries: 5 {% elif database == 'mongo' %} mongo: image: mongo:4.0 ports: - 27017:27017 environment: MONGO_INITDB_DATABASE: gorse MONGO_INITDB_ROOT_USERNAME: root MONGO_INITDB_ROOT_PASSWORD: password healthcheck: test: mongo interval: 10s timeout: 5s retries: 5 {% elif database == 'clickhouse+redis' %} clickhouse: image: clickhouse/clickhouse-server:23 ports: - 8123:8123 environment: CLICKHOUSE_DB: gorse CLICKHOUSE_USER: gorse CLICKHOUSE_PASSWORD: gorse_pass healthcheck: test: clickhouse-client --user $$CLICKHOUSE_USER --password $$CLICKHOUSE_PASSWORD --query "SELECT 1" interval: 10s timeout: 5s retries: 5 redis: image: redis/redis-stack:6.2.6-v9 restart: unless-stopped ports: - 6379:6379 healthcheck: test: redis-cli ping interval: 10s timeout: 5s retries: 5 {% endif %} worker: build: context: . dockerfile: cmd/gorse-worker/Dockerfile cache_from: - type=gha cache_to: - type=gha,mode=max restart: unless-stopped ports: - 8089:8089 command: > --master-host master --master-port 8086 --http-host 0.0.0.0 --http-port 8089 --log-path /var/log/gorse/worker.log --cache-path /var/lib/gorse/worker_cache.data depends_on: - master server: build: context: . dockerfile: cmd/gorse-server/Dockerfile cache_from: - type=gha cache_to: - type=gha,mode=max restart: unless-stopped ports: - 8087:8087 command: > --master-host master --master-port 8086 --http-host 0.0.0.0 --http-port 8087 --log-path /var/log/gorse/server.log --cache-path /var/lib/gorse/server_cache.data depends_on: - master master: build: context: . dockerfile: cmd/gorse-master/Dockerfile cache_from: - type=gha cache_to: - type=gha,mode=max restart: unless-stopped ports: - 8086:8086 - 8088:8088 environment: {% if database == 'mysql' %} GORSE_DATA_STORE: mysql://gorse:gorse_pass@tcp(mysql:3306)/gorse GORSE_CACHE_STORE: mysql://gorse:gorse_pass@tcp(mysql:3306)/gorse {% elif database == 'postgres' %} GORSE_DATA_STORE: postgres://gorse:gorse_pass@postgres/gorse?sslmode=disable GORSE_CACHE_STORE: postgres://gorse:gorse_pass@postgres/gorse?sslmode=disable {% elif database == 'mongo' %} GORSE_DATA_STORE: mongodb://root:password@mongo:27017/gorse?authSource=admin&connect=direct GORSE_CACHE_STORE: mongodb://root:password@mongo:27017/gorse?authSource=admin&connect=direct {% elif database == 'clickhouse+redis' %} GORSE_DATA_STORE: clickhouse://gorse:gorse_pass@clickhouse:8123/gorse?mutations_sync=2 GORSE_CACHE_STORE: redis://redis:6379 {% elif database == 'sqlite' %} GORSE_DATA_STORE: sqlite:///var/lib/gorse/data.sqlite3 GORSE_CACHE_STORE: sqlite:///var/lib/gorse/cache.sqlite3 {% endif %} command: > -c /etc/gorse/config.toml --log-path /var/log/gorse/master.log --cache-path /var/lib/gorse volumes: - ./client/config.toml:/etc/gorse/config.toml {% if database != 'sqlite' %} depends_on: {% if database == 'mysql' %} mysql: condition: service_healthy {% elif database == 'postgres' %} postgres: condition: service_healthy {% elif database == 'mongo' %} mongo: condition: service_healthy {% elif database == 'clickhouse+redis' %} clickhouse: condition: service_healthy redis: condition: service_healthy {% endif %} {% endif %} ================================================ FILE: client/setup-test.sh ================================================ #!/bin/bash set -e # Download config if [ ! -f ./config.toml ]; then wget https://github.com/gorse-io/gorse/raw/refs/heads/master/client/config.toml fi # Create docker-compose.yml if [ ! -f ./docker-compose.yml ]; then cat > docker-compose.yml < 0 { userIndex := indices[0] feedbackCount[userIndex]++ } } var features []lo.Tuple2[[]int32, []float32] var embeddings [][][]float32 positives := make(map[int32][]int) negatives := make(map[int32][]int) for i := 0; i < test.Count(); i++ { indices, values, embedding, target := test.Get(i) features = append(features, lo.Tuple2[[]int32, []float32]{A: indices, B: values}) embeddings = append(embeddings, embedding) userIndex := indices[0] if target > 0 { positives[userIndex] = append(positives[userIndex], i) } else { negatives[userIndex] = append(negatives[userIndex], i) } } predictions := ml.BatchInternalPredict(features, embeddings, runtime.NumCPU()) var csvFile *os.File var csvWriter *csv.Writer if exportUserAUC { var err error csvFile, err = os.Create("AFM.csv") if err != nil { log.Logger().Error("failed to create AFM.csv", zap.Error(err)) exportUserAUC = false } else { defer csvFile.Close() csvWriter = csv.NewWriter(csvFile) defer csvWriter.Flush() _ = csvWriter.Write([]string{"Feedback", "Candidates", "AUC"}) } } var sum float32 var count float32 for userIndex, posIndices := range positives { negIndices := negatives[userIndex] if len(negIndices) == 0 || feedbackCount[userIndex] == 0 || feedbackCount[userIndex] > cfg.Recommend.ContextSize { continue } var posPredictions, negPredictions []float32 for _, posIndex := range posIndices { posPredictions = append(posPredictions, predictions[posIndex]) } for _, negIndex := range negIndices { negPredictions = append(negPredictions, predictions[negIndex]) } userAUC := ctr.AUC(posPredictions, negPredictions) if exportUserAUC { _ = csvWriter.Write([]string{ strconv.Itoa(feedbackCount[userIndex]), strconv.Itoa(len(posIndices) + len(negIndices)), fmt.Sprintf("%.4f", userAUC), }) } userCount := float32(len(posIndices) + len(negIndices)) sum += userAUC * userCount count += userCount } var score float64 if count > 0 { score = float64(sum / count) } scores.Store("AFM", score) } func EvaluateLLM(cfg *config.Config, train, test *ctr.Dataset, items []data.Item, exportUserAUC bool, scores *sync.Map) { chat, err := logics.NewChatReranker( cfg.Recommend.Ranker.RerankerAPI, cfg.Recommend.Ranker.QueryTemplate, cfg.Recommend.Ranker.DocumentTemplate, ) if err != nil { log.Logger().Fatal("failed to create chat ranker", zap.Error(err)) } feedbacks := make(map[int32][]*logics.FeedbackItem) for i := 0; i < train.Count(); i++ { indices, _, _, target := train.Get(i) if target <= 0 { continue } userIndex := indices[0] itemIndex := indices[1] - int32(train.CountUsers()) feedbacks[userIndex] = append(feedbacks[userIndex], &logics.FeedbackItem{ Item: items[itemIndex], }) } positives := make(map[int32][]int32) negatives := make(map[int32][]int32) for i := 0; i < test.Count(); i++ { indices, _, _, target := test.Get(i) userIndex := indices[0] itemIndex := indices[1] - int32(test.CountUsers()) if target > 0 { positives[userIndex] = append(positives[userIndex], itemIndex) } else { negatives[userIndex] = append(negatives[userIndex], itemIndex) } } var csvFile *os.File var csvWriter *csv.Writer var csvMu sync.Mutex if exportUserAUC { var err error csvFile, err = os.Create(fmt.Sprintf("%s.csv", cfg.Recommend.Ranker.RerankerAPI.Model)) if err != nil { log.Logger().Error("failed to create LLM.csv", zap.Error(err)) exportUserAUC = false } else { defer csvFile.Close() csvWriter = csv.NewWriter(csvFile) defer csvWriter.Flush() _ = csvWriter.Write([]string{"Feedback", "Candidates", "AUC"}) } } var sum atomic.Float32 var count atomic.Float32 lo.Must0(parallel.ForEach(context.Background(), slices.Collect(maps.Keys(positives)), runtime.NumCPU(), func(_ int, userIndex int32) { posIndices := positives[userIndex] negIndices := negatives[userIndex] if len(negIndices) == 0 { return } candidates := make([]*data.Item, 0, len(posIndices)+len(negIndices)) positiveItems := mapset.NewSet[string]() negativeItems := mapset.NewSet[string]() for _, negIndex := range negIndices { item := items[negIndex] candidates = append(candidates, &item) negativeItems.Add(item.ItemId) } for _, posIndex := range posIndices { item := items[posIndex] candidates = append(candidates, &item) positiveItems.Add(item.ItemId) } feedback := feedbacks[int32(userIndex)] if len(feedback) == 0 || len(feedback) > cfg.Recommend.ContextSize { return } result, err := chat.Rank(context.Background(), &data.User{}, feedback, candidates) if err != nil { log.Logger().Error("failed to rank items for user", zap.Int32("user_index", userIndex), zap.Error(err)) return } var posPredictions, negPredictions []float32 for _, item := range result { if positiveItems.Contains(item.Id) { posPredictions = append(posPredictions, float32(item.Score)) } else if negativeItems.Contains(item.Id) { negPredictions = append(negPredictions, float32(item.Score)) } } userAUC := ctr.AUC(posPredictions, negPredictions) if exportUserAUC { csvMu.Lock() _ = csvWriter.Write([]string{ strconv.Itoa(len(feedback)), strconv.Itoa(len(posIndices) + len(negIndices)), fmt.Sprintf("%.4f", userAUC), }) csvMu.Unlock() } userCount := float32(len(posIndices) + len(negIndices)) sum.Add(userAUC * userCount) count.Add(userCount) })) var score float64 if count.Load() > 0 { score = float64(sum.Load() / count.Load()) } scores.Store(cfg.Recommend.Ranker.RerankerAPI.Model, score) } func EvaluateEmbedding(cfg *config.Config, train, test dataset.CFSplit, embeddingExpr, textExpr string, topK, jobs int, scores *sync.Map) { // Compile expression var embeddingProgram, textProgram *vm.Program var err error if embeddingExpr != "" { embeddingProgram, err = expr.Compile(embeddingExpr, expr.Env(map[string]any{ "item": data.Item{}, })) if err != nil { log.Logger().Fatal("failed to compile embedding expression", zap.Error(err)) } } else if textExpr != "" { textProgram, err = expr.Compile(textExpr, expr.Env(map[string]any{ "item": data.Item{}, })) if err != nil { log.Logger().Fatal("failed to compile text expression", zap.Error(err)) } } else { log.Logger().Fatal("one of embedding-column or text-column is required") } // Extract embeddings var dimensions atomic.Int64 embeddings := make([][]float32, test.CountItems()) if textExpr != "" { clientConfig := openai.DefaultConfig(cfg.OpenAI.AuthToken) clientConfig.BaseURL = cfg.OpenAI.BaseURL client := openai.NewClientWithConfig(clientConfig) // Generate embeddings bar := progressbar.Default(int64(test.CountItems())) lo.Must0(parallel.For(context.Background(), test.CountItems(), jobs, func(i int) { _ = bar.Add(1) item := &test.GetItems()[i] result, err := expr.Run(textProgram, map[string]any{ "item": *item, }) if err != nil { return } text, ok := result.(string) if !ok { return } // Generate embedding resp, err := client.CreateEmbeddings(context.Background(), openai.EmbeddingRequest{ Input: text, Model: openai.EmbeddingModel(cfg.OpenAI.EmbeddingModel), Dimensions: cfg.OpenAI.EmbeddingDimensions, }) if err != nil { log.Logger().Error("failed to create embeddings", zap.String("item_id", item.ItemId), zap.Error(err)) return } embeddings[i] = resp.Data[0].Embedding if dimensions.Load() == 0 { dimensions.Store(int64(len(resp.Data[0].Embedding))) } })) } else { lo.Must0(parallel.For(context.Background(), test.CountItems(), jobs, func(i int) { item := test.GetItems()[i] result, err := expr.Run(embeddingProgram, map[string]any{ "item": item, }) if err == nil { if e, ok := result.([]float32); ok { embeddings[i] = e if dim := dimensions.Swap(int64(len(e))); dim == 0 { dimensions.Store(int64(len(e))) } else if dim != int64(len(e)) { log.Logger().Fatal("inconsistent embedding dimensions", zap.Int64("expected", dim), zap.Int64("got", int64(len(e)))) } } } })) } var ( ndcg atomic.Float32 precision atomic.Float32 recall atomic.Float32 count atomic.Float32 ) negatives := test.SampleUserNegatives(train, 99) lo.Must0(parallel.For(context.Background(), test.CountUsers(), jobs, func(userIdx int) { targetSet := mapset.NewSet(test.GetUserFeedback()[userIdx]...) negativeSample := negatives[userIdx] if len(test.GetUserFeedback()[userIdx]) == 0 { return } candidates := make([]int32, 0, targetSet.Cardinality()+len(negativeSample)) candidates = append(candidates, test.GetUserFeedback()[userIdx]...) candidates = append(candidates, negativeSample...) feedback := train.GetUserFeedback()[userIdx] if len(feedback) == 0 { return } h := heap.NewTopKFilter[int32, float32](topK) for _, candidateIdx := range candidates { candidateEmbedding := embeddings[candidateIdx] if candidateEmbedding == nil { continue } var totalDistance float32 var validShots int for _, shotIdx := range feedback { shotEmbedding := embeddings[shotIdx] if shotEmbedding == nil { continue } totalDistance -= floats.Euclidean(candidateEmbedding, shotEmbedding) validShots++ } if validShots > 0 { h.Push(candidateIdx, totalDistance) } } if h.Len() == 0 { return } rankList := h.PopAllValues() ndcg.Add(cf.NDCG(targetSet, rankList)) precision.Add(cf.Precision(targetSet, rankList)) recall.Add(cf.Recall(targetSet, rankList)) count.Add(1) })) var score cf.Score if count.Load() > 0 { score = cf.Score{ NDCG: ndcg.Load() / count.Load(), Precision: precision.Load() / count.Load(), Recall: recall.Load() / count.Load(), } } scores.Store(fmt.Sprintf("%s (%d)", cfg.OpenAI.EmbeddingModel, dimensions.Load()), score) } var benchEmbeddingCmd = &cobra.Command{ Use: "bench-embedding", Short: "Benchmark embedding models for item-to-item", Run: func(cmd *cobra.Command, args []string) { // Load configuration configPath, _ := cmd.Flags().GetString("config") cfg, err := config.LoadConfig(configPath) if err != nil { log.Logger().Fatal("failed to load config", zap.Error(err)) } shots, _ := cmd.Flags().GetInt("shots") embeddingColumn, _ := cmd.Flags().GetString("embedding-column") textColumn, _ := cmd.Flags().GetString("text-column") if embeddingColumn == "" && textColumn == "" { log.Logger().Fatal("one of embedding-column or text-column is required") } // Load dataset m := master.NewMaster(cfg, os.TempDir(), false, configPath) m.DataClient, err = data.Open(m.Config.Database.DataStore, m.Config.Database.DataTablePrefix, storage.WithIsolationLevel(m.Config.Database.MySQL.IsolationLevel)) if err != nil { log.Logger().Fatal("failed to open data client", zap.Error(err)) } evaluator := master.NewOnlineEvaluator( m.Config.Recommend.DataSource.PositiveFeedbackTypes, m.Config.Recommend.DataSource.ReadFeedbackTypes) _, dataset, err := m.LoadDataFromDatabase(context.Background(), m.DataClient, m.Config.Recommend.DataSource.PositiveFeedbackTypes, m.Config.Recommend.DataSource.ReadFeedbackTypes, m.Config.Recommend.DataSource.ItemTTL, m.Config.Recommend.DataSource.PositiveFeedbackTTL, evaluator, nil) if err != nil { log.Logger().Fatal("failed to load dataset", zap.Error(err)) } // Override config if cmd.Flags().Changed("embedding-model") { cfg.OpenAI.EmbeddingModel = cmd.Flag("embedding-model").Value.String() } dimensions, _ := cmd.Flags().GetInt("embedding-dimensions") cfg.OpenAI.EmbeddingDimensions = dimensions // Split dataset var scores sync.Map train, test := dataset.SplitLatest(shots) test.SampleUserNegatives(dataset, 99) table := tablewriter.NewWriter(os.Stdout) table.Header([]string{"", "#users", "#items", "#interactions"}) lo.Must0(table.Bulk([][]string{ {"total", strconv.Itoa(dataset.CountUsers()), strconv.Itoa(dataset.CountItems()), strconv.Itoa(dataset.CountFeedback())}, {"train", strconv.Itoa(train.CountUsers()), strconv.Itoa(train.CountItems()), strconv.Itoa(train.CountFeedback())}, {"test", strconv.Itoa(test.CountUsers()), strconv.Itoa(test.CountItems()), strconv.Itoa(test.CountFeedback())}, })) lo.Must0(table.Render()) topK, _ := cmd.Flags().GetInt("top") jobs, _ := cmd.Flags().GetInt("jobs") EvaluateEmbedding(cfg, train, test, embeddingColumn, textColumn, topK, jobs, &scores) data := [][]string{{ "Embedding Model", fmt.Sprintf("NDCG@%d", topK), fmt.Sprintf("Precision@%d", topK), fmt.Sprintf("Recall@%d", topK), }} scores.Range(func(key, value any) bool { score := value.(cf.Score) data = append(data, []string{ key.(string), fmt.Sprintf("%.4f", score.NDCG), fmt.Sprintf("%.4f", score.Precision), fmt.Sprintf("%.4f", score.Recall), }) return true }) table = tablewriter.NewWriter(os.Stdout) table.Header(data[0]) lo.Must0(table.Bulk(data[1:])) lo.Must0(table.Render()) }, } func init() { rootCmd.PersistentFlags().StringP("config", "c", "", "Path to configuration file") rootCmd.PersistentFlags().IntP("jobs", "j", runtime.NumCPU(), "Number of jobs to run in parallel") rootCmd.AddCommand(benchLLMCmd) rootCmd.AddCommand(benchEmbeddingCmd) benchLLMCmd.PersistentFlags().Bool("user-auc", false, "Export user-level AUC scores to CSV file") benchEmbeddingCmd.PersistentFlags().IntP("top", "k", 10, "Number of top items to evaluate for each user") benchEmbeddingCmd.PersistentFlags().IntP("shots", "s", math.MaxInt, "Number of shots for each user") benchEmbeddingCmd.PersistentFlags().Int("embedding-dimensions", 0, "Embedding dimensions") benchEmbeddingCmd.PersistentFlags().String("embedding-model", "", "Embedding model") benchEmbeddingCmd.PersistentFlags().String("embedding-column", "", "Column name of embedding in item label") benchEmbeddingCmd.PersistentFlags().String("text-column", "", "Column name of text in item label") } func main() { if err := rootCmd.Execute(); err != nil { log.Logger().Fatal("failed to execute command", zap.Error(err)) } } ================================================ FILE: cmd/gorse-in-one/Dockerfile ================================================ # syntax = docker/dockerfile:1 ############################ # STEP 1 build executable binary ############################ FROM --platform=$BUILDPLATFORM golang:1.26 WORKDIR /src COPY go.* ./ RUN go mod download COPY . ./ ARG TARGETOS ARG TARGETARCH RUN --mount=type=cache,target=/root/.cache/go-build \ cd cmd/gorse-in-one && \ GOOS=${TARGETOS} GOARCH=${TARGETARCH} CGO_ENABLED=0 go build -ldflags=" \ -X 'github.com/gorse-io/gorse/cmd/version.Version=$(git describe --tags $(git rev-parse HEAD))' \ -X 'github.com/gorse-io/gorse/cmd/version.GitCommit=$(git rev-parse HEAD)' \ -X 'github.com/gorse-io/gorse/cmd/version.BuildTime=$(date)' \ -X 'github.com/gorse-io/gorse/config.RootDir=/var/lib/gorse'" . && \ mv gorse-in-one /usr/bin ############################ # STEP 2 build a small image ############################ FROM scratch COPY --from=0 /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ COPY --from=0 /usr/bin/gorse-in-one /usr/bin/gorse-in-one ENV USER=root ENTRYPOINT ["/usr/bin/gorse-in-one", "-c", "/etc/gorse/config.toml"] ================================================ FILE: cmd/gorse-in-one/Dockerfile.cuda ================================================ # syntax = docker/dockerfile:1 ############################ # STEP 1 build executable binary ############################ FROM nvidia/cuda:12.8.1-devel-ubuntu24.04 COPY --from=golang:1.26 /usr/local/go/ /usr/local/go/ ENV PATH=/usr/local/go/bin:$PATH RUN apt update && apt install -y git WORKDIR /src COPY go.* ./ RUN go mod download COPY . ./ RUN cd common/blas/cublas && make RUN --mount=type=cache,target=/root/.cache/go-build \ cd cmd/gorse-in-one && \ go build -tags xla -ldflags=" \ -X 'github.com/gorse-io/gorse/cmd/version.Version=$(git describe --tags $(git rev-parse HEAD))' \ -X 'github.com/gorse-io/gorse/cmd/version.GitCommit=$(git rev-parse HEAD)' \ -X 'github.com/gorse-io/gorse/cmd/version.BuildTime=$(date)' \ -X 'github.com/gorse-io/gorse/config.RootDir=/var/lib/gorse'" . && \ mv gorse-in-one /usr/bin ############################ # STEP 2 build runtime image ############################ FROM nvidia/cuda:12.8.1-runtime-ubuntu24.04 COPY --from=0 /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ COPY --from=0 /usr/bin/gorse-in-one /usr/bin/gorse-in-one RUN /usr/bin/gorse-in-one --version ENV USER=root ENTRYPOINT ["/usr/bin/gorse-in-one", "-c", "/etc/gorse/config.toml"] ================================================ FILE: cmd/gorse-in-one/Dockerfile.mkl ================================================ # syntax = docker/dockerfile:1 ############################ # STEP 1 build executable binary ############################ FROM golang:1.26-bookworm # Install Intel® oneAPI Toolkits RUN apt update && apt install -y wget gnupg2 git RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \ | gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null RUN echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list RUN apt update && apt install -y intel-oneapi-base-toolkit # Build Gorse WORKDIR /src COPY go.* ./ RUN go mod download COPY . ./ RUN --mount=type=cache,target=/root/.cache/go-build \ . /opt/intel/oneapi/setvars.sh && \ cd cmd/gorse-in-one && \ go build -tags mkl -ldflags=" \ -X 'github.com/gorse-io/gorse/cmd/version.Version=$(git describe --tags $(git rev-parse HEAD))' \ -X 'github.com/gorse-io/gorse/cmd/version.GitCommit=$(git rev-parse HEAD)' \ -X 'github.com/gorse-io/gorse/cmd/version.BuildTime=$(date)' \ -X 'github.com/gorse-io/gorse/config.RootDir=/var/lib/gorse'" . ############################ # STEP 2 build runtime image ############################ FROM debian:bookworm-slim RUN apt update && apt install -y ca-certificates && rm -rf /var/lib/apt/lists/* COPY --from=0 /src/cmd/gorse-in-one/gorse-in-one /usr/bin/gorse-in-one RUN /usr/bin/gorse-in-one --version ENV USER=root ENTRYPOINT ["/usr/bin/gorse-in-one", "-c", "/etc/gorse/config.toml"] ================================================ FILE: cmd/gorse-in-one/Dockerfile.openblas ================================================ # syntax = docker/dockerfile:1 ############################ # STEP 1 build executable binary ############################ FROM --platform=$BUILDPLATFORM golang:1.26 # Install OpenBLAS ARG BUILDARCH ARG TARGETARCH RUN if [ "${TARGETARCH}" = "amd64" ]; then \ dpkg --add-architecture ${TARGETARCH} && apt update && apt install -y gcc-x86-64-linux-gnu libopenblas-dev:${TARGETARCH} git; \ elif [ "${TARGETARCH}" = "arm64" ]; then \ dpkg --add-architecture ${TARGETARCH} && apt update && apt install -y gcc-aarch64-linux-gnu libopenblas-dev:${TARGETARCH} git; \ elif [ "${TARGETARCH}" = "riscv64" ]; then \ dpkg --add-architecture ${TARGETARCH} && apt update && apt install -y gcc-riscv64-linux-gnu libopenblas-dev:${TARGETARCH} git; \ elif [ "${TARGETARCH}" = "ppc64le" ]; then \ dpkg --add-architecture ppc64el && apt update && apt install -y gcc-powerpc64le-linux-gnu libopenblas-dev:ppc64el git; \ elif [ "${TARGETARCH}" = "s390x" ]; then \ dpkg --add-architecture ${TARGETARCH} && apt update && apt install -y gcc-s390x-linux-gnu libopenblas-dev:${TARGETARCH} git; \ elif [ "${TARGETARCH}" = "loong64" ]; then \ dpkg --add-architecture ${TARGETARCH} && apt update && apt install -y gcc-loongarch64-linux-gnu libopenblas-dev:${TARGETARCH} git; \ else \ echo "Unsupported TARGETARCH ${TARGETARCH}"; exit 1; \ fi # Build Gorse WORKDIR /src COPY go.* ./ RUN go mod download COPY . ./ ARG TARGETOS RUN --mount=type=cache,target=/root/.cache/go-build \ if [ "${TARGETARCH}" = "amd64" ]; then \ export CC=x86_64-linux-gnu-gcc; \ elif [ "${TARGETARCH}" = "arm64" ]; then \ export CC=aarch64-linux-gnu-gcc; \ elif [ "${TARGETARCH}" = "riscv64" ]; then \ export CC=riscv64-linux-gnu-gcc; \ elif [ "${TARGETARCH}" = "ppc64le" ]; then \ export CC=powerpc64le-linux-gnu-gcc; \ elif [ "${TARGETARCH}" = "s390x" ]; then \ export CC=s390x-linux-gnu-gcc; \ elif [ "${TARGETARCH}" = "loong64" ]; then \ export CC=loongarch64-linux-gnu-gcc; \ else \ echo "Unsupported TARGETARCH ${TARGETARCH}"; exit 1; \ fi && \ cd cmd/gorse-in-one && \ GOOS=${TARGETOS} GOARCH=${TARGETARCH} CGO_ENABLED=1 go build -tags openblas -ldflags=" \ -X 'github.com/gorse-io/gorse/cmd/version.Version=$(git describe --tags $(git rev-parse HEAD))' \ -X 'github.com/gorse-io/gorse/cmd/version.GitCommit=$(git rev-parse HEAD)' \ -X 'github.com/gorse-io/gorse/cmd/version.BuildTime=$(date)' \ -X 'github.com/gorse-io/gorse/config.RootDir=/var/lib/gorse'" . ############################ # STEP 2 build runtime image ############################ FROM debian:trixie-slim RUN apt update && apt install -y ca-certificates && rm -rf /var/lib/apt/lists/* COPY --from=0 /src/cmd/gorse-in-one/gorse-in-one /usr/bin/gorse-in-one RUN /usr/bin/gorse-in-one --version ENV USER=root ENTRYPOINT ["/usr/bin/gorse-in-one", "-c", "/etc/gorse/config.toml"] ================================================ FILE: cmd/gorse-in-one/Dockerfile.windows ================================================ ############################ # STEP 1 build executable binary ############################ FROM golang:1.26 WORKDIR /src COPY . ./ ENV CGO_ENABLED=0 RUN go build -o / -ldflags="\" \ -X 'github.com/gorse-io/gorse/cmd/version.Version=$(git describe --tags $(git rev-parse HEAD))' \ -X 'github.com/gorse-io/gorse/cmd/version.GitCommit=$(git rev-parse HEAD)' \ -X 'github.com/gorse-io/gorse/cmd/version.BuildTime=$(date)'\"" ./cmd/... ############################ # STEP 2 build a small image ############################ FROM mcr.microsoft.com/windows/servercore:ltsc2022 COPY --from=0 /gorse-in-one.exe /gorse-in-one.exe RUN /gorse-in-one.exe --version ENTRYPOINT [ "/gorse-in-one.exe" ] ================================================ FILE: cmd/gorse-in-one/main.go ================================================ // Copyright 2022 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package main import ( "compress/gzip" "fmt" "net/http" "os" "os/signal" "path/filepath" "runtime" "github.com/gorse-io/gorse/client" "github.com/gorse-io/gorse/cmd/version" "github.com/gorse-io/gorse/common/log" "github.com/gorse-io/gorse/config" "github.com/gorse-io/gorse/master" "github.com/gorse-io/gorse/storage/data" "github.com/klauspost/cpuid/v2" "github.com/schollz/progressbar/v3" "github.com/spf13/cobra" "go.uber.org/zap" ) var oneCommand = &cobra.Command{ Use: "gorse-in-one", Short: "The all in one distribution of gorse recommender system.", Run: func(cmd *cobra.Command, args []string) { // Show version if showVersion, _ := cmd.PersistentFlags().GetBool("version"); showVersion { fmt.Println(version.BuildInfo()) return } // setup logger debug, _ := cmd.PersistentFlags().GetBool("debug") log.SetLogger(cmd.PersistentFlags(), debug) // locate config file var configPath string cachePath, _ := cmd.PersistentFlags().GetString("cache-path") playground, _ := cmd.PersistentFlags().GetString("playground") if playground != "" { userHomeDir, err := os.UserHomeDir() if err != nil { log.Logger().Fatal("failed to get user home directory", zap.Error(err)) } etcDir := filepath.Join(userHomeDir, ".gorse", "etc") if err = os.MkdirAll(etcDir, os.ModePerm); err != nil { log.Logger().Fatal("failed to create config directory", zap.Error(err)) } configPath = filepath.Join(etcDir, "config.toml") if playground == "ml-100k" { err = os.WriteFile(configPath, []byte(client.ConfigTOML), 0644) } else { err = os.WriteFile(configPath, []byte(config.ConfigTOML), 0644) } if err != nil { log.Logger().Fatal("failed to write playground config", zap.Error(err)) } fmt.Println("Generated config file:", configPath) fmt.Println("Using cache directory:", cachePath) } else { configPath, _ = cmd.PersistentFlags().GetString("config") log.Logger().Info("load config", zap.String("config", configPath)) } // load config conf, err := config.LoadConfig(configPath) if err != nil { log.Logger().Fatal("failed to load config", zap.Error(err)) } // create master m := master.NewMaster(conf, cachePath, true, configPath) if playground != "" { setup(m, playground) } // Stop master done := make(chan struct{}) go func() { sigint := make(chan os.Signal, 1) signal.Notify(sigint, os.Interrupt) <-sigint m.Shutdown() close(done) }() // Start master m.Serve() <-done log.Logger().Info("stop gorse-in-one successfully") }, } func init() { log.AddFlags(oneCommand.PersistentFlags()) oneCommand.PersistentFlags().Bool("debug", false, "use debug log mode") oneCommand.PersistentFlags().BoolP("version", "v", false, "gorse version") oneCommand.PersistentFlags().String("playground", "", "playground mode (setup a recommender system for GitHub repositories)") oneCommand.PersistentFlags().Lookup("playground").NoOptDefVal = "default" oneCommand.PersistentFlags().StringP("config", "c", "", "configuration file path") oneCommand.PersistentFlags().String("cache-path", config.MkDir("master"), "path of cache folder") } func main() { if err := oneCommand.Execute(); err != nil { log.Logger().Fatal("failed to execute", zap.Error(err)) } } func setup(m *master.Master, playground string) { // set database to user home directory m.Config.Database.DataStore = config.GetDefaultConfig().Database.DataStore fmt.Println("Using database:", m.Config.Database.DataStore) m.Config.Database.CacheStore = config.GetDefaultConfig().Database.CacheStore fmt.Println("Using cache:", m.Config.Database.CacheStore) m.Config.Master.NumJobs = runtime.NumCPU() fmt.Printf("Using %d CPU cores: %s\n", m.Config.Master.NumJobs, cpuid.CPU.BrandName) // connect database var err error dataOpts := m.Config.Database.StorageOptions(m.Config.Database.DataStore) m.DataClient, err = data.Open(m.Config.Database.DataStore, m.Config.Database.DataTablePrefix, dataOpts...) if err != nil { log.Logger().Fatal("failed to connect data database", zap.Error(err), zap.String("database", log.RedactDBURL(m.Config.Database.DataStore))) } defer m.DataClient.Close() if err = m.DataClient.Init(); err != nil { log.Logger().Fatal("failed to init database", zap.Error(err)) } // import playground data var resp *http.Response if playground == "ml-100k" { resp, err = http.Get("https://cdn.gorse.io/example/ml-100k.bin.gz") } else { resp, err = http.Get("https://cdn.gorse.io/example/github.bin.gz") } if err != nil { log.Logger().Fatal("failed to download playground data", zap.Error(err)) } defer resp.Body.Close() bar := progressbar.DefaultBytes( resp.ContentLength, "Importing data", ) p := progressbar.NewReader(resp.Body, bar) d, err := gzip.NewReader(&p) if err != nil { log.Logger().Fatal("failed to decompress playground data", zap.Error(err)) } _, err = m.Restore(d) if err != nil { log.Logger().Fatal("failed to import playground data", zap.Error(err)) } // show info fmt.Printf("Welcome to Gorse Playground\n") fmt.Println() fmt.Printf(" Dashboard: http://127.0.0.1:%d/overview\n", m.Config.Master.HttpPort) fmt.Printf(" RESTful APIs: http://127.0.0.1:%d/apidocs\n", m.Config.Master.HttpPort) fmt.Printf(" Documentation: https://gorse.io\n") fmt.Println() } ================================================ FILE: cmd/gorse-master/Dockerfile ================================================ # syntax = docker/dockerfile:1 ############################ # STEP 1 build executable binary ############################ FROM --platform=$BUILDPLATFORM golang:1.26 WORKDIR /src COPY go.* ./ RUN go mod download COPY . ./ ARG TARGETOS ARG TARGETARCH RUN --mount=type=cache,target=/root/.cache/go-build \ cd cmd/gorse-master && \ GOOS=${TARGETOS} GOARCH=${TARGETARCH} CGO_ENABLED=0 go build -ldflags=" \ -X 'github.com/gorse-io/gorse/cmd/version.Version=$(git describe --tags $(git rev-parse HEAD))' \ -X 'github.com/gorse-io/gorse/cmd/version.GitCommit=$(git rev-parse HEAD)' \ -X 'github.com/gorse-io/gorse/cmd/version.BuildTime=$(date)' \ -X 'github.com/gorse-io/gorse/config.RootDir=/var/lib/gorse'" . ############################ # STEP 2 build a small image ############################ FROM scratch COPY --from=0 /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ COPY --from=0 /src/cmd/gorse-master/gorse-master /usr/bin/gorse-master ENV USER=root ENTRYPOINT ["/usr/bin/gorse-master", "-c", "/etc/gorse/config.toml"] ================================================ FILE: cmd/gorse-master/Dockerfile.cuda ================================================ # syntax = docker/dockerfile:1 ############################ # STEP 1 build executable binary ############################ FROM nvidia/cuda:12.8.1-devel-ubuntu24.04 COPY --from=golang:1.26 /usr/local/go/ /usr/local/go/ ENV PATH=/usr/local/go/bin:$PATH RUN apt update && apt install -y git WORKDIR /src COPY go.* ./ RUN go mod download COPY . ./ RUN cd common/blas/cublas && make RUN --mount=type=cache,target=/root/.cache/go-build \ cd cmd/gorse-master && \ go build -tags xla -ldflags=" \ -X 'github.com/gorse-io/gorse/cmd/version.Version=$(git describe --tags $(git rev-parse HEAD))' \ -X 'github.com/gorse-io/gorse/cmd/version.GitCommit=$(git rev-parse HEAD)' \ -X 'github.com/gorse-io/gorse/cmd/version.BuildTime=$(date)' \ -X 'github.com/gorse-io/gorse/config.RootDir=/var/lib/gorse'" . ############################ # STEP 2 build runtime image ############################ FROM nvidia/cuda:12.8.1-runtime-ubuntu24.04 COPY --from=0 /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ COPY --from=0 /src/cmd/gorse-master/gorse-master /usr/bin/gorse-master RUN /usr/bin/gorse-master --version ENV USER=root ENTRYPOINT ["/usr/bin/gorse-master", "-c", "/etc/gorse/config.toml"] ================================================ FILE: cmd/gorse-master/Dockerfile.mkl ================================================ # syntax = docker/dockerfile:1 ############################ # STEP 1 build executable binary ############################ FROM golang:1.26-bookworm # Install Intel® oneAPI Toolkits RUN apt update && apt install -y wget gnupg2 git RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \ | gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null RUN echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list RUN apt update && apt install -y intel-oneapi-base-toolkit # Build Gorse WORKDIR /src COPY go.* ./ RUN go mod download COPY . ./ RUN --mount=type=cache,target=/root/.cache/go-build \ . /opt/intel/oneapi/setvars.sh && \ cd cmd/gorse-master && \ go build -tags mkl -ldflags=" \ -X 'github.com/gorse-io/gorse/cmd/version.Version=$(git describe --tags $(git rev-parse HEAD))' \ -X 'github.com/gorse-io/gorse/cmd/version.GitCommit=$(git rev-parse HEAD)' \ -X 'github.com/gorse-io/gorse/cmd/version.BuildTime=$(date)' \ -X 'github.com/gorse-io/gorse/config.RootDir=/var/lib/gorse'" . ############################ # STEP 2 build runtime image ############################ FROM debian:bookworm-slim RUN apt update && apt install -y ca-certificates && rm -rf /var/lib/apt/lists/* COPY --from=0 /src/cmd/gorse-master/gorse-master /usr/bin/gorse-master RUN /usr/bin/gorse-master --version ENV USER=root ENTRYPOINT ["/usr/bin/gorse-master", "-c", "/etc/gorse/config.toml"] ================================================ FILE: cmd/gorse-master/Dockerfile.openblas ================================================ # syntax = docker/dockerfile:1 ############################ # STEP 1 build executable binary ############################ FROM --platform=$BUILDPLATFORM golang:1.26 # Install OpenBLAS ARG BUILDARCH ARG TARGETARCH RUN if [ "${TARGETARCH}" = "amd64" ]; then \ dpkg --add-architecture ${TARGETARCH} && apt update && apt install -y gcc-x86-64-linux-gnu libopenblas-dev:${TARGETARCH} git; \ elif [ "${TARGETARCH}" = "arm64" ]; then \ dpkg --add-architecture ${TARGETARCH} && apt update && apt install -y gcc-aarch64-linux-gnu libopenblas-dev:${TARGETARCH} git; \ elif [ "${TARGETARCH}" = "riscv64" ]; then \ dpkg --add-architecture ${TARGETARCH} && apt update && apt install -y gcc-riscv64-linux-gnu libopenblas-dev:${TARGETARCH} git; \ elif [ "${TARGETARCH}" = "ppc64le" ]; then \ dpkg --add-architecture ppc64el && apt update && apt install -y gcc-powerpc64le-linux-gnu libopenblas-dev:ppc64el git; \ elif [ "${TARGETARCH}" = "s390x" ]; then \ dpkg --add-architecture ${TARGETARCH} && apt update && apt install -y gcc-s390x-linux-gnu libopenblas-dev:${TARGETARCH} git; \ elif [ "${TARGETARCH}" = "loong64" ]; then \ dpkg --add-architecture ${TARGETARCH} && apt update && apt install -y gcc-loongarch64-linux-gnu libopenblas-dev:${TARGETARCH} git; \ else \ echo "Unsupported TARGETARCH ${TARGETARCH}"; exit 1; \ fi # Build Gorse WORKDIR /src COPY go.* ./ RUN go mod download COPY . ./ ARG TARGETOS RUN --mount=type=cache,target=/root/.cache/go-build \ if [ "${TARGETARCH}" = "amd64" ]; then \ export CC=x86_64-linux-gnu-gcc; \ elif [ "${TARGETARCH}" = "arm64" ]; then \ export CC=aarch64-linux-gnu-gcc; \ elif [ "${TARGETARCH}" = "riscv64" ]; then \ export CC=riscv64-linux-gnu-gcc; \ elif [ "${TARGETARCH}" = "ppc64le" ]; then \ export CC=powerpc64le-linux-gnu-gcc; \ elif [ "${TARGETARCH}" = "s390x" ]; then \ export CC=s390x-linux-gnu-gcc; \ elif [ "${TARGETARCH}" = "loong64" ]; then \ export CC=loongarch64-linux-gnu-gcc; \ else \ echo "Unsupported TARGETARCH ${TARGETARCH}"; exit 1; \ fi && \ cd cmd/gorse-master && \ GOOS=${TARGETOS} GOARCH=${TARGETARCH} CGO_ENABLED=1 go build -tags openblas -ldflags=" \ -X 'github.com/gorse-io/gorse/cmd/version.Version=$(git describe --tags $(git rev-parse HEAD))' \ -X 'github.com/gorse-io/gorse/cmd/version.GitCommit=$(git rev-parse HEAD)' \ -X 'github.com/gorse-io/gorse/cmd/version.BuildTime=$(date)' \ -X 'github.com/gorse-io/gorse/config.RootDir=/var/lib/gorse'" . ############################ # STEP 2 build runtime image ############################ FROM debian:trixie-slim RUN apt update && apt install -y ca-certificates && rm -rf /var/lib/apt/lists/* COPY --from=0 /src/cmd/gorse-master/gorse-master /usr/bin/gorse-master RUN /usr/bin/gorse-master --version ENV USER=root ENTRYPOINT ["/usr/bin/gorse-master", "-c", "/etc/gorse/config.toml"] ================================================ FILE: cmd/gorse-master/Dockerfile.windows ================================================ ############################ # STEP 1 build executable binary ############################ FROM golang:1.26 WORKDIR /src COPY . ./ ENV CGO_ENABLED=0 RUN go build -o / -ldflags="\" \ -X 'github.com/gorse-io/gorse/cmd/version.Version=$(git describe --tags $(git rev-parse HEAD))' \ -X 'github.com/gorse-io/gorse/cmd/version.GitCommit=$(git rev-parse HEAD)' \ -X 'github.com/gorse-io/gorse/cmd/version.BuildTime=$(date)'\"" ./cmd/... ############################ # STEP 2 build a small image ############################ FROM mcr.microsoft.com/windows/servercore:ltsc2022 COPY --from=0 /gorse-master.exe /gorse-master.exe RUN /gorse-master.exe --version ENTRYPOINT [ "/gorse-master.exe" ] ================================================ FILE: cmd/gorse-master/main.go ================================================ // Copyright 2020 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package main import ( "fmt" "os" "os/signal" "github.com/gorse-io/gorse/cmd/version" "github.com/gorse-io/gorse/common/log" "github.com/gorse-io/gorse/config" "github.com/gorse-io/gorse/master" "github.com/spf13/cobra" "go.uber.org/zap" ) var masterCommand = &cobra.Command{ Use: "gorse-master", Short: "The master node of gorse recommender system.", Run: func(cmd *cobra.Command, args []string) { // Show version if showVersion, _ := cmd.PersistentFlags().GetBool("version"); showVersion { fmt.Println(version.BuildInfo()) return } // setup logger debug, _ := cmd.PersistentFlags().GetBool("debug") log.SetLogger(cmd.PersistentFlags(), debug) // Create master configPath, _ := cmd.PersistentFlags().GetString("config") log.Logger().Info("load config", zap.String("config", configPath)) conf, err := config.LoadConfig(configPath) if err != nil { log.Logger().Fatal("failed to load config", zap.Error(err)) } cachePath, _ := cmd.PersistentFlags().GetString("cache-path") m := master.NewMaster(conf, cachePath, false, configPath) // Stop master done := make(chan struct{}) go func() { sigint := make(chan os.Signal, 1) signal.Notify(sigint, os.Interrupt) <-sigint m.Shutdown() close(done) }() // Start master m.Serve() <-done log.Logger().Info("stop gorse master successfully") }, } func init() { log.AddFlags(masterCommand.PersistentFlags()) masterCommand.PersistentFlags().Bool("debug", false, "use debug log mode") masterCommand.PersistentFlags().StringP("config", "c", "", "configuration file path") masterCommand.PersistentFlags().BoolP("version", "v", false, "gorse version") masterCommand.PersistentFlags().String("cache-path", config.MkDir("master"), "path of cache folder") } func main() { if err := masterCommand.Execute(); err != nil { log.Logger().Fatal("failed to execute", zap.Error(err)) } } ================================================ FILE: cmd/gorse-server/Dockerfile ================================================ # syntax = docker/dockerfile:1 ############################ # STEP 1 build executable binary ############################ FROM --platform=$BUILDPLATFORM golang:1.26 WORKDIR /src COPY go.* ./ RUN go mod download COPY . ./ ARG TARGETOS ARG TARGETARCH RUN --mount=type=cache,target=/root/.cache/go-build \ cd cmd/gorse-server && \ GOOS=${TARGETOS} GOARCH=${TARGETARCH} CGO_ENABLED=0 go build -ldflags=" \ -X 'github.com/gorse-io/gorse/cmd/version.Version=$(git describe --tags $(git rev-parse HEAD))' \ -X 'github.com/gorse-io/gorse/cmd/version.GitCommit=$(git rev-parse HEAD)' \ -X 'github.com/gorse-io/gorse/cmd/version.BuildTime=$(date)' \ -X 'github.com/gorse-io/gorse/config.RootDir=/var/lib/gorse'" . ############################ # STEP 2 build a small image ############################ FROM scratch COPY --from=0 /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ COPY --from=0 /src/cmd/gorse-server/gorse-server /usr/bin/gorse-server ENV USER=root ENTRYPOINT ["/usr/bin/gorse-server"] ================================================ FILE: cmd/gorse-server/Dockerfile.cuda ================================================ # syntax = docker/dockerfile:1 ############################ # STEP 1 build executable binary ############################ FROM nvidia/cuda:12.8.1-devel-ubuntu24.04 COPY --from=golang:1.26 /usr/local/go/ /usr/local/go/ ENV PATH=/usr/local/go/bin:$PATH RUN apt update && apt install -y git WORKDIR /src COPY go.* ./ RUN go mod download COPY . ./ RUN cd common/blas/cublas && make RUN --mount=type=cache,target=/root/.cache/go-build \ cd cmd/gorse-server && \ go build -tags xla -ldflags=" \ -X 'github.com/gorse-io/gorse/cmd/version.Version=$(git describe --tags $(git rev-parse HEAD))' \ -X 'github.com/gorse-io/gorse/cmd/version.GitCommit=$(git rev-parse HEAD)' \ -X 'github.com/gorse-io/gorse/cmd/version.BuildTime=$(date)' \ -X 'github.com/gorse-io/gorse/config.RootDir=/var/lib/gorse'" . ############################ # STEP 2 build runtime image ############################ FROM nvidia/cuda:12.8.1-runtime-ubuntu24.04 COPY --from=0 /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ COPY --from=0 /src/cmd/gorse-server/gorse-server /usr/bin/gorse-server RUN /usr/bin/gorse-server --version ENV USER=root ENTRYPOINT ["/usr/bin/gorse-server"] ================================================ FILE: cmd/gorse-server/Dockerfile.mkl ================================================ # syntax = docker/dockerfile:1 ############################ # STEP 1 build executable binary ############################ FROM golang:1.26-bookworm # Install Intel® oneAPI Toolkits RUN apt update && apt install -y wget gnupg2 git RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \ | gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null RUN echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list RUN apt update && apt install -y intel-oneapi-base-toolkit # Build Gorse WORKDIR /src COPY go.* ./ RUN go mod download COPY . ./ RUN --mount=type=cache,target=/root/.cache/go-build \ . /opt/intel/oneapi/setvars.sh && \ cd cmd/gorse-server && \ go build -tags mkl -ldflags=" \ -X 'github.com/gorse-io/gorse/cmd/version.Version=$(git describe --tags $(git rev-parse HEAD))' \ -X 'github.com/gorse-io/gorse/cmd/version.GitCommit=$(git rev-parse HEAD)' \ -X 'github.com/gorse-io/gorse/cmd/version.BuildTime=$(date)' \ -X 'github.com/gorse-io/gorse/config.RootDir=/var/lib/gorse'" . ############################ # STEP 2 build runtime image ############################ FROM debian:bookworm-slim RUN apt update && apt install -y ca-certificates && rm -rf /var/lib/apt/lists/* COPY --from=0 /src/cmd/gorse-server/gorse-server /usr/bin/gorse-server RUN /usr/bin/gorse-server --version ENV USER=root ENTRYPOINT ["/usr/bin/gorse-server"] ================================================ FILE: cmd/gorse-server/Dockerfile.openblas ================================================ # syntax = docker/dockerfile:1 ############################ # STEP 1 build executable binary ############################ FROM --platform=$BUILDPLATFORM golang:1.26 # Install OpenBLAS ARG BUILDARCH ARG TARGETARCH RUN if [ "${TARGETARCH}" = "amd64" ]; then \ dpkg --add-architecture ${TARGETARCH} && apt update && apt install -y gcc-x86-64-linux-gnu libopenblas-dev:${TARGETARCH} git; \ elif [ "${TARGETARCH}" = "arm64" ]; then \ dpkg --add-architecture ${TARGETARCH} && apt update && apt install -y gcc-aarch64-linux-gnu libopenblas-dev:${TARGETARCH} git; \ elif [ "${TARGETARCH}" = "riscv64" ]; then \ dpkg --add-architecture ${TARGETARCH} && apt update && apt install -y gcc-riscv64-linux-gnu libopenblas-dev:${TARGETARCH} git; \ elif [ "${TARGETARCH}" = "ppc64le" ]; then \ dpkg --add-architecture ppc64el && apt update && apt install -y gcc-powerpc64le-linux-gnu libopenblas-dev:ppc64el git; \ elif [ "${TARGETARCH}" = "s390x" ]; then \ dpkg --add-architecture ${TARGETARCH} && apt update && apt install -y gcc-s390x-linux-gnu libopenblas-dev:${TARGETARCH} git; \ elif [ "${TARGETARCH}" = "loong64" ]; then \ dpkg --add-architecture ${TARGETARCH} && apt update && apt install -y gcc-loongarch64-linux-gnu libopenblas-dev:${TARGETARCH} git; \ else \ echo "Unsupported TARGETARCH ${TARGETARCH}"; exit 1; \ fi # Build Gorse WORKDIR /src COPY go.* ./ RUN go mod download COPY . ./ ARG TARGETOS RUN --mount=type=cache,target=/root/.cache/go-build \ if [ "${TARGETARCH}" = "amd64" ]; then \ export CC=x86_64-linux-gnu-gcc; \ elif [ "${TARGETARCH}" = "arm64" ]; then \ export CC=aarch64-linux-gnu-gcc; \ elif [ "${TARGETARCH}" = "riscv64" ]; then \ export CC=riscv64-linux-gnu-gcc; \ elif [ "${TARGETARCH}" = "ppc64le" ]; then \ export CC=powerpc64le-linux-gnu-gcc; \ elif [ "${TARGETARCH}" = "s390x" ]; then \ export CC=s390x-linux-gnu-gcc; \ elif [ "${TARGETARCH}" = "loong64" ]; then \ export CC=loongarch64-linux-gnu-gcc; \ else \ echo "Unsupported TARGETARCH ${TARGETARCH}"; exit 1; \ fi && \ cd cmd/gorse-server && \ GOOS=${TARGETOS} GOARCH=${TARGETARCH} CGO_ENABLED=1 go build -tags openblas -ldflags=" \ -X 'github.com/gorse-io/gorse/cmd/version.Version=$(git describe --tags $(git rev-parse HEAD))' \ -X 'github.com/gorse-io/gorse/cmd/version.GitCommit=$(git rev-parse HEAD)' \ -X 'github.com/gorse-io/gorse/cmd/version.BuildTime=$(date)' \ -X 'github.com/gorse-io/gorse/config.RootDir=/var/lib/gorse'" . ############################ # STEP 2 build runtime image ############################ FROM debian:trixie-slim RUN apt update && apt install -y ca-certificates && rm -rf /var/lib/apt/lists/* COPY --from=0 /src/cmd/gorse-server/gorse-server /usr/bin/gorse-server RUN /usr/bin/gorse-server --version ENV USER=root ENTRYPOINT ["/usr/bin/gorse-server"] ================================================ FILE: cmd/gorse-server/Dockerfile.windows ================================================ ############################ # STEP 1 build executable binary ############################ FROM golang:1.26 WORKDIR /src COPY . ./ ENV CGO_ENABLED=0 RUN go build -o / -ldflags="\" \ -X 'github.com/gorse-io/gorse/cmd/version.Version=$(git describe --tags $(git rev-parse HEAD))' \ -X 'github.com/gorse-io/gorse/cmd/version.GitCommit=$(git rev-parse HEAD)' \ -X 'github.com/gorse-io/gorse/cmd/version.BuildTime=$(date)'\"" ./cmd/... ############################ # STEP 2 build a small image ############################ FROM mcr.microsoft.com/windows/servercore:ltsc2022 COPY --from=0 /gorse-server.exe /gorse-server.exe RUN /gorse-server.exe --version ENTRYPOINT [ "/gorse-server.exe" ] ================================================ FILE: cmd/gorse-server/main.go ================================================ // Copyright 2020 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package main import ( "fmt" "os" "os/signal" "github.com/gorse-io/gorse/cmd/version" "github.com/gorse-io/gorse/common/log" "github.com/gorse-io/gorse/common/util" "github.com/gorse-io/gorse/server" "github.com/spf13/cobra" "go.uber.org/zap" ) var serverCommand = &cobra.Command{ Use: "gorse-server", Short: "The server node of gorse recommender system.", Run: func(cmd *cobra.Command, args []string) { // show version showVersion, _ := cmd.PersistentFlags().GetBool("version") if showVersion { fmt.Println(version.BuildInfo()) return } // setup logger debug, _ := cmd.PersistentFlags().GetBool("debug") log.SetLogger(cmd.PersistentFlags(), debug) // create server masterPort, _ := cmd.PersistentFlags().GetInt("master-port") masterHost, _ := cmd.PersistentFlags().GetString("master-host") httpPort, _ := cmd.PersistentFlags().GetInt("http-port") httpHost, _ := cmd.PersistentFlags().GetString("http-host") cachePath, _ := cmd.PersistentFlags().GetString("cache-path") caFile, _ := cmd.PersistentFlags().GetString("ssl-ca") certFile, _ := cmd.PersistentFlags().GetString("ssl-cert") keyFile, _ := cmd.PersistentFlags().GetString("ssl-key") var tlsConfig *util.TLSConfig if caFile != "" && certFile != "" && keyFile != "" { tlsConfig = &util.TLSConfig{ SSLCA: caFile, SSLCert: certFile, SSLKey: keyFile, } } else if caFile == "" && certFile == "" && keyFile == "" { tlsConfig = nil } else { log.Logger().Fatal("incomplete SSL configuration", zap.String("ssl_ca", caFile), zap.String("ssl_cert", certFile), zap.String("ssl_key", keyFile)) } s := server.NewServer(masterHost, masterPort, httpHost, httpPort, cachePath, tlsConfig) // stop server done := make(chan struct{}) go func() { sigint := make(chan os.Signal, 1) signal.Notify(sigint, os.Interrupt) <-sigint s.Shutdown() close(done) }() // start server s.Serve() <-done log.Logger().Info("stop gorse server successfully") }, } func init() { log.AddFlags(serverCommand.PersistentFlags()) serverCommand.PersistentFlags().BoolP("version", "v", false, "gorse version") serverCommand.PersistentFlags().Int("master-port", 8086, "port of master node") serverCommand.PersistentFlags().String("master-host", "127.0.0.1", "host of master node") serverCommand.PersistentFlags().Int("http-port", 8087, "host for RESTful APIs and Prometheus metrics export") serverCommand.PersistentFlags().String("http-host", "127.0.0.1", "port for RESTful APIs and Prometheus metrics export") serverCommand.PersistentFlags().Bool("debug", false, "use debug log mode") serverCommand.PersistentFlags().String("cache-path", "server_cache.data", "path of cache file") serverCommand.PersistentFlags().String("ssl-ca", "", "path of SSL CA") serverCommand.PersistentFlags().String("ssl-cert", "", "path of SSL certificate") serverCommand.PersistentFlags().String("ssl-key", "", "path of SSL key") } func main() { if err := serverCommand.Execute(); err != nil { log.Logger().Fatal("failed to execute", zap.Error(err)) } } ================================================ FILE: cmd/gorse-worker/Dockerfile ================================================ # syntax = docker/dockerfile:1 ############################ # STEP 1 build executable binary ############################ FROM --platform=$BUILDPLATFORM golang:1.26 WORKDIR /src COPY go.* ./ RUN go mod download COPY . ./ ARG TARGETOS ARG TARGETARCH RUN --mount=type=cache,target=/root/.cache/go-build \ cd cmd/gorse-worker && \ GOOS=${TARGETOS} GOARCH=${TARGETARCH} CGO_ENABLED=0 go build -ldflags=" \ -X 'github.com/gorse-io/gorse/cmd/version.Version=$(git describe --tags $(git rev-parse HEAD))' \ -X 'github.com/gorse-io/gorse/cmd/version.GitCommit=$(git rev-parse HEAD)' \ -X 'github.com/gorse-io/gorse/cmd/version.BuildTime=$(date)' \ -X 'github.com/gorse-io/gorse/config.RootDir=/var/lib/gorse'" . ############################ # STEP 2 build a small image ############################ FROM scratch COPY --from=0 /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ COPY --from=0 /src/cmd/gorse-worker/gorse-worker /usr/bin/gorse-worker ENV USER=root ENTRYPOINT ["/usr/bin/gorse-worker"] ================================================ FILE: cmd/gorse-worker/Dockerfile.cuda ================================================ # syntax = docker/dockerfile:1 ############################ # STEP 1 build executable binary ############################ FROM nvidia/cuda:12.8.1-devel-ubuntu24.04 COPY --from=golang:1.26 /usr/local/go/ /usr/local/go/ ENV PATH=/usr/local/go/bin:$PATH RUN apt update && apt install -y git WORKDIR /src COPY go.* ./ RUN go mod download COPY . ./ RUN cd common/blas/cublas && make RUN --mount=type=cache,target=/root/.cache/go-build \ cd cmd/gorse-worker && \ go build -tags xla -ldflags=" \ -X 'github.com/gorse-io/gorse/cmd/version.Version=$(git describe --tags $(git rev-parse HEAD))' \ -X 'github.com/gorse-io/gorse/cmd/version.GitCommit=$(git rev-parse HEAD)' \ -X 'github.com/gorse-io/gorse/cmd/version.BuildTime=$(date)' \ -X 'github.com/gorse-io/gorse/config.RootDir=/var/lib/gorse'" . ############################ # STEP 2 build runtime image ############################ FROM nvidia/cuda:12.8.1-runtime-ubuntu24.04 COPY --from=0 /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ COPY --from=0 /src/cmd/gorse-worker/gorse-worker /usr/bin/gorse-worker RUN /usr/bin/gorse-worker --version ENV USER=root ENTRYPOINT ["/usr/bin/gorse-worker"] ================================================ FILE: cmd/gorse-worker/Dockerfile.mkl ================================================ # syntax = docker/dockerfile:1 ############################ # STEP 1 build executable binary ############################ FROM golang:1.26-bookworm # Install Intel® oneAPI Toolkits RUN apt update && apt install -y wget gnupg2 git RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \ | gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null RUN echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list RUN apt update && apt install -y intel-oneapi-base-toolkit # Build Gorse WORKDIR /src COPY go.* ./ RUN go mod download COPY . ./ RUN --mount=type=cache,target=/root/.cache/go-build \ . /opt/intel/oneapi/setvars.sh && \ cd cmd/gorse-worker && \ go build -tags mkl -ldflags=" \ -X 'github.com/gorse-io/gorse/cmd/version.Version=$(git describe --tags $(git rev-parse HEAD))' \ -X 'github.com/gorse-io/gorse/cmd/version.GitCommit=$(git rev-parse HEAD)' \ -X 'github.com/gorse-io/gorse/cmd/version.BuildTime=$(date)' \ -X 'github.com/gorse-io/gorse/config.RootDir=/var/lib/gorse'" . ############################ # STEP 2 build runtime image ############################ FROM debian:bookworm-slim RUN apt update && apt install -y ca-certificates && rm -rf /var/lib/apt/lists/* COPY --from=0 /src/cmd/gorse-worker/gorse-worker /usr/bin/gorse-worker RUN /usr/bin/gorse-worker --version ENV USER=root ENTRYPOINT ["/usr/bin/gorse-worker"] ================================================ FILE: cmd/gorse-worker/Dockerfile.openblas ================================================ # syntax = docker/dockerfile:1 ############################ # STEP 1 build executable binary ############################ FROM --platform=$BUILDPLATFORM golang:1.26 # Install OpenBLAS ARG BUILDARCH ARG TARGETARCH RUN if [ "${TARGETARCH}" = "amd64" ]; then \ dpkg --add-architecture ${TARGETARCH} && apt update && apt install -y gcc-x86-64-linux-gnu libopenblas-dev:${TARGETARCH} git; \ elif [ "${TARGETARCH}" = "arm64" ]; then \ dpkg --add-architecture ${TARGETARCH} && apt update && apt install -y gcc-aarch64-linux-gnu libopenblas-dev:${TARGETARCH} git; \ elif [ "${TARGETARCH}" = "riscv64" ]; then \ dpkg --add-architecture ${TARGETARCH} && apt update && apt install -y gcc-riscv64-linux-gnu libopenblas-dev:${TARGETARCH} git; \ elif [ "${TARGETARCH}" = "ppc64le" ]; then \ dpkg --add-architecture ppc64el && apt update && apt install -y gcc-powerpc64le-linux-gnu libopenblas-dev:ppc64el git; \ elif [ "${TARGETARCH}" = "s390x" ]; then \ dpkg --add-architecture ${TARGETARCH} && apt update && apt install -y gcc-s390x-linux-gnu libopenblas-dev:${TARGETARCH} git; \ elif [ "${TARGETARCH}" = "loong64" ]; then \ dpkg --add-architecture ${TARGETARCH} && apt update && apt install -y gcc-loongarch64-linux-gnu libopenblas-dev:${TARGETARCH} git; \ else \ echo "Unsupported TARGETARCH ${TARGETARCH}"; exit 1; \ fi # Build Gorse WORKDIR /src COPY go.* ./ RUN go mod download COPY . ./ ARG TARGETOS RUN --mount=type=cache,target=/root/.cache/go-build \ if [ "${TARGETARCH}" = "amd64" ]; then \ export CC=x86_64-linux-gnu-gcc; \ elif [ "${TARGETARCH}" = "arm64" ]; then \ export CC=aarch64-linux-gnu-gcc; \ elif [ "${TARGETARCH}" = "riscv64" ]; then \ export CC=riscv64-linux-gnu-gcc; \ elif [ "${TARGETARCH}" = "ppc64le" ]; then \ export CC=powerpc64le-linux-gnu-gcc; \ elif [ "${TARGETARCH}" = "s390x" ]; then \ export CC=s390x-linux-gnu-gcc; \ elif [ "${TARGETARCH}" = "loong64" ]; then \ export CC=loongarch64-linux-gnu-gcc; \ else \ echo "Unsupported TARGETARCH ${TARGETARCH}"; exit 1; \ fi && \ cd cmd/gorse-worker && \ GOOS=${TARGETOS} GOARCH=${TARGETARCH} CGO_ENABLED=1 go build -tags openblas -ldflags=" \ -X 'github.com/gorse-io/gorse/cmd/version.Version=$(git describe --tags $(git rev-parse HEAD))' \ -X 'github.com/gorse-io/gorse/cmd/version.GitCommit=$(git rev-parse HEAD)' \ -X 'github.com/gorse-io/gorse/cmd/version.BuildTime=$(date)' \ -X 'github.com/gorse-io/gorse/config.RootDir=/var/lib/gorse'" . ############################ # STEP 2 build runtime image ############################ FROM debian:trixie-slim RUN apt update && apt install -y ca-certificates && rm -rf /var/lib/apt/lists/* COPY --from=0 /src/cmd/gorse-worker/gorse-worker /usr/bin/gorse-worker RUN /usr/bin/gorse-worker --version ENV USER=root ENTRYPOINT ["/usr/bin/gorse-worker"] ================================================ FILE: cmd/gorse-worker/Dockerfile.windows ================================================ ############################ # STEP 1 build executable binary ############################ FROM golang:1.26 WORKDIR /src COPY . ./ ENV CGO_ENABLED=0 RUN go build -o / -ldflags="\" \ -X 'github.com/gorse-io/gorse/cmd/version.Version=$(git describe --tags $(git rev-parse HEAD))' \ -X 'github.com/gorse-io/gorse/cmd/version.GitCommit=$(git rev-parse HEAD)' \ -X 'github.com/gorse-io/gorse/cmd/version.BuildTime=$(date)'\"" ./cmd/... ############################ # STEP 2 build a small image ############################ FROM mcr.microsoft.com/windows/servercore:ltsc2022 COPY --from=0 /gorse-worker.exe /gorse-worker.exe RUN /gorse-worker.exe --version ENTRYPOINT [ "/gorse-worker.exe" ] ================================================ FILE: cmd/gorse-worker/main.go ================================================ // Copyright 2020 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package main import ( "fmt" _ "net/http/pprof" "time" "github.com/gorse-io/gorse/cmd/version" "github.com/gorse-io/gorse/common/log" "github.com/gorse-io/gorse/common/util" "github.com/gorse-io/gorse/worker" "github.com/spf13/cobra" "go.uber.org/zap" ) var workerCommand = &cobra.Command{ Use: "gorse-worker", Short: "The worker node of gorse recommender system.", Run: func(cmd *cobra.Command, args []string) { // show version showVersion, _ := cmd.PersistentFlags().GetBool("version") if showVersion { fmt.Println(version.BuildInfo()) return } masterHost, _ := cmd.PersistentFlags().GetString("master-host") masterPort, _ := cmd.PersistentFlags().GetInt("master-port") httpHost, _ := cmd.PersistentFlags().GetString("http-host") httpPort, _ := cmd.PersistentFlags().GetInt("http-port") workingJobs, _ := cmd.PersistentFlags().GetInt("jobs") // setup logger debug, _ := cmd.PersistentFlags().GetBool("debug") log.SetLogger(cmd.PersistentFlags(), debug) // create worker cachePath, _ := cmd.PersistentFlags().GetString("cache-path") caFile, _ := cmd.PersistentFlags().GetString("ssl-ca") certFile, _ := cmd.PersistentFlags().GetString("ssl-cert") keyFile, _ := cmd.PersistentFlags().GetString("ssl-key") var tlsConfig *util.TLSConfig if caFile != "" && certFile != "" && keyFile != "" { tlsConfig = &util.TLSConfig{ SSLCA: caFile, SSLCert: certFile, SSLKey: keyFile, } } else if caFile == "" && certFile == "" && keyFile == "" { tlsConfig = nil } else { log.Logger().Fatal("incomplete SSL configuration", zap.String("ssl_ca", caFile), zap.String("ssl_cert", certFile), zap.String("ssl_key", keyFile)) } interval, _ := cmd.PersistentFlags().GetDuration("interval") w := worker.NewWorker(masterHost, masterPort, httpHost, httpPort, workingJobs, cachePath, tlsConfig, interval) w.Serve() }, } func init() { log.AddFlags(workerCommand.PersistentFlags()) workerCommand.PersistentFlags().BoolP("version", "v", false, "gorse version") workerCommand.PersistentFlags().String("master-host", "127.0.0.1", "host of master node") workerCommand.PersistentFlags().Int("master-port", 8086, "port of master node") workerCommand.PersistentFlags().String("http-host", "127.0.0.1", "host for Prometheus metrics export") workerCommand.PersistentFlags().Int("http-port", 8089, "port for Prometheus metrics export") workerCommand.PersistentFlags().Bool("debug", false, "use debug log mode") workerCommand.PersistentFlags().IntP("jobs", "j", 1, "number of working jobs.") workerCommand.PersistentFlags().String("cache-path", "worker_cache.data", "path of cache file") workerCommand.PersistentFlags().String("ssl-ca", "", "path of SSL CA") workerCommand.PersistentFlags().String("ssl-cert", "", "path to SSL certificate") workerCommand.PersistentFlags().String("ssl-key", "", "path to SSL key") workerCommand.PersistentFlags().Duration("interval", time.Minute, "interval between checking users") } func main() { if err := workerCommand.Execute(); err != nil { log.Logger().Fatal("failed to execute", zap.Error(err)) } } ================================================ FILE: cmd/version/version.go ================================================ package version import ( "fmt" "runtime" ) // Default build-time variable. // These values are overridden via ldflags var ( Version = "unknown-version" GitCommit = "unknown-commit" BuildTime = "unknown-buildtime" APIVersion = "v0.2.7" ) func BuildInfo() string { var buildInfo string buildInfo += fmt.Sprintln("Version:\t", Version) buildInfo += fmt.Sprintln("API version:\t", APIVersion) buildInfo += fmt.Sprintln("Go version:\t", runtime.Version()) buildInfo += fmt.Sprintln("Git commit:\t", GitCommit) buildInfo += fmt.Sprintln("Built:\t\t", BuildTime) buildInfo += fmt.Sprintf("OS/Arch:\t %s/%s\n", runtime.GOOS, runtime.GOARCH) return buildInfo } ================================================ FILE: codecov.yml ================================================ coverage: status: patch: default: enabled: no ignore: - "protocol/*.pb.go" - "cmd/**" ================================================ FILE: common/ann/ann.go ================================================ // Copyright 2024 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package ann import ( "github.com/samber/lo" ) type Index interface { Add(v []float32) int SearchIndex(q, k int, prune0 bool) ([]lo.Tuple2[int, float32], error) SearchVector(q []float32, k int, prune0 bool) []lo.Tuple2[int, float32] } ================================================ FILE: common/ann/ann_test.go ================================================ // Copyright 2024 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package ann import ( "bufio" "os" "path/filepath" "strconv" "strings" "sync" "testing" mapset "github.com/deckarep/golang-set/v2" "github.com/gorse-io/gorse/common/datautil" "github.com/gorse-io/gorse/common/floats" "github.com/gorse-io/gorse/common/util" "github.com/samber/lo" "github.com/stretchr/testify/assert" "go.uber.org/atomic" ) const ( trainSize = 6000 testSize = 1000 ) func recall(gt, pred []lo.Tuple2[int, float32]) float64 { s := mapset.NewSet[int]() for _, pair := range gt { s.Add(pair.A) } hit := 0 for _, pair := range pred { if s.Contains(pair.A) { hit++ } } return float64(hit) / float64(len(gt)) } type MNIST struct { TrainImages [][]float32 TrainLabels []uint8 TestImages [][]float32 TestLabels []uint8 } func mnist() (*MNIST, error) { // Download and unzip dataset path, err := datautil.DownloadAndUnzip("mnist") if err != nil { return nil, err } // Open dataset m := new(MNIST) m.TrainImages, m.TrainLabels, err = m.openFile(filepath.Join(path, "train.libfm")) if err != nil { return nil, err } m.TestImages, m.TestLabels, err = m.openFile(filepath.Join(path, "test.libfm")) if err != nil { return nil, err } return m, nil } func (m *MNIST) openFile(path string) ([][]float32, []uint8, error) { // Open file f, err := os.Open(path) if err != nil { return nil, nil, err } defer f.Close() // Read data line by line var ( images [][]float32 labels []uint8 ) scanner := bufio.NewScanner(f) for scanner.Scan() { line := scanner.Text() splits := strings.Split(line, " ") // Parse label label, err := util.ParseUInt[uint8](splits[0]) if err != nil { return nil, nil, err } labels = append(labels, label) // Parse image image := make([]float32, 784) for _, split := range splits[1:] { kv := strings.Split(split, ":") index, err := strconv.Atoi(kv[0]) if err != nil { return nil, nil, err } value, err := util.ParseFloat[float32](kv[1]) if err != nil { return nil, nil, err } image[index] = value } images = append(images, image) } return images, labels, nil } func TestMNIST(t *testing.T) { dat, err := mnist() assert.NoError(t, err) // Create brute-force index bf := NewBruteforce(floats.Euclidean) for _, image := range dat.TrainImages[:trainSize] { _ = bf.Add(image) } // Create HNSW index hnsw := NewHNSW(floats.Euclidean) for _, image := range dat.TrainImages[:trainSize] { _ = hnsw.Add(image) } // Test search r := 0.0 for _, image := range dat.TestImages[:testSize] { gt := bf.SearchVector(image, 100, false) assert.Len(t, gt, 100) scores := hnsw.SearchVector(image, 100, false) assert.Len(t, scores, 100) r += recall(gt, scores) } r /= float64(testSize) assert.Greater(t, r, 0.99) // Test save and load path := filepath.Join(t.TempDir(), "mnist.bin") f, err := os.Create(path) assert.NoError(t, err) defer f.Close() assert.NoError(t, hnsw.Marshal(f)) f, err = os.Open(path) assert.NoError(t, err) defer f.Close() assert.NoError(t, hnsw.Unmarshal(f)) r = 0 for _, image := range dat.TestImages[:testSize] { gt := bf.SearchVector(image, 100, false) assert.Len(t, gt, 100) scores := hnsw.SearchVector(image, 100, false) assert.Len(t, scores, 100) r += recall(gt, scores) } r /= float64(testSize) assert.Greater(t, r, 0.99) } func TestMultithread(t *testing.T) { dat, err := mnist() assert.NoError(t, err) // Create HNSW index indices := make([]int, trainSize) hnsw := NewHNSW(floats.Euclidean) var wg1 sync.WaitGroup wg1.Add(trainSize) for i := range dat.TrainImages[:trainSize] { go func(i int) { defer wg1.Done() indices[i] = hnsw.Add(dat.TrainImages[i]) }(i) } wg1.Wait() // Create brute-force index reverse := make([]int, trainSize) for i, index := range indices { reverse[index] = i } bf := NewBruteforce(floats.Euclidean) for i := range reverse { _ = bf.Add(dat.TrainImages[reverse[i]]) } // Test search var r atomic.Float64 var wg2 sync.WaitGroup wg2.Add(testSize) for _, image := range dat.TestImages[:testSize] { go func(image []float32) { defer wg2.Done() gt := bf.SearchVector(image, 100, false) assert.Len(t, gt, 100) scores := hnsw.SearchVector(image, 100, false) assert.Len(t, scores, 100) r.Add(recall(gt, scores)) }(image) } wg2.Wait() assert.Greater(t, r.Load()/float64(testSize), 0.99) } func movieLens() ([][]int, error) { // Download and unzip dataset path, err := datautil.DownloadAndUnzip("ml-1m") if err != nil { return nil, err } // Open file f, err := os.Open(filepath.Join(path, "train.txt")) if err != nil { return nil, err } defer f.Close() // Read data line by line movies := make([][]int, 0) scanner := bufio.NewScanner(f) for scanner.Scan() { line := scanner.Text() splits := strings.Split(line, "\t") userId, err := strconv.Atoi(splits[0]) if err != nil { return nil, err } movieId, err := strconv.Atoi(splits[1]) if err != nil { return nil, err } for movieId >= len(movies) { movies = append(movies, make([]int, 0)) } movies[movieId] = append(movies[movieId], userId) } return movies, nil } func jaccard(a, b []int) float32 { var i, j, intersection int for i < len(a) && j < len(b) { if a[i] == b[j] { intersection++ i++ j++ } else if a[i] < b[j] { i++ } else { j++ } } if len(a)+len(b)-intersection == 0 { return 1 } return 1 - float32(intersection)/float32(len(a)+len(b)-intersection) } func TestMovieLens(t *testing.T) { movies, err := movieLens() assert.NoError(t, err) // Create brute-force index bf := NewBruteforce(jaccard) for _, movie := range movies { _ = bf.Add(movie) } // Create HNSW index hnsw := NewHNSW(jaccard) for _, movie := range movies { _ = hnsw.Add(movie) } // Test search r := 0.0 for i := range movies[:testSize] { gt, err := bf.SearchIndex(i, 100, false) assert.NoError(t, err) scores, err := hnsw.SearchIndex(i, 100, false) assert.NoError(t, err) r += recall(gt, scores) } r /= float64(testSize) assert.Greater(t, r, 0.98) } ================================================ FILE: common/ann/bruteforce.go ================================================ // Copyright 2024 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package ann import ( "github.com/gorse-io/gorse/common/heap" "github.com/juju/errors" "github.com/samber/lo" ) // Bruteforce is a naive implementation of vector index. type Bruteforce[T any] struct { distanceFunc func(a, b T) float32 vectors []T } func NewBruteforce[T any](distanceFunc func(a, b T) float32) *Bruteforce[T] { return &Bruteforce[T]{distanceFunc: distanceFunc} } func (b *Bruteforce[T]) Add(v T) int { // Add vector b.vectors = append(b.vectors, v) return len(b.vectors) } func (b *Bruteforce[T]) SearchIndex(q, k int, prune0 bool) ([]lo.Tuple2[int, float32], error) { // Check index if q < 0 || q >= len(b.vectors) { return nil, errors.Errorf("index out of range: %v", q) } // Search pq := heap.NewPriorityQueue(true) for i, vec := range b.vectors { if i != q { pq.Push(int32(i), b.distanceFunc(b.vectors[q], vec)) if pq.Len() > k { pq.Pop() } } } pq = pq.Reverse() scores := make([]lo.Tuple2[int, float32], 0) for pq.Len() > 0 { value, score := pq.Pop() if !prune0 || score > 0 { scores = append(scores, lo.Tuple2[int, float32]{A: int(value), B: score}) } } return scores, nil } func (b *Bruteforce[T]) SearchVector(q T, k int, prune0 bool) []lo.Tuple2[int, float32] { // Search pq := heap.NewPriorityQueue(true) for i, vec := range b.vectors { pq.Push(int32(i), b.distanceFunc(q, vec)) if pq.Len() > k { pq.Pop() } } pq = pq.Reverse() scores := make([]lo.Tuple2[int, float32], 0) for pq.Len() > 0 { value, score := pq.Pop() if !prune0 || score > 0 { scores = append(scores, lo.Tuple2[int, float32]{A: int(value), B: score}) } } return scores } ================================================ FILE: common/ann/hnsw.go ================================================ // Copyright 2024 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package ann import ( "encoding/binary" "io" "math/rand" "sync" "github.com/chewxy/math32" mapset "github.com/deckarep/golang-set/v2" "github.com/gorse-io/gorse/common/encoding" "github.com/gorse-io/gorse/common/heap" "github.com/pkg/errors" "github.com/samber/lo" "modernc.org/mathutil" ) // HNSW is a vector index based on Hierarchical Navigable Small Worlds. type HNSW[T any] struct { distanceFunc func(a, b T) float32 vectors []T bottomNeighbors []*heap.PriorityQueue upperNeighbors []sync.Map enterPoint int32 initOnce sync.Once indexMutex sync.Mutex rootMutex sync.Mutex bottomMutex []*sync.RWMutex levelFactor float32 maxConnection int // maximum number of connections for each element per layer maxConnection0 int ef int efConstruction int } func NewHNSW[T any](distanceFunc func(a, b T) float32) *HNSW[T] { return &HNSW[T]{ distanceFunc: distanceFunc, levelFactor: 1.0 / math32.Log(48), maxConnection: 48, maxConnection0: 96, efConstruction: 100, } } func (h *HNSW[T]) Add(v T) int { // Add vector h.indexMutex.Lock() h.vectors = append(h.vectors, v) h.bottomNeighbors = append(h.bottomNeighbors, heap.NewPriorityQueue(false)) h.bottomMutex = append(h.bottomMutex, new(sync.RWMutex)) q := len(h.vectors) - 1 h.indexMutex.Unlock() h.insert(int32(q)) return q } func (h *HNSW[T]) SearchIndex(q, k int, prune0 bool) ([]lo.Tuple2[int, float32], error) { // Check index if q < 0 || q >= len(h.vectors) { return nil, errors.Errorf("index out of range: %v", q) } w := h.knnSearch(h.vectors[q], k, h.efSearchValue(k)) scores := make([]lo.Tuple2[int, float32], 0) for w.Len() > 0 { value, score := w.Pop() if !prune0 || score > 0 { scores = append(scores, lo.Tuple2[int, float32]{A: int(value), B: score}) } } return scores, nil } func (h *HNSW[T]) SearchVector(q T, k int, prune0 bool) []lo.Tuple2[int, float32] { w := h.knnSearch(q, k, h.efSearchValue(k)) scores := make([]lo.Tuple2[int, float32], 0) for w.Len() > 0 { value, score := w.Pop() if !prune0 || score > 0 { scores = append(scores, lo.Tuple2[int, float32]{A: int(value), B: score}) } } return scores } func (h *HNSW[T]) knnSearch(q T, k, ef int) *heap.PriorityQueue { var ( w *heap.PriorityQueue // set for the current the nearest element enterPoints = h.distance(q, []int32{h.enterPoint}) // get enter point for hnsw topLayer = len(h.upperNeighbors) // top layer for hnsw ) for currentLayer := topLayer; currentLayer > 0; currentLayer-- { w = h.searchLayer(q, enterPoints, 1, currentLayer) enterPoints = heap.NewPriorityQueue(false) enterPoints.Push(w.Peek()) } w = h.searchLayer(q, enterPoints, ef, 0) return h.selectNeighbors(q, w, k) } // insert i-th vector into the vector index. func (h *HNSW[T]) insert(q int32) { // insert first point var isFirstPoint bool h.initOnce.Do(func() { if h.upperNeighbors == nil { h.bottomNeighbors[q] = heap.NewPriorityQueue(false) h.upperNeighbors = make([]sync.Map, 0) h.enterPoint = q isFirstPoint = true return } }) if isFirstPoint { return } h.rootMutex.Lock() var ( w *heap.PriorityQueue // list for the currently found nearest elements enterPoints = h.distance(h.vectors[q], []int32{h.enterPoint}) // get enter point for hnsw l = int(math32.Floor(-math32.Log(rand.Float32()) * h.levelFactor)) topLayer = len(h.upperNeighbors) ) if l <= topLayer { h.rootMutex.Unlock() } else { defer h.rootMutex.Unlock() } for currentLayer := topLayer; currentLayer >= l+1; currentLayer-- { w = h.searchLayer(h.vectors[q], enterPoints, 1, currentLayer) enterPoints = h.selectNeighbors(h.vectors[q], w, 1) } h.bottomMutex[q].Lock() for currentLayer := mathutil.Min(topLayer, l); currentLayer >= 0; currentLayer-- { w = h.searchLayer(h.vectors[q], enterPoints, h.efConstruction, currentLayer) neighbors := h.selectNeighbors(h.vectors[q], w, h.maxConnection) // add bidirectional connections from upperNeighbors to q at layer l_c h.setNeighbourhood(q, currentLayer, neighbors) for _, e := range neighbors.Elems() { h.bottomMutex[e.Value].Lock() h.getNeighbourhood(e.Value, currentLayer).Push(q, e.Weight) connections := h.getNeighbourhood(e.Value, currentLayer) var currentMaxConnection int if currentLayer == 0 { currentMaxConnection = h.maxConnection0 } else { currentMaxConnection = h.maxConnection } if connections.Len() > currentMaxConnection { // shrink connections of e if lc = 0 then M_max = M_max0 newConnections := h.selectNeighbors(h.vectors[q], connections, h.maxConnection) h.setNeighbourhood(e.Value, currentLayer, newConnections) } h.bottomMutex[e.Value].Unlock() } enterPoints = w } h.bottomMutex[q].Unlock() if l > topLayer { // set enter point for hnsw to q h.enterPoint = q h.upperNeighbors = append(h.upperNeighbors, sync.Map{}) h.setNeighbourhood(q, topLayer+1, heap.NewPriorityQueue(false)) } } func (h *HNSW[T]) searchLayer(q T, enterPoints *heap.PriorityQueue, ef, currentLayer int) *heap.PriorityQueue { var ( v = mapset.NewSet(enterPoints.Values()...) // set of visited elements candidates = enterPoints.Clone() // set of candidates w = enterPoints.Reverse() // dynamic list of found nearest upperNeighbors ) for candidates.Len() > 0 { // extract nearest element from candidates to q c, cq := candidates.Pop() // get the furthest element from w to q _, fq := w.Peek() if cq > fq { break // all elements in w are evaluated } // update candidates and w h.bottomMutex[c].RLock() neighbors := h.getNeighbourhood(c, currentLayer).Values() h.bottomMutex[c].RUnlock() for _, e := range neighbors { if !v.Contains(e) { v.Add(e) // get the furthest element from w to q _, fq = w.Peek() if eq := h.distanceFunc(h.vectors[e], q); eq < fq || w.Len() < ef { candidates.Push(e, eq) w.Push(e, eq) if w.Len() > ef { // remove the furthest element from w to q w.Pop() } } } } } return w.Reverse() } func (h *HNSW[T]) setNeighbourhood(e int32, currentLayer int, connections *heap.PriorityQueue) { if currentLayer == 0 { h.bottomNeighbors[e] = connections } else { h.upperNeighbors[currentLayer-1].Store(e, connections) } } func (h *HNSW[T]) getNeighbourhood(e int32, currentLayer int) *heap.PriorityQueue { if currentLayer == 0 { return h.bottomNeighbors[e] } else { if connections, ok := h.upperNeighbors[currentLayer-1].Load(e); ok { return connections.(*heap.PriorityQueue) } return nil } } func (h *HNSW[T]) selectNeighbors(_ T, candidates *heap.PriorityQueue, m int) *heap.PriorityQueue { pq := candidates.Reverse() for pq.Len() > m { pq.Pop() } return pq.Reverse() } func (h *HNSW[T]) distance(q T, points []int32) *heap.PriorityQueue { pq := heap.NewPriorityQueue(false) for _, point := range points { pq.Push(point, h.distanceFunc(h.vectors[point], q)) } return pq } // efSearchValue returns the efSearch value to use, given the current number of elements desired. func (h *HNSW[T]) efSearchValue(n int) int { if h.ef > 0 { return mathutil.Max(h.ef, n) } return mathutil.Max(h.efConstruction, n) } func (h *HNSW[T]) Marshal(w io.Writer) error { if err := binary.Write(w, binary.LittleEndian, h.levelFactor); err != nil { return err } if err := binary.Write(w, binary.LittleEndian, int64(h.maxConnection)); err != nil { return err } if err := binary.Write(w, binary.LittleEndian, int64(h.maxConnection0)); err != nil { return err } if err := binary.Write(w, binary.LittleEndian, int64(h.ef)); err != nil { return err } if err := binary.Write(w, binary.LittleEndian, int64(h.efConstruction)); err != nil { return err } // save vectors numVectors := int64(len(h.vectors)) if err := binary.Write(w, binary.LittleEndian, numVectors); err != nil { return err } for i := int64(0); i < numVectors; i++ { if err := encoding.WriteGob(w, h.vectors[i]); err != nil { return err } } // save neighbors for i := int64(0); i < numVectors; i++ { if err := h.bottomNeighbors[i].Marshal(w); err != nil { return err } } numLayers := len(h.upperNeighbors) if err := binary.Write(w, binary.LittleEndian, int64(numLayers)); err != nil { return err } for i := 0; i < numLayers; i++ { var elements []lo.Tuple2[int32, *heap.PriorityQueue] h.upperNeighbors[i].Range(func(key, value any) bool { elements = append(elements, lo.Tuple2[int32, *heap.PriorityQueue]{ A: key.(int32), B: value.(*heap.PriorityQueue)}) return true }) numElements := int32(len(elements)) if err := binary.Write(w, binary.LittleEndian, numElements); err != nil { return err } for j := int32(0); j < numElements; j++ { if err := binary.Write(w, binary.LittleEndian, elements[j].A); err != nil { return err } if err := elements[j].B.Marshal(w); err != nil { return err } } } if err := binary.Write(w, binary.LittleEndian, h.enterPoint); err != nil { return err } return nil } func (h *HNSW[T]) Unmarshal(r io.Reader) error { if err := binary.Read(r, binary.LittleEndian, &h.levelFactor); err != nil { return err } var maxConnection int64 if err := binary.Read(r, binary.LittleEndian, &maxConnection); err != nil { return err } h.maxConnection = int(maxConnection) var maxConnection0 int64 if err := binary.Read(r, binary.LittleEndian, &maxConnection0); err != nil { return err } h.maxConnection0 = int(maxConnection0) var ef int64 if err := binary.Read(r, binary.LittleEndian, &ef); err != nil { return err } h.ef = int(ef) var efConstruction int64 if err := binary.Read(r, binary.LittleEndian, &efConstruction); err != nil { return err } h.efConstruction = int(efConstruction) // read vectors var numVectors int64 if err := binary.Read(r, binary.LittleEndian, &numVectors); err != nil { return errors.WithStack(err) } h.vectors = make([]T, numVectors) for i := int64(0); i < numVectors; i++ { if err := encoding.ReadGob(r, &h.vectors[i]); err != nil { return errors.WithStack(err) } } // save neighbors h.bottomNeighbors = make([]*heap.PriorityQueue, numVectors) for i := int64(0); i < numVectors; i++ { h.bottomNeighbors[i] = heap.NewPriorityQueue(false) if err := h.bottomNeighbors[i].Unmarshal(r); err != nil { return errors.WithStack(err) } } var numLayers int64 if err := binary.Read(r, binary.LittleEndian, &numLayers); err != nil { return errors.WithStack(err) } h.upperNeighbors = make([]sync.Map, numLayers) for i := int64(0); i < numLayers; i++ { var numElements int32 if err := binary.Read(r, binary.LittleEndian, &numElements); err != nil { return errors.WithStack(err) } for j := int32(0); j < numElements; j++ { var e int32 if err := binary.Read(r, binary.LittleEndian, &e); err != nil { return errors.WithStack(err) } pq := heap.NewPriorityQueue(false) if err := pq.Unmarshal(r); err != nil { return errors.WithStack(err) } h.upperNeighbors[i].Store(e, pq) } } if err := binary.Read(r, binary.LittleEndian, &h.enterPoint); err != nil { return errors.WithStack(err) } h.bottomMutex = make([]*sync.RWMutex, numVectors) for i := int64(0); i < numVectors; i++ { h.bottomMutex[i] = new(sync.RWMutex) } return nil } ================================================ FILE: common/blas/blas.go ================================================ // Copyright 2025 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package blas type Order int const RowMajor Order = 101 type Transpose int const ( NoTrans Transpose = 111 Trans Transpose = 112 ) func NewTranspose(transpose bool) Transpose { if transpose { return Trans } else { return NoTrans } } ================================================ FILE: common/blas/blas_darwin_arm64.go ================================================ //go:build cgo // Copyright 2025 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package blas // #cgo CFLAGS: -DACCELERATE_NEW_LAPACK // #cgo LDFLAGS: -framework Accelerate // #include import "C" func SGEMM(order Order, transA, transB Transpose, m, n, k int, alpha float32, a []float32, lda int, b []float32, ldb int, beta float32, c []float32, ldc int) { C.cblas_sgemm(uint32(order), uint32(transA), uint32(transB), C.int(m), C.int(n), C.int(k), C.float(alpha), (*C.float)(&a[0]), C.int(lda), (*C.float)(&b[0]), C.int(ldb), C.float(beta), (*C.float)(&c[0]), C.int(ldc)) } ================================================ FILE: common/blas/blas_mkl.go ================================================ //go:build cgo && mkl // Copyright 2025 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package blas // #cgo CFLAGS: -I/opt/intel/oneapi/mkl/latest/include // #cgo LDFLAGS: -L/opt/intel/oneapi/mkl/latest/lib/intel64 -Wl,--start-group -lmkl_intel_lp64 -lmkl_sequential -lmkl_core -Wl,--end-group -lpthread -lm -ldl -static // #include "mkl.h" import "C" func SGEMM(order Order, transA, transB Transpose, m, n, k int, alpha float32, a []float32, lda int, b []float32, ldb int, beta float32, c []float32, ldc int) { C.cblas_sgemm(C.CBLAS_LAYOUT(order), C.CBLAS_TRANSPOSE(transA), C.CBLAS_TRANSPOSE(transB), C.int(m), C.int(n), C.int(k), C.float(alpha), (*C.float)(&a[0]), C.int(lda), (*C.float)(&b[0]), C.int(ldb), C.float(beta), (*C.float)(&c[0]), C.int(ldc)) } ================================================ FILE: common/blas/blas_openblas.go ================================================ //go:build cgo && openblas // Copyright 2025 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package blas // #cgo LDFLAGS: -lopenblas -lm -pthread -static // #include import "C" func SGEMM(order Order, transA, transB Transpose, m, n, k int, alpha float32, a []float32, lda int, b []float32, ldb int, beta float32, c []float32, ldc int) { C.cblas_sgemm(uint32(order), uint32(transA), uint32(transB), C.int(m), C.int(n), C.int(k), C.float(alpha), (*C.float)(&a[0]), C.int(lda), (*C.float)(&b[0]), C.int(ldb), C.float(beta), (*C.float)(&c[0]), C.int(ldc)) } ================================================ FILE: common/copier/copier.go ================================================ // Copyright 2021 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package copier import ( "encoding" "reflect" "github.com/juju/errors" ) func Copy(dst, src interface{}) error { dstPtr := reflect.ValueOf(dst) if dstPtr.Kind() != reflect.Ptr { return errors.NotValidf("expect dst to be a pointer, but receive %v", dstPtr.Kind()) } dstValue := dstPtr.Elem() srcValue := reflect.ValueOf(src) return copyValue(dstValue, srcValue) } func copyValue(dst, src reflect.Value) error { if dst.Kind() != src.Kind() { if dst.Kind() == reflect.Interface { newValuePointer := reflect.New(src.Type()) err := copyValue(newValuePointer.Elem(), src) if err != nil { return err } dst.Set(newValuePointer.Elem()) return nil } else { return errors.NotValidf("different type: %v != %v", dst.Kind(), src.Kind()) } } switch dst.Kind() { case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128, reflect.String: dst.Set(src) case reflect.Slice: if dst.IsNil() || (!dst.CanAddr() && dst.Len() != src.Len()) || dst.Cap() < src.Len() { newSlice := reflect.MakeSlice(src.Type(), src.Len(), src.Len()) dst.Set(newSlice) } else if dst.CanAddr() { dst.SetLen(src.Len()) } for i := 0; i < src.Len(); i++ { err := copyValue(dst.Index(i), src.Index(i)) if err != nil { return err } } case reflect.Map: if !reflect.DeepEqual(dst.Interface(), src.Interface()) { dst.Set(reflect.MakeMap(dst.Type())) keys := src.MapKeys() for _, k := range keys { value := src.MapIndex(k) newValuePointer := reflect.New((value).Type()) err := copyValue(newValuePointer.Elem(), src.MapIndex(k)) if err != nil { return err } dst.SetMapIndex(k, newValuePointer.Elem()) } } case reflect.Struct: if dst.Type().Name() != src.Type().Name() { return errors.NotValidf("different struct: %v != %v", dst.Type().Name(), src.Type().Name()) } dstPointer := reflect.New(dst.Type()) srcPointer := reflect.New(src.Type()) srcPointer.Elem().Set(src) srcMarshaller, hasSrcMarshaller := srcPointer.Interface().(encoding.BinaryMarshaler) dstUnmarshaler, hasDstUnmarshaler := dstPointer.Interface().(encoding.BinaryUnmarshaler) if hasDstUnmarshaler && hasSrcMarshaller { dstByte, err := srcMarshaller.MarshalBinary() if err != nil { return err } err = dstUnmarshaler.UnmarshalBinary(dstByte) if err != nil { return err } dst.Set(dstPointer.Elem()) } else { numFiled := src.NumField() for i := 0; i < numFiled; i++ { fieldDST := dst.Field(i) fieldSRC := src.Field(i) if !fieldDST.CanSet() { continue } err := copyValue(fieldDST, fieldSRC) if err != nil { return err } } } case reflect.Ptr: if src.IsNil() { // If source is nil, set dst to nil. dst.Set(reflect.Zero(dst.Type())) return nil } if dst.IsNil() { dst.Set(reflect.New(src.Elem().Type())) } srcElem := src.Elem() dstElem := dst.Elem() err := copyValue(dstElem, srcElem) if err != nil { return err } case reflect.Interface: if src.IsNil() { // If source is nil, set dst to nil. dst.Set(reflect.Zero(dst.Type())) return nil } if !dst.IsNil() { switch dst.Elem().Kind() { case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128, reflect.String: newValuePointer := reflect.New(src.Elem().Type()) err := copyValue(newValuePointer.Elem(), src.Elem()) if err != nil { return err } dst.Set(newValuePointer.Elem()) default: err := copyValue(dst.Elem(), src.Elem()) if err != nil { return err } } } else { newValuePointer := reflect.New(src.Elem().Type()) err := copyValue(newValuePointer.Elem(), src.Elem()) if err != nil { return err } dst.Set(newValuePointer.Elem()) } default: return errors.NotValidf("unsupported type: %v", dst.Kind()) } return nil } ================================================ FILE: common/copier/copier_test.go ================================================ // Copyright 2021 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package copier import ( "testing" "github.com/juju/errors" "github.com/stretchr/testify/assert" ) func TestPrimitives(t *testing.T) { var a = 1 var b int err := Copy(&b, a) assert.NoError(t, err) assert.Equal(t, a, b) // not a pointer err = Copy(b, a) assert.True(t, errors.Is(err, errors.NotValid)) // test type mismatch var c bool err = Copy(&c, a) assert.True(t, errors.Is(err, errors.NotValid)) // copy to interface var d interface{} err = Copy(&d, a) assert.NoError(t, err) assert.Equal(t, a, d) } func TestSlice(t *testing.T) { a := [][]int{{1}, {2}, {3}, {4}} b := make([][]int, 0) err := Copy(&b, a) assert.NoError(t, err) assert.Equal(t, a, b) // test deep copy a[0][0] = 100 assert.Equal(t, 1, b[0][0]) // test reuse memory var integers = []int{10} c := [][]int{integers, {20}, {30}, {40}} err = Copy(&c, a) assert.NoError(t, err) integers[0] = 100 assert.Equal(t, 100, c[0][0]) // copy to interface var d interface{} err = Copy(&d, a) assert.NoError(t, err) assert.Equal(t, a, d) // copy empty slice var e interface{} err = Copy(&e, make([]int, 0)) assert.NoError(t, err) assert.NotNil(t, e) // copy to larger slice var f = [][]int{{10}, {20}, {30}, {40}, {50}} err = Copy(&f, a) assert.NoError(t, err) assert.Equal(t, a, f) assert.Equal(t, 5, cap(f)) } func TestMap(t *testing.T) { a := map[int64][]int64{1: {1}, 2: {1}, 3: {1}} b := map[int64][]int64{4: {100}, 5: {200}, 6: {300}} err := Copy(&b, a) assert.NoError(t, err) assert.Equal(t, a, b) // test deep copy a[1][0] = 100 assert.Equal(t, int64(1), b[1][0]) // copy to interface var d interface{} err = Copy(&d, a) assert.NoError(t, err) assert.Equal(t, a, d) // test no copy var integers = []int64{100} c := map[int64][]int64{1: integers, 2: {1}, 3: {1}} err = Copy(&c, a) assert.NoError(t, err) assert.Equal(t, a, c) integers[0] = 10 assert.Equal(t, int64(10), c[1][0]) } type Foo struct { A int64 B []string } type Bar struct { A int64 } func TestStruct(t *testing.T) { a := Foo{A: 3, B: []string{"3"}} b := Foo{A: 4, B: []string{"4"}} err := Copy(&b, a) assert.NoError(t, err) assert.Equal(t, a, b) // test deep copy a.B[0] = "100" assert.Equal(t, "3", b.B[0]) // test type mismatch var c Bar err = Copy(&c, a) assert.True(t, errors.Is(err, errors.NotValid)) } func TestPtr(t *testing.T) { a := &Foo{A: 3, B: []string{"3"}} b := &Foo{A: 4, B: []string{"4"}} err := Copy(&b, a) assert.NoError(t, err) assert.Equal(t, a, b) } func TestInterface(t *testing.T) { var a = []interface{}{&Foo{A: 3, B: []string{"3"}}, []int{100}, 1} var b []interface{} err := Copy(&b, a) assert.NoError(t, err) assert.Equal(t, a, b) // test reuse memory var strings = []string{"30"} var c = []interface{}{&Foo{A: 30, B: strings}, []int{1000}, 10} err = Copy(&c, a) assert.NoError(t, err) assert.Equal(t, a, c) strings[0] = "123" assert.Equal(t, "123", c[0].(*Foo).B[0]) } type PrivateStruct struct { text *string } func (ps *PrivateStruct) MarshalBinary() (data []byte, err error) { return []byte(*ps.text), nil } func (ps *PrivateStruct) UnmarshalBinary(data []byte) error { ps.text = new(string(data)) return nil } func TestPrivate(t *testing.T) { var a = PrivateStruct{new("hello")} var b PrivateStruct err := Copy(&b, a) assert.NoError(t, err) assert.Equal(t, a, b) // test deep copy *a.text = "world" assert.Equal(t, "hello", *b.text) } type NilInterface interface{} type NilStruct struct { Interface NilInterface Pointer *Foo } func TestNil(t *testing.T) { var d = NilStruct{ Interface: 100, Pointer: &Foo{A: 100}, } var e NilStruct err := Copy(&d, e) assert.NoError(t, err) assert.Nil(t, d.Interface) assert.Nil(t, d.Pointer) } ================================================ FILE: common/datautil/datautil.go ================================================ // Copyright 2024 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package datautil import ( "archive/zip" "encoding/csv" "fmt" "io" "net/http" "os" "os/user" "path/filepath" "strings" "github.com/gorse-io/gorse/common/log" "github.com/gorse-io/gorse/common/util" "go.uber.org/zap" ) var ( tempDir string datasetDir string ) func init() { usr, err := user.Current() if err != nil { log.Logger().Fatal("failed to get user directory", zap.Error(err)) } datasetDir = filepath.Join(usr.HomeDir, ".gorse", "dataset") tempDir = filepath.Join(usr.HomeDir, ".gorse", "temp") } func LoadIris() ([][]float32, []int, error) { // Download dataset path, err := DownloadAndUnzip("iris") if err != nil { return nil, nil, err } dataFile := filepath.Join(path, "iris.data") // Load data f, err := os.Open(dataFile) if err != nil { return nil, nil, err } reader := csv.NewReader(f) rows, err := reader.ReadAll() if err != nil { return nil, nil, err } // Parse data data := make([][]float32, len(rows)) target := make([]int, len(rows)) types := make(map[string]int) for i, row := range rows { data[i] = make([]float32, 4) for j, cell := range row[:4] { data[i][j], err = util.ParseFloat[float32](cell) if err != nil { return nil, nil, err } } if _, exist := types[row[4]]; !exist { types[row[4]] = len(types) } target[i] = types[row[4]] } return data, target, nil } func DownloadAndUnzip(name string) (string, error) { url := fmt.Sprintf("https://cdn.gorse.io/datasets/%s.zip", name) path := filepath.Join(datasetDir, name) if _, err := os.Stat(path); os.IsNotExist(err) { zipFileName, _ := downloadFromUrl(url, tempDir) if _, err := unzip(zipFileName, datasetDir); err != nil { return "", err } } return path, nil } // downloadFromUrl downloads file from URL. func downloadFromUrl(src, dst string) (string, error) { log.Logger().Info("Download dataset", zap.String("source", src), zap.String("destination", dst)) // Extract file name tokens := strings.Split(src, "/") fileName := filepath.Join(dst, tokens[len(tokens)-1]) // Create file if err := os.MkdirAll(filepath.Dir(fileName), os.ModePerm); err != nil { return fileName, err } output, err := os.Create(fileName) if err != nil { log.Logger().Error("failed to create file", zap.Error(err), zap.String("filename", fileName)) return fileName, err } defer output.Close() // Download file response, err := http.Get(src) if err != nil { log.Logger().Error("failed to download", zap.Error(err), zap.String("source", src)) return fileName, err } defer response.Body.Close() // Save file _, err = io.Copy(output, response.Body) if err != nil { log.Logger().Error("failed to download", zap.Error(err), zap.String("source", src)) return fileName, err } return fileName, nil } // unzip zip file. func unzip(src, dst string) ([]string, error) { var fileNames []string // Open zip file r, err := zip.OpenReader(src) if err != nil { return fileNames, err } defer r.Close() // Extract files for _, f := range r.File { // Open file rc, err := f.Open() if err != nil { return fileNames, err } // Store filename/path for returning and using later on filePath := filepath.Join(dst, f.Name) // Check for ZipSlip. More Info: http://bit.ly/2MsjAWE if !strings.HasPrefix(filePath, filepath.Clean(dst)+string(os.PathSeparator)) { return fileNames, fmt.Errorf("%s: illegal file path", filePath) } // Add filename fileNames = append(fileNames, filePath) if f.FileInfo().IsDir() { // Create folder if err = os.MkdirAll(filePath, os.ModePerm); err != nil { return fileNames, err } } else { // Create all folders if err = os.MkdirAll(filepath.Dir(filePath), os.ModePerm); err != nil { return fileNames, err } // Create file outFile, err := os.OpenFile(filePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode()) if err != nil { return fileNames, err } // Save file _, err = io.Copy(outFile, rc) if err != nil { return nil, err } // Close the file without defer to close before next iteration of loop err = outFile.Close() if err != nil { return nil, err } } // Close file err = rc.Close() if err != nil { return nil, err } } return fileNames, nil } ================================================ FILE: common/datautil/datautil_test.go ================================================ // Copyright 2025 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package datautil import ( "github.com/stretchr/testify/assert" "testing" ) func TestLoadIris(t *testing.T) { data, target, err := LoadIris() assert.NoError(t, err) assert.Len(t, data, 150) assert.Len(t, data[0], 4) assert.Len(t, target, 150) } ================================================ FILE: common/encoding/encoding.go ================================================ // Copyright 2022 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package encoding import ( "bytes" "encoding/binary" "encoding/gob" "github.com/pkg/errors" "io" ) // WriteSlice writes matrix to byte stream. func WriteSlice[T any](w io.Writer, s []T) error { if err := binary.Write(w, binary.LittleEndian, int32(len(s))); err != nil { return errors.WithStack(err) } return binary.Write(w, binary.LittleEndian, s) } // ReadSlice reads matrix from byte stream. func ReadSlice[T any](r io.Reader, s *[]T) error { var length int32 if err := binary.Read(r, binary.LittleEndian, &length); err != nil { return errors.WithStack(err) } *s = make([]T, length) return binary.Read(r, binary.LittleEndian, *s) } // WriteString writes string to byte stream. func WriteString(w io.Writer, s string) error { return WriteBytes(w, []byte(s)) } // ReadString reads string from byte stream. func ReadString(r io.Reader) (string, error) { data, err := ReadBytes(r) return string(data), err } // WriteBytes writes bytes to byte stream. func WriteBytes(w io.Writer, s []byte) error { err := binary.Write(w, binary.LittleEndian, int32(len(s))) if err != nil { return err } n, err := w.Write(s) if err != nil { return err } else if n != len(s) { return errors.New("fail to write string") } return nil } // ReadBytes reads bytes from byte stream. func ReadBytes(r io.Reader) ([]byte, error) { var length int32 err := binary.Read(r, binary.LittleEndian, &length) if err != nil { return nil, err } data := make([]byte, length) readCount := 0 for { n, err := r.Read(data[readCount:]) if err != nil { return nil, err } readCount += n if readCount == len(data) { return data, nil } else if n == 0 { return nil, errors.New("fail to read string") } } } // WriteGob writes object to byte stream. func WriteGob(w io.Writer, v interface{}) error { buffer := bytes.NewBuffer(nil) encoder := gob.NewEncoder(buffer) err := encoder.Encode(v) if err != nil { return err } return WriteBytes(w, buffer.Bytes()) } // ReadGob read object from byte stream. func ReadGob(r io.Reader, v interface{}) error { data, err := ReadBytes(r) if err != nil { return err } buffer := bytes.NewBuffer(data) decoder := gob.NewDecoder(buffer) return decoder.Decode(v) } ================================================ FILE: common/encoding/encoding_test.go ================================================ // Copyright 2022 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package encoding import ( "bytes" "testing" "github.com/stretchr/testify/assert" ) func TestReadWriteSlice(t *testing.T) { a := []float32{1, 2, 3, 4} buf := bytes.NewBuffer(nil) err := WriteSlice(buf, a) assert.NoError(t, err) var b []float32 err = ReadSlice(buf, &b) assert.NoError(t, err) assert.Equal(t, a, b) } func TestWriteString(t *testing.T) { a := "abc" buf := bytes.NewBuffer(nil) err := WriteString(buf, a) assert.NoError(t, err) var b string b, err = ReadString(buf) assert.NoError(t, err) assert.Equal(t, a, b) } func TestWriteGob(t *testing.T) { a := "abc" buf := bytes.NewBuffer(nil) err := WriteGob(buf, a) assert.NoError(t, err) var b string err = ReadGob(buf, &b) assert.NoError(t, err) assert.Equal(t, a, b) } ================================================ FILE: common/expression/expression.go ================================================ // Copyright 2025 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package expression import ( "encoding/json" "errors" "fmt" "regexp" "strconv" "github.com/gorse-io/gorse/protocol" "github.com/samber/lo" ) var expressionPattern = regexp.MustCompile(`^(?P[a-zA-Z][a-zA-Z0-9_]*)(?P<=|>=|<|>|=)?(?P[0-9]*\.?[0-9]*)$`) type ExprType int const ( None ExprType = iota Less LessOrEqual Greater GreaterOrEqual ) func (typ ExprType) String() string { switch typ { case Less: return "<" case LessOrEqual: return "<=" case Greater: return ">" case GreaterOrEqual: return ">=" default: return "" } } type FeedbackTypeExpression struct { FeedbackType string ExprType ExprType Value float64 } func (f *FeedbackTypeExpression) String() string { if f.ExprType == None { return f.FeedbackType } else { return fmt.Sprintf("%s%v%v", f.FeedbackType, f.ExprType, f.Value) } } func (f *FeedbackTypeExpression) FromString(data string) error { groupNames := expressionPattern.SubexpNames() subMatches := expressionPattern.FindStringSubmatch(data) if len(subMatches) == 0 { return errors.New("invalid expression format, expected format: []") } for i, match := range subMatches { switch groupNames[i] { case "feedback_type": f.FeedbackType = match case "expr_type": switch match { case "<": f.ExprType = Less case "<=": f.ExprType = LessOrEqual case ">": f.ExprType = Greater case ">=": f.ExprType = GreaterOrEqual default: f.ExprType = None } case "value": if len(match) > 0 { var err error f.Value, err = strconv.ParseFloat(match, 64) if err != nil { return fmt.Errorf("invalid value: %w", err) } } } } return nil } func (f *FeedbackTypeExpression) MarshalJSON() ([]byte, error) { return json.Marshal(f.String()) } func (f *FeedbackTypeExpression) UnmarshalJSON(data []byte) error { var s string if err := json.Unmarshal(data, &s); err != nil { return fmt.Errorf("unmarshal FeedbackTypeExpression: %w", err) } if err := f.FromString(s); err != nil { return fmt.Errorf("unmarshal FeedbackTypeExpression: %w", err) } return nil } func (f *FeedbackTypeExpression) ToPB() *protocol.FeedbackTypeExpression { pb := &protocol.FeedbackTypeExpression{} pb.FeedbackType = f.FeedbackType switch f.ExprType { case None: pb.ExpressionType = protocol.ExpressionType_None case Less: pb.ExpressionType = protocol.ExpressionType_Less case LessOrEqual: pb.ExpressionType = protocol.ExpressionType_LessOrEqual case Greater: pb.ExpressionType = protocol.ExpressionType_Greater case GreaterOrEqual: pb.ExpressionType = protocol.ExpressionType_GreaterOrEqual } pb.Value = f.Value return pb } func (f *FeedbackTypeExpression) FromPB(pb *protocol.FeedbackTypeExpression) { f.FeedbackType = pb.FeedbackType switch pb.ExpressionType { case protocol.ExpressionType_None: f.ExprType = None case protocol.ExpressionType_Less: f.ExprType = Less case protocol.ExpressionType_LessOrEqual: f.ExprType = LessOrEqual case protocol.ExpressionType_Greater: f.ExprType = Greater case protocol.ExpressionType_GreaterOrEqual: f.ExprType = GreaterOrEqual } f.Value = pb.Value } func (f *FeedbackTypeExpression) Match(feedbackType string, value float64) bool { if f.FeedbackType != feedbackType { return false } switch f.ExprType { case None: return true case Less: return value < f.Value case LessOrEqual: return value <= f.Value case Greater: return value > f.Value case GreaterOrEqual: return value >= f.Value default: return false } } func MatchFeedbackTypeExpressions(exprs []FeedbackTypeExpression, feedbackType string, value float64) bool { for _, expr := range exprs { if expr.Match(feedbackType, value) { return true } } return false } func MustParseFeedbackTypeExpression(s string) FeedbackTypeExpression { var expr FeedbackTypeExpression lo.Must0(expr.FromString(s)) return expr } ================================================ FILE: common/expression/expression_test.go ================================================ // Copyright 2025 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package expression import ( "testing" "github.com/gorse-io/gorse/protocol" "github.com/stretchr/testify/assert" ) func TestFeedbackTypeExpression_UnmarshalJSON(t *testing.T) { var f FeedbackTypeExpression err := f.UnmarshalJSON([]byte(`"test"`)) assert.NoError(t, err) assert.Equal(t, "test", f.FeedbackType) assert.Equal(t, None, f.ExprType) err = f.UnmarshalJSON([]byte(`"1a"`)) assert.Error(t, err) err = f.UnmarshalJSON([]byte(`"test<16"`)) assert.NoError(t, err) assert.Equal(t, "test", f.FeedbackType) assert.Equal(t, Less, f.ExprType) assert.Equal(t, 16.0, f.Value) err = f.UnmarshalJSON([]byte(`"test<=16"`)) assert.NoError(t, err) assert.Equal(t, "test", f.FeedbackType) assert.Equal(t, LessOrEqual, f.ExprType) assert.Equal(t, 16.0, f.Value) err = f.UnmarshalJSON([]byte(`"test>16"`)) assert.NoError(t, err) assert.Equal(t, "test", f.FeedbackType) assert.Equal(t, Greater, f.ExprType) assert.Equal(t, 16.0, f.Value) err = f.UnmarshalJSON([]byte(`"test>=16"`)) assert.NoError(t, err) assert.Equal(t, "test", f.FeedbackType) assert.Equal(t, GreaterOrEqual, f.ExprType) assert.Equal(t, 16.0, f.Value) } func TestFeedbackTypeExpression_MarshalJSON(t *testing.T) { f := FeedbackTypeExpression{FeedbackType: "test", Value: 16} buf, err := f.MarshalJSON() assert.NoError(t, err) assert.Equal(t, `"test"`, string(buf)) f.ExprType = Less buf, err = f.MarshalJSON() assert.NoError(t, err) assert.Equal(t, `"test\u003c16"`, string(buf)) f.ExprType = LessOrEqual buf, err = f.MarshalJSON() assert.NoError(t, err) assert.Equal(t, `"test\u003c=16"`, string(buf)) f.ExprType = Greater buf, err = f.MarshalJSON() assert.NoError(t, err) assert.Equal(t, `"test\u003e16"`, string(buf)) f.ExprType = GreaterOrEqual buf, err = f.MarshalJSON() assert.NoError(t, err) assert.Equal(t, `"test\u003e=16"`, string(buf)) } func TestFeedbackTypeExpression_ToPB(t *testing.T) { f := FeedbackTypeExpression{FeedbackType: "test", Value: 6} pb := f.ToPB() assert.Equal(t, "test", pb.FeedbackType) assert.Equal(t, protocol.ExpressionType_None, pb.ExpressionType) assert.Equal(t, 6.0, pb.Value) f.ExprType = Less pb = f.ToPB() assert.Equal(t, protocol.ExpressionType_Less, pb.ExpressionType) f.ExprType = LessOrEqual pb = f.ToPB() assert.Equal(t, protocol.ExpressionType_LessOrEqual, pb.ExpressionType) f.ExprType = Greater pb = f.ToPB() assert.Equal(t, protocol.ExpressionType_Greater, pb.ExpressionType) f.ExprType = GreaterOrEqual pb = f.ToPB() assert.Equal(t, protocol.ExpressionType_GreaterOrEqual, pb.ExpressionType) } func TestFeedbackTypeExpression_FromPB(t *testing.T) { pb := &protocol.FeedbackTypeExpression{ FeedbackType: "test", ExpressionType: protocol.ExpressionType_Less, Value: 6.0, } f := FeedbackTypeExpression{} f.FromPB(pb) assert.Equal(t, "test", f.FeedbackType) assert.Equal(t, Less, f.ExprType) assert.Equal(t, 6.0, f.Value) pb.ExpressionType = protocol.ExpressionType_LessOrEqual f.FromPB(pb) assert.Equal(t, LessOrEqual, f.ExprType) pb.ExpressionType = protocol.ExpressionType_Greater f.FromPB(pb) assert.Equal(t, Greater, f.ExprType) pb.ExpressionType = protocol.ExpressionType_GreaterOrEqual f.FromPB(pb) assert.Equal(t, GreaterOrEqual, f.ExprType) pb.ExpressionType = protocol.ExpressionType_None f.FromPB(pb) assert.Equal(t, None, f.ExprType) } func TestFeedbackTypeExpression_Match(t *testing.T) { f := FeedbackTypeExpression{FeedbackType: "test", Value: 6} assert.True(t, f.Match("test", 0)) assert.False(t, f.Match("a", 1)) f.ExprType = Less assert.True(t, f.Match("test", 5)) assert.False(t, f.Match("test", 6)) f.ExprType = LessOrEqual assert.True(t, f.Match("test", 6)) assert.False(t, f.Match("test", 7)) f.ExprType = Greater assert.True(t, f.Match("test", 7)) assert.False(t, f.Match("test", 6)) f.ExprType = GreaterOrEqual assert.True(t, f.Match("test", 6)) assert.False(t, f.Match("test", 5)) } func TestMatchFeedbackTypeExpressions(t *testing.T) { expressions := []FeedbackTypeExpression{ {FeedbackType: "a"}, {FeedbackType: "b"}, {FeedbackType: "c"}, } assert.True(t, MatchFeedbackTypeExpressions(expressions, "a", 0)) assert.False(t, MatchFeedbackTypeExpressions(expressions, "d", 0)) } func TestMustParseFeedbackTypeExpression(t *testing.T) { assert.Panics(t, func() { MustParseFeedbackTypeExpression("test+") }) } ================================================ FILE: common/floats/floats.go ================================================ // Copyright 2020 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package floats import ( "math" "github.com/chewxy/math32" ) func dot(a, b []float32) (ret float32) { for i := range a { ret += a[i] * b[i] } return } func euclidean(a, b []float32) (ret float32) { for i := range a { ret += (a[i] - b[i]) * (a[i] - b[i]) } return math32.Sqrt(ret) } func addConst(a []float32, c float32) { for i := range a { a[i] += c } } func sub(a, b []float32) { for i := range a { a[i] -= b[i] } } func subTo(a, b, c []float32) { for i := range a { c[i] = a[i] - b[i] } } func mulTo(a, b, c []float32) { for i := range a { c[i] = a[i] * b[i] } } func mulConstAddTo(a []float32, b float32, c []float32, dst []float32) { for i := range a { dst[i] = a[i]*b + c[i] } } func mulConstAdd(a []float32, c float32, dst []float32) { for i := range a { dst[i] += a[i] * c } } func mulConstTo(a []float32, b float32, c []float32) { for i := range a { c[i] = a[i] * b } } func mulConst(a []float32, b float32) { for i := range a { a[i] *= b } } func divTo(a, b, c []float32) { for i := range a { c[i] = a[i] / b[i] } } func sqrtTo(a, b []float32) { for i := range a { b[i] = math32.Sqrt(a[i]) } } // MatZero fills zeros in a matrix of 32-bit floats. func MatZero(x [][]float32) { for i := range x { for j := range x[i] { x[i][j] = 0 } } } // Zero fills zeros in a slice of 32-bit floats. func Zero(a []float32) { for i := range a { a[i] = 0 } } // SubTo subtracts one vector by another and saves the result in dst: dst = a - b func SubTo(a, b, dst []float32) { if len(dst) != len(b) || len(a) != len(b) { panic("floats: slice lengths do not match") } feature.subTo(a, b, dst) } // Add two vectors: dst = dst + s func Add(dst, s []float32) { if len(dst) != len(s) { panic("floats: slice lengths do not match") } for i := range dst { dst[i] += s[i] } } // MulConst multiplies a vector with a const: dst = dst * c func MulConst(dst []float32, c float32) { feature.mulConst(dst, c) } // Div one vectors by another: dst = dst / s func Div(dst, s []float32) { if len(dst) != len(s) { panic("floats: slice lengths do not match") } for i := range dst { dst[i] /= s[i] } } func MulTo(a, b, c []float32) { if len(a) != len(b) || len(a) != len(c) { panic("floats: slice lengths do not match") } feature.mulTo(a, b, c) } // Sub one vector by another: dst = dst - s func Sub(dst, s []float32) { if len(dst) != len(s) { panic("floats: slice lengths do not match") } feature.sub(dst, s) } // MulConstTo multiplies a vector and a const, then saves the result in dst: dst = a * c func MulConstTo(a []float32, c float32, dst []float32) { if len(a) != len(dst) { panic("floats: slice lengths do not match") } feature.mulConstTo(a, c, dst) } func MulConstAddTo(a []float32, c float32, b, dst []float32) { if len(a) != len(b) || len(a) != len(dst) { panic("floats: slice lengths do not match") } feature.mulConstAddTo(a, c, b, dst) } // MulConstAddTo multiplies a vector and a const, then adds to dst: dst = dst + a * c func MulConstAdd(a []float32, c float32, dst []float32) { if len(a) != len(dst) { panic("floats: slice lengths do not match") } feature.mulConstAdd(a, c, dst) } // MulAddTo multiplies a vector and a vector, then adds to a vector: c += a * b func MulAddTo(a, b, c []float32) { if len(a) != len(b) || len(a) != len(c) { panic("floats: slice lengths do not match") } for i := range a { c[i] += a[i] * b[i] } } // AddTo adds two vectors and saves the result in dst: dst = a + b func AddTo(a, b, dst []float32) { if len(a) != len(b) || len(a) != len(dst) { panic("floats: slice lengths do not match") } for i := range a { dst[i] = a[i] + b[i] } } func AddConst(dst []float32, c float32) { feature.addConst(dst, c) } func DivTo(a, b, c []float32) { if len(a) != len(b) || len(a) != len(c) { panic("floats: slice lengths do not match") } feature.divTo(a, b, c) } func SqrtTo(a, b []float32) { if len(a) != len(b) { panic("floats: slice lengths do not match") } feature.sqrtTo(a, b) } func Sqrt(a []float32) { for i := range a { a[i] = float32(math.Sqrt(float64(a[i]))) } } // Dot two vectors. func Dot(a, b []float32) (ret float32) { if len(a) != len(b) { panic("floats: slice lengths do not match") } return feature.dot(a, b) } func Euclidean(a, b []float32) float32 { if len(a) != len(b) { panic("floats: slice lengths do not match") } return feature.euclidean(a, b) } func MM(transA, transB bool, m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int) { feature.mm(transA, transB, m, n, k, a, lda, b, ldb, c, ldc) } ================================================ FILE: common/floats/floats_amd64.go ================================================ //go:build !noasm // Copyright 2022 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package floats import ( "strings" "unsafe" "golang.org/x/sys/cpu" ) //go:generate goat src/floats_avx.c -O3 -mavx //go:generate goat src/floats_avx512.c -O3 -mavx -mfma -mavx512f type Feature uint64 const ( AVX Feature = 1 << iota FMA AVX512F MKL OPENBLAS ) const AVX512 = AVX | FMA | AVX512F var feature Feature func init() { if cpu.X86.HasAVX { feature = feature | AVX } if cpu.X86.HasFMA { feature = feature | FMA } if cpu.X86.HasAVX512F { feature = feature | AVX512F } } func (feature Feature) String() string { var features []string if feature&AVX512 == AVX512 { features = append(features, "AVX512") } else if feature&AVX == AVX { features = append(features, "AVX") } if len(features) == 0 { return "AMD64" } return strings.Join(features, "+") } func (feature Feature) mulConstAddTo(a []float32, b float32, c []float32, dst []float32) { if feature&AVX512 == AVX512 { _mm512_mul_const_add_to(unsafe.Pointer(&a[0]), unsafe.Pointer(&b), unsafe.Pointer(&c[0]), unsafe.Pointer(&dst[0]), int64(len(a))) } else if feature&AVX == AVX { _mm256_mul_const_add_to(unsafe.Pointer(&a[0]), unsafe.Pointer(&b), unsafe.Pointer(&c[0]), unsafe.Pointer(&dst[0]), int64(len(a))) } else { mulConstAddTo(a, b, c, dst) } } func (feature Feature) mulConstAdd(a []float32, b float32, c []float32) { if feature&AVX512 == AVX512 { _mm512_mul_const_add(unsafe.Pointer(&a[0]), unsafe.Pointer(&b), unsafe.Pointer(&c[0]), int64(len(a))) } else if feature&AVX == AVX { _mm256_mul_const_add(unsafe.Pointer(&a[0]), unsafe.Pointer(&b), unsafe.Pointer(&c[0]), int64(len(a))) } else { mulConstAdd(a, b, c) } } func (feature Feature) mulConstTo(a []float32, b float32, c []float32) { if feature&AVX512 == AVX512 { _mm512_mul_const_to(unsafe.Pointer(&a[0]), unsafe.Pointer(&b), unsafe.Pointer(&c[0]), int64(len(a))) } else if feature&AVX == AVX { _mm256_mul_const_to(unsafe.Pointer(&a[0]), unsafe.Pointer(&b), unsafe.Pointer(&c[0]), int64(len(a))) } else { mulConstTo(a, b, c) } } func (feature Feature) addConst(a []float32, b float32) { if feature&AVX512 == AVX512 { _mm512_add_const(unsafe.Pointer(&a[0]), unsafe.Pointer(&b), int64(len(a))) } else if feature&AVX == AVX { _mm256_add_const(unsafe.Pointer(&a[0]), unsafe.Pointer(&b), int64(len(a))) } else { addConst(a, b) } } func (feature Feature) sub(a, b []float32) { if feature&AVX512 == AVX512 { _mm512_sub(unsafe.Pointer(&a[0]), unsafe.Pointer(&b[0]), int64(len(a))) } else if feature&AVX == AVX { _mm256_sub(unsafe.Pointer(&a[0]), unsafe.Pointer(&b[0]), int64(len(a))) } else { sub(a, b) } } func (feature Feature) subTo(a, b, c []float32) { if feature&AVX512 == AVX512 { _mm512_sub_to(unsafe.Pointer(&a[0]), unsafe.Pointer(&b[0]), unsafe.Pointer(&c[0]), int64(len(a))) } else if feature&AVX == AVX { _mm256_sub_to(unsafe.Pointer(&a[0]), unsafe.Pointer(&b[0]), unsafe.Pointer(&c[0]), int64(len(a))) } else { subTo(a, b, c) } } func (feature Feature) mulTo(a, b, c []float32) { if feature&AVX512 == AVX512 { _mm512_mul_to(unsafe.Pointer(&a[0]), unsafe.Pointer(&b[0]), unsafe.Pointer(&c[0]), int64(len(a))) } else if feature&AVX == AVX { _mm256_mul_to(unsafe.Pointer(&a[0]), unsafe.Pointer(&b[0]), unsafe.Pointer(&c[0]), int64(len(a))) } else { mulTo(a, b, c) } } func (feature Feature) mulConst(a []float32, b float32) { if feature&AVX512 == AVX512 { _mm512_mul_const(unsafe.Pointer(&a[0]), unsafe.Pointer(&b), int64(len(a))) } else if feature&AVX == AVX { _mm256_mul_const(unsafe.Pointer(&a[0]), unsafe.Pointer(&b), int64(len(a))) } else { mulConst(a, b) } } func (feature Feature) divTo(a, b, c []float32) { if feature&AVX512 == AVX512 { _mm512_div_to(unsafe.Pointer(&a[0]), unsafe.Pointer(&b[0]), unsafe.Pointer(&c[0]), int64(len(a))) } else if feature&AVX == AVX { _mm256_div_to(unsafe.Pointer(&a[0]), unsafe.Pointer(&b[0]), unsafe.Pointer(&c[0]), int64(len(a))) } else { divTo(a, b, c) } } func (feature Feature) sqrtTo(a, b []float32) { if feature&AVX512 == AVX512 { _mm512_sqrt_to(unsafe.Pointer(&a[0]), unsafe.Pointer(&b[0]), int64(len(a))) } else if feature&AVX == AVX { _mm256_sqrt_to(unsafe.Pointer(&a[0]), unsafe.Pointer(&b[0]), int64(len(a))) } else { sqrtTo(a, b) } } func (feature Feature) dot(a, b []float32) float32 { if feature&AVX512 == AVX512 { return _mm512_dot(unsafe.Pointer(&a[0]), unsafe.Pointer(&b[0]), int64(len(a))) } else if feature&AVX == AVX { return _mm256_dot(unsafe.Pointer(&a[0]), unsafe.Pointer(&b[0]), int64(len(a))) } else { return dot(a, b) } } func (feature Feature) euclidean(a, b []float32) float32 { if feature&AVX512 == AVX512 { return _mm512_euclidean(unsafe.Pointer(&a[0]), unsafe.Pointer(&b[0]), int64(len(a))) } else if feature&AVX == AVX { return _mm256_euclidean(unsafe.Pointer(&a[0]), unsafe.Pointer(&b[0]), int64(len(a))) } else { return euclidean(a, b) } } func (feature Feature) mm(transA, transB bool, m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int) { // Bypass AVX512 optimizations when MKL or OpenBLAS is enabled. if feature&AVX512 == AVX512 && feature&MKL == 0 && feature&OPENBLAS == 0 { _mm512_mm(transA, transB, int64(m), int64(n), int64(k), unsafe.Pointer(&a[0]), int64(lda), unsafe.Pointer(&b[0]), int64(ldb), unsafe.Pointer(&c[0]), int64(ldc)) } else if feature&AVX == AVX && feature&MKL == 0 && feature&OPENBLAS == 0 { _mm256_mm(transA, transB, int64(m), int64(n), int64(k), unsafe.Pointer(&a[0]), int64(lda), unsafe.Pointer(&b[0]), int64(ldb), unsafe.Pointer(&c[0]), int64(ldc)) } else { mm(transA, transB, m, n, k, a, lda, b, ldb, c, ldc) } } ================================================ FILE: common/floats/floats_amd64_test.go ================================================ //go:build !noasm // Copyright 2022 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package floats import ( "fmt" "math/rand" "strconv" "testing" "github.com/stretchr/testify/suite" "golang.org/x/sys/cpu" ) var supportedFeatures []Feature func init() { supportedFeatures = []Feature{0} if cpu.X86.HasAVX { supportedFeatures = append(supportedFeatures, AVX) } if cpu.X86.HasAVX && cpu.X86.HasFMA && cpu.X86.HasAVX512F { supportedFeatures = append(supportedFeatures, AVX512) } } func TestAVX(t *testing.T) { suite.Run(t, &SIMDTestSuite{Feature: AVX}) } func TestAVX512(t *testing.T) { suite.Run(t, &SIMDTestSuite{Feature: AVX512}) } func initializeFloat32Array(n int) []float32 { x := make([]float32, n) for i := 0; i < n; i++ { x[i] = rand.Float32() } return x } func BenchmarkDot(b *testing.B) { for _, feat := range supportedFeatures { b.Run(feat.String(), func(b *testing.B) { for i := 16; i <= 128; i *= 2 { b.Run(strconv.Itoa(i), func(b *testing.B) { v1 := initializeFloat32Array(i) v2 := initializeFloat32Array(i) b.ResetTimer() for i := 0; i < b.N; i++ { feat.dot(v1, v2) } }) } }) } } func BenchmarkEuclidean(b *testing.B) { for _, feat := range supportedFeatures { b.Run(feat.String(), func(b *testing.B) { for i := 16; i <= 128; i *= 2 { b.Run(strconv.Itoa(i), func(b *testing.B) { v1 := initializeFloat32Array(i) v2 := initializeFloat32Array(i) b.ResetTimer() for i := 0; i < b.N; i++ { feat.euclidean(v1, v2) } }) } }) } } func BenchmarkMulConstAddTo(b *testing.B) { for _, feat := range supportedFeatures { b.Run(feat.String(), func(b *testing.B) { for i := 16; i <= 128; i *= 2 { b.Run(strconv.Itoa(i), func(b *testing.B) { v1 := initializeFloat32Array(i) v2 := initializeFloat32Array(i) v3 := make([]float32, i) b.ResetTimer() for i := 0; i < b.N; i++ { feat.mulConstAddTo(v1, 2, v2, v3) } }) } }) } } func BenchmarkMulConstAdd(b *testing.B) { for _, feat := range supportedFeatures { b.Run(feat.String(), func(b *testing.B) { for i := 16; i <= 128; i *= 2 { b.Run(strconv.Itoa(i), func(b *testing.B) { v1 := initializeFloat32Array(i) v2 := initializeFloat32Array(i) b.ResetTimer() for i := 0; i < b.N; i++ { feat.mulConstAdd(v1, 2, v2) } }) } }) } } func BenchmarkMulConst(b *testing.B) { for _, feat := range supportedFeatures { b.Run(feat.String(), func(b *testing.B) { for i := 16; i <= 128; i *= 2 { b.Run(strconv.Itoa(i), func(b *testing.B) { v1 := initializeFloat32Array(i) b.ResetTimer() for i := 0; i < b.N; i++ { feat.mulConst(v1, 2) } }) } }) } } func BenchmarkMulConstTo(b *testing.B) { for _, feat := range supportedFeatures { b.Run(feat.String(), func(b *testing.B) { for i := 16; i <= 128; i *= 2 { b.Run(strconv.Itoa(i), func(b *testing.B) { v1 := initializeFloat32Array(i) v2 := initializeFloat32Array(i) b.ResetTimer() for i := 0; i < b.N; i++ { feat.mulConstTo(v1, 2, v2) } }) } }) } } func BenchmarkAddConst(b *testing.B) { for _, feat := range supportedFeatures { b.Run(feat.String(), func(b *testing.B) { for i := 16; i <= 128; i *= 2 { b.Run(strconv.Itoa(i), func(b *testing.B) { v1 := initializeFloat32Array(i) b.ResetTimer() for i := 0; i < b.N; i++ { feat.addConst(v1, 2) } }) } }) } } func BenchmarkSub(b *testing.B) { for _, feat := range supportedFeatures { b.Run(feat.String(), func(b *testing.B) { for i := 16; i <= 128; i *= 2 { b.Run(strconv.Itoa(i), func(b *testing.B) { v1 := initializeFloat32Array(i) v2 := initializeFloat32Array(i) b.ResetTimer() for i := 0; i < b.N; i++ { feat.sub(v1, v2) } }) } }) } } func BenchmarkSubTo(b *testing.B) { for _, feat := range supportedFeatures { b.Run(feat.String(), func(b *testing.B) { for i := 16; i <= 128; i *= 2 { b.Run(strconv.Itoa(i), func(b *testing.B) { v1 := initializeFloat32Array(i) v2 := initializeFloat32Array(i) v3 := make([]float32, i) b.ResetTimer() for i := 0; i < b.N; i++ { feat.subTo(v1, v2, v3) } }) } }) } } func BenchmarkMulTo(b *testing.B) { for _, feat := range supportedFeatures { b.Run(feat.String(), func(b *testing.B) { for i := 16; i <= 128; i *= 2 { b.Run(strconv.Itoa(i), func(b *testing.B) { v1 := initializeFloat32Array(i) v2 := initializeFloat32Array(i) v3 := make([]float32, i) b.ResetTimer() for i := 0; i < b.N; i++ { feat.mulTo(v1, v2, v3) } }) } }) } } func BenchmarkDivTo(b *testing.B) { for _, feat := range supportedFeatures { b.Run(feat.String(), func(b *testing.B) { for i := 16; i <= 128; i *= 2 { b.Run(strconv.Itoa(i), func(b *testing.B) { v1 := initializeFloat32Array(i) v2 := initializeFloat32Array(i) v3 := make([]float32, i) b.ResetTimer() for i := 0; i < b.N; i++ { feat.divTo(v1, v2, v3) } }) } }) } } func BenchmarkSqrtTo(b *testing.B) { for _, feat := range supportedFeatures { b.Run(feat.String(), func(b *testing.B) { for i := 16; i <= 128; i *= 2 { b.Run(strconv.Itoa(i), func(b *testing.B) { v1 := initializeFloat32Array(i) v2 := make([]float32, i) b.ResetTimer() for i := 0; i < b.N; i++ { feat.sqrtTo(v1, v2) } }) } }) } } func BenchmarkMM(b *testing.B) { for _, transA := range []bool{false, true} { for _, transB := range []bool{false, true} { for _, feat := range supportedFeatures { b.Run(fmt.Sprintf("(%v,%v,%v)", transA, transB, feat.String()), func(b *testing.B) { for n := 16; n <= 128; n *= 2 { b.Run(strconv.Itoa(n), func(b *testing.B) { matA := initializeFloat32Array(n * n) matB := initializeFloat32Array(n * n) matC := make([]float32, n*n) b.ResetTimer() for i := 0; i < b.N; i++ { feat.mm(transA, transB, n, n, n, matA, n, matB, n, matC, n) } }) } }) } } } } ================================================ FILE: common/floats/floats_arm64.go ================================================ //go:build !noasm // Copyright 2022 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package floats import ( "strings" "unsafe" ) //go:generate goat src/floats_neon.c -O3 type Feature uint64 const ( AMX Feature = 1 << iota // Apple matrix extension OPENBLAS ) var feature Feature func (feature Feature) String() string { var features = []string{"ARM64"} if feature&AMX > 0 { features = append(features, "AMX") } return strings.Join(features, "+") } func (feature Feature) mulConstAddTo(a []float32, b float32, c, dst []float32) { vmul_const_add_to(unsafe.Pointer(&a[0]), unsafe.Pointer(&b), unsafe.Pointer(&c[0]), unsafe.Pointer(&dst[0]), int64(len(a))) } func (feature Feature) mulConstAdd(a []float32, b float32, c []float32) { vmul_const_add(unsafe.Pointer(&a[0]), unsafe.Pointer(&b), unsafe.Pointer(&c[0]), int64(len(a))) } func (feature Feature) mulConstTo(a []float32, b float32, c []float32) { vmul_const_to(unsafe.Pointer(&a[0]), unsafe.Pointer(&b), unsafe.Pointer(&c[0]), int64(len(a))) } func (feature Feature) addConst(a []float32, b float32) { vadd_const(unsafe.Pointer(&a[0]), unsafe.Pointer(&b), int64(len(a))) } func (feature Feature) sub(a, b []float32) { vsub(unsafe.Pointer(&a[0]), unsafe.Pointer(&b[0]), int64(len(a))) } func (feature Feature) subTo(a, b, c []float32) { vsub_to(unsafe.Pointer(&a[0]), unsafe.Pointer(&b[0]), unsafe.Pointer(&c[0]), int64(len(a))) } func (feature Feature) mulTo(a, b, c []float32) { vmul_to(unsafe.Pointer(&a[0]), unsafe.Pointer(&b[0]), unsafe.Pointer(&c[0]), int64(len(a))) } func (feature Feature) divTo(a, b, c []float32) { vdiv_to(unsafe.Pointer(&a[0]), unsafe.Pointer(&b[0]), unsafe.Pointer(&c[0]), int64(len(a))) } func (feature Feature) sqrtTo(a, b []float32) { vsqrt_to(unsafe.Pointer(&a[0]), unsafe.Pointer(&b[0]), int64(len(a))) } func (feature Feature) mulConst(a []float32, b float32) { vmul_const(unsafe.Pointer(&a[0]), unsafe.Pointer(&b), int64(len(a))) } func (feature Feature) dot(a, b []float32) float32 { return vdot(unsafe.Pointer(&a[0]), unsafe.Pointer(&b[0]), int64(len(a))) } func (feature Feature) euclidean(a, b []float32) float32 { return veuclidean(unsafe.Pointer(&a[0]), unsafe.Pointer(&b[0]), int64(len(a))) } func (feature Feature) mm(transA, transB bool, m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int) { if feature&AMX == AMX || feature&OPENBLAS == OPENBLAS { mm(transA, transB, m, n, k, a, lda, b, ldb, c, ldc) } else { vmm(transA, transB, int64(m), int64(n), int64(k), unsafe.Pointer(&a[0]), int64(lda), unsafe.Pointer(&b[0]), int64(ldb), unsafe.Pointer(&c[0]), int64(ldc)) } } ================================================ FILE: common/floats/floats_arm64_test.go ================================================ //go:build !noasm // Copyright 2022 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package floats import ( "fmt" "math/rand" "runtime" "strconv" "testing" "github.com/stretchr/testify/suite" ) func TestASIMD(t *testing.T) { suite.Run(t, &SIMDTestSuite{}) } func TestAMX(t *testing.T) { if runtime.GOOS != "darwin" || runtime.GOARCH != "arm64" { t.Skip("AMX is only supported on macOS ARM64") } suite.Run(t, &SIMDTestSuite{Feature: AMX}) } func initializeFloat32Array(n int) []float32 { x := make([]float32, n) for i := 0; i < n; i++ { x[i] = rand.Float32() } return x } func BenchmarkDot(b *testing.B) { for _, feat := range []Feature{0} { b.Run(feat.String(), func(b *testing.B) { for i := 16; i <= 128; i *= 2 { b.Run(strconv.Itoa(i), func(b *testing.B) { v1 := initializeFloat32Array(i) v2 := initializeFloat32Array(i) b.ResetTimer() for i := 0; i < b.N; i++ { feat.dot(v1, v2) } }) } }) } } func BenchmarkEuclidean(b *testing.B) { for _, feat := range []Feature{0} { b.Run(feat.String(), func(b *testing.B) { for i := 16; i <= 128; i *= 2 { b.Run(strconv.Itoa(i), func(b *testing.B) { v1 := initializeFloat32Array(i) v2 := initializeFloat32Array(i) b.ResetTimer() for i := 0; i < b.N; i++ { feat.euclidean(v1, v2) } }) } }) } } func BenchmarkMulConstAddTo(b *testing.B) { for _, feat := range []Feature{0} { b.Run(feat.String(), func(b *testing.B) { for i := 16; i <= 128; i *= 2 { b.Run(strconv.Itoa(i), func(b *testing.B) { v1 := initializeFloat32Array(i) v2 := initializeFloat32Array(i) v3 := make([]float32, i) b.ResetTimer() for i := 0; i < b.N; i++ { feat.mulConstAddTo(v1, 2, v2, v3) } }) } }) } } func BenchmarkMulConstAdd(b *testing.B) { for _, feat := range []Feature{0} { b.Run(feat.String(), func(b *testing.B) { for i := 16; i <= 128; i *= 2 { b.Run(strconv.Itoa(i), func(b *testing.B) { v1 := initializeFloat32Array(i) v2 := initializeFloat32Array(i) b.ResetTimer() for i := 0; i < b.N; i++ { feat.mulConstAdd(v1, 2, v2) } }) } }) } } func BenchmarkMulConstTo(b *testing.B) { for _, feat := range []Feature{0} { b.Run(feat.String(), func(b *testing.B) { for i := 16; i <= 128; i *= 2 { b.Run(strconv.Itoa(i), func(b *testing.B) { v1 := initializeFloat32Array(i) v2 := initializeFloat32Array(i) b.ResetTimer() for i := 0; i < b.N; i++ { feat.mulConstTo(v1, 2, v2) } }) } }) } } func BenchmarkAddConst(b *testing.B) { for _, feat := range []Feature{0} { b.Run(feat.String(), func(b *testing.B) { for i := 16; i <= 128; i *= 2 { b.Run(strconv.Itoa(i), func(b *testing.B) { v1 := initializeFloat32Array(i) b.ResetTimer() for i := 0; i < b.N; i++ { feat.addConst(v1, 2) } }) } }) } } func BenchmarkMulConst(b *testing.B) { for _, feat := range []Feature{0} { b.Run(feat.String(), func(b *testing.B) { for i := 16; i <= 128; i *= 2 { b.Run(strconv.Itoa(i), func(b *testing.B) { v1 := initializeFloat32Array(i) b.ResetTimer() for i := 0; i < b.N; i++ { feat.mulConst(v1, 2) } }) } }) } } func BenchmarkSubTo(b *testing.B) { for _, feat := range []Feature{0} { b.Run(feat.String(), func(b *testing.B) { for i := 16; i <= 128; i *= 2 { b.Run(strconv.Itoa(i), func(b *testing.B) { v1 := initializeFloat32Array(i) v2 := initializeFloat32Array(i) v3 := make([]float32, i) b.ResetTimer() for i := 0; i < b.N; i++ { feat.subTo(v1, v2, v3) } }) } }) } } func BenchmarkSub(b *testing.B) { for _, feat := range []Feature{0} { b.Run(feat.String(), func(b *testing.B) { for i := 16; i <= 128; i *= 2 { b.Run(strconv.Itoa(i), func(b *testing.B) { v1 := initializeFloat32Array(i) v2 := initializeFloat32Array(i) b.ResetTimer() for i := 0; i < b.N; i++ { feat.sub(v1, v2) } }) } }) } } func BenchmarkMulTo(b *testing.B) { for _, feat := range []Feature{0} { b.Run(feat.String(), func(b *testing.B) { for i := 16; i <= 128; i *= 2 { b.Run(strconv.Itoa(i), func(b *testing.B) { v1 := initializeFloat32Array(i) v2 := initializeFloat32Array(i) v3 := make([]float32, i) b.ResetTimer() for i := 0; i < b.N; i++ { feat.mulTo(v1, v2, v3) } }) } }) } } func BenchmarkDivTo(b *testing.B) { for _, feat := range []Feature{0} { b.Run(feat.String(), func(b *testing.B) { for i := 16; i <= 128; i *= 2 { b.Run(strconv.Itoa(i), func(b *testing.B) { v1 := initializeFloat32Array(i) v2 := initializeFloat32Array(i) v3 := make([]float32, i) b.ResetTimer() for i := 0; i < b.N; i++ { feat.divTo(v1, v2, v3) } }) } }) } } func BenchmarkSqrtTo(b *testing.B) { for _, feat := range []Feature{0} { b.Run(feat.String(), func(b *testing.B) { for i := 16; i <= 128; i *= 2 { b.Run(strconv.Itoa(i), func(b *testing.B) { v1 := initializeFloat32Array(i) v2 := make([]float32, i) b.ResetTimer() for i := 0; i < b.N; i++ { feat.sqrtTo(v1, v2) } }) } }) } } func BenchmarkMM(b *testing.B) { var feats = []Feature{0} if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" { feats = append(feats, AMX) } for _, transA := range []bool{false, true} { for _, transB := range []bool{false, true} { for _, feat := range feats { b.Run(fmt.Sprintf("(%v,%v,%v)", transA, transB, feat.String()), func(b *testing.B) { for n := 16; n <= 128; n *= 2 { b.Run(strconv.Itoa(n), func(b *testing.B) { matA := initializeFloat32Array(n * n) matB := initializeFloat32Array(n * n) matC := make([]float32, n*n) b.ResetTimer() for i := 0; i < b.N; i++ { feat.mm(transA, transB, n, n, n, matA, n, matB, n, matC, n) } }) } }) } } } } ================================================ FILE: common/floats/floats_avx.go ================================================ //go:build !noasm && amd64 // Code generated by GoAT. DO NOT EDIT. // versions: // clang 19.1.7 (++20250114103320+cd708029e0b2-1~exp1~20250114103432.75) // objdump 2.38 // flags: -mavx -O3 // source: src/floats_avx.c package floats import "unsafe" //go:noescape func _mm256_mul_const_add_to(a, b, c, dst unsafe.Pointer, n int64) //go:noescape func _mm256_mul_const_add(a, b, c unsafe.Pointer, n int64) //go:noescape func _mm256_mul_const_to(a, b, c unsafe.Pointer, n int64) //go:noescape func _mm256_mul_const(a, b unsafe.Pointer, n int64) //go:noescape func _mm256_add_const(a, b unsafe.Pointer, n int64) //go:noescape func _mm256_sub_to(a, b, c unsafe.Pointer, n int64) //go:noescape func _mm256_sub(a, b unsafe.Pointer, n int64) //go:noescape func _mm256_mul_to(a, b, c unsafe.Pointer, n int64) //go:noescape func _mm256_div_to(a, b, c unsafe.Pointer, n int64) //go:noescape func _mm256_sqrt_to(a, b unsafe.Pointer, n int64) //go:noescape func _mm256_dot(a, b unsafe.Pointer, n int64) (result float32) //go:noescape func _mm256_euclidean(a, b unsafe.Pointer, n int64) (result float32) //go:noescape func _mm256_mm(transA, transB bool, m, n, k int64, a unsafe.Pointer, lda int64, b unsafe.Pointer, ldb int64, c unsafe.Pointer, ldc int64) ================================================ FILE: common/floats/floats_avx.s ================================================ //go:build !noasm && amd64 // Code generated by GoAT. DO NOT EDIT. // versions: // clang 19.1.7 (++20250114103320+cd708029e0b2-1~exp1~20250114103432.75) // objdump 2.38 // flags: -mavx -O3 // source: src/floats_avx.c TEXT ·_mm256_mul_const_add_to(SB), $0-40 MOVQ a+0(FP), DI MOVQ b+8(FP), SI MOVQ c+16(FP), DX MOVQ dst+24(FP), CX MOVQ n+32(FP), R8 BYTE $0x55 // pushq %rbp WORD $0x8948; BYTE $0xe5 // movq %rsp, %rbp LONG $0xf8e48348 // andq $-8, %rsp LONG $0x07488d4d // leaq 7(%r8), %r9 WORD $0x854d; BYTE $0xc0 // testq %r8, %r8 LONG $0xc8490f4d // cmovnsq %r8, %r9 WORD $0x894c; BYTE $0xc8 // movq %r9, %rax LONG $0x03f8c148 // sarq $3, %rax LONG $0xf8e18349 // andq $-8, %r9 WORD $0x294d; BYTE $0xc8 // subq %r9, %r8 WORD $0xc085 // testl %eax, %eax JLE LBB0_6 WORD $0xf883; BYTE $0x01 // cmpl $1, %eax JE LBB0_4 WORD $0x8941; BYTE $0xc1 // movl %eax, %r9d LONG $0xfee18141; WORD $0xffff; BYTE $0x7f // andl $2147483646, %r9d # imm = 0x7FFFFFFE LBB0_3: LONG $0x187de2c4; BYTE $0x06 // vbroadcastss (%rsi), %ymm0 LONG $0x0759fcc5 // vmulps (%rdi), %ymm0, %ymm0 LONG $0x0258fcc5 // vaddps (%rdx), %ymm0, %ymm0 LONG $0x0111fcc5 // vmovups %ymm0, (%rcx) LONG $0x187de2c4; BYTE $0x06 // vbroadcastss (%rsi), %ymm0 LONG $0x4759fcc5; BYTE $0x20 // vmulps 32(%rdi), %ymm0, %ymm0 LONG $0x4258fcc5; BYTE $0x20 // vaddps 32(%rdx), %ymm0, %ymm0 LONG $0x4111fcc5; BYTE $0x20 // vmovups %ymm0, 32(%rcx) LONG $0x40c78348 // addq $64, %rdi LONG $0x40c28348 // addq $64, %rdx LONG $0x40c18348 // addq $64, %rcx LONG $0xfec18341 // addl $-2, %r9d JNE LBB0_3 LBB0_4: WORD $0x01a8 // testb $1, %al JE LBB0_6 LONG $0x187de2c4; BYTE $0x06 // vbroadcastss (%rsi), %ymm0 LONG $0x0759fcc5 // vmulps (%rdi), %ymm0, %ymm0 LONG $0x0258fcc5 // vaddps (%rdx), %ymm0, %ymm0 LONG $0x0111fcc5 // vmovups %ymm0, (%rcx) LONG $0x20c78348 // addq $32, %rdi LONG $0x20c28348 // addq $32, %rdx LONG $0x20c18348 // addq $32, %rcx LBB0_6: WORD $0x854d; BYTE $0xc0 // testq %r8, %r8 JLE LBB0_14 WORD $0x8944; BYTE $0xc0 // movl %r8d, %eax LONG $0x0710fac5 // vmovss (%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x0659fac5 // vmulss (%rsi), %xmm0, %xmm0 LONG $0x0258fac5 // vaddss (%rdx), %xmm0, %xmm0 LONG $0x0111fac5 // vmovss %xmm0, (%rcx) LONG $0x01f88348 // cmpq $1, %rax JE LBB0_14 LONG $0x4710fac5; BYTE $0x04 // vmovss 4(%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x0659fac5 // vmulss (%rsi), %xmm0, %xmm0 LONG $0x4258fac5; BYTE $0x04 // vaddss 4(%rdx), %xmm0, %xmm0 LONG $0x4111fac5; BYTE $0x04 // vmovss %xmm0, 4(%rcx) WORD $0xf883; BYTE $0x02 // cmpl $2, %eax JE LBB0_14 LONG $0x4710fac5; BYTE $0x08 // vmovss 8(%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x0659fac5 // vmulss (%rsi), %xmm0, %xmm0 LONG $0x4258fac5; BYTE $0x08 // vaddss 8(%rdx), %xmm0, %xmm0 LONG $0x4111fac5; BYTE $0x08 // vmovss %xmm0, 8(%rcx) WORD $0xf883; BYTE $0x03 // cmpl $3, %eax JE LBB0_14 LONG $0x4710fac5; BYTE $0x0c // vmovss 12(%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x0659fac5 // vmulss (%rsi), %xmm0, %xmm0 LONG $0x4258fac5; BYTE $0x0c // vaddss 12(%rdx), %xmm0, %xmm0 LONG $0x4111fac5; BYTE $0x0c // vmovss %xmm0, 12(%rcx) WORD $0xf883; BYTE $0x04 // cmpl $4, %eax JE LBB0_14 LONG $0x4710fac5; BYTE $0x10 // vmovss 16(%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x0659fac5 // vmulss (%rsi), %xmm0, %xmm0 LONG $0x4258fac5; BYTE $0x10 // vaddss 16(%rdx), %xmm0, %xmm0 LONG $0x4111fac5; BYTE $0x10 // vmovss %xmm0, 16(%rcx) WORD $0xf883; BYTE $0x05 // cmpl $5, %eax JE LBB0_14 LONG $0x4710fac5; BYTE $0x14 // vmovss 20(%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x0659fac5 // vmulss (%rsi), %xmm0, %xmm0 LONG $0x4258fac5; BYTE $0x14 // vaddss 20(%rdx), %xmm0, %xmm0 LONG $0x4111fac5; BYTE $0x14 // vmovss %xmm0, 20(%rcx) WORD $0xf883; BYTE $0x06 // cmpl $6, %eax JE LBB0_14 LONG $0x4710fac5; BYTE $0x18 // vmovss 24(%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x0659fac5 // vmulss (%rsi), %xmm0, %xmm0 LONG $0x4258fac5; BYTE $0x18 // vaddss 24(%rdx), %xmm0, %xmm0 LONG $0x4111fac5; BYTE $0x18 // vmovss %xmm0, 24(%rcx) LBB0_14: WORD $0x8948; BYTE $0xec // movq %rbp, %rsp BYTE $0x5d // popq %rbp WORD $0xf8c5; BYTE $0x77 // vzeroupper RET TEXT ·_mm256_mul_const_add(SB), $0-32 MOVQ a+0(FP), DI MOVQ b+8(FP), SI MOVQ c+16(FP), DX MOVQ n+24(FP), CX BYTE $0x55 // pushq %rbp WORD $0x8948; BYTE $0xe5 // movq %rsp, %rbp LONG $0xf8e48348 // andq $-8, %rsp LONG $0x07418d4c // leaq 7(%rcx), %r8 WORD $0x8548; BYTE $0xc9 // testq %rcx, %rcx LONG $0xc1490f4c // cmovnsq %rcx, %r8 WORD $0x894c; BYTE $0xc0 // movq %r8, %rax LONG $0x03f8c148 // sarq $3, %rax LONG $0xf8e08349 // andq $-8, %r8 WORD $0x294c; BYTE $0xc1 // subq %r8, %rcx WORD $0xc085 // testl %eax, %eax JLE LBB1_6 WORD $0xf883; BYTE $0x01 // cmpl $1, %eax JE LBB1_4 WORD $0x8941; BYTE $0xc0 // movl %eax, %r8d LONG $0xfee08141; WORD $0xffff; BYTE $0x7f // andl $2147483646, %r8d # imm = 0x7FFFFFFE LBB1_3: LONG $0x187de2c4; BYTE $0x06 // vbroadcastss (%rsi), %ymm0 LONG $0x0759fcc5 // vmulps (%rdi), %ymm0, %ymm0 LONG $0x0258fcc5 // vaddps (%rdx), %ymm0, %ymm0 LONG $0x0211fcc5 // vmovups %ymm0, (%rdx) LONG $0x187de2c4; BYTE $0x06 // vbroadcastss (%rsi), %ymm0 LONG $0x4759fcc5; BYTE $0x20 // vmulps 32(%rdi), %ymm0, %ymm0 LONG $0x4258fcc5; BYTE $0x20 // vaddps 32(%rdx), %ymm0, %ymm0 LONG $0x4211fcc5; BYTE $0x20 // vmovups %ymm0, 32(%rdx) LONG $0x40c78348 // addq $64, %rdi LONG $0x40c28348 // addq $64, %rdx LONG $0xfec08341 // addl $-2, %r8d JNE LBB1_3 LBB1_4: WORD $0x01a8 // testb $1, %al JE LBB1_6 LONG $0x187de2c4; BYTE $0x06 // vbroadcastss (%rsi), %ymm0 LONG $0x0759fcc5 // vmulps (%rdi), %ymm0, %ymm0 LONG $0x0258fcc5 // vaddps (%rdx), %ymm0, %ymm0 LONG $0x0211fcc5 // vmovups %ymm0, (%rdx) LONG $0x20c78348 // addq $32, %rdi LONG $0x20c28348 // addq $32, %rdx LBB1_6: WORD $0x8548; BYTE $0xc9 // testq %rcx, %rcx JLE LBB1_14 WORD $0xc889 // movl %ecx, %eax LONG $0x0710fac5 // vmovss (%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x0659fac5 // vmulss (%rsi), %xmm0, %xmm0 LONG $0x0258fac5 // vaddss (%rdx), %xmm0, %xmm0 LONG $0x0211fac5 // vmovss %xmm0, (%rdx) LONG $0x01f88348 // cmpq $1, %rax JE LBB1_14 LONG $0x4710fac5; BYTE $0x04 // vmovss 4(%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x0659fac5 // vmulss (%rsi), %xmm0, %xmm0 LONG $0x4258fac5; BYTE $0x04 // vaddss 4(%rdx), %xmm0, %xmm0 LONG $0x4211fac5; BYTE $0x04 // vmovss %xmm0, 4(%rdx) WORD $0xf883; BYTE $0x02 // cmpl $2, %eax JE LBB1_14 LONG $0x4710fac5; BYTE $0x08 // vmovss 8(%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x0659fac5 // vmulss (%rsi), %xmm0, %xmm0 LONG $0x4258fac5; BYTE $0x08 // vaddss 8(%rdx), %xmm0, %xmm0 LONG $0x4211fac5; BYTE $0x08 // vmovss %xmm0, 8(%rdx) WORD $0xf883; BYTE $0x03 // cmpl $3, %eax JE LBB1_14 LONG $0x4710fac5; BYTE $0x0c // vmovss 12(%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x0659fac5 // vmulss (%rsi), %xmm0, %xmm0 LONG $0x4258fac5; BYTE $0x0c // vaddss 12(%rdx), %xmm0, %xmm0 LONG $0x4211fac5; BYTE $0x0c // vmovss %xmm0, 12(%rdx) WORD $0xf883; BYTE $0x04 // cmpl $4, %eax JE LBB1_14 LONG $0x4710fac5; BYTE $0x10 // vmovss 16(%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x0659fac5 // vmulss (%rsi), %xmm0, %xmm0 LONG $0x4258fac5; BYTE $0x10 // vaddss 16(%rdx), %xmm0, %xmm0 LONG $0x4211fac5; BYTE $0x10 // vmovss %xmm0, 16(%rdx) WORD $0xf883; BYTE $0x05 // cmpl $5, %eax JE LBB1_14 LONG $0x4710fac5; BYTE $0x14 // vmovss 20(%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x0659fac5 // vmulss (%rsi), %xmm0, %xmm0 LONG $0x4258fac5; BYTE $0x14 // vaddss 20(%rdx), %xmm0, %xmm0 LONG $0x4211fac5; BYTE $0x14 // vmovss %xmm0, 20(%rdx) WORD $0xf883; BYTE $0x06 // cmpl $6, %eax JE LBB1_14 LONG $0x4710fac5; BYTE $0x18 // vmovss 24(%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x0659fac5 // vmulss (%rsi), %xmm0, %xmm0 LONG $0x4258fac5; BYTE $0x18 // vaddss 24(%rdx), %xmm0, %xmm0 LONG $0x4211fac5; BYTE $0x18 // vmovss %xmm0, 24(%rdx) LBB1_14: WORD $0x8948; BYTE $0xec // movq %rbp, %rsp BYTE $0x5d // popq %rbp WORD $0xf8c5; BYTE $0x77 // vzeroupper RET TEXT ·_mm256_mul_const_to(SB), $0-32 MOVQ a+0(FP), DI MOVQ b+8(FP), SI MOVQ c+16(FP), DX MOVQ n+24(FP), CX BYTE $0x55 // pushq %rbp WORD $0x8948; BYTE $0xe5 // movq %rsp, %rbp LONG $0xf8e48348 // andq $-8, %rsp LONG $0x07418d4c // leaq 7(%rcx), %r8 WORD $0x8548; BYTE $0xc9 // testq %rcx, %rcx LONG $0xc1490f4c // cmovnsq %rcx, %r8 WORD $0x894c; BYTE $0xc0 // movq %r8, %rax LONG $0x03f8c148 // sarq $3, %rax LONG $0xf8e08349 // andq $-8, %r8 WORD $0x294c; BYTE $0xc1 // subq %r8, %rcx WORD $0xc085 // testl %eax, %eax JLE LBB2_6 WORD $0xf883; BYTE $0x01 // cmpl $1, %eax JE LBB2_4 WORD $0x8941; BYTE $0xc0 // movl %eax, %r8d LONG $0xfee08141; WORD $0xffff; BYTE $0x7f // andl $2147483646, %r8d # imm = 0x7FFFFFFE LBB2_3: LONG $0x187de2c4; BYTE $0x06 // vbroadcastss (%rsi), %ymm0 LONG $0x0759fcc5 // vmulps (%rdi), %ymm0, %ymm0 LONG $0x0211fcc5 // vmovups %ymm0, (%rdx) LONG $0x187de2c4; BYTE $0x06 // vbroadcastss (%rsi), %ymm0 LONG $0x4759fcc5; BYTE $0x20 // vmulps 32(%rdi), %ymm0, %ymm0 LONG $0x4211fcc5; BYTE $0x20 // vmovups %ymm0, 32(%rdx) LONG $0x40c78348 // addq $64, %rdi LONG $0x40c28348 // addq $64, %rdx LONG $0xfec08341 // addl $-2, %r8d JNE LBB2_3 LBB2_4: WORD $0x01a8 // testb $1, %al JE LBB2_6 LONG $0x187de2c4; BYTE $0x06 // vbroadcastss (%rsi), %ymm0 LONG $0x0759fcc5 // vmulps (%rdi), %ymm0, %ymm0 LONG $0x0211fcc5 // vmovups %ymm0, (%rdx) LONG $0x20c78348 // addq $32, %rdi LONG $0x20c28348 // addq $32, %rdx LBB2_6: WORD $0x8548; BYTE $0xc9 // testq %rcx, %rcx JLE LBB2_14 WORD $0xc889 // movl %ecx, %eax LONG $0x0710fac5 // vmovss (%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x0659fac5 // vmulss (%rsi), %xmm0, %xmm0 LONG $0x0211fac5 // vmovss %xmm0, (%rdx) LONG $0x01f88348 // cmpq $1, %rax JE LBB2_14 LONG $0x4710fac5; BYTE $0x04 // vmovss 4(%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x0659fac5 // vmulss (%rsi), %xmm0, %xmm0 LONG $0x4211fac5; BYTE $0x04 // vmovss %xmm0, 4(%rdx) WORD $0xf883; BYTE $0x02 // cmpl $2, %eax JE LBB2_14 LONG $0x4710fac5; BYTE $0x08 // vmovss 8(%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x0659fac5 // vmulss (%rsi), %xmm0, %xmm0 LONG $0x4211fac5; BYTE $0x08 // vmovss %xmm0, 8(%rdx) WORD $0xf883; BYTE $0x03 // cmpl $3, %eax JE LBB2_14 LONG $0x4710fac5; BYTE $0x0c // vmovss 12(%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x0659fac5 // vmulss (%rsi), %xmm0, %xmm0 LONG $0x4211fac5; BYTE $0x0c // vmovss %xmm0, 12(%rdx) WORD $0xf883; BYTE $0x04 // cmpl $4, %eax JE LBB2_14 LONG $0x4710fac5; BYTE $0x10 // vmovss 16(%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x0659fac5 // vmulss (%rsi), %xmm0, %xmm0 LONG $0x4211fac5; BYTE $0x10 // vmovss %xmm0, 16(%rdx) WORD $0xf883; BYTE $0x05 // cmpl $5, %eax JE LBB2_14 LONG $0x4710fac5; BYTE $0x14 // vmovss 20(%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x0659fac5 // vmulss (%rsi), %xmm0, %xmm0 LONG $0x4211fac5; BYTE $0x14 // vmovss %xmm0, 20(%rdx) WORD $0xf883; BYTE $0x06 // cmpl $6, %eax JE LBB2_14 LONG $0x4710fac5; BYTE $0x18 // vmovss 24(%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x0659fac5 // vmulss (%rsi), %xmm0, %xmm0 LONG $0x4211fac5; BYTE $0x18 // vmovss %xmm0, 24(%rdx) LBB2_14: WORD $0x8948; BYTE $0xec // movq %rbp, %rsp BYTE $0x5d // popq %rbp WORD $0xf8c5; BYTE $0x77 // vzeroupper RET TEXT ·_mm256_mul_const(SB), $0-24 MOVQ a+0(FP), DI MOVQ b+8(FP), SI MOVQ n+16(FP), DX BYTE $0x55 // pushq %rbp WORD $0x8948; BYTE $0xe5 // movq %rsp, %rbp LONG $0xf8e48348 // andq $-8, %rsp LONG $0x074a8d48 // leaq 7(%rdx), %rcx WORD $0x8548; BYTE $0xd2 // testq %rdx, %rdx LONG $0xca490f48 // cmovnsq %rdx, %rcx WORD $0x8948; BYTE $0xc8 // movq %rcx, %rax LONG $0x03f8c148 // sarq $3, %rax LONG $0xf8e18348 // andq $-8, %rcx WORD $0x2948; BYTE $0xca // subq %rcx, %rdx WORD $0xc085 // testl %eax, %eax JLE LBB3_6 WORD $0xf883; BYTE $0x01 // cmpl $1, %eax JE LBB3_4 WORD $0xc189 // movl %eax, %ecx LONG $0xfffee181; WORD $0x7fff // andl $2147483646, %ecx # imm = 0x7FFFFFFE LBB3_3: LONG $0x187de2c4; BYTE $0x06 // vbroadcastss (%rsi), %ymm0 LONG $0x0759fcc5 // vmulps (%rdi), %ymm0, %ymm0 LONG $0x0711fcc5 // vmovups %ymm0, (%rdi) LONG $0x187de2c4; BYTE $0x06 // vbroadcastss (%rsi), %ymm0 LONG $0x4759fcc5; BYTE $0x20 // vmulps 32(%rdi), %ymm0, %ymm0 LONG $0x4711fcc5; BYTE $0x20 // vmovups %ymm0, 32(%rdi) LONG $0x40c78348 // addq $64, %rdi WORD $0xc183; BYTE $0xfe // addl $-2, %ecx JNE LBB3_3 LBB3_4: WORD $0x01a8 // testb $1, %al JE LBB3_6 LONG $0x187de2c4; BYTE $0x06 // vbroadcastss (%rsi), %ymm0 LONG $0x0759fcc5 // vmulps (%rdi), %ymm0, %ymm0 LONG $0x0711fcc5 // vmovups %ymm0, (%rdi) LONG $0x20c78348 // addq $32, %rdi LBB3_6: WORD $0x8548; BYTE $0xd2 // testq %rdx, %rdx JLE LBB3_14 WORD $0xd089 // movl %edx, %eax LONG $0x0610fac5 // vmovss (%rsi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x0759fac5 // vmulss (%rdi), %xmm0, %xmm0 LONG $0x0711fac5 // vmovss %xmm0, (%rdi) LONG $0x01f88348 // cmpq $1, %rax JE LBB3_14 LONG $0x0610fac5 // vmovss (%rsi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x4759fac5; BYTE $0x04 // vmulss 4(%rdi), %xmm0, %xmm0 LONG $0x4711fac5; BYTE $0x04 // vmovss %xmm0, 4(%rdi) WORD $0xf883; BYTE $0x02 // cmpl $2, %eax JE LBB3_14 LONG $0x0610fac5 // vmovss (%rsi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x4759fac5; BYTE $0x08 // vmulss 8(%rdi), %xmm0, %xmm0 LONG $0x4711fac5; BYTE $0x08 // vmovss %xmm0, 8(%rdi) WORD $0xf883; BYTE $0x03 // cmpl $3, %eax JE LBB3_14 LONG $0x0610fac5 // vmovss (%rsi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x4759fac5; BYTE $0x0c // vmulss 12(%rdi), %xmm0, %xmm0 LONG $0x4711fac5; BYTE $0x0c // vmovss %xmm0, 12(%rdi) WORD $0xf883; BYTE $0x04 // cmpl $4, %eax JE LBB3_14 LONG $0x0610fac5 // vmovss (%rsi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x4759fac5; BYTE $0x10 // vmulss 16(%rdi), %xmm0, %xmm0 LONG $0x4711fac5; BYTE $0x10 // vmovss %xmm0, 16(%rdi) WORD $0xf883; BYTE $0x05 // cmpl $5, %eax JE LBB3_14 LONG $0x0610fac5 // vmovss (%rsi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x4759fac5; BYTE $0x14 // vmulss 20(%rdi), %xmm0, %xmm0 LONG $0x4711fac5; BYTE $0x14 // vmovss %xmm0, 20(%rdi) WORD $0xf883; BYTE $0x06 // cmpl $6, %eax JE LBB3_14 LONG $0x0610fac5 // vmovss (%rsi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x4759fac5; BYTE $0x18 // vmulss 24(%rdi), %xmm0, %xmm0 LONG $0x4711fac5; BYTE $0x18 // vmovss %xmm0, 24(%rdi) LBB3_14: WORD $0x8948; BYTE $0xec // movq %rbp, %rsp BYTE $0x5d // popq %rbp WORD $0xf8c5; BYTE $0x77 // vzeroupper RET TEXT ·_mm256_add_const(SB), $0-24 MOVQ a+0(FP), DI MOVQ b+8(FP), SI MOVQ n+16(FP), DX BYTE $0x55 // pushq %rbp WORD $0x8948; BYTE $0xe5 // movq %rsp, %rbp LONG $0xf8e48348 // andq $-8, %rsp LONG $0x074a8d48 // leaq 7(%rdx), %rcx WORD $0x8548; BYTE $0xd2 // testq %rdx, %rdx LONG $0xca490f48 // cmovnsq %rdx, %rcx WORD $0x8948; BYTE $0xc8 // movq %rcx, %rax LONG $0x03f8c148 // sarq $3, %rax LONG $0xf8e18348 // andq $-8, %rcx WORD $0x2948; BYTE $0xca // subq %rcx, %rdx WORD $0xc085 // testl %eax, %eax JLE LBB4_6 WORD $0xf883; BYTE $0x01 // cmpl $1, %eax JE LBB4_4 WORD $0xc189 // movl %eax, %ecx LONG $0xfffee181; WORD $0x7fff // andl $2147483646, %ecx # imm = 0x7FFFFFFE LBB4_3: LONG $0x187de2c4; BYTE $0x06 // vbroadcastss (%rsi), %ymm0 LONG $0x0758fcc5 // vaddps (%rdi), %ymm0, %ymm0 LONG $0x0711fcc5 // vmovups %ymm0, (%rdi) LONG $0x187de2c4; BYTE $0x06 // vbroadcastss (%rsi), %ymm0 LONG $0x4758fcc5; BYTE $0x20 // vaddps 32(%rdi), %ymm0, %ymm0 LONG $0x4711fcc5; BYTE $0x20 // vmovups %ymm0, 32(%rdi) LONG $0x40c78348 // addq $64, %rdi WORD $0xc183; BYTE $0xfe // addl $-2, %ecx JNE LBB4_3 LBB4_4: WORD $0x01a8 // testb $1, %al JE LBB4_6 LONG $0x187de2c4; BYTE $0x06 // vbroadcastss (%rsi), %ymm0 LONG $0x0758fcc5 // vaddps (%rdi), %ymm0, %ymm0 LONG $0x0711fcc5 // vmovups %ymm0, (%rdi) LONG $0x20c78348 // addq $32, %rdi LBB4_6: WORD $0x8548; BYTE $0xd2 // testq %rdx, %rdx JLE LBB4_14 WORD $0xd089 // movl %edx, %eax LONG $0x0610fac5 // vmovss (%rsi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x0758fac5 // vaddss (%rdi), %xmm0, %xmm0 LONG $0x0711fac5 // vmovss %xmm0, (%rdi) LONG $0x01f88348 // cmpq $1, %rax JE LBB4_14 LONG $0x0610fac5 // vmovss (%rsi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x4758fac5; BYTE $0x04 // vaddss 4(%rdi), %xmm0, %xmm0 LONG $0x4711fac5; BYTE $0x04 // vmovss %xmm0, 4(%rdi) WORD $0xf883; BYTE $0x02 // cmpl $2, %eax JE LBB4_14 LONG $0x0610fac5 // vmovss (%rsi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x4758fac5; BYTE $0x08 // vaddss 8(%rdi), %xmm0, %xmm0 LONG $0x4711fac5; BYTE $0x08 // vmovss %xmm0, 8(%rdi) WORD $0xf883; BYTE $0x03 // cmpl $3, %eax JE LBB4_14 LONG $0x0610fac5 // vmovss (%rsi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x4758fac5; BYTE $0x0c // vaddss 12(%rdi), %xmm0, %xmm0 LONG $0x4711fac5; BYTE $0x0c // vmovss %xmm0, 12(%rdi) WORD $0xf883; BYTE $0x04 // cmpl $4, %eax JE LBB4_14 LONG $0x0610fac5 // vmovss (%rsi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x4758fac5; BYTE $0x10 // vaddss 16(%rdi), %xmm0, %xmm0 LONG $0x4711fac5; BYTE $0x10 // vmovss %xmm0, 16(%rdi) WORD $0xf883; BYTE $0x05 // cmpl $5, %eax JE LBB4_14 LONG $0x0610fac5 // vmovss (%rsi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x4758fac5; BYTE $0x14 // vaddss 20(%rdi), %xmm0, %xmm0 LONG $0x4711fac5; BYTE $0x14 // vmovss %xmm0, 20(%rdi) WORD $0xf883; BYTE $0x06 // cmpl $6, %eax JE LBB4_14 LONG $0x0610fac5 // vmovss (%rsi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x4758fac5; BYTE $0x18 // vaddss 24(%rdi), %xmm0, %xmm0 LONG $0x4711fac5; BYTE $0x18 // vmovss %xmm0, 24(%rdi) LBB4_14: WORD $0x8948; BYTE $0xec // movq %rbp, %rsp BYTE $0x5d // popq %rbp WORD $0xf8c5; BYTE $0x77 // vzeroupper RET TEXT ·_mm256_sub_to(SB), $0-32 MOVQ a+0(FP), DI MOVQ b+8(FP), SI MOVQ c+16(FP), DX MOVQ n+24(FP), CX BYTE $0x55 // pushq %rbp WORD $0x8948; BYTE $0xe5 // movq %rsp, %rbp LONG $0xf8e48348 // andq $-8, %rsp LONG $0x07418d48 // leaq 7(%rcx), %rax WORD $0x8548; BYTE $0xc9 // testq %rcx, %rcx LONG $0xc1490f48 // cmovnsq %rcx, %rax WORD $0x8949; BYTE $0xc0 // movq %rax, %r8 LONG $0x03f8c149 // sarq $3, %r8 LONG $0xf8e08348 // andq $-8, %rax WORD $0x2948; BYTE $0xc1 // subq %rax, %rcx WORD $0x8545; BYTE $0xc0 // testl %r8d, %r8d JLE LBB5_6 WORD $0x8944; BYTE $0xc0 // movl %r8d, %eax WORD $0xe083; BYTE $0x03 // andl $3, %eax LONG $0x04f88341 // cmpl $4, %r8d JB LBB5_4 LONG $0xfce08141; WORD $0xffff; BYTE $0x7f // andl $2147483644, %r8d # imm = 0x7FFFFFFC LBB5_3: LONG $0x0710fcc5 // vmovups (%rdi), %ymm0 LONG $0x065cfcc5 // vsubps (%rsi), %ymm0, %ymm0 LONG $0x0211fcc5 // vmovups %ymm0, (%rdx) LONG $0x4710fcc5; BYTE $0x20 // vmovups 32(%rdi), %ymm0 LONG $0x465cfcc5; BYTE $0x20 // vsubps 32(%rsi), %ymm0, %ymm0 LONG $0x4211fcc5; BYTE $0x20 // vmovups %ymm0, 32(%rdx) LONG $0x4710fcc5; BYTE $0x40 // vmovups 64(%rdi), %ymm0 LONG $0x465cfcc5; BYTE $0x40 // vsubps 64(%rsi), %ymm0, %ymm0 LONG $0x4211fcc5; BYTE $0x40 // vmovups %ymm0, 64(%rdx) LONG $0x4710fcc5; BYTE $0x60 // vmovups 96(%rdi), %ymm0 LONG $0x465cfcc5; BYTE $0x60 // vsubps 96(%rsi), %ymm0, %ymm0 LONG $0x4211fcc5; BYTE $0x60 // vmovups %ymm0, 96(%rdx) LONG $0x80ef8348 // subq $-128, %rdi LONG $0x80ee8348 // subq $-128, %rsi LONG $0x80ea8348 // subq $-128, %rdx LONG $0xfcc08341 // addl $-4, %r8d JNE LBB5_3 LBB5_4: WORD $0xc085 // testl %eax, %eax JE LBB5_6 LBB5_5: LONG $0x0710fcc5 // vmovups (%rdi), %ymm0 LONG $0x065cfcc5 // vsubps (%rsi), %ymm0, %ymm0 LONG $0x0211fcc5 // vmovups %ymm0, (%rdx) LONG $0x20c78348 // addq $32, %rdi LONG $0x20c68348 // addq $32, %rsi LONG $0x20c28348 // addq $32, %rdx WORD $0xc8ff // decl %eax JNE LBB5_5 LBB5_6: WORD $0x8548; BYTE $0xc9 // testq %rcx, %rcx JLE LBB5_14 WORD $0xc889 // movl %ecx, %eax LONG $0x0710fac5 // vmovss (%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x065cfac5 // vsubss (%rsi), %xmm0, %xmm0 LONG $0x0211fac5 // vmovss %xmm0, (%rdx) LONG $0x01f88348 // cmpq $1, %rax JE LBB5_14 LONG $0x4710fac5; BYTE $0x04 // vmovss 4(%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x465cfac5; BYTE $0x04 // vsubss 4(%rsi), %xmm0, %xmm0 LONG $0x4211fac5; BYTE $0x04 // vmovss %xmm0, 4(%rdx) WORD $0xf883; BYTE $0x02 // cmpl $2, %eax JE LBB5_14 LONG $0x4710fac5; BYTE $0x08 // vmovss 8(%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x465cfac5; BYTE $0x08 // vsubss 8(%rsi), %xmm0, %xmm0 LONG $0x4211fac5; BYTE $0x08 // vmovss %xmm0, 8(%rdx) WORD $0xf883; BYTE $0x03 // cmpl $3, %eax JE LBB5_14 LONG $0x4710fac5; BYTE $0x0c // vmovss 12(%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x465cfac5; BYTE $0x0c // vsubss 12(%rsi), %xmm0, %xmm0 LONG $0x4211fac5; BYTE $0x0c // vmovss %xmm0, 12(%rdx) WORD $0xf883; BYTE $0x04 // cmpl $4, %eax JE LBB5_14 LONG $0x4710fac5; BYTE $0x10 // vmovss 16(%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x465cfac5; BYTE $0x10 // vsubss 16(%rsi), %xmm0, %xmm0 LONG $0x4211fac5; BYTE $0x10 // vmovss %xmm0, 16(%rdx) WORD $0xf883; BYTE $0x05 // cmpl $5, %eax JE LBB5_14 LONG $0x4710fac5; BYTE $0x14 // vmovss 20(%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x465cfac5; BYTE $0x14 // vsubss 20(%rsi), %xmm0, %xmm0 LONG $0x4211fac5; BYTE $0x14 // vmovss %xmm0, 20(%rdx) WORD $0xf883; BYTE $0x06 // cmpl $6, %eax JE LBB5_14 LONG $0x4710fac5; BYTE $0x18 // vmovss 24(%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x465cfac5; BYTE $0x18 // vsubss 24(%rsi), %xmm0, %xmm0 LONG $0x4211fac5; BYTE $0x18 // vmovss %xmm0, 24(%rdx) LBB5_14: WORD $0x8948; BYTE $0xec // movq %rbp, %rsp BYTE $0x5d // popq %rbp WORD $0xf8c5; BYTE $0x77 // vzeroupper RET TEXT ·_mm256_sub(SB), $0-24 MOVQ a+0(FP), DI MOVQ b+8(FP), SI MOVQ n+16(FP), DX BYTE $0x55 // pushq %rbp WORD $0x8948; BYTE $0xe5 // movq %rsp, %rbp LONG $0xf8e48348 // andq $-8, %rsp LONG $0x07428d48 // leaq 7(%rdx), %rax WORD $0x8548; BYTE $0xd2 // testq %rdx, %rdx LONG $0xc2490f48 // cmovnsq %rdx, %rax WORD $0x8948; BYTE $0xc1 // movq %rax, %rcx LONG $0x03f9c148 // sarq $3, %rcx LONG $0xf8e08348 // andq $-8, %rax WORD $0x2948; BYTE $0xc2 // subq %rax, %rdx WORD $0xc985 // testl %ecx, %ecx JLE LBB6_6 WORD $0xc889 // movl %ecx, %eax WORD $0xe083; BYTE $0x03 // andl $3, %eax WORD $0xf983; BYTE $0x04 // cmpl $4, %ecx JB LBB6_4 LONG $0xfffce181; WORD $0x7fff // andl $2147483644, %ecx # imm = 0x7FFFFFFC LBB6_3: LONG $0x0710fcc5 // vmovups (%rdi), %ymm0 LONG $0x4f10fcc5; BYTE $0x20 // vmovups 32(%rdi), %ymm1 LONG $0x5710fcc5; BYTE $0x40 // vmovups 64(%rdi), %ymm2 LONG $0x065cfcc5 // vsubps (%rsi), %ymm0, %ymm0 LONG $0x5f10fcc5; BYTE $0x60 // vmovups 96(%rdi), %ymm3 LONG $0x0711fcc5 // vmovups %ymm0, (%rdi) LONG $0x465cf4c5; BYTE $0x20 // vsubps 32(%rsi), %ymm1, %ymm0 LONG $0x4711fcc5; BYTE $0x20 // vmovups %ymm0, 32(%rdi) LONG $0x465cecc5; BYTE $0x40 // vsubps 64(%rsi), %ymm2, %ymm0 LONG $0x4711fcc5; BYTE $0x40 // vmovups %ymm0, 64(%rdi) LONG $0x465ce4c5; BYTE $0x60 // vsubps 96(%rsi), %ymm3, %ymm0 LONG $0x4711fcc5; BYTE $0x60 // vmovups %ymm0, 96(%rdi) LONG $0x80ef8348 // subq $-128, %rdi LONG $0x80ee8348 // subq $-128, %rsi WORD $0xc183; BYTE $0xfc // addl $-4, %ecx JNE LBB6_3 LBB6_4: WORD $0xc085 // testl %eax, %eax JE LBB6_6 LBB6_5: LONG $0x0710fcc5 // vmovups (%rdi), %ymm0 LONG $0x065cfcc5 // vsubps (%rsi), %ymm0, %ymm0 LONG $0x0711fcc5 // vmovups %ymm0, (%rdi) LONG $0x20c78348 // addq $32, %rdi LONG $0x20c68348 // addq $32, %rsi WORD $0xc8ff // decl %eax JNE LBB6_5 LBB6_6: WORD $0x8548; BYTE $0xd2 // testq %rdx, %rdx JLE LBB6_14 WORD $0xd089 // movl %edx, %eax LONG $0x0710fac5 // vmovss (%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x065cfac5 // vsubss (%rsi), %xmm0, %xmm0 LONG $0x0711fac5 // vmovss %xmm0, (%rdi) LONG $0x01f88348 // cmpq $1, %rax JE LBB6_14 LONG $0x4710fac5; BYTE $0x04 // vmovss 4(%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x465cfac5; BYTE $0x04 // vsubss 4(%rsi), %xmm0, %xmm0 LONG $0x4711fac5; BYTE $0x04 // vmovss %xmm0, 4(%rdi) WORD $0xf883; BYTE $0x02 // cmpl $2, %eax JE LBB6_14 LONG $0x4710fac5; BYTE $0x08 // vmovss 8(%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x465cfac5; BYTE $0x08 // vsubss 8(%rsi), %xmm0, %xmm0 LONG $0x4711fac5; BYTE $0x08 // vmovss %xmm0, 8(%rdi) WORD $0xf883; BYTE $0x03 // cmpl $3, %eax JE LBB6_14 LONG $0x4710fac5; BYTE $0x0c // vmovss 12(%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x465cfac5; BYTE $0x0c // vsubss 12(%rsi), %xmm0, %xmm0 LONG $0x4711fac5; BYTE $0x0c // vmovss %xmm0, 12(%rdi) WORD $0xf883; BYTE $0x04 // cmpl $4, %eax JE LBB6_14 LONG $0x4710fac5; BYTE $0x10 // vmovss 16(%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x465cfac5; BYTE $0x10 // vsubss 16(%rsi), %xmm0, %xmm0 LONG $0x4711fac5; BYTE $0x10 // vmovss %xmm0, 16(%rdi) WORD $0xf883; BYTE $0x05 // cmpl $5, %eax JE LBB6_14 LONG $0x4710fac5; BYTE $0x14 // vmovss 20(%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x465cfac5; BYTE $0x14 // vsubss 20(%rsi), %xmm0, %xmm0 LONG $0x4711fac5; BYTE $0x14 // vmovss %xmm0, 20(%rdi) WORD $0xf883; BYTE $0x06 // cmpl $6, %eax JE LBB6_14 LONG $0x4710fac5; BYTE $0x18 // vmovss 24(%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x465cfac5; BYTE $0x18 // vsubss 24(%rsi), %xmm0, %xmm0 LONG $0x4711fac5; BYTE $0x18 // vmovss %xmm0, 24(%rdi) LBB6_14: WORD $0x8948; BYTE $0xec // movq %rbp, %rsp BYTE $0x5d // popq %rbp WORD $0xf8c5; BYTE $0x77 // vzeroupper RET TEXT ·_mm256_mul_to(SB), $0-32 MOVQ a+0(FP), DI MOVQ b+8(FP), SI MOVQ c+16(FP), DX MOVQ n+24(FP), CX BYTE $0x55 // pushq %rbp WORD $0x8948; BYTE $0xe5 // movq %rsp, %rbp LONG $0xf8e48348 // andq $-8, %rsp LONG $0x07418d48 // leaq 7(%rcx), %rax WORD $0x8548; BYTE $0xc9 // testq %rcx, %rcx LONG $0xc1490f48 // cmovnsq %rcx, %rax WORD $0x8949; BYTE $0xc0 // movq %rax, %r8 LONG $0x03f8c149 // sarq $3, %r8 LONG $0xf8e08348 // andq $-8, %rax WORD $0x2948; BYTE $0xc1 // subq %rax, %rcx WORD $0x8545; BYTE $0xc0 // testl %r8d, %r8d JLE LBB7_6 WORD $0x8944; BYTE $0xc0 // movl %r8d, %eax WORD $0xe083; BYTE $0x03 // andl $3, %eax LONG $0x04f88341 // cmpl $4, %r8d JB LBB7_4 LONG $0xfce08141; WORD $0xffff; BYTE $0x7f // andl $2147483644, %r8d # imm = 0x7FFFFFFC LBB7_3: LONG $0x0710fcc5 // vmovups (%rdi), %ymm0 LONG $0x0659fcc5 // vmulps (%rsi), %ymm0, %ymm0 LONG $0x0211fcc5 // vmovups %ymm0, (%rdx) LONG $0x4710fcc5; BYTE $0x20 // vmovups 32(%rdi), %ymm0 LONG $0x4659fcc5; BYTE $0x20 // vmulps 32(%rsi), %ymm0, %ymm0 LONG $0x4211fcc5; BYTE $0x20 // vmovups %ymm0, 32(%rdx) LONG $0x4710fcc5; BYTE $0x40 // vmovups 64(%rdi), %ymm0 LONG $0x4659fcc5; BYTE $0x40 // vmulps 64(%rsi), %ymm0, %ymm0 LONG $0x4211fcc5; BYTE $0x40 // vmovups %ymm0, 64(%rdx) LONG $0x4710fcc5; BYTE $0x60 // vmovups 96(%rdi), %ymm0 LONG $0x4659fcc5; BYTE $0x60 // vmulps 96(%rsi), %ymm0, %ymm0 LONG $0x4211fcc5; BYTE $0x60 // vmovups %ymm0, 96(%rdx) LONG $0x80ef8348 // subq $-128, %rdi LONG $0x80ee8348 // subq $-128, %rsi LONG $0x80ea8348 // subq $-128, %rdx LONG $0xfcc08341 // addl $-4, %r8d JNE LBB7_3 LBB7_4: WORD $0xc085 // testl %eax, %eax JE LBB7_6 LBB7_5: LONG $0x0710fcc5 // vmovups (%rdi), %ymm0 LONG $0x0659fcc5 // vmulps (%rsi), %ymm0, %ymm0 LONG $0x0211fcc5 // vmovups %ymm0, (%rdx) LONG $0x20c78348 // addq $32, %rdi LONG $0x20c68348 // addq $32, %rsi LONG $0x20c28348 // addq $32, %rdx WORD $0xc8ff // decl %eax JNE LBB7_5 LBB7_6: WORD $0x8548; BYTE $0xc9 // testq %rcx, %rcx JLE LBB7_14 WORD $0xc889 // movl %ecx, %eax LONG $0x0710fac5 // vmovss (%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x0659fac5 // vmulss (%rsi), %xmm0, %xmm0 LONG $0x0211fac5 // vmovss %xmm0, (%rdx) LONG $0x01f88348 // cmpq $1, %rax JE LBB7_14 LONG $0x4710fac5; BYTE $0x04 // vmovss 4(%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x4659fac5; BYTE $0x04 // vmulss 4(%rsi), %xmm0, %xmm0 LONG $0x4211fac5; BYTE $0x04 // vmovss %xmm0, 4(%rdx) WORD $0xf883; BYTE $0x02 // cmpl $2, %eax JE LBB7_14 LONG $0x4710fac5; BYTE $0x08 // vmovss 8(%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x4659fac5; BYTE $0x08 // vmulss 8(%rsi), %xmm0, %xmm0 LONG $0x4211fac5; BYTE $0x08 // vmovss %xmm0, 8(%rdx) WORD $0xf883; BYTE $0x03 // cmpl $3, %eax JE LBB7_14 LONG $0x4710fac5; BYTE $0x0c // vmovss 12(%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x4659fac5; BYTE $0x0c // vmulss 12(%rsi), %xmm0, %xmm0 LONG $0x4211fac5; BYTE $0x0c // vmovss %xmm0, 12(%rdx) WORD $0xf883; BYTE $0x04 // cmpl $4, %eax JE LBB7_14 LONG $0x4710fac5; BYTE $0x10 // vmovss 16(%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x4659fac5; BYTE $0x10 // vmulss 16(%rsi), %xmm0, %xmm0 LONG $0x4211fac5; BYTE $0x10 // vmovss %xmm0, 16(%rdx) WORD $0xf883; BYTE $0x05 // cmpl $5, %eax JE LBB7_14 LONG $0x4710fac5; BYTE $0x14 // vmovss 20(%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x4659fac5; BYTE $0x14 // vmulss 20(%rsi), %xmm0, %xmm0 LONG $0x4211fac5; BYTE $0x14 // vmovss %xmm0, 20(%rdx) WORD $0xf883; BYTE $0x06 // cmpl $6, %eax JE LBB7_14 LONG $0x4710fac5; BYTE $0x18 // vmovss 24(%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x4659fac5; BYTE $0x18 // vmulss 24(%rsi), %xmm0, %xmm0 LONG $0x4211fac5; BYTE $0x18 // vmovss %xmm0, 24(%rdx) LBB7_14: WORD $0x8948; BYTE $0xec // movq %rbp, %rsp BYTE $0x5d // popq %rbp WORD $0xf8c5; BYTE $0x77 // vzeroupper RET TEXT ·_mm256_div_to(SB), $0-32 MOVQ a+0(FP), DI MOVQ b+8(FP), SI MOVQ c+16(FP), DX MOVQ n+24(FP), CX BYTE $0x55 // pushq %rbp WORD $0x8948; BYTE $0xe5 // movq %rsp, %rbp LONG $0xf8e48348 // andq $-8, %rsp LONG $0x07418d48 // leaq 7(%rcx), %rax WORD $0x8548; BYTE $0xc9 // testq %rcx, %rcx LONG $0xc1490f48 // cmovnsq %rcx, %rax WORD $0x8949; BYTE $0xc0 // movq %rax, %r8 LONG $0x03f8c149 // sarq $3, %r8 LONG $0xf8e08348 // andq $-8, %rax WORD $0x2948; BYTE $0xc1 // subq %rax, %rcx WORD $0x8545; BYTE $0xc0 // testl %r8d, %r8d JLE LBB8_6 WORD $0x8944; BYTE $0xc0 // movl %r8d, %eax WORD $0xe083; BYTE $0x03 // andl $3, %eax LONG $0x04f88341 // cmpl $4, %r8d JB LBB8_4 LONG $0xfce08141; WORD $0xffff; BYTE $0x7f // andl $2147483644, %r8d # imm = 0x7FFFFFFC LBB8_3: LONG $0x0710fcc5 // vmovups (%rdi), %ymm0 LONG $0x065efcc5 // vdivps (%rsi), %ymm0, %ymm0 LONG $0x0211fcc5 // vmovups %ymm0, (%rdx) LONG $0x4710fcc5; BYTE $0x20 // vmovups 32(%rdi), %ymm0 LONG $0x465efcc5; BYTE $0x20 // vdivps 32(%rsi), %ymm0, %ymm0 LONG $0x4211fcc5; BYTE $0x20 // vmovups %ymm0, 32(%rdx) LONG $0x4710fcc5; BYTE $0x40 // vmovups 64(%rdi), %ymm0 LONG $0x465efcc5; BYTE $0x40 // vdivps 64(%rsi), %ymm0, %ymm0 LONG $0x4211fcc5; BYTE $0x40 // vmovups %ymm0, 64(%rdx) LONG $0x4710fcc5; BYTE $0x60 // vmovups 96(%rdi), %ymm0 LONG $0x465efcc5; BYTE $0x60 // vdivps 96(%rsi), %ymm0, %ymm0 LONG $0x4211fcc5; BYTE $0x60 // vmovups %ymm0, 96(%rdx) LONG $0x80ef8348 // subq $-128, %rdi LONG $0x80ee8348 // subq $-128, %rsi LONG $0x80ea8348 // subq $-128, %rdx LONG $0xfcc08341 // addl $-4, %r8d JNE LBB8_3 LBB8_4: WORD $0xc085 // testl %eax, %eax JE LBB8_6 LBB8_5: LONG $0x0710fcc5 // vmovups (%rdi), %ymm0 LONG $0x065efcc5 // vdivps (%rsi), %ymm0, %ymm0 LONG $0x0211fcc5 // vmovups %ymm0, (%rdx) LONG $0x20c78348 // addq $32, %rdi LONG $0x20c68348 // addq $32, %rsi LONG $0x20c28348 // addq $32, %rdx WORD $0xc8ff // decl %eax JNE LBB8_5 LBB8_6: WORD $0x8548; BYTE $0xc9 // testq %rcx, %rcx JLE LBB8_14 WORD $0xc889 // movl %ecx, %eax LONG $0x0710fac5 // vmovss (%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x065efac5 // vdivss (%rsi), %xmm0, %xmm0 LONG $0x0211fac5 // vmovss %xmm0, (%rdx) LONG $0x01f88348 // cmpq $1, %rax JE LBB8_14 LONG $0x4710fac5; BYTE $0x04 // vmovss 4(%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x465efac5; BYTE $0x04 // vdivss 4(%rsi), %xmm0, %xmm0 LONG $0x4211fac5; BYTE $0x04 // vmovss %xmm0, 4(%rdx) WORD $0xf883; BYTE $0x02 // cmpl $2, %eax JE LBB8_14 LONG $0x4710fac5; BYTE $0x08 // vmovss 8(%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x465efac5; BYTE $0x08 // vdivss 8(%rsi), %xmm0, %xmm0 LONG $0x4211fac5; BYTE $0x08 // vmovss %xmm0, 8(%rdx) WORD $0xf883; BYTE $0x03 // cmpl $3, %eax JE LBB8_14 LONG $0x4710fac5; BYTE $0x0c // vmovss 12(%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x465efac5; BYTE $0x0c // vdivss 12(%rsi), %xmm0, %xmm0 LONG $0x4211fac5; BYTE $0x0c // vmovss %xmm0, 12(%rdx) WORD $0xf883; BYTE $0x04 // cmpl $4, %eax JE LBB8_14 LONG $0x4710fac5; BYTE $0x10 // vmovss 16(%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x465efac5; BYTE $0x10 // vdivss 16(%rsi), %xmm0, %xmm0 LONG $0x4211fac5; BYTE $0x10 // vmovss %xmm0, 16(%rdx) WORD $0xf883; BYTE $0x05 // cmpl $5, %eax JE LBB8_14 LONG $0x4710fac5; BYTE $0x14 // vmovss 20(%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x465efac5; BYTE $0x14 // vdivss 20(%rsi), %xmm0, %xmm0 LONG $0x4211fac5; BYTE $0x14 // vmovss %xmm0, 20(%rdx) WORD $0xf883; BYTE $0x06 // cmpl $6, %eax JE LBB8_14 LONG $0x4710fac5; BYTE $0x18 // vmovss 24(%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x465efac5; BYTE $0x18 // vdivss 24(%rsi), %xmm0, %xmm0 LONG $0x4211fac5; BYTE $0x18 // vmovss %xmm0, 24(%rdx) LBB8_14: WORD $0x8948; BYTE $0xec // movq %rbp, %rsp BYTE $0x5d // popq %rbp WORD $0xf8c5; BYTE $0x77 // vzeroupper RET TEXT ·_mm256_sqrt_to(SB), $0-24 MOVQ a+0(FP), DI MOVQ b+8(FP), SI MOVQ n+16(FP), DX BYTE $0x55 // pushq %rbp WORD $0x8948; BYTE $0xe5 // movq %rsp, %rbp LONG $0xf8e48348 // andq $-8, %rsp LONG $0x07428d48 // leaq 7(%rdx), %rax WORD $0x8548; BYTE $0xd2 // testq %rdx, %rdx LONG $0xc2490f48 // cmovnsq %rdx, %rax WORD $0x8948; BYTE $0xc1 // movq %rax, %rcx LONG $0x03f9c148 // sarq $3, %rcx LONG $0xf8e08348 // andq $-8, %rax WORD $0x2948; BYTE $0xc2 // subq %rax, %rdx WORD $0xc985 // testl %ecx, %ecx JLE LBB9_6 WORD $0xc889 // movl %ecx, %eax WORD $0xe083; BYTE $0x03 // andl $3, %eax WORD $0xf983; BYTE $0x04 // cmpl $4, %ecx JB LBB9_4 LONG $0xfffce181; WORD $0x7fff // andl $2147483644, %ecx # imm = 0x7FFFFFFC LBB9_3: LONG $0x0751fcc5 // vsqrtps (%rdi), %ymm0 LONG $0x0611fcc5 // vmovups %ymm0, (%rsi) LONG $0x4751fcc5; BYTE $0x20 // vsqrtps 32(%rdi), %ymm0 LONG $0x4611fcc5; BYTE $0x20 // vmovups %ymm0, 32(%rsi) LONG $0x4751fcc5; BYTE $0x40 // vsqrtps 64(%rdi), %ymm0 LONG $0x4611fcc5; BYTE $0x40 // vmovups %ymm0, 64(%rsi) LONG $0x4751fcc5; BYTE $0x60 // vsqrtps 96(%rdi), %ymm0 LONG $0x4611fcc5; BYTE $0x60 // vmovups %ymm0, 96(%rsi) LONG $0x80ef8348 // subq $-128, %rdi LONG $0x80ee8348 // subq $-128, %rsi WORD $0xc183; BYTE $0xfc // addl $-4, %ecx JNE LBB9_3 LBB9_4: WORD $0xc085 // testl %eax, %eax JE LBB9_6 LBB9_5: LONG $0x0751fcc5 // vsqrtps (%rdi), %ymm0 LONG $0x0611fcc5 // vmovups %ymm0, (%rsi) LONG $0x20c78348 // addq $32, %rdi LONG $0x20c68348 // addq $32, %rsi WORD $0xc8ff // decl %eax JNE LBB9_5 LBB9_6: WORD $0x8548; BYTE $0xd2 // testq %rdx, %rdx JLE LBB9_14 WORD $0xd089 // movl %edx, %eax LONG $0x0710fac5 // vmovss (%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0xc051fac5 // vsqrtss %xmm0, %xmm0, %xmm0 LONG $0x0611fac5 // vmovss %xmm0, (%rsi) LONG $0x01f88348 // cmpq $1, %rax JE LBB9_14 LONG $0x4710fac5; BYTE $0x04 // vmovss 4(%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0xc051fac5 // vsqrtss %xmm0, %xmm0, %xmm0 LONG $0x4611fac5; BYTE $0x04 // vmovss %xmm0, 4(%rsi) WORD $0xf883; BYTE $0x02 // cmpl $2, %eax JE LBB9_14 LONG $0x4710fac5; BYTE $0x08 // vmovss 8(%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0xc051fac5 // vsqrtss %xmm0, %xmm0, %xmm0 LONG $0x4611fac5; BYTE $0x08 // vmovss %xmm0, 8(%rsi) WORD $0xf883; BYTE $0x03 // cmpl $3, %eax JE LBB9_14 LONG $0x4710fac5; BYTE $0x0c // vmovss 12(%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0xc051fac5 // vsqrtss %xmm0, %xmm0, %xmm0 LONG $0x4611fac5; BYTE $0x0c // vmovss %xmm0, 12(%rsi) WORD $0xf883; BYTE $0x04 // cmpl $4, %eax JE LBB9_14 LONG $0x4710fac5; BYTE $0x10 // vmovss 16(%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0xc051fac5 // vsqrtss %xmm0, %xmm0, %xmm0 LONG $0x4611fac5; BYTE $0x10 // vmovss %xmm0, 16(%rsi) WORD $0xf883; BYTE $0x05 // cmpl $5, %eax JE LBB9_14 LONG $0x4710fac5; BYTE $0x14 // vmovss 20(%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0xc051fac5 // vsqrtss %xmm0, %xmm0, %xmm0 LONG $0x4611fac5; BYTE $0x14 // vmovss %xmm0, 20(%rsi) WORD $0xf883; BYTE $0x06 // cmpl $6, %eax JE LBB9_14 LONG $0x4710fac5; BYTE $0x18 // vmovss 24(%rdi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0xc051fac5 // vsqrtss %xmm0, %xmm0, %xmm0 LONG $0x4611fac5; BYTE $0x18 // vmovss %xmm0, 24(%rsi) LBB9_14: WORD $0x8948; BYTE $0xec // movq %rbp, %rsp BYTE $0x5d // popq %rbp WORD $0xf8c5; BYTE $0x77 // vzeroupper RET TEXT ·_mm256_dot(SB), $8-32 MOVQ a+0(FP), DI MOVQ b+8(FP), SI MOVQ n+16(FP), DX BYTE $0x55 // pushq %rbp WORD $0x8948; BYTE $0xe5 // movq %rsp, %rbp LONG $0xf8e48348 // andq $-8, %rsp LONG $0x07428d48 // leaq 7(%rdx), %rax WORD $0x8548; BYTE $0xd2 // testq %rdx, %rdx LONG $0xc2490f48 // cmovnsq %rdx, %rax WORD $0x8949; BYTE $0xc0 // movq %rax, %r8 LONG $0x03f8c149 // sarq $3, %r8 LONG $0xf8e08348 // andq $-8, %rax WORD $0x2948; BYTE $0xc2 // subq %rax, %rdx WORD $0x8545; BYTE $0xc0 // testl %r8d, %r8d JLE LBB10_1 LONG $0x0710fcc5 // vmovups (%rdi), %ymm0 LONG $0x0659fcc5 // vmulps (%rsi), %ymm0, %ymm0 LONG $0x20c78348 // addq $32, %rdi LONG $0x20c68348 // addq $32, %rsi LONG $0x01f88341 // cmpl $1, %r8d JE LBB10_8 LONG $0xff488d41 // leal -1(%r8), %ecx LONG $0xfec08341 // addl $-2, %r8d WORD $0xc889 // movl %ecx, %eax WORD $0xe083; BYTE $0x03 // andl $3, %eax LONG $0x03f88341 // cmpl $3, %r8d JB LBB10_6 WORD $0xe183; BYTE $0xfc // andl $-4, %ecx LBB10_5: LONG $0x0f10fcc5 // vmovups (%rdi), %ymm1 LONG $0x5710fcc5; BYTE $0x20 // vmovups 32(%rdi), %ymm2 LONG $0x5f10fcc5; BYTE $0x40 // vmovups 64(%rdi), %ymm3 LONG $0x6710fcc5; BYTE $0x60 // vmovups 96(%rdi), %ymm4 LONG $0x0e59f4c5 // vmulps (%rsi), %ymm1, %ymm1 LONG $0xc158fcc5 // vaddps %ymm1, %ymm0, %ymm0 LONG $0x4e59ecc5; BYTE $0x20 // vmulps 32(%rsi), %ymm2, %ymm1 LONG $0x5659e4c5; BYTE $0x40 // vmulps 64(%rsi), %ymm3, %ymm2 LONG $0xc158fcc5 // vaddps %ymm1, %ymm0, %ymm0 LONG $0xc258fcc5 // vaddps %ymm2, %ymm0, %ymm0 LONG $0x4e59dcc5; BYTE $0x60 // vmulps 96(%rsi), %ymm4, %ymm1 LONG $0xc158fcc5 // vaddps %ymm1, %ymm0, %ymm0 LONG $0x80ef8348 // subq $-128, %rdi LONG $0x80ee8348 // subq $-128, %rsi WORD $0xc183; BYTE $0xfc // addl $-4, %ecx JNE LBB10_5 LBB10_6: WORD $0xc085 // testl %eax, %eax JE LBB10_8 LBB10_7: LONG $0x0f10fcc5 // vmovups (%rdi), %ymm1 LONG $0x0e59f4c5 // vmulps (%rsi), %ymm1, %ymm1 LONG $0xc158fcc5 // vaddps %ymm1, %ymm0, %ymm0 LONG $0x20c78348 // addq $32, %rdi LONG $0x20c68348 // addq $32, %rsi WORD $0xc8ff // decl %eax JNE LBB10_7 JMP LBB10_8 LBB10_1: LONG $0xc057f8c5 // vxorps %xmm0, %xmm0, %xmm0 LBB10_8: LONG $0x197de3c4; WORD $0x01c1 // vextractf128 $1, %ymm0, %xmm1 LONG $0xc058f0c5 // vaddps %xmm0, %xmm1, %xmm0 LONG $0xc8c6f9c5; BYTE $0x01 // vshufpd $1, %xmm0, %xmm0, %xmm1 # xmm1 = xmm0[1,0] LONG $0xc158f8c5 // vaddps %xmm1, %xmm0, %xmm0 LONG $0xc816fac5 // vmovshdup %xmm0, %xmm1 # xmm1 = xmm0[1,1,3,3] LONG $0xc158fac5 // vaddss %xmm1, %xmm0, %xmm0 WORD $0x8548; BYTE $0xd2 // testq %rdx, %rdx JLE LBB10_16 WORD $0xd089 // movl %edx, %eax LONG $0x0f10fac5 // vmovss (%rdi), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0x0e59f2c5 // vmulss (%rsi), %xmm1, %xmm1 LONG $0xc058f2c5 // vaddss %xmm0, %xmm1, %xmm0 LONG $0x01f88348 // cmpq $1, %rax JE LBB10_16 LONG $0x4f10fac5; BYTE $0x04 // vmovss 4(%rdi), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0x4e59f2c5; BYTE $0x04 // vmulss 4(%rsi), %xmm1, %xmm1 LONG $0xc058f2c5 // vaddss %xmm0, %xmm1, %xmm0 WORD $0xf883; BYTE $0x02 // cmpl $2, %eax JE LBB10_16 LONG $0x4f10fac5; BYTE $0x08 // vmovss 8(%rdi), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0x4e59f2c5; BYTE $0x08 // vmulss 8(%rsi), %xmm1, %xmm1 LONG $0xc058f2c5 // vaddss %xmm0, %xmm1, %xmm0 WORD $0xf883; BYTE $0x03 // cmpl $3, %eax JE LBB10_16 LONG $0x4f10fac5; BYTE $0x0c // vmovss 12(%rdi), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0x4e59f2c5; BYTE $0x0c // vmulss 12(%rsi), %xmm1, %xmm1 LONG $0xc058f2c5 // vaddss %xmm0, %xmm1, %xmm0 WORD $0xf883; BYTE $0x04 // cmpl $4, %eax JE LBB10_16 LONG $0x4f10fac5; BYTE $0x10 // vmovss 16(%rdi), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0x4e59f2c5; BYTE $0x10 // vmulss 16(%rsi), %xmm1, %xmm1 LONG $0xc058f2c5 // vaddss %xmm0, %xmm1, %xmm0 WORD $0xf883; BYTE $0x05 // cmpl $5, %eax JE LBB10_16 LONG $0x4f10fac5; BYTE $0x14 // vmovss 20(%rdi), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0x4e59f2c5; BYTE $0x14 // vmulss 20(%rsi), %xmm1, %xmm1 LONG $0xc058f2c5 // vaddss %xmm0, %xmm1, %xmm0 WORD $0xf883; BYTE $0x06 // cmpl $6, %eax JE LBB10_16 LONG $0x4f10fac5; BYTE $0x18 // vmovss 24(%rdi), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0x4e59f2c5; BYTE $0x18 // vmulss 24(%rsi), %xmm1, %xmm1 LONG $0xc058f2c5 // vaddss %xmm0, %xmm1, %xmm0 LBB10_16: WORD $0x8948; BYTE $0xec // movq %rbp, %rsp BYTE $0x5d // popq %rbp WORD $0xf8c5; BYTE $0x77 // vzeroupper MOVSS X0, result+24(FP) RET TEXT ·_mm256_euclidean(SB), $8-32 MOVQ a+0(FP), DI MOVQ b+8(FP), SI MOVQ n+16(FP), DX BYTE $0x55 // pushq %rbp WORD $0x8948; BYTE $0xe5 // movq %rsp, %rbp LONG $0xf8e48348 // andq $-8, %rsp LONG $0x07428d48 // leaq 7(%rdx), %rax WORD $0x8548; BYTE $0xd2 // testq %rdx, %rdx LONG $0xc2490f48 // cmovnsq %rdx, %rax WORD $0x8949; BYTE $0xc0 // movq %rax, %r8 LONG $0x03f8c149 // sarq $3, %r8 LONG $0xf8e08348 // andq $-8, %rax WORD $0x2948; BYTE $0xc2 // subq %rax, %rdx WORD $0x8545; BYTE $0xc0 // testl %r8d, %r8d JLE LBB11_1 LONG $0x0710fcc5 // vmovups (%rdi), %ymm0 LONG $0x065cfcc5 // vsubps (%rsi), %ymm0, %ymm0 LONG $0xc059fcc5 // vmulps %ymm0, %ymm0, %ymm0 LONG $0x20c78348 // addq $32, %rdi LONG $0x20c68348 // addq $32, %rsi LONG $0x01f88341 // cmpl $1, %r8d JE LBB11_8 LONG $0xff488d41 // leal -1(%r8), %ecx LONG $0xfec08341 // addl $-2, %r8d WORD $0xc889 // movl %ecx, %eax WORD $0xe083; BYTE $0x03 // andl $3, %eax LONG $0x03f88341 // cmpl $3, %r8d JB LBB11_6 WORD $0xe183; BYTE $0xfc // andl $-4, %ecx LBB11_5: LONG $0x0f10fcc5 // vmovups (%rdi), %ymm1 LONG $0x5710fcc5; BYTE $0x20 // vmovups 32(%rdi), %ymm2 LONG $0x5f10fcc5; BYTE $0x40 // vmovups 64(%rdi), %ymm3 LONG $0x6710fcc5; BYTE $0x60 // vmovups 96(%rdi), %ymm4 LONG $0x0e5cf4c5 // vsubps (%rsi), %ymm1, %ymm1 LONG $0xc959f4c5 // vmulps %ymm1, %ymm1, %ymm1 LONG $0xc158fcc5 // vaddps %ymm1, %ymm0, %ymm0 LONG $0x4e5cecc5; BYTE $0x20 // vsubps 32(%rsi), %ymm2, %ymm1 LONG $0xc959f4c5 // vmulps %ymm1, %ymm1, %ymm1 LONG $0xc158fcc5 // vaddps %ymm1, %ymm0, %ymm0 LONG $0x4e5ce4c5; BYTE $0x40 // vsubps 64(%rsi), %ymm3, %ymm1 LONG $0xc959f4c5 // vmulps %ymm1, %ymm1, %ymm1 LONG $0xc158fcc5 // vaddps %ymm1, %ymm0, %ymm0 LONG $0x4e5cdcc5; BYTE $0x60 // vsubps 96(%rsi), %ymm4, %ymm1 LONG $0xc959f4c5 // vmulps %ymm1, %ymm1, %ymm1 LONG $0xc158fcc5 // vaddps %ymm1, %ymm0, %ymm0 LONG $0x80ef8348 // subq $-128, %rdi LONG $0x80ee8348 // subq $-128, %rsi WORD $0xc183; BYTE $0xfc // addl $-4, %ecx JNE LBB11_5 LBB11_6: WORD $0xc085 // testl %eax, %eax JE LBB11_8 LBB11_7: LONG $0x0f10fcc5 // vmovups (%rdi), %ymm1 LONG $0x0e5cf4c5 // vsubps (%rsi), %ymm1, %ymm1 LONG $0xc959f4c5 // vmulps %ymm1, %ymm1, %ymm1 LONG $0xc158fcc5 // vaddps %ymm1, %ymm0, %ymm0 LONG $0x20c78348 // addq $32, %rdi LONG $0x20c68348 // addq $32, %rsi WORD $0xc8ff // decl %eax JNE LBB11_7 JMP LBB11_8 LBB11_1: LONG $0xc057f8c5 // vxorps %xmm0, %xmm0, %xmm0 LBB11_8: LONG $0x197de3c4; WORD $0x01c1 // vextractf128 $1, %ymm0, %xmm1 LONG $0xc058f0c5 // vaddps %xmm0, %xmm1, %xmm0 LONG $0xc8c6f9c5; BYTE $0x01 // vshufpd $1, %xmm0, %xmm0, %xmm1 # xmm1 = xmm0[1,0] LONG $0xc158f8c5 // vaddps %xmm1, %xmm0, %xmm0 LONG $0xc816fac5 // vmovshdup %xmm0, %xmm1 # xmm1 = xmm0[1,1,3,3] LONG $0xc158fac5 // vaddss %xmm1, %xmm0, %xmm0 WORD $0x8548; BYTE $0xd2 // testq %rdx, %rdx JLE LBB11_16 LONG $0x0f10fac5 // vmovss (%rdi), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0x0e5cf2c5 // vsubss (%rsi), %xmm1, %xmm1 WORD $0xd089 // movl %edx, %eax LONG $0xc959f2c5 // vmulss %xmm1, %xmm1, %xmm1 LONG $0xc058f2c5 // vaddss %xmm0, %xmm1, %xmm0 LONG $0x01f88348 // cmpq $1, %rax JE LBB11_16 LONG $0x4f10fac5; BYTE $0x04 // vmovss 4(%rdi), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0x4e5cf2c5; BYTE $0x04 // vsubss 4(%rsi), %xmm1, %xmm1 LONG $0xc959f2c5 // vmulss %xmm1, %xmm1, %xmm1 LONG $0xc058f2c5 // vaddss %xmm0, %xmm1, %xmm0 WORD $0xf883; BYTE $0x02 // cmpl $2, %eax JE LBB11_16 LONG $0x4f10fac5; BYTE $0x08 // vmovss 8(%rdi), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0x4e5cf2c5; BYTE $0x08 // vsubss 8(%rsi), %xmm1, %xmm1 LONG $0xc959f2c5 // vmulss %xmm1, %xmm1, %xmm1 LONG $0xc058f2c5 // vaddss %xmm0, %xmm1, %xmm0 WORD $0xf883; BYTE $0x03 // cmpl $3, %eax JE LBB11_16 LONG $0x4f10fac5; BYTE $0x0c // vmovss 12(%rdi), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0x4e5cf2c5; BYTE $0x0c // vsubss 12(%rsi), %xmm1, %xmm1 LONG $0xc959f2c5 // vmulss %xmm1, %xmm1, %xmm1 LONG $0xc058f2c5 // vaddss %xmm0, %xmm1, %xmm0 WORD $0xf883; BYTE $0x04 // cmpl $4, %eax JE LBB11_16 LONG $0x4f10fac5; BYTE $0x10 // vmovss 16(%rdi), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0x4e5cf2c5; BYTE $0x10 // vsubss 16(%rsi), %xmm1, %xmm1 LONG $0xc959f2c5 // vmulss %xmm1, %xmm1, %xmm1 LONG $0xc058f2c5 // vaddss %xmm0, %xmm1, %xmm0 WORD $0xf883; BYTE $0x05 // cmpl $5, %eax JE LBB11_16 LONG $0x4f10fac5; BYTE $0x14 // vmovss 20(%rdi), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0x4e5cf2c5; BYTE $0x14 // vsubss 20(%rsi), %xmm1, %xmm1 LONG $0xc959f2c5 // vmulss %xmm1, %xmm1, %xmm1 LONG $0xc058f2c5 // vaddss %xmm0, %xmm1, %xmm0 WORD $0xf883; BYTE $0x06 // cmpl $6, %eax JE LBB11_16 LONG $0x4f10fac5; BYTE $0x18 // vmovss 24(%rdi), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0x4e5cf2c5; BYTE $0x18 // vsubss 24(%rsi), %xmm1, %xmm1 LONG $0xc959f2c5 // vmulss %xmm1, %xmm1, %xmm1 LONG $0xc058f2c5 // vaddss %xmm0, %xmm1, %xmm0 LBB11_16: LONG $0xc051fac5 // vsqrtss %xmm0, %xmm0, %xmm0 WORD $0x8948; BYTE $0xec // movq %rbp, %rsp BYTE $0x5d // popq %rbp WORD $0xf8c5; BYTE $0x77 // vzeroupper MOVSS X0, result+24(FP) RET TEXT ·_mm256_mm(SB), $0-88 MOVQ transA+0(FP), DI MOVQ transB+1(FP), SI MOVQ m+8(FP), DX MOVQ n+16(FP), CX MOVQ k+24(FP), R8 MOVQ a+32(FP), R9 PUSHQ ldc+72(FP) PUSHQ c+64(FP) PUSHQ ldb+56(FP) PUSHQ b+48(FP) PUSHQ lda+40(FP) PUSHQ $0 BYTE $0x55 // pushq %rbp WORD $0x8948; BYTE $0xe5 // movq %rsp, %rbp WORD $0x5741 // pushq %r15 WORD $0x5641 // pushq %r14 WORD $0x5541 // pushq %r13 WORD $0x5441 // pushq %r12 BYTE $0x53 // pushq %rbx LONG $0xf0e48348 // andq $-16, %rsp LONG $0x90ec8148; WORD $0x0000; BYTE $0x00 // subq $144, %rsp LONG $0x244c894c; BYTE $0x08 // movq %r9, 8(%rsp) # 8-byte Spill WORD $0x8948; BYTE $0xcb // movq %rcx, %rbx LONG $0x24148948 // movq %rdx, (%rsp) # 8-byte Spill LONG $0x30458b48 // movq 48(%rbp), %rax LONG $0x24448948; BYTE $0x68 // movq %rax, 104(%rsp) # 8-byte Spill LONG $0x28458b48 // movq 40(%rbp), %rax LONG $0x24448948; BYTE $0x70 // movq %rax, 112(%rsp) # 8-byte Spill LONG $0x206d8b4c // movq 32(%rbp), %r13 LONG $0x18458b48 // movq 24(%rbp), %rax LONG $0x24448948; BYTE $0x60 // movq %rax, 96(%rsp) # 8-byte Spill WORD $0xf889 // movl %edi, %eax WORD $0x0840; BYTE $0xf0 // orb %sil, %al LONG $0x2444894c; BYTE $0x58 // movq %r8, 88(%rsp) # 8-byte Spill JE LBB12_1 WORD $0xf089 // movl %esi, %eax WORD $0x0134 // xorb $1, %al WORD $0x0840; BYTE $0xf8 // orb %dil, %al JE LBB12_16 WORD $0xf989 // movl %edi, %ecx WORD $0xf180; BYTE $0x01 // xorb $1, %cl WORD $0x0840; BYTE $0xf1 // orb %sil, %cl WORD $0x854d; BYTE $0xc0 // testq %r8, %r8 WORD $0x9f0f; BYTE $0xc2 // setg %dl WORD $0x8548; BYTE $0xdb // testq %rbx, %rbx WORD $0x9f0f; BYTE $0xc0 // setg %al WORD $0xd020 // andb %dl, %al WORD $0xc984 // testb %cl, %cl JE LBB12_88 LONG $0x243c8348; BYTE $0x00 // cmpq $0, (%rsp) # 8-byte Folded Reload WORD $0x9f0f; BYTE $0xc1 // setg %cl WORD $0x2040; BYTE $0xf0 // andb %sil, %al WORD $0x2040; BYTE $0xf9 // andb %dil, %cl WORD $0xc120 // andb %al, %cl WORD $0xf980; BYTE $0x01 // cmpb $1, %cl JNE LBB12_117 LONG $0x30458b48 // movq 48(%rbp), %rax QUAD $0x00000000853c8d48 // leaq (,%rax,4), %rdi LONG $0xff408d49 // leaq -1(%r8), %rax LONG $0x45af0f48; BYTE $0x10 // imulq 16(%rbp), %rax LONG $0x180c8d49 // leaq (%r8,%rbx), %rcx LONG $0x18758b48 // movq 24(%rbp), %rsi LONG $0x8e0c8d48 // leaq (%rsi,%rcx,4), %rcx LONG $0xfcc18348 // addq $-4, %rcx LONG $0x244c8948; BYTE $0x28 // movq %rcx, 40(%rsp) # 8-byte Spill LONG $0x284d8b48 // movq 40(%rbp), %rcx LONG $0x99148d48 // leaq (%rcx,%rbx,4), %rdx LONG $0x24548948; BYTE $0x20 // movq %rdx, 32(%rsp) # 8-byte Spill LONG $0x24548b48; BYTE $0x08 // movq 8(%rsp), %rdx # 8-byte Reload LONG $0x82048d48 // leaq (%rdx,%rax,4), %rax LONG $0x04c08348 // addq $4, %rax LONG $0x24448948; BYTE $0x18 // movq %rax, 24(%rsp) # 8-byte Spill LONG $0x20fb8348 // cmpq $32, %rbx WORD $0x930f; BYTE $0xc0 // setae %al LONG $0x01fd8349 // cmpq $1, %r13 LONG $0xc0940f41 // sete %r8b WORD $0x2041; BYTE $0xc0 // andb %al, %r8b QUAD $0xffffffffffe0bc49; WORD $0x7fff // movabsq $9223372036854775776, %r12 # imm = 0x7FFFFFFFFFFFFFE0 WORD $0x2149; BYTE $0xdc // andq %rbx, %r12 LONG $0xff438d48 // leaq -1(%rbx), %rax LONG $0x24448948; BYTE $0x50 // movq %rax, 80(%rsp) # 8-byte Spill LONG $0x60468d48 // leaq 96(%rsi), %rax QUAD $0x0000008024848948 // movq %rax, 128(%rsp) # 8-byte Spill LONG $0x60718d4c // leaq 96(%rcx), %r14 QUAD $0x00000000ad048d4a // leaq (,%r13,4), %rax LONG $0x24448948; BYTE $0x40 // movq %rax, 64(%rsp) # 8-byte Spill QUAD $0x00000000ed1c8d4e // leaq (,%r13,8), %r11 LONG $0x04518d48 // leaq 4(%rcx), %rdx WORD $0xf641; BYTE $0xd0 // notb %r8b LONG $0x24448844; BYTE $0x78 // movb %r8b, 120(%rsp) # 1-byte Spill WORD $0x3145; BYTE $0xc0 // xorl %r8d, %r8d LONG $0x247c8948; BYTE $0x30 // movq %rdi, 48(%rsp) # 8-byte Spill JMP LBB12_104 LBB12_116: LONG $0x24448b4c; BYTE $0x38 // movq 56(%rsp), %r8 # 8-byte Reload WORD $0xff49; BYTE $0xc0 // incq %r8 LONG $0x247c8b48; BYTE $0x30 // movq 48(%rsp), %rdi # 8-byte Reload WORD $0x0149; BYTE $0xfe // addq %rdi, %r14 WORD $0x0148; BYTE $0xfa // addq %rdi, %rdx LONG $0x24043b4c // cmpq (%rsp), %r8 # 8-byte Folded Reload JE LBB12_117 LBB12_104: QUAD $0x000000000000b848; WORD $0x2000 // movabsq $2305843009213693952, %rax # imm = 0x2000000000000000 LONG $0x10458548 // testq %rax, 16(%rbp) WORD $0x950f; BYTE $0xc0 // setne %al WORD $0x8948; BYTE $0xf9 // movq %rdi, %rcx LONG $0xc8af0f49 // imulq %r8, %rcx LONG $0x284d8b4c // movq 40(%rbp), %r9 LONG $0x09348d49 // leaq (%r9,%rcx), %rsi LONG $0x244c0348; BYTE $0x20 // addq 32(%rsp), %rcx # 8-byte Folded Reload LONG $0x247c8b48; BYTE $0x08 // movq 8(%rsp), %rdi # 8-byte Reload LONG $0x87148d4e // leaq (%rdi,%r8,4), %r10 LONG $0x247c8b48; BYTE $0x18 // movq 24(%rsp), %rdi # 8-byte Reload LONG $0x873c8d4a // leaq (%rdi,%r8,4), %rdi LONG $0x2444894c; BYTE $0x38 // movq %r8, 56(%rsp) # 8-byte Spill LONG $0x45af0f4c; BYTE $0x30 // imulq 48(%rbp), %r8 LONG $0x81048d4f // leaq (%r9,%r8,4), %r8 LONG $0x2444894c; BYTE $0x48 // movq %r8, 72(%rsp) # 8-byte Spill WORD $0x3948; BYTE $0xfe // cmpq %rdi, %rsi LONG $0xc7920f40 // setb %dil WORD $0x3949; BYTE $0xca // cmpq %rcx, %r10 LONG $0xc0920f41 // setb %r8b WORD $0x2041; BYTE $0xf8 // andb %dil, %r8b WORD $0x0841; BYTE $0xc0 // orb %al, %r8b LONG $0x24743b48; BYTE $0x28 // cmpq 40(%rsp), %rsi # 8-byte Folded Reload WORD $0x920f; BYTE $0xc0 // setb %al LONG $0x187d8b48 // movq 24(%rbp), %rdi WORD $0x3948; BYTE $0xf9 // cmpq %rdi, %rcx WORD $0x970f; BYTE $0xc1 // seta %cl WORD $0xc120 // andb %al, %cl WORD $0x0844; BYTE $0xc1 // orb %r8b, %cl LONG $0x78244c0a // orb 120(%rsp), %cl # 1-byte Folded Reload LONG $0x10244c88 // movb %cl, 16(%rsp) # 1-byte Spill QUAD $0x00000080248c8b4c // movq 128(%rsp), %r9 # 8-byte Reload WORD $0xf631 // xorl %esi, %esi JMP LBB12_105 LBB12_115: WORD $0xff48; BYTE $0xc6 // incq %rsi LONG $0x04c18349 // addq $4, %r9 LONG $0x04c78348 // addq $4, %rdi LONG $0x24743b48; BYTE $0x58 // cmpq 88(%rsp), %rsi # 8-byte Folded Reload JE LBB12_116 LBB12_105: WORD $0x8948; BYTE $0xf0 // movq %rsi, %rax LONG $0x45af0f48; BYTE $0x10 // imulq 16(%rbp), %rax LONG $0x102444f6; BYTE $0x01 // testb $1, 16(%rsp) # 1-byte Folded Reload JE LBB12_107 WORD $0x3145; BYTE $0xc0 // xorl %r8d, %r8d JMP LBB12_110 LBB12_107: LONG $0x187dc2c4; WORD $0x8204 // vbroadcastss (%r10,%rax,4), %ymm0 WORD $0xc931 // xorl %ecx, %ecx LBB12_108: LONG $0x597cc1c4; WORD $0x894c; BYTE $0xa0 // vmulps -96(%r9,%rcx,4), %ymm0, %ymm1 LONG $0x5874c1c4; WORD $0x8e4c; BYTE $0xa0 // vaddps -96(%r14,%rcx,4), %ymm1, %ymm1 LONG $0x597cc1c4; WORD $0x8954; BYTE $0xc0 // vmulps -64(%r9,%rcx,4), %ymm0, %ymm2 LONG $0x586cc1c4; WORD $0x8e54; BYTE $0xc0 // vaddps -64(%r14,%rcx,4), %ymm2, %ymm2 LONG $0x597cc1c4; WORD $0x895c; BYTE $0xe0 // vmulps -32(%r9,%rcx,4), %ymm0, %ymm3 LONG $0x5864c1c4; WORD $0x8e5c; BYTE $0xe0 // vaddps -32(%r14,%rcx,4), %ymm3, %ymm3 LONG $0x597cc1c4; WORD $0x8924 // vmulps (%r9,%rcx,4), %ymm0, %ymm4 LONG $0x585cc1c4; WORD $0x8e24 // vaddps (%r14,%rcx,4), %ymm4, %ymm4 LONG $0x117cc1c4; WORD $0x8e4c; BYTE $0xa0 // vmovups %ymm1, -96(%r14,%rcx,4) LONG $0x117cc1c4; WORD $0x8e54; BYTE $0xc0 // vmovups %ymm2, -64(%r14,%rcx,4) LONG $0x117cc1c4; WORD $0x8e5c; BYTE $0xe0 // vmovups %ymm3, -32(%r14,%rcx,4) LONG $0x117cc1c4; WORD $0x8e24 // vmovups %ymm4, (%r14,%rcx,4) LONG $0x20c18348 // addq $32, %rcx WORD $0x3949; BYTE $0xcc // cmpq %rcx, %r12 JNE LBB12_108 WORD $0x894d; BYTE $0xe0 // movq %r12, %r8 WORD $0x3949; BYTE $0xdc // cmpq %rbx, %r12 JE LBB12_115 LBB12_110: WORD $0x894d; BYTE $0xc7 // movq %r8, %r15 WORD $0xc3f6; BYTE $0x01 // testb $1, %bl JE LBB12_112 LONG $0x184d8b48 // movq 24(%rbp), %rcx LONG $0xb10c8d48 // leaq (%rcx,%rsi,4), %rcx LONG $0x107ac1c4; WORD $0x8204 // vmovss (%r10,%rax,4), %xmm0 # xmm0 = mem[0],zero,zero,zero WORD $0x894d; BYTE $0xc7 // movq %r8, %r15 LONG $0x7daf0f4c; BYTE $0x20 // imulq 32(%rbp), %r15 LONG $0x597aa1c4; WORD $0xb904 // vmulss (%rcx,%r15,4), %xmm0, %xmm0 LONG $0x244c8b48; BYTE $0x48 // movq 72(%rsp), %rcx # 8-byte Reload LONG $0x587aa1c4; WORD $0x8104 // vaddss (%rcx,%r8,4), %xmm0, %xmm0 LONG $0x117aa1c4; WORD $0x8104 // vmovss %xmm0, (%rcx,%r8,4) WORD $0x894d; BYTE $0xc7 // movq %r8, %r15 LONG $0x01cf8349 // orq $1, %r15 LBB12_112: LONG $0x24443b4c; BYTE $0x50 // cmpq 80(%rsp), %r8 # 8-byte Folded Reload JE LBB12_115 LONG $0x244c8b48; BYTE $0x40 // movq 64(%rsp), %rcx # 8-byte Reload WORD $0x8949; BYTE $0xcd // movq %rcx, %r13 LONG $0xefaf0f4d // imulq %r15, %r13 LONG $0x01478d4d // leaq 1(%r15), %r8 LONG $0xc1af0f4c // imulq %rcx, %r8 WORD $0x8948; BYTE $0xf9 // movq %rdi, %rcx LBB12_114: LONG $0x107ac1c4; WORD $0x8204 // vmovss (%r10,%rax,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x597aa1c4; WORD $0x2904 // vmulss (%rcx,%r13), %xmm0, %xmm0 LONG $0x587aa1c4; WORD $0xba44; BYTE $0xfc // vaddss -4(%rdx,%r15,4), %xmm0, %xmm0 LONG $0x117aa1c4; WORD $0xba44; BYTE $0xfc // vmovss %xmm0, -4(%rdx,%r15,4) LONG $0x107ac1c4; WORD $0x8204 // vmovss (%r10,%rax,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x597aa1c4; WORD $0x0104 // vmulss (%rcx,%r8), %xmm0, %xmm0 LONG $0x587aa1c4; WORD $0xba04 // vaddss (%rdx,%r15,4), %xmm0, %xmm0 LONG $0x117aa1c4; WORD $0xba04 // vmovss %xmm0, (%rdx,%r15,4) LONG $0x02c78349 // addq $2, %r15 WORD $0x014c; BYTE $0xd9 // addq %r11, %rcx WORD $0x394c; BYTE $0xfb // cmpq %r15, %rbx JNE LBB12_114 JMP LBB12_115 LBB12_1: LONG $0x243c8348; BYTE $0x00 // cmpq $0, (%rsp) # 8-byte Folded Reload WORD $0x9e0f; BYTE $0xc0 // setle %al WORD $0x854d; BYTE $0xc0 // testq %r8, %r8 WORD $0x9e0f; BYTE $0xc1 // setle %cl WORD $0xc108 // orb %al, %cl WORD $0x8548; BYTE $0xdb // testq %rbx, %rbx WORD $0x9e0f; BYTE $0xc0 // setle %al WORD $0xc808 // orb %cl, %al JNE LBB12_117 LONG $0x30458b48 // movq 48(%rbp), %rax QUAD $0x00000000853c8d48 // leaq (,%rax,4), %rdi LONG $0x10458b48 // movq 16(%rbp), %rax QUAD $0x0000000085048d48 // leaq (,%rax,4), %rax LONG $0x24448948; BYTE $0x38 // movq %rax, 56(%rsp) # 8-byte Spill LONG $0xff408d49 // leaq -1(%r8), %rax LONG $0xc5af0f49 // imulq %r13, %rax WORD $0x0148; BYTE $0xd8 // addq %rbx, %rax LONG $0x18758b48 // movq 24(%rbp), %rsi LONG $0x86048d48 // leaq (%rsi,%rax,4), %rax LONG $0x24448948; BYTE $0x30 // movq %rax, 48(%rsp) # 8-byte Spill LONG $0x28458b48 // movq 40(%rbp), %rax LONG $0x980c8d48 // leaq (%rax,%rbx,4), %rcx LONG $0x244c8948; BYTE $0x28 // movq %rcx, 40(%rsp) # 8-byte Spill LONG $0x244c8b48; BYTE $0x08 // movq 8(%rsp), %rcx # 8-byte Reload LONG $0x810c8d4a // leaq (%rcx,%r8,4), %rcx LONG $0x244c8948; BYTE $0x20 // movq %rcx, 32(%rsp) # 8-byte Spill QUAD $0xffffffffffe0bc49; WORD $0x7fff // movabsq $9223372036854775776, %r12 # imm = 0x7FFFFFFFFFFFFFE0 WORD $0x2149; BYTE $0xdc // andq %rbx, %r12 LONG $0xff4b8d48 // leaq -1(%rbx), %rcx LONG $0x244c8948; BYTE $0x10 // movq %rcx, 16(%rsp) # 8-byte Spill LONG $0x604e8d48 // leaq 96(%rsi), %rcx LONG $0x244c8948; BYTE $0x18 // movq %rcx, 24(%rsp) # 8-byte Spill QUAD $0x00000000ad148d4a // leaq (,%r13,4), %rdx LONG $0x60708d4c // leaq 96(%rax), %r14 LONG $0x044e8d48 // leaq 4(%rsi), %rcx LONG $0x244c8948; BYTE $0x78 // movq %rcx, 120(%rsp) # 8-byte Spill LONG $0x04588d4c // leaq 4(%rax), %r11 WORD $0x3145; BYTE $0xff // xorl %r15d, %r15d LONG $0x247c8948; BYTE $0x40 // movq %rdi, 64(%rsp) # 8-byte Spill JMP LBB12_3 LBB12_14: LONG $0x247c8b4c; BYTE $0x50 // movq 80(%rsp), %r15 # 8-byte Reload WORD $0xff49; BYTE $0xc7 // incq %r15 LONG $0x247c8b48; BYTE $0x40 // movq 64(%rsp), %rdi # 8-byte Reload WORD $0x0149; BYTE $0xfe // addq %rdi, %r14 WORD $0x0149; BYTE $0xfb // addq %rdi, %r11 LONG $0x243c3b4c // cmpq (%rsp), %r15 # 8-byte Folded Reload LONG $0x206d8b4c // movq 32(%rbp), %r13 JE LBB12_117 LBB12_3: QUAD $0x000000000000b848; WORD $0x2000 // movabsq $2305843009213693952, %rax # imm = 0x2000000000000000 WORD $0x8549; BYTE $0xc5 // testq %rax, %r13 LONG $0x2444950f; BYTE $0x48 // setne 72(%rsp) # 1-byte Folded Spill WORD $0x8948; BYTE $0xf9 // movq %rdi, %rcx LONG $0xcfaf0f49 // imulq %r15, %rcx LONG $0x286d8b4c // movq 40(%rbp), %r13 LONG $0x29348d4a // leaq (%rcx,%r13), %rsi LONG $0x244c0348; BYTE $0x28 // addq 40(%rsp), %rcx # 8-byte Folded Reload LONG $0x247c8b48; BYTE $0x38 // movq 56(%rsp), %rdi # 8-byte Reload LONG $0xffaf0f49 // imulq %r15, %rdi LONG $0x24448b48; BYTE $0x08 // movq 8(%rsp), %rax # 8-byte Reload LONG $0x380c8d4c // leaq (%rax,%rdi), %r9 LONG $0x247c0348; BYTE $0x20 // addq 32(%rsp), %rdi # 8-byte Folded Reload WORD $0x894d; BYTE $0xfa // movq %r15, %r10 LONG $0x55af0f4c; BYTE $0x10 // imulq 16(%rbp), %r10 LONG $0x247c894c; BYTE $0x50 // movq %r15, 80(%rsp) # 8-byte Spill LONG $0x7daf0f4c; BYTE $0x30 // imulq 48(%rbp), %r15 WORD $0x3948; BYTE $0xfe // cmpq %rdi, %rsi LONG $0xc7920f40 // setb %dil WORD $0x3949; BYTE $0xc9 // cmpq %rcx, %r9 LONG $0x90148d4e // leaq (%rax,%r10,4), %r10 LONG $0xbd4c8d4f; BYTE $0x00 // leaq (%r13,%r15,4), %r9 LONG $0xc5920f41 // setb %r13b WORD $0x2041; BYTE $0xfd // andb %dil, %r13b LONG $0x24743b48; BYTE $0x30 // cmpq 48(%rsp), %rsi # 8-byte Folded Reload LONG $0xc6920f40 // setb %sil LONG $0x184d3b48 // cmpq 24(%rbp), %rcx WORD $0x970f; BYTE $0xc1 // seta %cl WORD $0x2040; BYTE $0xf1 // andb %sil, %cl LONG $0x246c0a44; BYTE $0x48 // orb 72(%rsp), %r13b # 1-byte Folded Reload WORD $0x0841; BYTE $0xcd // orb %cl, %r13b LONG $0x24448b48; BYTE $0x78 // movq 120(%rsp), %rax # 8-byte Reload LONG $0x244c8b48; BYTE $0x18 // movq 24(%rsp), %rcx # 8-byte Reload WORD $0xff31 // xorl %edi, %edi JMP LBB12_4 LBB12_13: WORD $0xff48; BYTE $0xc7 // incq %rdi WORD $0x0148; BYTE $0xd1 // addq %rdx, %rcx WORD $0x0148; BYTE $0xd0 // addq %rdx, %rax WORD $0x394c; BYTE $0xc7 // cmpq %r8, %rdi JE LBB12_14 LBB12_4: LONG $0x20fb8348 // cmpq $32, %rbx LONG $0xc6920f40 // setb %sil WORD $0x0844; BYTE $0xee // orb %r13b, %sil LONG $0x01c6f640 // testb $1, %sil JE LBB12_6 WORD $0xf631 // xorl %esi, %esi JMP LBB12_9 LBB12_6: LONG $0x187dc2c4; WORD $0xba04 // vbroadcastss (%r10,%rdi,4), %ymm0 WORD $0xf631 // xorl %esi, %esi LBB12_7: LONG $0x4c59fcc5; WORD $0xa0b1 // vmulps -96(%rcx,%rsi,4), %ymm0, %ymm1 LONG $0x5874c1c4; WORD $0xb64c; BYTE $0xa0 // vaddps -96(%r14,%rsi,4), %ymm1, %ymm1 LONG $0x5459fcc5; WORD $0xc0b1 // vmulps -64(%rcx,%rsi,4), %ymm0, %ymm2 LONG $0x586cc1c4; WORD $0xb654; BYTE $0xc0 // vaddps -64(%r14,%rsi,4), %ymm2, %ymm2 LONG $0x5c59fcc5; WORD $0xe0b1 // vmulps -32(%rcx,%rsi,4), %ymm0, %ymm3 LONG $0x5864c1c4; WORD $0xb65c; BYTE $0xe0 // vaddps -32(%r14,%rsi,4), %ymm3, %ymm3 LONG $0x2459fcc5; BYTE $0xb1 // vmulps (%rcx,%rsi,4), %ymm0, %ymm4 LONG $0x585cc1c4; WORD $0xb624 // vaddps (%r14,%rsi,4), %ymm4, %ymm4 LONG $0x117cc1c4; WORD $0xb64c; BYTE $0xa0 // vmovups %ymm1, -96(%r14,%rsi,4) LONG $0x117cc1c4; WORD $0xb654; BYTE $0xc0 // vmovups %ymm2, -64(%r14,%rsi,4) LONG $0x117cc1c4; WORD $0xb65c; BYTE $0xe0 // vmovups %ymm3, -32(%r14,%rsi,4) LONG $0x117cc1c4; WORD $0xb624 // vmovups %ymm4, (%r14,%rsi,4) LONG $0x20c68348 // addq $32, %rsi WORD $0x3949; BYTE $0xf4 // cmpq %rsi, %r12 JNE LBB12_7 WORD $0x894c; BYTE $0xe6 // movq %r12, %rsi WORD $0x3949; BYTE $0xdc // cmpq %rbx, %r12 JE LBB12_13 LBB12_9: WORD $0x8949; BYTE $0xf7 // movq %rsi, %r15 WORD $0xc3f6; BYTE $0x01 // testb $1, %bl JE LBB12_11 WORD $0x8949; BYTE $0xff // movq %rdi, %r15 LONG $0x7daf0f4c; BYTE $0x20 // imulq 32(%rbp), %r15 LONG $0x18458b4c // movq 24(%rbp), %r8 LONG $0xb83c8d4f // leaq (%r8,%r15,4), %r15 LONG $0x24448b4c; BYTE $0x58 // movq 88(%rsp), %r8 # 8-byte Reload LONG $0x107ac1c4; WORD $0xba04 // vmovss (%r10,%rdi,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x597ac1c4; WORD $0xb704 // vmulss (%r15,%rsi,4), %xmm0, %xmm0 LONG $0x587ac1c4; WORD $0xb104 // vaddss (%r9,%rsi,4), %xmm0, %xmm0 LONG $0x117ac1c4; WORD $0xb104 // vmovss %xmm0, (%r9,%rsi,4) WORD $0x8949; BYTE $0xf7 // movq %rsi, %r15 LONG $0x01cf8349 // orq $1, %r15 LBB12_11: LONG $0x24743b48; BYTE $0x10 // cmpq 16(%rsp), %rsi # 8-byte Folded Reload JE LBB12_13 LBB12_12: LONG $0x107ac1c4; WORD $0xba04 // vmovss (%r10,%rdi,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x597aa1c4; WORD $0xb844; BYTE $0xfc // vmulss -4(%rax,%r15,4), %xmm0, %xmm0 LONG $0x587a81c4; WORD $0xbb44; BYTE $0xfc // vaddss -4(%r11,%r15,4), %xmm0, %xmm0 LONG $0x117a81c4; WORD $0xbb44; BYTE $0xfc // vmovss %xmm0, -4(%r11,%r15,4) LONG $0x107ac1c4; WORD $0xba04 // vmovss (%r10,%rdi,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x597aa1c4; WORD $0xb804 // vmulss (%rax,%r15,4), %xmm0, %xmm0 LONG $0x587a81c4; WORD $0xbb04 // vaddss (%r11,%r15,4), %xmm0, %xmm0 LONG $0x117a81c4; WORD $0xbb04 // vmovss %xmm0, (%r11,%r15,4) LONG $0x02c78349 // addq $2, %r15 WORD $0x394c; BYTE $0xfb // cmpq %r15, %rbx JNE LBB12_12 JMP LBB12_13 LBB12_16: LONG $0x243c8348; BYTE $0x00 // cmpq $0, (%rsp) # 8-byte Folded Reload JLE LBB12_117 LONG $0x07408d49 // leaq 7(%r8), %rax WORD $0x854d; BYTE $0xc0 // testq %r8, %r8 LONG $0xc0490f49 // cmovnsq %r8, %rax WORD $0x8548; BYTE $0xdb // testq %rbx, %rbx JLE LBB12_117 WORD $0x8948; BYTE $0xc1 // movq %rax, %rcx LONG $0xf8e18348 // andq $-8, %rcx WORD $0x2949; BYTE $0xc8 // subq %rcx, %r8 LONG $0x03f8c148 // sarq $3, %rax WORD $0xf883; BYTE $0x02 // cmpl $2, %eax JL LBB12_47 WORD $0x854d; BYTE $0xc0 // testq %r8, %r8 JLE LBB12_20 WORD $0x8944; BYTE $0xc6 // movl %r8d, %esi WORD $0x788d; BYTE $0xff // leal -1(%rax), %edi LONG $0xfe408d44 // leal -2(%rax), %r8d WORD $0x8941; BYTE $0xfb // movl %edi, %r11d LONG $0x03e38341 // andl $3, %r11d WORD $0xe783; BYTE $0xfc // andl $-4, %edi WORD $0x3145; BYTE $0xc9 // xorl %r9d, %r9d JMP LBB12_25 LBB12_39: WORD $0xff49; BYTE $0xc1 // incq %r9 LONG $0x240c3b4c // cmpq (%rsp), %r9 # 8-byte Folded Reload JE LBB12_117 LBB12_25: WORD $0x894c; BYTE $0xc8 // movq %r9, %rax LONG $0x45af0f48; BYTE $0x10 // imulq 16(%rbp), %rax LONG $0x244c8b48; BYTE $0x08 // movq 8(%rsp), %rcx # 8-byte Reload LONG $0x813c8d4c // leaq (%rcx,%rax,4), %r15 LONG $0x81148d4c // leaq (%rcx,%rax,4), %r10 LONG $0x20c28349 // addq $32, %r10 WORD $0x894c; BYTE $0xc8 // movq %r9, %rax LONG $0x45af0f48; BYTE $0x30 // imulq 48(%rbp), %rax LONG $0x284d8b48 // movq 40(%rbp), %rcx LONG $0x81348d4c // leaq (%rcx,%rax,4), %r14 WORD $0xd231 // xorl %edx, %edx JMP LBB12_26 LBB12_38: LONG $0x117ac1c4; WORD $0x9604 // vmovss %xmm0, (%r14,%rdx,4) WORD $0xff48; BYTE $0xc2 // incq %rdx WORD $0x3948; BYTE $0xda // cmpq %rbx, %rdx LONG $0x206d8b4c // movq 32(%rbp), %r13 JE LBB12_39 LBB12_26: WORD $0x8948; BYTE $0xd0 // movq %rdx, %rax LONG $0xc5af0f49 // imulq %r13, %rax LONG $0x107cc1c4; BYTE $0x07 // vmovups (%r15), %ymm0 LONG $0x184d8b48 // movq 24(%rbp), %rcx LONG $0x0459fcc5; BYTE $0x81 // vmulps (%rcx,%rax,4), %ymm0, %ymm0 LONG $0x812c8d4c // leaq (%rcx,%rax,4), %r13 LONG $0x20c58349 // addq $32, %r13 WORD $0xf889 // movl %edi, %eax WORD $0x894d; BYTE $0xd4 // movq %r10, %r12 LONG $0x03f88341 // cmpl $3, %r8d JB LBB12_28 LBB12_27: LONG $0x107cc1c4; WORD $0x240c // vmovups (%r12), %ymm1 LONG $0x107cc1c4; WORD $0x2454; BYTE $0x20 // vmovups 32(%r12), %ymm2 LONG $0x107cc1c4; WORD $0x245c; BYTE $0x40 // vmovups 64(%r12), %ymm3 LONG $0x107cc1c4; WORD $0x2464; BYTE $0x60 // vmovups 96(%r12), %ymm4 LONG $0x5974c1c4; WORD $0x004d // vmulps (%r13), %ymm1, %ymm1 LONG $0xc158fcc5 // vaddps %ymm1, %ymm0, %ymm0 LONG $0x596cc1c4; WORD $0x204d // vmulps 32(%r13), %ymm2, %ymm1 LONG $0x5964c1c4; WORD $0x4055 // vmulps 64(%r13), %ymm3, %ymm2 LONG $0xc158fcc5 // vaddps %ymm1, %ymm0, %ymm0 LONG $0xc258fcc5 // vaddps %ymm2, %ymm0, %ymm0 LONG $0x595cc1c4; WORD $0x604d // vmulps 96(%r13), %ymm4, %ymm1 LONG $0xc158fcc5 // vaddps %ymm1, %ymm0, %ymm0 LONG $0x80ec8349 // subq $-128, %r12 LONG $0x80ed8349 // subq $-128, %r13 WORD $0xc083; BYTE $0xfc // addl $-4, %eax JNE LBB12_27 LBB12_28: WORD $0x8545; BYTE $0xdb // testl %r11d, %r11d JE LBB12_31 WORD $0x8944; BYTE $0xd8 // movl %r11d, %eax LBB12_30: LONG $0x107cc1c4; WORD $0x240c // vmovups (%r12), %ymm1 LONG $0x5974c1c4; WORD $0x004d // vmulps (%r13), %ymm1, %ymm1 LONG $0xc158fcc5 // vaddps %ymm1, %ymm0, %ymm0 LONG $0x20c48349 // addq $32, %r12 LONG $0x20c58349 // addq $32, %r13 WORD $0xc8ff // decl %eax JNE LBB12_30 LBB12_31: LONG $0x197de3c4; WORD $0x01c1 // vextractf128 $1, %ymm0, %xmm1 LONG $0xc058f0c5 // vaddps %xmm0, %xmm1, %xmm0 LONG $0xc8c6f9c5; BYTE $0x01 // vshufpd $1, %xmm0, %xmm0, %xmm1 # xmm1 = xmm0[1,0] LONG $0xc158f8c5 // vaddps %xmm1, %xmm0, %xmm0 LONG $0xc816fac5 // vmovshdup %xmm0, %xmm1 # xmm1 = xmm0[1,1,3,3] LONG $0xc158fac5 // vaddss %xmm1, %xmm0, %xmm0 LONG $0x107ac1c4; WORD $0x240c // vmovss (%r12), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0x5972c1c4; WORD $0x004d // vmulss (%r13), %xmm1, %xmm1 LONG $0xc058f2c5 // vaddss %xmm0, %xmm1, %xmm0 WORD $0xfe83; BYTE $0x01 // cmpl $1, %esi JE LBB12_38 LONG $0x107ac1c4; WORD $0x244c; BYTE $0x04 // vmovss 4(%r12), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0x5972c1c4; WORD $0x044d // vmulss 4(%r13), %xmm1, %xmm1 LONG $0xc058f2c5 // vaddss %xmm0, %xmm1, %xmm0 WORD $0xfe83; BYTE $0x02 // cmpl $2, %esi JE LBB12_38 LONG $0x107ac1c4; WORD $0x244c; BYTE $0x08 // vmovss 8(%r12), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0x5972c1c4; WORD $0x084d // vmulss 8(%r13), %xmm1, %xmm1 LONG $0xc058f2c5 // vaddss %xmm0, %xmm1, %xmm0 WORD $0xfe83; BYTE $0x03 // cmpl $3, %esi JE LBB12_38 LONG $0x107ac1c4; WORD $0x244c; BYTE $0x0c // vmovss 12(%r12), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0x5972c1c4; WORD $0x0c4d // vmulss 12(%r13), %xmm1, %xmm1 LONG $0xc058f2c5 // vaddss %xmm0, %xmm1, %xmm0 WORD $0xfe83; BYTE $0x04 // cmpl $4, %esi JE LBB12_38 LONG $0x107ac1c4; WORD $0x244c; BYTE $0x10 // vmovss 16(%r12), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0x5972c1c4; WORD $0x104d // vmulss 16(%r13), %xmm1, %xmm1 LONG $0xc058f2c5 // vaddss %xmm0, %xmm1, %xmm0 WORD $0xfe83; BYTE $0x05 // cmpl $5, %esi JE LBB12_38 LONG $0x107ac1c4; WORD $0x244c; BYTE $0x14 // vmovss 20(%r12), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0x5972c1c4; WORD $0x144d // vmulss 20(%r13), %xmm1, %xmm1 LONG $0xc058f2c5 // vaddss %xmm0, %xmm1, %xmm0 WORD $0xfe83; BYTE $0x06 // cmpl $6, %esi JE LBB12_38 LONG $0x107ac1c4; WORD $0x244c; BYTE $0x18 // vmovss 24(%r12), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0x5972c1c4; WORD $0x184d // vmulss 24(%r13), %xmm1, %xmm1 LONG $0xc058f2c5 // vaddss %xmm0, %xmm1, %xmm0 JMP LBB12_38 LBB12_88: LONG $0x243c8348; BYTE $0x00 // cmpq $0, (%rsp) # 8-byte Folded Reload WORD $0x9e0f; BYTE $0xc1 // setle %cl WORD $0x0134 // xorb $1, %al WORD $0xc808 // orb %cl, %al JNE LBB12_117 LONG $0x30458b48 // movq 48(%rbp), %rax QUAD $0x00000000853c8d48 // leaq (,%rax,4), %rdi LONG $0xff408d49 // leaq -1(%r8), %rax WORD $0x8948; BYTE $0xc1 // movq %rax, %rcx LONG $0x4daf0f48; BYTE $0x10 // imulq 16(%rbp), %rcx LONG $0xc5af0f49 // imulq %r13, %rax WORD $0x0148; BYTE $0xd8 // addq %rbx, %rax LONG $0x18758b48 // movq 24(%rbp), %rsi LONG $0x86048d48 // leaq (%rsi,%rax,4), %rax LONG $0x24448948; BYTE $0x38 // movq %rax, 56(%rsp) # 8-byte Spill LONG $0x28458b48 // movq 40(%rbp), %rax LONG $0x98148d48 // leaq (%rax,%rbx,4), %rdx LONG $0x24548948; BYTE $0x30 // movq %rdx, 48(%rsp) # 8-byte Spill LONG $0x24548b48; BYTE $0x08 // movq 8(%rsp), %rdx # 8-byte Reload LONG $0x8a0c8d48 // leaq (%rdx,%rcx,4), %rcx LONG $0x04c18348 // addq $4, %rcx LONG $0x244c8948; BYTE $0x28 // movq %rcx, 40(%rsp) # 8-byte Spill QUAD $0xffffffffffe0bf49; WORD $0x7fff // movabsq $9223372036854775776, %r15 # imm = 0x7FFFFFFFFFFFFFE0 WORD $0x2149; BYTE $0xdf // andq %rbx, %r15 LONG $0xff4b8d48 // leaq -1(%rbx), %rcx LONG $0x244c8948; BYTE $0x10 // movq %rcx, 16(%rsp) # 8-byte Spill LONG $0x604e8d48 // leaq 96(%rsi), %rcx LONG $0x244c8948; BYTE $0x20 // movq %rcx, 32(%rsp) # 8-byte Spill QUAD $0x00000000ad148d4a // leaq (,%r13,4), %rdx LONG $0x60588d4c // leaq 96(%rax), %r11 LONG $0x044e8d48 // leaq 4(%rsi), %rcx LONG $0x244c8948; BYTE $0x18 // movq %rcx, 24(%rsp) # 8-byte Spill LONG $0x04708d48 // leaq 4(%rax), %rsi WORD $0x3145; BYTE $0xf6 // xorl %r14d, %r14d LONG $0x247c8948; BYTE $0x40 // movq %rdi, 64(%rsp) # 8-byte Spill JMP LBB12_90 LBB12_101: LONG $0x24748b4c; BYTE $0x48 // movq 72(%rsp), %r14 # 8-byte Reload WORD $0xff49; BYTE $0xc6 // incq %r14 LONG $0x247c8b48; BYTE $0x40 // movq 64(%rsp), %rdi # 8-byte Reload WORD $0x0149; BYTE $0xfb // addq %rdi, %r11 WORD $0x0148; BYTE $0xfe // addq %rdi, %rsi LONG $0x24343b4c // cmpq (%rsp), %r14 # 8-byte Folded Reload LONG $0x206d8b4c // movq 32(%rbp), %r13 JE LBB12_117 LBB12_90: QUAD $0x000000000000b948; WORD $0x2000 // movabsq $2305843009213693952, %rcx # imm = 0x2000000000000000 WORD $0x8549; BYTE $0xcd // testq %rcx, %r13 WORD $0x950f; BYTE $0xc0 // setne %al LONG $0x104d8548 // testq %rcx, 16(%rbp) WORD $0x950f; BYTE $0xc1 // setne %cl LONG $0xfeaf0f49 // imulq %r14, %rdi LONG $0x286d8b4c // movq 40(%rbp), %r13 LONG $0x2f048d4e // leaq (%rdi,%r13), %r8 LONG $0x247c0348; BYTE $0x30 // addq 48(%rsp), %rdi # 8-byte Folded Reload LONG $0x244c8b4c; BYTE $0x08 // movq 8(%rsp), %r9 # 8-byte Reload LONG $0xb1148d4f // leaq (%r9,%r14,4), %r10 LONG $0x244c8b4c; BYTE $0x28 // movq 40(%rsp), %r9 # 8-byte Reload LONG $0xb10c8d4f // leaq (%r9,%r14,4), %r9 LONG $0x2474894c; BYTE $0x48 // movq %r14, 72(%rsp) # 8-byte Spill LONG $0x75af0f4c; BYTE $0x30 // imulq 48(%rbp), %r14 WORD $0x394d; BYTE $0xc8 // cmpq %r9, %r8 LONG $0xc1920f41 // setb %r9b WORD $0x3949; BYTE $0xfa // cmpq %rdi, %r10 LONG $0xc4920f41 // setb %r12b WORD $0x2045; BYTE $0xcc // andb %r9b, %r12b QUAD $0x00000000b50c8d4e // leaq (,%r14,4), %r9 WORD $0x014d; BYTE $0xe9 // addq %r13, %r9 LONG $0x244c894c; BYTE $0x50 // movq %r9, 80(%rsp) # 8-byte Spill WORD $0x0841; BYTE $0xcc // orb %cl, %r12b LONG $0x24443b4c; BYTE $0x38 // cmpq 56(%rsp), %r8 # 8-byte Folded Reload WORD $0x920f; BYTE $0xc1 // setb %cl LONG $0x187d3b48 // cmpq 24(%rbp), %rdi LONG $0xc1970f41 // seta %r9b WORD $0x2041; BYTE $0xc9 // andb %cl, %r9b WORD $0x0841; BYTE $0xc1 // orb %al, %r9b WORD $0x0845; BYTE $0xe1 // orb %r12b, %r9b LONG $0x24448b48; BYTE $0x18 // movq 24(%rsp), %rax # 8-byte Reload LONG $0x244c8b48; BYTE $0x20 // movq 32(%rsp), %rcx # 8-byte Reload WORD $0x3145; BYTE $0xf6 // xorl %r14d, %r14d JMP LBB12_91 LBB12_100: WORD $0xff49; BYTE $0xc6 // incq %r14 WORD $0x0148; BYTE $0xd1 // addq %rdx, %rcx WORD $0x0148; BYTE $0xd0 // addq %rdx, %rax LONG $0x24743b4c; BYTE $0x58 // cmpq 88(%rsp), %r14 # 8-byte Folded Reload JE LBB12_101 LBB12_91: LONG $0x20fb8348 // cmpq $32, %rbx LONG $0xc0920f41 // setb %r8b WORD $0x894c; BYTE $0xf7 // movq %r14, %rdi LONG $0x7daf0f48; BYTE $0x10 // imulq 16(%rbp), %rdi WORD $0x0845; BYTE $0xc8 // orb %r9b, %r8b LONG $0x01c0f641 // testb $1, %r8b JE LBB12_93 WORD $0x3145; BYTE $0xc0 // xorl %r8d, %r8d JMP LBB12_96 LBB12_93: LONG $0x187dc2c4; WORD $0xba04 // vbroadcastss (%r10,%rdi,4), %ymm0 WORD $0x3145; BYTE $0xc0 // xorl %r8d, %r8d LBB12_94: LONG $0x597ca1c4; WORD $0x814c; BYTE $0xa0 // vmulps -96(%rcx,%r8,4), %ymm0, %ymm1 LONG $0x587481c4; WORD $0x834c; BYTE $0xa0 // vaddps -96(%r11,%r8,4), %ymm1, %ymm1 LONG $0x597ca1c4; WORD $0x8154; BYTE $0xc0 // vmulps -64(%rcx,%r8,4), %ymm0, %ymm2 LONG $0x586c81c4; WORD $0x8354; BYTE $0xc0 // vaddps -64(%r11,%r8,4), %ymm2, %ymm2 LONG $0x597ca1c4; WORD $0x815c; BYTE $0xe0 // vmulps -32(%rcx,%r8,4), %ymm0, %ymm3 LONG $0x586481c4; WORD $0x835c; BYTE $0xe0 // vaddps -32(%r11,%r8,4), %ymm3, %ymm3 LONG $0x597ca1c4; WORD $0x8124 // vmulps (%rcx,%r8,4), %ymm0, %ymm4 LONG $0x585c81c4; WORD $0x8324 // vaddps (%r11,%r8,4), %ymm4, %ymm4 LONG $0x117c81c4; WORD $0x834c; BYTE $0xa0 // vmovups %ymm1, -96(%r11,%r8,4) LONG $0x117c81c4; WORD $0x8354; BYTE $0xc0 // vmovups %ymm2, -64(%r11,%r8,4) LONG $0x117c81c4; WORD $0x835c; BYTE $0xe0 // vmovups %ymm3, -32(%r11,%r8,4) LONG $0x117c81c4; WORD $0x8324 // vmovups %ymm4, (%r11,%r8,4) LONG $0x20c08349 // addq $32, %r8 WORD $0x394d; BYTE $0xc7 // cmpq %r8, %r15 JNE LBB12_94 WORD $0x894d; BYTE $0xf8 // movq %r15, %r8 WORD $0x3949; BYTE $0xdf // cmpq %rbx, %r15 JE LBB12_100 LBB12_96: WORD $0x894d; BYTE $0xc4 // movq %r8, %r12 WORD $0xc3f6; BYTE $0x01 // testb $1, %bl JE LBB12_98 WORD $0x894d; BYTE $0xf4 // movq %r14, %r12 LONG $0x65af0f4c; BYTE $0x20 // imulq 32(%rbp), %r12 LONG $0x186d8b4c // movq 24(%rbp), %r13 QUAD $0x00000000a5248d4e // leaq (,%r12,4), %r12 WORD $0x014d; BYTE $0xec // addq %r13, %r12 LONG $0x107ac1c4; WORD $0xba04 // vmovss (%r10,%rdi,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x597a81c4; WORD $0x8404 // vmulss (%r12,%r8,4), %xmm0, %xmm0 LONG $0x24648b4c; BYTE $0x50 // movq 80(%rsp), %r12 # 8-byte Reload LONG $0x587a81c4; WORD $0x8404 // vaddss (%r12,%r8,4), %xmm0, %xmm0 LONG $0x117a81c4; WORD $0x8404 // vmovss %xmm0, (%r12,%r8,4) WORD $0x894d; BYTE $0xc4 // movq %r8, %r12 LONG $0x01cc8349 // orq $1, %r12 LBB12_98: LONG $0x24443b4c; BYTE $0x10 // cmpq 16(%rsp), %r8 # 8-byte Folded Reload JE LBB12_100 LBB12_99: LONG $0x107ac1c4; WORD $0xba04 // vmovss (%r10,%rdi,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x597aa1c4; WORD $0xa044; BYTE $0xfc // vmulss -4(%rax,%r12,4), %xmm0, %xmm0 LONG $0x587aa1c4; WORD $0xa644; BYTE $0xfc // vaddss -4(%rsi,%r12,4), %xmm0, %xmm0 LONG $0x117aa1c4; WORD $0xa644; BYTE $0xfc // vmovss %xmm0, -4(%rsi,%r12,4) LONG $0x107ac1c4; WORD $0xba04 // vmovss (%r10,%rdi,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x597aa1c4; WORD $0xa004 // vmulss (%rax,%r12,4), %xmm0, %xmm0 LONG $0x587aa1c4; WORD $0xa604 // vaddss (%rsi,%r12,4), %xmm0, %xmm0 LONG $0x117aa1c4; WORD $0xa604 // vmovss %xmm0, (%rsi,%r12,4) LONG $0x02c48349 // addq $2, %r12 WORD $0x394c; BYTE $0xe3 // cmpq %r12, %rbx JNE LBB12_99 JMP LBB12_100 LBB12_47: WORD $0x854d; BYTE $0xc0 // testq %r8, %r8 LONG $0x28758b4c // movq 40(%rbp), %r14 JLE LBB12_71 WORD $0x8944; BYTE $0xc6 // movl %r8d, %esi WORD $0xc085 // testl %eax, %eax JLE LBB12_49 LONG $0x24448348; WORD $0x3860 // addq $56, 96(%rsp) # 8-byte Folded Spill LONG $0x02e5c149 // shlq $2, %r13 LONG $0x2464c148; WORD $0x0268 // shlq $2, 104(%rsp) # 8-byte Folded Spill WORD $0xd231 // xorl %edx, %edx LONG $0x244c8b4c; BYTE $0x08 // movq 8(%rsp), %r9 # 8-byte Reload JMP LBB12_61 LBB12_70: WORD $0xff48; BYTE $0xc2 // incq %rdx LONG $0x2444034c; BYTE $0x68 // addq 104(%rsp), %r8 # 8-byte Folded Reload LONG $0x2444894c; BYTE $0x70 // movq %r8, 112(%rsp) # 8-byte Spill LONG $0x24143b48 // cmpq (%rsp), %rdx # 8-byte Folded Reload JE LBB12_117 LBB12_61: WORD $0x8948; BYTE $0xd0 // movq %rdx, %rax LONG $0x45af0f48; BYTE $0x10 // imulq 16(%rbp), %rax LONG $0x244c8b48; BYTE $0x60 // movq 96(%rsp), %rcx # 8-byte Reload WORD $0xff31 // xorl %edi, %edi LONG $0x24448b4c; BYTE $0x70 // movq 112(%rsp), %r8 # 8-byte Reload JMP LBB12_62 LBB12_69: LONG $0x117ac1c4; WORD $0xb804 // vmovss %xmm0, (%r8,%rdi,4) WORD $0xff48; BYTE $0xc7 // incq %rdi WORD $0x014c; BYTE $0xe9 // addq %r13, %rcx WORD $0x3948; BYTE $0xfb // cmpq %rdi, %rbx JE LBB12_70 LBB12_62: LONG $0x107cc1c4; WORD $0x8104 // vmovups (%r9,%rax,4), %ymm0 LONG $0x4159fcc5; BYTE $0xc8 // vmulps -56(%rcx), %ymm0, %ymm0 LONG $0x197de3c4; WORD $0x01c1 // vextractf128 $1, %ymm0, %xmm1 LONG $0xc058f0c5 // vaddps %xmm0, %xmm1, %xmm0 LONG $0xc8c6f9c5; BYTE $0x01 // vshufpd $1, %xmm0, %xmm0, %xmm1 # xmm1 = xmm0[1,0] LONG $0xc158f8c5 // vaddps %xmm1, %xmm0, %xmm0 LONG $0xc816fac5 // vmovshdup %xmm0, %xmm1 # xmm1 = xmm0[1,1,3,3] LONG $0xc158fac5 // vaddss %xmm1, %xmm0, %xmm0 LONG $0x107ac1c4; WORD $0x814c; BYTE $0x20 // vmovss 32(%r9,%rax,4), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0x4959f2c5; BYTE $0xe8 // vmulss -24(%rcx), %xmm1, %xmm1 LONG $0xc058f2c5 // vaddss %xmm0, %xmm1, %xmm0 WORD $0xfe83; BYTE $0x01 // cmpl $1, %esi JE LBB12_69 LONG $0x107ac1c4; WORD $0x814c; BYTE $0x24 // vmovss 36(%r9,%rax,4), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0x4959f2c5; BYTE $0xec // vmulss -20(%rcx), %xmm1, %xmm1 LONG $0xc058f2c5 // vaddss %xmm0, %xmm1, %xmm0 WORD $0xfe83; BYTE $0x02 // cmpl $2, %esi JE LBB12_69 LONG $0x107ac1c4; WORD $0x814c; BYTE $0x28 // vmovss 40(%r9,%rax,4), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0x4959f2c5; BYTE $0xf0 // vmulss -16(%rcx), %xmm1, %xmm1 LONG $0xc058f2c5 // vaddss %xmm0, %xmm1, %xmm0 WORD $0xfe83; BYTE $0x03 // cmpl $3, %esi JE LBB12_69 LONG $0x107ac1c4; WORD $0x814c; BYTE $0x2c // vmovss 44(%r9,%rax,4), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0x4959f2c5; BYTE $0xf4 // vmulss -12(%rcx), %xmm1, %xmm1 LONG $0xc058f2c5 // vaddss %xmm0, %xmm1, %xmm0 WORD $0xfe83; BYTE $0x04 // cmpl $4, %esi JE LBB12_69 LONG $0x107ac1c4; WORD $0x814c; BYTE $0x30 // vmovss 48(%r9,%rax,4), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0x4959f2c5; BYTE $0xf8 // vmulss -8(%rcx), %xmm1, %xmm1 LONG $0xc058f2c5 // vaddss %xmm0, %xmm1, %xmm0 WORD $0xfe83; BYTE $0x05 // cmpl $5, %esi JE LBB12_69 LONG $0x107ac1c4; WORD $0x814c; BYTE $0x34 // vmovss 52(%r9,%rax,4), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0x4959f2c5; BYTE $0xfc // vmulss -4(%rcx), %xmm1, %xmm1 LONG $0xc058f2c5 // vaddss %xmm0, %xmm1, %xmm0 WORD $0xfe83; BYTE $0x06 // cmpl $6, %esi JE LBB12_69 LONG $0x107ac1c4; WORD $0x814c; BYTE $0x38 // vmovss 56(%r9,%rax,4), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0x0959f2c5 // vmulss (%rcx), %xmm1, %xmm1 LONG $0xc058f2c5 // vaddss %xmm0, %xmm1, %xmm0 JMP LBB12_69 LBB12_20: WORD $0x508d; BYTE $0xff // leal -1(%rax), %edx WORD $0x708d; BYTE $0xfe // leal -2(%rax), %esi WORD $0xd789 // movl %edx, %edi WORD $0xe783; BYTE $0xfc // andl $-4, %edi WORD $0xc8fe // decb %al LONG $0xc0b60f44 // movzbl %al, %r8d LONG $0x03e08341 // andl $3, %r8d LONG $0x05e0c141 // shll $5, %r8d WORD $0x3145; BYTE $0xd2 // xorl %r10d, %r10d JMP LBB12_21 LBB12_46: WORD $0xff49; BYTE $0xc2 // incq %r10 LONG $0x24143b4c // cmpq (%rsp), %r10 # 8-byte Folded Reload JE LBB12_117 LBB12_21: WORD $0x894c; BYTE $0xd0 // movq %r10, %rax LONG $0x45af0f48; BYTE $0x10 // imulq 16(%rbp), %rax LONG $0x244c8b48; BYTE $0x08 // movq 8(%rsp), %rcx # 8-byte Reload LONG $0x81048d48 // leaq (%rcx,%rax,4), %rax WORD $0x894c; BYTE $0xd1 // movq %r10, %rcx LONG $0x4daf0f48; BYTE $0x30 // imulq 48(%rbp), %rcx LONG $0x284d8b4c // movq 40(%rbp), %r9 LONG $0x890c8d49 // leaq (%r9,%rcx,4), %rcx WORD $0x3145; BYTE $0xc9 // xorl %r9d, %r9d JMP LBB12_22 LBB12_45: LONG $0x197de3c4; WORD $0x01c1 // vextractf128 $1, %ymm0, %xmm1 LONG $0xc058f0c5 // vaddps %xmm0, %xmm1, %xmm0 LONG $0xc8c6f9c5; BYTE $0x01 // vshufpd $1, %xmm0, %xmm0, %xmm1 # xmm1 = xmm0[1,0] LONG $0xc158f8c5 // vaddps %xmm1, %xmm0, %xmm0 LONG $0xc816fac5 // vmovshdup %xmm0, %xmm1 # xmm1 = xmm0[1,1,3,3] LONG $0xc158fac5 // vaddss %xmm1, %xmm0, %xmm0 LONG $0x117aa1c4; WORD $0x8904 // vmovss %xmm0, (%rcx,%r9,4) WORD $0xff49; BYTE $0xc1 // incq %r9 WORD $0x3949; BYTE $0xd9 // cmpq %rbx, %r9 JE LBB12_46 LBB12_22: WORD $0x894d; BYTE $0xce // movq %r9, %r14 LONG $0xf5af0f4d // imulq %r13, %r14 LONG $0x187d8b4c // movq 24(%rbp), %r15 LONG $0xb71c8d4f // leaq (%r15,%r14,4), %r11 LONG $0x0010fcc5 // vmovups (%rax), %ymm0 LONG $0x597c81c4; WORD $0xb704 // vmulps (%r15,%r14,4), %ymm0, %ymm0 WORD $0xfe83; BYTE $0x03 // cmpl $3, %esi JAE LBB12_40 WORD $0x8949; BYTE $0xc6 // movq %rax, %r14 JMP LBB12_42 LBB12_40: WORD $0x8941; BYTE $0xff // movl %edi, %r15d WORD $0x8949; BYTE $0xc6 // movq %rax, %r14 LBB12_41: LONG $0x107cc1c4; WORD $0x204e // vmovups 32(%r14), %ymm1 LONG $0x107cc1c4; WORD $0x4056 // vmovups 64(%r14), %ymm2 LONG $0x107cc1c4; WORD $0x605e // vmovups 96(%r14), %ymm3 QUAD $0x000080a6107cc1c4; BYTE $0x00 // vmovups 128(%r14), %ymm4 LONG $0x5974c1c4; WORD $0x204b // vmulps 32(%r11), %ymm1, %ymm1 LONG $0xc158fcc5 // vaddps %ymm1, %ymm0, %ymm0 LONG $0x596cc1c4; WORD $0x404b // vmulps 64(%r11), %ymm2, %ymm1 LONG $0x5964c1c4; WORD $0x6053 // vmulps 96(%r11), %ymm3, %ymm2 LONG $0xc158fcc5 // vaddps %ymm1, %ymm0, %ymm0 LONG $0xc258fcc5 // vaddps %ymm2, %ymm0, %ymm0 QUAD $0x0000808b595cc1c4; BYTE $0x00 // vmulps 128(%r11), %ymm4, %ymm1 LONG $0x80ee8349 // subq $-128, %r14 LONG $0x80eb8349 // subq $-128, %r11 LONG $0xc158fcc5 // vaddps %ymm1, %ymm0, %ymm0 LONG $0xfcc78341 // addl $-4, %r15d JNE LBB12_41 LBB12_42: WORD $0xc2f6; BYTE $0x03 // testb $3, %dl JE LBB12_45 WORD $0x3145; BYTE $0xff // xorl %r15d, %r15d LBB12_44: LONG $0x107c81c4; WORD $0x3e4c; BYTE $0x20 // vmovups 32(%r14,%r15), %ymm1 LONG $0x597481c4; WORD $0x3b4c; BYTE $0x20 // vmulps 32(%r11,%r15), %ymm1, %ymm1 LONG $0xc158fcc5 // vaddps %ymm1, %ymm0, %ymm0 LONG $0x20c78349 // addq $32, %r15 WORD $0x3945; BYTE $0xf8 // cmpl %r15d, %r8d JNE LBB12_44 JMP LBB12_45 LBB12_71: WORD $0xc085 // testl %eax, %eax JLE LBB12_75 QUAD $0xffffffffffe0b848; WORD $0x7fff // movabsq $9223372036854775776, %rax # imm = 0x7FFFFFFFFFFFFFE0 LONG $0x1ec88348 // orq $30, %rax WORD $0x2148; BYTE $0xd8 // andq %rbx, %rax QUAD $0x00000000ed0c8d4a // leaq (,%r13,8), %rcx LONG $0x04568d49 // leaq 4(%r14), %rdx LONG $0x30758b48 // movq 48(%rbp), %rsi QUAD $0x00000000b5348d48 // leaq (,%rsi,4), %rsi WORD $0xff31 // xorl %edi, %edi JMP LBB12_73 LBB12_81: WORD $0xff48; BYTE $0xc7 // incq %rdi WORD $0x0148; BYTE $0xf2 // addq %rsi, %rdx LONG $0x243c3b48 // cmpq (%rsp), %rdi # 8-byte Folded Reload JE LBB12_117 LBB12_73: WORD $0x8949; BYTE $0xf9 // movq %rdi, %r9 LONG $0x4daf0f4c; BYTE $0x10 // imulq 16(%rbp), %r9 LONG $0x01fb8348 // cmpq $1, %rbx JNE LBB12_77 WORD $0x3145; BYTE $0xc0 // xorl %r8d, %r8d LONG $0x24748b4c; BYTE $0x08 // movq 8(%rsp), %r14 # 8-byte Reload JMP LBB12_79 LBB12_77: LONG $0x18558b4c // movq 24(%rbp), %r10 WORD $0x3145; BYTE $0xc0 // xorl %r8d, %r8d LONG $0x24748b4c; BYTE $0x08 // movq 8(%rsp), %r14 # 8-byte Reload LBB12_78: LONG $0x107c81c4; WORD $0x8e04 // vmovups (%r14,%r9,4), %ymm0 LONG $0x597cc1c4; BYTE $0x02 // vmulps (%r10), %ymm0, %ymm0 LONG $0x197de3c4; WORD $0x01c1 // vextractf128 $1, %ymm0, %xmm1 LONG $0xc058f0c5 // vaddps %xmm0, %xmm1, %xmm0 LONG $0xc8c6f9c5; BYTE $0x01 // vshufpd $1, %xmm0, %xmm0, %xmm1 # xmm1 = xmm0[1,0] LONG $0xc158f8c5 // vaddps %xmm1, %xmm0, %xmm0 LONG $0xc816fac5 // vmovshdup %xmm0, %xmm1 # xmm1 = xmm0[1,1,3,3] LONG $0xc158fac5 // vaddss %xmm1, %xmm0, %xmm0 LONG $0x117aa1c4; WORD $0x8244; BYTE $0xfc // vmovss %xmm0, -4(%rdx,%r8,4) LONG $0x107c81c4; WORD $0x8e04 // vmovups (%r14,%r9,4), %ymm0 LONG $0x597c81c4; WORD $0xaa04 // vmulps (%r10,%r13,4), %ymm0, %ymm0 LONG $0x197de3c4; WORD $0x01c1 // vextractf128 $1, %ymm0, %xmm1 LONG $0xc058f0c5 // vaddps %xmm0, %xmm1, %xmm0 LONG $0xc8c6f9c5; BYTE $0x01 // vshufpd $1, %xmm0, %xmm0, %xmm1 # xmm1 = xmm0[1,0] LONG $0xc158f8c5 // vaddps %xmm1, %xmm0, %xmm0 LONG $0xc816fac5 // vmovshdup %xmm0, %xmm1 # xmm1 = xmm0[1,1,3,3] LONG $0xc158fac5 // vaddss %xmm1, %xmm0, %xmm0 LONG $0x117aa1c4; WORD $0x8204 // vmovss %xmm0, (%rdx,%r8,4) LONG $0x02c08349 // addq $2, %r8 WORD $0x0149; BYTE $0xca // addq %rcx, %r10 WORD $0x394c; BYTE $0xc0 // cmpq %r8, %rax JNE LBB12_78 LBB12_79: WORD $0xc3f6; BYTE $0x01 // testb $1, %bl JE LBB12_81 WORD $0x8949; BYTE $0xfa // movq %rdi, %r10 LONG $0x55af0f4c; BYTE $0x30 // imulq 48(%rbp), %r10 WORD $0x894d; BYTE $0xc3 // movq %r8, %r11 LONG $0xddaf0f4d // imulq %r13, %r11 LONG $0x107c81c4; WORD $0x8e04 // vmovups (%r14,%r9,4), %ymm0 LONG $0x184d8b4c // movq 24(%rbp), %r9 LONG $0x597c81c4; WORD $0x9904 // vmulps (%r9,%r11,4), %ymm0, %ymm0 LONG $0x284d8b4c // movq 40(%rbp), %r9 LONG $0x910c8d4f // leaq (%r9,%r10,4), %r9 LONG $0x197de3c4; WORD $0x01c1 // vextractf128 $1, %ymm0, %xmm1 LONG $0xc058f0c5 // vaddps %xmm0, %xmm1, %xmm0 LONG $0xc8c6f9c5; BYTE $0x01 // vshufpd $1, %xmm0, %xmm0, %xmm1 # xmm1 = xmm0[1,0] LONG $0xc158f8c5 // vaddps %xmm1, %xmm0, %xmm0 LONG $0xc816fac5 // vmovshdup %xmm0, %xmm1 # xmm1 = xmm0[1,1,3,3] LONG $0xc158fac5 // vaddss %xmm1, %xmm0, %xmm0 LONG $0x117a81c4; WORD $0x8104 // vmovss %xmm0, (%r9,%r8,4) JMP LBB12_81 LBB12_49: LONG $0x24448348; WORD $0x1860 // addq $24, 96(%rsp) # 8-byte Folded Spill LONG $0x02e5c149 // shlq $2, %r13 LONG $0x2464c148; WORD $0x0268 // shlq $2, 104(%rsp) # 8-byte Folded Spill WORD $0xc031 // xorl %eax, %eax LONG $0xc057f8c5 // vxorps %xmm0, %xmm0, %xmm0 LONG $0x244c8b4c; BYTE $0x08 // movq 8(%rsp), %r9 # 8-byte Reload JMP LBB12_50 LBB12_59: WORD $0xff48; BYTE $0xc0 // incq %rax LONG $0x2444034c; BYTE $0x68 // addq 104(%rsp), %r8 # 8-byte Folded Reload LONG $0x2444894c; BYTE $0x70 // movq %r8, 112(%rsp) # 8-byte Spill LONG $0x24043b48 // cmpq (%rsp), %rax # 8-byte Folded Reload JE LBB12_117 LBB12_50: WORD $0x8948; BYTE $0xc1 // movq %rax, %rcx LONG $0x4daf0f48; BYTE $0x10 // imulq 16(%rbp), %rcx LONG $0x24548b48; BYTE $0x60 // movq 96(%rsp), %rdx # 8-byte Reload WORD $0xff31 // xorl %edi, %edi LONG $0x24448b4c; BYTE $0x70 // movq 112(%rsp), %r8 # 8-byte Reload JMP LBB12_51 LBB12_58: LONG $0x117ac1c4; WORD $0xb80c // vmovss %xmm1, (%r8,%rdi,4) WORD $0xff48; BYTE $0xc7 // incq %rdi WORD $0x014c; BYTE $0xea // addq %r13, %rdx WORD $0x3948; BYTE $0xfb // cmpq %rdi, %rbx JE LBB12_59 LBB12_51: LONG $0x107ac1c4; WORD $0x890c // vmovss (%r9,%rcx,4), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0x4a59f2c5; BYTE $0xe8 // vmulss -24(%rdx), %xmm1, %xmm1 LONG $0xc858f2c5 // vaddss %xmm0, %xmm1, %xmm1 WORD $0xfe83; BYTE $0x01 // cmpl $1, %esi JE LBB12_58 LONG $0x107ac1c4; WORD $0x8954; BYTE $0x04 // vmovss 4(%r9,%rcx,4), %xmm2 # xmm2 = mem[0],zero,zero,zero LONG $0x5259eac5; BYTE $0xec // vmulss -20(%rdx), %xmm2, %xmm2 LONG $0xc958eac5 // vaddss %xmm1, %xmm2, %xmm1 WORD $0xfe83; BYTE $0x02 // cmpl $2, %esi JE LBB12_58 LONG $0x107ac1c4; WORD $0x8954; BYTE $0x08 // vmovss 8(%r9,%rcx,4), %xmm2 # xmm2 = mem[0],zero,zero,zero LONG $0x5259eac5; BYTE $0xf0 // vmulss -16(%rdx), %xmm2, %xmm2 LONG $0xc958eac5 // vaddss %xmm1, %xmm2, %xmm1 WORD $0xfe83; BYTE $0x03 // cmpl $3, %esi JE LBB12_58 LONG $0x107ac1c4; WORD $0x8954; BYTE $0x0c // vmovss 12(%r9,%rcx,4), %xmm2 # xmm2 = mem[0],zero,zero,zero LONG $0x5259eac5; BYTE $0xf4 // vmulss -12(%rdx), %xmm2, %xmm2 LONG $0xc958eac5 // vaddss %xmm1, %xmm2, %xmm1 WORD $0xfe83; BYTE $0x04 // cmpl $4, %esi JE LBB12_58 LONG $0x107ac1c4; WORD $0x8954; BYTE $0x10 // vmovss 16(%r9,%rcx,4), %xmm2 # xmm2 = mem[0],zero,zero,zero LONG $0x5259eac5; BYTE $0xf8 // vmulss -8(%rdx), %xmm2, %xmm2 LONG $0xc958eac5 // vaddss %xmm1, %xmm2, %xmm1 WORD $0xfe83; BYTE $0x05 // cmpl $5, %esi JE LBB12_58 LONG $0x107ac1c4; WORD $0x8954; BYTE $0x14 // vmovss 20(%r9,%rcx,4), %xmm2 # xmm2 = mem[0],zero,zero,zero LONG $0x5259eac5; BYTE $0xfc // vmulss -4(%rdx), %xmm2, %xmm2 LONG $0xc958eac5 // vaddss %xmm1, %xmm2, %xmm1 WORD $0xfe83; BYTE $0x06 // cmpl $6, %esi JE LBB12_58 LONG $0x107ac1c4; WORD $0x8954; BYTE $0x18 // vmovss 24(%r9,%rcx,4), %xmm2 # xmm2 = mem[0],zero,zero,zero LONG $0x1259eac5 // vmulss (%rdx), %xmm2, %xmm2 LONG $0xc958eac5 // vaddss %xmm1, %xmm2, %xmm1 JMP LBB12_58 LBB12_75: LONG $0x304d8b48 // movq 48(%rbp), %rcx QUAD $0x000000008d248d4c // leaq (,%rcx,4), %r12 LONG $0x02e3c148 // shlq $2, %rbx LONG $0x24048b48 // movq (%rsp), %rax # 8-byte Reload WORD $0xc289 // movl %eax, %edx WORD $0xe283; BYTE $0x07 // andl $7, %edx LONG $0x24548948; BYTE $0x10 // movq %rdx, 16(%rsp) # 8-byte Spill LONG $0x08f88348 // cmpq $8, %rax JAE LBB12_82 WORD $0x3145; BYTE $0xed // xorl %r13d, %r13d JMP LBB12_84 LBB12_82: QUAD $0xffffffffffe0b848; WORD $0x7fff // movabsq $9223372036854775776, %rax # imm = 0x7FFFFFFFFFFFFFE0 LONG $0x18c88348 // orq $24, %rax LONG $0x24042148 // andq %rax, (%rsp) # 8-byte Folded Spill LONG $0x05e1c148 // shlq $5, %rcx LONG $0x244c8948; BYTE $0x58 // movq %rcx, 88(%rsp) # 8-byte Spill WORD $0x3145; BYTE $0xed // xorl %r13d, %r13d LBB12_83: WORD $0x894c; BYTE $0xf7 // movq %r14, %rdi WORD $0xf631 // xorl %esi, %esi WORD $0x8948; BYTE $0xda // movq %rbx, %rdx LONG $0x000000e8; BYTE $0x00 // callq memset@PLT LONG $0x263c8d4f // leaq (%r14,%r12), %r15 WORD $0x894c; BYTE $0xff // movq %r15, %rdi WORD $0xf631 // xorl %esi, %esi WORD $0x8948; BYTE $0xda // movq %rbx, %rdx LONG $0x000000e8; BYTE $0x00 // callq memset@PLT WORD $0x014d; BYTE $0xe7 // addq %r12, %r15 WORD $0x894c; BYTE $0xff // movq %r15, %rdi WORD $0xf631 // xorl %esi, %esi WORD $0x8948; BYTE $0xda // movq %rbx, %rdx LONG $0x000000e8; BYTE $0x00 // callq memset@PLT WORD $0x014d; BYTE $0xe7 // addq %r12, %r15 WORD $0x894c; BYTE $0xff // movq %r15, %rdi WORD $0xf631 // xorl %esi, %esi WORD $0x8948; BYTE $0xda // movq %rbx, %rdx LONG $0x000000e8; BYTE $0x00 // callq memset@PLT WORD $0x014d; BYTE $0xe7 // addq %r12, %r15 WORD $0x894c; BYTE $0xff // movq %r15, %rdi WORD $0xf631 // xorl %esi, %esi WORD $0x8948; BYTE $0xda // movq %rbx, %rdx LONG $0x000000e8; BYTE $0x00 // callq memset@PLT WORD $0x014d; BYTE $0xe7 // addq %r12, %r15 WORD $0x894c; BYTE $0xff // movq %r15, %rdi WORD $0xf631 // xorl %esi, %esi WORD $0x8948; BYTE $0xda // movq %rbx, %rdx LONG $0x000000e8; BYTE $0x00 // callq memset@PLT WORD $0x014d; BYTE $0xe7 // addq %r12, %r15 WORD $0x894c; BYTE $0xff // movq %r15, %rdi WORD $0xf631 // xorl %esi, %esi WORD $0x8948; BYTE $0xda // movq %rbx, %rdx LONG $0x000000e8; BYTE $0x00 // callq memset@PLT WORD $0x014d; BYTE $0xe7 // addq %r12, %r15 WORD $0x894c; BYTE $0xff // movq %r15, %rdi WORD $0xf631 // xorl %esi, %esi WORD $0x8948; BYTE $0xda // movq %rbx, %rdx LONG $0x000000e8; BYTE $0x00 // callq memset@PLT LONG $0x24048b48 // movq (%rsp), %rax # 8-byte Reload LONG $0x08c58349 // addq $8, %r13 LONG $0x2474034c; BYTE $0x58 // addq 88(%rsp), %r14 # 8-byte Folded Reload WORD $0x394c; BYTE $0xe8 // cmpq %r13, %rax JNE LBB12_83 LBB12_84: LONG $0x247c8b4c; BYTE $0x10 // movq 16(%rsp), %r15 # 8-byte Reload WORD $0x854d; BYTE $0xff // testq %r15, %r15 JE LBB12_117 LONG $0x6daf0f4c; BYTE $0x30 // imulq 48(%rbp), %r13 LONG $0x28458b48 // movq 40(%rbp), %rax LONG $0xa8348d4e // leaq (%rax,%r13,4), %r14 LBB12_86: WORD $0x894c; BYTE $0xf7 // movq %r14, %rdi WORD $0xf631 // xorl %esi, %esi WORD $0x8948; BYTE $0xda // movq %rbx, %rdx LONG $0x000000e8; BYTE $0x00 // callq memset@PLT WORD $0x014d; BYTE $0xe6 // addq %r12, %r14 WORD $0xff49; BYTE $0xcf // decq %r15 JNE LBB12_86 LBB12_117: LONG $0xd8658d48 // leaq -40(%rbp), %rsp BYTE $0x5b // popq %rbx WORD $0x5c41 // popq %r12 WORD $0x5d41 // popq %r13 WORD $0x5e41 // popq %r14 WORD $0x5f41 // popq %r15 BYTE $0x5d // popq %rbp WORD $0xf8c5; BYTE $0x77 // vzeroupper POPQ DI POPQ DI POPQ DI POPQ DI POPQ DI POPQ DI RET ================================================ FILE: common/floats/floats_avx512.go ================================================ //go:build !noasm && amd64 // Code generated by GoAT. DO NOT EDIT. // versions: // clang 19.1.7 (++20250114103320+cd708029e0b2-1~exp1~20250114103432.75) // objdump 2.38 // flags: -mavx -mfma -mavx512f -O3 // source: src/floats_avx512.c package floats import "unsafe" //go:noescape func _mm512_mul_const_add_to(a, b, c, dst unsafe.Pointer, n int64) //go:noescape func _mm512_mul_const_add(a, b, c unsafe.Pointer, n int64) //go:noescape func _mm512_mul_const_to(a, b, c unsafe.Pointer, n int64) //go:noescape func _mm512_mul_const(a, b unsafe.Pointer, n int64) //go:noescape func _mm512_add_const(a, b unsafe.Pointer, n int64) //go:noescape func _mm512_sub_to(a, b, c unsafe.Pointer, n int64) //go:noescape func _mm512_sub(a, b unsafe.Pointer, n int64) //go:noescape func _mm512_mul_to(a, b, c unsafe.Pointer, n int64) //go:noescape func _mm512_div_to(a, b, c unsafe.Pointer, n int64) //go:noescape func _mm512_sqrt_to(a, b unsafe.Pointer, n int64) //go:noescape func _mm512_dot(a, b unsafe.Pointer, n int64) (result float32) //go:noescape func _mm512_euclidean(a, b unsafe.Pointer, n int64) (result float32) //go:noescape func _mm512_mm(transA, transB bool, m, n, k int64, a unsafe.Pointer, lda int64, b unsafe.Pointer, ldb int64, c unsafe.Pointer, ldc int64) ================================================ FILE: common/floats/floats_avx512.s ================================================ //go:build !noasm && amd64 // Code generated by GoAT. DO NOT EDIT. // versions: // clang 19.1.7 (++20250114103320+cd708029e0b2-1~exp1~20250114103432.75) // objdump 2.38 // flags: -mavx -mfma -mavx512f -O3 // source: src/floats_avx512.c TEXT ·_mm512_mul_const_add_to(SB), $0-40 MOVQ a+0(FP), DI MOVQ b+8(FP), SI MOVQ c+16(FP), DX MOVQ dst+24(FP), CX MOVQ n+32(FP), R8 BYTE $0x55 // pushq %rbp WORD $0x8948; BYTE $0xe5 // movq %rsp, %rbp LONG $0xf8e48348 // andq $-8, %rsp LONG $0x0f488d4d // leaq 15(%r8), %r9 WORD $0x854d; BYTE $0xc0 // testq %r8, %r8 LONG $0xc8490f4d // cmovnsq %r8, %r9 WORD $0x894c; BYTE $0xc8 // movq %r9, %rax LONG $0x04f8c148 // sarq $4, %rax LONG $0xf0e18349 // andq $-16, %r9 WORD $0x294d; BYTE $0xc8 // subq %r9, %r8 WORD $0xc085 // testl %eax, %eax JLE LBB0_6 WORD $0xf883; BYTE $0x01 // cmpl $1, %eax JE LBB0_4 WORD $0x8941; BYTE $0xc1 // movl %eax, %r9d LONG $0xfee18141; WORD $0xffff; BYTE $0x7f // andl $2147483646, %r9d # imm = 0x7FFFFFFE LBB0_3: LONG $0x487cf162; WORD $0x0710 // vmovups (%rdi), %zmm0 LONG $0x487df262; WORD $0x0e18 // vbroadcastss (%rsi), %zmm1 LONG $0x487df262; WORD $0x0aa8 // vfmadd213ps (%rdx), %zmm0, %zmm1 # zmm1 = (zmm0 * zmm1) + mem LONG $0x487cf162; WORD $0x0911 // vmovups %zmm1, (%rcx) LONG $0x487df262; WORD $0x0618 // vbroadcastss (%rsi), %zmm0 LONG $0x487cf162; WORD $0x4f10; BYTE $0x01 // vmovups 64(%rdi), %zmm1 LONG $0x4875f262; WORD $0x42a8; BYTE $0x01 // vfmadd213ps 64(%rdx), %zmm1, %zmm0 # zmm0 = (zmm1 * zmm0) + mem LONG $0x487cf162; WORD $0x4111; BYTE $0x01 // vmovups %zmm0, 64(%rcx) LONG $0x80ef8348 // subq $-128, %rdi LONG $0x80ea8348 // subq $-128, %rdx LONG $0x80e98348 // subq $-128, %rcx LONG $0xfec18341 // addl $-2, %r9d JNE LBB0_3 LBB0_4: WORD $0x01a8 // testb $1, %al JE LBB0_6 LONG $0x487cf162; WORD $0x0710 // vmovups (%rdi), %zmm0 LONG $0x487df262; WORD $0x0e18 // vbroadcastss (%rsi), %zmm1 LONG $0x487df262; WORD $0x0aa8 // vfmadd213ps (%rdx), %zmm0, %zmm1 # zmm1 = (zmm0 * zmm1) + mem LONG $0x487cf162; WORD $0x0911 // vmovups %zmm1, (%rcx) LONG $0x40c78348 // addq $64, %rdi LONG $0x40c28348 // addq $64, %rdx LONG $0x40c18348 // addq $64, %rcx LBB0_6: LONG $0x07f88349 // cmpq $7, %r8 JLE LBB0_8 LONG $0x187de2c4; BYTE $0x06 // vbroadcastss (%rsi), %ymm0 LONG $0x0759fcc5 // vmulps (%rdi), %ymm0, %ymm0 LONG $0x0258fcc5 // vaddps (%rdx), %ymm0, %ymm0 LONG $0x0111fcc5 // vmovups %ymm0, (%rcx) LONG $0x20c78348 // addq $32, %rdi LONG $0x20c28348 // addq $32, %rdx LONG $0x20c18348 // addq $32, %rcx LONG $0xf8c08341 // addl $-8, %r8d LBB0_8: WORD $0x8545; BYTE $0xc0 // testl %r8d, %r8d JLE LBB0_13 WORD $0x8944; BYTE $0xc0 // movl %r8d, %eax LONG $0x01f88341 // cmpl $1, %r8d JNE LBB0_14 WORD $0x3145; BYTE $0xc0 // xorl %r8d, %r8d JMP LBB0_11 LBB0_14: WORD $0x8941; BYTE $0xc1 // movl %eax, %r9d LONG $0xfee18141; WORD $0xffff; BYTE $0x7f // andl $2147483646, %r9d # imm = 0x7FFFFFFE WORD $0x3145; BYTE $0xc0 // xorl %r8d, %r8d LBB0_15: LONG $0x107aa1c4; WORD $0x8704 // vmovss (%rdi,%r8,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x0e10fac5 // vmovss (%rsi), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0xa979a2c4; WORD $0x820c // vfmadd213ss (%rdx,%r8,4), %xmm0, %xmm1 # xmm1 = (xmm0 * xmm1) + mem LONG $0x117aa1c4; WORD $0x810c // vmovss %xmm1, (%rcx,%r8,4) LONG $0x107aa1c4; WORD $0x8744; BYTE $0x04 // vmovss 4(%rdi,%r8,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x0e10fac5 // vmovss (%rsi), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0xa979a2c4; WORD $0x824c; BYTE $0x04 // vfmadd213ss 4(%rdx,%r8,4), %xmm0, %xmm1 # xmm1 = (xmm0 * xmm1) + mem LONG $0x117aa1c4; WORD $0x814c; BYTE $0x04 // vmovss %xmm1, 4(%rcx,%r8,4) LONG $0x02c08349 // addq $2, %r8 WORD $0x394d; BYTE $0xc1 // cmpq %r8, %r9 JNE LBB0_15 LBB0_11: WORD $0x01a8 // testb $1, %al JE LBB0_13 LONG $0x107aa1c4; WORD $0x8704 // vmovss (%rdi,%r8,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x0e10fac5 // vmovss (%rsi), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0xa979a2c4; WORD $0x820c // vfmadd213ss (%rdx,%r8,4), %xmm0, %xmm1 # xmm1 = (xmm0 * xmm1) + mem LONG $0x117aa1c4; WORD $0x810c // vmovss %xmm1, (%rcx,%r8,4) LBB0_13: WORD $0x8948; BYTE $0xec // movq %rbp, %rsp BYTE $0x5d // popq %rbp WORD $0xf8c5; BYTE $0x77 // vzeroupper RET TEXT ·_mm512_mul_const_add(SB), $0-32 MOVQ a+0(FP), DI MOVQ b+8(FP), SI MOVQ c+16(FP), DX MOVQ n+24(FP), CX BYTE $0x55 // pushq %rbp WORD $0x8948; BYTE $0xe5 // movq %rsp, %rbp LONG $0xf8e48348 // andq $-8, %rsp LONG $0x0f418d4c // leaq 15(%rcx), %r8 WORD $0x8548; BYTE $0xc9 // testq %rcx, %rcx LONG $0xc1490f4c // cmovnsq %rcx, %r8 WORD $0x894c; BYTE $0xc0 // movq %r8, %rax LONG $0x04f8c148 // sarq $4, %rax LONG $0xf0e08349 // andq $-16, %r8 WORD $0x294c; BYTE $0xc1 // subq %r8, %rcx WORD $0xc085 // testl %eax, %eax JLE LBB1_6 WORD $0xf883; BYTE $0x01 // cmpl $1, %eax JE LBB1_4 WORD $0x8941; BYTE $0xc0 // movl %eax, %r8d LONG $0xfee08141; WORD $0xffff; BYTE $0x7f // andl $2147483646, %r8d # imm = 0x7FFFFFFE LBB1_3: LONG $0x487cf162; WORD $0x0710 // vmovups (%rdi), %zmm0 LONG $0x487df262; WORD $0x0e18 // vbroadcastss (%rsi), %zmm1 LONG $0x487df262; WORD $0x0aa8 // vfmadd213ps (%rdx), %zmm0, %zmm1 # zmm1 = (zmm0 * zmm1) + mem LONG $0x487cf162; WORD $0x0a11 // vmovups %zmm1, (%rdx) LONG $0x487cf162; WORD $0x4710; BYTE $0x01 // vmovups 64(%rdi), %zmm0 LONG $0x487df262; WORD $0x0e18 // vbroadcastss (%rsi), %zmm1 LONG $0x487df262; WORD $0x4aa8; BYTE $0x01 // vfmadd213ps 64(%rdx), %zmm0, %zmm1 # zmm1 = (zmm0 * zmm1) + mem LONG $0x487cf162; WORD $0x4a11; BYTE $0x01 // vmovups %zmm1, 64(%rdx) LONG $0x80ef8348 // subq $-128, %rdi LONG $0x80ea8348 // subq $-128, %rdx LONG $0xfec08341 // addl $-2, %r8d JNE LBB1_3 LBB1_4: WORD $0x01a8 // testb $1, %al JE LBB1_6 LONG $0x487cf162; WORD $0x0710 // vmovups (%rdi), %zmm0 LONG $0x487df262; WORD $0x0e18 // vbroadcastss (%rsi), %zmm1 LONG $0x487df262; WORD $0x0aa8 // vfmadd213ps (%rdx), %zmm0, %zmm1 # zmm1 = (zmm0 * zmm1) + mem LONG $0x487cf162; WORD $0x0a11 // vmovups %zmm1, (%rdx) LONG $0x40c78348 // addq $64, %rdi LONG $0x40c28348 // addq $64, %rdx LBB1_6: LONG $0x07f98348 // cmpq $7, %rcx JLE LBB1_8 LONG $0x187de2c4; BYTE $0x06 // vbroadcastss (%rsi), %ymm0 LONG $0x0759fcc5 // vmulps (%rdi), %ymm0, %ymm0 LONG $0x0258fcc5 // vaddps (%rdx), %ymm0, %ymm0 LONG $0x0211fcc5 // vmovups %ymm0, (%rdx) LONG $0x20c78348 // addq $32, %rdi LONG $0x20c28348 // addq $32, %rdx WORD $0xc183; BYTE $0xf8 // addl $-8, %ecx LBB1_8: WORD $0xc985 // testl %ecx, %ecx JLE LBB1_13 WORD $0xc889 // movl %ecx, %eax WORD $0xf983; BYTE $0x01 // cmpl $1, %ecx JNE LBB1_14 WORD $0xc931 // xorl %ecx, %ecx JMP LBB1_11 LBB1_14: WORD $0x8941; BYTE $0xc0 // movl %eax, %r8d LONG $0xfee08141; WORD $0xffff; BYTE $0x7f // andl $2147483646, %r8d # imm = 0x7FFFFFFE WORD $0xc931 // xorl %ecx, %ecx LBB1_15: LONG $0x0410fac5; BYTE $0x8f // vmovss (%rdi,%rcx,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x0e10fac5 // vmovss (%rsi), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0xa979e2c4; WORD $0x8a0c // vfmadd213ss (%rdx,%rcx,4), %xmm0, %xmm1 # xmm1 = (xmm0 * xmm1) + mem LONG $0x0c11fac5; BYTE $0x8a // vmovss %xmm1, (%rdx,%rcx,4) LONG $0x4410fac5; WORD $0x048f // vmovss 4(%rdi,%rcx,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x0e10fac5 // vmovss (%rsi), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0xa979e2c4; WORD $0x8a4c; BYTE $0x04 // vfmadd213ss 4(%rdx,%rcx,4), %xmm0, %xmm1 # xmm1 = (xmm0 * xmm1) + mem LONG $0x4c11fac5; WORD $0x048a // vmovss %xmm1, 4(%rdx,%rcx,4) LONG $0x02c18348 // addq $2, %rcx WORD $0x3949; BYTE $0xc8 // cmpq %rcx, %r8 JNE LBB1_15 LBB1_11: WORD $0x01a8 // testb $1, %al JE LBB1_13 LONG $0x0410fac5; BYTE $0x8f // vmovss (%rdi,%rcx,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x0e10fac5 // vmovss (%rsi), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0xa979e2c4; WORD $0x8a0c // vfmadd213ss (%rdx,%rcx,4), %xmm0, %xmm1 # xmm1 = (xmm0 * xmm1) + mem LONG $0x0c11fac5; BYTE $0x8a // vmovss %xmm1, (%rdx,%rcx,4) LBB1_13: WORD $0x8948; BYTE $0xec // movq %rbp, %rsp BYTE $0x5d // popq %rbp WORD $0xf8c5; BYTE $0x77 // vzeroupper RET TEXT ·_mm512_mul_const_to(SB), $0-32 MOVQ a+0(FP), DI MOVQ b+8(FP), SI MOVQ c+16(FP), DX MOVQ n+24(FP), CX BYTE $0x55 // pushq %rbp WORD $0x8948; BYTE $0xe5 // movq %rsp, %rbp LONG $0xf8e48348 // andq $-8, %rsp LONG $0x0f418d48 // leaq 15(%rcx), %rax WORD $0x8548; BYTE $0xc9 // testq %rcx, %rcx LONG $0xc1490f48 // cmovnsq %rcx, %rax WORD $0x8949; BYTE $0xc0 // movq %rax, %r8 LONG $0x04f8c149 // sarq $4, %r8 LONG $0xf0e08348 // andq $-16, %rax WORD $0x2948; BYTE $0xc1 // subq %rax, %rcx WORD $0x8545; BYTE $0xc0 // testl %r8d, %r8d JLE LBB2_6 WORD $0x8944; BYTE $0xc0 // movl %r8d, %eax WORD $0xe083; BYTE $0x03 // andl $3, %eax LONG $0x04f88341 // cmpl $4, %r8d JB LBB2_4 LONG $0xfce08141; WORD $0xffff; BYTE $0x7f // andl $2147483644, %r8d # imm = 0x7FFFFFFC LBB2_3: LONG $0x487cf162; WORD $0x0710 // vmovups (%rdi), %zmm0 LONG $0x587cf162; WORD $0x0659 // vmulps (%rsi){1to16}, %zmm0, %zmm0 LONG $0x487cf162; WORD $0x0211 // vmovups %zmm0, (%rdx) LONG $0x487cf162; WORD $0x4710; BYTE $0x01 // vmovups 64(%rdi), %zmm0 LONG $0x587cf162; WORD $0x0659 // vmulps (%rsi){1to16}, %zmm0, %zmm0 LONG $0x487cf162; WORD $0x4211; BYTE $0x01 // vmovups %zmm0, 64(%rdx) LONG $0x487cf162; WORD $0x4710; BYTE $0x02 // vmovups 128(%rdi), %zmm0 LONG $0x587cf162; WORD $0x0659 // vmulps (%rsi){1to16}, %zmm0, %zmm0 LONG $0x487cf162; WORD $0x4211; BYTE $0x02 // vmovups %zmm0, 128(%rdx) LONG $0x487cf162; WORD $0x4710; BYTE $0x03 // vmovups 192(%rdi), %zmm0 LONG $0x587cf162; WORD $0x0659 // vmulps (%rsi){1to16}, %zmm0, %zmm0 LONG $0x487cf162; WORD $0x4211; BYTE $0x03 // vmovups %zmm0, 192(%rdx) LONG $0x00c78148; WORD $0x0001; BYTE $0x00 // addq $256, %rdi # imm = 0x100 LONG $0x00c28148; WORD $0x0001; BYTE $0x00 // addq $256, %rdx # imm = 0x100 LONG $0xfcc08341 // addl $-4, %r8d JNE LBB2_3 LBB2_4: WORD $0xc085 // testl %eax, %eax JE LBB2_6 LBB2_5: LONG $0x487cf162; WORD $0x0710 // vmovups (%rdi), %zmm0 LONG $0x587cf162; WORD $0x0659 // vmulps (%rsi){1to16}, %zmm0, %zmm0 LONG $0x487cf162; WORD $0x0211 // vmovups %zmm0, (%rdx) LONG $0x40c78348 // addq $64, %rdi LONG $0x40c28348 // addq $64, %rdx WORD $0xc8ff // decl %eax JNE LBB2_5 LBB2_6: LONG $0x07f98348 // cmpq $7, %rcx JLE LBB2_8 LONG $0x187de2c4; BYTE $0x06 // vbroadcastss (%rsi), %ymm0 LONG $0x0759fcc5 // vmulps (%rdi), %ymm0, %ymm0 LONG $0x0211fcc5 // vmovups %ymm0, (%rdx) LONG $0x20c78348 // addq $32, %rdi LONG $0x20c28348 // addq $32, %rdx WORD $0xc183; BYTE $0xf8 // addl $-8, %ecx LBB2_8: WORD $0xc985 // testl %ecx, %ecx JLE LBB2_14 WORD $0x8941; BYTE $0xc8 // movl %ecx, %r8d WORD $0x8944; BYTE $0xc0 // movl %r8d, %eax WORD $0xe083; BYTE $0x03 // andl $3, %eax WORD $0xf983; BYTE $0x04 // cmpl $4, %ecx JAE LBB2_15 WORD $0xc931 // xorl %ecx, %ecx JMP LBB2_11 LBB2_15: LONG $0xfce08141; WORD $0xffff; BYTE $0x7f // andl $2147483644, %r8d # imm = 0x7FFFFFFC WORD $0xc931 // xorl %ecx, %ecx LBB2_16: LONG $0x0410fac5; BYTE $0x8f // vmovss (%rdi,%rcx,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x0659fac5 // vmulss (%rsi), %xmm0, %xmm0 LONG $0x0411fac5; BYTE $0x8a // vmovss %xmm0, (%rdx,%rcx,4) LONG $0x4410fac5; WORD $0x048f // vmovss 4(%rdi,%rcx,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x0659fac5 // vmulss (%rsi), %xmm0, %xmm0 LONG $0x4411fac5; WORD $0x048a // vmovss %xmm0, 4(%rdx,%rcx,4) LONG $0x4410fac5; WORD $0x088f // vmovss 8(%rdi,%rcx,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x0659fac5 // vmulss (%rsi), %xmm0, %xmm0 LONG $0x4411fac5; WORD $0x088a // vmovss %xmm0, 8(%rdx,%rcx,4) LONG $0x4410fac5; WORD $0x0c8f // vmovss 12(%rdi,%rcx,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x0659fac5 // vmulss (%rsi), %xmm0, %xmm0 LONG $0x4411fac5; WORD $0x0c8a // vmovss %xmm0, 12(%rdx,%rcx,4) LONG $0x04c18348 // addq $4, %rcx WORD $0x3949; BYTE $0xc8 // cmpq %rcx, %r8 JNE LBB2_16 LBB2_11: WORD $0x8548; BYTE $0xc0 // testq %rax, %rax JE LBB2_14 LONG $0x8a148d48 // leaq (%rdx,%rcx,4), %rdx LONG $0x8f0c8d48 // leaq (%rdi,%rcx,4), %rcx WORD $0xff31 // xorl %edi, %edi LBB2_13: LONG $0x0410fac5; BYTE $0xb9 // vmovss (%rcx,%rdi,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x0659fac5 // vmulss (%rsi), %xmm0, %xmm0 LONG $0x0411fac5; BYTE $0xba // vmovss %xmm0, (%rdx,%rdi,4) WORD $0xff48; BYTE $0xc7 // incq %rdi WORD $0x3948; BYTE $0xf8 // cmpq %rdi, %rax JNE LBB2_13 LBB2_14: WORD $0x8948; BYTE $0xec // movq %rbp, %rsp BYTE $0x5d // popq %rbp WORD $0xf8c5; BYTE $0x77 // vzeroupper RET TEXT ·_mm512_mul_const(SB), $0-24 MOVQ a+0(FP), DI MOVQ b+8(FP), SI MOVQ n+16(FP), DX BYTE $0x55 // pushq %rbp WORD $0x8948; BYTE $0xe5 // movq %rsp, %rbp LONG $0xf8e48348 // andq $-8, %rsp LONG $0x0f428d48 // leaq 15(%rdx), %rax WORD $0x8548; BYTE $0xd2 // testq %rdx, %rdx LONG $0xc2490f48 // cmovnsq %rdx, %rax WORD $0x8948; BYTE $0xc1 // movq %rax, %rcx LONG $0x04f9c148 // sarq $4, %rcx LONG $0xf0e08348 // andq $-16, %rax WORD $0x2948; BYTE $0xc2 // subq %rax, %rdx WORD $0xc985 // testl %ecx, %ecx JLE LBB3_6 WORD $0xc889 // movl %ecx, %eax WORD $0xe083; BYTE $0x03 // andl $3, %eax WORD $0xf983; BYTE $0x04 // cmpl $4, %ecx JB LBB3_4 LONG $0xfffce181; WORD $0x7fff // andl $2147483644, %ecx # imm = 0x7FFFFFFC LBB3_3: LONG $0x487cf162; WORD $0x0710 // vmovups (%rdi), %zmm0 LONG $0x487cf162; WORD $0x4f10; BYTE $0x01 // vmovups 64(%rdi), %zmm1 LONG $0x487cf162; WORD $0x5710; BYTE $0x02 // vmovups 128(%rdi), %zmm2 LONG $0x587cf162; WORD $0x0659 // vmulps (%rsi){1to16}, %zmm0, %zmm0 LONG $0x487cf162; WORD $0x5f10; BYTE $0x03 // vmovups 192(%rdi), %zmm3 LONG $0x487cf162; WORD $0x0711 // vmovups %zmm0, (%rdi) LONG $0x5874f162; WORD $0x0659 // vmulps (%rsi){1to16}, %zmm1, %zmm0 LONG $0x487cf162; WORD $0x4711; BYTE $0x01 // vmovups %zmm0, 64(%rdi) LONG $0x586cf162; WORD $0x0659 // vmulps (%rsi){1to16}, %zmm2, %zmm0 LONG $0x487cf162; WORD $0x4711; BYTE $0x02 // vmovups %zmm0, 128(%rdi) LONG $0x5864f162; WORD $0x0659 // vmulps (%rsi){1to16}, %zmm3, %zmm0 LONG $0x487cf162; WORD $0x4711; BYTE $0x03 // vmovups %zmm0, 192(%rdi) LONG $0x00c78148; WORD $0x0001; BYTE $0x00 // addq $256, %rdi # imm = 0x100 WORD $0xc183; BYTE $0xfc // addl $-4, %ecx JNE LBB3_3 LBB3_4: WORD $0xc085 // testl %eax, %eax JE LBB3_6 LBB3_5: LONG $0x487cf162; WORD $0x0710 // vmovups (%rdi), %zmm0 LONG $0x587cf162; WORD $0x0659 // vmulps (%rsi){1to16}, %zmm0, %zmm0 LONG $0x487cf162; WORD $0x0711 // vmovups %zmm0, (%rdi) LONG $0x40c78348 // addq $64, %rdi WORD $0xc8ff // decl %eax JNE LBB3_5 LBB3_6: LONG $0x07fa8348 // cmpq $7, %rdx JLE LBB3_8 LONG $0x187de2c4; BYTE $0x06 // vbroadcastss (%rsi), %ymm0 LONG $0x0759fcc5 // vmulps (%rdi), %ymm0, %ymm0 LONG $0x0711fcc5 // vmovups %ymm0, (%rdi) LONG $0x20c78348 // addq $32, %rdi WORD $0xc283; BYTE $0xf8 // addl $-8, %edx LBB3_8: WORD $0xd285 // testl %edx, %edx JLE LBB3_14 WORD $0xd189 // movl %edx, %ecx WORD $0xc889 // movl %ecx, %eax WORD $0xe083; BYTE $0x03 // andl $3, %eax WORD $0xfa83; BYTE $0x04 // cmpl $4, %edx JAE LBB3_15 WORD $0xd231 // xorl %edx, %edx JMP LBB3_11 LBB3_15: LONG $0xfffce181; WORD $0x7fff // andl $2147483644, %ecx # imm = 0x7FFFFFFC WORD $0xd231 // xorl %edx, %edx LBB3_16: LONG $0x0610fac5 // vmovss (%rsi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x0459fac5; BYTE $0x97 // vmulss (%rdi,%rdx,4), %xmm0, %xmm0 LONG $0x0411fac5; BYTE $0x97 // vmovss %xmm0, (%rdi,%rdx,4) LONG $0x0610fac5 // vmovss (%rsi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x4459fac5; WORD $0x0497 // vmulss 4(%rdi,%rdx,4), %xmm0, %xmm0 LONG $0x4411fac5; WORD $0x0497 // vmovss %xmm0, 4(%rdi,%rdx,4) LONG $0x0610fac5 // vmovss (%rsi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x4459fac5; WORD $0x0897 // vmulss 8(%rdi,%rdx,4), %xmm0, %xmm0 LONG $0x4411fac5; WORD $0x0897 // vmovss %xmm0, 8(%rdi,%rdx,4) LONG $0x0610fac5 // vmovss (%rsi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x4459fac5; WORD $0x0c97 // vmulss 12(%rdi,%rdx,4), %xmm0, %xmm0 LONG $0x4411fac5; WORD $0x0c97 // vmovss %xmm0, 12(%rdi,%rdx,4) LONG $0x04c28348 // addq $4, %rdx WORD $0x3948; BYTE $0xd1 // cmpq %rdx, %rcx JNE LBB3_16 LBB3_11: WORD $0x8548; BYTE $0xc0 // testq %rax, %rax JE LBB3_14 LONG $0x970c8d48 // leaq (%rdi,%rdx,4), %rcx WORD $0xd231 // xorl %edx, %edx LBB3_13: LONG $0x0610fac5 // vmovss (%rsi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x0459fac5; BYTE $0x91 // vmulss (%rcx,%rdx,4), %xmm0, %xmm0 LONG $0x0411fac5; BYTE $0x91 // vmovss %xmm0, (%rcx,%rdx,4) WORD $0xff48; BYTE $0xc2 // incq %rdx WORD $0x3948; BYTE $0xd0 // cmpq %rdx, %rax JNE LBB3_13 LBB3_14: WORD $0x8948; BYTE $0xec // movq %rbp, %rsp BYTE $0x5d // popq %rbp WORD $0xf8c5; BYTE $0x77 // vzeroupper RET TEXT ·_mm512_add_const(SB), $0-24 MOVQ a+0(FP), DI MOVQ b+8(FP), SI MOVQ n+16(FP), DX BYTE $0x55 // pushq %rbp WORD $0x8948; BYTE $0xe5 // movq %rsp, %rbp LONG $0xf8e48348 // andq $-8, %rsp LONG $0x0f428d48 // leaq 15(%rdx), %rax WORD $0x8548; BYTE $0xd2 // testq %rdx, %rdx LONG $0xc2490f48 // cmovnsq %rdx, %rax WORD $0x8948; BYTE $0xc1 // movq %rax, %rcx LONG $0x04f9c148 // sarq $4, %rcx LONG $0xf0e08348 // andq $-16, %rax WORD $0x2948; BYTE $0xc2 // subq %rax, %rdx WORD $0xc985 // testl %ecx, %ecx JLE LBB4_6 WORD $0xc889 // movl %ecx, %eax WORD $0xe083; BYTE $0x03 // andl $3, %eax WORD $0xf983; BYTE $0x04 // cmpl $4, %ecx JB LBB4_4 LONG $0xfffce181; WORD $0x7fff // andl $2147483644, %ecx # imm = 0x7FFFFFFC LBB4_3: LONG $0x487cf162; WORD $0x0710 // vmovups (%rdi), %zmm0 LONG $0x487cf162; WORD $0x4f10; BYTE $0x01 // vmovups 64(%rdi), %zmm1 LONG $0x487cf162; WORD $0x5710; BYTE $0x02 // vmovups 128(%rdi), %zmm2 LONG $0x587cf162; WORD $0x0658 // vaddps (%rsi){1to16}, %zmm0, %zmm0 LONG $0x487cf162; WORD $0x5f10; BYTE $0x03 // vmovups 192(%rdi), %zmm3 LONG $0x487cf162; WORD $0x0711 // vmovups %zmm0, (%rdi) LONG $0x5874f162; WORD $0x0658 // vaddps (%rsi){1to16}, %zmm1, %zmm0 LONG $0x487cf162; WORD $0x4711; BYTE $0x01 // vmovups %zmm0, 64(%rdi) LONG $0x586cf162; WORD $0x0658 // vaddps (%rsi){1to16}, %zmm2, %zmm0 LONG $0x487cf162; WORD $0x4711; BYTE $0x02 // vmovups %zmm0, 128(%rdi) LONG $0x5864f162; WORD $0x0658 // vaddps (%rsi){1to16}, %zmm3, %zmm0 LONG $0x487cf162; WORD $0x4711; BYTE $0x03 // vmovups %zmm0, 192(%rdi) LONG $0x00c78148; WORD $0x0001; BYTE $0x00 // addq $256, %rdi # imm = 0x100 WORD $0xc183; BYTE $0xfc // addl $-4, %ecx JNE LBB4_3 LBB4_4: WORD $0xc085 // testl %eax, %eax JE LBB4_6 LBB4_5: LONG $0x487cf162; WORD $0x0710 // vmovups (%rdi), %zmm0 LONG $0x587cf162; WORD $0x0658 // vaddps (%rsi){1to16}, %zmm0, %zmm0 LONG $0x487cf162; WORD $0x0711 // vmovups %zmm0, (%rdi) LONG $0x40c78348 // addq $64, %rdi WORD $0xc8ff // decl %eax JNE LBB4_5 LBB4_6: LONG $0x07fa8348 // cmpq $7, %rdx JLE LBB4_8 LONG $0x187de2c4; BYTE $0x06 // vbroadcastss (%rsi), %ymm0 LONG $0x0758fcc5 // vaddps (%rdi), %ymm0, %ymm0 LONG $0x0711fcc5 // vmovups %ymm0, (%rdi) LONG $0x20c78348 // addq $32, %rdi WORD $0xc283; BYTE $0xf8 // addl $-8, %edx LBB4_8: WORD $0xd285 // testl %edx, %edx JLE LBB4_14 WORD $0xd189 // movl %edx, %ecx WORD $0xc889 // movl %ecx, %eax WORD $0xe083; BYTE $0x03 // andl $3, %eax WORD $0xfa83; BYTE $0x04 // cmpl $4, %edx JAE LBB4_15 WORD $0xd231 // xorl %edx, %edx JMP LBB4_11 LBB4_15: LONG $0xfffce181; WORD $0x7fff // andl $2147483644, %ecx # imm = 0x7FFFFFFC WORD $0xd231 // xorl %edx, %edx LBB4_16: LONG $0x0610fac5 // vmovss (%rsi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x0458fac5; BYTE $0x97 // vaddss (%rdi,%rdx,4), %xmm0, %xmm0 LONG $0x0411fac5; BYTE $0x97 // vmovss %xmm0, (%rdi,%rdx,4) LONG $0x0610fac5 // vmovss (%rsi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x4458fac5; WORD $0x0497 // vaddss 4(%rdi,%rdx,4), %xmm0, %xmm0 LONG $0x4411fac5; WORD $0x0497 // vmovss %xmm0, 4(%rdi,%rdx,4) LONG $0x0610fac5 // vmovss (%rsi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x4458fac5; WORD $0x0897 // vaddss 8(%rdi,%rdx,4), %xmm0, %xmm0 LONG $0x4411fac5; WORD $0x0897 // vmovss %xmm0, 8(%rdi,%rdx,4) LONG $0x0610fac5 // vmovss (%rsi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x4458fac5; WORD $0x0c97 // vaddss 12(%rdi,%rdx,4), %xmm0, %xmm0 LONG $0x4411fac5; WORD $0x0c97 // vmovss %xmm0, 12(%rdi,%rdx,4) LONG $0x04c28348 // addq $4, %rdx WORD $0x3948; BYTE $0xd1 // cmpq %rdx, %rcx JNE LBB4_16 LBB4_11: WORD $0x8548; BYTE $0xc0 // testq %rax, %rax JE LBB4_14 LONG $0x970c8d48 // leaq (%rdi,%rdx,4), %rcx WORD $0xd231 // xorl %edx, %edx LBB4_13: LONG $0x0610fac5 // vmovss (%rsi), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x0458fac5; BYTE $0x91 // vaddss (%rcx,%rdx,4), %xmm0, %xmm0 LONG $0x0411fac5; BYTE $0x91 // vmovss %xmm0, (%rcx,%rdx,4) WORD $0xff48; BYTE $0xc2 // incq %rdx WORD $0x3948; BYTE $0xd0 // cmpq %rdx, %rax JNE LBB4_13 LBB4_14: WORD $0x8948; BYTE $0xec // movq %rbp, %rsp BYTE $0x5d // popq %rbp WORD $0xf8c5; BYTE $0x77 // vzeroupper RET TEXT ·_mm512_sub_to(SB), $0-32 MOVQ a+0(FP), DI MOVQ b+8(FP), SI MOVQ c+16(FP), DX MOVQ n+24(FP), CX BYTE $0x55 // pushq %rbp WORD $0x8948; BYTE $0xe5 // movq %rsp, %rbp LONG $0xf8e48348 // andq $-8, %rsp LONG $0x0f418d48 // leaq 15(%rcx), %rax WORD $0x8548; BYTE $0xc9 // testq %rcx, %rcx LONG $0xc1490f48 // cmovnsq %rcx, %rax WORD $0x8949; BYTE $0xc0 // movq %rax, %r8 LONG $0x04f8c149 // sarq $4, %r8 LONG $0xf0e08348 // andq $-16, %rax WORD $0x2948; BYTE $0xc1 // subq %rax, %rcx WORD $0x8545; BYTE $0xc0 // testl %r8d, %r8d JLE LBB5_6 WORD $0x8944; BYTE $0xc0 // movl %r8d, %eax WORD $0xe083; BYTE $0x03 // andl $3, %eax LONG $0x04f88341 // cmpl $4, %r8d JB LBB5_4 LONG $0xfce08141; WORD $0xffff; BYTE $0x7f // andl $2147483644, %r8d # imm = 0x7FFFFFFC LBB5_3: LONG $0x487cf162; WORD $0x0710 // vmovups (%rdi), %zmm0 LONG $0x487cf162; WORD $0x065c // vsubps (%rsi), %zmm0, %zmm0 LONG $0x487cf162; WORD $0x0211 // vmovups %zmm0, (%rdx) LONG $0x487cf162; WORD $0x4710; BYTE $0x01 // vmovups 64(%rdi), %zmm0 LONG $0x487cf162; WORD $0x465c; BYTE $0x01 // vsubps 64(%rsi), %zmm0, %zmm0 LONG $0x487cf162; WORD $0x4211; BYTE $0x01 // vmovups %zmm0, 64(%rdx) LONG $0x487cf162; WORD $0x4710; BYTE $0x02 // vmovups 128(%rdi), %zmm0 LONG $0x487cf162; WORD $0x465c; BYTE $0x02 // vsubps 128(%rsi), %zmm0, %zmm0 LONG $0x487cf162; WORD $0x4211; BYTE $0x02 // vmovups %zmm0, 128(%rdx) LONG $0x487cf162; WORD $0x4710; BYTE $0x03 // vmovups 192(%rdi), %zmm0 LONG $0x487cf162; WORD $0x465c; BYTE $0x03 // vsubps 192(%rsi), %zmm0, %zmm0 LONG $0x487cf162; WORD $0x4211; BYTE $0x03 // vmovups %zmm0, 192(%rdx) LONG $0x00c78148; WORD $0x0001; BYTE $0x00 // addq $256, %rdi # imm = 0x100 LONG $0x00c68148; WORD $0x0001; BYTE $0x00 // addq $256, %rsi # imm = 0x100 LONG $0x00c28148; WORD $0x0001; BYTE $0x00 // addq $256, %rdx # imm = 0x100 LONG $0xfcc08341 // addl $-4, %r8d JNE LBB5_3 LBB5_4: WORD $0xc085 // testl %eax, %eax JE LBB5_6 LBB5_5: LONG $0x487cf162; WORD $0x0710 // vmovups (%rdi), %zmm0 LONG $0x487cf162; WORD $0x065c // vsubps (%rsi), %zmm0, %zmm0 LONG $0x487cf162; WORD $0x0211 // vmovups %zmm0, (%rdx) LONG $0x40c78348 // addq $64, %rdi LONG $0x40c68348 // addq $64, %rsi LONG $0x40c28348 // addq $64, %rdx WORD $0xc8ff // decl %eax JNE LBB5_5 LBB5_6: LONG $0x07f98348 // cmpq $7, %rcx JLE LBB5_8 LONG $0x0710fcc5 // vmovups (%rdi), %ymm0 LONG $0x065cfcc5 // vsubps (%rsi), %ymm0, %ymm0 LONG $0x0211fcc5 // vmovups %ymm0, (%rdx) LONG $0x20c78348 // addq $32, %rdi LONG $0x20c68348 // addq $32, %rsi LONG $0x20c28348 // addq $32, %rdx WORD $0xc183; BYTE $0xf8 // addl $-8, %ecx LBB5_8: WORD $0xc985 // testl %ecx, %ecx JLE LBB5_14 WORD $0x8941; BYTE $0xc8 // movl %ecx, %r8d WORD $0x8944; BYTE $0xc0 // movl %r8d, %eax WORD $0xe083; BYTE $0x03 // andl $3, %eax WORD $0xf983; BYTE $0x04 // cmpl $4, %ecx JAE LBB5_15 WORD $0xc931 // xorl %ecx, %ecx JMP LBB5_11 LBB5_15: LONG $0xfce08141; WORD $0xffff; BYTE $0x7f // andl $2147483644, %r8d # imm = 0x7FFFFFFC WORD $0xc931 // xorl %ecx, %ecx LBB5_16: LONG $0x0410fac5; BYTE $0x8f // vmovss (%rdi,%rcx,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x045cfac5; BYTE $0x8e // vsubss (%rsi,%rcx,4), %xmm0, %xmm0 LONG $0x0411fac5; BYTE $0x8a // vmovss %xmm0, (%rdx,%rcx,4) LONG $0x4410fac5; WORD $0x048f // vmovss 4(%rdi,%rcx,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x445cfac5; WORD $0x048e // vsubss 4(%rsi,%rcx,4), %xmm0, %xmm0 LONG $0x4411fac5; WORD $0x048a // vmovss %xmm0, 4(%rdx,%rcx,4) LONG $0x4410fac5; WORD $0x088f // vmovss 8(%rdi,%rcx,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x445cfac5; WORD $0x088e // vsubss 8(%rsi,%rcx,4), %xmm0, %xmm0 LONG $0x4411fac5; WORD $0x088a // vmovss %xmm0, 8(%rdx,%rcx,4) LONG $0x4410fac5; WORD $0x0c8f // vmovss 12(%rdi,%rcx,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x445cfac5; WORD $0x0c8e // vsubss 12(%rsi,%rcx,4), %xmm0, %xmm0 LONG $0x4411fac5; WORD $0x0c8a // vmovss %xmm0, 12(%rdx,%rcx,4) LONG $0x04c18348 // addq $4, %rcx WORD $0x3949; BYTE $0xc8 // cmpq %rcx, %r8 JNE LBB5_16 LBB5_11: WORD $0x8548; BYTE $0xc0 // testq %rax, %rax JE LBB5_14 LONG $0x8a148d48 // leaq (%rdx,%rcx,4), %rdx LONG $0x8e348d48 // leaq (%rsi,%rcx,4), %rsi LONG $0x8f0c8d48 // leaq (%rdi,%rcx,4), %rcx WORD $0xff31 // xorl %edi, %edi LBB5_13: LONG $0x0410fac5; BYTE $0xb9 // vmovss (%rcx,%rdi,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x045cfac5; BYTE $0xbe // vsubss (%rsi,%rdi,4), %xmm0, %xmm0 LONG $0x0411fac5; BYTE $0xba // vmovss %xmm0, (%rdx,%rdi,4) WORD $0xff48; BYTE $0xc7 // incq %rdi WORD $0x3948; BYTE $0xf8 // cmpq %rdi, %rax JNE LBB5_13 LBB5_14: WORD $0x8948; BYTE $0xec // movq %rbp, %rsp BYTE $0x5d // popq %rbp WORD $0xf8c5; BYTE $0x77 // vzeroupper RET TEXT ·_mm512_sub(SB), $0-24 MOVQ a+0(FP), DI MOVQ b+8(FP), SI MOVQ n+16(FP), DX BYTE $0x55 // pushq %rbp WORD $0x8948; BYTE $0xe5 // movq %rsp, %rbp LONG $0xf8e48348 // andq $-8, %rsp LONG $0x0f428d48 // leaq 15(%rdx), %rax WORD $0x8548; BYTE $0xd2 // testq %rdx, %rdx LONG $0xc2490f48 // cmovnsq %rdx, %rax WORD $0x8948; BYTE $0xc1 // movq %rax, %rcx LONG $0x04f9c148 // sarq $4, %rcx LONG $0xf0e08348 // andq $-16, %rax WORD $0x2948; BYTE $0xc2 // subq %rax, %rdx WORD $0xc985 // testl %ecx, %ecx JLE LBB6_6 WORD $0xc889 // movl %ecx, %eax WORD $0xe083; BYTE $0x03 // andl $3, %eax WORD $0xf983; BYTE $0x04 // cmpl $4, %ecx JB LBB6_4 LONG $0xfffce181; WORD $0x7fff // andl $2147483644, %ecx # imm = 0x7FFFFFFC LBB6_3: LONG $0x487cf162; WORD $0x0710 // vmovups (%rdi), %zmm0 LONG $0x487cf162; WORD $0x4f10; BYTE $0x01 // vmovups 64(%rdi), %zmm1 LONG $0x487cf162; WORD $0x5710; BYTE $0x02 // vmovups 128(%rdi), %zmm2 LONG $0x487cf162; WORD $0x065c // vsubps (%rsi), %zmm0, %zmm0 LONG $0x487cf162; WORD $0x5f10; BYTE $0x03 // vmovups 192(%rdi), %zmm3 LONG $0x487cf162; WORD $0x0711 // vmovups %zmm0, (%rdi) LONG $0x4874f162; WORD $0x465c; BYTE $0x01 // vsubps 64(%rsi), %zmm1, %zmm0 LONG $0x487cf162; WORD $0x4711; BYTE $0x01 // vmovups %zmm0, 64(%rdi) LONG $0x486cf162; WORD $0x465c; BYTE $0x02 // vsubps 128(%rsi), %zmm2, %zmm0 LONG $0x487cf162; WORD $0x4711; BYTE $0x02 // vmovups %zmm0, 128(%rdi) LONG $0x4864f162; WORD $0x465c; BYTE $0x03 // vsubps 192(%rsi), %zmm3, %zmm0 LONG $0x487cf162; WORD $0x4711; BYTE $0x03 // vmovups %zmm0, 192(%rdi) LONG $0x00c78148; WORD $0x0001; BYTE $0x00 // addq $256, %rdi # imm = 0x100 LONG $0x00c68148; WORD $0x0001; BYTE $0x00 // addq $256, %rsi # imm = 0x100 WORD $0xc183; BYTE $0xfc // addl $-4, %ecx JNE LBB6_3 LBB6_4: WORD $0xc085 // testl %eax, %eax JE LBB6_6 LBB6_5: LONG $0x487cf162; WORD $0x0710 // vmovups (%rdi), %zmm0 LONG $0x487cf162; WORD $0x065c // vsubps (%rsi), %zmm0, %zmm0 LONG $0x487cf162; WORD $0x0711 // vmovups %zmm0, (%rdi) LONG $0x40c78348 // addq $64, %rdi LONG $0x40c68348 // addq $64, %rsi WORD $0xc8ff // decl %eax JNE LBB6_5 LBB6_6: WORD $0x8548; BYTE $0xd2 // testq %rdx, %rdx JLE LBB6_12 WORD $0xd089 // movl %edx, %eax LONG $0xff488d48 // leaq -1(%rax), %rcx WORD $0xe283; BYTE $0x03 // andl $3, %edx LONG $0x03f98348 // cmpq $3, %rcx JAE LBB6_13 WORD $0xc931 // xorl %ecx, %ecx JMP LBB6_9 LBB6_13: WORD $0x2948; BYTE $0xd0 // subq %rdx, %rax WORD $0xc931 // xorl %ecx, %ecx LBB6_14: LONG $0x0410fac5; BYTE $0x8f // vmovss (%rdi,%rcx,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x045cfac5; BYTE $0x8e // vsubss (%rsi,%rcx,4), %xmm0, %xmm0 LONG $0x4c10fac5; WORD $0x048f // vmovss 4(%rdi,%rcx,4), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0x0411fac5; BYTE $0x8f // vmovss %xmm0, (%rdi,%rcx,4) LONG $0x445cf2c5; WORD $0x048e // vsubss 4(%rsi,%rcx,4), %xmm1, %xmm0 LONG $0x4411fac5; WORD $0x048f // vmovss %xmm0, 4(%rdi,%rcx,4) LONG $0x4410fac5; WORD $0x088f // vmovss 8(%rdi,%rcx,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x445cfac5; WORD $0x088e // vsubss 8(%rsi,%rcx,4), %xmm0, %xmm0 LONG $0x4411fac5; WORD $0x088f // vmovss %xmm0, 8(%rdi,%rcx,4) LONG $0x4410fac5; WORD $0x0c8f // vmovss 12(%rdi,%rcx,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x445cfac5; WORD $0x0c8e // vsubss 12(%rsi,%rcx,4), %xmm0, %xmm0 LONG $0x4411fac5; WORD $0x0c8f // vmovss %xmm0, 12(%rdi,%rcx,4) LONG $0x04c18348 // addq $4, %rcx WORD $0x3948; BYTE $0xc8 // cmpq %rcx, %rax JNE LBB6_14 LBB6_9: WORD $0x8548; BYTE $0xd2 // testq %rdx, %rdx JE LBB6_12 LONG $0x8f048d48 // leaq (%rdi,%rcx,4), %rax LONG $0x8e0c8d48 // leaq (%rsi,%rcx,4), %rcx WORD $0xf631 // xorl %esi, %esi LBB6_11: LONG $0x0410fac5; BYTE $0xb0 // vmovss (%rax,%rsi,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x045cfac5; BYTE $0xb1 // vsubss (%rcx,%rsi,4), %xmm0, %xmm0 LONG $0x0411fac5; BYTE $0xb0 // vmovss %xmm0, (%rax,%rsi,4) WORD $0xff48; BYTE $0xc6 // incq %rsi WORD $0x3948; BYTE $0xf2 // cmpq %rsi, %rdx JNE LBB6_11 LBB6_12: WORD $0x8948; BYTE $0xec // movq %rbp, %rsp BYTE $0x5d // popq %rbp WORD $0xf8c5; BYTE $0x77 // vzeroupper RET TEXT ·_mm512_mul_to(SB), $0-32 MOVQ a+0(FP), DI MOVQ b+8(FP), SI MOVQ c+16(FP), DX MOVQ n+24(FP), CX BYTE $0x55 // pushq %rbp WORD $0x8948; BYTE $0xe5 // movq %rsp, %rbp LONG $0xf8e48348 // andq $-8, %rsp LONG $0x0f418d48 // leaq 15(%rcx), %rax WORD $0x8548; BYTE $0xc9 // testq %rcx, %rcx LONG $0xc1490f48 // cmovnsq %rcx, %rax WORD $0x8949; BYTE $0xc0 // movq %rax, %r8 LONG $0x04f8c149 // sarq $4, %r8 LONG $0xf0e08348 // andq $-16, %rax WORD $0x2948; BYTE $0xc1 // subq %rax, %rcx WORD $0x8545; BYTE $0xc0 // testl %r8d, %r8d JLE LBB7_6 WORD $0x8944; BYTE $0xc0 // movl %r8d, %eax WORD $0xe083; BYTE $0x03 // andl $3, %eax LONG $0x04f88341 // cmpl $4, %r8d JB LBB7_4 LONG $0xfce08141; WORD $0xffff; BYTE $0x7f // andl $2147483644, %r8d # imm = 0x7FFFFFFC LBB7_3: LONG $0x487cf162; WORD $0x0710 // vmovups (%rdi), %zmm0 LONG $0x487cf162; WORD $0x0659 // vmulps (%rsi), %zmm0, %zmm0 LONG $0x487cf162; WORD $0x0211 // vmovups %zmm0, (%rdx) LONG $0x487cf162; WORD $0x4710; BYTE $0x01 // vmovups 64(%rdi), %zmm0 LONG $0x487cf162; WORD $0x4659; BYTE $0x01 // vmulps 64(%rsi), %zmm0, %zmm0 LONG $0x487cf162; WORD $0x4211; BYTE $0x01 // vmovups %zmm0, 64(%rdx) LONG $0x487cf162; WORD $0x4710; BYTE $0x02 // vmovups 128(%rdi), %zmm0 LONG $0x487cf162; WORD $0x4659; BYTE $0x02 // vmulps 128(%rsi), %zmm0, %zmm0 LONG $0x487cf162; WORD $0x4211; BYTE $0x02 // vmovups %zmm0, 128(%rdx) LONG $0x487cf162; WORD $0x4710; BYTE $0x03 // vmovups 192(%rdi), %zmm0 LONG $0x487cf162; WORD $0x4659; BYTE $0x03 // vmulps 192(%rsi), %zmm0, %zmm0 LONG $0x487cf162; WORD $0x4211; BYTE $0x03 // vmovups %zmm0, 192(%rdx) LONG $0x00c78148; WORD $0x0001; BYTE $0x00 // addq $256, %rdi # imm = 0x100 LONG $0x00c68148; WORD $0x0001; BYTE $0x00 // addq $256, %rsi # imm = 0x100 LONG $0x00c28148; WORD $0x0001; BYTE $0x00 // addq $256, %rdx # imm = 0x100 LONG $0xfcc08341 // addl $-4, %r8d JNE LBB7_3 LBB7_4: WORD $0xc085 // testl %eax, %eax JE LBB7_6 LBB7_5: LONG $0x487cf162; WORD $0x0710 // vmovups (%rdi), %zmm0 LONG $0x487cf162; WORD $0x0659 // vmulps (%rsi), %zmm0, %zmm0 LONG $0x487cf162; WORD $0x0211 // vmovups %zmm0, (%rdx) LONG $0x40c78348 // addq $64, %rdi LONG $0x40c68348 // addq $64, %rsi LONG $0x40c28348 // addq $64, %rdx WORD $0xc8ff // decl %eax JNE LBB7_5 LBB7_6: LONG $0x07f98348 // cmpq $7, %rcx JLE LBB7_8 LONG $0x0710fcc5 // vmovups (%rdi), %ymm0 LONG $0x0659fcc5 // vmulps (%rsi), %ymm0, %ymm0 LONG $0x0211fcc5 // vmovups %ymm0, (%rdx) LONG $0x20c78348 // addq $32, %rdi LONG $0x20c68348 // addq $32, %rsi LONG $0x20c28348 // addq $32, %rdx WORD $0xc183; BYTE $0xf8 // addl $-8, %ecx LBB7_8: WORD $0xc985 // testl %ecx, %ecx JLE LBB7_14 WORD $0x8941; BYTE $0xc8 // movl %ecx, %r8d WORD $0x8944; BYTE $0xc0 // movl %r8d, %eax WORD $0xe083; BYTE $0x03 // andl $3, %eax WORD $0xf983; BYTE $0x04 // cmpl $4, %ecx JAE LBB7_15 WORD $0xc931 // xorl %ecx, %ecx JMP LBB7_11 LBB7_15: LONG $0xfce08141; WORD $0xffff; BYTE $0x7f // andl $2147483644, %r8d # imm = 0x7FFFFFFC WORD $0xc931 // xorl %ecx, %ecx LBB7_16: LONG $0x0410fac5; BYTE $0x8f // vmovss (%rdi,%rcx,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x0459fac5; BYTE $0x8e // vmulss (%rsi,%rcx,4), %xmm0, %xmm0 LONG $0x0411fac5; BYTE $0x8a // vmovss %xmm0, (%rdx,%rcx,4) LONG $0x4410fac5; WORD $0x048f // vmovss 4(%rdi,%rcx,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x4459fac5; WORD $0x048e // vmulss 4(%rsi,%rcx,4), %xmm0, %xmm0 LONG $0x4411fac5; WORD $0x048a // vmovss %xmm0, 4(%rdx,%rcx,4) LONG $0x4410fac5; WORD $0x088f // vmovss 8(%rdi,%rcx,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x4459fac5; WORD $0x088e // vmulss 8(%rsi,%rcx,4), %xmm0, %xmm0 LONG $0x4411fac5; WORD $0x088a // vmovss %xmm0, 8(%rdx,%rcx,4) LONG $0x4410fac5; WORD $0x0c8f // vmovss 12(%rdi,%rcx,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x4459fac5; WORD $0x0c8e // vmulss 12(%rsi,%rcx,4), %xmm0, %xmm0 LONG $0x4411fac5; WORD $0x0c8a // vmovss %xmm0, 12(%rdx,%rcx,4) LONG $0x04c18348 // addq $4, %rcx WORD $0x3949; BYTE $0xc8 // cmpq %rcx, %r8 JNE LBB7_16 LBB7_11: WORD $0x8548; BYTE $0xc0 // testq %rax, %rax JE LBB7_14 LONG $0x8a148d48 // leaq (%rdx,%rcx,4), %rdx LONG $0x8e348d48 // leaq (%rsi,%rcx,4), %rsi LONG $0x8f0c8d48 // leaq (%rdi,%rcx,4), %rcx WORD $0xff31 // xorl %edi, %edi LBB7_13: LONG $0x0410fac5; BYTE $0xb9 // vmovss (%rcx,%rdi,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x0459fac5; BYTE $0xbe // vmulss (%rsi,%rdi,4), %xmm0, %xmm0 LONG $0x0411fac5; BYTE $0xba // vmovss %xmm0, (%rdx,%rdi,4) WORD $0xff48; BYTE $0xc7 // incq %rdi WORD $0x3948; BYTE $0xf8 // cmpq %rdi, %rax JNE LBB7_13 LBB7_14: WORD $0x8948; BYTE $0xec // movq %rbp, %rsp BYTE $0x5d // popq %rbp WORD $0xf8c5; BYTE $0x77 // vzeroupper RET TEXT ·_mm512_div_to(SB), $0-32 MOVQ a+0(FP), DI MOVQ b+8(FP), SI MOVQ c+16(FP), DX MOVQ n+24(FP), CX BYTE $0x55 // pushq %rbp WORD $0x8948; BYTE $0xe5 // movq %rsp, %rbp LONG $0xf8e48348 // andq $-8, %rsp LONG $0x0f418d48 // leaq 15(%rcx), %rax WORD $0x8548; BYTE $0xc9 // testq %rcx, %rcx LONG $0xc1490f48 // cmovnsq %rcx, %rax WORD $0x8949; BYTE $0xc0 // movq %rax, %r8 LONG $0x04f8c149 // sarq $4, %r8 LONG $0xf0e08348 // andq $-16, %rax WORD $0x2948; BYTE $0xc1 // subq %rax, %rcx WORD $0x8545; BYTE $0xc0 // testl %r8d, %r8d JLE LBB8_6 WORD $0x8944; BYTE $0xc0 // movl %r8d, %eax WORD $0xe083; BYTE $0x03 // andl $3, %eax LONG $0x04f88341 // cmpl $4, %r8d JB LBB8_4 LONG $0xfce08141; WORD $0xffff; BYTE $0x7f // andl $2147483644, %r8d # imm = 0x7FFFFFFC LBB8_3: LONG $0x487cf162; WORD $0x0710 // vmovups (%rdi), %zmm0 LONG $0x487cf162; WORD $0x065e // vdivps (%rsi), %zmm0, %zmm0 LONG $0x487cf162; WORD $0x0211 // vmovups %zmm0, (%rdx) LONG $0x487cf162; WORD $0x4710; BYTE $0x01 // vmovups 64(%rdi), %zmm0 LONG $0x487cf162; WORD $0x465e; BYTE $0x01 // vdivps 64(%rsi), %zmm0, %zmm0 LONG $0x487cf162; WORD $0x4211; BYTE $0x01 // vmovups %zmm0, 64(%rdx) LONG $0x487cf162; WORD $0x4710; BYTE $0x02 // vmovups 128(%rdi), %zmm0 LONG $0x487cf162; WORD $0x465e; BYTE $0x02 // vdivps 128(%rsi), %zmm0, %zmm0 LONG $0x487cf162; WORD $0x4211; BYTE $0x02 // vmovups %zmm0, 128(%rdx) LONG $0x487cf162; WORD $0x4710; BYTE $0x03 // vmovups 192(%rdi), %zmm0 LONG $0x487cf162; WORD $0x465e; BYTE $0x03 // vdivps 192(%rsi), %zmm0, %zmm0 LONG $0x487cf162; WORD $0x4211; BYTE $0x03 // vmovups %zmm0, 192(%rdx) LONG $0x00c78148; WORD $0x0001; BYTE $0x00 // addq $256, %rdi # imm = 0x100 LONG $0x00c68148; WORD $0x0001; BYTE $0x00 // addq $256, %rsi # imm = 0x100 LONG $0x00c28148; WORD $0x0001; BYTE $0x00 // addq $256, %rdx # imm = 0x100 LONG $0xfcc08341 // addl $-4, %r8d JNE LBB8_3 LBB8_4: WORD $0xc085 // testl %eax, %eax JE LBB8_6 LBB8_5: LONG $0x487cf162; WORD $0x0710 // vmovups (%rdi), %zmm0 LONG $0x487cf162; WORD $0x065e // vdivps (%rsi), %zmm0, %zmm0 LONG $0x487cf162; WORD $0x0211 // vmovups %zmm0, (%rdx) LONG $0x40c78348 // addq $64, %rdi LONG $0x40c68348 // addq $64, %rsi LONG $0x40c28348 // addq $64, %rdx WORD $0xc8ff // decl %eax JNE LBB8_5 LBB8_6: LONG $0x07f98348 // cmpq $7, %rcx JLE LBB8_8 LONG $0x0710fcc5 // vmovups (%rdi), %ymm0 LONG $0x065efcc5 // vdivps (%rsi), %ymm0, %ymm0 LONG $0x0211fcc5 // vmovups %ymm0, (%rdx) LONG $0x20c78348 // addq $32, %rdi LONG $0x20c68348 // addq $32, %rsi LONG $0x20c28348 // addq $32, %rdx WORD $0xc183; BYTE $0xf8 // addl $-8, %ecx LBB8_8: WORD $0xc985 // testl %ecx, %ecx JLE LBB8_14 WORD $0x8941; BYTE $0xc8 // movl %ecx, %r8d WORD $0x8944; BYTE $0xc0 // movl %r8d, %eax WORD $0xe083; BYTE $0x03 // andl $3, %eax WORD $0xf983; BYTE $0x04 // cmpl $4, %ecx JAE LBB8_15 WORD $0xc931 // xorl %ecx, %ecx JMP LBB8_11 LBB8_15: LONG $0xfce08141; WORD $0xffff; BYTE $0x7f // andl $2147483644, %r8d # imm = 0x7FFFFFFC WORD $0xc931 // xorl %ecx, %ecx LBB8_16: LONG $0x0410fac5; BYTE $0x8f // vmovss (%rdi,%rcx,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x045efac5; BYTE $0x8e // vdivss (%rsi,%rcx,4), %xmm0, %xmm0 LONG $0x0411fac5; BYTE $0x8a // vmovss %xmm0, (%rdx,%rcx,4) LONG $0x4410fac5; WORD $0x048f // vmovss 4(%rdi,%rcx,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x445efac5; WORD $0x048e // vdivss 4(%rsi,%rcx,4), %xmm0, %xmm0 LONG $0x4411fac5; WORD $0x048a // vmovss %xmm0, 4(%rdx,%rcx,4) LONG $0x4410fac5; WORD $0x088f // vmovss 8(%rdi,%rcx,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x445efac5; WORD $0x088e // vdivss 8(%rsi,%rcx,4), %xmm0, %xmm0 LONG $0x4411fac5; WORD $0x088a // vmovss %xmm0, 8(%rdx,%rcx,4) LONG $0x4410fac5; WORD $0x0c8f // vmovss 12(%rdi,%rcx,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x445efac5; WORD $0x0c8e // vdivss 12(%rsi,%rcx,4), %xmm0, %xmm0 LONG $0x4411fac5; WORD $0x0c8a // vmovss %xmm0, 12(%rdx,%rcx,4) LONG $0x04c18348 // addq $4, %rcx WORD $0x3949; BYTE $0xc8 // cmpq %rcx, %r8 JNE LBB8_16 LBB8_11: WORD $0x8548; BYTE $0xc0 // testq %rax, %rax JE LBB8_14 LONG $0x8a148d48 // leaq (%rdx,%rcx,4), %rdx LONG $0x8e348d48 // leaq (%rsi,%rcx,4), %rsi LONG $0x8f0c8d48 // leaq (%rdi,%rcx,4), %rcx WORD $0xff31 // xorl %edi, %edi LBB8_13: LONG $0x0410fac5; BYTE $0xb9 // vmovss (%rcx,%rdi,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x045efac5; BYTE $0xbe // vdivss (%rsi,%rdi,4), %xmm0, %xmm0 LONG $0x0411fac5; BYTE $0xba // vmovss %xmm0, (%rdx,%rdi,4) WORD $0xff48; BYTE $0xc7 // incq %rdi WORD $0x3948; BYTE $0xf8 // cmpq %rdi, %rax JNE LBB8_13 LBB8_14: WORD $0x8948; BYTE $0xec // movq %rbp, %rsp BYTE $0x5d // popq %rbp WORD $0xf8c5; BYTE $0x77 // vzeroupper RET TEXT ·_mm512_sqrt_to(SB), $0-24 MOVQ a+0(FP), DI MOVQ b+8(FP), SI MOVQ n+16(FP), DX BYTE $0x55 // pushq %rbp WORD $0x8948; BYTE $0xe5 // movq %rsp, %rbp LONG $0xf8e48348 // andq $-8, %rsp LONG $0x0f428d48 // leaq 15(%rdx), %rax WORD $0x8548; BYTE $0xd2 // testq %rdx, %rdx LONG $0xc2490f48 // cmovnsq %rdx, %rax WORD $0x8948; BYTE $0xc1 // movq %rax, %rcx LONG $0x04f9c148 // sarq $4, %rcx LONG $0xf0e08348 // andq $-16, %rax WORD $0x2948; BYTE $0xc2 // subq %rax, %rdx WORD $0xc985 // testl %ecx, %ecx JLE LBB9_6 WORD $0xc889 // movl %ecx, %eax WORD $0xe083; BYTE $0x03 // andl $3, %eax WORD $0xf983; BYTE $0x04 // cmpl $4, %ecx JB LBB9_4 LONG $0xfffce181; WORD $0x7fff // andl $2147483644, %ecx # imm = 0x7FFFFFFC LBB9_3: LONG $0x487cf162; WORD $0x0751 // vsqrtps (%rdi), %zmm0 LONG $0x487cf162; WORD $0x0611 // vmovups %zmm0, (%rsi) LONG $0x487cf162; WORD $0x4751; BYTE $0x01 // vsqrtps 64(%rdi), %zmm0 LONG $0x487cf162; WORD $0x4611; BYTE $0x01 // vmovups %zmm0, 64(%rsi) LONG $0x487cf162; WORD $0x4751; BYTE $0x02 // vsqrtps 128(%rdi), %zmm0 LONG $0x487cf162; WORD $0x4611; BYTE $0x02 // vmovups %zmm0, 128(%rsi) LONG $0x487cf162; WORD $0x4751; BYTE $0x03 // vsqrtps 192(%rdi), %zmm0 LONG $0x487cf162; WORD $0x4611; BYTE $0x03 // vmovups %zmm0, 192(%rsi) LONG $0x00c78148; WORD $0x0001; BYTE $0x00 // addq $256, %rdi # imm = 0x100 LONG $0x00c68148; WORD $0x0001; BYTE $0x00 // addq $256, %rsi # imm = 0x100 WORD $0xc183; BYTE $0xfc // addl $-4, %ecx JNE LBB9_3 LBB9_4: WORD $0xc085 // testl %eax, %eax JE LBB9_6 LBB9_5: LONG $0x487cf162; WORD $0x0751 // vsqrtps (%rdi), %zmm0 LONG $0x487cf162; WORD $0x0611 // vmovups %zmm0, (%rsi) LONG $0x40c78348 // addq $64, %rdi LONG $0x40c68348 // addq $64, %rsi WORD $0xc8ff // decl %eax JNE LBB9_5 LBB9_6: LONG $0x07fa8348 // cmpq $7, %rdx JLE LBB9_8 LONG $0x0751fcc5 // vsqrtps (%rdi), %ymm0 LONG $0x0611fcc5 // vmovups %ymm0, (%rsi) LONG $0x20c78348 // addq $32, %rdi LONG $0x20c68348 // addq $32, %rsi WORD $0xc283; BYTE $0xf8 // addl $-8, %edx LBB9_8: WORD $0xd285 // testl %edx, %edx JLE LBB9_14 WORD $0xd189 // movl %edx, %ecx WORD $0xc889 // movl %ecx, %eax WORD $0xe083; BYTE $0x03 // andl $3, %eax WORD $0xfa83; BYTE $0x04 // cmpl $4, %edx JAE LBB9_15 WORD $0xd231 // xorl %edx, %edx JMP LBB9_11 LBB9_15: LONG $0xfffce181; WORD $0x7fff // andl $2147483644, %ecx # imm = 0x7FFFFFFC WORD $0xd231 // xorl %edx, %edx LBB9_16: LONG $0x0410fac5; BYTE $0x97 // vmovss (%rdi,%rdx,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0xc051fac5 // vsqrtss %xmm0, %xmm0, %xmm0 LONG $0x0411fac5; BYTE $0x96 // vmovss %xmm0, (%rsi,%rdx,4) LONG $0x4410fac5; WORD $0x0497 // vmovss 4(%rdi,%rdx,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0xc051fac5 // vsqrtss %xmm0, %xmm0, %xmm0 LONG $0x4411fac5; WORD $0x0496 // vmovss %xmm0, 4(%rsi,%rdx,4) LONG $0x4410fac5; WORD $0x0897 // vmovss 8(%rdi,%rdx,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0xc051fac5 // vsqrtss %xmm0, %xmm0, %xmm0 LONG $0x4411fac5; WORD $0x0896 // vmovss %xmm0, 8(%rsi,%rdx,4) LONG $0x4410fac5; WORD $0x0c97 // vmovss 12(%rdi,%rdx,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0xc051fac5 // vsqrtss %xmm0, %xmm0, %xmm0 LONG $0x4411fac5; WORD $0x0c96 // vmovss %xmm0, 12(%rsi,%rdx,4) LONG $0x04c28348 // addq $4, %rdx WORD $0x3948; BYTE $0xd1 // cmpq %rdx, %rcx JNE LBB9_16 LBB9_11: WORD $0x8548; BYTE $0xc0 // testq %rax, %rax JE LBB9_14 LONG $0x960c8d48 // leaq (%rsi,%rdx,4), %rcx LONG $0x97148d48 // leaq (%rdi,%rdx,4), %rdx WORD $0xf631 // xorl %esi, %esi LBB9_13: LONG $0x0410fac5; BYTE $0xb2 // vmovss (%rdx,%rsi,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0xc051fac5 // vsqrtss %xmm0, %xmm0, %xmm0 LONG $0x0411fac5; BYTE $0xb1 // vmovss %xmm0, (%rcx,%rsi,4) WORD $0xff48; BYTE $0xc6 // incq %rsi WORD $0x3948; BYTE $0xf0 // cmpq %rsi, %rax JNE LBB9_13 LBB9_14: WORD $0x8948; BYTE $0xec // movq %rbp, %rsp BYTE $0x5d // popq %rbp WORD $0xf8c5; BYTE $0x77 // vzeroupper RET TEXT ·_mm512_dot(SB), $8-32 MOVQ a+0(FP), DI MOVQ b+8(FP), SI MOVQ n+16(FP), DX BYTE $0x55 // pushq %rbp WORD $0x8948; BYTE $0xe5 // movq %rsp, %rbp LONG $0xf8e48348 // andq $-8, %rsp LONG $0x0f428d48 // leaq 15(%rdx), %rax WORD $0x8548; BYTE $0xd2 // testq %rdx, %rdx LONG $0xc2490f48 // cmovnsq %rdx, %rax WORD $0x8948; BYTE $0xc1 // movq %rax, %rcx LONG $0x04f9c148 // sarq $4, %rcx LONG $0xf0e08348 // andq $-16, %rax WORD $0x2948; BYTE $0xc2 // subq %rax, %rdx WORD $0xc985 // testl %ecx, %ecx JLE LBB10_1 LONG $0x487cf162; WORD $0x0710 // vmovups (%rdi), %zmm0 LONG $0x487cf162; WORD $0x0e59 // vmulps (%rsi), %zmm0, %zmm1 LONG $0x40c78348 // addq $64, %rdi LONG $0x40c68348 // addq $64, %rsi WORD $0xf983; BYTE $0x01 // cmpl $1, %ecx JE LBB10_9 WORD $0x8949; BYTE $0xc8 // movq %rcx, %r8 LONG $0x06e0c149 // shlq $6, %r8 QUAD $0x003fffffff80b848; WORD $0x0000 // movabsq $274877906816, %rax # imm = 0x3FFFFFFF80 WORD $0x0149; BYTE $0xc0 // addq %rax, %r8 LONG $0x40c88348 // orq $64, %rax WORD $0x214c; BYTE $0xc0 // andq %r8, %rax LONG $0xff518d44 // leal -1(%rcx), %r10d LONG $0xfe418d44 // leal -2(%rcx), %r8d LONG $0x03f88341 // cmpl $3, %r8d JAE LBB10_14 WORD $0x8949; BYTE $0xf8 // movq %rdi, %r8 WORD $0x8949; BYTE $0xf1 // movq %rsi, %r9 LONG $0x487cf162; WORD $0xc128 // vmovaps %zmm1, %zmm0 JMP LBB10_5 LBB10_1: LONG $0xc957f0c5 // vxorps %xmm1, %xmm1, %xmm1 JMP LBB10_9 LBB10_14: WORD $0x8945; BYTE $0xd3 // movl %r10d, %r11d LONG $0xfce38341 // andl $-4, %r11d WORD $0x8949; BYTE $0xf8 // movq %rdi, %r8 WORD $0x8949; BYTE $0xf1 // movq %rsi, %r9 LBB10_15: LONG $0x487cd162; WORD $0x0010 // vmovups (%r8), %zmm0 LONG $0x487cd162; WORD $0x5010; BYTE $0x01 // vmovups 64(%r8), %zmm2 LONG $0x487cd162; WORD $0x5810; BYTE $0x02 // vmovups 128(%r8), %zmm3 LONG $0x487cd162; WORD $0x6010; BYTE $0x03 // vmovups 192(%r8), %zmm4 LONG $0x4875d262; WORD $0x0198 // vfmadd132ps (%r9), %zmm1, %zmm0 # zmm0 = (zmm0 * mem) + zmm1 LONG $0x486dd262; WORD $0x41b8; BYTE $0x01 // vfmadd231ps 64(%r9), %zmm2, %zmm0 # zmm0 = (zmm2 * mem) + zmm0 LONG $0x4865d262; WORD $0x41b8; BYTE $0x02 // vfmadd231ps 128(%r9), %zmm3, %zmm0 # zmm0 = (zmm3 * mem) + zmm0 LONG $0x485dd262; WORD $0x41b8; BYTE $0x03 // vfmadd231ps 192(%r9), %zmm4, %zmm0 # zmm0 = (zmm4 * mem) + zmm0 LONG $0x00c08149; WORD $0x0001; BYTE $0x00 // addq $256, %r8 # imm = 0x100 LONG $0x00c18149; WORD $0x0001; BYTE $0x00 // addq $256, %r9 # imm = 0x100 LONG $0x487cf162; WORD $0xc828 // vmovaps %zmm0, %zmm1 LONG $0xfcc38341 // addl $-4, %r11d JNE LBB10_15 LBB10_5: LONG $0x40588d4c // leaq 64(%rax), %r11 LONG $0x03c2f641 // testb $3, %r10b JE LBB10_8 WORD $0xc9fe // decb %cl WORD $0xb60f; BYTE $0xc9 // movzbl %cl, %ecx WORD $0xe183; BYTE $0x03 // andl $3, %ecx WORD $0xe1c1; BYTE $0x06 // shll $6, %ecx WORD $0x3145; BYTE $0xd2 // xorl %r10d, %r10d LBB10_7: LONG $0x487c9162; WORD $0x0c10; BYTE $0x10 // vmovups (%r8,%r10), %zmm1 LONG $0x48759262; WORD $0x04b8; BYTE $0x11 // vfmadd231ps (%r9,%r10), %zmm1, %zmm0 # zmm0 = (zmm1 * mem) + zmm0 LONG $0x40c28349 // addq $64, %r10 WORD $0x3944; BYTE $0xd1 // cmpl %r10d, %ecx JNE LBB10_7 LBB10_8: WORD $0x0148; BYTE $0xc7 // addq %rax, %rdi LONG $0x40c78348 // addq $64, %rdi WORD $0x014c; BYTE $0xde // addq %r11, %rsi LONG $0x487cf162; WORD $0xc828 // vmovaps %zmm0, %zmm1 LBB10_9: LONG $0x48fdf362; WORD $0xc81b; BYTE $0x01 // vextractf64x4 $1, %zmm1, %ymm0 LONG $0xc058f4c5 // vaddps %ymm0, %ymm1, %ymm0 LONG $0x197de3c4; WORD $0x01c1 // vextractf128 $1, %ymm0, %xmm1 LONG $0xc158f8c5 // vaddps %xmm1, %xmm0, %xmm0 LONG $0xc8c6f9c5; BYTE $0x01 // vshufpd $1, %xmm0, %xmm0, %xmm1 # xmm1 = xmm0[1,0] LONG $0xc158f8c5 // vaddps %xmm1, %xmm0, %xmm0 LONG $0xc816fac5 // vmovshdup %xmm0, %xmm1 # xmm1 = xmm0[1,1,3,3] LONG $0xc158f8c5 // vaddps %xmm1, %xmm0, %xmm0 LONG $0x07fa8348 // cmpq $7, %rdx JLE LBB10_11 LONG $0x0f10fcc5 // vmovups (%rdi), %ymm1 LONG $0x0e59f4c5 // vmulps (%rsi), %ymm1, %ymm1 LONG $0x20c78348 // addq $32, %rdi LONG $0x20c68348 // addq $32, %rsi LONG $0x197de3c4; WORD $0x01ca // vextractf128 $1, %ymm1, %xmm2 LONG $0xc958e8c5 // vaddps %xmm1, %xmm2, %xmm1 LONG $0xd1c6f1c5; BYTE $0x01 // vshufpd $1, %xmm1, %xmm1, %xmm2 # xmm2 = xmm1[1,0] LONG $0xca58f0c5 // vaddps %xmm2, %xmm1, %xmm1 LONG $0xd116fac5 // vmovshdup %xmm1, %xmm2 # xmm2 = xmm1[1,1,3,3] LONG $0xca58f2c5 // vaddss %xmm2, %xmm1, %xmm1 LONG $0xc158fac5 // vaddss %xmm1, %xmm0, %xmm0 WORD $0xc283; BYTE $0xf8 // addl $-8, %edx LBB10_11: WORD $0xd285 // testl %edx, %edx JLE LBB10_21 WORD $0x8941; BYTE $0xd0 // movl %edx, %r8d WORD $0x8944; BYTE $0xc0 // movl %r8d, %eax WORD $0xe083; BYTE $0x03 // andl $3, %eax WORD $0xfa83; BYTE $0x04 // cmpl $4, %edx JAE LBB10_16 WORD $0xc931 // xorl %ecx, %ecx JMP LBB10_18 LBB10_16: LONG $0xfce08141; WORD $0xffff; BYTE $0x7f // andl $2147483644, %r8d # imm = 0x7FFFFFFC WORD $0xc931 // xorl %ecx, %ecx LBB10_17: LONG $0x0c10fac5; BYTE $0x8f // vmovss (%rdi,%rcx,4), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0x5410fac5; WORD $0x048f // vmovss 4(%rdi,%rcx,4), %xmm2 # xmm2 = mem[0],zero,zero,zero LONG $0x9979e2c4; WORD $0x8e0c // vfmadd132ss (%rsi,%rcx,4), %xmm0, %xmm1 # xmm1 = (xmm1 * mem) + xmm0 LONG $0xb969e2c4; WORD $0x8e4c; BYTE $0x04 // vfmadd231ss 4(%rsi,%rcx,4), %xmm2, %xmm1 # xmm1 = (xmm2 * mem) + xmm1 LONG $0x5410fac5; WORD $0x088f // vmovss 8(%rdi,%rcx,4), %xmm2 # xmm2 = mem[0],zero,zero,zero LONG $0x9971e2c4; WORD $0x8e54; BYTE $0x08 // vfmadd132ss 8(%rsi,%rcx,4), %xmm1, %xmm2 # xmm2 = (xmm2 * mem) + xmm1 LONG $0x4410fac5; WORD $0x0c8f // vmovss 12(%rdi,%rcx,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x9969e2c4; WORD $0x8e44; BYTE $0x0c // vfmadd132ss 12(%rsi,%rcx,4), %xmm2, %xmm0 # xmm0 = (xmm0 * mem) + xmm2 LONG $0x04c18348 // addq $4, %rcx WORD $0x3949; BYTE $0xc8 // cmpq %rcx, %r8 JNE LBB10_17 LBB10_18: WORD $0x8548; BYTE $0xc0 // testq %rax, %rax JE LBB10_21 LONG $0x8e148d48 // leaq (%rsi,%rcx,4), %rdx LONG $0x8f0c8d48 // leaq (%rdi,%rcx,4), %rcx WORD $0xf631 // xorl %esi, %esi LBB10_20: LONG $0x0c10fac5; BYTE $0xb1 // vmovss (%rcx,%rsi,4), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0xb971e2c4; WORD $0xb204 // vfmadd231ss (%rdx,%rsi,4), %xmm1, %xmm0 # xmm0 = (xmm1 * mem) + xmm0 WORD $0xff48; BYTE $0xc6 // incq %rsi WORD $0x3948; BYTE $0xf0 // cmpq %rsi, %rax JNE LBB10_20 LBB10_21: WORD $0x8948; BYTE $0xec // movq %rbp, %rsp BYTE $0x5d // popq %rbp WORD $0xf8c5; BYTE $0x77 // vzeroupper MOVSS X0, result+24(FP) RET TEXT ·_mm512_euclidean(SB), $8-32 MOVQ a+0(FP), DI MOVQ b+8(FP), SI MOVQ n+16(FP), DX BYTE $0x55 // pushq %rbp WORD $0x8948; BYTE $0xe5 // movq %rsp, %rbp LONG $0xf8e48348 // andq $-8, %rsp LONG $0x0f428d48 // leaq 15(%rdx), %rax WORD $0x8548; BYTE $0xd2 // testq %rdx, %rdx LONG $0xc2490f48 // cmovnsq %rdx, %rax WORD $0x8948; BYTE $0xc1 // movq %rax, %rcx LONG $0x04f9c148 // sarq $4, %rcx LONG $0xf0e08348 // andq $-16, %rax WORD $0x2948; BYTE $0xc2 // subq %rax, %rdx WORD $0xc985 // testl %ecx, %ecx JLE LBB11_1 LONG $0x487cf162; WORD $0x0710 // vmovups (%rdi), %zmm0 LONG $0x487cf162; WORD $0x065c // vsubps (%rsi), %zmm0, %zmm0 LONG $0x487cf162; WORD $0xc059 // vmulps %zmm0, %zmm0, %zmm0 LONG $0x40c78348 // addq $64, %rdi LONG $0x40c68348 // addq $64, %rsi WORD $0xf983; BYTE $0x01 // cmpl $1, %ecx JE LBB11_9 WORD $0x8949; BYTE $0xc8 // movq %rcx, %r8 LONG $0x06e0c149 // shlq $6, %r8 QUAD $0x003fffffff80b848; WORD $0x0000 // movabsq $274877906816, %rax # imm = 0x3FFFFFFF80 WORD $0x0149; BYTE $0xc0 // addq %rax, %r8 LONG $0x40c88348 // orq $64, %rax WORD $0x214c; BYTE $0xc0 // andq %r8, %rax LONG $0xff518d44 // leal -1(%rcx), %r10d LONG $0xfe418d44 // leal -2(%rcx), %r8d LONG $0x03f88341 // cmpl $3, %r8d JAE LBB11_18 WORD $0x8949; BYTE $0xf8 // movq %rdi, %r8 WORD $0x8949; BYTE $0xf1 // movq %rsi, %r9 JMP LBB11_5 LBB11_1: LONG $0xc057f8c5 // vxorps %xmm0, %xmm0, %xmm0 JMP LBB11_9 LBB11_18: WORD $0x8945; BYTE $0xd3 // movl %r10d, %r11d LONG $0xfce38341 // andl $-4, %r11d WORD $0x8949; BYTE $0xf8 // movq %rdi, %r8 WORD $0x8949; BYTE $0xf1 // movq %rsi, %r9 LBB11_19: LONG $0x487cd162; WORD $0x0810 // vmovups (%r8), %zmm1 LONG $0x487cd162; WORD $0x5010; BYTE $0x01 // vmovups 64(%r8), %zmm2 LONG $0x487cd162; WORD $0x5810; BYTE $0x02 // vmovups 128(%r8), %zmm3 LONG $0x487cd162; WORD $0x6010; BYTE $0x03 // vmovups 192(%r8), %zmm4 LONG $0x4874d162; WORD $0x095c // vsubps (%r9), %zmm1, %zmm1 LONG $0x4874f162; WORD $0xc959 // vmulps %zmm1, %zmm1, %zmm1 LONG $0x487cf162; WORD $0xc158 // vaddps %zmm1, %zmm0, %zmm0 LONG $0x486cd162; WORD $0x495c; BYTE $0x01 // vsubps 64(%r9), %zmm2, %zmm1 LONG $0x4874f162; WORD $0xc959 // vmulps %zmm1, %zmm1, %zmm1 LONG $0x487cf162; WORD $0xc158 // vaddps %zmm1, %zmm0, %zmm0 LONG $0x4864d162; WORD $0x495c; BYTE $0x02 // vsubps 128(%r9), %zmm3, %zmm1 LONG $0x4874f162; WORD $0xc959 // vmulps %zmm1, %zmm1, %zmm1 LONG $0x487cf162; WORD $0xc158 // vaddps %zmm1, %zmm0, %zmm0 LONG $0x485cd162; WORD $0x495c; BYTE $0x03 // vsubps 192(%r9), %zmm4, %zmm1 LONG $0x4874f162; WORD $0xc959 // vmulps %zmm1, %zmm1, %zmm1 LONG $0x487cf162; WORD $0xc158 // vaddps %zmm1, %zmm0, %zmm0 LONG $0x00c08149; WORD $0x0001; BYTE $0x00 // addq $256, %r8 # imm = 0x100 LONG $0x00c18149; WORD $0x0001; BYTE $0x00 // addq $256, %r9 # imm = 0x100 LONG $0xfcc38341 // addl $-4, %r11d JNE LBB11_19 LBB11_5: LONG $0x40588d4c // leaq 64(%rax), %r11 LONG $0x03c2f641 // testb $3, %r10b JE LBB11_8 WORD $0xc9fe // decb %cl WORD $0xb60f; BYTE $0xc9 // movzbl %cl, %ecx WORD $0xe183; BYTE $0x03 // andl $3, %ecx WORD $0xe1c1; BYTE $0x06 // shll $6, %ecx WORD $0x3145; BYTE $0xd2 // xorl %r10d, %r10d LBB11_7: LONG $0x487c9162; WORD $0x0c10; BYTE $0x10 // vmovups (%r8,%r10), %zmm1 LONG $0x48749162; WORD $0x0c5c; BYTE $0x11 // vsubps (%r9,%r10), %zmm1, %zmm1 LONG $0x4874f162; WORD $0xc959 // vmulps %zmm1, %zmm1, %zmm1 LONG $0x487cf162; WORD $0xc158 // vaddps %zmm1, %zmm0, %zmm0 LONG $0x40c28349 // addq $64, %r10 WORD $0x3944; BYTE $0xd1 // cmpl %r10d, %ecx JNE LBB11_7 LBB11_8: WORD $0x0148; BYTE $0xc7 // addq %rax, %rdi LONG $0x40c78348 // addq $64, %rdi WORD $0x014c; BYTE $0xde // addq %r11, %rsi LBB11_9: LONG $0x48fdf362; WORD $0xc11b; BYTE $0x01 // vextractf64x4 $1, %zmm0, %ymm1 LONG $0xc158fcc5 // vaddps %ymm1, %ymm0, %ymm0 LONG $0x197de3c4; WORD $0x01c1 // vextractf128 $1, %ymm0, %xmm1 LONG $0xc158f8c5 // vaddps %xmm1, %xmm0, %xmm0 LONG $0xc8c6f9c5; BYTE $0x03 // vshufpd $3, %xmm0, %xmm0, %xmm1 # xmm1 = xmm0[1,1] LONG $0xc158f8c5 // vaddps %xmm1, %xmm0, %xmm0 LONG $0xc816fac5 // vmovshdup %xmm0, %xmm1 # xmm1 = xmm0[1,1,3,3] LONG $0xc158f8c5 // vaddps %xmm1, %xmm0, %xmm0 LONG $0x07fa8348 // cmpq $7, %rdx JLE LBB11_11 LONG $0x0f10fcc5 // vmovups (%rdi), %ymm1 LONG $0x0e5cf4c5 // vsubps (%rsi), %ymm1, %ymm1 LONG $0xc959f4c5 // vmulps %ymm1, %ymm1, %ymm1 LONG $0x20c78348 // addq $32, %rdi LONG $0x20c68348 // addq $32, %rsi LONG $0x197de3c4; WORD $0x01ca // vextractf128 $1, %ymm1, %xmm2 LONG $0xc958e8c5 // vaddps %xmm1, %xmm2, %xmm1 LONG $0xd1c6f1c5; BYTE $0x01 // vshufpd $1, %xmm1, %xmm1, %xmm2 # xmm2 = xmm1[1,0] LONG $0xca58f0c5 // vaddps %xmm2, %xmm1, %xmm1 LONG $0xd116fac5 // vmovshdup %xmm1, %xmm2 # xmm2 = xmm1[1,1,3,3] LONG $0xca58f2c5 // vaddss %xmm2, %xmm1, %xmm1 LONG $0xc158fac5 // vaddss %xmm1, %xmm0, %xmm0 WORD $0xc283; BYTE $0xf8 // addl $-8, %edx LBB11_11: WORD $0xd285 // testl %edx, %edx JLE LBB11_17 WORD $0x8941; BYTE $0xd0 // movl %edx, %r8d WORD $0x8944; BYTE $0xc0 // movl %r8d, %eax WORD $0xe083; BYTE $0x03 // andl $3, %eax WORD $0xfa83; BYTE $0x04 // cmpl $4, %edx JAE LBB11_20 WORD $0xc931 // xorl %ecx, %ecx JMP LBB11_14 LBB11_20: LONG $0xfce08141; WORD $0xffff; BYTE $0x7f // andl $2147483644, %r8d # imm = 0x7FFFFFFC WORD $0xc931 // xorl %ecx, %ecx LBB11_21: LONG $0x0c10fac5; BYTE $0x8f // vmovss (%rdi,%rcx,4), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0x5410fac5; WORD $0x048f // vmovss 4(%rdi,%rcx,4), %xmm2 # xmm2 = mem[0],zero,zero,zero LONG $0x0c5cf2c5; BYTE $0x8e // vsubss (%rsi,%rcx,4), %xmm1, %xmm1 LONG $0x545ceac5; WORD $0x048e // vsubss 4(%rsi,%rcx,4), %xmm2, %xmm2 LONG $0xa971e2c4; BYTE $0xc8 // vfmadd213ss %xmm0, %xmm1, %xmm1 # xmm1 = (xmm1 * xmm1) + xmm0 LONG $0x4410fac5; WORD $0x088f // vmovss 8(%rdi,%rcx,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x5c5cfac5; WORD $0x088e // vsubss 8(%rsi,%rcx,4), %xmm0, %xmm3 LONG $0xa969e2c4; BYTE $0xd1 // vfmadd213ss %xmm1, %xmm2, %xmm2 # xmm2 = (xmm2 * xmm2) + xmm1 LONG $0x4410fac5; WORD $0x0c8f // vmovss 12(%rdi,%rcx,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x445cfac5; WORD $0x0c8e // vsubss 12(%rsi,%rcx,4), %xmm0, %xmm0 LONG $0xa961e2c4; BYTE $0xda // vfmadd213ss %xmm2, %xmm3, %xmm3 # xmm3 = (xmm3 * xmm3) + xmm2 LONG $0xa979e2c4; BYTE $0xc3 // vfmadd213ss %xmm3, %xmm0, %xmm0 # xmm0 = (xmm0 * xmm0) + xmm3 LONG $0x04c18348 // addq $4, %rcx WORD $0x3949; BYTE $0xc8 // cmpq %rcx, %r8 JNE LBB11_21 LBB11_14: WORD $0x8548; BYTE $0xc0 // testq %rax, %rax JE LBB11_17 LONG $0x8e148d48 // leaq (%rsi,%rcx,4), %rdx LONG $0x8f0c8d48 // leaq (%rdi,%rcx,4), %rcx WORD $0xf631 // xorl %esi, %esi LBB11_16: LONG $0x0c10fac5; BYTE $0xb1 // vmovss (%rcx,%rsi,4), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0x0c5cf2c5; BYTE $0xb2 // vsubss (%rdx,%rsi,4), %xmm1, %xmm1 LONG $0xb971e2c4; BYTE $0xc1 // vfmadd231ss %xmm1, %xmm1, %xmm0 # xmm0 = (xmm1 * xmm1) + xmm0 WORD $0xff48; BYTE $0xc6 // incq %rsi WORD $0x3948; BYTE $0xf0 // cmpq %rsi, %rax JNE LBB11_16 LBB11_17: LONG $0xc051fac5 // vsqrtss %xmm0, %xmm0, %xmm0 WORD $0x8948; BYTE $0xec // movq %rbp, %rsp BYTE $0x5d // popq %rbp WORD $0xf8c5; BYTE $0x77 // vzeroupper MOVSS X0, result+24(FP) RET TEXT ·_mm512_mm(SB), $0-88 MOVQ transA+0(FP), DI MOVQ transB+1(FP), SI MOVQ m+8(FP), DX MOVQ n+16(FP), CX MOVQ k+24(FP), R8 MOVQ a+32(FP), R9 PUSHQ ldc+72(FP) PUSHQ c+64(FP) PUSHQ ldb+56(FP) PUSHQ b+48(FP) PUSHQ lda+40(FP) PUSHQ $0 BYTE $0x55 // pushq %rbp WORD $0x8948; BYTE $0xe5 // movq %rsp, %rbp WORD $0x5741 // pushq %r15 WORD $0x5641 // pushq %r14 WORD $0x5541 // pushq %r13 WORD $0x5441 // pushq %r12 BYTE $0x53 // pushq %rbx LONG $0xf8e48348 // andq $-8, %rsp LONG $0x88ec8148; WORD $0x0000; BYTE $0x00 // subq $136, %rsp LONG $0x244c894c; BYTE $0x18 // movq %r9, 24(%rsp) # 8-byte Spill LONG $0x30458b48 // movq 48(%rbp), %rax LONG $0x24448948; BYTE $0x70 // movq %rax, 112(%rsp) # 8-byte Spill LONG $0x28458b48 // movq 40(%rbp), %rax LONG $0x24448948; BYTE $0x68 // movq %rax, 104(%rsp) # 8-byte Spill WORD $0xf889 // movl %edi, %eax WORD $0x0840; BYTE $0xf0 // orb %sil, %al LONG $0x2404894c // movq %r8, (%rsp) # 8-byte Spill LONG $0x24548948; BYTE $0x38 // movq %rdx, 56(%rsp) # 8-byte Spill JE LBB12_1 WORD $0xf089 // movl %esi, %eax WORD $0x0134 // xorb $1, %al WORD $0x0840; BYTE $0xf8 // orb %dil, %al JE LBB12_22 WORD $0x8941; BYTE $0xf9 // movl %edi, %r9d LONG $0x01f18041 // xorb $1, %r9b WORD $0x0841; BYTE $0xf1 // orb %sil, %r9b WORD $0x854d; BYTE $0xc0 // testq %r8, %r8 LONG $0xc09f0f41 // setg %r8b WORD $0x8548; BYTE $0xc9 // testq %rcx, %rcx WORD $0x9f0f; BYTE $0xc0 // setg %al WORD $0x2044; BYTE $0xc0 // andb %r8b, %al WORD $0x8445; BYTE $0xc9 // testb %r9b, %r9b JE LBB12_99 WORD $0x8548; BYTE $0xd2 // testq %rdx, %rdx WORD $0x9f0f; BYTE $0xc2 // setg %dl WORD $0x2040; BYTE $0xf0 // andb %sil, %al WORD $0x2040; BYTE $0xfa // andb %dil, %dl WORD $0xc220 // andb %al, %dl WORD $0xfa80; BYTE $0x01 // cmpb $1, %dl JNE LBB12_140 LONG $0x30458b48 // movq 48(%rbp), %rax QUAD $0x00000000850c8d4c // leaq (,%rax,4), %r9 LONG $0x24148b48 // movq (%rsp), %rdx # 8-byte Reload LONG $0xff428d48 // leaq -1(%rdx), %rax LONG $0x45af0f48; BYTE $0x10 // imulq 16(%rbp), %rax WORD $0x0148; BYTE $0xca // addq %rcx, %rdx LONG $0x187d8b48 // movq 24(%rbp), %rdi LONG $0x97548d48; BYTE $0xfc // leaq -4(%rdi,%rdx,4), %rdx LONG $0x24548948; BYTE $0x48 // movq %rdx, 72(%rsp) # 8-byte Spill LONG $0x28458b4c // movq 40(%rbp), %r8 LONG $0x88148d49 // leaq (%r8,%rcx,4), %rdx LONG $0x24548948; BYTE $0x40 // movq %rdx, 64(%rsp) # 8-byte Spill LONG $0x24548b48; BYTE $0x18 // movq 24(%rsp), %rdx # 8-byte Reload LONG $0x82048d48 // leaq (%rdx,%rax,4), %rax LONG $0x04c08348 // addq $4, %rax QUAD $0x0000008024848948 // movq %rax, 128(%rsp) # 8-byte Spill LONG $0x08f98348 // cmpq $8, %rcx WORD $0x930f; BYTE $0xc0 // setae %al LONG $0x20558b48 // movq 32(%rbp), %rdx LONG $0x01fa8348 // cmpq $1, %rdx LONG $0xc2940f41 // sete %r10b WORD $0x2041; BYTE $0xc2 // andb %al, %r10b QUAD $0xffffffffffc0be48; WORD $0x7fff // movabsq $9223372036854775744, %rsi # imm = 0x7FFFFFFFFFFFFFC0 WORD $0x8948; BYTE $0xc8 // movq %rcx, %rax WORD $0x2148; BYTE $0xf0 // andq %rsi, %rax LONG $0x24448948; BYTE $0x28 // movq %rax, 40(%rsp) # 8-byte Spill LONG $0x38ce8348 // orq $56, %rsi WORD $0x2148; BYTE $0xce // andq %rcx, %rsi LONG $0xff418d48 // leaq -1(%rcx), %rax LONG $0x24448948; BYTE $0x20 // movq %rax, 32(%rsp) # 8-byte Spill LONG $0xc0878d48; WORD $0x0000; BYTE $0x00 // leaq 192(%rdi), %rax LONG $0x24448948; BYTE $0x78 // movq %rax, 120(%rsp) # 8-byte Spill LONG $0xc0808d49; WORD $0x0000; BYTE $0x00 // leaq 192(%r8), %rax QUAD $0x00000000953c8d48 // leaq (,%rdx,4), %rdi LONG $0x247c8948; BYTE $0x60 // movq %rdi, 96(%rsp) # 8-byte Spill QUAD $0x00000000d53c8d48 // leaq (,%rdx,8), %rdi WORD $0xf641; BYTE $0xd2 // notb %r10b LONG $0x24548844; BYTE $0x0f // movb %r10b, 15(%rsp) # 1-byte Spill WORD $0x3145; BYTE $0xf6 // xorl %r14d, %r14d WORD $0x894c; BYTE $0xc2 // movq %r8, %rdx LONG $0x244c894c; BYTE $0x50 // movq %r9, 80(%rsp) # 8-byte Spill JMP LBB12_121 LBB12_139: LONG $0x24748b4c; BYTE $0x58 // movq 88(%rsp), %r14 # 8-byte Reload WORD $0xff49; BYTE $0xc6 // incq %r14 LONG $0x244c8b4c; BYTE $0x50 // movq 80(%rsp), %r9 # 8-byte Reload WORD $0x014c; BYTE $0xc8 // addq %r9, %rax WORD $0x014c; BYTE $0xca // addq %r9, %rdx LONG $0x24743b4c; BYTE $0x38 // cmpq 56(%rsp), %r14 # 8-byte Folded Reload JE LBB12_140 LBB12_121: QUAD $0x000000000000b849; WORD $0x2000 // movabsq $2305843009213693952, %r8 # imm = 0x2000000000000000 LONG $0x1045854c // testq %r8, 16(%rbp) LONG $0xc0950f41 // setne %r8b LONG $0xceaf0f4d // imulq %r14, %r9 LONG $0x287d8b4c // movq 40(%rbp), %r15 LONG $0x0f148d4f // leaq (%r15,%r9), %r10 LONG $0x244c034c; BYTE $0x40 // addq 64(%rsp), %r9 # 8-byte Folded Reload LONG $0x245c8b4c; BYTE $0x18 // movq 24(%rsp), %r11 # 8-byte Reload LONG $0xb31c8d4b // leaq (%r11,%r14,4), %rbx QUAD $0x00000080249c8b4c // movq 128(%rsp), %r11 # 8-byte Reload LONG $0xb31c8d4f // leaq (%r11,%r14,4), %r11 LONG $0x2474894c; BYTE $0x58 // movq %r14, 88(%rsp) # 8-byte Spill LONG $0x75af0f4c; BYTE $0x30 // imulq 48(%rbp), %r14 LONG $0xb7348d4f // leaq (%r15,%r14,4), %r14 LONG $0x2474894c; BYTE $0x30 // movq %r14, 48(%rsp) # 8-byte Spill WORD $0x394d; BYTE $0xda // cmpq %r11, %r10 LONG $0xc3920f41 // setb %r11b WORD $0x394c; BYTE $0xcb // cmpq %r9, %rbx LONG $0xc6920f41 // setb %r14b WORD $0x2045; BYTE $0xde // andb %r11b, %r14b WORD $0x0845; BYTE $0xc6 // orb %r8b, %r14b LONG $0x24543b4c; BYTE $0x48 // cmpq 72(%rsp), %r10 # 8-byte Folded Reload LONG $0xc0920f41 // setb %r8b LONG $0x18658b4c // movq 24(%rbp), %r12 WORD $0x394d; BYTE $0xe1 // cmpq %r12, %r9 LONG $0xc1970f41 // seta %r9b WORD $0x2045; BYTE $0xc1 // andb %r8b, %r9b WORD $0x0845; BYTE $0xf1 // orb %r14b, %r9b LONG $0x244c0a44; BYTE $0x0f // orb 15(%rsp), %r9b # 1-byte Folded Reload LONG $0x244c8844; BYTE $0x10 // movb %r9b, 16(%rsp) # 1-byte Spill LONG $0x244c8b4c; BYTE $0x78 // movq 120(%rsp), %r9 # 8-byte Reload WORD $0x3145; BYTE $0xed // xorl %r13d, %r13d JMP LBB12_122 LBB12_138: WORD $0xff49; BYTE $0xc5 // incq %r13 LONG $0x04c18349 // addq $4, %r9 LONG $0x04c48349 // addq $4, %r12 LONG $0x242c3b4c // cmpq (%rsp), %r13 # 8-byte Folded Reload JE LBB12_139 LBB12_122: WORD $0x894d; BYTE $0xe8 // movq %r13, %r8 LONG $0x45af0f4c; BYTE $0x10 // imulq 16(%rbp), %r8 LONG $0x102444f6; BYTE $0x01 // testb $1, 16(%rsp) # 1-byte Folded Reload JE LBB12_124 WORD $0x3145; BYTE $0xff // xorl %r15d, %r15d JMP LBB12_133 LBB12_124: LONG $0x40f98348 // cmpq $64, %rcx JAE LBB12_126 WORD $0x3145; BYTE $0xd2 // xorl %r10d, %r10d JMP LBB12_130 LBB12_126: LONG $0x487db262; WORD $0x0418; BYTE $0x83 // vbroadcastss (%rbx,%r8,4), %zmm0 WORD $0x3145; BYTE $0xdb // xorl %r11d, %r11d LONG $0x24548b4c; BYTE $0x28 // movq 40(%rsp), %r10 # 8-byte Reload LBB12_127: QUAD $0xfd994c10487c9162 // vmovups -192(%r9,%r11,4), %zmm1 QUAD $0xfe995410487c9162 // vmovups -128(%r9,%r11,4), %zmm2 QUAD $0xff995c10487c9162 // vmovups -64(%r9,%r11,4), %zmm3 LONG $0x487c9162; WORD $0x2410; BYTE $0x99 // vmovups (%r9,%r11,4), %zmm4 QUAD $0xfd984ca8487db262 // vfmadd213ps -192(%rax,%r11,4), %zmm0, %zmm1 # zmm1 = (zmm0 * zmm1) + mem QUAD $0xfe9854a8487db262 // vfmadd213ps -128(%rax,%r11,4), %zmm0, %zmm2 # zmm2 = (zmm0 * zmm2) + mem QUAD $0xff985ca8487db262 // vfmadd213ps -64(%rax,%r11,4), %zmm0, %zmm3 # zmm3 = (zmm0 * zmm3) + mem LONG $0x487db262; WORD $0x24a8; BYTE $0x98 // vfmadd213ps (%rax,%r11,4), %zmm0, %zmm4 # zmm4 = (zmm0 * zmm4) + mem QUAD $0xfd984c11487cb162 // vmovups %zmm1, -192(%rax,%r11,4) QUAD $0xfe985411487cb162 // vmovups %zmm2, -128(%rax,%r11,4) QUAD $0xff985c11487cb162 // vmovups %zmm3, -64(%rax,%r11,4) LONG $0x487cb162; WORD $0x2411; BYTE $0x98 // vmovups %zmm4, (%rax,%r11,4) LONG $0x40c38349 // addq $64, %r11 WORD $0x394d; BYTE $0xda // cmpq %r11, %r10 JNE LBB12_127 WORD $0x3949; BYTE $0xca // cmpq %rcx, %r10 JE LBB12_138 LONG $0x247c8b4c; BYTE $0x28 // movq 40(%rsp), %r15 # 8-byte Reload WORD $0x894d; BYTE $0xfa // movq %r15, %r10 WORD $0xc1f6; BYTE $0x38 // testb $56, %cl JE LBB12_133 LBB12_130: LONG $0x187da2c4; WORD $0x8304 // vbroadcastss (%rbx,%r8,4), %ymm0 LBB12_131: LONG $0x107c81c4; WORD $0x940c // vmovups (%r12,%r10,4), %ymm1 LONG $0xa87da2c4; WORD $0x920c // vfmadd213ps (%rdx,%r10,4), %ymm0, %ymm1 # ymm1 = (ymm0 * ymm1) + mem LONG $0x117ca1c4; WORD $0x920c // vmovups %ymm1, (%rdx,%r10,4) LONG $0x08c28349 // addq $8, %r10 WORD $0x394c; BYTE $0xd6 // cmpq %r10, %rsi JNE LBB12_131 WORD $0x8949; BYTE $0xf7 // movq %rsi, %r15 WORD $0x3948; BYTE $0xce // cmpq %rcx, %rsi JE LBB12_138 LBB12_133: WORD $0x894d; BYTE $0xfe // movq %r15, %r14 WORD $0xc1f6; BYTE $0x01 // testb $1, %cl JE LBB12_135 LONG $0x18558b4c // movq 24(%rbp), %r10 LONG $0xaa148d4f // leaq (%r10,%r13,4), %r10 LONG $0x107aa1c4; WORD $0x8304 // vmovss (%rbx,%r8,4), %xmm0 # xmm0 = mem[0],zero,zero,zero WORD $0x894d; BYTE $0xfb // movq %r15, %r11 LONG $0x5daf0f4c; BYTE $0x20 // imulq 32(%rbp), %r11 LONG $0x107a81c4; WORD $0x9a0c // vmovss (%r10,%r11,4), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0x24548b4c; BYTE $0x30 // movq 48(%rsp), %r10 # 8-byte Reload LONG $0xa97982c4; WORD $0xba0c // vfmadd213ss (%r10,%r15,4), %xmm0, %xmm1 # xmm1 = (xmm0 * xmm1) + mem LONG $0x117a81c4; WORD $0xba0c // vmovss %xmm1, (%r10,%r15,4) WORD $0x894d; BYTE $0xfe // movq %r15, %r14 LONG $0x01ce8349 // orq $1, %r14 LBB12_135: LONG $0x247c3b4c; BYTE $0x20 // cmpq 32(%rsp), %r15 # 8-byte Folded Reload JE LBB12_138 LONG $0x24548b4c; BYTE $0x60 // movq 96(%rsp), %r10 # 8-byte Reload WORD $0x894d; BYTE $0xd7 // movq %r10, %r15 LONG $0xfeaf0f4d // imulq %r14, %r15 LONG $0x015e8d4d // leaq 1(%r14), %r11 LONG $0xdaaf0f4d // imulq %r10, %r11 WORD $0x894d; BYTE $0xe2 // movq %r12, %r10 LBB12_137: LONG $0x107aa1c4; WORD $0x8304 // vmovss (%rbx,%r8,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x107a81c4; WORD $0x3a0c // vmovss (%r10,%r15), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0xa979a2c4; WORD $0xb20c // vfmadd213ss (%rdx,%r14,4), %xmm0, %xmm1 # xmm1 = (xmm0 * xmm1) + mem LONG $0x117aa1c4; WORD $0xb20c // vmovss %xmm1, (%rdx,%r14,4) LONG $0x107aa1c4; WORD $0x8304 // vmovss (%rbx,%r8,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x107a81c4; WORD $0x1a0c // vmovss (%r10,%r11), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0xa979a2c4; WORD $0xb24c; BYTE $0x04 // vfmadd213ss 4(%rdx,%r14,4), %xmm0, %xmm1 # xmm1 = (xmm0 * xmm1) + mem LONG $0x117aa1c4; WORD $0xb24c; BYTE $0x04 // vmovss %xmm1, 4(%rdx,%r14,4) LONG $0x02c68349 // addq $2, %r14 WORD $0x0149; BYTE $0xfa // addq %rdi, %r10 WORD $0x394c; BYTE $0xf1 // cmpq %r14, %rcx JNE LBB12_137 JMP LBB12_138 LBB12_1: WORD $0x8548; BYTE $0xd2 // testq %rdx, %rdx WORD $0x9e0f; BYTE $0xc0 // setle %al WORD $0x854d; BYTE $0xc0 // testq %r8, %r8 WORD $0x9e0f; BYTE $0xc2 // setle %dl WORD $0xc208 // orb %al, %dl WORD $0x8548; BYTE $0xc9 // testq %rcx, %rcx WORD $0x9e0f; BYTE $0xc0 // setle %al WORD $0xd008 // orb %dl, %al JNE LBB12_140 LONG $0x30458b48 // movq 48(%rbp), %rax QUAD $0x0000000085048d4c // leaq (,%rax,4), %r8 LONG $0x10458b48 // movq 16(%rbp), %rax QUAD $0x0000000085048d48 // leaq (,%rax,4), %rax LONG $0x24448948; BYTE $0x60 // movq %rax, 96(%rsp) # 8-byte Spill LONG $0x243c8b48 // movq (%rsp), %rdi # 8-byte Reload LONG $0xff478d48 // leaq -1(%rdi), %rax LONG $0x20558b48 // movq 32(%rbp), %rdx LONG $0xc2af0f48 // imulq %rdx, %rax WORD $0x0148; BYTE $0xc8 // addq %rcx, %rax LONG $0x18758b48 // movq 24(%rbp), %rsi LONG $0x86048d48 // leaq (%rsi,%rax,4), %rax LONG $0x24448948; BYTE $0x58 // movq %rax, 88(%rsp) # 8-byte Spill LONG $0x284d8b4c // movq 40(%rbp), %r9 LONG $0x89048d49 // leaq (%r9,%rcx,4), %rax LONG $0x24448948; BYTE $0x50 // movq %rax, 80(%rsp) # 8-byte Spill LONG $0x24448b48; BYTE $0x18 // movq 24(%rsp), %rax # 8-byte Reload LONG $0xb8048d48 // leaq (%rax,%rdi,4), %rax LONG $0x24448948; BYTE $0x48 // movq %rax, 72(%rsp) # 8-byte Spill QUAD $0xffffffffffc0be49; WORD $0x7fff // movabsq $9223372036854775744, %r14 # imm = 0x7FFFFFFFFFFFFFC0 WORD $0x8949; BYTE $0xcc // movq %rcx, %r12 WORD $0x214d; BYTE $0xf4 // andq %r14, %r12 LONG $0x38ce8349 // orq $56, %r14 WORD $0x2149; BYTE $0xce // andq %rcx, %r14 LONG $0xff418d48 // leaq -1(%rcx), %rax LONG $0x24448948; BYTE $0x10 // movq %rax, 16(%rsp) # 8-byte Spill LONG $0xc0868d48; WORD $0x0000; BYTE $0x00 // leaq 192(%rsi), %rax LONG $0x24448948; BYTE $0x40 // movq %rax, 64(%rsp) # 8-byte Spill QUAD $0x0000000095048d48 // leaq (,%rdx,4), %rax LONG $0xc0b18d49; WORD $0x0000; BYTE $0x00 // leaq 192(%r9), %rsi WORD $0x3145; BYTE $0xff // xorl %r15d, %r15d LONG $0x2444894c; BYTE $0x28 // movq %r8, 40(%rsp) # 8-byte Spill JMP LBB12_3 LBB12_20: LONG $0x247c8b4c; BYTE $0x30 // movq 48(%rsp), %r15 # 8-byte Reload WORD $0xff49; BYTE $0xc7 // incq %r15 LONG $0x24448b4c; BYTE $0x28 // movq 40(%rsp), %r8 # 8-byte Reload WORD $0x014c; BYTE $0xc6 // addq %r8, %rsi WORD $0x014d; BYTE $0xc1 // addq %r8, %r9 LONG $0x247c3b4c; BYTE $0x38 // cmpq 56(%rsp), %r15 # 8-byte Folded Reload JE LBB12_140 LBB12_3: QUAD $0x000000000000ba48; WORD $0x2000 // movabsq $2305843009213693952, %rdx # imm = 0x2000000000000000 LONG $0x20558548 // testq %rdx, 32(%rbp) WORD $0x950f; BYTE $0xc2 // setne %dl WORD $0x894c; BYTE $0xc7 // movq %r8, %rdi LONG $0xffaf0f49 // imulq %r15, %rdi LONG $0x28458b4c // movq 40(%rbp), %r8 WORD $0x0149; BYTE $0xf8 // addq %rdi, %r8 LONG $0x247c0348; BYTE $0x50 // addq 80(%rsp), %rdi # 8-byte Folded Reload LONG $0x24548b4c; BYTE $0x60 // movq 96(%rsp), %r10 # 8-byte Reload LONG $0xd7af0f4d // imulq %r15, %r10 LONG $0x246c8b4c; BYTE $0x18 // movq 24(%rsp), %r13 # 8-byte Reload LONG $0x2a1c8d4f // leaq (%r10,%r13), %r11 LONG $0x2454034c; BYTE $0x48 // addq 72(%rsp), %r10 # 8-byte Folded Reload WORD $0x894c; BYTE $0xfb // movq %r15, %rbx LONG $0x5daf0f48; BYTE $0x10 // imulq 16(%rbp), %rbx LONG $0x247c894c; BYTE $0x30 // movq %r15, 48(%rsp) # 8-byte Spill LONG $0x7daf0f4c; BYTE $0x30 // imulq 48(%rbp), %r15 WORD $0x394d; BYTE $0xd0 // cmpq %r10, %r8 LONG $0xc2920f41 // setb %r10b WORD $0x3949; BYTE $0xfb // cmpq %rdi, %r11 LONG $0x9d5c8d49; BYTE $0x00 // leaq (%r13,%rbx,4), %rbx LONG $0x285d8b4c // movq 40(%rbp), %r11 LONG $0xbb1c8d4f // leaq (%r11,%r15,4), %r11 LONG $0x245c894c; BYTE $0x20 // movq %r11, 32(%rsp) # 8-byte Spill LONG $0xc5920f41 // setb %r13b WORD $0x2045; BYTE $0xd5 // andb %r10b, %r13b LONG $0x24443b4c; BYTE $0x58 // cmpq 88(%rsp), %r8 # 8-byte Folded Reload LONG $0xc0920f41 // setb %r8b LONG $0x18558b4c // movq 24(%rbp), %r10 WORD $0x394c; BYTE $0xd7 // cmpq %r10, %rdi LONG $0xc7970f40 // seta %dil WORD $0x2044; BYTE $0xc7 // andb %r8b, %dil WORD $0x0841; BYTE $0xd5 // orb %dl, %r13b WORD $0x0841; BYTE $0xfd // orb %dil, %r13b LONG $0x24548b48; BYTE $0x40 // movq 64(%rsp), %rdx # 8-byte Reload WORD $0xff31 // xorl %edi, %edi LONG $0x241c8b4c // movq (%rsp), %r11 # 8-byte Reload JMP LBB12_4 LBB12_19: WORD $0xff48; BYTE $0xc7 // incq %rdi WORD $0x0148; BYTE $0xc2 // addq %rax, %rdx WORD $0x0149; BYTE $0xc2 // addq %rax, %r10 WORD $0x394c; BYTE $0xdf // cmpq %r11, %rdi JE LBB12_20 LBB12_4: LONG $0x08f98348 // cmpq $8, %rcx LONG $0xc0920f41 // setb %r8b WORD $0x0845; BYTE $0xe8 // orb %r13b, %r8b LONG $0x01c0f641 // testb $1, %r8b JE LBB12_6 WORD $0x3145; BYTE $0xc0 // xorl %r8d, %r8d JMP LBB12_15 LBB12_6: LONG $0x40f98348 // cmpq $64, %rcx JAE LBB12_8 WORD $0x3145; BYTE $0xff // xorl %r15d, %r15d JMP LBB12_12 LBB12_8: LONG $0x487df262; WORD $0x0418; BYTE $0xbb // vbroadcastss (%rbx,%rdi,4), %zmm0 WORD $0x3145; BYTE $0xc0 // xorl %r8d, %r8d LBB12_9: QUAD $0xfd824c10487cb162 // vmovups -192(%rdx,%r8,4), %zmm1 QUAD $0xfe825410487cb162 // vmovups -128(%rdx,%r8,4), %zmm2 QUAD $0xff825c10487cb162 // vmovups -64(%rdx,%r8,4), %zmm3 LONG $0x487cb162; WORD $0x2410; BYTE $0x82 // vmovups (%rdx,%r8,4), %zmm4 QUAD $0xfd864ca8487db262 // vfmadd213ps -192(%rsi,%r8,4), %zmm0, %zmm1 # zmm1 = (zmm0 * zmm1) + mem QUAD $0xfe8654a8487db262 // vfmadd213ps -128(%rsi,%r8,4), %zmm0, %zmm2 # zmm2 = (zmm0 * zmm2) + mem QUAD $0xff865ca8487db262 // vfmadd213ps -64(%rsi,%r8,4), %zmm0, %zmm3 # zmm3 = (zmm0 * zmm3) + mem LONG $0x487db262; WORD $0x24a8; BYTE $0x86 // vfmadd213ps (%rsi,%r8,4), %zmm0, %zmm4 # zmm4 = (zmm0 * zmm4) + mem QUAD $0xfd864c11487cb162 // vmovups %zmm1, -192(%rsi,%r8,4) QUAD $0xfe865411487cb162 // vmovups %zmm2, -128(%rsi,%r8,4) QUAD $0xff865c11487cb162 // vmovups %zmm3, -64(%rsi,%r8,4) LONG $0x487cb162; WORD $0x2411; BYTE $0x86 // vmovups %zmm4, (%rsi,%r8,4) LONG $0x40c08349 // addq $64, %r8 WORD $0x394d; BYTE $0xc4 // cmpq %r8, %r12 JNE LBB12_9 WORD $0x3949; BYTE $0xcc // cmpq %rcx, %r12 JE LBB12_19 WORD $0x894d; BYTE $0xe7 // movq %r12, %r15 WORD $0x894d; BYTE $0xe0 // movq %r12, %r8 WORD $0xc1f6; BYTE $0x38 // testb $56, %cl JE LBB12_15 LBB12_12: LONG $0x187de2c4; WORD $0xbb04 // vbroadcastss (%rbx,%rdi,4), %ymm0 LBB12_13: LONG $0x107c81c4; WORD $0xba0c // vmovups (%r10,%r15,4), %ymm1 LONG $0xa87d82c4; WORD $0xb90c // vfmadd213ps (%r9,%r15,4), %ymm0, %ymm1 # ymm1 = (ymm0 * ymm1) + mem LONG $0x117c81c4; WORD $0xb90c // vmovups %ymm1, (%r9,%r15,4) LONG $0x08c78349 // addq $8, %r15 WORD $0x394d; BYTE $0xfe // cmpq %r15, %r14 JNE LBB12_13 WORD $0x894d; BYTE $0xf0 // movq %r14, %r8 WORD $0x3949; BYTE $0xce // cmpq %rcx, %r14 JE LBB12_19 LBB12_15: WORD $0x894d; BYTE $0xc7 // movq %r8, %r15 WORD $0xc1f6; BYTE $0x01 // testb $1, %cl JE LBB12_17 WORD $0x8949; BYTE $0xff // movq %rdi, %r15 LONG $0x7daf0f4c; BYTE $0x20 // imulq 32(%rbp), %r15 LONG $0x185d8b4c // movq 24(%rbp), %r11 LONG $0xbb3c8d4f // leaq (%r11,%r15,4), %r15 LONG $0x241c8b4c // movq (%rsp), %r11 # 8-byte Reload LONG $0x0410fac5; BYTE $0xbb // vmovss (%rbx,%rdi,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x107a81c4; WORD $0x870c // vmovss (%r15,%r8,4), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0x247c8b4c; BYTE $0x20 // movq 32(%rsp), %r15 # 8-byte Reload LONG $0xa97982c4; WORD $0x870c // vfmadd213ss (%r15,%r8,4), %xmm0, %xmm1 # xmm1 = (xmm0 * xmm1) + mem LONG $0x117a81c4; WORD $0x870c // vmovss %xmm1, (%r15,%r8,4) WORD $0x894d; BYTE $0xc7 // movq %r8, %r15 LONG $0x01cf8349 // orq $1, %r15 LBB12_17: LONG $0x24443b4c; BYTE $0x10 // cmpq 16(%rsp), %r8 # 8-byte Folded Reload JE LBB12_19 LBB12_18: LONG $0x0410fac5; BYTE $0xbb // vmovss (%rbx,%rdi,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x107a81c4; WORD $0xba0c // vmovss (%r10,%r15,4), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0xa97982c4; WORD $0xb90c // vfmadd213ss (%r9,%r15,4), %xmm0, %xmm1 # xmm1 = (xmm0 * xmm1) + mem LONG $0x117a81c4; WORD $0xb90c // vmovss %xmm1, (%r9,%r15,4) LONG $0x0410fac5; BYTE $0xbb // vmovss (%rbx,%rdi,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x107a81c4; WORD $0xba4c; BYTE $0x04 // vmovss 4(%r10,%r15,4), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0xa97982c4; WORD $0xb94c; BYTE $0x04 // vfmadd213ss 4(%r9,%r15,4), %xmm0, %xmm1 # xmm1 = (xmm0 * xmm1) + mem LONG $0x117a81c4; WORD $0xb94c; BYTE $0x04 // vmovss %xmm1, 4(%r9,%r15,4) LONG $0x02c78349 // addq $2, %r15 WORD $0x394c; BYTE $0xf9 // cmpq %r15, %rcx JNE LBB12_18 JMP LBB12_19 LBB12_22: WORD $0x8548; BYTE $0xd2 // testq %rdx, %rdx JLE LBB12_140 LONG $0x24148b48 // movq (%rsp), %rdx # 8-byte Reload LONG $0x0f428d48 // leaq 15(%rdx), %rax WORD $0x8548; BYTE $0xd2 // testq %rdx, %rdx LONG $0xc2490f48 // cmovnsq %rdx, %rax WORD $0x8548; BYTE $0xc9 // testq %rcx, %rcx JLE LBB12_140 WORD $0x8948; BYTE $0xc6 // movq %rax, %rsi LONG $0x04fec148 // sarq $4, %rsi LONG $0xf0e08348 // andq $-16, %rax LONG $0x24248b4c // movq (%rsp), %r12 # 8-byte Reload WORD $0x2949; BYTE $0xc4 // subq %rax, %r12 LONG $0x07fc8349 // cmpq $7, %r12 JLE LBB12_25 LONG $0x247c8d41; BYTE $0xf8 // leal -8(%r12), %edi WORD $0xfe83; BYTE $0x01 // cmpl $1, %esi JLE LBB12_30 LONG $0x08fc8349 // cmpq $8, %r12 JLE LBB12_35 LONG $0xff468d44 // leal -1(%rsi), %r8d WORD $0x468d; BYTE $0xfe // leal -2(%rsi), %eax WORD $0x0489; BYTE $0x24 // movl %eax, (%rsp) # 4-byte Spill WORD $0x8944; BYTE $0xc0 // movl %r8d, %eax WORD $0xe083; BYTE $0xfc // andl $-4, %eax LONG $0x20244489 // movl %eax, 32(%rsp) # 4-byte Spill LONG $0x10458b48 // movq 16(%rbp), %rax QUAD $0x0000000085048d48 // leaq (,%rax,4), %rax LONG $0x24448948; BYTE $0x28 // movq %rax, 40(%rsp) # 8-byte Spill LONG $0x20458b48 // movq 32(%rbp), %rax QUAD $0x00000000852c8d4c // leaq (,%rax,4), %r13 WORD $0xfe40; BYTE $0xce // decb %sil LONG $0xf6b60f40 // movzbl %sil, %esi WORD $0xe683; BYTE $0x03 // andl $3, %esi WORD $0xe6c1; BYTE $0x06 // shll $6, %esi WORD $0x3145; BYTE $0xdb // xorl %r11d, %r11d LONG $0x244c8b4c; BYTE $0x18 // movq 24(%rsp), %r9 # 8-byte Reload JMP LBB12_40 LBB12_58: LONG $0x245c8b4c; BYTE $0x30 // movq 48(%rsp), %r11 # 8-byte Reload WORD $0xff49; BYTE $0xc3 // incq %r11 LONG $0x244c034c; BYTE $0x28 // addq 40(%rsp), %r9 # 8-byte Folded Reload LONG $0x245c3b4c; BYTE $0x38 // cmpq 56(%rsp), %r11 # 8-byte Folded Reload JE LBB12_140 LBB12_40: WORD $0x894c; BYTE $0xd8 // movq %r11, %rax LONG $0x45af0f48; BYTE $0x10 // imulq 16(%rbp), %rax LONG $0x24548b4c; BYTE $0x18 // movq 24(%rsp), %r10 # 8-byte Reload LONG $0x82148d49 // leaq (%r10,%rax,4), %rdx LONG $0x82048d49 // leaq (%r10,%rax,4), %rax LONG $0x40c08348 // addq $64, %rax LONG $0x24448948; BYTE $0x10 // movq %rax, 16(%rsp) # 8-byte Spill LONG $0x245c894c; BYTE $0x30 // movq %r11, 48(%rsp) # 8-byte Spill WORD $0x894c; BYTE $0xd8 // movq %r11, %rax LONG $0x45af0f48; BYTE $0x30 // imulq 48(%rbp), %rax LONG $0x28558b4c // movq 40(%rbp), %r10 LONG $0x82248d4d // leaq (%r10,%rax,4), %r12 LONG $0x185d8b48 // movq 24(%rbp), %rbx WORD $0xc031 // xorl %eax, %eax JMP LBB12_41 LBB12_57: LONG $0x117ac1c4; WORD $0x8404 // vmovss %xmm0, (%r12,%rax,4) WORD $0xff48; BYTE $0xc0 // incq %rax WORD $0x014c; BYTE $0xeb // addq %r13, %rbx WORD $0x3948; BYTE $0xc8 // cmpq %rcx, %rax JE LBB12_58 LBB12_41: WORD $0x8949; BYTE $0xc2 // movq %rax, %r10 LONG $0x55af0f4c; BYTE $0x20 // imulq 32(%rbp), %r10 LONG $0x487cf162; WORD $0x0210 // vmovups (%rdx), %zmm0 LONG $0x185d8b4c // movq 24(%rbp), %r11 LONG $0x487c9162; WORD $0x0459; BYTE $0x93 // vmulps (%r11,%r10,4), %zmm0, %zmm0 LONG $0x03243c83 // cmpl $3, (%rsp) # 4-byte Folded Reload JAE LBB12_49 LONG $0x93348d4f // leaq (%r11,%r10,4), %r14 LONG $0x40c68349 // addq $64, %r14 LONG $0x247c8b4c; BYTE $0x10 // movq 16(%rsp), %r15 # 8-byte Reload LONG $0x03c0f641 // testb $3, %r8b JNE LBB12_53 JMP LBB12_56 LBB12_49: LONG $0x24548b44; BYTE $0x20 // movl 32(%rsp), %r10d # 4-byte Reload WORD $0x3145; BYTE $0xf6 // xorl %r14d, %r14d LBB12_50: QUAD $0x01314c10487c9162 // vmovups 64(%r9,%r14), %zmm1 QUAD $0x02315410487c9162 // vmovups 128(%r9,%r14), %zmm2 QUAD $0x03315c10487c9162 // vmovups 192(%r9,%r14), %zmm3 QUAD $0x01334c98487db262 // vfmadd132ps 64(%rbx,%r14), %zmm0, %zmm1 # zmm1 = (zmm1 * mem) + zmm0 QUAD $0x02334cb8486db262 // vfmadd231ps 128(%rbx,%r14), %zmm2, %zmm1 # zmm1 = (zmm2 * mem) + zmm1 QUAD $0x04315410487c9162 // vmovups 256(%r9,%r14), %zmm2 QUAD $0x03334cb84865b262 // vfmadd231ps 192(%rbx,%r14), %zmm3, %zmm1 # zmm1 = (zmm3 * mem) + zmm1 LONG $0x487cf162; WORD $0xc128 // vmovaps %zmm1, %zmm0 QUAD $0x043344b8486db262 // vfmadd231ps 256(%rbx,%r14), %zmm2, %zmm0 # zmm0 = (zmm2 * mem) + zmm0 LONG $0x00c68149; WORD $0x0001; BYTE $0x00 // addq $256, %r14 # imm = 0x100 LONG $0xfcc28341 // addl $-4, %r10d JNE LBB12_50 LONG $0x31148d4f // leaq (%r9,%r14), %r10 LONG $0x331c8d4e // leaq (%rbx,%r14), %r11 LONG $0x313c8d4f // leaq (%r9,%r14), %r15 LONG $0x40c78349 // addq $64, %r15 WORD $0x0149; BYTE $0xde // addq %rbx, %r14 LONG $0x40c68349 // addq $64, %r14 LONG $0x03c0f641 // testb $3, %r8b JE LBB12_56 LBB12_53: WORD $0x3145; BYTE $0xd2 // xorl %r10d, %r10d WORD $0x3145; BYTE $0xdb // xorl %r11d, %r11d LBB12_54: LONG $0x487cd162; WORD $0x0f10 // vmovups (%r15), %zmm1 LONG $0x48759262; WORD $0x04b8; BYTE $0x16 // vfmadd231ps (%r14,%r10), %zmm1, %zmm0 # zmm0 = (zmm1 * mem) + zmm0 LONG $0x40c78349 // addq $64, %r15 LONG $0xc0c38349 // addq $-64, %r11 LONG $0x40c28349 // addq $64, %r10 WORD $0x3944; BYTE $0xd6 // cmpl %r10d, %esi JNE LBB12_54 LONG $0xc0578d4d // leaq -64(%r15), %r10 WORD $0x294d; BYTE $0xde // subq %r11, %r14 LONG $0xc05e8d4d // leaq -64(%r14), %r11 LBB12_56: LONG $0x48fdf362; WORD $0xc11b; BYTE $0x01 // vextractf64x4 $1, %zmm0, %ymm1 LONG $0xc158fcc5 // vaddps %ymm1, %ymm0, %ymm0 LONG $0x197de3c4; WORD $0x01c1 // vextractf128 $1, %ymm0, %xmm1 LONG $0xc158f8c5 // vaddps %xmm1, %xmm0, %xmm0 LONG $0xc8c6f9c5; BYTE $0x01 // vshufpd $1, %xmm0, %xmm0, %xmm1 # xmm1 = xmm0[1,0] LONG $0xc158f8c5 // vaddps %xmm1, %xmm0, %xmm0 LONG $0x107cc1c4; BYTE $0x0f // vmovups (%r15), %ymm1 LONG $0x5974c1c4; BYTE $0x0e // vmulps (%r14), %ymm1, %ymm1 LONG $0x197de3c4; WORD $0x01ca // vextractf128 $1, %ymm1, %xmm2 LONG $0xc958e8c5 // vaddps %xmm1, %xmm2, %xmm1 LONG $0xd1c6f1c5; BYTE $0x01 // vshufpd $1, %xmm1, %xmm1, %xmm2 # xmm2 = xmm1[1,0] LONG $0xca58f0c5 // vaddps %xmm2, %xmm1, %xmm1 LONG $0xd116fac5 // vmovshdup %xmm1, %xmm2 # xmm2 = xmm1[1,1,3,3] LONG $0xca58f2c5 // vaddss %xmm2, %xmm1, %xmm1 LONG $0xd016fac5 // vmovshdup %xmm0, %xmm2 # xmm2 = xmm0[1,1,3,3] LONG $0xc258fac5 // vaddss %xmm2, %xmm0, %xmm0 LONG $0xc958fac5 // vaddss %xmm1, %xmm0, %xmm1 LONG $0x107ac1c4; WORD $0x6042 // vmovss 96(%r10), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x9971c2c4; WORD $0x6043 // vfmadd132ss 96(%r11), %xmm1, %xmm0 # xmm0 = (xmm0 * mem) + xmm1 WORD $0xff83; BYTE $0x01 // cmpl $1, %edi JE LBB12_57 LONG $0x107ac1c4; WORD $0x644a // vmovss 100(%r10), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0xb971c2c4; WORD $0x6443 // vfmadd231ss 100(%r11), %xmm1, %xmm0 # xmm0 = (xmm1 * mem) + xmm0 WORD $0xff83; BYTE $0x02 // cmpl $2, %edi JE LBB12_57 LONG $0x107ac1c4; WORD $0x684a // vmovss 104(%r10), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0xb971c2c4; WORD $0x6843 // vfmadd231ss 104(%r11), %xmm1, %xmm0 # xmm0 = (xmm1 * mem) + xmm0 WORD $0xff83; BYTE $0x03 // cmpl $3, %edi JE LBB12_57 LONG $0x107ac1c4; WORD $0x6c4a // vmovss 108(%r10), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0xb971c2c4; WORD $0x6c43 // vfmadd231ss 108(%r11), %xmm1, %xmm0 # xmm0 = (xmm1 * mem) + xmm0 WORD $0xff83; BYTE $0x04 // cmpl $4, %edi JE LBB12_57 LONG $0x107ac1c4; WORD $0x704a // vmovss 112(%r10), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0xb971c2c4; WORD $0x7043 // vfmadd231ss 112(%r11), %xmm1, %xmm0 # xmm0 = (xmm1 * mem) + xmm0 WORD $0xff83; BYTE $0x05 // cmpl $5, %edi JE LBB12_57 LONG $0x107ac1c4; WORD $0x744a // vmovss 116(%r10), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0xb971c2c4; WORD $0x7443 // vfmadd231ss 116(%r11), %xmm1, %xmm0 # xmm0 = (xmm1 * mem) + xmm0 WORD $0xff83; BYTE $0x06 // cmpl $6, %edi JE LBB12_57 LONG $0x107ac1c4; WORD $0x784a // vmovss 120(%r10), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0xb971c2c4; WORD $0x7843 // vfmadd231ss 120(%r11), %xmm1, %xmm0 # xmm0 = (xmm1 * mem) + xmm0 JMP LBB12_57 LBB12_99: WORD $0x8548; BYTE $0xd2 // testq %rdx, %rdx WORD $0x9e0f; BYTE $0xc2 // setle %dl WORD $0x0134 // xorb $1, %al WORD $0xd008 // orb %dl, %al JNE LBB12_140 LONG $0x30458b48 // movq 48(%rbp), %rax QUAD $0x00000000850c8d4c // leaq (,%rax,4), %r9 LONG $0x24048b48 // movq (%rsp), %rax # 8-byte Reload WORD $0xff48; BYTE $0xc8 // decq %rax WORD $0x8948; BYTE $0xc2 // movq %rax, %rdx LONG $0x55af0f48; BYTE $0x10 // imulq 16(%rbp), %rdx LONG $0x20758b48 // movq 32(%rbp), %rsi LONG $0xc6af0f48 // imulq %rsi, %rax WORD $0x0148; BYTE $0xc8 // addq %rcx, %rax LONG $0x18458b4c // movq 24(%rbp), %r8 LONG $0x80048d49 // leaq (%r8,%rax,4), %rax LONG $0x24448948; BYTE $0x58 // movq %rax, 88(%rsp) # 8-byte Spill LONG $0x287d8b48 // movq 40(%rbp), %rdi LONG $0x8f048d48 // leaq (%rdi,%rcx,4), %rax LONG $0x24448948; BYTE $0x50 // movq %rax, 80(%rsp) # 8-byte Spill LONG $0x24448b48; BYTE $0x18 // movq 24(%rsp), %rax # 8-byte Reload LONG $0x90448d48; BYTE $0x04 // leaq 4(%rax,%rdx,4), %rax LONG $0x24448948; BYTE $0x48 // movq %rax, 72(%rsp) # 8-byte Spill QUAD $0xffffffffffc0bc49; WORD $0x7fff // movabsq $9223372036854775744, %r12 # imm = 0x7FFFFFFFFFFFFFC0 WORD $0x8949; BYTE $0xcf // movq %rcx, %r15 WORD $0x214d; BYTE $0xe7 // andq %r12, %r15 LONG $0x38cc8349 // orq $56, %r12 WORD $0x2149; BYTE $0xcc // andq %rcx, %r12 LONG $0xff418d48 // leaq -1(%rcx), %rax LONG $0x24448948; BYTE $0x20 // movq %rax, 32(%rsp) # 8-byte Spill LONG $0xc0808d49; WORD $0x0000; BYTE $0x00 // leaq 192(%r8), %rax LONG $0x24448948; BYTE $0x40 // movq %rax, 64(%rsp) # 8-byte Spill QUAD $0x00000000b52c8d4c // leaq (,%rsi,4), %r13 LONG $0xc0878d48; WORD $0x0000; BYTE $0x00 // leaq 192(%rdi), %rax WORD $0xdb31 // xorl %ebx, %ebx LONG $0x244c894c; BYTE $0x60 // movq %r9, 96(%rsp) # 8-byte Spill JMP LBB12_101 LBB12_118: LONG $0x245c8b48; BYTE $0x28 // movq 40(%rsp), %rbx # 8-byte Reload WORD $0xff48; BYTE $0xc3 // incq %rbx LONG $0x244c8b4c; BYTE $0x60 // movq 96(%rsp), %r9 # 8-byte Reload WORD $0x014c; BYTE $0xc8 // addq %r9, %rax WORD $0x014c; BYTE $0xcf // addq %r9, %rdi LONG $0x245c3b48; BYTE $0x38 // cmpq 56(%rsp), %rbx # 8-byte Folded Reload JE LBB12_140 LBB12_101: QUAD $0x000000000000ba48; WORD $0x2000 // movabsq $2305843009213693952, %rdx # imm = 0x2000000000000000 LONG $0x20558548 // testq %rdx, 32(%rbp) LONG $0x2444950f; BYTE $0x10 // setne 16(%rsp) # 1-byte Folded Spill LONG $0x10558548 // testq %rdx, 16(%rbp) LONG $0xc6950f40 // setne %sil WORD $0x894d; BYTE $0xc8 // movq %r9, %r8 LONG $0xc3af0f4c // imulq %rbx, %r8 LONG $0x28558b48 // movq 40(%rbp), %rdx LONG $0x02148d4e // leaq (%rdx,%r8), %r10 LONG $0x2444034c; BYTE $0x50 // addq 80(%rsp), %r8 # 8-byte Folded Reload LONG $0x244c8b4c; BYTE $0x18 // movq 24(%rsp), %r9 # 8-byte Reload LONG $0x990c8d4d // leaq (%r9,%rbx,4), %r9 LONG $0x245c8b4c; BYTE $0x48 // movq 72(%rsp), %r11 # 8-byte Reload LONG $0x9b1c8d4d // leaq (%r11,%rbx,4), %r11 LONG $0x245c8948; BYTE $0x28 // movq %rbx, 40(%rsp) # 8-byte Spill LONG $0x5daf0f48; BYTE $0x30 // imulq 48(%rbp), %rbx WORD $0x394d; BYTE $0xda // cmpq %r11, %r10 LONG $0xc3920f41 // setb %r11b WORD $0x394d; BYTE $0xc1 // cmpq %r8, %r9 LONG $0xc6920f41 // setb %r14b WORD $0x2045; BYTE $0xde // andb %r11b, %r14b LONG $0x9a148d48 // leaq (%rdx,%rbx,4), %rdx LONG $0x24548948; BYTE $0x30 // movq %rdx, 48(%rsp) # 8-byte Spill WORD $0x0841; BYTE $0xf6 // orb %sil, %r14b LONG $0x24543b4c; BYTE $0x58 // cmpq 88(%rsp), %r10 # 8-byte Folded Reload LONG $0xc6920f40 // setb %sil LONG $0x18558b4c // movq 24(%rbp), %r10 WORD $0x394d; BYTE $0xd0 // cmpq %r10, %r8 WORD $0x970f; BYTE $0xc2 // seta %dl WORD $0x2040; BYTE $0xf2 // andb %sil, %dl LONG $0x1024540a // orb 16(%rsp), %dl # 1-byte Folded Reload WORD $0x0844; BYTE $0xf2 // orb %r14b, %dl LONG $0x10245488 // movb %dl, 16(%rsp) # 1-byte Spill LONG $0x24548b48; BYTE $0x40 // movq 64(%rsp), %rdx # 8-byte Reload WORD $0xf631 // xorl %esi, %esi LONG $0x241c8b4c // movq (%rsp), %r11 # 8-byte Reload JMP LBB12_102 LBB12_117: WORD $0xff48; BYTE $0xc6 // incq %rsi WORD $0x014c; BYTE $0xea // addq %r13, %rdx WORD $0x014d; BYTE $0xea // addq %r13, %r10 WORD $0x394c; BYTE $0xde // cmpq %r11, %rsi JE LBB12_118 LBB12_102: LONG $0x08f98348 // cmpq $8, %rcx WORD $0x920f; BYTE $0xc3 // setb %bl WORD $0x8949; BYTE $0xf0 // movq %rsi, %r8 LONG $0x45af0f4c; BYTE $0x10 // imulq 16(%rbp), %r8 LONG $0x10245c0a // orb 16(%rsp), %bl # 1-byte Folded Reload WORD $0xc3f6; BYTE $0x01 // testb $1, %bl JE LBB12_104 WORD $0x3145; BYTE $0xf6 // xorl %r14d, %r14d JMP LBB12_113 LBB12_104: LONG $0x40f98348 // cmpq $64, %rcx JAE LBB12_106 WORD $0xdb31 // xorl %ebx, %ebx JMP LBB12_110 LBB12_106: LONG $0x487d9262; WORD $0x0418; BYTE $0x81 // vbroadcastss (%r9,%r8,4), %zmm0 WORD $0xdb31 // xorl %ebx, %ebx LBB12_107: QUAD $0xfd9a4c10487cf162 // vmovups -192(%rdx,%rbx,4), %zmm1 QUAD $0xfe9a5410487cf162 // vmovups -128(%rdx,%rbx,4), %zmm2 QUAD $0xff9a5c10487cf162 // vmovups -64(%rdx,%rbx,4), %zmm3 LONG $0x487cf162; WORD $0x2410; BYTE $0x9a // vmovups (%rdx,%rbx,4), %zmm4 QUAD $0xfd984ca8487df262 // vfmadd213ps -192(%rax,%rbx,4), %zmm0, %zmm1 # zmm1 = (zmm0 * zmm1) + mem QUAD $0xfe9854a8487df262 // vfmadd213ps -128(%rax,%rbx,4), %zmm0, %zmm2 # zmm2 = (zmm0 * zmm2) + mem QUAD $0xff985ca8487df262 // vfmadd213ps -64(%rax,%rbx,4), %zmm0, %zmm3 # zmm3 = (zmm0 * zmm3) + mem LONG $0x487df262; WORD $0x24a8; BYTE $0x98 // vfmadd213ps (%rax,%rbx,4), %zmm0, %zmm4 # zmm4 = (zmm0 * zmm4) + mem QUAD $0xfd984c11487cf162 // vmovups %zmm1, -192(%rax,%rbx,4) QUAD $0xfe985411487cf162 // vmovups %zmm2, -128(%rax,%rbx,4) QUAD $0xff985c11487cf162 // vmovups %zmm3, -64(%rax,%rbx,4) LONG $0x487cf162; WORD $0x2411; BYTE $0x98 // vmovups %zmm4, (%rax,%rbx,4) LONG $0x40c38348 // addq $64, %rbx WORD $0x3949; BYTE $0xdf // cmpq %rbx, %r15 JNE LBB12_107 WORD $0x3949; BYTE $0xcf // cmpq %rcx, %r15 JE LBB12_117 WORD $0x894c; BYTE $0xfb // movq %r15, %rbx WORD $0x894d; BYTE $0xfe // movq %r15, %r14 WORD $0xc1f6; BYTE $0x38 // testb $56, %cl JE LBB12_113 LBB12_110: LONG $0x187d82c4; WORD $0x8104 // vbroadcastss (%r9,%r8,4), %ymm0 LBB12_111: LONG $0x107cc1c4; WORD $0x9a0c // vmovups (%r10,%rbx,4), %ymm1 LONG $0xa87de2c4; WORD $0x9f0c // vfmadd213ps (%rdi,%rbx,4), %ymm0, %ymm1 # ymm1 = (ymm0 * ymm1) + mem LONG $0x0c11fcc5; BYTE $0x9f // vmovups %ymm1, (%rdi,%rbx,4) LONG $0x08c38348 // addq $8, %rbx WORD $0x3949; BYTE $0xdc // cmpq %rbx, %r12 JNE LBB12_111 WORD $0x894d; BYTE $0xe6 // movq %r12, %r14 WORD $0x3949; BYTE $0xcc // cmpq %rcx, %r12 JE LBB12_117 LBB12_113: WORD $0x894c; BYTE $0xf3 // movq %r14, %rbx WORD $0xc1f6; BYTE $0x01 // testb $1, %cl JE LBB12_115 WORD $0x8948; BYTE $0xf3 // movq %rsi, %rbx LONG $0x5daf0f48; BYTE $0x20 // imulq 32(%rbp), %rbx LONG $0x185d8b4c // movq 24(%rbp), %r11 LONG $0x9b1c8d49 // leaq (%r11,%rbx,4), %rbx LONG $0x241c8b4c // movq (%rsp), %r11 # 8-byte Reload LONG $0x107a81c4; WORD $0x8104 // vmovss (%r9,%r8,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x107aa1c4; WORD $0xb30c // vmovss (%rbx,%r14,4), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0x245c8b48; BYTE $0x30 // movq 48(%rsp), %rbx # 8-byte Reload LONG $0xa979a2c4; WORD $0xb30c // vfmadd213ss (%rbx,%r14,4), %xmm0, %xmm1 # xmm1 = (xmm0 * xmm1) + mem LONG $0x117aa1c4; WORD $0xb30c // vmovss %xmm1, (%rbx,%r14,4) WORD $0x894c; BYTE $0xf3 // movq %r14, %rbx LONG $0x01cb8348 // orq $1, %rbx LBB12_115: LONG $0x24743b4c; BYTE $0x20 // cmpq 32(%rsp), %r14 # 8-byte Folded Reload JE LBB12_117 LBB12_116: LONG $0x107a81c4; WORD $0x8104 // vmovss (%r9,%r8,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x107ac1c4; WORD $0x9a0c // vmovss (%r10,%rbx,4), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0xa979e2c4; WORD $0x9f0c // vfmadd213ss (%rdi,%rbx,4), %xmm0, %xmm1 # xmm1 = (xmm0 * xmm1) + mem LONG $0x0c11fac5; BYTE $0x9f // vmovss %xmm1, (%rdi,%rbx,4) LONG $0x107a81c4; WORD $0x8104 // vmovss (%r9,%r8,4), %xmm0 # xmm0 = mem[0],zero,zero,zero LONG $0x107ac1c4; WORD $0x9a4c; BYTE $0x04 // vmovss 4(%r10,%rbx,4), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0xa979e2c4; WORD $0x9f4c; BYTE $0x04 // vfmadd213ss 4(%rdi,%rbx,4), %xmm0, %xmm1 # xmm1 = (xmm0 * xmm1) + mem LONG $0x4c11fac5; WORD $0x049f // vmovss %xmm1, 4(%rdi,%rbx,4) LONG $0x02c38348 // addq $2, %rbx WORD $0x3948; BYTE $0xd9 // cmpq %rbx, %rcx JNE LBB12_116 JMP LBB12_117 LBB12_25: WORD $0x8948; BYTE $0xf0 // movq %rsi, %rax LONG $0x06e0c148 // shlq $6, %rax QUAD $0x003fffffff80ba48; WORD $0x0000 // movabsq $274877906816, %rdx # imm = 0x3FFFFFFF80 WORD $0x0148; BYTE $0xd0 // addq %rdx, %rax LONG $0x40ca8348 // orq $64, %rdx WORD $0x2148; BYTE $0xc2 // andq %rax, %rdx LONG $0x40c28348 // addq $64, %rdx LONG $0x24148948 // movq %rdx, (%rsp) # 8-byte Spill WORD $0x894c; BYTE $0xe7 // movq %r12, %rdi WORD $0x8944; BYTE $0xe2 // movl %r12d, %edx WORD $0x468d; BYTE $0xff // leal -1(%rsi), %eax LONG $0xfe468d44 // leal -2(%rsi), %r8d LONG $0x24448944; BYTE $0x20 // movl %r8d, 32(%rsp) # 4-byte Spill LONG $0x10244489 // movl %eax, 16(%rsp) # 4-byte Spill WORD $0xe083; BYTE $0xfc // andl $-4, %eax LONG $0x30244489 // movl %eax, 48(%rsp) # 4-byte Spill WORD $0xf089 // movl %esi, %eax WORD $0xc8fe // decb %al LONG $0xe8b60f44 // movzbl %al, %r13d LONG $0x03e58341 // andl $3, %r13d LONG $0x06e5c141 // shll $6, %r13d WORD $0x3145; BYTE $0xc9 // xorl %r9d, %r9d JMP LBB12_26 LBB12_89: LONG $0x244c8b4c; BYTE $0x28 // movq 40(%rsp), %r9 # 8-byte Reload WORD $0xff49; BYTE $0xc1 // incq %r9 LONG $0x244c3b4c; BYTE $0x38 // cmpq 56(%rsp), %r9 # 8-byte Folded Reload JE LBB12_140 LBB12_26: WORD $0x894c; BYTE $0xc8 // movq %r9, %rax LONG $0x45af0f48; BYTE $0x10 // imulq 16(%rbp), %rax LONG $0x24448b4c; BYTE $0x18 // movq 24(%rsp), %r8 # 8-byte Reload LONG $0x801c8d49 // leaq (%r8,%rax,4), %rbx LONG $0x80348d4d // leaq (%r8,%rax,4), %r14 LONG $0x40c68349 // addq $64, %r14 LONG $0x244c894c; BYTE $0x28 // movq %r9, 40(%rsp) # 8-byte Spill WORD $0x894c; BYTE $0xc8 // movq %r9, %rax LONG $0x45af0f48; BYTE $0x30 // imulq 48(%rbp), %rax LONG $0x28458b4c // movq 40(%rbp), %r8 LONG $0x80048d49 // leaq (%r8,%rax,4), %rax LONG $0x24048b4c // movq (%rsp), %r8 # 8-byte Reload LONG $0x183c8d4d // leaq (%r8,%rbx), %r15 LONG $0x40c78349 // addq $64, %r15 WORD $0x3145; BYTE $0xc9 // xorl %r9d, %r9d LONG $0x185d8b4c // movq 24(%rbp), %r11 JMP LBB12_27 LBB12_88: LONG $0x117aa1c4; WORD $0x8804 // vmovss %xmm0, (%rax,%r9,4) WORD $0xff49; BYTE $0xc1 // incq %r9 WORD $0x3949; BYTE $0xc9 // cmpq %rcx, %r9 JE LBB12_89 LBB12_27: WORD $0x894d; BYTE $0xc8 // movq %r9, %r8 LONG $0x45af0f4c; BYTE $0x20 // imulq 32(%rbp), %r8 LONG $0x83048d4f // leaq (%r11,%r8,4), %r8 WORD $0xf685 // testl %esi, %esi JLE LBB12_28 LONG $0x40508d4d // leaq 64(%r8), %r10 LONG $0x487cf162; WORD $0x0310 // vmovups (%rbx), %zmm0 LONG $0x487cd162; WORD $0x0859 // vmulps (%r8), %zmm0, %zmm1 WORD $0xfe83; BYTE $0x02 // cmpl $2, %esi JL LBB12_81 WORD $0x894d; BYTE $0xf0 // movq %r14, %r8 WORD $0x894d; BYTE $0xd3 // movq %r10, %r11 LONG $0x24648b44; BYTE $0x30 // movl 48(%rsp), %r12d # 4-byte Reload LONG $0x20247c83; BYTE $0x03 // cmpl $3, 32(%rsp) # 4-byte Folded Reload JB LBB12_83 LBB12_97: LONG $0x487cd162; WORD $0x0010 // vmovups (%r8), %zmm0 LONG $0x487cd162; WORD $0x5010; BYTE $0x01 // vmovups 64(%r8), %zmm2 LONG $0x487cd162; WORD $0x5810; BYTE $0x02 // vmovups 128(%r8), %zmm3 LONG $0x487cd162; WORD $0x6010; BYTE $0x03 // vmovups 192(%r8), %zmm4 LONG $0x4875d262; WORD $0x0398 // vfmadd132ps (%r11), %zmm1, %zmm0 # zmm0 = (zmm0 * mem) + zmm1 LONG $0x486dd262; WORD $0x43b8; BYTE $0x01 // vfmadd231ps 64(%r11), %zmm2, %zmm0 # zmm0 = (zmm2 * mem) + zmm0 LONG $0x4865d262; WORD $0x43b8; BYTE $0x02 // vfmadd231ps 128(%r11), %zmm3, %zmm0 # zmm0 = (zmm3 * mem) + zmm0 LONG $0x485dd262; WORD $0x43b8; BYTE $0x03 // vfmadd231ps 192(%r11), %zmm4, %zmm0 # zmm0 = (zmm4 * mem) + zmm0 LONG $0x00c08149; WORD $0x0001; BYTE $0x00 // addq $256, %r8 # imm = 0x100 LONG $0x00c38149; WORD $0x0001; BYTE $0x00 // addq $256, %r11 # imm = 0x100 LONG $0x487cf162; WORD $0xc828 // vmovaps %zmm0, %zmm1 LONG $0xfcc48341 // addl $-4, %r12d JNE LBB12_97 LBB12_83: LONG $0x102444f6; BYTE $0x03 // testb $3, 16(%rsp) # 1-byte Folded Reload JE LBB12_86 WORD $0x3145; BYTE $0xe4 // xorl %r12d, %r12d LONG $0x487cf162; WORD $0xc128 // vmovaps %zmm1, %zmm0 LBB12_85: LONG $0x487c9162; WORD $0x0c10; BYTE $0x20 // vmovups (%r8,%r12), %zmm1 LONG $0x48759262; WORD $0x04b8; BYTE $0x23 // vfmadd231ps (%r11,%r12), %zmm1, %zmm0 # zmm0 = (zmm1 * mem) + zmm0 LONG $0x40c48349 // addq $64, %r12 WORD $0x3945; BYTE $0xe5 // cmpl %r12d, %r13d JNE LBB12_85 LBB12_86: LONG $0x2414034c // addq (%rsp), %r10 # 8-byte Folded Reload LONG $0x487cf162; WORD $0xc828 // vmovaps %zmm0, %zmm1 WORD $0x894d; BYTE $0xd0 // movq %r10, %r8 WORD $0x894d; BYTE $0xfa // movq %r15, %r10 LONG $0x185d8b4c // movq 24(%rbp), %r11 JMP LBB12_87 LBB12_28: LONG $0xc957f0c5 // vxorps %xmm1, %xmm1, %xmm1 WORD $0x8949; BYTE $0xda // movq %rbx, %r10 JMP LBB12_87 LBB12_81: WORD $0x894d; BYTE $0xd0 // movq %r10, %r8 WORD $0x894d; BYTE $0xf2 // movq %r14, %r10 LBB12_87: LONG $0x48fdf362; WORD $0xc81b; BYTE $0x01 // vextractf64x4 $1, %zmm1, %ymm0 LONG $0xc058f4c5 // vaddps %ymm0, %ymm1, %ymm0 LONG $0x197de3c4; WORD $0x01c1 // vextractf128 $1, %ymm0, %xmm1 LONG $0xc158f8c5 // vaddps %xmm1, %xmm0, %xmm0 LONG $0xc8c6f9c5; BYTE $0x01 // vshufpd $1, %xmm0, %xmm0, %xmm1 # xmm1 = xmm0[1,0] LONG $0xc158f8c5 // vaddps %xmm1, %xmm0, %xmm0 LONG $0xc816fac5 // vmovshdup %xmm0, %xmm1 # xmm1 = xmm0[1,1,3,3] LONG $0xc158fac5 // vaddss %xmm1, %xmm0, %xmm0 WORD $0x8548; BYTE $0xff // testq %rdi, %rdi JLE LBB12_88 LONG $0x107ac1c4; BYTE $0x0a // vmovss (%r10), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0xb971c2c4; BYTE $0x00 // vfmadd231ss (%r8), %xmm1, %xmm0 # xmm0 = (xmm1 * mem) + xmm0 WORD $0xfa83; BYTE $0x01 // cmpl $1, %edx JE LBB12_88 LONG $0x107ac1c4; WORD $0x044a // vmovss 4(%r10), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0xb971c2c4; WORD $0x0440 // vfmadd231ss 4(%r8), %xmm1, %xmm0 # xmm0 = (xmm1 * mem) + xmm0 WORD $0xfa83; BYTE $0x02 // cmpl $2, %edx JE LBB12_88 LONG $0x107ac1c4; WORD $0x084a // vmovss 8(%r10), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0xb971c2c4; WORD $0x0840 // vfmadd231ss 8(%r8), %xmm1, %xmm0 # xmm0 = (xmm1 * mem) + xmm0 WORD $0xfa83; BYTE $0x03 // cmpl $3, %edx JE LBB12_88 LONG $0x107ac1c4; WORD $0x0c4a // vmovss 12(%r10), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0xb971c2c4; WORD $0x0c40 // vfmadd231ss 12(%r8), %xmm1, %xmm0 # xmm0 = (xmm1 * mem) + xmm0 WORD $0xfa83; BYTE $0x04 // cmpl $4, %edx JE LBB12_88 LONG $0x107ac1c4; WORD $0x104a // vmovss 16(%r10), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0xb971c2c4; WORD $0x1040 // vfmadd231ss 16(%r8), %xmm1, %xmm0 # xmm0 = (xmm1 * mem) + xmm0 WORD $0xfa83; BYTE $0x05 // cmpl $5, %edx JE LBB12_88 LONG $0x107ac1c4; WORD $0x144a // vmovss 20(%r10), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0xb971c2c4; WORD $0x1440 // vfmadd231ss 20(%r8), %xmm1, %xmm0 # xmm0 = (xmm1 * mem) + xmm0 WORD $0xfa83; BYTE $0x06 // cmpl $6, %edx JE LBB12_88 LONG $0x107ac1c4; WORD $0x184a // vmovss 24(%r10), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0xb971c2c4; WORD $0x1840 // vfmadd231ss 24(%r8), %xmm1, %xmm0 # xmm0 = (xmm1 * mem) + xmm0 JMP LBB12_88 LBB12_30: LONG $0x20458b48 // movq 32(%rbp), %rax QUAD $0x0000000085048d48 // leaq (,%rax,4), %rax LONG $0x2464c148; WORD $0x0270 // shlq $2, 112(%rsp) # 8-byte Folded Spill WORD $0xd231 // xorl %edx, %edx JMP LBB12_31 LBB12_72: WORD $0xff48; BYTE $0xc2 // incq %rdx LONG $0x24448b4c; BYTE $0x68 // movq 104(%rsp), %r8 # 8-byte Reload LONG $0x2444034c; BYTE $0x70 // addq 112(%rsp), %r8 # 8-byte Folded Reload LONG $0x2444894c; BYTE $0x68 // movq %r8, 104(%rsp) # 8-byte Spill LONG $0x24543b48; BYTE $0x38 // cmpq 56(%rsp), %rdx # 8-byte Folded Reload JE LBB12_140 LBB12_31: WORD $0x8949; BYTE $0xd1 // movq %rdx, %r9 LONG $0x4daf0f4c; BYTE $0x10 // imulq 16(%rbp), %r9 LONG $0x24548b4c; BYTE $0x18 // movq 24(%rsp), %r10 # 8-byte Reload LONG $0x8a048d4f // leaq (%r10,%r9,4), %r8 LONG $0x8a0c8d4f // leaq (%r10,%r9,4), %r9 LONG $0x40c18349 // addq $64, %r9 LONG $0x187d8b4c // movq 24(%rbp), %r15 WORD $0x894d; BYTE $0xfa // movq %r15, %r10 WORD $0x3145; BYTE $0xdb // xorl %r11d, %r11d JMP LBB12_32 LBB12_71: LONG $0x245c8b48; BYTE $0x68 // movq 104(%rsp), %rbx # 8-byte Reload LONG $0x117aa1c4; WORD $0x9b04 // vmovss %xmm0, (%rbx,%r11,4) WORD $0xff49; BYTE $0xc3 // incq %r11 WORD $0x0149; BYTE $0xc2 // addq %rax, %r10 WORD $0x394c; BYTE $0xd9 // cmpq %r11, %rcx JE LBB12_72 LBB12_32: WORD $0xf685 // testl %esi, %esi JLE LBB12_33 WORD $0x894c; BYTE $0xdb // movq %r11, %rbx LONG $0x5daf0f48; BYTE $0x20 // imulq 32(%rbp), %rbx LONG $0x9f1c8d49 // leaq (%r15,%rbx,4), %rbx LONG $0x487cd162; WORD $0x0010 // vmovups (%r8), %zmm0 LONG $0x487cd162; WORD $0x0259 // vmulps (%r10), %zmm0, %zmm0 LONG $0x40c38348 // addq $64, %rbx WORD $0x894d; BYTE $0xce // movq %r9, %r14 JMP LBB12_70 LBB12_33: LONG $0xc057f8c5 // vxorps %xmm0, %xmm0, %xmm0 WORD $0x894c; BYTE $0xd3 // movq %r10, %rbx WORD $0x894d; BYTE $0xc6 // movq %r8, %r14 LBB12_70: LONG $0x48fdf362; WORD $0xc11b; BYTE $0x01 // vextractf64x4 $1, %zmm0, %ymm1 LONG $0xc158fcc5 // vaddps %ymm1, %ymm0, %ymm0 LONG $0x197de3c4; WORD $0x01c1 // vextractf128 $1, %ymm0, %xmm1 LONG $0xc158f8c5 // vaddps %xmm1, %xmm0, %xmm0 LONG $0xc8c6f9c5; BYTE $0x01 // vshufpd $1, %xmm0, %xmm0, %xmm1 # xmm1 = xmm0[1,0] LONG $0xc158f8c5 // vaddps %xmm1, %xmm0, %xmm0 LONG $0x107cc1c4; BYTE $0x0e // vmovups (%r14), %ymm1 LONG $0x0b59f4c5 // vmulps (%rbx), %ymm1, %ymm1 LONG $0x197de3c4; WORD $0x01ca // vextractf128 $1, %ymm1, %xmm2 LONG $0xc958e8c5 // vaddps %xmm1, %xmm2, %xmm1 LONG $0xd1c6f1c5; BYTE $0x01 // vshufpd $1, %xmm1, %xmm1, %xmm2 # xmm2 = xmm1[1,0] LONG $0xca58f0c5 // vaddps %xmm2, %xmm1, %xmm1 LONG $0xd116fac5 // vmovshdup %xmm1, %xmm2 # xmm2 = xmm1[1,1,3,3] LONG $0xca58f2c5 // vaddss %xmm2, %xmm1, %xmm1 LONG $0xd016fac5 // vmovshdup %xmm0, %xmm2 # xmm2 = xmm0[1,1,3,3] LONG $0xc258fac5 // vaddss %xmm2, %xmm0, %xmm0 LONG $0xc158fac5 // vaddss %xmm1, %xmm0, %xmm0 LONG $0x08fc8349 // cmpq $8, %r12 JLE LBB12_71 LONG $0x107ac1c4; WORD $0x204e // vmovss 32(%r14), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0xb971e2c4; WORD $0x2043 // vfmadd231ss 32(%rbx), %xmm1, %xmm0 # xmm0 = (xmm1 * mem) + xmm0 WORD $0xff83; BYTE $0x01 // cmpl $1, %edi JE LBB12_71 LONG $0x107ac1c4; WORD $0x244e // vmovss 36(%r14), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0xb971e2c4; WORD $0x2443 // vfmadd231ss 36(%rbx), %xmm1, %xmm0 # xmm0 = (xmm1 * mem) + xmm0 WORD $0xff83; BYTE $0x02 // cmpl $2, %edi JE LBB12_71 LONG $0x107ac1c4; WORD $0x284e // vmovss 40(%r14), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0xb971e2c4; WORD $0x2843 // vfmadd231ss 40(%rbx), %xmm1, %xmm0 # xmm0 = (xmm1 * mem) + xmm0 WORD $0xff83; BYTE $0x03 // cmpl $3, %edi JE LBB12_71 LONG $0x107ac1c4; WORD $0x2c4e // vmovss 44(%r14), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0xb971e2c4; WORD $0x2c43 // vfmadd231ss 44(%rbx), %xmm1, %xmm0 # xmm0 = (xmm1 * mem) + xmm0 WORD $0xff83; BYTE $0x04 // cmpl $4, %edi JE LBB12_71 LONG $0x107ac1c4; WORD $0x304e // vmovss 48(%r14), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0xb971e2c4; WORD $0x3043 // vfmadd231ss 48(%rbx), %xmm1, %xmm0 # xmm0 = (xmm1 * mem) + xmm0 WORD $0xff83; BYTE $0x05 // cmpl $5, %edi JE LBB12_71 LONG $0x107ac1c4; WORD $0x344e // vmovss 52(%r14), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0xb971e2c4; WORD $0x3443 // vfmadd231ss 52(%rbx), %xmm1, %xmm0 # xmm0 = (xmm1 * mem) + xmm0 WORD $0xff83; BYTE $0x06 // cmpl $6, %edi JE LBB12_71 LONG $0x107ac1c4; WORD $0x384e // vmovss 56(%r14), %xmm1 # xmm1 = mem[0],zero,zero,zero LONG $0xb971e2c4; WORD $0x3843 // vfmadd231ss 56(%rbx), %xmm1, %xmm0 # xmm0 = (xmm1 * mem) + xmm0 JMP LBB12_71 LBB12_35: WORD $0x8948; BYTE $0xf2 // movq %rsi, %rdx LONG $0x06e2c148 // shlq $6, %rdx QUAD $0x003fffffff80b848; WORD $0x0000 // movabsq $274877906816, %rax # imm = 0x3FFFFFFF80 WORD $0x0148; BYTE $0xc2 // addq %rax, %rdx LONG $0x40c88348 // orq $64, %rax WORD $0x2148; BYTE $0xd0 // andq %rdx, %rax WORD $0x568d; BYTE $0xff // leal -1(%rsi), %edx WORD $0x7e8d; BYTE $0xfe // leal -2(%rsi), %edi WORD $0x8941; BYTE $0xd0 // movl %edx, %r8d LONG $0xfce08341 // andl $-4, %r8d LONG $0x24448944; BYTE $0x10 // movl %r8d, 16(%rsp) # 4-byte Spill WORD $0x8941; BYTE $0xf0 // movl %esi, %r8d WORD $0xfe41; BYTE $0xc8 // decb %r8b LONG $0xc8b60f45 // movzbl %r8b, %r9d LONG $0x03e18341 // andl $3, %r9d LONG $0x06e1c141 // shll $6, %r9d WORD $0xdb31 // xorl %ebx, %ebx JMP LBB12_36 LBB12_68: LONG $0x245c8b48; BYTE $0x20 // movq 32(%rsp), %rbx # 8-byte Reload WORD $0xff48; BYTE $0xc3 // incq %rbx LONG $0x245c3b48; BYTE $0x38 // cmpq 56(%rsp), %rbx # 8-byte Folded Reload JE LBB12_140 LBB12_36: WORD $0x8949; BYTE $0xd8 // movq %rbx, %r8 LONG $0x45af0f4c; BYTE $0x10 // imulq 16(%rbp), %r8 LONG $0x245c8b4c; BYTE $0x18 // movq 24(%rsp), %r11 # 8-byte Reload LONG $0x83148d4f // leaq (%r11,%r8,4), %r10 LONG $0x83048d4f // leaq (%r11,%r8,4), %r8 LONG $0x40c08349 // addq $64, %r8 LONG $0x2404894c // movq %r8, (%rsp) # 8-byte Spill LONG $0x245c8948; BYTE $0x20 // movq %rbx, 32(%rsp) # 8-byte Spill WORD $0x8949; BYTE $0xd8 // movq %rbx, %r8 LONG $0x45af0f4c; BYTE $0x30 // imulq 48(%rbp), %r8 LONG $0x285d8b4c // movq 40(%rbp), %r11 LONG $0x83348d4f // leaq (%r11,%r8,4), %r14 WORD $0x3145; BYTE $0xff // xorl %r15d, %r15d JMP LBB12_37 LBB12_67: LONG $0x48fdf362; WORD $0xc11b; BYTE $0x01 // vextractf64x4 $1, %zmm0, %ymm1 LONG $0xc158fcc5 // vaddps %ymm1, %ymm0, %ymm0 LONG $0x197de3c4; WORD $0x01c1 // vextractf128 $1, %ymm0, %xmm1 LONG $0xc158f8c5 // vaddps %xmm1, %xmm0, %xmm0 LONG $0xc8c6f9c5; BYTE $0x01 // vshufpd $1, %xmm0, %xmm0, %xmm1 # xmm1 = xmm0[1,0] LONG $0xc158f8c5 // vaddps %xmm1, %xmm0, %xmm0 LONG $0x4c10fcc5; WORD $0x4003 // vmovups 64(%rbx,%rax), %ymm1 LONG $0x5974c1c4; WORD $0x004c; BYTE $0x40 // vmulps 64(%r8,%rax), %ymm1, %ymm1 LONG $0x197de3c4; WORD $0x01ca // vextractf128 $1, %ymm1, %xmm2 LONG $0xc958e8c5 // vaddps %xmm1, %xmm2, %xmm1 LONG $0xd1c6f1c5; BYTE $0x01 // vshufpd $1, %xmm1, %xmm1, %xmm2 # xmm2 = xmm1[1,0] LONG $0xca58f0c5 // vaddps %xmm2, %xmm1, %xmm1 LONG $0xd116fac5 // vmovshdup %xmm1, %xmm2 # xmm2 = xmm1[1,1,3,3] LONG $0xca58f2c5 // vaddss %xmm2, %xmm1, %xmm1 LONG $0xd016fac5 // vmovshdup %xmm0, %xmm2 # xmm2 = xmm0[1,1,3,3] LONG $0xc258fac5 // vaddss %xmm2, %xmm0, %xmm0 LONG $0xc158fac5 // vaddss %xmm1, %xmm0, %xmm0 LONG $0x117a81c4; WORD $0xbe04 // vmovss %xmm0, (%r14,%r15,4) WORD $0xff49; BYTE $0xc7 // incq %r15 WORD $0x3949; BYTE $0xcf // cmpq %rcx, %r15 JE LBB12_68 LBB12_37: WORD $0x894d; BYTE $0xf8 // movq %r15, %r8 LONG $0x45af0f4c; BYTE $0x20 // imulq 32(%rbp), %r8 LONG $0x185d8b4c // movq 24(%rbp), %r11 LONG $0x83048d4f // leaq (%r11,%r8,4), %r8 WORD $0xf685 // testl %esi, %esi JLE LBB12_38 LONG $0x487cd162; WORD $0x0210 // vmovups (%r10), %zmm0 LONG $0x487cd162; WORD $0x0859 // vmulps (%r8), %zmm0, %zmm1 LONG $0x40c08349 // addq $64, %r8 LONG $0x241c8b48 // movq (%rsp), %rbx # 8-byte Reload WORD $0xff83; BYTE $0x03 // cmpl $3, %edi JAE LBB12_62 LBB12_61: WORD $0x8949; BYTE $0xdd // movq %rbx, %r13 WORD $0x894d; BYTE $0xc4 // movq %r8, %r12 LONG $0x487cf162; WORD $0xc128 // vmovaps %zmm1, %zmm0 JMP LBB12_64 LBB12_38: LONG $0xc957f0c5 // vxorps %xmm1, %xmm1, %xmm1 WORD $0x894c; BYTE $0xd3 // movq %r10, %rbx WORD $0xff83; BYTE $0x03 // cmpl $3, %edi JB LBB12_61 LBB12_62: LONG $0x245c8b44; BYTE $0x10 // movl 16(%rsp), %r11d # 4-byte Reload WORD $0x8949; BYTE $0xdd // movq %rbx, %r13 WORD $0x894d; BYTE $0xc4 // movq %r8, %r12 LBB12_63: LONG $0x487cd162; WORD $0x4510; BYTE $0x00 // vmovups (%r13), %zmm0 LONG $0x487cd162; WORD $0x5510; BYTE $0x01 // vmovups 64(%r13), %zmm2 LONG $0x487cd162; WORD $0x5d10; BYTE $0x02 // vmovups 128(%r13), %zmm3 LONG $0x487cd162; WORD $0x6510; BYTE $0x03 // vmovups 192(%r13), %zmm4 LONG $0x4875d262; WORD $0x0498; BYTE $0x24 // vfmadd132ps (%r12), %zmm1, %zmm0 # zmm0 = (zmm0 * mem) + zmm1 QUAD $0x012444b8486dd262 // vfmadd231ps 64(%r12), %zmm2, %zmm0 # zmm0 = (zmm2 * mem) + zmm0 QUAD $0x022444b84865d262 // vfmadd231ps 128(%r12), %zmm3, %zmm0 # zmm0 = (zmm3 * mem) + zmm0 QUAD $0x032444b8485dd262 // vfmadd231ps 192(%r12), %zmm4, %zmm0 # zmm0 = (zmm4 * mem) + zmm0 LONG $0x00c58149; WORD $0x0001; BYTE $0x00 // addq $256, %r13 # imm = 0x100 LONG $0x00c48149; WORD $0x0001; BYTE $0x00 // addq $256, %r12 # imm = 0x100 LONG $0x487cf162; WORD $0xc828 // vmovaps %zmm0, %zmm1 LONG $0xfcc38341 // addl $-4, %r11d JNE LBB12_63 LBB12_64: WORD $0xc2f6; BYTE $0x03 // testb $3, %dl JE LBB12_67 WORD $0x3145; BYTE $0xdb // xorl %r11d, %r11d LBB12_66: QUAD $0x001d4c10487c9162 // vmovups (%r13,%r11), %zmm1 LONG $0x48759262; WORD $0x04b8; BYTE $0x1c // vfmadd231ps (%r12,%r11), %zmm1, %zmm0 # zmm0 = (zmm1 * mem) + zmm0 LONG $0x40c38349 // addq $64, %r11 WORD $0x3945; BYTE $0xd9 // cmpl %r11d, %r9d JNE LBB12_66 JMP LBB12_67 LBB12_140: LONG $0xd8658d48 // leaq -40(%rbp), %rsp BYTE $0x5b // popq %rbx WORD $0x5c41 // popq %r12 WORD $0x5d41 // popq %r13 WORD $0x5e41 // popq %r14 WORD $0x5f41 // popq %r15 BYTE $0x5d // popq %rbp WORD $0xf8c5; BYTE $0x77 // vzeroupper POPQ DI POPQ DI POPQ DI POPQ DI POPQ DI POPQ DI RET ================================================ FILE: common/floats/floats_neon.go ================================================ //go:build !noasm && arm64 // Code generated by GoAT. DO NOT EDIT. // versions: // clang 18.1.3 (1ubuntu1) // objdump 2.42 // flags: -O3 // source: src/floats_neon.c package floats import "unsafe" //go:noescape func vmul_const_add_to(a, b, c, dst unsafe.Pointer, n int64) //go:noescape func vmul_const_add(a, b, c unsafe.Pointer, n int64) //go:noescape func vmul_const_to(a, b, c unsafe.Pointer, n int64) //go:noescape func vmul_const(a, b unsafe.Pointer, n int64) //go:noescape func vadd_const(a, b unsafe.Pointer, n int64) //go:noescape func vsub_to(a, b, c unsafe.Pointer, n int64) //go:noescape func vsub(a, b unsafe.Pointer, n int64) //go:noescape func vmul_to(a, b, c unsafe.Pointer, n int64) //go:noescape func vdiv_to(a, b, c unsafe.Pointer, n int64) //go:noescape func vsqrt_to(a, b unsafe.Pointer, n int64) //go:noescape func vdot(a, b unsafe.Pointer, n int64) (result float32) //go:noescape func veuclidean(a, b unsafe.Pointer, n int64) (result float32) //go:noescape func vmm(transA, transB bool, m, n, k int64, a unsafe.Pointer, lda int64, b unsafe.Pointer, ldb int64, c unsafe.Pointer, ldc int64) ================================================ FILE: common/floats/floats_neon.s ================================================ //go:build !noasm && arm64 // Code generated by GoAT. DO NOT EDIT. // versions: // clang 18.1.3 (1ubuntu1) // objdump 2.42 // flags: -O3 // source: src/floats_neon.c TEXT ·vmul_const_add_to(SB), $0-40 MOVD a+0(FP), R0 MOVD b+8(FP), R1 MOVD c+16(FP), R2 MOVD dst+24(FP), R3 MOVD n+32(FP), R4 WORD $0xf100049f // cmp x4, #1 BLT LBB0_6 WORD $0xa9bf7bfd // stp x29, x30, [sp, #-16]! WORD $0xf100309f // cmp x4, #12 WORD $0x910003fd // mov x29, sp BHS LBB0_7 WORD $0xaa1f03e8 // mov x8, xzr LBB0_3: WORD $0xd37ef50b // lsl x11, x8, #2 WORD $0xcb080088 // sub x8, x4, x8 WORD $0x8b0b0069 // add x9, x3, x11 WORD $0x8b0b004a // add x10, x2, x11 WORD $0x8b0b000b // add x11, x0, x11 LBB0_4: WORD $0xbc404560 // ldr s0, [x11], #4 WORD $0xbd400021 // ldr s1, [x1] WORD $0xbc404542 // ldr s2, [x10], #4 WORD $0xf1000508 // subs x8, x8, #1 WORD $0x1f010800 // fmadd s0, s0, s1, s2 WORD $0xbc004520 // str s0, [x9], #4 BNE LBB0_4 LBB0_5: WORD $0xa8c17bfd // ldp x29, x30, [sp], #16 LBB0_6: RET LBB0_7: WORD $0xd37ef488 // lsl x8, x4, #2 WORD $0x91001029 // add x9, x1, #4 WORD $0xeb03013f // cmp x9, x3 WORD $0x8b08006b // add x11, x3, x8 WORD $0x8b08004a // add x10, x2, x8 WORD $0x8b080008 // add x8, x0, x8 WORD $0xfa418160 // ccmp x11, x1, #0, hi WORD $0x1a9f97e9 // cset w9, hi WORD $0xeb03015f // cmp x10, x3 WORD $0xfa428160 // ccmp x11, x2, #0, hi WORD $0x1a9f97ea // cset w10, hi WORD $0xeb00017f // cmp x11, x0 WORD $0xfa438100 // ccmp x8, x3, #0, hi WORD $0xaa1f03e8 // mov x8, xzr BHI LBB0_3 WORD $0x3707fc69 // tbnz w9, #0, .LBB0_3 WORD $0x3707fc4a // tbnz w10, #0, .LBB0_3 WORD $0x4d40c820 // ld1r { v0.4s }, [x1] WORD $0x927dec88 // and x8, x4, #0x7ffffffffffffff8 WORD $0x91004009 // add x9, x0, #16 WORD $0x9100404a // add x10, x2, #16 WORD $0x9100406b // add x11, x3, #16 WORD $0xaa0803ec // mov x12, x8 LBB0_11: WORD $0xad7f9141 // ldp q1, q4, [x10, #-16] WORD $0xf100218c // subs x12, x12, #8 WORD $0xad7f8d22 // ldp q2, q3, [x9, #-16] WORD $0x91008129 // add x9, x9, #32 WORD $0x9100814a // add x10, x10, #32 WORD $0x4e22cc01 // fmla v1.4s, v0.4s, v2.4s WORD $0x4e23cc04 // fmla v4.4s, v0.4s, v3.4s WORD $0xad3f9161 // stp q1, q4, [x11, #-16] WORD $0x9100816b // add x11, x11, #32 BNE LBB0_11 WORD $0xeb04011f // cmp x8, x4 BNE LBB0_3 B LBB0_5 TEXT ·vmul_const_add(SB), $0-32 MOVD a+0(FP), R0 MOVD b+8(FP), R1 MOVD c+16(FP), R2 MOVD n+24(FP), R3 WORD $0xf100047f // cmp x3, #1 BLT LBB1_6 WORD $0xa9bf7bfd // stp x29, x30, [sp, #-16]! WORD $0xf100207f // cmp x3, #8 WORD $0x910003fd // mov x29, sp BHS LBB1_7 WORD $0xaa1f03e8 // mov x8, xzr LBB1_3: WORD $0xd37ef50a // lsl x10, x8, #2 WORD $0xcb080068 // sub x8, x3, x8 WORD $0x8b0a0049 // add x9, x2, x10 WORD $0x8b0a000a // add x10, x0, x10 LBB1_4: WORD $0xbc404540 // ldr s0, [x10], #4 WORD $0xbd400021 // ldr s1, [x1] WORD $0xbd400122 // ldr s2, [x9] WORD $0xf1000508 // subs x8, x8, #1 WORD $0x1f010800 // fmadd s0, s0, s1, s2 WORD $0xbc004520 // str s0, [x9], #4 BNE LBB1_4 LBB1_5: WORD $0xa8c17bfd // ldp x29, x30, [sp], #16 LBB1_6: RET LBB1_7: WORD $0xd37ef468 // lsl x8, x3, #2 WORD $0x91001029 // add x9, x1, #4 WORD $0xeb02013f // cmp x9, x2 WORD $0x8b08004a // add x10, x2, x8 WORD $0x8b080008 // add x8, x0, x8 WORD $0xfa418140 // ccmp x10, x1, #0, hi WORD $0x1a9f97e9 // cset w9, hi WORD $0xeb00015f // cmp x10, x0 WORD $0xfa428100 // ccmp x8, x2, #0, hi WORD $0xaa1f03e8 // mov x8, xzr BHI LBB1_3 WORD $0x3707fd09 // tbnz w9, #0, .LBB1_3 WORD $0x4d40c820 // ld1r { v0.4s }, [x1] WORD $0x927dec68 // and x8, x3, #0x7ffffffffffffff8 WORD $0x91004009 // add x9, x0, #16 WORD $0x9100404a // add x10, x2, #16 WORD $0xaa0803eb // mov x11, x8 LBB1_10: WORD $0xad7f9141 // ldp q1, q4, [x10, #-16] WORD $0xf100216b // subs x11, x11, #8 WORD $0xad7f8d22 // ldp q2, q3, [x9, #-16] WORD $0x91008129 // add x9, x9, #32 WORD $0x4e22cc01 // fmla v1.4s, v0.4s, v2.4s WORD $0x4e23cc04 // fmla v4.4s, v0.4s, v3.4s WORD $0xad3f9141 // stp q1, q4, [x10, #-16] WORD $0x9100814a // add x10, x10, #32 BNE LBB1_10 WORD $0xeb03011f // cmp x8, x3 BNE LBB1_3 B LBB1_5 TEXT ·vmul_const_to(SB), $0-32 MOVD a+0(FP), R0 MOVD b+8(FP), R1 MOVD c+16(FP), R2 MOVD n+24(FP), R3 WORD $0xf100047f // cmp x3, #1 BLT LBB2_6 WORD $0xa9bf7bfd // stp x29, x30, [sp, #-16]! WORD $0xf100207f // cmp x3, #8 WORD $0x910003fd // mov x29, sp BHS LBB2_7 WORD $0xaa1f03e8 // mov x8, xzr LBB2_3: WORD $0xd37ef50a // lsl x10, x8, #2 WORD $0xcb080068 // sub x8, x3, x8 WORD $0x8b0a0049 // add x9, x2, x10 WORD $0x8b0a000a // add x10, x0, x10 LBB2_4: WORD $0xbc404540 // ldr s0, [x10], #4 WORD $0xbd400021 // ldr s1, [x1] WORD $0xf1000508 // subs x8, x8, #1 WORD $0x1e210800 // fmul s0, s0, s1 WORD $0xbc004520 // str s0, [x9], #4 BNE LBB2_4 LBB2_5: WORD $0xa8c17bfd // ldp x29, x30, [sp], #16 LBB2_6: RET LBB2_7: WORD $0xd37ef468 // lsl x8, x3, #2 WORD $0x91001029 // add x9, x1, #4 WORD $0xeb02013f // cmp x9, x2 WORD $0x8b08004a // add x10, x2, x8 WORD $0x8b080008 // add x8, x0, x8 WORD $0xfa418140 // ccmp x10, x1, #0, hi WORD $0x1a9f97e9 // cset w9, hi WORD $0xeb00015f // cmp x10, x0 WORD $0xfa428100 // ccmp x8, x2, #0, hi WORD $0xaa1f03e8 // mov x8, xzr BHI LBB2_3 WORD $0x3707fd29 // tbnz w9, #0, .LBB2_3 WORD $0x4d40c820 // ld1r { v0.4s }, [x1] WORD $0x927dec68 // and x8, x3, #0x7ffffffffffffff8 WORD $0x91004009 // add x9, x0, #16 WORD $0x9100404a // add x10, x2, #16 WORD $0xaa0803eb // mov x11, x8 LBB2_10: WORD $0xad7f8921 // ldp q1, q2, [x9, #-16] WORD $0xf100216b // subs x11, x11, #8 WORD $0x91008129 // add x9, x9, #32 WORD $0x6e20dc21 // fmul v1.4s, v1.4s, v0.4s WORD $0x6e20dc42 // fmul v2.4s, v2.4s, v0.4s WORD $0xad3f8941 // stp q1, q2, [x10, #-16] WORD $0x9100814a // add x10, x10, #32 BNE LBB2_10 WORD $0xeb03011f // cmp x8, x3 BNE LBB2_3 B LBB2_5 TEXT ·vmul_const(SB), $0-24 MOVD a+0(FP), R0 MOVD b+8(FP), R1 MOVD n+16(FP), R2 WORD $0xf100045f // cmp x2, #1 BLT LBB3_8 WORD $0xa9bf7bfd // stp x29, x30, [sp, #-16]! WORD $0xf100205f // cmp x2, #8 WORD $0x910003fd // mov x29, sp BLO LBB3_4 WORD $0x91001028 // add x8, x1, #4 WORD $0xeb00011f // cmp x8, x0 BLS LBB3_9 WORD $0x8b020808 // add x8, x0, x2, lsl #2 WORD $0xeb01011f // cmp x8, x1 BLS LBB3_9 LBB3_4: WORD $0xaa1f03e8 // mov x8, xzr LBB3_5: WORD $0x8b080809 // add x9, x0, x8, lsl #2 WORD $0xcb080048 // sub x8, x2, x8 LBB3_6: WORD $0xbd400020 // ldr s0, [x1] WORD $0xbd400121 // ldr s1, [x9] WORD $0xf1000508 // subs x8, x8, #1 WORD $0x1e210800 // fmul s0, s0, s1 WORD $0xbc004520 // str s0, [x9], #4 BNE LBB3_6 LBB3_7: WORD $0xa8c17bfd // ldp x29, x30, [sp], #16 LBB3_8: RET LBB3_9: WORD $0x4d40c820 // ld1r { v0.4s }, [x1] WORD $0x927dec48 // and x8, x2, #0x7ffffffffffffff8 WORD $0x91004009 // add x9, x0, #16 WORD $0xaa0803ea // mov x10, x8 LBB3_10: WORD $0xad7f8921 // ldp q1, q2, [x9, #-16] WORD $0xf100214a // subs x10, x10, #8 WORD $0x6e21dc01 // fmul v1.4s, v0.4s, v1.4s WORD $0x6e22dc02 // fmul v2.4s, v0.4s, v2.4s WORD $0xad3f8921 // stp q1, q2, [x9, #-16] WORD $0x91008129 // add x9, x9, #32 BNE LBB3_10 WORD $0xeb02011f // cmp x8, x2 BEQ LBB3_7 B LBB3_5 TEXT ·vadd_const(SB), $0-24 MOVD a+0(FP), R0 MOVD b+8(FP), R1 MOVD n+16(FP), R2 WORD $0xf100045f // cmp x2, #1 BLT LBB4_8 WORD $0xa9bf7bfd // stp x29, x30, [sp, #-16]! WORD $0xf100205f // cmp x2, #8 WORD $0x910003fd // mov x29, sp BLO LBB4_4 WORD $0x91001028 // add x8, x1, #4 WORD $0xeb00011f // cmp x8, x0 BLS LBB4_9 WORD $0x8b020808 // add x8, x0, x2, lsl #2 WORD $0xeb01011f // cmp x8, x1 BLS LBB4_9 LBB4_4: WORD $0xaa1f03e8 // mov x8, xzr LBB4_5: WORD $0x8b080809 // add x9, x0, x8, lsl #2 WORD $0xcb080048 // sub x8, x2, x8 LBB4_6: WORD $0xbd400020 // ldr s0, [x1] WORD $0xbd400121 // ldr s1, [x9] WORD $0xf1000508 // subs x8, x8, #1 WORD $0x1e212800 // fadd s0, s0, s1 WORD $0xbc004520 // str s0, [x9], #4 BNE LBB4_6 LBB4_7: WORD $0xa8c17bfd // ldp x29, x30, [sp], #16 LBB4_8: RET LBB4_9: WORD $0x4d40c820 // ld1r { v0.4s }, [x1] WORD $0x927dec48 // and x8, x2, #0x7ffffffffffffff8 WORD $0x91004009 // add x9, x0, #16 WORD $0xaa0803ea // mov x10, x8 LBB4_10: WORD $0xad7f8921 // ldp q1, q2, [x9, #-16] WORD $0xf100214a // subs x10, x10, #8 WORD $0x4e21d401 // fadd v1.4s, v0.4s, v1.4s WORD $0x4e22d402 // fadd v2.4s, v0.4s, v2.4s WORD $0xad3f8921 // stp q1, q2, [x9, #-16] WORD $0x91008129 // add x9, x9, #32 BNE LBB4_10 WORD $0xeb02011f // cmp x8, x2 BEQ LBB4_7 B LBB4_5 TEXT ·vsub_to(SB), $0-32 MOVD a+0(FP), R0 MOVD b+8(FP), R1 MOVD c+16(FP), R2 MOVD n+24(FP), R3 WORD $0xf100047f // cmp x3, #1 BLT LBB5_6 WORD $0xa9bf7bfd // stp x29, x30, [sp, #-16]! WORD $0xf100207f // cmp x3, #8 WORD $0x910003fd // mov x29, sp BHS LBB5_7 WORD $0xaa1f03e8 // mov x8, xzr LBB5_3: WORD $0xd37ef50b // lsl x11, x8, #2 WORD $0xcb080068 // sub x8, x3, x8 WORD $0x8b0b0049 // add x9, x2, x11 WORD $0x8b0b002a // add x10, x1, x11 WORD $0x8b0b000b // add x11, x0, x11 LBB5_4: WORD $0xbc404560 // ldr s0, [x11], #4 WORD $0xf1000508 // subs x8, x8, #1 WORD $0xbc404541 // ldr s1, [x10], #4 WORD $0x1e213800 // fsub s0, s0, s1 WORD $0xbc004520 // str s0, [x9], #4 BNE LBB5_4 LBB5_5: WORD $0xa8c17bfd // ldp x29, x30, [sp], #16 LBB5_6: RET LBB5_7: WORD $0xcb000048 // sub x8, x2, x0 WORD $0xf100811f // cmp x8, #32 WORD $0xaa1f03e8 // mov x8, xzr BLO LBB5_3 WORD $0xcb010049 // sub x9, x2, x1 WORD $0xf100813f // cmp x9, #32 BLO LBB5_3 WORD $0x927dec68 // and x8, x3, #0x7ffffffffffffff8 WORD $0x91004009 // add x9, x0, #16 WORD $0x9100402a // add x10, x1, #16 WORD $0x9100404b // add x11, x2, #16 WORD $0xaa0803ec // mov x12, x8 LBB5_10: WORD $0xad7f8d40 // ldp q0, q3, [x10, #-16] WORD $0xf100218c // subs x12, x12, #8 WORD $0xad7f8921 // ldp q1, q2, [x9, #-16] WORD $0x91008129 // add x9, x9, #32 WORD $0x9100814a // add x10, x10, #32 WORD $0x4ea0d420 // fsub v0.4s, v1.4s, v0.4s WORD $0x4ea3d441 // fsub v1.4s, v2.4s, v3.4s WORD $0xad3f8560 // stp q0, q1, [x11, #-16] WORD $0x9100816b // add x11, x11, #32 BNE LBB5_10 WORD $0xeb03011f // cmp x8, x3 BNE LBB5_3 B LBB5_5 TEXT ·vsub(SB), $0-24 MOVD a+0(FP), R0 MOVD b+8(FP), R1 MOVD n+16(FP), R2 WORD $0xf100045f // cmp x2, #1 BLT LBB6_8 WORD $0xa9bf7bfd // stp x29, x30, [sp, #-16]! WORD $0xf100205f // cmp x2, #8 WORD $0x910003fd // mov x29, sp BLO LBB6_4 WORD $0xd37ef448 // lsl x8, x2, #2 WORD $0x8b080029 // add x9, x1, x8 WORD $0xeb00013f // cmp x9, x0 BLS LBB6_9 WORD $0x8b080008 // add x8, x0, x8 WORD $0xeb01011f // cmp x8, x1 BLS LBB6_9 LBB6_4: WORD $0xaa1f03e8 // mov x8, xzr LBB6_5: WORD $0xd37ef50a // lsl x10, x8, #2 WORD $0xcb080048 // sub x8, x2, x8 WORD $0x8b0a0009 // add x9, x0, x10 WORD $0x8b0a002a // add x10, x1, x10 LBB6_6: WORD $0xbc404540 // ldr s0, [x10], #4 WORD $0xbd400121 // ldr s1, [x9] WORD $0xf1000508 // subs x8, x8, #1 WORD $0x1e203820 // fsub s0, s1, s0 WORD $0xbc004520 // str s0, [x9], #4 BNE LBB6_6 LBB6_7: WORD $0xa8c17bfd // ldp x29, x30, [sp], #16 LBB6_8: RET LBB6_9: WORD $0x927dec48 // and x8, x2, #0x7ffffffffffffff8 WORD $0x91004029 // add x9, x1, #16 WORD $0x9100400a // add x10, x0, #16 WORD $0xaa0803eb // mov x11, x8 LBB6_10: WORD $0xad7f8d40 // ldp q0, q3, [x10, #-16] WORD $0xf100216b // subs x11, x11, #8 WORD $0xad7f8921 // ldp q1, q2, [x9, #-16] WORD $0x91008129 // add x9, x9, #32 WORD $0x4ea1d400 // fsub v0.4s, v0.4s, v1.4s WORD $0x4ea2d461 // fsub v1.4s, v3.4s, v2.4s WORD $0xad3f8540 // stp q0, q1, [x10, #-16] WORD $0x9100814a // add x10, x10, #32 BNE LBB6_10 WORD $0xeb02011f // cmp x8, x2 BEQ LBB6_7 B LBB6_5 TEXT ·vmul_to(SB), $0-32 MOVD a+0(FP), R0 MOVD b+8(FP), R1 MOVD c+16(FP), R2 MOVD n+24(FP), R3 WORD $0xf100047f // cmp x3, #1 BLT LBB7_6 WORD $0xa9bf7bfd // stp x29, x30, [sp, #-16]! WORD $0xf100207f // cmp x3, #8 WORD $0x910003fd // mov x29, sp BHS LBB7_7 WORD $0xaa1f03e8 // mov x8, xzr LBB7_3: WORD $0xd37ef50b // lsl x11, x8, #2 WORD $0xcb080068 // sub x8, x3, x8 WORD $0x8b0b0049 // add x9, x2, x11 WORD $0x8b0b002a // add x10, x1, x11 WORD $0x8b0b000b // add x11, x0, x11 LBB7_4: WORD $0xbc404560 // ldr s0, [x11], #4 WORD $0xf1000508 // subs x8, x8, #1 WORD $0xbc404541 // ldr s1, [x10], #4 WORD $0x1e210800 // fmul s0, s0, s1 WORD $0xbc004520 // str s0, [x9], #4 BNE LBB7_4 LBB7_5: WORD $0xa8c17bfd // ldp x29, x30, [sp], #16 LBB7_6: RET LBB7_7: WORD $0xcb000048 // sub x8, x2, x0 WORD $0xf100811f // cmp x8, #32 WORD $0xaa1f03e8 // mov x8, xzr BLO LBB7_3 WORD $0xcb010049 // sub x9, x2, x1 WORD $0xf100813f // cmp x9, #32 BLO LBB7_3 WORD $0x927dec68 // and x8, x3, #0x7ffffffffffffff8 WORD $0x91004009 // add x9, x0, #16 WORD $0x9100402a // add x10, x1, #16 WORD $0x9100404b // add x11, x2, #16 WORD $0xaa0803ec // mov x12, x8 LBB7_10: WORD $0xad7f8d40 // ldp q0, q3, [x10, #-16] WORD $0xf100218c // subs x12, x12, #8 WORD $0xad7f8921 // ldp q1, q2, [x9, #-16] WORD $0x91008129 // add x9, x9, #32 WORD $0x9100814a // add x10, x10, #32 WORD $0x6e20dc20 // fmul v0.4s, v1.4s, v0.4s WORD $0x6e23dc41 // fmul v1.4s, v2.4s, v3.4s WORD $0xad3f8560 // stp q0, q1, [x11, #-16] WORD $0x9100816b // add x11, x11, #32 BNE LBB7_10 WORD $0xeb03011f // cmp x8, x3 BNE LBB7_3 B LBB7_5 TEXT ·vdiv_to(SB), $0-32 MOVD a+0(FP), R0 MOVD b+8(FP), R1 MOVD c+16(FP), R2 MOVD n+24(FP), R3 WORD $0xf100047f // cmp x3, #1 BLT LBB8_6 WORD $0xa9bf7bfd // stp x29, x30, [sp, #-16]! WORD $0xf100207f // cmp x3, #8 WORD $0x910003fd // mov x29, sp BHS LBB8_7 WORD $0xaa1f03e8 // mov x8, xzr LBB8_3: WORD $0xd37ef50b // lsl x11, x8, #2 WORD $0xcb080068 // sub x8, x3, x8 WORD $0x8b0b0049 // add x9, x2, x11 WORD $0x8b0b002a // add x10, x1, x11 WORD $0x8b0b000b // add x11, x0, x11 LBB8_4: WORD $0xbc404560 // ldr s0, [x11], #4 WORD $0xf1000508 // subs x8, x8, #1 WORD $0xbc404541 // ldr s1, [x10], #4 WORD $0x1e211800 // fdiv s0, s0, s1 WORD $0xbc004520 // str s0, [x9], #4 BNE LBB8_4 LBB8_5: WORD $0xa8c17bfd // ldp x29, x30, [sp], #16 LBB8_6: RET LBB8_7: WORD $0xcb000048 // sub x8, x2, x0 WORD $0xf100811f // cmp x8, #32 WORD $0xaa1f03e8 // mov x8, xzr BLO LBB8_3 WORD $0xcb010049 // sub x9, x2, x1 WORD $0xf100813f // cmp x9, #32 BLO LBB8_3 WORD $0x927dec68 // and x8, x3, #0x7ffffffffffffff8 WORD $0x91004009 // add x9, x0, #16 WORD $0x9100402a // add x10, x1, #16 WORD $0x9100404b // add x11, x2, #16 WORD $0xaa0803ec // mov x12, x8 LBB8_10: WORD $0xad7f8d40 // ldp q0, q3, [x10, #-16] WORD $0xf100218c // subs x12, x12, #8 WORD $0xad7f8921 // ldp q1, q2, [x9, #-16] WORD $0x91008129 // add x9, x9, #32 WORD $0x9100814a // add x10, x10, #32 WORD $0x6e20fc20 // fdiv v0.4s, v1.4s, v0.4s WORD $0x6e23fc41 // fdiv v1.4s, v2.4s, v3.4s WORD $0xad3f8560 // stp q0, q1, [x11, #-16] WORD $0x9100816b // add x11, x11, #32 BNE LBB8_10 WORD $0xeb03011f // cmp x8, x3 BNE LBB8_3 B LBB8_5 TEXT ·vsqrt_to(SB), $0-24 MOVD a+0(FP), R0 MOVD b+8(FP), R1 MOVD n+16(FP), R2 WORD $0xa9bf7bfd // stp x29, x30, [sp, #-16]! WORD $0x91000c48 // add x8, x2, #3 WORD $0xf100005f // cmp x2, #0 WORD $0x910003fd // mov x29, sp WORD $0x9a82b108 // csel x8, x8, x2, lt WORD $0xd342fd09 // lsr x9, x8, #2 WORD $0x927ef508 // and x8, x8, #0xfffffffffffffffc WORD $0xcb080048 // sub x8, x2, x8 WORD $0x7100053f // cmp w9, #1 BLT LBB9_2 LBB9_1: WORD $0x3cc10400 // ldr q0, [x0], #16 WORD $0x71000529 // subs w9, w9, #1 WORD $0x6ea1f800 // fsqrt v0.4s, v0.4s WORD $0x3c810420 // str q0, [x1], #16 BNE LBB9_1 LBB9_2: WORD $0x7100051f // cmp w8, #1 BLT LBB9_5 WORD $0x92407d08 // and x8, x8, #0xffffffff LBB9_4: WORD $0x0ddfc800 // ld1r { v0.2s }, [x0], #4 WORD $0xf1000508 // subs x8, x8, #1 WORD $0x2ea1f800 // fsqrt v0.2s, v0.2s WORD $0x0d9f8020 // st1 { v0.s }[0], [x1], #4 BNE LBB9_4 LBB9_5: WORD $0xa8c17bfd // ldp x29, x30, [sp], #16 RET TEXT ·vdot(SB), $8-24 MOVD a+0(FP), R0 MOVD b+8(FP), R1 MOVD n+16(FP), R2 WORD $0xa9bf7bfd // stp x29, x30, [sp, #-16]! WORD $0x91000c48 // add x8, x2, #3 WORD $0xf100005f // cmp x2, #0 WORD $0x910003fd // mov x29, sp WORD $0x9a82b108 // csel x8, x8, x2, lt WORD $0x9342fd0a // asr x10, x8, #2 WORD $0x927ef508 // and x8, x8, #0xfffffffffffffffc WORD $0xcb080048 // sub x8, x2, x8 WORD $0x7100055f // cmp w10, #1 BLT LBB10_5 WORD $0x3cc10400 // ldr q0, [x0], #16 WORD $0x71000549 // subs w9, w10, #1 WORD $0x3cc10421 // ldr q1, [x1], #16 WORD $0x6e21dc00 // fmul v0.4s, v0.4s, v1.4s BEQ LBB10_6 WORD $0xb27b7beb // mov x11, #68719476704 WORD $0xaa0103ec // mov x12, x1 WORD $0x8b0a116a // add x10, x11, x10, lsl #4 WORD $0xaa0003eb // mov x11, x0 WORD $0x927c7d4a // and x10, x10, #0xffffffff0 WORD $0x9100414a // add x10, x10, #16 LBB10_3: WORD $0x3cc10561 // ldr q1, [x11], #16 WORD $0x71000529 // subs w9, w9, #1 WORD $0x3cc10582 // ldr q2, [x12], #16 WORD $0x4e21cc40 // fmla v0.4s, v2.4s, v1.4s BNE LBB10_3 WORD $0x8b0a0000 // add x0, x0, x10 WORD $0x8b0a0021 // add x1, x1, x10 B LBB10_6 LBB10_5: WORD $0x6f00e400 // movi v0.2d, #0000000000000000 LBB10_6: WORD $0x2f00e401 // movi d1, #0000000000000000 WORD $0x5e0c0402 // mov s2, v0.s[1] WORD $0x7100011f // cmp w8, #0 WORD $0x1e212801 // fadd s1, s0, s1 WORD $0x1e212841 // fadd s1, s2, s1 WORD $0x5e140402 // mov s2, v0.s[2] WORD $0x5e1c0400 // mov s0, v0.s[3] WORD $0x1e212841 // fadd s1, s2, s1 WORD $0x1e212800 // fadd s0, s0, s1 BLE LBB10_14 WORD $0x92407d09 // and x9, x8, #0xffffffff WORD $0xf100213f // cmp x9, #8 BHS LBB10_9 WORD $0xaa1f03e8 // mov x8, xzr B LBB10_12 LBB10_9: WORD $0x9240090a // and x10, x8, #0x7 WORD $0x9100400b // add x11, x0, #16 WORD $0x9100402c // add x12, x1, #16 WORD $0xcb0a0128 // sub x8, x9, x10 WORD $0xaa0803ed // mov x13, x8 LBB10_10: WORD $0xad7f9181 // ldp q1, q4, [x12, #-16] WORD $0xf10021ad // subs x13, x13, #8 WORD $0xad7f8d62 // ldp q2, q3, [x11, #-16] WORD $0x9100816b // add x11, x11, #32 WORD $0x9100818c // add x12, x12, #32 WORD $0x6e21dc41 // fmul v1.4s, v2.4s, v1.4s WORD $0x5e0c0422 // mov s2, v1.s[1] WORD $0x1e212800 // fadd s0, s0, s1 WORD $0x5e140425 // mov s5, v1.s[2] WORD $0x5e1c0421 // mov s1, v1.s[3] WORD $0x1e222800 // fadd s0, s0, s2 WORD $0x6e24dc62 // fmul v2.4s, v3.4s, v4.4s WORD $0x1e252800 // fadd s0, s0, s5 WORD $0x5e140443 // mov s3, v2.s[2] WORD $0x1e212800 // fadd s0, s0, s1 WORD $0x5e0c0441 // mov s1, v2.s[1] WORD $0x1e222800 // fadd s0, s0, s2 WORD $0x1e212800 // fadd s0, s0, s1 WORD $0x5e1c0441 // mov s1, v2.s[3] WORD $0x1e232800 // fadd s0, s0, s3 WORD $0x1e212800 // fadd s0, s0, s1 BNE LBB10_10 WORD $0xb400014a // cbz x10, .LBB10_14 LBB10_12: WORD $0xd37ef50a // lsl x10, x8, #2 WORD $0xcb080128 // sub x8, x9, x8 WORD $0x8b0a0029 // add x9, x1, x10 WORD $0x8b0a000a // add x10, x0, x10 LBB10_13: WORD $0xbc404541 // ldr s1, [x10], #4 WORD $0xf1000508 // subs x8, x8, #1 WORD $0xbc404522 // ldr s2, [x9], #4 WORD $0x1f020020 // fmadd s0, s1, s2, s0 BNE LBB10_13 LBB10_14: WORD $0xa8c17bfd // ldp x29, x30, [sp], #16 FMOVS F0, result+24(FP) RET TEXT ·veuclidean(SB), $8-24 MOVD a+0(FP), R0 MOVD b+8(FP), R1 MOVD n+16(FP), R2 WORD $0xa9bf7bfd // stp x29, x30, [sp, #-16]! WORD $0x910003fd // mov x29, sp WORD $0xd10043ff // sub sp, sp, #16 WORD $0x91000c48 // add x8, x2, #3 WORD $0xf100005f // cmp x2, #0 WORD $0x9a82b108 // csel x8, x8, x2, lt WORD $0x9342fd0a // asr x10, x8, #2 WORD $0x927ef508 // and x8, x8, #0xfffffffffffffffc WORD $0xcb080049 // sub x9, x2, x8 WORD $0x7100055f // cmp w10, #1 BLT LBB11_5 WORD $0x3cc10400 // ldr q0, [x0], #16 WORD $0x71000548 // subs w8, w10, #1 WORD $0x3cc10421 // ldr q1, [x1], #16 WORD $0x4ea1d400 // fsub v0.4s, v0.4s, v1.4s WORD $0x6e20dc00 // fmul v0.4s, v0.4s, v0.4s BEQ LBB11_6 WORD $0xb27b7beb // mov x11, #68719476704 WORD $0xaa0103ec // mov x12, x1 WORD $0x8b0a116a // add x10, x11, x10, lsl #4 WORD $0x927c7d4a // and x10, x10, #0xffffffff0 WORD $0x9100414b // add x11, x10, #16 WORD $0x8b0b000a // add x10, x0, x11 LBB11_3: WORD $0x3cc10401 // ldr q1, [x0], #16 WORD $0x71000508 // subs w8, w8, #1 WORD $0x3cc10582 // ldr q2, [x12], #16 WORD $0x4ea2d421 // fsub v1.4s, v1.4s, v2.4s WORD $0x4e21cc20 // fmla v0.4s, v1.4s, v1.4s BNE LBB11_3 WORD $0x8b0b0021 // add x1, x1, x11 WORD $0xaa0a03e0 // mov x0, x10 B LBB11_6 LBB11_5: WORD $0x6f00e400 // movi v0.2d, #0000000000000000 LBB11_6: WORD $0x7e30d801 // faddp s1, v0.2s WORD $0x5e140402 // mov s2, v0.s[2] WORD $0x7100013f // cmp w9, #0 WORD $0x5e1c0400 // mov s0, v0.s[3] WORD $0x1e212841 // fadd s1, s2, s1 WORD $0x1e212800 // fadd s0, s0, s1 WORD $0xbd000be0 // str s0, [sp, #8] BLE LBB11_15 WORD $0x92407d28 // and x8, x9, #0xffffffff WORD $0xf100211f // cmp x8, #8 BHS LBB11_9 WORD $0xaa1f03e9 // mov x9, xzr B LBB11_12 LBB11_9: WORD $0x9240092a // and x10, x9, #0x7 WORD $0x9100400b // add x11, x0, #16 WORD $0x9100402c // add x12, x1, #16 WORD $0xcb0a0109 // sub x9, x8, x10 WORD $0xaa0903ed // mov x13, x9 LBB11_10: WORD $0xad7f9181 // ldp q1, q4, [x12, #-16] WORD $0xf10021ad // subs x13, x13, #8 WORD $0xad7f8d62 // ldp q2, q3, [x11, #-16] WORD $0x9100816b // add x11, x11, #32 WORD $0x9100818c // add x12, x12, #32 WORD $0x4ea1d441 // fsub v1.4s, v2.4s, v1.4s WORD $0x6e21dc21 // fmul v1.4s, v1.4s, v1.4s WORD $0x5e0c0422 // mov s2, v1.s[1] WORD $0x1e212800 // fadd s0, s0, s1 WORD $0x5e140425 // mov s5, v1.s[2] WORD $0x5e1c0421 // mov s1, v1.s[3] WORD $0x1e222800 // fadd s0, s0, s2 WORD $0x4ea4d462 // fsub v2.4s, v3.4s, v4.4s WORD $0x1e252800 // fadd s0, s0, s5 WORD $0x6e22dc42 // fmul v2.4s, v2.4s, v2.4s WORD $0x1e212800 // fadd s0, s0, s1 WORD $0x5e0c0441 // mov s1, v2.s[1] WORD $0x5e140443 // mov s3, v2.s[2] WORD $0x1e222800 // fadd s0, s0, s2 WORD $0x1e212800 // fadd s0, s0, s1 WORD $0x5e1c0441 // mov s1, v2.s[3] WORD $0x1e232800 // fadd s0, s0, s3 WORD $0x1e212800 // fadd s0, s0, s1 BNE LBB11_10 WORD $0xb400016a // cbz x10, .LBB11_14 LBB11_12: WORD $0xd37ef52b // lsl x11, x9, #2 WORD $0xcb090108 // sub x8, x8, x9 WORD $0x8b0b002a // add x10, x1, x11 WORD $0x8b0b000b // add x11, x0, x11 LBB11_13: WORD $0xbc404561 // ldr s1, [x11], #4 WORD $0xf1000508 // subs x8, x8, #1 WORD $0xbc404542 // ldr s2, [x10], #4 WORD $0x1e223821 // fsub s1, s1, s2 WORD $0x1f010020 // fmadd s0, s1, s1, s0 BNE LBB11_13 LBB11_14: WORD $0xbd000be0 // str s0, [sp, #8] LBB11_15: WORD $0xfd4007e0 // ldr d0, [sp, #8] WORD $0x2ea1f800 // fsqrt v0.2s, v0.2s WORD $0x910043ff // add sp, sp, #16 WORD $0xa8c17bfd // ldp x29, x30, [sp], #16 FMOVS F0, result+24(FP) RET TEXT ·vmm(SB), $24-88 MOVD transA+0(FP), R0 MOVD transB+1(FP), R1 MOVD m+8(FP), R2 MOVD n+16(FP), R3 MOVD k+24(FP), R4 MOVD a+32(FP), R5 MOVD lda+40(FP), R6 MOVD b+48(FP), R7 MOVD ldb+56(FP), R8 MOVD R8, 0(RSP) MOVD c+64(FP), R8 MOVD R8, 8(RSP) MOVD ldc+72(FP), R8 MOVD R8, 16(RSP) WORD $0xa9ba7bfd // stp x29, x30, [sp, #-96]! WORD $0x910003fd // mov x29, sp WORD $0xa9016ffc // stp x28, x27, [sp, #16] WORD $0xa946a7a8 // ldp x8, x9, [x29, #104] WORD $0xf94033ab // ldr x11, [x29, #96] WORD $0xa90267fa // stp x26, x25, [sp, #32] WORD $0xa9035ff8 // stp x24, x23, [sp, #48] WORD $0xa90457f6 // stp x22, x21, [sp, #64] WORD $0xa9054ff4 // stp x20, x19, [sp, #80] WORD $0x37000a80 // tbnz w0, #0, .LBB12_16 WORD $0x37000a61 // tbnz w1, #0, .LBB12_16 WORD $0xf100045f // cmp x2, #1 BLT LBB12_101 WORD $0xf100048c // subs x12, x4, #1 BLT LBB12_101 WORD $0xf100047f // cmp x3, #1 BLT LBB12_101 WORD $0x9b0b0d8f // madd x15, x12, x11, x3 WORD $0xd37ef529 // lsl x9, x9, #2 WORD $0xd37ef4cc // lsl x12, x6, #2 WORD $0x8b03090d // add x13, x8, x3, lsl #2 WORD $0x8b0408ae // add x14, x5, x4, lsl #2 WORD $0xd37ef571 // lsl x17, x11, #2 WORD $0xd37df56b // ubfx x11, x11, #61, #1 WORD $0xaa1f03ea // mov x10, xzr WORD $0x927df070 // and x16, x3, #0xfffffffffffffff8 WORD $0x910040e0 // add x0, x7, #16 WORD $0x91004101 // add x1, x8, #16 WORD $0xaa0803f3 // mov x19, x8 WORD $0x8b0f08ef // add x15, x7, x15, lsl #2 B LBB12_7 LBB12_6: WORD $0x9100054a // add x10, x10, #1 WORD $0x8b090021 // add x1, x1, x9 WORD $0x8b090273 // add x19, x19, x9 WORD $0xeb02015f // cmp x10, x2 BEQ LBB12_101 LBB12_7: WORD $0x9b0a7d35 // mul x21, x9, x10 WORD $0xaa1f03f4 // mov x20, xzr WORD $0x9b0a7d96 // mul x22, x12, x10 WORD $0x9b067d57 // mul x23, x10, x6 WORD $0x8b150118 // add x24, x8, x21 WORD $0x8b1501b9 // add x25, x13, x21 WORD $0xeb0f031f // cmp x24, x15 WORD $0x8b1601d5 // add x21, x14, x22 WORD $0xfa473320 // ccmp x25, x7, #0, lo WORD $0x8b1600b6 // add x22, x5, x22 WORD $0x1a9f957a // csinc w26, w11, wzr, ls WORD $0xeb15031f // cmp x24, x21 WORD $0xaa0003f8 // mov x24, x0 WORD $0x8b1708b5 // add x21, x5, x23, lsl #2 WORD $0xfa5932c2 // ccmp x22, x25, #2, lo WORD $0xaa0703f7 // mov x23, x7 WORD $0x1a9f2756 // csinc w22, w26, wzr, hs B LBB12_9 LBB12_8: WORD $0x91000694 // add x20, x20, #1 WORD $0x8b110318 // add x24, x24, x17 WORD $0x8b1102f7 // add x23, x23, x17 WORD $0xeb04029f // cmp x20, x4 BEQ LBB12_6 LBB12_9: WORD $0xf100207f // cmp x3, #8 WORD $0x1a9f26d9 // csinc w25, w22, wzr, hs WORD $0x36000079 // tbz w25, #0, .LBB12_11 WORD $0xaa1f03fb // mov x27, xzr B LBB12_14 LBB12_11: WORD $0x8b140ab9 // add x25, x21, x20, lsl #2 WORD $0xaa0103fa // mov x26, x1 WORD $0xaa1803fb // mov x27, x24 WORD $0x4d40cb20 // ld1r { v0.4s }, [x25] WORD $0xaa1003f9 // mov x25, x16 LBB12_12: WORD $0xad7f9341 // ldp q1, q4, [x26, #-16] WORD $0xf1002339 // subs x25, x25, #8 WORD $0xad7f8f62 // ldp q2, q3, [x27, #-16] WORD $0x9100837b // add x27, x27, #32 WORD $0x4e20cc41 // fmla v1.4s, v2.4s, v0.4s WORD $0x4e20cc64 // fmla v4.4s, v3.4s, v0.4s WORD $0xad3f9341 // stp q1, q4, [x26, #-16] WORD $0x9100835a // add x26, x26, #32 BNE LBB12_12 WORD $0xeb03021f // cmp x16, x3 WORD $0xaa1003fb // mov x27, x16 BEQ LBB12_8 LBB12_14: WORD $0xd37ef77a // lsl x26, x27, #2 WORD $0xcb1b007b // sub x27, x3, x27 WORD $0x8b1a02f9 // add x25, x23, x26 WORD $0x8b1a027a // add x26, x19, x26 LBB12_15: WORD $0xbc747aa0 // ldr s0, [x21, x20, lsl #2] WORD $0xbc404721 // ldr s1, [x25], #4 WORD $0xbd400342 // ldr s2, [x26] WORD $0xf100077b // subs x27, x27, #1 WORD $0x1f010800 // fmadd s0, s0, s1, s2 WORD $0xbc004740 // str s0, [x26], #4 BNE LBB12_15 B LBB12_8 LBB12_16: WORD $0x36000ca1 // tbz w1, #0, .LBB12_34 WORD $0x37000c80 // tbnz w0, #0, .LBB12_34 WORD $0xf100045f // cmp x2, #1 BLT LBB12_101 WORD $0x91000c8a // add x10, x4, #3 WORD $0xf100009f // cmp x4, #0 WORD $0x9a84b14a // csel x10, x10, x4, lt WORD $0xf100047f // cmp x3, #1 BLT LBB12_101 WORD $0x9342fd4f // asr x15, x10, #2 WORD $0x927ef54a // and x10, x10, #0xfffffffffffffffc WORD $0xcb0a008c // sub x12, x4, x10 WORD $0x710005ed // subs w13, w15, #1 WORD $0x92407d8a // and x10, x12, #0xffffffff BLT LBB12_62 BNE LBB12_74 WORD $0x7100019f // cmp w12, #0 BLE LBB12_97 WORD $0x2f00e400 // movi d0, #0000000000000000 WORD $0x9240098c // and x12, x12, #0x7 WORD $0xd37ef4d0 // lsl x16, x6, #2 WORD $0xd37ef571 // lsl x17, x11, #2 WORD $0xaa1f03ed // mov x13, xzr WORD $0xcb0c014e // sub x14, x10, x12 WORD $0x910080af // add x15, x5, #32 WORD $0x910080e0 // add x0, x7, #32 WORD $0x910040e1 // add x1, x7, #16 WORD $0x910040a4 // add x4, x5, #16 B LBB12_25 LBB12_24: WORD $0x910005ad // add x13, x13, #1 WORD $0x8b1001ef // add x15, x15, x16 WORD $0x8b100084 // add x4, x4, x16 WORD $0xeb0201bf // cmp x13, x2 BEQ LBB12_101 LBB12_25: WORD $0x9b067db4 // mul x20, x13, x6 WORD $0xaa1f03f3 // mov x19, xzr WORD $0xaa0103f6 // mov x22, x1 WORD $0xaa0003f7 // mov x23, x0 WORD $0x9b097db5 // mul x21, x13, x9 WORD $0x8b1408b4 // add x20, x5, x20, lsl #2 WORD $0x8b150915 // add x21, x8, x21, lsl #2 B LBB12_27 LBB12_26: WORD $0xbc337aa1 // str s1, [x21, x19, lsl #2] WORD $0x91000673 // add x19, x19, #1 WORD $0x8b1102f7 // add x23, x23, x17 WORD $0xeb03027f // cmp x19, x3 WORD $0x8b1102d6 // add x22, x22, x17 BEQ LBB12_24 LBB12_27: WORD $0x9b0b7e78 // mul x24, x19, x11 WORD $0x3dc00281 // ldr q1, [x20] WORD $0xf100215f // cmp x10, #8 WORD $0xd37ef718 // lsl x24, x24, #2 WORD $0x3cf868e2 // ldr q2, [x7, x24] WORD $0x6e22dc21 // fmul v1.4s, v1.4s, v2.4s WORD $0x1e202822 // fadd s2, s1, s0 WORD $0x5e0c0423 // mov s3, v1.s[1] WORD $0x1e222862 // fadd s2, s3, s2 WORD $0x5e140423 // mov s3, v1.s[2] WORD $0x5e1c0421 // mov s1, v1.s[3] WORD $0x1e222862 // fadd s2, s3, s2 WORD $0x1e222821 // fadd s1, s1, s2 BHS LBB12_29 WORD $0xaa1f03f8 // mov x24, xzr B LBB12_32 LBB12_29: WORD $0xaa1703f8 // mov x24, x23 WORD $0xaa0f03f9 // mov x25, x15 WORD $0xaa0e03fa // mov x26, x14 LBB12_30: WORD $0xad7f9702 // ldp q2, q5, [x24, #-16] WORD $0xf100235a // subs x26, x26, #8 WORD $0xad7f9323 // ldp q3, q4, [x25, #-16] WORD $0x91008339 // add x25, x25, #32 WORD $0x91008318 // add x24, x24, #32 WORD $0x6e22dc62 // fmul v2.4s, v3.4s, v2.4s WORD $0x5e0c0443 // mov s3, v2.s[1] WORD $0x1e222821 // fadd s1, s1, s2 WORD $0x5e140446 // mov s6, v2.s[2] WORD $0x5e1c0442 // mov s2, v2.s[3] WORD $0x1e232821 // fadd s1, s1, s3 WORD $0x6e25dc83 // fmul v3.4s, v4.4s, v5.4s WORD $0x1e262821 // fadd s1, s1, s6 WORD $0x5e140464 // mov s4, v3.s[2] WORD $0x1e222821 // fadd s1, s1, s2 WORD $0x5e0c0462 // mov s2, v3.s[1] WORD $0x1e232821 // fadd s1, s1, s3 WORD $0x1e222821 // fadd s1, s1, s2 WORD $0x5e1c0462 // mov s2, v3.s[3] WORD $0x1e242821 // fadd s1, s1, s4 WORD $0x1e222821 // fadd s1, s1, s2 BNE LBB12_30 WORD $0xaa0e03f8 // mov x24, x14 WORD $0xb4fffa0c // cbz x12, .LBB12_26 LBB12_32: WORD $0xd37ef71a // lsl x26, x24, #2 WORD $0xcb180158 // sub x24, x10, x24 WORD $0x8b1a02d9 // add x25, x22, x26 WORD $0x8b1a009a // add x26, x4, x26 LBB12_33: WORD $0xbc404742 // ldr s2, [x26], #4 WORD $0xf1000718 // subs x24, x24, #1 WORD $0xbc404723 // ldr s3, [x25], #4 WORD $0x1f030441 // fmadd s1, s2, s3, s1 BNE LBB12_33 B LBB12_26 LBB12_34: WORD $0xf100009f // cmp x4, #0 WORD $0xfa40c864 // ccmp x3, #0, #4, gt WORD $0x1a9fd7ea // cset w10, gt WORD $0x36000aa0 // tbz w0, #0, .LBB12_48 WORD $0x37000a81 // tbnz w1, #0, .LBB12_48 WORD $0xf100045f // cmp x2, #1 WORD $0x5200014a // eor w10, w10, #0x1 WORD $0x1a9fa54a // csinc w10, w10, wzr, ge WORD $0x370030ea // tbnz w10, #0, .LBB12_101 WORD $0xd100048c // sub x12, x4, #1 WORD $0xd37ef529 // lsl x9, x9, #2 WORD $0xd37ef56e // lsl x14, x11, #2 WORD $0x9b067d8d // mul x13, x12, x6 WORD $0xd37df4cf // ubfx x15, x6, #61, #1 WORD $0xaa1f03ea // mov x10, xzr WORD $0x910040e0 // add x0, x7, #16 WORD $0x91004101 // add x1, x8, #16 WORD $0xaa0803f3 // mov x19, x8 WORD $0x9b0b0d90 // madd x16, x12, x11, x3 WORD $0x8b03090c // add x12, x8, x3, lsl #2 WORD $0xd37df56b // ubfx x11, x11, #61, #1 WORD $0x8b0d08b1 // add x17, x5, x13, lsl #2 WORD $0x927df06d // and x13, x3, #0xfffffffffffffff8 WORD $0x8b1008f0 // add x16, x7, x16, lsl #2 WORD $0x91001231 // add x17, x17, #4 B LBB12_39 LBB12_38: WORD $0x9100054a // add x10, x10, #1 WORD $0x8b090021 // add x1, x1, x9 WORD $0x8b090273 // add x19, x19, x9 WORD $0xeb02015f // cmp x10, x2 BEQ LBB12_101 LBB12_39: WORD $0x9b0a7d35 // mul x21, x9, x10 WORD $0xd37ef556 // lsl x22, x10, #2 WORD $0xaa1f03f4 // mov x20, xzr WORD $0x8b160238 // add x24, x17, x22 WORD $0x8b150117 // add x23, x8, x21 WORD $0x8b150199 // add x25, x12, x21 WORD $0x8b1600b5 // add x21, x5, x22 WORD $0xeb1802ff // cmp x23, x24 WORD $0xaa0003f8 // mov x24, x0 WORD $0xfa5932a2 // ccmp x21, x25, #2, lo WORD $0x1a9f25f6 // csinc w22, w15, wzr, hs WORD $0xeb1002ff // cmp x23, x16 WORD $0xfa473320 // ccmp x25, x7, #0, lo WORD $0x1a9f9577 // csinc w23, w11, wzr, ls WORD $0x2a1702d6 // orr w22, w22, w23 WORD $0xaa0703f7 // mov x23, x7 B LBB12_41 LBB12_40: WORD $0x91000694 // add x20, x20, #1 WORD $0x8b0e0318 // add x24, x24, x14 WORD $0x8b0e02f7 // add x23, x23, x14 WORD $0xeb04029f // cmp x20, x4 BEQ LBB12_38 LBB12_41: WORD $0x9b067e99 // mul x25, x20, x6 WORD $0xf100207f // cmp x3, #8 WORD $0x1a9f26da // csinc w26, w22, wzr, hs WORD $0x3600007a // tbz w26, #0, .LBB12_43 WORD $0xaa1f03fc // mov x28, xzr B LBB12_46 LBB12_43: WORD $0x8b190aba // add x26, x21, x25, lsl #2 WORD $0xaa0103fb // mov x27, x1 WORD $0xaa1803fc // mov x28, x24 WORD $0x4d40cb40 // ld1r { v0.4s }, [x26] WORD $0xaa0d03fa // mov x26, x13 LBB12_44: WORD $0xad7f9361 // ldp q1, q4, [x27, #-16] WORD $0xf100235a // subs x26, x26, #8 WORD $0xad7f8f82 // ldp q2, q3, [x28, #-16] WORD $0x9100839c // add x28, x28, #32 WORD $0x4e20cc41 // fmla v1.4s, v2.4s, v0.4s WORD $0x4e20cc64 // fmla v4.4s, v3.4s, v0.4s WORD $0xad3f9361 // stp q1, q4, [x27, #-16] WORD $0x9100837b // add x27, x27, #32 BNE LBB12_44 WORD $0xeb0301bf // cmp x13, x3 WORD $0xaa0d03fc // mov x28, x13 BEQ LBB12_40 LBB12_46: WORD $0xd37ef79b // lsl x27, x28, #2 WORD $0xcb1c007c // sub x28, x3, x28 WORD $0x8b1b02fa // add x26, x23, x27 WORD $0x8b1b027b // add x27, x19, x27 LBB12_47: WORD $0xbc797aa0 // ldr s0, [x21, x25, lsl #2] WORD $0xbc404741 // ldr s1, [x26], #4 WORD $0xbd400362 // ldr s2, [x27] WORD $0xf100079c // subs x28, x28, #1 WORD $0x1f010800 // fmadd s0, s0, s1, s2 WORD $0xbc004760 // str s0, [x27], #4 BNE LBB12_47 B LBB12_40 LBB12_48: WORD $0x0a01014a // and w10, w10, w1 WORD $0x7100055f // cmp w10, #1 BNE LBB12_101 WORD $0xf100045f // cmp x2, #1 BLT LBB12_101 WORD $0x36002640 // tbz w0, #0, .LBB12_101 WORD $0xd100048c // sub x12, x4, #1 WORD $0x8b030090 // add x16, x4, x3 WORD $0xf1001c7f // cmp x3, #7 WORD $0x9b067d8f // mul x15, x12, x6 WORD $0x8b1008f1 // add x17, x7, x16, lsl #2 WORD $0xfa418960 // ccmp x11, #1, #0, hi WORD $0xd37ef529 // lsl x9, x9, #2 WORD $0x8b03090c // add x12, x8, x3, lsl #2 WORD $0xd37df4d0 // ubfx x16, x6, #61, #1 WORD $0x1a9f17e1 // cset w1, eq WORD $0xaa1f03ea // mov x10, xzr WORD $0x927df06d // and x13, x3, #0xfffffffffffffff8 WORD $0x910040ee // add x14, x7, #16 WORD $0x52000021 // eor w1, w1, #0x1 WORD $0xaa0803f3 // mov x19, x8 WORD $0x8b0f08a0 // add x0, x5, x15, lsl #2 WORD $0xd37ef56f // lsl x15, x11, #2 WORD $0xd100122b // sub x11, x17, #4 WORD $0x91001011 // add x17, x0, #4 WORD $0x91004100 // add x0, x8, #16 B LBB12_53 LBB12_52: WORD $0x9100054a // add x10, x10, #1 WORD $0x8b090000 // add x0, x0, x9 WORD $0x8b090273 // add x19, x19, x9 WORD $0xeb02015f // cmp x10, x2 BEQ LBB12_101 LBB12_53: WORD $0x9b0a7d35 // mul x21, x9, x10 WORD $0xd37ef556 // lsl x22, x10, #2 WORD $0xaa1f03f4 // mov x20, xzr WORD $0x8b160238 // add x24, x17, x22 WORD $0x8b150117 // add x23, x8, x21 WORD $0x8b150199 // add x25, x12, x21 WORD $0x8b1600b5 // add x21, x5, x22 WORD $0xeb1802ff // cmp x23, x24 WORD $0xaa0e03f8 // mov x24, x14 WORD $0xfa5932a2 // ccmp x21, x25, #2, lo WORD $0x1a9f2616 // csinc w22, w16, wzr, hs WORD $0xeb0b02ff // cmp x23, x11 WORD $0xaa0703f7 // mov x23, x7 WORD $0xfa473320 // ccmp x25, x7, #0, lo WORD $0x1a9f96d6 // csinc w22, w22, wzr, ls WORD $0x2a160036 // orr w22, w1, w22 B LBB12_55 LBB12_54: WORD $0x91000694 // add x20, x20, #1 WORD $0x91001318 // add x24, x24, #4 WORD $0x910012f7 // add x23, x23, #4 WORD $0xeb04029f // cmp x20, x4 BEQ LBB12_52 LBB12_55: WORD $0x9b067e99 // mul x25, x20, x6 WORD $0x36000076 // tbz w22, #0, .LBB12_57 WORD $0xaa1f03fc // mov x28, xzr B LBB12_60 LBB12_57: WORD $0x8b190aba // add x26, x21, x25, lsl #2 WORD $0xaa0003fb // mov x27, x0 WORD $0xaa1803fc // mov x28, x24 WORD $0x4d40cb40 // ld1r { v0.4s }, [x26] WORD $0xaa0d03fa // mov x26, x13 LBB12_58: WORD $0xad7f9361 // ldp q1, q4, [x27, #-16] WORD $0xf100235a // subs x26, x26, #8 WORD $0xad7f8f82 // ldp q2, q3, [x28, #-16] WORD $0x9100839c // add x28, x28, #32 WORD $0x4e20cc41 // fmla v1.4s, v2.4s, v0.4s WORD $0x4e20cc64 // fmla v4.4s, v3.4s, v0.4s WORD $0xad3f9361 // stp q1, q4, [x27, #-16] WORD $0x9100837b // add x27, x27, #32 BNE LBB12_58 WORD $0xeb0301bf // cmp x13, x3 WORD $0xaa0d03fc // mov x28, x13 BEQ LBB12_54 LBB12_60: WORD $0x9b1c5dfa // madd x26, x15, x28, x23 WORD $0x8b1c0a7b // add x27, x19, x28, lsl #2 WORD $0xcb1c007c // sub x28, x3, x28 LBB12_61: WORD $0xbc797aa0 // ldr s0, [x21, x25, lsl #2] WORD $0xbd400341 // ldr s1, [x26] WORD $0xf100079c // subs x28, x28, #1 WORD $0xbd400362 // ldr s2, [x27] WORD $0x8b0f035a // add x26, x26, x15 WORD $0x1f010800 // fmadd s0, s0, s1, s2 WORD $0xbc004760 // str s0, [x27], #4 BNE LBB12_61 B LBB12_54 LBB12_62: WORD $0x7100019f // cmp w12, #0 BLE LBB12_88 WORD $0x9240098c // and x12, x12, #0x7 WORD $0xd37ef4ce // lsl x14, x6, #2 WORD $0xd37ef56b // lsl x11, x11, #2 WORD $0xaa1f03ed // mov x13, xzr WORD $0xcb0c014f // sub x15, x10, x12 WORD $0x910040b0 // add x16, x5, #16 WORD $0x910040f1 // add x17, x7, #16 B LBB12_65 LBB12_64: WORD $0x910005ad // add x13, x13, #1 WORD $0x8b0e0210 // add x16, x16, x14 WORD $0x8b0e00a5 // add x5, x5, x14 WORD $0xeb0201bf // cmp x13, x2 BEQ LBB12_101 LBB12_65: WORD $0x9b097da1 // mul x1, x13, x9 WORD $0xaa1f03e0 // mov x0, xzr WORD $0xaa0703e4 // mov x4, x7 WORD $0xaa1103e6 // mov x6, x17 WORD $0x8b010901 // add x1, x8, x1, lsl #2 B LBB12_67 LBB12_66: WORD $0xbc207820 // str s0, [x1, x0, lsl #2] WORD $0x91000400 // add x0, x0, #1 WORD $0x8b0b00c6 // add x6, x6, x11 WORD $0xeb03001f // cmp x0, x3 WORD $0x8b0b0084 // add x4, x4, x11 BEQ LBB12_64 LBB12_67: WORD $0x2f00e400 // movi d0, #0000000000000000 WORD $0xf100215f // cmp x10, #8 BHS LBB12_69 WORD $0xaa1f03f3 // mov x19, xzr B LBB12_72 LBB12_69: WORD $0xaa0f03f3 // mov x19, x15 WORD $0xaa0603f4 // mov x20, x6 WORD $0xaa1003f5 // mov x21, x16 LBB12_70: WORD $0xad7f9281 // ldp q1, q4, [x20, #-16] WORD $0xf1002273 // subs x19, x19, #8 WORD $0xad7f8ea2 // ldp q2, q3, [x21, #-16] WORD $0x910082b5 // add x21, x21, #32 WORD $0x91008294 // add x20, x20, #32 WORD $0x6e21dc41 // fmul v1.4s, v2.4s, v1.4s WORD $0x5e0c0422 // mov s2, v1.s[1] WORD $0x1e212800 // fadd s0, s0, s1 WORD $0x5e140425 // mov s5, v1.s[2] WORD $0x5e1c0421 // mov s1, v1.s[3] WORD $0x1e222800 // fadd s0, s0, s2 WORD $0x6e24dc62 // fmul v2.4s, v3.4s, v4.4s WORD $0x1e252800 // fadd s0, s0, s5 WORD $0x5e140443 // mov s3, v2.s[2] WORD $0x1e212800 // fadd s0, s0, s1 WORD $0x5e0c0441 // mov s1, v2.s[1] WORD $0x1e222800 // fadd s0, s0, s2 WORD $0x1e212800 // fadd s0, s0, s1 WORD $0x5e1c0441 // mov s1, v2.s[3] WORD $0x1e232800 // fadd s0, s0, s3 WORD $0x1e212800 // fadd s0, s0, s1 BNE LBB12_70 WORD $0xaa0f03f3 // mov x19, x15 WORD $0xb4fffb6c // cbz x12, .LBB12_66 LBB12_72: WORD $0xd37ef675 // lsl x21, x19, #2 WORD $0xcb130153 // sub x19, x10, x19 WORD $0x8b150094 // add x20, x4, x21 WORD $0x8b1500b5 // add x21, x5, x21 LBB12_73: WORD $0xbc4046a1 // ldr s1, [x21], #4 WORD $0xf1000673 // subs x19, x19, #1 WORD $0xbc404682 // ldr s2, [x20], #4 WORD $0x1f020020 // fmadd s0, s1, s2, s0 BNE LBB12_73 B LBB12_66 LBB12_74: WORD $0xb27b7bf0 // mov x16, #68719476704 WORD $0x2f00e400 // movi d0, #0000000000000000 WORD $0xd37ef571 // lsl x17, x11, #2 WORD $0x8b0f1210 // add x16, x16, x15, lsl #4 WORD $0x9240098f // and x15, x12, #0x7 WORD $0xaa1f03ee // mov x14, xzr WORD $0xcb0f0140 // sub x0, x10, x15 WORD $0x927c7e01 // and x1, x16, #0xffffffff0 WORD $0xd37ef4d0 // lsl x16, x6, #2 WORD $0x9100c024 // add x4, x1, #48 WORD $0x91008034 // add x20, x1, #32 WORD $0x8b0400a1 // add x1, x5, x4 WORD $0x8b0400e4 // add x4, x7, x4 WORD $0x8b1400f3 // add x19, x7, x20 WORD $0x8b1400b4 // add x20, x5, x20 B LBB12_76 LBB12_75: WORD $0x910005ce // add x14, x14, #1 WORD $0x8b100021 // add x1, x1, x16 WORD $0x8b100294 // add x20, x20, x16 WORD $0xeb0201df // cmp x14, x2 BEQ LBB12_101 LBB12_76: WORD $0x9b067dd6 // mul x22, x14, x6 WORD $0xaa1f03f5 // mov x21, xzr WORD $0xaa1303f9 // mov x25, x19 WORD $0xaa0403fa // mov x26, x4 WORD $0x9b097dd7 // mul x23, x14, x9 WORD $0x8b1608b6 // add x22, x5, x22, lsl #2 WORD $0x8b170917 // add x23, x8, x23, lsl #2 WORD $0x910042d8 // add x24, x22, #16 B LBB12_78 LBB12_77: WORD $0xbc357ae1 // str s1, [x23, x21, lsl #2] WORD $0x910006b5 // add x21, x21, #1 WORD $0x8b11035a // add x26, x26, x17 WORD $0xeb0302bf // cmp x21, x3 WORD $0x8b110339 // add x25, x25, x17 BEQ LBB12_75 LBB12_78: WORD $0x9b0b7ebb // mul x27, x21, x11 WORD $0x3dc002c1 // ldr q1, [x22] WORD $0x2a0d03fc // mov w28, w13 WORD $0xaa1803fe // mov x30, x24 WORD $0x8b1b08fb // add x27, x7, x27, lsl #2 WORD $0x3cc10762 // ldr q2, [x27], #16 WORD $0x6e22dc21 // fmul v1.4s, v1.4s, v2.4s LBB12_79: WORD $0x3cc107c2 // ldr q2, [x30], #16 WORD $0x7100079c // subs w28, w28, #1 WORD $0x3cc10763 // ldr q3, [x27], #16 WORD $0x4e22cc61 // fmla v1.4s, v3.4s, v2.4s BNE LBB12_79 WORD $0x1e202822 // fadd s2, s1, s0 WORD $0x5e0c0423 // mov s3, v1.s[1] WORD $0x7100059f // cmp w12, #1 WORD $0x1e222862 // fadd s2, s3, s2 WORD $0x5e140423 // mov s3, v1.s[2] WORD $0x5e1c0421 // mov s1, v1.s[3] WORD $0x1e222862 // fadd s2, s3, s2 WORD $0x1e222821 // fadd s1, s1, s2 BLT LBB12_77 WORD $0xf100215f // cmp x10, #8 BHS LBB12_83 WORD $0xaa1f03fb // mov x27, xzr B LBB12_86 LBB12_83: WORD $0xaa1a03fb // mov x27, x26 WORD $0xaa0103fc // mov x28, x1 WORD $0xaa0003fe // mov x30, x0 LBB12_84: WORD $0xad7f9762 // ldp q2, q5, [x27, #-16] WORD $0xf10023de // subs x30, x30, #8 WORD $0xad7f9383 // ldp q3, q4, [x28, #-16] WORD $0x9100839c // add x28, x28, #32 WORD $0x9100837b // add x27, x27, #32 WORD $0x6e22dc62 // fmul v2.4s, v3.4s, v2.4s WORD $0x5e0c0443 // mov s3, v2.s[1] WORD $0x1e222821 // fadd s1, s1, s2 WORD $0x5e140446 // mov s6, v2.s[2] WORD $0x5e1c0442 // mov s2, v2.s[3] WORD $0x1e232821 // fadd s1, s1, s3 WORD $0x6e25dc83 // fmul v3.4s, v4.4s, v5.4s WORD $0x1e262821 // fadd s1, s1, s6 WORD $0x5e140464 // mov s4, v3.s[2] WORD $0x1e222821 // fadd s1, s1, s2 WORD $0x5e0c0462 // mov s2, v3.s[1] WORD $0x1e232821 // fadd s1, s1, s3 WORD $0x1e222821 // fadd s1, s1, s2 WORD $0x5e1c0462 // mov s2, v3.s[3] WORD $0x1e242821 // fadd s1, s1, s4 WORD $0x1e222821 // fadd s1, s1, s2 BNE LBB12_84 WORD $0xaa0003fb // mov x27, x0 WORD $0xb4fff8ef // cbz x15, .LBB12_77 LBB12_86: WORD $0xd37ef77e // lsl x30, x27, #2 WORD $0xcb1b015b // sub x27, x10, x27 WORD $0x8b1e033c // add x28, x25, x30 WORD $0x8b1e029e // add x30, x20, x30 LBB12_87: WORD $0xbc4047c2 // ldr s2, [x30], #4 WORD $0xf100077b // subs x27, x27, #1 WORD $0xbc404783 // ldr s3, [x28], #4 WORD $0x1f030441 // fmadd s1, s2, s3, s1 BNE LBB12_87 B LBB12_77 LBB12_88: WORD $0x6f00e400 // movi v0.2d, #0000000000000000 WORD $0xd37ef529 // lsl x9, x9, #2 WORD $0xaa1f03ea // mov x10, xzr WORD $0x927dec6b // and x11, x3, #0x7ffffffffffffff8 WORD $0x9100410c // add x12, x8, #16 B LBB12_90 LBB12_89: WORD $0x9100054a // add x10, x10, #1 WORD $0x8b09018c // add x12, x12, x9 WORD $0x8b090108 // add x8, x8, x9 WORD $0xeb02015f // cmp x10, x2 BEQ LBB12_101 LBB12_90: WORD $0xf100207f // cmp x3, #8 BHS LBB12_92 WORD $0xaa1f03ee // mov x14, xzr B LBB12_95 LBB12_92: WORD $0xaa0b03ed // mov x13, x11 WORD $0xaa0c03ee // mov x14, x12 LBB12_93: WORD $0xad3f81c0 // stp q0, q0, [x14, #-16] WORD $0xf10021ad // subs x13, x13, #8 WORD $0x910081ce // add x14, x14, #32 BNE LBB12_93 WORD $0xeb03017f // cmp x11, x3 WORD $0xaa0b03ee // mov x14, x11 BEQ LBB12_89 LBB12_95: WORD $0x8b0e090d // add x13, x8, x14, lsl #2 WORD $0xcb0e006e // sub x14, x3, x14 LBB12_96: WORD $0xf10005ce // subs x14, x14, #1 WORD $0xb80045bf // str wzr, [x13], #4 BNE LBB12_96 B LBB12_89 LBB12_97: WORD $0x2f00e400 // movi d0, #0000000000000000 WORD $0xd37ef56a // lsl x10, x11, #2 WORD $0xd37ef529 // lsl x9, x9, #2 WORD $0xaa1f03eb // mov x11, xzr LBB12_98: WORD $0x9b067d6c // mul x12, x11, x6 WORD $0xaa0303ed // mov x13, x3 WORD $0xaa0803ee // mov x14, x8 WORD $0xaa0703ef // mov x15, x7 WORD $0x8b0c08ac // add x12, x5, x12, lsl #2 LBB12_99: WORD $0x3dc00181 // ldr q1, [x12] WORD $0x3dc001e2 // ldr q2, [x15] WORD $0xf10005ad // subs x13, x13, #1 WORD $0x8b0a01ef // add x15, x15, x10 WORD $0x6e22dc21 // fmul v1.4s, v1.4s, v2.4s WORD $0x1e202822 // fadd s2, s1, s0 WORD $0x5e0c0423 // mov s3, v1.s[1] WORD $0x1e222862 // fadd s2, s3, s2 WORD $0x5e140423 // mov s3, v1.s[2] WORD $0x5e1c0421 // mov s1, v1.s[3] WORD $0x1e222862 // fadd s2, s3, s2 WORD $0x1e222821 // fadd s1, s1, s2 WORD $0xbc0045c1 // str s1, [x14], #4 BNE LBB12_99 WORD $0x9100056b // add x11, x11, #1 WORD $0x8b090108 // add x8, x8, x9 WORD $0xeb02017f // cmp x11, x2 BNE LBB12_98 LBB12_101: WORD $0xa9454ff4 // ldp x20, x19, [sp, #80] WORD $0xa94457f6 // ldp x22, x21, [sp, #64] WORD $0xa9435ff8 // ldp x24, x23, [sp, #48] WORD $0xa94267fa // ldp x26, x25, [sp, #32] WORD $0xa9416ffc // ldp x28, x27, [sp, #16] WORD $0xa8c67bfd // ldp x29, x30, [sp], #96 RET ================================================ FILE: common/floats/floats_noasm.go ================================================ //go:build noasm || (!amd64 && !arm64 && !riscv64) // Copyright 2025 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package floats type Feature uint64 const OPENBLAS Feature = 1 << iota var feature Feature func (Feature) String() string { return "NOASM" } func (Feature) mulConstAddTo(a []float32, b float32, c, dst []float32) { mulConstAddTo(a, b, c, dst) } func (Feature) mulConstAdd(a []float32, b float32, c []float32) { mulConstAdd(a, b, c) } func (Feature) mulConstTo(a []float32, b float32, c []float32) { mulConstTo(a, b, c) } func (Feature) addConst(a []float32, b float32) { addConst(a, b) } func (Feature) sub(a, b []float32) { sub(a, b) } func (Feature) subTo(a, b, c []float32) { subTo(a, b, c) } func (Feature) mulTo(a, b, c []float32) { mulTo(a, b, c) } func (Feature) mulConst(a []float32, b float32) { mulConst(a, b) } func (Feature) divTo(a, b, c []float32) { divTo(a, b, c) } func (Feature) sqrtTo(a, b []float32) { sqrtTo(a, b) } func (Feature) dot(a, b []float32) float32 { return dot(a, b) } func (Feature) euclidean(a, b []float32) float32 { return euclidean(a, b) } func (Feature) mm(transA, transB bool, m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int) { mm(transA, transB, m, n, k, a, lda, b, ldb, c, ldc) } ================================================ FILE: common/floats/floats_riscv64.go ================================================ //go:build !noasm // Copyright 2022 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package floats import ( "unsafe" "golang.org/x/sys/cpu" ) //go:generate goat src/floats_rvv.c -O3 -march=rv64imafdv type Feature uint64 const ( V Feature = 1 << iota OPENBLAS ) var feature Feature func init() { if cpu.RISCV64.HasV { feature = feature | V } } func (feature Feature) String() string { if feature == V { return "RVV" } else { return "RV" } } func (feature Feature) mulConstAddTo(a []float32, b float32, c []float32, dst []float32) { if feature&V == V { vmul_const_add_to(unsafe.Pointer(&a[0]), unsafe.Pointer(&b), unsafe.Pointer(&c[0]), unsafe.Pointer(&dst[0]), int64(len(a))) } else { mulConstAddTo(a, b, c, dst) } } func (feature Feature) mulConstAdd(a []float32, b float32, c []float32) { if feature&V == V { vmul_const_add(unsafe.Pointer(&a[0]), unsafe.Pointer(&b), unsafe.Pointer(&c[0]), int64(len(a))) } else { mulConstAdd(a, b, c) } } func (feature Feature) mulConstTo(a []float32, b float32, c []float32) { if feature&V == V { vmul_const_to(unsafe.Pointer(&a[0]), unsafe.Pointer(&b), unsafe.Pointer(&c[0]), int64(len(a))) } else { mulConstTo(a, b, c) } } func (feature Feature) addConst(a []float32, b float32) { if feature&V == V { vadd_const(unsafe.Pointer(&a[0]), unsafe.Pointer(&b), int64(len(a))) } else { addConst(a, b) } } func (feature Feature) sub(a, b []float32) { if feature&V == V { vsub(unsafe.Pointer(&a[0]), unsafe.Pointer(&b[0]), int64(len(a))) } else { sub(a, b) } } func (feature Feature) subTo(a, b, c []float32) { if feature&V == V { vsub_to(unsafe.Pointer(&a[0]), unsafe.Pointer(&b[0]), unsafe.Pointer(&c[0]), int64(len(a))) } else { subTo(a, b, c) } } func (feature Feature) mulTo(a, b, c []float32) { if feature&V == V { vmul_to(unsafe.Pointer(&a[0]), unsafe.Pointer(&b[0]), unsafe.Pointer(&c[0]), int64(len(a))) } else { mulTo(a, b, c) } } func (feature Feature) mulConst(a []float32, b float32) { if feature&V == V { vmul_const(unsafe.Pointer(&a[0]), unsafe.Pointer(&b), int64(len(a))) } else { mulConst(a, b) } } func (feature Feature) divTo(a, b, c []float32) { if feature&V == V { vdiv_to(unsafe.Pointer(&a[0]), unsafe.Pointer(&b[0]), unsafe.Pointer(&c[0]), int64(len(a))) } else { divTo(a, b, c) } } func (feature Feature) sqrtTo(a, b []float32) { if feature&V == V { vsqrt_to(unsafe.Pointer(&a[0]), unsafe.Pointer(&b[0]), int64(len(a))) } else { sqrtTo(a, b) } } func (feature Feature) dot(a, b []float32) float32 { if feature&V == V { return vdot(unsafe.Pointer(&a[0]), unsafe.Pointer(&b[0]), int64(len(a))) } else { return dot(a, b) } } func (feature Feature) euclidean(a, b []float32) float32 { if feature&V == V { return veuclidean(unsafe.Pointer(&a[0]), unsafe.Pointer(&b[0]), int64(len(a))) } else { return euclidean(a, b) } } func (feature Feature) mm(transA, transB bool, m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int) { if feature&V == V && feature&OPENBLAS == 0 { vmm(transA, transB, int64(m), int64(n), int64(k), unsafe.Pointer(&a[0]), int64(lda), unsafe.Pointer(&b[0]), int64(ldb), unsafe.Pointer(&c[0]), int64(ldc)) } else { mm(transA, transB, m, n, k, a, lda, b, ldb, c, ldc) } } ================================================ FILE: common/floats/floats_riscv64_test.go ================================================ //go:build !noasm // Copyright 2022 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package floats import ( "fmt" "math/rand" "strconv" "testing" "github.com/stretchr/testify/suite" ) func TestV(t *testing.T) { suite.Run(t, &SIMDTestSuite{Feature: V}) } func initializeFloat32Array(n int) []float32 { x := make([]float32, n) for i := 0; i < n; i++ { x[i] = rand.Float32() } return x } func BenchmarkDot(b *testing.B) { for _, feat := range []Feature{0, V} { b.Run(feat.String(), func(b *testing.B) { for i := 16; i <= 128; i *= 2 { b.Run(strconv.Itoa(i), func(b *testing.B) { v1 := initializeFloat32Array(i) v2 := initializeFloat32Array(i) b.ResetTimer() for i := 0; i < b.N; i++ { feat.dot(v1, v2) } }) } }) } } func BenchmarkEuclidean(b *testing.B) { for _, feat := range []Feature{0, V} { b.Run(feat.String(), func(b *testing.B) { for i := 16; i <= 128; i *= 2 { b.Run(strconv.Itoa(i), func(b *testing.B) { v1 := initializeFloat32Array(i) v2 := initializeFloat32Array(i) b.ResetTimer() for i := 0; i < b.N; i++ { feat.euclidean(v1, v2) } }) } }) } } func BenchmarkMulConstAddTo(b *testing.B) { for _, feat := range []Feature{0, V} { b.Run(feat.String(), func(b *testing.B) { for i := 16; i <= 128; i *= 2 { b.Run(strconv.Itoa(i), func(b *testing.B) { v1 := initializeFloat32Array(i) v2 := initializeFloat32Array(i) v3 := make([]float32, i) b.ResetTimer() for i := 0; i < b.N; i++ { feat.mulConstAddTo(v1, 2, v2, v3) } }) } }) } } func BenchmarkMulConstAdd(b *testing.B) { for _, feat := range []Feature{0, V} { b.Run(feat.String(), func(b *testing.B) { for i := 16; i <= 128; i *= 2 { b.Run(strconv.Itoa(i), func(b *testing.B) { v1 := initializeFloat32Array(i) v2 := initializeFloat32Array(i) b.ResetTimer() for i := 0; i < b.N; i++ { feat.mulConstAdd(v1, 2, v2) } }) } }) } } func BenchmarkMulConst(b *testing.B) { for _, feat := range []Feature{0, V} { b.Run(feat.String(), func(b *testing.B) { for i := 16; i <= 128; i *= 2 { b.Run(strconv.Itoa(i), func(b *testing.B) { v1 := initializeFloat32Array(i) b.ResetTimer() for i := 0; i < b.N; i++ { feat.mulConst(v1, 2) } }) } }) } } func BenchmarkMulConstTo(b *testing.B) { for _, feat := range []Feature{0, V} { b.Run(feat.String(), func(b *testing.B) { for i := 16; i <= 128; i *= 2 { b.Run(strconv.Itoa(i), func(b *testing.B) { v1 := initializeFloat32Array(i) v2 := initializeFloat32Array(i) b.ResetTimer() for i := 0; i < b.N; i++ { feat.mulConstTo(v1, 2, v2) } }) } }) } } func BenchmarkAddConst(b *testing.B) { for _, feat := range []Feature{0, V} { b.Run(feat.String(), func(b *testing.B) { for i := 16; i <= 128; i *= 2 { b.Run(strconv.Itoa(i), func(b *testing.B) { v1 := initializeFloat32Array(i) b.ResetTimer() for i := 0; i < b.N; i++ { feat.addConst(v1, 2) } }) } }) } } func BenchmarkSub(b *testing.B) { for _, feat := range []Feature{0, V} { b.Run(feat.String(), func(b *testing.B) { for i := 16; i <= 128; i *= 2 { b.Run(strconv.Itoa(i), func(b *testing.B) { v1 := initializeFloat32Array(i) v2 := initializeFloat32Array(i) b.ResetTimer() for i := 0; i < b.N; i++ { feat.sub(v1, v2) } }) } }) } } func BenchmarkSubTo(b *testing.B) { for _, feat := range []Feature{0, V} { b.Run(feat.String(), func(b *testing.B) { for i := 16; i <= 128; i *= 2 { b.Run(strconv.Itoa(i), func(b *testing.B) { v1 := initializeFloat32Array(i) v2 := initializeFloat32Array(i) v3 := make([]float32, i) b.ResetTimer() for i := 0; i < b.N; i++ { feat.subTo(v1, v2, v3) } }) } }) } } func BenchmarkMulTo(b *testing.B) { for _, feat := range []Feature{0, V} { b.Run(feat.String(), func(b *testing.B) { for i := 16; i <= 128; i *= 2 { b.Run(strconv.Itoa(i), func(b *testing.B) { v1 := initializeFloat32Array(i) v2 := initializeFloat32Array(i) v3 := make([]float32, i) b.ResetTimer() for i := 0; i < b.N; i++ { feat.mulTo(v1, v2, v3) } }) } }) } } func BenchmarkDivTo(b *testing.B) { for _, feat := range []Feature{0, V} { b.Run(feat.String(), func(b *testing.B) { for i := 16; i <= 128; i *= 2 { b.Run(strconv.Itoa(i), func(b *testing.B) { v1 := initializeFloat32Array(i) v2 := initializeFloat32Array(i) v3 := make([]float32, i) b.ResetTimer() for i := 0; i < b.N; i++ { feat.divTo(v1, v2, v3) } }) } }) } } func BenchmarkSqrtTo(b *testing.B) { for _, feat := range []Feature{0, V} { b.Run(feat.String(), func(b *testing.B) { for i := 16; i <= 128; i *= 2 { b.Run(strconv.Itoa(i), func(b *testing.B) { v1 := initializeFloat32Array(i) v2 := make([]float32, i) b.ResetTimer() for i := 0; i < b.N; i++ { feat.sqrtTo(v1, v2) } }) } }) } } func BenchmarkMM(b *testing.B) { for _, transA := range []bool{false, true} { for _, transB := range []bool{false, true} { for _, feat := range []Feature{0, V} { b.Run(fmt.Sprintf("(%v,%v,%v)", transA, transB, feat.String()), func(b *testing.B) { for n := 16; n <= 128; n *= 2 { b.Run(strconv.Itoa(n), func(b *testing.B) { matA := initializeFloat32Array(n * n) matB := initializeFloat32Array(n * n) matC := make([]float32, n*n) b.ResetTimer() for i := 0; i < b.N; i++ { feat.mm(transA, transB, n, n, n, matA, n, matB, n, matC, n) } }) } }) } } } } ================================================ FILE: common/floats/floats_rvv.go ================================================ //go:build !noasm && riscv64 // Code generated by GoAT. DO NOT EDIT. // versions: // clang 18.1.8 (11bb4) // objdump 2.42 // flags: -march=rv64imafdv -O3 // source: src/floats_rvv.c package floats import "unsafe" //go:noescape func vmul_const_add_to(a, b, c, dst unsafe.Pointer, n int64) //go:noescape func vmul_const_add(a, b, c unsafe.Pointer, n int64) //go:noescape func vmul_const_to(a, b, c unsafe.Pointer, n int64) //go:noescape func vmul_const(a, b unsafe.Pointer, n int64) //go:noescape func vadd_const(a, b unsafe.Pointer, n int64) //go:noescape func vsub_to(a, b, c unsafe.Pointer, n int64) //go:noescape func vsub(a, b unsafe.Pointer, n int64) //go:noescape func vmul_to(a, b, c unsafe.Pointer, n int64) //go:noescape func vdiv_to(a, b, c unsafe.Pointer, n int64) //go:noescape func vsqrt_to(a, b unsafe.Pointer, n int64) //go:noescape func vdot(a, b unsafe.Pointer, n int64) (result float32) //go:noescape func veuclidean(a, b unsafe.Pointer, n int64) (result float32) //go:noescape func vmm(transA, transB bool, m, n, k int64, a unsafe.Pointer, lda int64, b unsafe.Pointer, ldb int64, c unsafe.Pointer, ldc int64) ================================================ FILE: common/floats/floats_rvv.s ================================================ //go:build !noasm && riscv64 // Code generated by GoAT. DO NOT EDIT. // versions: // clang 18.1.8 (11bb4) // objdump 2.42 // flags: -march=rv64imafdv -O3 // source: src/floats_rvv.c TEXT ·vmul_const_add_to(SB), $0-40 MOV a+0(FP), A0 MOV b+8(FP), A1 MOV c+16(FP), A2 MOV dst+24(FP), A3 MOV n+32(FP), A4 BLEZ A4, LBB0_8 WORD $0xc2202873 // csrr a6, vlenb WORD $0x00185793 // srli a5, a6, 1 WORD $0x01800893 // li a7, 24 BLTU A7, A5, LBB0_3 WORD $0x01800793 // li a5, 24 LBB0_3: WORD $0xff010113 // addi sp, sp, -16 WORD $0x00113423 // sd ra, 8(sp) # 8-byte Folded Spill WORD $0x00813023 // sd s0, 0(sp) # 8-byte Folded Spill WORD $0x01010413 // addi s0, sp, 16 WORD $0xff017113 // andi sp, sp, -16 BGEU A4, A5, LBB0_9 WORD $0x00000793 // li a5, 0 LBB0_5: WORD $0x00279813 // slli a6, a5, 2 WORD $0x010686b3 // add a3, a3, a6 WORD $0x01060633 // add a2, a2, a6 WORD $0x01050533 // add a0, a0, a6 WORD $0x40f70733 // sub a4, a4, a5 LBB0_6: WORD $0x00052787 // flw fa5, 0(a0) WORD $0x0005a707 // flw fa4, 0(a1) WORD $0x00062687 // flw fa3, 0(a2) WORD $0x68e7f7c3 // fmadd.s fa5, fa5, fa4, fa3 WORD $0x00f6a027 // fsw fa5, 0(a3) WORD $0x00468693 // addi a3, a3, 4 WORD $0x00460613 // addi a2, a2, 4 WORD $0xfff70713 // addi a4, a4, -1 WORD $0x00450513 // addi a0, a0, 4 BNEZ A4, LBB0_6 LBB0_7: WORD $0xff040113 // addi sp, s0, -16 WORD $0x00813083 // ld ra, 8(sp) # 8-byte Folded Reload WORD $0x00013403 // ld s0, 0(sp) # 8-byte Folded Reload WORD $0x01010113 // addi sp, sp, 16 LBB0_8: RET LBB0_9: WORD $0x00271293 // slli t0, a4, 2 WORD $0x005688b3 // add a7, a3, t0 WORD $0x005507b3 // add a5, a0, t0 WORD $0x00f6b7b3 // sltu a5, a3, a5 WORD $0x01153333 // sltu t1, a0, a7 WORD $0x0067f333 // and t1, a5, t1 WORD $0x00000793 // li a5, 0 BNEZ T1, LBB0_5 WORD $0x00458313 // addi t1, a1, 4 WORD $0x0066b333 // sltu t1, a3, t1 WORD $0x0115b3b3 // sltu t2, a1, a7 WORD $0x00737333 // and t1, t1, t2 BNEZ T1, LBB0_5 WORD $0x005602b3 // add t0, a2, t0 WORD $0x0056b2b3 // sltu t0, a3, t0 WORD $0x011638b3 // sltu a7, a2, a7 WORD $0x0112f8b3 // and a7, t0, a7 BNEZ A7, LBB0_5 WORD $0x00185893 // srli a7, a6, 1 WORD $0x0d1077d7 // vsetvli a5, zero, e32, m2, ta, ma WORD $0x0a05e407 // vlse32.v v8, (a1), zero WORD $0x411007b3 // neg a5, a7 WORD $0x00e7f7b3 // and a5, a5, a4 WORD $0x00181813 // slli a6, a6, 1 WORD $0x00078293 // mv t0, a5 WORD $0x00068313 // mv t1, a3 WORD $0x00060393 // mv t2, a2 WORD $0x00050e13 // mv t3, a0 LBB0_13: WORD $0x228e6507 // vl2re32.v v10, (t3) WORD $0x2283e607 // vl2re32.v v12, (t2) WORD $0xb2a41657 // vfmacc.vv v12, v8, v10 WORD $0x22830627 // vs2r.v v12, (t1) WORD $0x010e0e33 // add t3, t3, a6 WORD $0x010383b3 // add t2, t2, a6 WORD $0x411282b3 // sub t0, t0, a7 WORD $0x01030333 // add t1, t1, a6 BNEZ T0, LBB0_13 BNE A5, A4, LBB0_5 JMP LBB0_7 TEXT ·vmul_const_add(SB), $0-32 MOV a+0(FP), A0 MOV b+8(FP), A1 MOV c+16(FP), A2 MOV n+24(FP), A3 BLEZ A3, LBB1_8 WORD $0xc22027f3 // csrr a5, vlenb WORD $0x0017d713 // srli a4, a5, 1 WORD $0x01000813 // li a6, 16 BLTU A6, A4, LBB1_3 WORD $0x01000713 // li a4, 16 LBB1_3: WORD $0xff010113 // addi sp, sp, -16 WORD $0x00113423 // sd ra, 8(sp) # 8-byte Folded Spill WORD $0x00813023 // sd s0, 0(sp) # 8-byte Folded Spill WORD $0x01010413 // addi s0, sp, 16 WORD $0xff017113 // andi sp, sp, -16 BGEU A3, A4, LBB1_9 WORD $0x00000713 // li a4, 0 LBB1_5: WORD $0x00271793 // slli a5, a4, 2 WORD $0x00f60633 // add a2, a2, a5 WORD $0x00f50533 // add a0, a0, a5 WORD $0x40e686b3 // sub a3, a3, a4 LBB1_6: WORD $0x00052787 // flw fa5, 0(a0) WORD $0x0005a707 // flw fa4, 0(a1) WORD $0x00062687 // flw fa3, 0(a2) WORD $0x68e7f7c3 // fmadd.s fa5, fa5, fa4, fa3 WORD $0x00f62027 // fsw fa5, 0(a2) WORD $0x00460613 // addi a2, a2, 4 WORD $0xfff68693 // addi a3, a3, -1 WORD $0x00450513 // addi a0, a0, 4 BNEZ A3, LBB1_6 LBB1_7: WORD $0xff040113 // addi sp, s0, -16 WORD $0x00813083 // ld ra, 8(sp) # 8-byte Folded Reload WORD $0x00013403 // ld s0, 0(sp) # 8-byte Folded Reload WORD $0x01010113 // addi sp, sp, 16 LBB1_8: RET LBB1_9: WORD $0x00269713 // slli a4, a3, 2 WORD $0x00e60833 // add a6, a2, a4 WORD $0x00e50733 // add a4, a0, a4 WORD $0x00e63733 // sltu a4, a2, a4 WORD $0x010538b3 // sltu a7, a0, a6 WORD $0x011778b3 // and a7, a4, a7 WORD $0x00000713 // li a4, 0 BNEZ A7, LBB1_5 WORD $0x00458893 // addi a7, a1, 4 WORD $0x011638b3 // sltu a7, a2, a7 WORD $0x0105b833 // sltu a6, a1, a6 WORD $0x0108f833 // and a6, a7, a6 BNEZ A6, LBB1_5 WORD $0x0017d813 // srli a6, a5, 1 WORD $0x0d107757 // vsetvli a4, zero, e32, m2, ta, ma WORD $0x0a05e407 // vlse32.v v8, (a1), zero WORD $0x41000733 // neg a4, a6 WORD $0x00d77733 // and a4, a4, a3 WORD $0x00179793 // slli a5, a5, 1 WORD $0x00070893 // mv a7, a4 WORD $0x00060293 // mv t0, a2 WORD $0x00050313 // mv t1, a0 LBB1_12: WORD $0x22836507 // vl2re32.v v10, (t1) WORD $0x2282e607 // vl2re32.v v12, (t0) WORD $0xb2a41657 // vfmacc.vv v12, v8, v10 WORD $0x22828627 // vs2r.v v12, (t0) WORD $0x00f30333 // add t1, t1, a5 WORD $0x410888b3 // sub a7, a7, a6 WORD $0x00f282b3 // add t0, t0, a5 BNEZ A7, LBB1_12 BNE A4, A3, LBB1_5 JMP LBB1_7 TEXT ·vmul_const_to(SB), $0-32 MOV a+0(FP), A0 MOV b+8(FP), A1 MOV c+16(FP), A2 MOV n+24(FP), A3 BLEZ A3, LBB2_8 WORD $0xc22027f3 // csrr a5, vlenb WORD $0x0017d713 // srli a4, a5, 1 WORD $0x01000813 // li a6, 16 BLTU A6, A4, LBB2_3 WORD $0x01000713 // li a4, 16 LBB2_3: WORD $0xff010113 // addi sp, sp, -16 WORD $0x00113423 // sd ra, 8(sp) # 8-byte Folded Spill WORD $0x00813023 // sd s0, 0(sp) # 8-byte Folded Spill WORD $0x01010413 // addi s0, sp, 16 WORD $0xff017113 // andi sp, sp, -16 BGEU A3, A4, LBB2_9 WORD $0x00000713 // li a4, 0 LBB2_5: WORD $0x00271793 // slli a5, a4, 2 WORD $0x00f60633 // add a2, a2, a5 WORD $0x00f50533 // add a0, a0, a5 WORD $0x40e686b3 // sub a3, a3, a4 LBB2_6: WORD $0x00052787 // flw fa5, 0(a0) WORD $0x0005a707 // flw fa4, 0(a1) WORD $0x10e7f7d3 // fmul.s fa5, fa5, fa4 WORD $0x00f62027 // fsw fa5, 0(a2) WORD $0x00460613 // addi a2, a2, 4 WORD $0xfff68693 // addi a3, a3, -1 WORD $0x00450513 // addi a0, a0, 4 BNEZ A3, LBB2_6 LBB2_7: WORD $0xff040113 // addi sp, s0, -16 WORD $0x00813083 // ld ra, 8(sp) # 8-byte Folded Reload WORD $0x00013403 // ld s0, 0(sp) # 8-byte Folded Reload WORD $0x01010113 // addi sp, sp, 16 LBB2_8: RET LBB2_9: WORD $0x00269713 // slli a4, a3, 2 WORD $0x00e60833 // add a6, a2, a4 WORD $0x00e50733 // add a4, a0, a4 WORD $0x00e63733 // sltu a4, a2, a4 WORD $0x010538b3 // sltu a7, a0, a6 WORD $0x011778b3 // and a7, a4, a7 WORD $0x00000713 // li a4, 0 BNEZ A7, LBB2_5 WORD $0x00458893 // addi a7, a1, 4 WORD $0x011638b3 // sltu a7, a2, a7 WORD $0x0105b833 // sltu a6, a1, a6 WORD $0x0108f833 // and a6, a7, a6 BNEZ A6, LBB2_5 WORD $0x0017d813 // srli a6, a5, 1 WORD $0x41000733 // neg a4, a6 WORD $0x0005a787 // flw fa5, 0(a1) WORD $0x00d77733 // and a4, a4, a3 WORD $0x00179793 // slli a5, a5, 1 WORD $0x0d1078d7 // vsetvli a7, zero, e32, m2, ta, ma WORD $0x00070893 // mv a7, a4 WORD $0x00060293 // mv t0, a2 WORD $0x00050313 // mv t1, a0 LBB2_12: WORD $0x22836407 // vl2re32.v v8, (t1) WORD $0x9287d457 // vfmul.vf v8, v8, fa5 WORD $0x22828427 // vs2r.v v8, (t0) WORD $0x00f30333 // add t1, t1, a5 WORD $0x410888b3 // sub a7, a7, a6 WORD $0x00f282b3 // add t0, t0, a5 BNEZ A7, LBB2_12 BNE A4, A3, LBB2_5 JMP LBB2_7 TEXT ·vmul_const(SB), $0-24 MOV a+0(FP), A0 MOV b+8(FP), A1 MOV n+16(FP), A2 BLEZ A2, LBB3_10 WORD $0xc2202773 // csrr a4, vlenb WORD $0x00175693 // srli a3, a4, 1 WORD $0x00800793 // li a5, 8 BLTU A5, A3, LBB3_3 WORD $0x00800693 // li a3, 8 LBB3_3: WORD $0xff010113 // addi sp, sp, -16 WORD $0x00113423 // sd ra, 8(sp) # 8-byte Folded Spill WORD $0x00813023 // sd s0, 0(sp) # 8-byte Folded Spill WORD $0x01010413 // addi s0, sp, 16 WORD $0xff017113 // andi sp, sp, -16 BLTU A2, A3, LBB3_6 WORD $0x00458693 // addi a3, a1, 4 BGEU A0, A3, LBB3_11 WORD $0x00261693 // slli a3, a2, 2 WORD $0x00d506b3 // add a3, a0, a3 BGEU A1, A3, LBB3_11 LBB3_6: WORD $0x00000693 // li a3, 0 LBB3_7: WORD $0x00269713 // slli a4, a3, 2 WORD $0x00e50533 // add a0, a0, a4 WORD $0x40d60633 // sub a2, a2, a3 LBB3_8: WORD $0x0005a787 // flw fa5, 0(a1) WORD $0x00052707 // flw fa4, 0(a0) WORD $0x10e7f7d3 // fmul.s fa5, fa5, fa4 WORD $0x00f52027 // fsw fa5, 0(a0) WORD $0xfff60613 // addi a2, a2, -1 WORD $0x00450513 // addi a0, a0, 4 BNEZ A2, LBB3_8 LBB3_9: WORD $0xff040113 // addi sp, s0, -16 WORD $0x00813083 // ld ra, 8(sp) # 8-byte Folded Reload WORD $0x00013403 // ld s0, 0(sp) # 8-byte Folded Reload WORD $0x01010113 // addi sp, sp, 16 LBB3_10: RET LBB3_11: WORD $0x00175793 // srli a5, a4, 1 WORD $0x40f006b3 // neg a3, a5 WORD $0x0005a787 // flw fa5, 0(a1) WORD $0x00c6f6b3 // and a3, a3, a2 WORD $0x00171713 // slli a4, a4, 1 WORD $0x0d107857 // vsetvli a6, zero, e32, m2, ta, ma WORD $0x00068813 // mv a6, a3 WORD $0x00050893 // mv a7, a0 LBB3_12: WORD $0x2288e407 // vl2re32.v v8, (a7) WORD $0x9287d457 // vfmul.vf v8, v8, fa5 WORD $0x22888427 // vs2r.v v8, (a7) WORD $0x40f80833 // sub a6, a6, a5 WORD $0x00e888b3 // add a7, a7, a4 BNEZ A6, LBB3_12 BEQ A3, A2, LBB3_9 JMP LBB3_7 TEXT ·vadd_const(SB), $0-24 MOV a+0(FP), A0 MOV b+8(FP), A1 MOV n+16(FP), A2 BLEZ A2, LBB4_10 WORD $0xc2202773 // csrr a4, vlenb WORD $0x00175693 // srli a3, a4, 1 WORD $0x00800793 // li a5, 8 BLTU A5, A3, LBB4_3 WORD $0x00800693 // li a3, 8 LBB4_3: WORD $0xff010113 // addi sp, sp, -16 WORD $0x00113423 // sd ra, 8(sp) # 8-byte Folded Spill WORD $0x00813023 // sd s0, 0(sp) # 8-byte Folded Spill WORD $0x01010413 // addi s0, sp, 16 WORD $0xff017113 // andi sp, sp, -16 BLTU A2, A3, LBB4_6 WORD $0x00458693 // addi a3, a1, 4 BGEU A0, A3, LBB4_11 WORD $0x00261693 // slli a3, a2, 2 WORD $0x00d506b3 // add a3, a0, a3 BGEU A1, A3, LBB4_11 LBB4_6: WORD $0x00000693 // li a3, 0 LBB4_7: WORD $0x00269713 // slli a4, a3, 2 WORD $0x00e50533 // add a0, a0, a4 WORD $0x40d60633 // sub a2, a2, a3 LBB4_8: WORD $0x0005a787 // flw fa5, 0(a1) WORD $0x00052707 // flw fa4, 0(a0) WORD $0x00e7f7d3 // fadd.s fa5, fa5, fa4 WORD $0x00f52027 // fsw fa5, 0(a0) WORD $0xfff60613 // addi a2, a2, -1 WORD $0x00450513 // addi a0, a0, 4 BNEZ A2, LBB4_8 LBB4_9: WORD $0xff040113 // addi sp, s0, -16 WORD $0x00813083 // ld ra, 8(sp) # 8-byte Folded Reload WORD $0x00013403 // ld s0, 0(sp) # 8-byte Folded Reload WORD $0x01010113 // addi sp, sp, 16 LBB4_10: RET LBB4_11: WORD $0x00175793 // srli a5, a4, 1 WORD $0x40f006b3 // neg a3, a5 WORD $0x0005a787 // flw fa5, 0(a1) WORD $0x00c6f6b3 // and a3, a3, a2 WORD $0x00171713 // slli a4, a4, 1 WORD $0x0d107857 // vsetvli a6, zero, e32, m2, ta, ma WORD $0x00068813 // mv a6, a3 WORD $0x00050893 // mv a7, a0 LBB4_12: WORD $0x2288e407 // vl2re32.v v8, (a7) WORD $0x0287d457 // vfadd.vf v8, v8, fa5 WORD $0x22888427 // vs2r.v v8, (a7) WORD $0x40f80833 // sub a6, a6, a5 WORD $0x00e888b3 // add a7, a7, a4 BNEZ A6, LBB4_12 BEQ A3, A2, LBB4_9 JMP LBB4_7 TEXT ·vsub_to(SB), $0-32 MOV a+0(FP), A0 MOV b+8(FP), A1 MOV c+16(FP), A2 MOV n+24(FP), A3 BLEZ A3, LBB5_8 WORD $0xc2202873 // csrr a6, vlenb WORD $0x00185713 // srli a4, a6, 1 WORD $0x01000793 // li a5, 16 BLTU A5, A4, LBB5_3 WORD $0x01000713 // li a4, 16 LBB5_3: WORD $0xff010113 // addi sp, sp, -16 WORD $0x00113423 // sd ra, 8(sp) # 8-byte Folded Spill WORD $0x00813023 // sd s0, 0(sp) # 8-byte Folded Spill WORD $0x01010413 // addi s0, sp, 16 WORD $0xff017113 // andi sp, sp, -16 BGEU A3, A4, LBB5_9 WORD $0x00000713 // li a4, 0 LBB5_5: WORD $0x40e686b3 // sub a3, a3, a4 WORD $0x00271713 // slli a4, a4, 2 WORD $0x00e60633 // add a2, a2, a4 WORD $0x00e585b3 // add a1, a1, a4 WORD $0x00e50533 // add a0, a0, a4 LBB5_6: WORD $0x00052787 // flw fa5, 0(a0) WORD $0x0005a707 // flw fa4, 0(a1) WORD $0x08e7f7d3 // fsub.s fa5, fa5, fa4 WORD $0x00f62027 // fsw fa5, 0(a2) WORD $0xfff68693 // addi a3, a3, -1 WORD $0x00460613 // addi a2, a2, 4 WORD $0x00458593 // addi a1, a1, 4 WORD $0x00450513 // addi a0, a0, 4 BNEZ A3, LBB5_6 LBB5_7: WORD $0xff040113 // addi sp, s0, -16 WORD $0x00813083 // ld ra, 8(sp) # 8-byte Folded Reload WORD $0x00013403 // ld s0, 0(sp) # 8-byte Folded Reload WORD $0x01010113 // addi sp, sp, 16 LBB5_8: RET LBB5_9: WORD $0x00181793 // slli a5, a6, 1 WORD $0x40a608b3 // sub a7, a2, a0 WORD $0x00000713 // li a4, 0 BLTU A7, A5, LBB5_5 WORD $0x40b608b3 // sub a7, a2, a1 BLTU A7, A5, LBB5_5 WORD $0x00185813 // srli a6, a6, 1 WORD $0x41000733 // neg a4, a6 WORD $0x00d77733 // and a4, a4, a3 WORD $0x0d1078d7 // vsetvli a7, zero, e32, m2, ta, ma WORD $0x00070893 // mv a7, a4 WORD $0x00060293 // mv t0, a2 WORD $0x00058313 // mv t1, a1 WORD $0x00050393 // mv t2, a0 LBB5_12: WORD $0x2283e407 // vl2re32.v v8, (t2) WORD $0x22836507 // vl2re32.v v10, (t1) WORD $0x0a851457 // vfsub.vv v8, v8, v10 WORD $0x22828427 // vs2r.v v8, (t0) WORD $0x00f383b3 // add t2, t2, a5 WORD $0x00f30333 // add t1, t1, a5 WORD $0x410888b3 // sub a7, a7, a6 WORD $0x00f282b3 // add t0, t0, a5 BNEZ A7, LBB5_12 BNE A4, A3, LBB5_5 JMP LBB5_7 TEXT ·vsub(SB), $0-24 MOV a+0(FP), A0 MOV b+8(FP), A1 MOV n+16(FP), A2 BLEZ A2, LBB6_10 WORD $0xc22027f3 // csrr a5, vlenb WORD $0x0017d693 // srli a3, a5, 1 WORD $0x01000713 // li a4, 16 BLTU A4, A3, LBB6_3 WORD $0x01000693 // li a3, 16 LBB6_3: WORD $0xff010113 // addi sp, sp, -16 WORD $0x00113423 // sd ra, 8(sp) # 8-byte Folded Spill WORD $0x00813023 // sd s0, 0(sp) # 8-byte Folded Spill WORD $0x01010413 // addi s0, sp, 16 WORD $0xff017113 // andi sp, sp, -16 BLTU A2, A3, LBB6_6 WORD $0x00261693 // slli a3, a2, 2 WORD $0x00d58733 // add a4, a1, a3 BGEU A0, A4, LBB6_11 WORD $0x00d506b3 // add a3, a0, a3 BGEU A1, A3, LBB6_11 LBB6_6: WORD $0x00000693 // li a3, 0 LBB6_7: WORD $0x40d60633 // sub a2, a2, a3 WORD $0x00269693 // slli a3, a3, 2 WORD $0x00d50533 // add a0, a0, a3 WORD $0x00d585b3 // add a1, a1, a3 LBB6_8: WORD $0x0005a787 // flw fa5, 0(a1) WORD $0x00052707 // flw fa4, 0(a0) WORD $0x08f777d3 // fsub.s fa5, fa4, fa5 WORD $0x00f52027 // fsw fa5, 0(a0) WORD $0xfff60613 // addi a2, a2, -1 WORD $0x00450513 // addi a0, a0, 4 WORD $0x00458593 // addi a1, a1, 4 BNEZ A2, LBB6_8 LBB6_9: WORD $0xff040113 // addi sp, s0, -16 WORD $0x00813083 // ld ra, 8(sp) # 8-byte Folded Reload WORD $0x00013403 // ld s0, 0(sp) # 8-byte Folded Reload WORD $0x01010113 // addi sp, sp, 16 LBB6_10: RET LBB6_11: WORD $0x0017d713 // srli a4, a5, 1 WORD $0x40e006b3 // neg a3, a4 WORD $0x00c6f6b3 // and a3, a3, a2 WORD $0x00179793 // slli a5, a5, 1 WORD $0x0d107857 // vsetvli a6, zero, e32, m2, ta, ma WORD $0x00068813 // mv a6, a3 WORD $0x00050893 // mv a7, a0 WORD $0x00058293 // mv t0, a1 LBB6_12: WORD $0x2282e407 // vl2re32.v v8, (t0) WORD $0x2288e507 // vl2re32.v v10, (a7) WORD $0x0aa41457 // vfsub.vv v8, v10, v8 WORD $0x22888427 // vs2r.v v8, (a7) WORD $0x00f282b3 // add t0, t0, a5 WORD $0x40e80833 // sub a6, a6, a4 WORD $0x00f888b3 // add a7, a7, a5 BNEZ A6, LBB6_12 BEQ A3, A2, LBB6_9 JMP LBB6_7 TEXT ·vmul_to(SB), $0-32 MOV a+0(FP), A0 MOV b+8(FP), A1 MOV c+16(FP), A2 MOV n+24(FP), A3 BLEZ A3, LBB7_8 WORD $0xc2202873 // csrr a6, vlenb WORD $0x00185713 // srli a4, a6, 1 WORD $0x01000793 // li a5, 16 BLTU A5, A4, LBB7_3 WORD $0x01000713 // li a4, 16 LBB7_3: WORD $0xff010113 // addi sp, sp, -16 WORD $0x00113423 // sd ra, 8(sp) # 8-byte Folded Spill WORD $0x00813023 // sd s0, 0(sp) # 8-byte Folded Spill WORD $0x01010413 // addi s0, sp, 16 WORD $0xff017113 // andi sp, sp, -16 BGEU A3, A4, LBB7_9 WORD $0x00000713 // li a4, 0 LBB7_5: WORD $0x40e686b3 // sub a3, a3, a4 WORD $0x00271713 // slli a4, a4, 2 WORD $0x00e60633 // add a2, a2, a4 WORD $0x00e585b3 // add a1, a1, a4 WORD $0x00e50533 // add a0, a0, a4 LBB7_6: WORD $0x00052787 // flw fa5, 0(a0) WORD $0x0005a707 // flw fa4, 0(a1) WORD $0x10e7f7d3 // fmul.s fa5, fa5, fa4 WORD $0x00f62027 // fsw fa5, 0(a2) WORD $0xfff68693 // addi a3, a3, -1 WORD $0x00460613 // addi a2, a2, 4 WORD $0x00458593 // addi a1, a1, 4 WORD $0x00450513 // addi a0, a0, 4 BNEZ A3, LBB7_6 LBB7_7: WORD $0xff040113 // addi sp, s0, -16 WORD $0x00813083 // ld ra, 8(sp) # 8-byte Folded Reload WORD $0x00013403 // ld s0, 0(sp) # 8-byte Folded Reload WORD $0x01010113 // addi sp, sp, 16 LBB7_8: RET LBB7_9: WORD $0x00181793 // slli a5, a6, 1 WORD $0x40a608b3 // sub a7, a2, a0 WORD $0x00000713 // li a4, 0 BLTU A7, A5, LBB7_5 WORD $0x40b608b3 // sub a7, a2, a1 BLTU A7, A5, LBB7_5 WORD $0x00185813 // srli a6, a6, 1 WORD $0x41000733 // neg a4, a6 WORD $0x00d77733 // and a4, a4, a3 WORD $0x0d1078d7 // vsetvli a7, zero, e32, m2, ta, ma WORD $0x00070893 // mv a7, a4 WORD $0x00060293 // mv t0, a2 WORD $0x00058313 // mv t1, a1 WORD $0x00050393 // mv t2, a0 LBB7_12: WORD $0x2283e407 // vl2re32.v v8, (t2) WORD $0x22836507 // vl2re32.v v10, (t1) WORD $0x92851457 // vfmul.vv v8, v8, v10 WORD $0x22828427 // vs2r.v v8, (t0) WORD $0x00f383b3 // add t2, t2, a5 WORD $0x00f30333 // add t1, t1, a5 WORD $0x410888b3 // sub a7, a7, a6 WORD $0x00f282b3 // add t0, t0, a5 BNEZ A7, LBB7_12 BNE A4, A3, LBB7_5 JMP LBB7_7 TEXT ·vdiv_to(SB), $0-32 MOV a+0(FP), A0 MOV b+8(FP), A1 MOV c+16(FP), A2 MOV n+24(FP), A3 BLEZ A3, LBB8_8 WORD $0xc2202873 // csrr a6, vlenb WORD $0x00185713 // srli a4, a6, 1 WORD $0x01000793 // li a5, 16 BLTU A5, A4, LBB8_3 WORD $0x01000713 // li a4, 16 LBB8_3: WORD $0xff010113 // addi sp, sp, -16 WORD $0x00113423 // sd ra, 8(sp) # 8-byte Folded Spill WORD $0x00813023 // sd s0, 0(sp) # 8-byte Folded Spill WORD $0x01010413 // addi s0, sp, 16 WORD $0xff017113 // andi sp, sp, -16 BGEU A3, A4, LBB8_9 WORD $0x00000713 // li a4, 0 LBB8_5: WORD $0x40e686b3 // sub a3, a3, a4 WORD $0x00271713 // slli a4, a4, 2 WORD $0x00e60633 // add a2, a2, a4 WORD $0x00e585b3 // add a1, a1, a4 WORD $0x00e50533 // add a0, a0, a4 LBB8_6: WORD $0x00052787 // flw fa5, 0(a0) WORD $0x0005a707 // flw fa4, 0(a1) WORD $0x18e7f7d3 // fdiv.s fa5, fa5, fa4 WORD $0x00f62027 // fsw fa5, 0(a2) WORD $0xfff68693 // addi a3, a3, -1 WORD $0x00460613 // addi a2, a2, 4 WORD $0x00458593 // addi a1, a1, 4 WORD $0x00450513 // addi a0, a0, 4 BNEZ A3, LBB8_6 LBB8_7: WORD $0xff040113 // addi sp, s0, -16 WORD $0x00813083 // ld ra, 8(sp) # 8-byte Folded Reload WORD $0x00013403 // ld s0, 0(sp) # 8-byte Folded Reload WORD $0x01010113 // addi sp, sp, 16 LBB8_8: RET LBB8_9: WORD $0x00181793 // slli a5, a6, 1 WORD $0x40a608b3 // sub a7, a2, a0 WORD $0x00000713 // li a4, 0 BLTU A7, A5, LBB8_5 WORD $0x40b608b3 // sub a7, a2, a1 BLTU A7, A5, LBB8_5 WORD $0x00185813 // srli a6, a6, 1 WORD $0x41000733 // neg a4, a6 WORD $0x00d77733 // and a4, a4, a3 WORD $0x0d1078d7 // vsetvli a7, zero, e32, m2, ta, ma WORD $0x00070893 // mv a7, a4 WORD $0x00060293 // mv t0, a2 WORD $0x00058313 // mv t1, a1 WORD $0x00050393 // mv t2, a0 LBB8_12: WORD $0x2283e407 // vl2re32.v v8, (t2) WORD $0x22836507 // vl2re32.v v10, (t1) WORD $0x82851457 // vfdiv.vv v8, v8, v10 WORD $0x22828427 // vs2r.v v8, (t0) WORD $0x00f383b3 // add t2, t2, a5 WORD $0x00f30333 // add t1, t1, a5 WORD $0x410888b3 // sub a7, a7, a6 WORD $0x00f282b3 // add t0, t0, a5 BNEZ A7, LBB8_12 BNE A4, A3, LBB8_5 JMP LBB8_7 TEXT ·vsqrt_to(SB), $0-24 MOV a+0(FP), A0 MOV b+8(FP), A1 MOV n+16(FP), A2 BLEZ A2, LBB9_4 WORD $0xff010113 // addi sp, sp, -16 WORD $0x00113423 // sd ra, 8(sp) # 8-byte Folded Spill WORD $0x00813023 // sd s0, 0(sp) # 8-byte Folded Spill WORD $0x01010413 // addi s0, sp, 16 WORD $0xff017113 // andi sp, sp, -16 LBB9_2: WORD $0x0d0676d7 // vsetvli a3, a2, e32, m1, ta, ma WORD $0x02056407 // vle32.v v8, (a0) WORD $0x4e801457 // vfsqrt.v v8, v8 WORD $0x0205e427 // vse32.v v8, (a1) WORD $0x00269713 // slli a4, a3, 2 WORD $0x00e50533 // add a0, a0, a4 WORD $0x40d60633 // sub a2, a2, a3 WORD $0x00e585b3 // add a1, a1, a4 BGTZ A2, LBB9_2 WORD $0xff040113 // addi sp, s0, -16 WORD $0x00813083 // ld ra, 8(sp) # 8-byte Folded Reload WORD $0x00013403 // ld s0, 0(sp) # 8-byte Folded Reload WORD $0x01010113 // addi sp, sp, 16 LBB9_4: RET TEXT ·vdot(SB), $8-24 MOV a+0(FP), A0 MOV b+8(FP), A1 MOV n+16(FP), A2 WORD $0xff010113 // addi sp, sp, -16 WORD $0x00113423 // sd ra, 8(sp) # 8-byte Folded Spill WORD $0x00813023 // sd s0, 0(sp) # 8-byte Folded Spill WORD $0x01010413 // addi s0, sp, 16 WORD $0xff017113 // andi sp, sp, -16 WORD $0x0d0076d7 // vsetvli a3, zero, e32, m1, ta, ma WORD $0x02d65733 // divu a4, a2, a3 WORD $0x0007079b // sext.w a5, a4 WORD $0x5e003457 // vmv.v.i v8, 0 BLEZ A5, LBB10_3 WORD $0x00000813 // li a6, 0 WORD $0x00269893 // slli a7, a3, 2 WORD $0x9e8034d7 // vmv1r.v v9, v8 LBB10_2: WORD $0x02056507 // vle32.v v10, (a0) WORD $0x0205e587 // vle32.v v11, (a1) WORD $0xb2a594d7 // vfmacc.vv v9, v11, v10 WORD $0x0018081b // addiw a6, a6, 1 WORD $0x01150533 // add a0, a0, a7 WORD $0x011585b3 // add a1, a1, a7 BLT A6, A5, LBB10_2 JMP LBB10_4 LBB10_3: WORD $0x9e8034d7 // vmv1r.v v9, v8 LBB10_4: WORD $0x02d706b3 // mul a3, a4, a3 WORD $0x40d60633 // sub a2, a2, a3 WORD $0x0e941457 // vfredosum.vs v8, v9, v8 WORD $0x0d067057 // vsetvli zero, a2, e32, m1, ta, ma WORD $0x02056487 // vle32.v v9, (a0) WORD $0x0205e507 // vle32.v v10, (a1) WORD $0x929514d7 // vfmul.vv v9, v9, v10 WORD $0x0e941457 // vfredosum.vs v8, v9, v8 WORD $0x42801557 // vfmv.f.s fa0, v8 WORD $0xff040113 // addi sp, s0, -16 WORD $0x00813083 // ld ra, 8(sp) # 8-byte Folded Reload WORD $0x00013403 // ld s0, 0(sp) # 8-byte Folded Reload WORD $0x01010113 // addi sp, sp, 16 MOVF FA0, result+24(FP) RET TEXT ·veuclidean(SB), $8-24 MOV a+0(FP), A0 MOV b+8(FP), A1 MOV n+16(FP), A2 WORD $0xff010113 // addi sp, sp, -16 WORD $0x00113423 // sd ra, 8(sp) # 8-byte Folded Spill WORD $0x00813023 // sd s0, 0(sp) # 8-byte Folded Spill WORD $0x01010413 // addi s0, sp, 16 WORD $0xff017113 // andi sp, sp, -16 WORD $0x0d0076d7 // vsetvli a3, zero, e32, m1, ta, ma WORD $0x02d65733 // divu a4, a2, a3 WORD $0x0007079b // sext.w a5, a4 WORD $0x5e003457 // vmv.v.i v8, 0 BLEZ A5, LBB11_3 WORD $0x00000813 // li a6, 0 WORD $0x00269893 // slli a7, a3, 2 WORD $0x9e8034d7 // vmv1r.v v9, v8 LBB11_2: WORD $0x02056507 // vle32.v v10, (a0) WORD $0x0205e587 // vle32.v v11, (a1) WORD $0x0aa59557 // vfsub.vv v10, v10, v11 WORD $0xb2a514d7 // vfmacc.vv v9, v10, v10 WORD $0x0018081b // addiw a6, a6, 1 WORD $0x01150533 // add a0, a0, a7 WORD $0x011585b3 // add a1, a1, a7 BLT A6, A5, LBB11_2 JMP LBB11_4 LBB11_3: WORD $0x9e8034d7 // vmv1r.v v9, v8 LBB11_4: WORD $0x02d706b3 // mul a3, a4, a3 WORD $0x40d60633 // sub a2, a2, a3 WORD $0x0e941457 // vfredosum.vs v8, v9, v8 WORD $0x0d067657 // vsetvli a2, a2, e32, m1, ta, ma WORD $0x02056487 // vle32.v v9, (a0) WORD $0x0205e507 // vle32.v v10, (a1) WORD $0x0d007557 // vsetvli a0, zero, e32, m1, ta, ma WORD $0x0a9514d7 // vfsub.vv v9, v9, v10 WORD $0x0d067057 // vsetvli zero, a2, e32, m1, ta, ma WORD $0x929494d7 // vfmul.vv v9, v9, v9 WORD $0x0e941457 // vfredosum.vs v8, v9, v8 WORD $0x4e801457 // vfsqrt.v v8, v8 WORD $0x42801557 // vfmv.f.s fa0, v8 WORD $0xff040113 // addi sp, s0, -16 WORD $0x00813083 // ld ra, 8(sp) # 8-byte Folded Reload WORD $0x00013403 // ld s0, 0(sp) # 8-byte Folded Reload WORD $0x01010113 // addi sp, sp, 16 MOVF FA0, result+24(FP) RET TEXT ·vmm(SB), $0-88 MOVB transA+0(FP), A0 MOVB transB+1(FP), A1 MOV m+8(FP), A2 MOV n+16(FP), A3 MOV k+24(FP), A4 MOV a+32(FP), A5 MOV lda+40(FP), A6 MOV b+48(FP), A7 ADDI -24, SP, SP MOV ldb+80(FP), T0 MOV T0, 0(SP) MOV c+88(FP), T0 MOV T0, 8(SP) MOV ldc+96(FP), T0 MOV T0, 16(SP) WORD $0xf8010113 // addi sp, sp, -128 WORD $0x06113c23 // sd ra, 120(sp) # 8-byte Folded Spill WORD $0x06813823 // sd s0, 112(sp) # 8-byte Folded Spill WORD $0x07213423 // sd s2, 104(sp) # 8-byte Folded Spill WORD $0x07313023 // sd s3, 96(sp) # 8-byte Folded Spill WORD $0x05413c23 // sd s4, 88(sp) # 8-byte Folded Spill WORD $0x05513823 // sd s5, 80(sp) # 8-byte Folded Spill WORD $0x05613423 // sd s6, 72(sp) # 8-byte Folded Spill WORD $0x05713023 // sd s7, 64(sp) # 8-byte Folded Spill WORD $0x03813c23 // sd s8, 56(sp) # 8-byte Folded Spill WORD $0x03913823 // sd s9, 48(sp) # 8-byte Folded Spill WORD $0x03a13423 // sd s10, 40(sp) # 8-byte Folded Spill WORD $0x08010413 // addi s0, sp, 128 WORD $0xff017113 // andi sp, sp, -16 WORD $0x01043303 // ld t1, 16(s0) WORD $0x00843283 // ld t0, 8(s0) WORD $0x02513023 // sd t0, 32(sp) # 8-byte Folded Spill WORD $0x00043383 // ld t2, 0(s0) WORD $0x01113823 // sd a7, 16(sp) # 8-byte Folded Spill WORD $0x00c13c23 // sd a2, 24(sp) # 8-byte Folded Spill BNEZ A0, LBB12_18 BNEZ A1, LBB12_18 WORD $0x01813503 // ld a0, 24(sp) # 8-byte Folded Reload BLEZ A0, LBB12_65 BLEZ A4, LBB12_65 BLEZ A3, LBB12_65 WORD $0x00269f13 // slli t5, a3, 2 WORD $0x00271f93 // slli t6, a4, 2 WORD $0xfff70513 // addi a0, a4, -1 WORD $0x02750533 // mul a0, a0, t2 WORD $0x00d50eb3 // add t4, a0, a3 WORD $0xc2202973 // csrr s2, vlenb WORD $0x00195513 // srli a0, s2, 1 WORD $0x00800593 // li a1, 8 WORD $0x002e9e93 // slli t4, t4, 2 WORD $0x00050a13 // mv s4, a0 BLTU A1, A0, LBB12_7 WORD $0x00800a13 // li s4, 8 LBB12_7: WORD $0x00000593 // li a1, 0 WORD $0x00231293 // slli t0, t1, 2 WORD $0x00281613 // slli a2, a6, 2 WORD $0x01013883 // ld a7, 16(sp) # 8-byte Folded Reload WORD $0x01d88eb3 // add t4, a7, t4 WORD $0x02013303 // ld t1, 32(sp) # 8-byte Folded Reload WORD $0x01e30f33 // add t5, t1, t5 WORD $0x01f78fb3 // add t6, a5, t6 WORD $0x00100893 // li a7, 1 WORD $0x03d89893 // slli a7, a7, 61 WORD $0x0113f8b3 // and a7, t2, a7 WORD $0x00239e13 // slli t3, t2, 2 WORD $0x00191913 // slli s2, s2, 1 WORD $0x011039b3 // snez s3, a7 WORD $0x0146ba33 // sltu s4, a3, s4 JMP LBB12_9 LBB12_8: WORD $0x00158593 // addi a1, a1, 1 WORD $0x00530333 // add t1, t1, t0 WORD $0x01813883 // ld a7, 24(sp) # 8-byte Folded Reload BEQ A1, A7, LBB12_65 LBB12_9: WORD $0x00000b13 // li s6, 0 WORD $0x02b288b3 // mul a7, t0, a1 WORD $0x02013383 // ld t2, 32(sp) # 8-byte Folded Reload WORD $0x011383b3 // add t2, t2, a7 WORD $0x011f08b3 // add a7, t5, a7 WORD $0x02b60ab3 // mul s5, a2, a1 WORD $0x01578bb3 // add s7, a5, s5 WORD $0x015f8ab3 // add s5, t6, s5 WORD $0x0153bab3 // sltu s5, t2, s5 WORD $0x011bbbb3 // sltu s7, s7, a7 WORD $0x017afab3 // and s5, s5, s7 WORD $0x03058bb3 // mul s7, a1, a6 WORD $0x002b9b93 // slli s7, s7, 2 WORD $0x01778bb3 // add s7, a5, s7 WORD $0x01d3b3b3 // sltu t2, t2, t4 WORD $0x01013c83 // ld s9, 16(sp) # 8-byte Folded Reload WORD $0x011cb8b3 // sltu a7, s9, a7 WORD $0x0113f8b3 // and a7, t2, a7 WORD $0x0138e8b3 // or a7, a7, s3 WORD $0x015a63b3 // or t2, s4, s5 WORD $0x0113ec33 // or s8, t2, a7 WORD $0x000c8393 // mv t2, s9 JMP LBB12_11 LBB12_10: WORD $0x001b0b13 // addi s6, s6, 1 WORD $0x01c383b3 // add t2, t2, t3 BEQ S6, A4, LBB12_8 LBB12_11: WORD $0x002b1d13 // slli s10, s6, 2 WORD $0x01ab8d33 // add s10, s7, s10 BEQZ S8, LBB12_13 WORD $0x00000093 // li ra, 0 JMP LBB12_16 LBB12_13: WORD $0x0d1078d7 // vsetvli a7, zero, e32, m2, ta, ma WORD $0x0a0d6407 // vlse32.v v8, (s10), zero WORD $0x40a008b3 // neg a7, a0 WORD $0x00d8f0b3 // and ra, a7, a3 WORD $0x00008a93 // mv s5, ra WORD $0x00030c93 // mv s9, t1 WORD $0x00038893 // mv a7, t2 LBB12_14: WORD $0x2288e507 // vl2re32.v v10, (a7) WORD $0x228ce607 // vl2re32.v v12, (s9) WORD $0xb2a41657 // vfmacc.vv v12, v8, v10 WORD $0x228c8627 // vs2r.v v12, (s9) WORD $0x012888b3 // add a7, a7, s2 WORD $0x40aa8ab3 // sub s5, s5, a0 WORD $0x012c8cb3 // add s9, s9, s2 BNEZ S5, LBB12_14 BEQ RA, A3, LBB12_10 LBB12_16: WORD $0x00209a93 // slli s5, ra, 2 WORD $0x01538cb3 // add s9, t2, s5 WORD $0x01530ab3 // add s5, t1, s5 WORD $0x401680b3 // sub ra, a3, ra LBB12_17: WORD $0x000d2787 // flw fa5, 0(s10) WORD $0x000ca707 // flw fa4, 0(s9) WORD $0x000aa687 // flw fa3, 0(s5) WORD $0x68e7f7c3 // fmadd.s fa5, fa5, fa4, fa3 WORD $0x00faa027 // fsw fa5, 0(s5) WORD $0x004c8c93 // addi s9, s9, 4 WORD $0xfff08093 // addi ra, ra, -1 WORD $0x004a8a93 // addi s5, s5, 4 BNEZ RA, LBB12_17 JMP LBB12_10 LBB12_18: BEQZ A1, LBB12_29 BNEZ A0, LBB12_29 WORD $0x01813503 // ld a0, 24(sp) # 8-byte Folded Reload BLEZ A0, LBB12_65 BLEZ A3, LBB12_65 WORD $0x0d007e57 // vsetvli t3, zero, e32, m1, ta, ma WORD $0x03c75533 // divu a0, a4, t3 WORD $0x0005059b // sext.w a1, a0 WORD $0x5e003457 // vmv.v.i v8, 0 WORD $0x03c50533 // mul a0, a0, t3 WORD $0x40a70733 // sub a4, a4, a0 WORD $0x0d077557 // vsetvli a0, a4, e32, m1, ta, ma BLEZ A1, LBB12_61 WORD $0x00000713 // li a4, 0 WORD $0x00281813 // slli a6, a6, 2 WORD $0x002e1e13 // slli t3, t3, 2 WORD $0x00239393 // slli t2, t2, 2 LBB12_24: WORD $0x00000e93 // li t4, 0 WORD $0x02670633 // mul a2, a4, t1 WORD $0x00261613 // slli a2, a2, 2 WORD $0x02013f03 // ld t5, 32(sp) # 8-byte Folded Reload WORD $0x00cf0f33 // add t5, t5, a2 WORD $0x01013f83 // ld t6, 16(sp) # 8-byte Folded Reload LBB12_25: WORD $0x00000293 // li t0, 0 WORD $0x0d007657 // vsetvli a2, zero, e32, m1, ta, ma WORD $0x000f8913 // mv s2, t6 WORD $0x00078893 // mv a7, a5 WORD $0x9e8034d7 // vmv1r.v v9, v8 LBB12_26: WORD $0x0208e507 // vle32.v v10, (a7) WORD $0x02096587 // vle32.v v11, (s2) WORD $0xb2a594d7 // vfmacc.vv v9, v11, v10 WORD $0x0012829b // addiw t0, t0, 1 WORD $0x01c888b3 // add a7, a7, t3 WORD $0x01c90933 // add s2, s2, t3 BLT T0, A1, LBB12_26 WORD $0x0e9414d7 // vfredosum.vs v9, v9, v8 WORD $0x0d057057 // vsetvli zero, a0, e32, m1, ta, ma WORD $0x0208e507 // vle32.v v10, (a7) WORD $0x02096587 // vle32.v v11, (s2) WORD $0x92a59557 // vfmul.vv v10, v10, v11 WORD $0x0ea494d7 // vfredosum.vs v9, v10, v9 WORD $0x002e9613 // slli a2, t4, 2 WORD $0x00cf0633 // add a2, t5, a2 WORD $0xcd00f057 // vsetivli zero, 1, e32, m1, ta, ma WORD $0x020664a7 // vse32.v v9, (a2) WORD $0x001e8e93 // addi t4, t4, 1 WORD $0x007f8fb3 // add t6, t6, t2 BNE T4, A3, LBB12_25 WORD $0x00170713 // addi a4, a4, 1 WORD $0x010787b3 // add a5, a5, a6 WORD $0x01813603 // ld a2, 24(sp) # 8-byte Folded Reload BNE A4, A2, LBB12_24 JMP LBB12_65 LBB12_29: WORD $0x00e02633 // sgtz a2, a4 WORD $0x00d028b3 // sgtz a7, a3 WORD $0x01167633 // and a2, a2, a7 BEQZ A0, LBB12_45 BNEZ A1, LBB12_45 WORD $0x01813503 // ld a0, 24(sp) # 8-byte Folded Reload WORD $0x00a02533 // sgtz a0, a0 WORD $0x00c57533 // and a0, a0, a2 BEQZ A0, LBB12_65 WORD $0x00269f13 // slli t5, a3, 2 WORD $0xfff70513 // addi a0, a4, -1 WORD $0x030505b3 // mul a1, a0, a6 WORD $0x00259f93 // slli t6, a1, 2 WORD $0x02750533 // mul a0, a0, t2 WORD $0x00d50533 // add a0, a0, a3 WORD $0x00251e93 // slli t4, a0, 2 WORD $0xc2202573 // csrr a0, vlenb WORD $0x00155593 // srli a1, a0, 1 WORD $0x00800613 // li a2, 8 WORD $0x01f78fb3 // add t6, a5, t6 WORD $0x00058a93 // mv s5, a1 BLTU A2, A1, LBB12_34 WORD $0x00800a93 // li s5, 8 LBB12_34: WORD $0x00000e13 // li t3, 0 WORD $0x00231313 // slli t1, t1, 2 WORD $0x01013603 // ld a2, 16(sp) # 8-byte Folded Reload WORD $0x01d60633 // add a2, a2, t4 WORD $0x00c13423 // sd a2, 8(sp) # 8-byte Folded Spill WORD $0x02013b03 // ld s6, 32(sp) # 8-byte Folded Reload WORD $0x01eb0f33 // add t5, s6, t5 WORD $0x004f8f93 // addi t6, t6, 4 WORD $0x00100613 // li a2, 1 WORD $0x03d61613 // slli a2, a2, 61 WORD $0x00c878b3 // and a7, a6, a2 WORD $0x00c3f633 // and a2, t2, a2 WORD $0x00239393 // slli t2, t2, 2 WORD $0x00151913 // slli s2, a0, 1 WORD $0x00c039b3 // snez s3, a2 WORD $0x01103a33 // snez s4, a7 WORD $0x0156bab3 // sltu s5, a3, s5 JMP LBB12_36 LBB12_35: WORD $0x001e0e13 // addi t3, t3, 1 WORD $0x006b0b33 // add s6, s6, t1 WORD $0x01813603 // ld a2, 24(sp) # 8-byte Folded Reload BEQ T3, A2, LBB12_65 LBB12_36: WORD $0x00000b93 // li s7, 0 WORD $0x03c30633 // mul a2, t1, t3 WORD $0x02013883 // ld a7, 32(sp) # 8-byte Folded Reload WORD $0x00c888b3 // add a7, a7, a2 WORD $0x00cf0633 // add a2, t5, a2 WORD $0x002e1293 // slli t0, t3, 2 WORD $0x00578c33 // add s8, a5, t0 WORD $0x005f82b3 // add t0, t6, t0 WORD $0x0058b2b3 // sltu t0, a7, t0 WORD $0x00cc3eb3 // sltu t4, s8, a2 WORD $0x01d2f2b3 // and t0, t0, t4 WORD $0x0142e2b3 // or t0, t0, s4 WORD $0x00813e83 // ld t4, 8(sp) # 8-byte Folded Reload WORD $0x01d8b8b3 // sltu a7, a7, t4 WORD $0x01013d03 // ld s10, 16(sp) # 8-byte Folded Reload WORD $0x00cd3633 // sltu a2, s10, a2 WORD $0x00c8f633 // and a2, a7, a2 WORD $0x01366633 // or a2, a2, s3 WORD $0x00c2e633 // or a2, t0, a2 WORD $0x00caecb3 // or s9, s5, a2 JMP LBB12_38 LBB12_37: WORD $0x001b8b93 // addi s7, s7, 1 WORD $0x007d0d33 // add s10, s10, t2 BEQ S7, A4, LBB12_35 LBB12_38: WORD $0x030b8633 // mul a2, s7, a6 WORD $0x00261613 // slli a2, a2, 2 WORD $0x00cc00b3 // add ra, s8, a2 BEQZ S9, LBB12_40 WORD $0x00000e93 // li t4, 0 JMP LBB12_43 LBB12_40: WORD $0x00155613 // srli a2, a0, 1 WORD $0x0d1078d7 // vsetvli a7, zero, e32, m2, ta, ma WORD $0x0a00e407 // vlse32.v v8, (ra), zero WORD $0x40c00633 // neg a2, a2 WORD $0x00d67eb3 // and t4, a2, a3 WORD $0x000e8893 // mv a7, t4 WORD $0x000b0293 // mv t0, s6 WORD $0x000d0613 // mv a2, s10 LBB12_41: WORD $0x22866507 // vl2re32.v v10, (a2) WORD $0x2282e607 // vl2re32.v v12, (t0) WORD $0xb2a41657 // vfmacc.vv v12, v8, v10 WORD $0x22828627 // vs2r.v v12, (t0) WORD $0x01260633 // add a2, a2, s2 WORD $0x40b888b3 // sub a7, a7, a1 WORD $0x012282b3 // add t0, t0, s2 BNEZ A7, LBB12_41 BEQ T4, A3, LBB12_37 LBB12_43: WORD $0x002e9293 // slli t0, t4, 2 WORD $0x005d08b3 // add a7, s10, t0 WORD $0x005b02b3 // add t0, s6, t0 WORD $0x41d68eb3 // sub t4, a3, t4 LBB12_44: WORD $0x0000a787 // flw fa5, 0(ra) WORD $0x0008a707 // flw fa4, 0(a7) WORD $0x0002a687 // flw fa3, 0(t0) WORD $0x68e7f7c3 // fmadd.s fa5, fa5, fa4, fa3 WORD $0x00f2a027 // fsw fa5, 0(t0) WORD $0x00488893 // addi a7, a7, 4 WORD $0xfffe8e93 // addi t4, t4, -1 WORD $0x00428293 // addi t0, t0, 4 BNEZ T4, LBB12_44 JMP LBB12_37 LBB12_45: WORD $0x00b675b3 // and a1, a2, a1 BEQZ A1, LBB12_65 WORD $0x01813583 // ld a1, 24(sp) # 8-byte Folded Reload BLEZ A1, LBB12_65 BEQZ A0, LBB12_65 WORD $0x00269f13 // slli t5, a3, 2 WORD $0xfff70513 // addi a0, a4, -1 WORD $0x03050533 // mul a0, a0, a6 WORD $0x00251f93 // slli t6, a0, 2 WORD $0x00d70533 // add a0, a4, a3 WORD $0x00251513 // slli a0, a0, 2 WORD $0x01013e83 // ld t4, 16(sp) # 8-byte Folded Reload WORD $0x00ae8eb3 // add t4, t4, a0 WORD $0xc2202573 // csrr a0, vlenb WORD $0x00155593 // srli a1, a0, 1 WORD $0x00800613 // li a2, 8 WORD $0x01f78fb3 // add t6, a5, t6 WORD $0x00058913 // mv s2, a1 BLTU A2, A1, LBB12_50 WORD $0x00800913 // li s2, 8 LBB12_50: WORD $0x00000e13 // li t3, 0 WORD $0x00231313 // slli t1, t1, 2 WORD $0xffce8e93 // addi t4, t4, -4 WORD $0x01d13423 // sd t4, 8(sp) # 8-byte Folded Spill WORD $0x02013a83 // ld s5, 32(sp) # 8-byte Folded Reload WORD $0x01ea8f33 // add t5, s5, t5 WORD $0x004f8f93 // addi t6, t6, 4 WORD $0x0126b633 // sltu a2, a3, s2 WORD $0xfff64613 // not a2, a2 WORD $0xfff38893 // addi a7, t2, -1 WORD $0x0018b893 // seqz a7, a7 WORD $0x01167633 // and a2, a2, a7 WORD $0x00100893 // li a7, 1 WORD $0x03d89893 // slli a7, a7, 61 WORD $0x011878b3 // and a7, a6, a7 WORD $0x00151913 // slli s2, a0, 1 WORD $0x00239e93 // slli t4, t2, 2 WORD $0x011039b3 // snez s3, a7 WORD $0x00164a13 // xori s4, a2, 1 JMP LBB12_52 LBB12_51: WORD $0x001e0e13 // addi t3, t3, 1 WORD $0x006a8ab3 // add s5, s5, t1 WORD $0x01813603 // ld a2, 24(sp) # 8-byte Folded Reload BEQ T3, A2, LBB12_65 LBB12_52: WORD $0x00000b13 // li s6, 0 WORD $0x03c30633 // mul a2, t1, t3 WORD $0x02013883 // ld a7, 32(sp) # 8-byte Folded Reload WORD $0x00c888b3 // add a7, a7, a2 WORD $0x00cf0633 // add a2, t5, a2 WORD $0x002e1293 // slli t0, t3, 2 WORD $0x00578bb3 // add s7, a5, t0 WORD $0x005f82b3 // add t0, t6, t0 WORD $0x0058b2b3 // sltu t0, a7, t0 WORD $0x00cbb3b3 // sltu t2, s7, a2 WORD $0x0072f2b3 // and t0, t0, t2 WORD $0x0132e2b3 // or t0, t0, s3 WORD $0x00813383 // ld t2, 8(sp) # 8-byte Folded Reload WORD $0x0078b8b3 // sltu a7, a7, t2 WORD $0x01013c83 // ld s9, 16(sp) # 8-byte Folded Reload WORD $0x00ccb633 // sltu a2, s9, a2 WORD $0x00c8f633 // and a2, a7, a2 WORD $0x00ca6633 // or a2, s4, a2 WORD $0x00566c33 // or s8, a2, t0 JMP LBB12_54 LBB12_53: WORD $0x001b0b13 // addi s6, s6, 1 WORD $0x004c8c93 // addi s9, s9, 4 BEQ S6, A4, LBB12_51 LBB12_54: WORD $0x030b0633 // mul a2, s6, a6 WORD $0x00261613 // slli a2, a2, 2 WORD $0x00cb8d33 // add s10, s7, a2 BEQZ S8, LBB12_56 WORD $0x00000093 // li ra, 0 JMP LBB12_59 LBB12_56: WORD $0x00155613 // srli a2, a0, 1 WORD $0x0d1078d7 // vsetvli a7, zero, e32, m2, ta, ma WORD $0x0a0d6407 // vlse32.v v8, (s10), zero WORD $0x40c00633 // neg a2, a2 WORD $0x00d670b3 // and ra, a2, a3 WORD $0x00008893 // mv a7, ra WORD $0x000a8293 // mv t0, s5 WORD $0x000c8613 // mv a2, s9 LBB12_57: WORD $0x22866507 // vl2re32.v v10, (a2) WORD $0x2282e607 // vl2re32.v v12, (t0) WORD $0xb2a41657 // vfmacc.vv v12, v8, v10 WORD $0x22828627 // vs2r.v v12, (t0) WORD $0x01260633 // add a2, a2, s2 WORD $0x40b888b3 // sub a7, a7, a1 WORD $0x012282b3 // add t0, t0, s2 BNEZ A7, LBB12_57 BEQ RA, A3, LBB12_53 LBB12_59: WORD $0x021e83b3 // mul t2, t4, ra WORD $0x007c83b3 // add t2, s9, t2 WORD $0x00209893 // slli a7, ra, 2 WORD $0x011a88b3 // add a7, s5, a7 WORD $0x401682b3 // sub t0, a3, ra LBB12_60: WORD $0x000d2787 // flw fa5, 0(s10) WORD $0x0003a707 // flw fa4, 0(t2) WORD $0x0008a687 // flw fa3, 0(a7) WORD $0x68e7f7c3 // fmadd.s fa5, fa5, fa4, fa3 WORD $0x00f8a027 // fsw fa5, 0(a7) WORD $0x01d383b3 // add t2, t2, t4 WORD $0xfff28293 // addi t0, t0, -1 WORD $0x00488893 // addi a7, a7, 4 BNEZ T0, LBB12_60 JMP LBB12_53 LBB12_61: WORD $0x00000593 // li a1, 0 WORD $0x0d007657 // vsetvli a2, zero, e32, m1, ta, ma WORD $0x0e841457 // vfredosum.vs v8, v8, v8 WORD $0x00239393 // slli t2, t2, 2 WORD $0x00231313 // slli t1, t1, 2 LBB12_62: WORD $0x03058633 // mul a2, a1, a6 WORD $0x00261613 // slli a2, a2, 2 WORD $0x00c78733 // add a4, a5, a2 WORD $0x00068893 // mv a7, a3 WORD $0x02013283 // ld t0, 32(sp) # 8-byte Folded Reload WORD $0x01013e03 // ld t3, 16(sp) # 8-byte Folded Reload LBB12_63: WORD $0x0d057057 // vsetvli zero, a0, e32, m1, ta, ma WORD $0x02076487 // vle32.v v9, (a4) WORD $0x020e6507 // vle32.v v10, (t3) WORD $0x929514d7 // vfmul.vv v9, v9, v10 WORD $0x0e9414d7 // vfredosum.vs v9, v9, v8 WORD $0xcd00f057 // vsetivli zero, 1, e32, m1, ta, ma WORD $0x0202e4a7 // vse32.v v9, (t0) WORD $0x007e0e33 // add t3, t3, t2 WORD $0xfff88893 // addi a7, a7, -1 WORD $0x00428293 // addi t0, t0, 4 BNEZ A7, LBB12_63 WORD $0x00158593 // addi a1, a1, 1 WORD $0x02013603 // ld a2, 32(sp) # 8-byte Folded Reload WORD $0x00660633 // add a2, a2, t1 WORD $0x02c13023 // sd a2, 32(sp) # 8-byte Folded Spill WORD $0x01813603 // ld a2, 24(sp) # 8-byte Folded Reload BNE A1, A2, LBB12_62 LBB12_65: WORD $0xf8040113 // addi sp, s0, -128 WORD $0x07813083 // ld ra, 120(sp) # 8-byte Folded Reload WORD $0x07013403 // ld s0, 112(sp) # 8-byte Folded Reload WORD $0x06813903 // ld s2, 104(sp) # 8-byte Folded Reload WORD $0x06013983 // ld s3, 96(sp) # 8-byte Folded Reload WORD $0x05813a03 // ld s4, 88(sp) # 8-byte Folded Reload WORD $0x05013a83 // ld s5, 80(sp) # 8-byte Folded Reload WORD $0x04813b03 // ld s6, 72(sp) # 8-byte Folded Reload WORD $0x04013b83 // ld s7, 64(sp) # 8-byte Folded Reload WORD $0x03813c03 // ld s8, 56(sp) # 8-byte Folded Reload WORD $0x03013c83 // ld s9, 48(sp) # 8-byte Folded Reload WORD $0x02813d03 // ld s10, 40(sp) # 8-byte Folded Reload WORD $0x08010113 // addi sp, sp, 128 ADDI 24, SP, SP RET ================================================ FILE: common/floats/floats_test.go ================================================ // Copyright 2025 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package floats import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" ) func TestMatZero(t *testing.T) { a := [][]float32{ {3, 2, 5, 6, 0, 0}, {1, 2, 3, 4, 5, 6}, } MatZero(a) assert.Equal(t, [][]float32{ {0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0}, }, a) } func TestZero(t *testing.T) { a := []float32{3, 2, 5, 6, 0, 0} Zero(a) assert.Equal(t, []float32{0, 0, 0, 0, 0, 0}, a) } func TestAdd(t *testing.T) { a := []float32{1, 2, 3, 4} b := []float32{5, 6, 7, 8} Add(a, b) assert.Equal(t, []float32{6, 8, 10, 12}, a) assert.Panics(t, func() { Add([]float32{1}, nil) }) } func TestSub(t *testing.T) { a := []float32{1, 2, 3, 4} b := []float32{5, 6, 7, 8} Sub(a, b) assert.Equal(t, []float32{-4, -4, -4, -4}, a) assert.Panics(t, func() { Sub([]float32{1}, nil) }) } func TestSubTo(t *testing.T) { a := []float32{1, 2, 3, 4} b := []float32{5, 6, 7, 8} c := make([]float32, 4) SubTo(a, b, c) assert.Equal(t, []float32{-4, -4, -4, -4}, c) assert.Panics(t, func() { SubTo([]float32{1}, nil, nil) }) } func TestMulTo(t *testing.T) { a := []float32{1, 2, 3, 4} b := []float32{5, 6, 7, 8} c := make([]float32, 4) MulTo(a, b, c) assert.Equal(t, []float32{5, 12, 21, 32}, c) assert.Panics(t, func() { MulTo([]float32{1}, nil, nil) }) } func TestMulConst(t *testing.T) { a := []float32{1, 2, 3, 4} MulConst(a, 2) assert.Equal(t, []float32{2, 4, 6, 8}, a) } func TestDiv(t *testing.T) { a := []float32{1, 4, 9, 16} b := []float32{1, 2, 3, 4} Div(a, b) assert.Equal(t, []float32{1, 2, 3, 4}, a) assert.Panics(t, func() { Div([]float32{1}, nil) }) } func TestMulConstTo(t *testing.T) { a := []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10} dst := make([]float32, 11) target := []float32{0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20} MulConstTo(a, 2, dst) assert.Equal(t, target, dst) assert.Panics(t, func() { MulConstTo(nil, 2, dst) }) } func TestMulConstAdd(t *testing.T) { a := []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10} dst := []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10} target := []float32{0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30} MulConstAdd(a, 2, dst) assert.Equal(t, target, dst) assert.Panics(t, func() { MulConstAdd(nil, 1, dst) }) } func TestMulConstAddTo(t *testing.T) { a := []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10} b := []float32{0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20} dst := make([]float32, 11) target := []float32{0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50} MulConstAddTo(a, 3, b, dst) assert.Equal(t, target, dst) } func TestMulAddTo(t *testing.T) { a := []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10} b := []float32{0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20} c := []float32{0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30} target := []float32{0, 5, 14, 27, 44, 65, 90, 119, 152, 189, 230} MulAddTo(a, b, c) assert.Equal(t, target, c) assert.Panics(t, func() { MulAddTo(nil, nil, c) }) } func TestAddTo(t *testing.T) { a := []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10} b := []float32{0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20} dst := make([]float32, 11) target := []float32{0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30} AddTo(a, b, dst) assert.Equal(t, target, dst) assert.Panics(t, func() { AddTo(nil, nil, dst) }) } func TestAddConst(t *testing.T) { a := []float32{1, 2, 3, 4} AddConst(a, 2) assert.Equal(t, []float32{3, 4, 5, 6}, a) } func TestDivTo(t *testing.T) { a := []float32{1, 4, 9, 16} b := []float32{1, 2, 3, 4} c := make([]float32, 4) DivTo(a, b, c) assert.Equal(t, []float32{1, 2, 3, 4}, c) assert.Panics(t, func() { DivTo([]float32{1}, nil, nil) }) } func TestSqrtTo(t *testing.T) { a := []float32{1, 4, 9, 16} b := make([]float32, 4) SqrtTo(a, b) assert.Equal(t, []float32{1, 2, 3, 4}, b) assert.Panics(t, func() { SqrtTo([]float32{1}, nil) }) } func TestSqrt(t *testing.T) { a := []float32{1, 4, 9, 16} Sqrt(a) assert.Equal(t, []float32{1, 2, 3, 4}, a) } func TestDot(t *testing.T) { a := []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10} b := []float32{0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20} assert.Equal(t, float32(770), Dot(a, b)) assert.Panics(t, func() { Dot([]float32{1}, nil) }) } func TestEuclidean(t *testing.T) { a := []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10} b := []float32{0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20} assert.Equal(t, float32(19.621416), Euclidean(a, b)) assert.Panics(t, func() { Euclidean([]float32{1}, nil) }) } type NativeTestSuite struct { suite.Suite } func (suite *NativeTestSuite) TestDot() { a := []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10} b := []float32{0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20} suite.Equal(float32(770), dot(a, b)) } func (suite *NativeTestSuite) TestEuclidean() { a := []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10} b := []float32{0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20} suite.InDelta(float32(19.621416), euclidean(a, b), 1e-6) } func (suite *NativeTestSuite) TestMulConstAddTo() { a := []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10} b := []float32{0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20} dst := make([]float32, 11) mulConstAddTo(a, 3, b, dst) suite.Equal([]float32{0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50}, dst) } func (suite *NativeTestSuite) TestMulConstAdd() { a := []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10} dst := []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10} target := []float32{0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30} mulConstAdd(a, 2, dst) suite.Equal(target, dst) } func (suite *NativeTestSuite) TestMulConstTo() { a := []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10} dst := make([]float32, 11) target := []float32{0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20} mulConstTo(a, 2, dst) suite.Equal(target, dst) } func (suite *NativeTestSuite) TestAddConst() { a := []float32{1, 2, 3, 4} addConst(a, 2) suite.Equal([]float32{3, 4, 5, 6}, a) } func (suite *NativeTestSuite) TestSub() { a := []float32{1, 2, 3, 4} b := []float32{5, 6, 7, 8} sub(a, b) suite.Equal([]float32{-4, -4, -4, -4}, a) } func (suite *NativeTestSuite) TestSubTo() { a := []float32{1, 2, 3, 4} b := []float32{5, 6, 7, 8} c := make([]float32, 4) subTo(a, b, c) suite.Equal([]float32{-4, -4, -4, -4}, c) } func (suite *NativeTestSuite) TestMulTo() { a := []float32{1, 2, 3, 4} b := []float32{5, 6, 7, 8} c := make([]float32, 4) mulTo(a, b, c) suite.Equal([]float32{5, 12, 21, 32}, c) } func (suite *NativeTestSuite) TestDivTo() { a := []float32{1, 4, 9, 16} b := []float32{1, 2, 3, 4} c := make([]float32, 4) divTo(a, b, c) suite.Equal([]float32{1, 2, 3, 4}, c) } func (suite *NativeTestSuite) TestSqrtTo() { a := []float32{1, 4, 9, 16} b := make([]float32, 4) sqrtTo(a, b) suite.Equal([]float32{1, 2, 3, 4}, b) } func (suite *NativeTestSuite) TestMulConst() { a := []float32{1, 2, 3, 4} mulConst(a, 2) suite.Equal([]float32{2, 4, 6, 8}, a) } func (suite *NativeTestSuite) TestMM() { a := []float32{1, 2, 3, 4, 5, 6} b := []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12} c := make([]float32, 8) target := []float32{38, 44, 50, 56, 83, 98, 113, 128} mm(false, false, 2, 4, 3, a, 3, b, 4, c, 4) suite.Equal(target, c) c = make([]float32, 8) target = []float32{14, 32, 50, 68, 32, 77, 122, 167} mm(false, true, 2, 4, 3, a, 3, b, 3, c, 4) suite.Equal(target, c) c = make([]float32, 8) target = []float32{61, 70, 79, 88, 76, 88, 100, 112} mm(true, false, 2, 4, 3, a, 2, b, 4, c, 4) suite.Equal(target, c) c = make([]float32, 8) target = []float32{22, 49, 76, 103, 28, 64, 100, 136} mm(true, true, 2, 4, 3, a, 2, b, 3, c, 4) suite.Equal(target, c) } func TestNativeTestSuite(t *testing.T) { suite.Run(t, new(NativeTestSuite)) } type SIMDTestSuite struct { suite.Suite Feature } func (suite *SIMDTestSuite) SetupSuite() { if feature&suite.Feature != suite.Feature { suite.T().Skipf("%s is not supported", (suite.Feature - (feature & suite.Feature)).String()) } } func (suite *SIMDTestSuite) TestMulConstAddTo() { a := []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20} b := []float32{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 200} dst := make([]float32, len(a)) suite.mulConstAddTo(a, 2, b, dst) c := make([]float32, len(a)) mulConstAddTo(a, 2, b, c) assert.Equal(suite.T(), c, dst) } func (suite *SIMDTestSuite) TestMulConstAdd() { a := []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20} b := []float32{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 200} suite.mulConstAdd(a, 2, b) c := []float32{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 200} mulConstAdd(a, 2, c) assert.Equal(suite.T(), c, b) } func (suite *SIMDTestSuite) TestMulConstTo() { a := []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20} b := []float32{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 200} suite.mulConstTo(a, 2, b) c := []float32{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 200} mulConstTo(a, 2, c) assert.Equal(suite.T(), c, b) } func (suite *SIMDTestSuite) TestAddConst() { a := []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20} suite.addConst(a, 2) c := []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20} addConst(c, 2) assert.Equal(suite.T(), c, a) } func (suite *SIMDTestSuite) TestSub() { a := []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20} b := []float32{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 200} suite.sub(a, b) c := []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20} sub(c, b) suite.Equal(c, a) } func (suite *SIMDTestSuite) TestSubTo() { a := []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20} b := []float32{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 200} c := make([]float32, len(a)) suite.subTo(a, b, c) d := make([]float32, len(a)) subTo(a, b, d) assert.Equal(suite.T(), c, d) } func (suite *SIMDTestSuite) TestMulTo() { a := []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20} b := []float32{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 200} expected, actual := make([]float32, len(a)), make([]float32, len(a)) suite.mulTo(a, b, actual) mulTo(a, b, expected) assert.Equal(suite.T(), expected, actual) } func (suite *SIMDTestSuite) TestMulConst() { b := []float32{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 200} suite.mulConst(b, 2) c := []float32{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 200} mulConst(c, 2) assert.Equal(suite.T(), c, b) } func (suite *SIMDTestSuite) TestDivTo() { a := []float32{1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121, 144, 169, 196, 225, 256, 289, 324, 361, 400} b := []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20} c := make([]float32, len(a)) suite.divTo(a, b, c) d := make([]float32, len(a)) divTo(a, b, d) assert.Equal(suite.T(), c, d) } func (suite *SIMDTestSuite) TestSqrtTo() { a := []float32{1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121, 144, 169, 196, 225, 256, 289, 324, 361, 400} b := make([]float32, len(a)) suite.sqrtTo(a, b) c := make([]float32, len(a)) sqrtTo(a, c) assert.Equal(suite.T(), b, c) } func (suite *SIMDTestSuite) TestDot() { a := []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20} b := []float32{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 200} actual := suite.dot(a, b) expected := dot(a, b) assert.Equal(suite.T(), expected, actual) } func (suite *SIMDTestSuite) TestEuclidean() { a := []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20} b := []float32{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 200} actual := suite.euclidean(a, b) expected := euclidean(a, b) assert.Equal(suite.T(), expected, actual) } func (suite *SIMDTestSuite) TestMM() { a := []float32{1, 2, 3, 4, 5, 6} b := []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12} c := make([]float32, 8) target := []float32{38, 44, 50, 56, 83, 98, 113, 128} suite.mm(false, false, 2, 4, 3, a, 3, b, 4, c, 4) suite.Equal(target, c) c = make([]float32, 8) target = []float32{14, 32, 50, 68, 32, 77, 122, 167} suite.mm(false, true, 2, 4, 3, a, 3, b, 3, c, 4) suite.Equal(target, c) c = make([]float32, 8) target = []float32{61, 70, 79, 88, 76, 88, 100, 112} suite.mm(true, false, 2, 4, 3, a, 2, b, 4, c, 4) suite.Equal(target, c) c = make([]float32, 8) target = []float32{22, 49, 76, 103, 28, 64, 100, 136} suite.mm(true, true, 2, 4, 3, a, 2, b, 3, c, 4) suite.Equal(target, c) } ================================================ FILE: common/floats/mm.go ================================================ //go:build !cgo || (!(darwin && arm64) && !mkl && !openblas) // Copyright 2025 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package floats func mm(transA, transB bool, m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int) { if !transA && !transB { for i := 0; i < m; i++ { for l := 0; l < k; l++ { // C_l += A_{il} * B_i MulConstAdd(b[l*ldb:(l+1)*ldb], a[i*lda+l], c[i*ldc:(i+1)*ldc]) } } } else if !transA && transB { for i := 0; i < m; i++ { for j := 0; j < n; j++ { c[i*ldc+j] = Dot(a[i*lda:(i+1)*lda], b[j*ldb:(j+1)*ldb]) } } } else if transA && !transB { for i := 0; i < m; i++ { for l := 0; l < k; l++ { // C_j += A_{ji} * B_i MulConstAdd(b[l*ldb:(l+1)*ldb], a[l*lda+i], c[i*ldc:(i+1)*ldc]) } } } else { for i := 0; i < m; i++ { for j := 0; j < n; j++ { for l := 0; l < k; l++ { c[i*ldc+j] += a[l*lda+i] * b[j*ldb+l] } } } } } ================================================ FILE: common/floats/mm_darwin_arm64.go ================================================ //go:build cgo // Copyright 2025 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package floats import "github.com/gorse-io/gorse/common/blas" func init() { feature = feature | AMX } func mm(transA, transB bool, m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int) { blas.SGEMM(blas.RowMajor, blas.NewTranspose(transA), blas.NewTranspose(transB), m, n, k, 1.0, a, lda, b, ldb, 0, c, ldc) } ================================================ FILE: common/floats/mm_mkl.go ================================================ //go:build cgo && mkl // Copyright 2025 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package floats import "github.com/gorse-io/gorse/common/blas" func init() { feature = feature | MKL } func mm(transA, transB bool, m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int) { blas.SGEMM(blas.RowMajor, blas.NewTranspose(transA), blas.NewTranspose(transB), m, n, k, 1.0, a, lda, b, ldb, 0, c, ldc) } ================================================ FILE: common/floats/mm_openblas.go ================================================ //go:build cgo && openblas // Copyright 2025 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package floats import "github.com/gorse-io/gorse/common/blas" func init() { feature = feature | OPENBLAS } func mm(transA, transB bool, m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int) { blas.SGEMM(blas.RowMajor, blas.NewTranspose(transA), blas.NewTranspose(transB), m, n, k, 1.0, a, lda, b, ldb, 0, c, ldc) } ================================================ FILE: common/floats/src/.gitignore ================================================ * !.gitignore !*.c !*.h !Makefile ================================================ FILE: common/floats/src/Makefile ================================================ SOURCES = munit.c floats_test.c ifeq ($(shell uname -m),x86_64) SOURCES += floats_avx.c floats_avx512.c CFLAGS = -O3 -mavx -mavx512f -mavx512dq else ifeq ($(shell uname -m),aarch64) SOURCES += floats_neon.c floats_sve2.c CFLAGS = -O3 -march=armv8-a+sve else ifeq ($(shell uname -m),riscv64) SOURCES += floats_rvv.c CFLAGS = -O3 -march=rv64imafdv endif OBJECTS = $(SOURCES:.c=.o) DEPENDENCES = $(SOURCES:.c=.d) EXECUTE = floats_test $(EXECUTE): $(OBJECTS) $(CC) $(OBJECTS) -lm -o $(EXECUTE) test: $(EXECUTE) ./${EXECUTE} clean: rm $(OBJECTS) $(DEPENDENCES) $(EXECUTE) -include $(DEPENDENCES) %.d: %.c @set -e; \ rm -f $@; \ $(CC) $(CFLAGS) -MM -MT $(@:.d=.o) $< > $@.$$$$; \ sed 's,\($*\)\.o[ :]*,\1.o $@: ,g' $@.$$$$ > $@; \ rm -f $@.$$$$ ================================================ FILE: common/floats/src/floats_avx.c ================================================ // Copyright 2022 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include void _mm256_mul_const_add_to(float *a, float *b, float *c, float *dst, int64_t n) { int epoch = n / 8; int remain = n % 8; for (int i = 0; i < epoch; i++) { __m256 v1 = _mm256_loadu_ps(a); __m256 v2 = _mm256_broadcast_ss(b); __m256 v3 = _mm256_loadu_ps(c); __m256 v = _mm256_add_ps(_mm256_mul_ps(v1, v2), v3); _mm256_storeu_ps(dst, v); a += 8; c += 8; dst += 8; } for (int i = 0; i < remain; i++) { dst[i] = c[i] + a[i] * b[0]; } } void _mm256_mul_const_add(float *a, float *b, float *c, int64_t n) { int epoch = n / 8; int remain = n % 8; for (int i = 0; i < epoch; i++) { __m256 v1 = _mm256_loadu_ps(a); __m256 v2 = _mm256_broadcast_ss(b); __m256 v3 = _mm256_loadu_ps(c); __m256 v = _mm256_add_ps(_mm256_mul_ps(v1, v2), v3); _mm256_storeu_ps(c, v); a += 8; c += 8; } for (int i = 0; i < remain; i++) { c[i] += a[i] * b[0]; } } void _mm256_mul_const_to(float *a, float *b, float *c, int64_t n) { int epoch = n / 8; int remain = n % 8; for (int i = 0; i < epoch; i++) { __m256 v1 = _mm256_loadu_ps(a); __m256 v2 = _mm256_broadcast_ss(b); __m256 v = _mm256_mul_ps(v1, v2); _mm256_storeu_ps(c, v); a += 8; c += 8; } for (int i = 0; i < remain; i++) { c[i] = a[i] * b[0]; } } void _mm256_mul_const(float *a, float *b, int64_t n) { int epoch = n / 8; int remain = n % 8; for (int i = 0; i < epoch; i++) { __m256 v1 = _mm256_loadu_ps(a); __m256 v2 = _mm256_broadcast_ss(b); __m256 v = _mm256_mul_ps(v1, v2); _mm256_storeu_ps(a, v); a += 8; } for (int i = 0; i < remain; i++) { a[i] *= b[0]; } } void _mm256_add_const(float *a, float *b, int64_t n) { int epoch = n / 8; int remain = n % 8; for (int i = 0; i < epoch; i++) { __m256 v1 = _mm256_loadu_ps(a); __m256 v2 = _mm256_broadcast_ss(b); __m256 v = _mm256_add_ps(v1, v2); _mm256_storeu_ps(a, v); a += 8; } for (int i = 0; i < remain; i++) { a[i] += b[0]; } } void _mm256_sub_to(float *a, float *b, float *c, int64_t n) { int epoch = n / 8; int remain = n % 8; for (int i = 0; i < epoch; i++) { __m256 v1 = _mm256_loadu_ps(a); __m256 v2 = _mm256_loadu_ps(b); __m256 v = _mm256_sub_ps(v1, v2); _mm256_storeu_ps(c, v); a += 8; b += 8; c += 8; } for (int i = 0; i < remain; i++) { c[i] = a[i] - b[i]; } } void _mm256_sub(float *a, float *b, int64_t n) { int epoch = n / 8; int remain = n % 8; for (int i = 0; i < epoch; i++) { __m256 v1 = _mm256_loadu_ps(a); __m256 v2 = _mm256_loadu_ps(b); __m256 v = _mm256_sub_ps(v1, v2); _mm256_storeu_ps(a, v); a += 8; b += 8; } for (int i = 0; i < remain; i++) { a[i] -= b[i]; } } void _mm256_mul_to(float *a, float *b, float *c, int64_t n) { int epoch = n / 8; int remain = n % 8; for (int i = 0; i < epoch; i++) { __m256 v1 = _mm256_loadu_ps(a); __m256 v2 = _mm256_loadu_ps(b); __m256 v = _mm256_mul_ps(v1, v2); _mm256_storeu_ps(c, v); a += 8; b += 8; c += 8; } for (int i = 0; i < remain; i++) { c[i] = a[i] * b[i]; } } void _mm256_div_to(float *a, float *b, float *c, int64_t n) { int epoch = n / 8; int remain = n % 8; for (int i = 0; i < epoch; i++) { __m256 v1 = _mm256_loadu_ps(a); __m256 v2 = _mm256_loadu_ps(b); __m256 v = _mm256_div_ps(v1, v2); _mm256_storeu_ps(c, v); a += 8; b += 8; c += 8; } for (int i = 0; i < remain; i++) { c[i] = a[i] / b[i]; } } void _mm256_sqrt_to(float *a, float *b, int64_t n) { int epoch = n / 8; int remain = n % 8; for (int i = 0; i < epoch; i++) { __m256 v1 = _mm256_loadu_ps(a); __m256 v2 = _mm256_sqrt_ps(v1); _mm256_storeu_ps(b, v2); a += 8; b += 8; } for (int i = 0; i < remain; i++) { __m128 v = _mm_set1_ps(a[i]); __m128 r = _mm_sqrt_ss(v); b[i] = _mm_cvtss_f32(r); } } inline __attribute__((always_inline)) float dot(float *a, float *b, int64_t n) { int epoch = n / 8; int remain = n % 8; __m256 s = _mm256_setzero_ps(); if (epoch > 0) { __m256 v1 = _mm256_loadu_ps(a); __m256 v2 = _mm256_loadu_ps(b); s = _mm256_mul_ps(v1, v2); a += 8; b += 8; } for (int i = 1; i < epoch; i++) { __m256 v1 = _mm256_loadu_ps(a); __m256 v2 = _mm256_loadu_ps(b); s = _mm256_add_ps(_mm256_mul_ps(v1, v2), s); a += 8; b += 8; } __m128 s7_6_5_4 = _mm256_extractf128_ps(s, 1); __m128 s3_2_1_0 = _mm256_castps256_ps128(s); __m128 s37_26_15_04 = _mm_add_ps(s7_6_5_4, s3_2_1_0); __m128 sxx_15_04 = s37_26_15_04; __m128 sxx_37_26 = _mm_movehl_ps(s37_26_15_04, s37_26_15_04); const __m128 sxx_1357_0246 = _mm_add_ps(sxx_15_04, sxx_37_26); const __m128 sxxx_0246 = sxx_1357_0246; const __m128 sxxx_1357 = _mm_shuffle_ps(sxx_1357_0246, sxx_1357_0246, 0x1); __m128 sxxx_01234567 = _mm_add_ss(sxxx_0246, sxxx_1357); float sum = _mm_cvtss_f32(sxxx_01234567); for (int i = 0; i < remain; i++) { sum += a[i] * b[i]; } return sum; } float _mm256_dot(float *a, float *b, int64_t n) { return dot(a, b, n); } float _mm256_euclidean(float *a, float *b, int64_t n) { int epoch = n / 8; int remain = n % 8; __m256 sum = _mm256_setzero_ps(); if (epoch > 0) { __m256 v1 = _mm256_loadu_ps(a); __m256 v2 = _mm256_loadu_ps(b); __m256 v = _mm256_sub_ps(v1, v2); sum = _mm256_mul_ps(v, v); a += 8; b += 8; } for (int i = 1; i < epoch; i++) { __m256 v1 = _mm256_loadu_ps(a); __m256 v2 = _mm256_loadu_ps(b); __m256 v = _mm256_sub_ps(v1, v2); v = _mm256_mul_ps(v, v); sum = _mm256_add_ps(v, sum); a += 8; b += 8; } __m128 s7_6_5_4 = _mm256_extractf128_ps(sum, 1); __m128 s3_2_1_0 = _mm256_castps256_ps128(sum); __m128 s37_26_15_04 = _mm_add_ps(s7_6_5_4, s3_2_1_0); __m128 sxx_15_04 = s37_26_15_04; __m128 sxx_37_26 = _mm_movehl_ps(s37_26_15_04, s37_26_15_04); const __m128 sxx_1357_0246 = _mm_add_ps(sxx_15_04, sxx_37_26); const __m128 sxxx_0246 = sxx_1357_0246; const __m128 sxxx_1357 = _mm_shuffle_ps(sxx_1357_0246, sxx_1357_0246, 0x1); __m128 sxxx_01234567 = _mm_add_ss(sxxx_0246, sxxx_1357); float ret = _mm_cvtss_f32(sxxx_01234567); for (int i = 0; i < remain; i++) { ret += (a[i] - b[i]) * (a[i] - b[i]); } __m128 v = _mm_set1_ps(ret); __m128 r = _mm_sqrt_ss(v); return _mm_cvtss_f32(r); } void _mm256_mm(_Bool transA, _Bool transB, int64_t m, int64_t n, int64_t k, float *a, int64_t lda, float *b, int64_t ldb, float *c, int64_t ldc) { if (!transA && !transB) { for (int i = 0; i < m; i++) { for (int l = 0; l < k; l++) { for (int j = 0; j < n; j++) { c[i * ldc + j] += a[i * lda + l] * b[l * ldb + j]; } } } } else if (!transA && transB) { for (int i = 0; i < m; i++) { for (int j = 0; j < n; j++) { c[i * ldc + j] = dot(a + i * lda, b + j * ldb, k); } } } else if (transA && !transB) { for (int i = 0; i < m; i++) { for (int l = 0; l < k; l++) { for (int j = 0; j < n; j++) { c[i * ldc + j] += a[l * lda + i] * b[l * ldb + j]; } } } } else if (transA && transB) { for (int i = 0; i < m; i++) { for (int l = 0; l < k; l++) { for (int j = 0; j < n; j++) { c[i * ldc + j] += a[l * lda + i] * b[j * ldb + l]; } } } } } ================================================ FILE: common/floats/src/floats_avx512.c ================================================ // Copyright 2022 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include void _mm512_mul_const_add_to(float *a, float *b, float *c, float *dst, int64_t n) { int epoch = n / 16; int remain = n % 16; for (int i = 0; i < epoch; i++) { __m512 v1 = _mm512_loadu_ps(a); __m512 v2 = _mm512_set1_ps(*b); __m512 v3 = _mm512_loadu_ps(c); __m512 v = _mm512_fmadd_ps(v1, v2, v3); _mm512_storeu_ps(dst, v); a += 16; c += 16; dst += 16; } if (remain >= 8) { __m256 v1 = _mm256_loadu_ps(a); __m256 v2 = _mm256_broadcast_ss(b); __m256 v3 = _mm256_loadu_ps(c); __m256 v = _mm256_add_ps(_mm256_mul_ps(v1, v2), v3); _mm256_storeu_ps(dst, v); a += 8; c += 8; dst += 8; remain -= 8; } for (int i = 0; i < remain; i++) { dst[i] = c[i] + a[i] * b[0]; } } void _mm512_mul_const_add(float *a, float *b, float *c, int64_t n) { int epoch = n / 16; int remain = n % 16; for (int i = 0; i < epoch; i++) { __m512 v1 = _mm512_loadu_ps(a); __m512 v2 = _mm512_set1_ps(*b); __m512 v3 = _mm512_loadu_ps(c); __m512 v = _mm512_fmadd_ps(v1, v2, v3); _mm512_storeu_ps(c, v); a += 16; c += 16; } if (remain >= 8) { __m256 v1 = _mm256_loadu_ps(a); __m256 v2 = _mm256_broadcast_ss(b); __m256 v3 = _mm256_loadu_ps(c); __m256 v = _mm256_add_ps(_mm256_mul_ps(v1, v2), v3); _mm256_storeu_ps(c, v); a += 8; c += 8; remain -= 8; } for (int i = 0; i < remain; i++) { c[i] += a[i] * b[0]; } } void _mm512_mul_const_to(float *a, float *b, float *c, int64_t n) { int epoch = n / 16; int remain = n % 16; for (int i = 0; i < epoch; i++) { __m512 v1 = _mm512_loadu_ps(a); __m512 v2 = _mm512_set1_ps(*b); __m512 v = _mm512_mul_ps(v1, v2); _mm512_storeu_ps(c, v); a += 16; c += 16; } if (remain >= 8) { __m256 v1 = _mm256_loadu_ps(a); __m256 v2 = _mm256_broadcast_ss(b); __m256 v = _mm256_mul_ps(v1, v2); _mm256_storeu_ps(c, v); a += 8; c += 8; remain -= 8; } for (int i = 0; i < remain; i++) { c[i] = a[i] * b[0]; } } void _mm512_mul_const(float *a, float *b, int64_t n) { int epoch = n / 16; int remain = n % 16; for (int i = 0; i < epoch; i++) { __m512 v1 = _mm512_loadu_ps(a); __m512 v2 = _mm512_set1_ps(*b); __m512 v = _mm512_mul_ps(v1, v2); _mm512_storeu_ps(a, v); a += 16; } if (remain >= 8) { __m256 v1 = _mm256_loadu_ps(a); __m256 v2 = _mm256_broadcast_ss(b); __m256 v = _mm256_mul_ps(v1, v2); _mm256_storeu_ps(a, v); a += 8; remain -= 8; } for (int i = 0; i < remain; i++) { a[i] *= b[0]; } } void _mm512_add_const(float *a, float *b, int64_t n) { int epoch = n / 16; int remain = n % 16; for (int i = 0; i < epoch; i++) { __m512 v1 = _mm512_loadu_ps(a); __m512 v2 = _mm512_set1_ps(*b); __m512 v = _mm512_add_ps(v1, v2); _mm512_storeu_ps(a, v); a += 16; } if (remain >= 8) { __m256 v1 = _mm256_loadu_ps(a); __m256 v2 = _mm256_broadcast_ss(b); __m256 v = _mm256_add_ps(v1, v2); _mm256_storeu_ps(a, v); a += 8; remain -= 8; } for (int i = 0; i < remain; i++) { a[i] += b[0]; } } void _mm512_sub_to(float *a, float *b, float *c, int64_t n) { int epoch = n / 16; int remain = n % 16; for (int i = 0; i < epoch; i++) { __m512 v1 = _mm512_loadu_ps(a); __m512 v2 = _mm512_loadu_ps(b); __m512 v = _mm512_sub_ps(v1, v2); _mm512_storeu_ps(c, v); a += 16; b += 16; c += 16; } if (remain >= 8) { __m256 v1 = _mm256_loadu_ps(a); __m256 v2 = _mm256_loadu_ps(b); __m256 v = _mm256_sub_ps(v1, v2); _mm256_storeu_ps(c, v); a += 8; b += 8; c += 8; remain -= 8; } for (int i = 0; i < remain; i++) { c[i] = a[i] - b[i]; } } void _mm512_sub(float *a, float *b, int64_t n) { int epoch = n / 16; int remain = n % 16; for (int i = 0; i < epoch; i++) { __m512 v1 = _mm512_loadu_ps(a); __m512 v2 = _mm512_loadu_ps(b); __m512 v = _mm512_sub_ps(v1, v2); _mm512_storeu_ps(a, v); a += 16; b += 16; } for (int i = 0; i < remain; i++) { a[i] -= b[i]; } } void _mm512_mul_to(float *a, float *b, float *c, int64_t n) { int epoch = n / 16; int remain = n % 16; for (int i = 0; i < epoch; i++) { __m512 v1 = _mm512_loadu_ps(a); __m512 v2 = _mm512_loadu_ps(b); __m512 v = _mm512_mul_ps(v1, v2); _mm512_storeu_ps(c, v); a += 16; b += 16; c += 16; } if (remain >= 8) { __m256 v1 = _mm256_loadu_ps(a); __m256 v2 = _mm256_loadu_ps(b); __m256 v = _mm256_mul_ps(v1, v2); _mm256_storeu_ps(c, v); a += 8; b += 8; c += 8; remain -= 8; } for (int i = 0; i < remain; i++) { c[i] = a[i] * b[i]; } } void _mm512_div_to(float *a, float *b, float *c, int64_t n) { int epoch = n / 16; int remain = n % 16; for (int i = 0; i < epoch; i++) { __m512 v1 = _mm512_loadu_ps(a); __m512 v2 = _mm512_loadu_ps(b); __m512 v = _mm512_div_ps(v1, v2); _mm512_storeu_ps(c, v); a += 16; b += 16; c += 16; } if (remain >= 8) { __m256 v1 = _mm256_loadu_ps(a); __m256 v2 = _mm256_loadu_ps(b); __m256 v = _mm256_div_ps(v1, v2); _mm256_storeu_ps(c, v); a += 8; b += 8; c += 8; remain -= 8; } for (int i = 0; i < remain; i++) { c[i] = a[i] / b[i]; } } void _mm512_sqrt_to(float *a, float *b, int64_t n) { int epoch = n / 16; int remain = n % 16; for (int i = 0; i < epoch; i++) { __m512 v1 = _mm512_loadu_ps(a); __m512 v2 = _mm512_sqrt_ps(v1); _mm512_storeu_ps(b, v2); a += 16; b += 16; } if (remain >= 8) { __m256 v1 = _mm256_loadu_ps(a); __m256 v2 = _mm256_sqrt_ps(v1); _mm256_storeu_ps(b, v2); a += 8; b += 8; remain -= 8; } for (int i = 0; i < remain; i++) { __m128 v = _mm_set1_ps(a[i]); __m128 r = _mm_sqrt_ss(v); b[i] = _mm_cvtss_f32(r); } } inline __attribute__((always_inline)) float dot(float *a, float *b, int64_t n) { int epoch = n / 16; int remain = n % 16; __m512 s = _mm512_setzero_ps(); if (epoch > 0) { __m512 v1 = _mm512_loadu_ps(a); __m512 v2 = _mm512_loadu_ps(b); s = _mm512_mul_ps(v1, v2); a += 16; b += 16; } for (int i = 1; i < epoch; i++) { __m512 v1 = _mm512_loadu_ps(a); __m512 v2 = _mm512_loadu_ps(b); s = _mm512_fmadd_ps(v1, v2, s); a += 16; b += 16; } __m256 sf_e_d_c_b_a_9_8 = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(s), 1)); __m256 s7_6_5_4_3_2_1_0 = _mm512_castps512_ps256(s); __m256 s7f_6e_5d_4c_3b_2a_19_08 = _mm256_add_ps(sf_e_d_c_b_a_9_8, s7_6_5_4_3_2_1_0); __m128 s7f_6e_5d_4c = _mm_castsi128_ps(_mm256_extracti128_si256(_mm256_castps_si256(s7f_6e_5d_4c_3b_2a_19_08), 1)); __m128 s3b_2a_19_08 = _mm256_castps256_ps128(s7f_6e_5d_4c_3b_2a_19_08); __m128 s37bf_26ae_159d_048c = _mm_add_ps(s7f_6e_5d_4c, s3b_2a_19_08); __m128 sxx_159d_048c = s37bf_26ae_159d_048c; __m128 sxx_37bf_26ae = _mm_movehl_ps(sxx_159d_048c, s37bf_26ae_159d_048c); const __m128 sxx_13579bdf_02468ace = _mm_add_ps(sxx_159d_048c, sxx_37bf_26ae); const __m128 sxxx_02468ace = sxx_13579bdf_02468ace; const __m128 sxxx_13579bdf = _mm_shuffle_ps(sxx_13579bdf_02468ace, sxx_13579bdf_02468ace, 0x1); __m128 sxxx_0123456789abcdef = _mm_add_ss(sxxx_02468ace, sxxx_13579bdf); float sum = _mm_cvtss_f32(sxxx_0123456789abcdef); if (remain >= 8) { __m256 s; __m256 v1 = _mm256_loadu_ps(a); __m256 v2 = _mm256_loadu_ps(b); s = _mm256_mul_ps(v1, v2); a += 8; b += 8; __m128 s7_6_5_4 = _mm256_extractf128_ps(s, 1); __m128 s3_2_1_0 = _mm256_castps256_ps128(s); __m128 s37_26_15_04 = _mm_add_ps(s7_6_5_4, s3_2_1_0); __m128 sxx_15_04 = s37_26_15_04; __m128 sxx_37_26 = _mm_movehl_ps(s37_26_15_04, s37_26_15_04); const __m128 sxx_1357_0246 = _mm_add_ps(sxx_15_04, sxx_37_26); const __m128 sxxx_0246 = sxx_1357_0246; const __m128 sxxx_1357 = _mm_shuffle_ps(sxx_1357_0246, sxx_1357_0246, 0x1); __m128 sxxx_01234567 = _mm_add_ss(sxxx_0246, sxxx_1357); sum += _mm_cvtss_f32(sxxx_01234567); remain -= 8; } for (int i = 0; i < remain; i++) { sum += a[i] * b[i]; } return sum; } float _mm512_dot(float *a, float *b, int64_t n) { return dot(a, b, n); } float _mm512_euclidean(float *a, float *b, int64_t n) { int epoch = n / 16; int remain = n % 16; __m512 s = _mm512_setzero_ps(); if (epoch > 0) { __m512 v1 = _mm512_loadu_ps(a); __m512 v2 = _mm512_loadu_ps(b); __m512 v = _mm512_sub_ps(v1, v2); s = _mm512_mul_ps(v, v); a += 16; b += 16; } for (int i = 1; i < epoch; i++) { __m512 v1 = _mm512_loadu_ps(a); __m512 v2 = _mm512_loadu_ps(b); __m512 v = _mm512_sub_ps(v1, v2); v = _mm512_mul_ps(v, v); s = _mm512_add_ps(v, s); a += 16; b += 16; } __m256 sf_e_d_c_b_a_9_8 = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(s), 1)); __m256 s7_6_5_4_3_2_1_0 = _mm512_castps512_ps256(s); __m256 s7f_6e_5d_4c_3b_2a_19_08 = _mm256_add_ps(sf_e_d_c_b_a_9_8, s7_6_5_4_3_2_1_0); __m128 s7f_6e_5d_4c = _mm_castsi128_ps(_mm256_extracti128_si256(_mm256_castps_si256(s7f_6e_5d_4c_3b_2a_19_08), 1)); __m128 s3b_2a_19_08 = _mm256_castps256_ps128(s7f_6e_5d_4c_3b_2a_19_08); __m128 s37bf_26ae_159d_048c = _mm_add_ps(s7f_6e_5d_4c, s3b_2a_19_08); __m128 sxx_159d_048c = s37bf_26ae_159d_048c; __m128 sxx_37bf_26ae = _mm_movehl_ps(sxx_159d_048c, s37bf_26ae_159d_048c); const __m128 sxx_13579bdf_02468ace = _mm_add_ps(sxx_159d_048c, sxx_37bf_26ae); const __m128 sxxx_02468ace = sxx_13579bdf_02468ace; const __m128 sxxx_13579bdf = _mm_shuffle_ps(sxx_13579bdf_02468ace, sxx_13579bdf_02468ace, 0x1); __m128 sxxx_0123456789abcdef = _mm_add_ps(sxxx_02468ace, sxxx_13579bdf); float sum = _mm_cvtss_f32(sxxx_0123456789abcdef); if (remain >= 8) { __m256 v1 = _mm256_loadu_ps(a); __m256 v2 = _mm256_loadu_ps(b); __m256 v = _mm256_sub_ps(v1, v2); v = _mm256_mul_ps(v, v); a += 8; b += 8; __m128 s7_6_5_4 = _mm256_extractf128_ps(v, 1); __m128 s3_2_1_0 = _mm256_castps256_ps128(v); __m128 s37_26_15_04 = _mm_add_ps(s7_6_5_4, s3_2_1_0); __m128 sxx_15_04 = s37_26_15_04; __m128 sxx_37_26 = _mm_movehl_ps(s37_26_15_04, s37_26_15_04); const __m128 sxx_1357_0246 = _mm_add_ps(sxx_15_04, sxx_37_26); const __m128 sxxx_0246 = sxx_1357_0246; const __m128 sxxx_1357 = _mm_shuffle_ps(sxx_1357_0246, sxx_1357_0246, 0x1); __m128 sxxx_01234567 = _mm_add_ss(sxxx_0246, sxxx_1357); sum += _mm_cvtss_f32(sxxx_01234567); remain -= 8; } for (int i = 0; i < remain; i++) { sum += (a[i] - b[i]) * (a[i] - b[i]); } __m128 v = _mm_set1_ps(sum); __m128 r = _mm_sqrt_ss(v); return _mm_cvtss_f32(r); } void _mm512_mm(_Bool transA, _Bool transB, int64_t m, int64_t n, int64_t k, float *a, int64_t lda, float *b, int64_t ldb, float *c, int64_t ldc) { if (!transA && !transB) { for (int i = 0; i < m; i++) { for (int l = 0; l < k; l++) { for (int j = 0; j < n; j++) { c[i * ldc + j] += a[i * lda + l] * b[l * ldb + j]; } } } } else if (!transA && transB) { for (int i = 0; i < m; i++) { for (int j = 0; j < n; j++) { c[i * ldc + j] = dot(a + i * lda, b + j * ldb, k); } } } else if (transA && !transB) { for (int i = 0; i < m; i++) { for (int l = 0; l < k; l++) { for (int j = 0; j < n; j++) { c[i * ldc + j] += a[l * lda + i] * b[l * ldb + j]; } } } } else if (transA && transB) { for (int i = 0; i < m; i++) { for (int l = 0; l < k; l++) { for (int j = 0; j < n; j++) { c[i * ldc + j] += a[l * lda + i] * b[j * ldb + l]; } } } } } ================================================ FILE: common/floats/src/floats_neon.c ================================================ // Copyright 2022 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include void vmul_const_add_to(float *a, float *b, float *c, float *dst, long n) { for (int i = 0; i < n; i++) { dst[i] = a[i] * (*b) + c[i]; } } void vmul_const_add(float *a, float *b, float *c, long n) { for (int i = 0; i < n; i++) { c[i] += a[i] * b[0]; } } void vmul_const_to(float *a, float *b, float *c, long n) { for (int i = 0; i < n; i++) { c[i] = a[i] * b[0]; } } void vmul_const(float *a, float *b, long n) { for (int i = 0; i < n; i++) { a[i] *= b[0]; } } void vadd_const(float *a, float *b, long n) { for (int i = 0; i < n; i++) { a[i] += b[0]; } } void vsub_to(float *a, float *b, float *c, long n) { for (long i = 0; i < n; i++) { c[i] = a[i] - b[i]; } } void vsub(float *a, float *b, long n) { for (long i = 0; i < n; i++) { a[i] -= b[i]; } } void vmul_to(float *a, float *b, float *c, long n) { for (long i = 0; i < n; i++) { c[i] = a[i] * b[i]; } } void vdiv_to(float *a, float *b, float *c, long n) { for (int64_t i = 0; i < n; i++) { c[i] = a[i] / b[i]; } } void vsqrt_to(float *a, float *b, long n) { int epoch = n / 4; int remain = n % 4; for (int i = 0; i < epoch; i++) { float32x4_t v1 = vld1q_f32(a); float32x4_t v2 = vsqrtq_f32(v1); vst1q_f32(b, v2); a += 4; b += 4; } for (int i = 0; i < remain; i++) { float32x2_t v = vdup_n_f32(a[i]); float32x2_t r = vsqrt_f32(v); b[i] = vget_lane_f32(r, 0); } } inline float dot(float *a, float *b, long n) { int epoch = n / 4; int remain = n % 4; float32x4_t s = vdupq_n_f32(0); if (epoch > 0) { float32x4_t v1 = vld1q_f32(a); float32x4_t v2 = vld1q_f32(b); s = vmulq_f32(v1, v2); a += 4; b += 4; } for (int i = 1; i < epoch; i++) { float32x4_t v1 = vld1q_f32(a); float32x4_t v2 = vld1q_f32(b); s = vmlaq_f32(s, v1, v2); a += 4; b += 4; } float partial[4]; vst1q_f32(partial, s); float sum = 0; for (int i = 0; i < 4; i++) { sum += partial[i]; } for (int i = 0; i < remain; i++) { sum += a[i] * b[i]; } return sum; } float vdot(float *a, float *b, long n) { return dot(a, b, n); } float veuclidean(float *a, float *b, long n) { int epoch = n / 4; int remain = n % 4; float32x4_t s = vdupq_n_f32(0); if (epoch > 0) { float32x4_t v1 = vld1q_f32(a); float32x4_t v2 = vld1q_f32(b); float32x4_t v = vsubq_f32(v1, v2); s = vmulq_f32(v, v); a += 4; b += 4; } for (int i = 1; i < epoch; i++) { float32x4_t v1 = vld1q_f32(a); float32x4_t v2 = vld1q_f32(b); float32x4_t v = vsubq_f32(v1, v2); s = vmlaq_f32(s, v, v); a += 4; b += 4; } float partial[4]; vst1q_f32(partial, s); float sum = 0; for (int i = 0; i < 4; i++) { sum += partial[i]; } for (int i = 0; i < remain; i++) { sum += (a[i] - b[i]) * (a[i] - b[i]); } float32x2_t v = vld1_f32(&sum); float32x2_t r = vsqrt_f32(v); return vget_lane_f32(r, 0); } void vmm(_Bool transA, _Bool transB, long m, long n, long k, float *a, long lda, float *b, long ldb, float *c, long ldc) { if (!transA && !transB) { for (int i = 0; i < m; i++) { for (int l = 0; l < k; l++) { for (int j = 0; j < n; j++) { c[i * ldc + j] += a[i * lda + l] * b[l * ldb + j]; } } } } else if (!transA && transB) { for (int i = 0; i < m; i++) { for (int j = 0; j < n; j++) { c[i * ldc + j] = dot(a + i * lda, b + j * ldb, k); } } } else if (transA && !transB) { for (int i = 0; i < m; i++) { for (int l = 0; l < k; l++) { for (int j = 0; j < n; j++) { c[i * ldc + j] += a[l * lda + i] * b[l * ldb + j]; } } } } else if (transA && transB) { for (int i = 0; i < m; i++) { for (int l = 0; l < k; l++) { for (int j = 0; j < n; j++) { c[i * ldc + j] += a[l * lda + i] * b[j * ldb + l]; } } } } } ================================================ FILE: common/floats/src/floats_rvv.c ================================================ // Copyright 2025 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include void vmul_const_add_to(float *a, float *b, float *c, float *dst, long n) { for (int i = 0; i < n; i++) { dst[i] = a[i] * (*b) + c[i]; } } void vmul_const_add(float *a, float *b, float *c, long n) { for (int i = 0; i < n; i++) { c[i] += a[i] * b[0]; } } void vmul_const_to(float *a, float *b, float *c, long n) { for (int i = 0; i < n; i++) { c[i] = a[i] * b[0]; } } void vmul_const(float *a, float *b, long n) { for (int i = 0; i < n; i++) { a[i] *= b[0]; } } void vadd_const(float *a, float *b, long n) { for (int i = 0; i < n; i++) { a[i] += b[0]; } } void vsub_to(float *a, float *b, float *c, long n) { for (long i = 0; i < n; i++) { c[i] = a[i] - b[i]; } } void vsub(float *a, float *b, long n) { for (long i = 0; i < n; i++) { a[i] -= b[i]; } } void vmul_to(float *a, float *b, float *c, long n) { for (long i = 0; i < n; i++) { c[i] = a[i] * b[i]; } } void vdiv_to(float *a, float *b, float *c, long n) { for (long i = 0; i < n; i++) { c[i] = a[i] / b[i]; } } void vsqrt_to(float *a, float *b, long n) { for (size_t vl; n > 0; a += vl, b += vl, n -= vl) { vl = __riscv_vsetvl_e32m1(n); vfloat32m1_t v1 = __riscv_vle32_v_f32m1(a, vl); vfloat32m1_t v2 = __riscv_vfsqrt_v_f32m1(v1, vl); __riscv_vse32_v_f32m1(b, v2, vl); } } inline float dot(float *a, float *b, long n) { size_t vlmax = __riscv_vsetvlmax_e32m1(); int epoch = n / vlmax; int remain = n % vlmax; vfloat32m1_t s1 = __riscv_vfmv_v_f_f32m1(0, vlmax); for (int i = 0; i < epoch; i++) { vfloat32m1_t v1 = __riscv_vle32_v_f32m1(a, vlmax); vfloat32m1_t v2 = __riscv_vle32_v_f32m1(b, vlmax); s1 = __riscv_vfmacc_vv_f32m1(s1, v1, v2, vlmax); a += vlmax; b += vlmax; } vfloat32m1_t s = __riscv_vfmv_v_f_f32m1(0, vlmax); s = __riscv_vfredosum_vs_f32m1_f32m1(s1, s, vlmax); size_t vl = __riscv_vsetvl_e32m1(remain); vfloat32m1_t v1 = __riscv_vle32_v_f32m1(a, vl); vfloat32m1_t v2 = __riscv_vle32_v_f32m1(b, vl); vfloat32m1_t s2 = __riscv_vfmul_vv_f32m1(v1, v2, vl); s = __riscv_vfredosum_vs_f32m1_f32m1(s2, s, vl); return __riscv_vfmv_f_s_f32m1_f32(s); } float vdot(float *a, float *b, long n) { return dot(a, b, n); } float veuclidean(float *a, float *b, long n) { size_t vlmax = __riscv_vsetvlmax_e32m1(); int epoch = n / vlmax; int remain = n % vlmax; vfloat32m1_t s1 = __riscv_vfmv_v_f_f32m1(0, vlmax); for (int i = 0; i < epoch; i++) { vfloat32m1_t v1 = __riscv_vle32_v_f32m1(a, vlmax); vfloat32m1_t v2 = __riscv_vle32_v_f32m1(b, vlmax); vfloat32m1_t v = __riscv_vfsub_vv_f32m1(v1, v2, vlmax); s1 = __riscv_vfmacc_vv_f32m1(s1, v, v, vlmax); a += vlmax; b += vlmax; } vfloat32m1_t s = __riscv_vfmv_v_f_f32m1(0, vlmax); s = __riscv_vfredosum_vs_f32m1_f32m1(s1, s, vlmax); size_t vl = __riscv_vsetvl_e32m1(remain); vfloat32m1_t v1 = __riscv_vle32_v_f32m1(a, vl); vfloat32m1_t v2 = __riscv_vle32_v_f32m1(b, vl); vfloat32m1_t v = __riscv_vfsub_vv_f32m1(v1, v2, vlmax); vfloat32m1_t s2 = __riscv_vfmul_vv_f32m1(v, v, vl); s = __riscv_vfredosum_vs_f32m1_f32m1(s2, s, vl); s = __riscv_vfsqrt_v_f32m1(s, vl); return __riscv_vfmv_f_s_f32m1_f32(s); } void vmm(_Bool transA, _Bool transB, long m, long n, long k, float *a, long lda, float *b, long ldb, float *c, long ldc) { if (!transA && !transB) { for (int i = 0; i < m; i++) { for (int l = 0; l < k; l++) { for (int j = 0; j < n; j++) { c[i * ldc + j] += a[i * lda + l] * b[l * ldb + j]; } } } } else if (!transA && transB) { for (int i = 0; i < m; i++) { for (int j = 0; j < n; j++) { c[i * ldc + j] = dot(a + i * lda, b + j * ldb, k); } } } else if (transA && !transB) { for (int i = 0; i < m; i++) { for (int l = 0; l < k; l++) { for (int j = 0; j < n; j++) { c[i * ldc + j] += a[l * lda + i] * b[l * ldb + j]; } } } } else if (transA && transB) { for (int i = 0; i < m; i++) { for (int l = 0; l < k; l++) { for (int j = 0; j < n; j++) { c[i * ldc + j] += a[l * lda + i] * b[j * ldb + l]; } } } } } ================================================ FILE: common/floats/src/floats_sve2.c ================================================ // Copyright 2022 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include void svmul_const_add_to(float *a, float *b, float *c, long n) { for (long i = 0; i < n; i += svcntw()) { svbool_t pg = svwhilelt_b32(i, n); svfloat32_t a_seg = svld1(pg, a + i); svfloat32_t c_seg = svld1(pg, c + i); svst1(pg, c + i, svmla_x(pg, c_seg, a_seg, *b)); } } void svmul_const_to(float *a, float *b, float *c, long n) { for (long i = 0; i < n; i += svcntw()) { svbool_t pg = svwhilelt_b32(i, n); svfloat32_t a_seg = svld1(pg, a + i); svst1(pg, c + i, svmul_x(pg, a_seg, *b)); } } void svmul_const(float *a, float *b, long n) { for (long i = 0; i < n; i += svcntw()) { svbool_t pg = svwhilelt_b32(i, n); svfloat32_t a_seg = svld1(pg, a + i); svst1(pg, a + i, svmul_x(pg, a_seg, *b)); } } void svmul_to(float *a, float *b, float *c, long n) { for (long i = 0; i < n; i += svcntw()) { svbool_t pg = svwhilelt_b32(i, n); svfloat32_t a_seg = svld1(pg, a + i); svfloat32_t b_seg = svld1(pg, b + i); svst1(pg, c + i, svmul_x(pg, a_seg, b_seg)); } } ================================================ FILE: common/floats/src/floats_test.c ================================================ #include "munit.h" #include "math.h" const size_t kVectorLength = 63; const size_t kIteration = 1; /* no simd */ void mul_const_add_to(float *a, float *b, float *c, float *dst, int64_t n) { for (int64_t i = 0; i < n; i++) { dst[i] = a[i] * (*b) + c[i]; } } void mul_const_add(float *a, float *b, float *c, int64_t n) { for (int64_t i = 0; i < n; i++) { c[i] += a[i] * (*b); } } void mul_const_to(float *a, float *b, float *c, int64_t n) { for (int64_t i = 0; i < n; i++) { c[i] = a[i] * (*b); } } void mul_const(float *a, float *b, int64_t n) { for (int64_t i = 0; i < n; i++) { a[i] *= *b; } } void sub_to(float *a, float *b, float *c, int64_t n) { for (int64_t i = 0; i < n; i++) { c[i] = a[i] - b[i]; } } void sub(float *a, float *b, int64_t n) { for (int64_t i = 0; i < n; i++) { a[i] -= b[i]; } } void mul_to(float *a, float *b, float *c, int64_t n) { for (int64_t i = 0; i < n; i++) { c[i] = a[i] * b[i]; } } void div_to(float *a, float *b, float *c, int64_t n) { for (int64_t i = 0; i < n; i++) { c[i] = a[i] / b[i]; } } void sqrt_to(float *a, float *b, int64_t n) { for (int64_t i = 0; i < n; i++) { b[i] = sqrtf(a[i]); } } float dot(float *a, float *b, int64_t n) { float sum = 0; for (int64_t i = 0; i < n; i++) { sum += a[i] * b[i]; } return sum; } float euclidean(float *a, float *b, int64_t n) { float sum = 0; for (int64_t i = 0; i < n; i++) { sum += powf(a[i] - b[i], 2); } return sqrtf(sum); } int rand_float(float *a, int64_t n) { for (int i = 0; i < n; i++) { a[i] = munit_rand_double(); } } #if defined(__x86_64__) void _mm256_mul_const_add_to(float *a, float *b, float *c, float *dst, int64_t n); void _mm256_mul_const_add(float *a, float *b, float *c, int64_t n); void _mm256_mul_const_to(float *a, float *b, float *c, int64_t n); void _mm256_mul_const(float *a, float *b, int64_t n); void _mm256_sub_to(float *a, float *b, float *c, int64_t n); void _mm256_sub(float *a, float *b, int64_t n); void _mm256_mul_to(float *a, float *b, float *c, int64_t n); void _mm256_div_to(float *a, float *b, float *c, int64_t n); void _mm256_sqrt_to(float *a, float *b, int64_t n); float _mm256_dot(float *a, float *b, int64_t n); float _mm256_euclidean(float *a, float *b, int64_t n); void _mm512_mul_const_add_to(float *a, float *b, float *c, float *dst, int64_t n); void _mm512_mul_const_add(float *a, float *b, float *c, int64_t n); void _mm512_mul_const_to(float *a, float *b, float *c, int64_t n); void _mm512_mul_const(float *a, float *b, int64_t n); void _mm512_sub_to(float *a, float *b, float *c, int64_t n); void _mm512_sub(float *a, float *b, int64_t n); void _mm512_mul_to(float *a, float *b, float *c, int64_t n); void _mm512_div_to(float *a, float *b, float *c, int64_t n); void _mm512_sqrt_to(float *a, float *b, int64_t n); float _mm512_dot(float *a, float *b, int64_t n); float _mm512_euclidean(float *a, float *b, int64_t n); MunitResult mm256_mul_const_add_to_test(const MunitParameter params[], void *user_data_or_fixture) { float a[kVectorLength], b[kVectorLength], expect[kVectorLength], actual[kVectorLength]; rand_float(a, kVectorLength); rand_float(b, kVectorLength); float d = munit_rand_double(); mul_const_add_to(a, &d, b, expect, kVectorLength); _mm256_mul_const_add_to(a, &d, b, actual, kVectorLength); munit_assert_floats_equal(kVectorLength, expect, actual); return MUNIT_OK; } MunitResult mm256_mul_const_add_test(const MunitParameter params[], void *user_data_or_fixture) { float a[kVectorLength], expect[kVectorLength], actual[kVectorLength]; rand_float(a, kVectorLength); rand_float(expect, kVectorLength); memcpy(expect, actual, sizeof(float) * kVectorLength); float b = munit_rand_double(); mul_const_add(a, &b, expect, kVectorLength); _mm256_mul_const_add(a, &b, actual, kVectorLength); munit_assert_floats_equal(kVectorLength, expect, actual); return MUNIT_OK; } MunitResult mm256_mul_const_to_test(const MunitParameter params[], void *user_data_or_fixture) { float a[kVectorLength], expect[kVectorLength], actual[kVectorLength]; rand_float(a, kVectorLength); float b = munit_rand_double(); mul_const_to(a, &b, expect, kVectorLength); _mm256_mul_const_to(a, &b, actual, kVectorLength); munit_assert_floats_equal(kVectorLength, expect, actual); return MUNIT_OK; } MunitResult mm256_mul_const_test(const MunitParameter params[], void *user_data_or_fixture) { float expect[kVectorLength], actual[kVectorLength]; rand_float(expect, kVectorLength); memcpy(expect, actual, sizeof(float) * kVectorLength); float b = munit_rand_double(); mul_const(expect, &b, kVectorLength); _mm256_mul_const(actual, &b, kVectorLength); munit_assert_floats_equal(kVectorLength, expect, actual); return MUNIT_OK; } MunitResult mm256_sub_to_test(const MunitParameter params[], void *user_data_or_fixture) { float a[kVectorLength], b[kVectorLength], expect[kVectorLength], actual[kVectorLength]; rand_float(a, kVectorLength); rand_float(b, kVectorLength); sub_to(a, b, expect, kVectorLength); _mm256_sub_to(a, b, actual, kVectorLength); munit_assert_floats_equal(kVectorLength, expect, actual); return MUNIT_OK; } MunitResult mm256_sub_test(const MunitParameter params[], void *user_data_or_fixture) { float expected[kVectorLength], actual[kVectorLength], b[kVectorLength]; rand_float(b, kVectorLength); rand_float(expected, kVectorLength); memcpy(expected, actual, sizeof(float) * kVectorLength); sub(expected, b, kVectorLength); _mm256_sub(actual, b, kVectorLength); munit_assert_floats_equal(kVectorLength, expected, actual); return MUNIT_OK; } MunitResult mm256_mul_to_test(const MunitParameter params[], void *user_data_or_fixture) { float a[kVectorLength], b[kVectorLength], expect[kVectorLength], actual[kVectorLength]; rand_float(a, kVectorLength); rand_float(b, kVectorLength); mul_to(a, b, expect, kVectorLength); _mm256_mul_to(a, b, actual, kVectorLength); munit_assert_floats_equal(kVectorLength, expect, actual); return MUNIT_OK; } MunitResult mm256_div_to_test(const MunitParameter params[], void *user_data_or_fixture) { float a[kVectorLength], b[kVectorLength], expect[kVectorLength], actual[kVectorLength]; rand_float(a, kVectorLength); rand_float(b, kVectorLength); div_to(a, b, expect, kVectorLength); _mm256_div_to(a, b, actual, kVectorLength); munit_assert_floats_equal(kVectorLength, expect, actual); return MUNIT_OK; } MunitResult mm256_sqrt_to_test(const MunitParameter params[], void *user_data_or_fixture) { float a[kVectorLength], expect[kVectorLength], actual[kVectorLength]; rand_float(a, kVectorLength); memcpy(expect, actual, sizeof(float) * kVectorLength); sqrt_to(a, expect, kVectorLength); _mm256_sqrt_to(a, actual, kVectorLength); munit_assert_floats_equal(kVectorLength, expect, actual); return MUNIT_OK; } MunitResult mm256_dot_test(const MunitParameter params[], void *user_data_or_fixture) { float a[kVectorLength], b[kVectorLength]; rand_float(a, kVectorLength); rand_float(b, kVectorLength); float expect = dot(a, b, kVectorLength); float actual = _mm256_dot(a, b, kVectorLength); munit_assert_float_equal(expect, actual, 5); return MUNIT_OK; } MunitResult mm256_euclidean_test(const MunitParameter params[], void *user_data_or_fixture) { float a[kVectorLength], b[kVectorLength]; rand_float(a, kVectorLength); rand_float(b, kVectorLength); float expect = euclidean(a, b, kVectorLength); float actual = _mm256_euclidean(a, b, kVectorLength); munit_assert_float_equal(expect, actual, 5); return MUNIT_OK; } MunitTest mm256_tests[] = { {"mul_const_add_to", mm256_mul_const_add_to_test, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, {"mul_const_add", mm256_mul_const_add_test, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, {"mul_const_to", mm256_mul_const_to_test, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, {"mul_const", mm256_mul_const_test, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, {"sub", mm256_sub_test, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, {"sub_to", mm256_sub_to_test, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, {"mul_to", mm256_mul_to_test, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, {"div_to", mm256_div_to_test, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, {"sqrt_to", mm256_sqrt_to_test, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, {"dot", mm256_dot_test, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, {"euclidean", mm256_euclidean_test, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, {NULL, NULL, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}}; static const MunitSuite mm256_suite = { "mm256_", mm256_tests, NULL, kIteration, MUNIT_SUITE_OPTION_NONE}; MunitResult mm512_mul_const_add_to_test(const MunitParameter params[], void *user_data_or_fixture) { float a[kVectorLength], b[kVectorLength], c[kVectorLength], expect[kVectorLength], actual[kVectorLength]; rand_float(a, kVectorLength); rand_float(b, kVectorLength); rand_float(c, kVectorLength); memcpy(expect, actual, sizeof(float) * kVectorLength); float d = munit_rand_double(); mul_const_add_to(a, &d, b, expect, kVectorLength); _mm512_mul_const_add_to(a, &d, b, actual, kVectorLength); munit_assert_floats_equal(kVectorLength, expect, actual); return MUNIT_OK; } MunitResult mm512_mul_const_add_test(const MunitParameter params[], void *user_data_or_fixture) { float a[kVectorLength], expect[kVectorLength], actual[kVectorLength]; rand_float(a, kVectorLength); rand_float(expect, kVectorLength); memcpy(expect, actual, sizeof(float) * kVectorLength); float b = munit_rand_double(); mul_const_add(a, &b, expect, kVectorLength); _mm512_mul_const_add(a, &b, actual, kVectorLength); munit_assert_floats_equal(kVectorLength, expect, actual); return MUNIT_OK; } MunitResult mm512_mul_const_to_test(const MunitParameter params[], void *user_data_or_fixture) { float a[kVectorLength], expect[kVectorLength], actual[kVectorLength]; rand_float(a, kVectorLength); float b = munit_rand_double(); mul_const_to(a, &b, expect, kVectorLength); _mm512_mul_const_to(a, &b, actual, kVectorLength); munit_assert_floats_equal(kVectorLength, expect, actual); return MUNIT_OK; } MunitResult mm512_mul_const_test(const MunitParameter params[], void *user_data_or_fixture) { float expect[kVectorLength], actual[kVectorLength]; rand_float(expect, kVectorLength); memcpy(expect, actual, sizeof(float) * kVectorLength); float b = munit_rand_double(); mul_const(expect, &b, kVectorLength); _mm512_mul_const(actual, &b, kVectorLength); munit_assert_floats_equal(kVectorLength, expect, actual); return MUNIT_OK; } MunitResult mm512_sub_test(const MunitParameter params[], void *user_data_or_fixture) { float expected[kVectorLength], actual[kVectorLength], b[kVectorLength]; rand_float(b, kVectorLength); rand_float(expected, kVectorLength); memcpy(expected, actual, sizeof(float) * kVectorLength); sub(expected, b, kVectorLength); _mm512_sub(actual, b, kVectorLength); munit_assert_floats_equal(kVectorLength, expected, actual); return MUNIT_OK; } MunitResult mm512_sub_to_test(const MunitParameter params[], void *user_data_or_fixture) { float a[kVectorLength], b[kVectorLength], expect[kVectorLength], actual[kVectorLength]; rand_float(a, kVectorLength); rand_float(b, kVectorLength); sub_to(a, b, expect, kVectorLength); _mm512_sub_to(a, b, actual, kVectorLength); munit_assert_floats_equal(kVectorLength, expect, actual); return MUNIT_OK; } MunitResult mm512_mul_to_test(const MunitParameter params[], void *user_data_or_fixture) { float a[kVectorLength], b[kVectorLength], expect[kVectorLength], actual[kVectorLength]; rand_float(a, kVectorLength); rand_float(b, kVectorLength); mul_to(a, b, expect, kVectorLength); _mm512_mul_to(a, b, actual, kVectorLength); munit_assert_floats_equal(kVectorLength, expect, actual); return MUNIT_OK; } MunitResult mm512_div_to_test(const MunitParameter params[], void *user_data_or_fixture) { float a[kVectorLength], b[kVectorLength], expect[kVectorLength], actual[kVectorLength]; rand_float(a, kVectorLength); rand_float(b, kVectorLength); div_to(a, b, expect, kVectorLength); _mm512_div_to(a, b, actual, kVectorLength); munit_assert_floats_equal(kVectorLength, expect, actual); return MUNIT_OK; } MunitResult mm512_sqrt_to_test(const MunitParameter params[], void *user_data_or_fixture) { float a[kVectorLength], expect[kVectorLength], actual[kVectorLength]; rand_float(a, kVectorLength); memcpy(expect, actual, sizeof(float) * kVectorLength); sqrt_to(a, expect, kVectorLength); _mm512_sqrt_to(a, actual, kVectorLength); munit_assert_floats_equal(kVectorLength, expect, actual); return MUNIT_OK; } MunitResult mm512_dot_test(const MunitParameter params[], void *user_data_or_fixture) { float a[kVectorLength], b[kVectorLength]; rand_float(a, kVectorLength); rand_float(b, kVectorLength); float expect = dot(a, b, kVectorLength); float actual = _mm512_dot(a, b, kVectorLength); munit_assert_float_equal(expect, actual, 5); return MUNIT_OK; } MunitResult mm512_euclidean_test(const MunitParameter params[], void *user_data_or_fixture) { float a[kVectorLength], b[kVectorLength]; rand_float(a, kVectorLength); rand_float(b, kVectorLength); float expect = euclidean(a, b, kVectorLength); float actual = _mm512_euclidean(a, b, kVectorLength); munit_assert_float_equal(expect, actual, 5); return MUNIT_OK; } MunitTest mm512_tests[] = { {"mul_const_add_to", mm512_mul_const_add_to_test, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, {"mul_const_add", mm512_mul_const_add_test, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, {"mul_const_to", mm512_mul_const_to_test, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, {"mul_const", mm512_mul_const_test, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, {"sub_to", mm512_sub_to_test, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, {"sub", mm512_sub_test, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, {"mul_to", mm512_mul_to_test, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, {"div_to", mm512_div_to_test, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, {"sqrt_to", mm512_sqrt_to_test, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, {"dot", mm512_dot_test, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, {"euclidean", mm512_euclidean_test, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, {NULL, NULL, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}}; static const MunitSuite mm512_suite = { "mm512_", mm512_tests, NULL, kIteration, MUNIT_SUITE_OPTION_NONE}; int main(int argc, char *const argv[MUNIT_ARRAY_PARAM(argc + 1)]) { munit_suite_main(&mm256_suite, NULL, argc, argv); munit_suite_main(&mm512_suite, NULL, argc, argv); return 0; } #elif defined(__aarch64__) void vmul_const_add_to(float *a, float *b, float *c, float *dst, int64_t n); void vmul_const_add(float *a, float *b, float *c, int64_t n); void vmul_const_to(float *a, float *b, float *c, int64_t n); void vmul_const(float *a, float *b, int64_t n); void vsub_to(float *a, float *b, float *c, int64_t n); void vsub(float *a, float *b, int64_t n); void vmul_to(float *a, float *b, float *c, int64_t n); void vdiv_to(float *a, float *b, float *c, int64_t n); void vsqrt_to(float *a, float *b, int64_t n); float vdot(float *a, float *b, int64_t n); float veuclidean(float *a, float *b, int64_t n); MunitResult vmul_const_add_to_test(const MunitParameter params[], void *user_data_or_fixture) { float a[kVectorLength], b[kVectorLength], expect[kVectorLength], actual[kVectorLength]; rand_float(a, kVectorLength); rand_float(b, kVectorLength); float d = munit_rand_double(); mul_const_add_to(a, &d, b, expect, kVectorLength); vmul_const_add_to(a, &d, b, actual, kVectorLength); munit_assert_floats_equal(kVectorLength, expect, actual); return MUNIT_OK; } MunitResult vmul_const_add_test(const MunitParameter params[], void *user_data_or_fixture) { float a[kVectorLength], expect[kVectorLength], actual[kVectorLength]; rand_float(a, kVectorLength); rand_float(expect, kVectorLength); memcpy(expect, actual, sizeof(float) * kVectorLength); float b = munit_rand_double(); mul_const_add(a, &b, expect, kVectorLength); vmul_const_add(a, &b, actual, kVectorLength); munit_assert_floats_equal(kVectorLength, expect, actual); return MUNIT_OK; } MunitResult vmul_const_to_test(const MunitParameter params[], void *user_data_or_fixture) { float a[kVectorLength], expect[kVectorLength], actual[kVectorLength]; rand_float(a, kVectorLength); float b = munit_rand_double(); mul_const_to(a, &b, expect, kVectorLength); vmul_const_to(a, &b, actual, kVectorLength); munit_assert_floats_equal(kVectorLength, expect, actual); return MUNIT_OK; } MunitResult vmul_const_test(const MunitParameter params[], void *user_data_or_fixture) { float expect[kVectorLength], actual[kVectorLength]; rand_float(expect, kVectorLength); memcpy(expect, actual, sizeof(float) * kVectorLength); float b = munit_rand_double(); mul_const(expect, &b, kVectorLength); vmul_const(actual, &b, kVectorLength); munit_assert_floats_equal(kVectorLength, expect, actual); return MUNIT_OK; } MunitResult vmul_to_test(const MunitParameter params[], void *user_data_or_fixture) { float a[kVectorLength], b[kVectorLength], expect[kVectorLength], actual[kVectorLength]; rand_float(a, kVectorLength); rand_float(b, kVectorLength); mul_to(a, b, expect, kVectorLength); vmul_to(a, b, actual, kVectorLength); munit_assert_floats_equal(kVectorLength, expect, actual); return MUNIT_OK; } MunitResult vdot_test(const MunitParameter params[], void *user_data_or_fixture) { float a[kVectorLength], b[kVectorLength]; rand_float(a, kVectorLength); rand_float(b, kVectorLength); float expect = dot(a, b, kVectorLength); float actual = vdot(a, b, kVectorLength); munit_assert_float_equal(expect, actual, 5); return MUNIT_OK; } MunitResult veuclidean_test(const MunitParameter params[], void *user_data_or_fixture) { float a[kVectorLength], b[kVectorLength]; rand_float(a, kVectorLength); rand_float(b, kVectorLength); float expect = euclidean(a, b, kVectorLength); float actual = veuclidean(a, b, kVectorLength); munit_assert_float_equal(expect, actual, 5); return MUNIT_OK; } MunitTest vtests[] = { {"mul_const_add_to", vmul_const_add_to_test, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, {"mul_const_to", vmul_const_to_test, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, {"mul_const", vmul_const_test, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, {"mul_to", vmul_to_test, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, {"dot", vdot_test, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, {"euclidean", veuclidean_test, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, {NULL, NULL, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}}; static const MunitSuite vsuite = { "v", vtests, NULL, kIteration, MUNIT_SUITE_OPTION_NONE}; void svmul_const_add_to(float *a, float *b, float *c, float *dst, long n); void svmul_const_to(float *a, float *b, float *c, long n); void svmul_const(float *a, float *b, long n); void svmul_to(float *a, float *b, float *c, long n); MunitResult svmul_const_add_to_test(const MunitParameter params[], void *user_data_or_fixture) { float a[kVectorLength], b[kVectorLength], expect[kVectorLength], actual[kVectorLength]; rand_float(a, kVectorLength); rand_float(b, kVectorLength); float d = munit_rand_double(); mul_const_add_to(a, &d, b, expect, kVectorLength); svmul_const_add_to(a, &d, b, actual, kVectorLength); munit_assert_floats_equal(kVectorLength, expect, actual); return MUNIT_OK; } MunitResult svmul_const_to_test(const MunitParameter params[], void *user_data_or_fixture) { float a[kVectorLength], expect[kVectorLength], actual[kVectorLength]; rand_float(a, kVectorLength); float b = munit_rand_double(); mul_const_to(a, &b, expect, kVectorLength); svmul_const_to(a, &b, actual, kVectorLength); munit_assert_floats_equal(kVectorLength, expect, actual); return MUNIT_OK; } MunitResult svmul_const_test(const MunitParameter params[], void *user_data_or_fixture) { float expect[kVectorLength], actual[kVectorLength]; rand_float(expect, kVectorLength); memcpy(expect, actual, sizeof(float) * kVectorLength); float b = munit_rand_double(); mul_const(expect, &b, kVectorLength); svmul_const(actual, &b, kVectorLength); munit_assert_floats_equal(kVectorLength, expect, actual); return MUNIT_OK; } MunitResult svmul_to_test(const MunitParameter params[], void *user_data_or_fixture) { float a[kVectorLength], b[kVectorLength], expect[kVectorLength], actual[kVectorLength]; rand_float(a, kVectorLength); rand_float(b, kVectorLength); mul_to(a, b, expect, kVectorLength); svmul_to(a, b, actual, kVectorLength); munit_assert_floats_equal(kVectorLength, expect, actual); return MUNIT_OK; } MunitTest svtests[] = { {"mul_const_add_to", svmul_const_add_to_test, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, {"mul_const_to", svmul_const_to_test, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, {"mul_const", svmul_const_test, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, {"mul_to", svmul_to_test, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, {NULL, NULL, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}}; static const MunitSuite svsuite = { "sv", svtests, NULL, kIteration, MUNIT_SUITE_OPTION_NONE}; int main(int argc, char *const argv[MUNIT_ARRAY_PARAM(argc + 1)]) { munit_suite_main(&vsuite, NULL, argc, argv); munit_suite_main(&svsuite, NULL, argc, argv); return 0; } #elif defined(__riscv) && (__riscv_xlen == 64) void vmul_const_add_to(float *a, float *b, float *c, float *dst, int64_t n); void vmul_const_add(float *a, float *b, float *c, int64_t n); void vmul_const_to(float *a, float *b, float *c, int64_t n); void vmul_const(float *a, float *b, int64_t n); void vsub_to(float *a, float *b, float *c, int64_t n); void vsub(float *a, float *b, int64_t n); void vmul_to(float *a, float *b, float *c, int64_t n); void vdiv_to(float *a, float *b, float *c, int64_t n); void vsqrt_to(float *a, float *b, int64_t n); float vdot(float *a, float *b, int64_t n); float veuclidean(float *a, float *b, int64_t n); MunitResult vmul_const_add_to_test(const MunitParameter params[], void *user_data_or_fixture) { float a[kVectorLength], b[kVectorLength], expect[kVectorLength], actual[kVectorLength]; rand_float(a, kVectorLength); rand_float(b, kVectorLength); float d = munit_rand_double(); mul_const_add_to(a, &d, b, expect, kVectorLength); vmul_const_add_to(a, &d, b, actual, kVectorLength); munit_assert_floats_equal(kVectorLength, expect, actual); return MUNIT_OK; } MunitResult vmul_const_add_test(const MunitParameter params[], void *user_data_or_fixture) { float a[kVectorLength], expect[kVectorLength], actual[kVectorLength]; rand_float(a, kVectorLength); rand_float(expect, kVectorLength); memcpy(expect, actual, sizeof(float) * kVectorLength); float b = munit_rand_double(); mul_const_add(a, &b, expect, kVectorLength); vmul_const_add(a, &b, actual, kVectorLength); munit_assert_floats_equal(kVectorLength, expect, actual); return MUNIT_OK; } MunitResult vmul_const_to_test(const MunitParameter params[], void *user_data_or_fixture) { float a[kVectorLength], expect[kVectorLength], actual[kVectorLength]; rand_float(a, kVectorLength); float b = munit_rand_double(); mul_const_to(a, &b, expect, kVectorLength); vmul_const_to(a, &b, actual, kVectorLength); munit_assert_floats_equal(kVectorLength, expect, actual); return MUNIT_OK; } MunitResult vmul_const_test(const MunitParameter params[], void *user_data_or_fixture) { float expect[kVectorLength], actual[kVectorLength]; rand_float(expect, kVectorLength); memcpy(expect, actual, sizeof(float) * kVectorLength); float b = munit_rand_double(); mul_const(expect, &b, kVectorLength); vmul_const(actual, &b, kVectorLength); munit_assert_floats_equal(kVectorLength, expect, actual); return MUNIT_OK; } MunitResult vmul_to_test(const MunitParameter params[], void *user_data_or_fixture) { float a[kVectorLength], b[kVectorLength], expect[kVectorLength], actual[kVectorLength]; rand_float(a, kVectorLength); rand_float(b, kVectorLength); mul_to(a, b, expect, kVectorLength); vmul_to(a, b, actual, kVectorLength); munit_assert_floats_equal(kVectorLength, expect, actual); return MUNIT_OK; } MunitResult vsqrt_to_test(const MunitParameter params[], void *user_data_or_fixture) { float a[kVectorLength], expect[kVectorLength], actual[kVectorLength]; rand_float(a, kVectorLength); memcpy(expect, actual, sizeof(float) * kVectorLength); sqrt_to(a, expect, kVectorLength); vsqrt_to(a, actual, kVectorLength); munit_assert_floats_equal(kVectorLength, expect, actual); return MUNIT_OK; } MunitResult vdot_test(const MunitParameter params[], void *user_data_or_fixture) { float a[kVectorLength], b[kVectorLength]; rand_float(a, kVectorLength); rand_float(b, kVectorLength); float expect = dot(a, b, kVectorLength); float actual = vdot(a, b, kVectorLength); munit_assert_float_equal(expect, actual, 5); return MUNIT_OK; } MunitResult veuclidean_test(const MunitParameter params[], void *user_data_or_fixture) { float a[kVectorLength], b[kVectorLength]; rand_float(a, kVectorLength); rand_float(b, kVectorLength); float expect = euclidean(a, b, kVectorLength); float actual = veuclidean(a, b, kVectorLength); munit_assert_float_equal(expect, actual, 5); return MUNIT_OK; } MunitTest vtests[] = { {"mul_const_add_to", vmul_const_add_to_test, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, {"mul_const_to", vmul_const_to_test, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, {"mul_const", vmul_const_test, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, {"mul_to", vmul_to_test, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, {"sqrt_to", vsqrt_to_test, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, {"dot", vdot_test, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, {"euclidean", veuclidean_test, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, {NULL, NULL, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}}; static const MunitSuite vsuite = { "v", vtests, NULL, kIteration, MUNIT_SUITE_OPTION_NONE}; int main(int argc, char *const argv[MUNIT_ARRAY_PARAM(argc + 1)]) { munit_suite_main(&vsuite, NULL, argc, argv); return 0; } #endif ================================================ FILE: common/floats/src/munit.c ================================================ /* Copyright (c) 2013-2018 Evan Nemerson * * Permission is hereby granted, free of charge, to any person * obtaining a copy of this software and associated documentation * files (the "Software"), to deal in the Software without * restriction, including without limitation the rights to use, copy, * modify, merge, publish, distribute, sublicense, and/or sell copies * of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be * included in all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ /*** Configuration ***/ /* This is just where the output from the test goes. It's really just * meant to let you choose stdout or stderr, but if anyone really want * to direct it to a file let me know, it would be fairly easy to * support. */ #if !defined(MUNIT_OUTPUT_FILE) # define MUNIT_OUTPUT_FILE stdout #endif /* This is a bit more useful; it tells µnit how to format the seconds in * timed tests. If your tests run for longer you might want to reduce * it, and if your computer is really fast and your tests are tiny you * can increase it. */ #if !defined(MUNIT_TEST_TIME_FORMAT) # define MUNIT_TEST_TIME_FORMAT "0.8f" #endif /* If you have long test names you might want to consider bumping * this. The result information takes 43 characters. */ #if !defined(MUNIT_TEST_NAME_LEN) # define MUNIT_TEST_NAME_LEN 37 #endif /* If you don't like the timing information, you can disable it by * defining MUNIT_DISABLE_TIMING. */ #if !defined(MUNIT_DISABLE_TIMING) # define MUNIT_ENABLE_TIMING #endif /*** End configuration ***/ #if defined(_POSIX_C_SOURCE) && (_POSIX_C_SOURCE < 200809L) # undef _POSIX_C_SOURCE #endif #if !defined(_POSIX_C_SOURCE) # define _POSIX_C_SOURCE 200809L #endif /* Solaris freaks out if you try to use a POSIX or SUS standard without * the "right" C standard. */ #if defined(_XOPEN_SOURCE) # undef _XOPEN_SOURCE #endif #if defined(__STDC_VERSION__) # if __STDC_VERSION__ >= 201112L # define _XOPEN_SOURCE 700 # elif __STDC_VERSION__ >= 199901L # define _XOPEN_SOURCE 600 # endif #endif /* Because, according to Microsoft, POSIX is deprecated. You've got * to appreciate the chutzpah. */ #if defined(_MSC_VER) && !defined(_CRT_NONSTDC_NO_DEPRECATE) # define _CRT_NONSTDC_NO_DEPRECATE #endif #if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) # include #elif defined(_WIN32) /* https://msdn.microsoft.com/en-us/library/tf4dy80a.aspx */ #endif #include #include #include #include #include #include #include #include #if !defined(MUNIT_NO_NL_LANGINFO) && !defined(_WIN32) #define MUNIT_NL_LANGINFO #include #include #include #endif #if !defined(_WIN32) # include # include # include #else # include # include # include # if !defined(STDERR_FILENO) # define STDERR_FILENO _fileno(stderr) # endif #endif #include "munit.h" #define MUNIT_STRINGIFY(x) #x #define MUNIT_XSTRINGIFY(x) MUNIT_STRINGIFY(x) #if defined(__GNUC__) || defined(__INTEL_COMPILER) || defined(__SUNPRO_CC) || defined(__IBMCPP__) # define MUNIT_THREAD_LOCAL __thread #elif (defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201102L)) || defined(_Thread_local) # define MUNIT_THREAD_LOCAL _Thread_local #elif defined(_WIN32) # define MUNIT_THREAD_LOCAL __declspec(thread) #endif /* MSVC 12.0 will emit a warning at /W4 for code like 'do { ... } * while (0)', or 'do { ... } while (1)'. I'm pretty sure nobody * at Microsoft compiles with /W4. */ #if defined(_MSC_VER) && (_MSC_VER <= 1800) #pragma warning(disable: 4127) #endif #if defined(_WIN32) || defined(__EMSCRIPTEN__) # define MUNIT_NO_FORK #endif #if defined(__EMSCRIPTEN__) # define MUNIT_NO_BUFFER #endif /*** Logging ***/ static MunitLogLevel munit_log_level_visible = MUNIT_LOG_INFO; static MunitLogLevel munit_log_level_fatal = MUNIT_LOG_ERROR; #if defined(MUNIT_THREAD_LOCAL) static MUNIT_THREAD_LOCAL munit_bool munit_error_jmp_buf_valid = 0; static MUNIT_THREAD_LOCAL jmp_buf munit_error_jmp_buf; #endif /* At certain warning levels, mingw will trigger warnings about * suggesting the format attribute, which we've explicitly *not* set * because it will then choke on our attempts to use the MS-specific * I64 modifier for size_t (which we have to use since MSVC doesn't * support the C99 z modifier). */ #if defined(__MINGW32__) || defined(__MINGW64__) # pragma GCC diagnostic push # pragma GCC diagnostic ignored "-Wsuggest-attribute=format" #endif MUNIT_PRINTF(5,0) static void munit_logf_exv(MunitLogLevel level, FILE* fp, const char* filename, int line, const char* format, va_list ap) { if (level < munit_log_level_visible) return; switch (level) { case MUNIT_LOG_DEBUG: fputs("Debug", fp); break; case MUNIT_LOG_INFO: fputs("Info", fp); break; case MUNIT_LOG_WARNING: fputs("Warning", fp); break; case MUNIT_LOG_ERROR: fputs("Error", fp); break; default: munit_logf_ex(MUNIT_LOG_ERROR, filename, line, "Invalid log level (%d)", level); return; } fputs(": ", fp); if (filename != NULL) fprintf(fp, "%s:%d: ", filename, line); vfprintf(fp, format, ap); fputc('\n', fp); } MUNIT_PRINTF(3,4) static void munit_logf_internal(MunitLogLevel level, FILE* fp, const char* format, ...) { va_list ap; va_start(ap, format); munit_logf_exv(level, fp, NULL, 0, format, ap); va_end(ap); } static void munit_log_internal(MunitLogLevel level, FILE* fp, const char* message) { munit_logf_internal(level, fp, "%s", message); } void munit_logf_ex(MunitLogLevel level, const char* filename, int line, const char* format, ...) { va_list ap; va_start(ap, format); munit_logf_exv(level, stderr, filename, line, format, ap); va_end(ap); if (level >= munit_log_level_fatal) { #if defined(MUNIT_THREAD_LOCAL) if (munit_error_jmp_buf_valid) longjmp(munit_error_jmp_buf, 1); #endif abort(); } } void munit_errorf_ex(const char* filename, int line, const char* format, ...) { va_list ap; va_start(ap, format); munit_logf_exv(MUNIT_LOG_ERROR, stderr, filename, line, format, ap); va_end(ap); #if defined(MUNIT_THREAD_LOCAL) if (munit_error_jmp_buf_valid) longjmp(munit_error_jmp_buf, 1); #endif abort(); } #if defined(__MINGW32__) || defined(__MINGW64__) #pragma GCC diagnostic pop #endif #if !defined(MUNIT_STRERROR_LEN) # define MUNIT_STRERROR_LEN 80 #endif static void munit_log_errno(MunitLogLevel level, FILE* fp, const char* msg) { #if defined(MUNIT_NO_STRERROR_R) || (defined(__MINGW32__) && !defined(MINGW_HAS_SECURE_API)) munit_logf_internal(level, fp, "%s: %s (%d)", msg, strerror(errno), errno); #else char munit_error_str[MUNIT_STRERROR_LEN]; munit_error_str[0] = '\0'; #if !defined(_WIN32) strerror_r(errno, munit_error_str, MUNIT_STRERROR_LEN); #else strerror_s(munit_error_str, MUNIT_STRERROR_LEN, errno); #endif munit_logf_internal(level, fp, "%s: %s (%d)", msg, munit_error_str, errno); #endif } /*** Memory allocation ***/ void* munit_malloc_ex(const char* filename, int line, size_t size) { void* ptr; if (size == 0) return NULL; ptr = calloc(1, size); if (MUNIT_UNLIKELY(ptr == NULL)) { munit_logf_ex(MUNIT_LOG_ERROR, filename, line, "Failed to allocate %" MUNIT_SIZE_MODIFIER "u bytes.", size); } return ptr; } /*** Timer code ***/ #if defined(MUNIT_ENABLE_TIMING) #define psnip_uint64_t munit_uint64_t #define psnip_uint32_t munit_uint32_t /* Code copied from portable-snippets * . If you need to * change something, please do it there so we can keep the code in * sync. */ /* Clocks (v1) * Portable Snippets - https://gitub.com/nemequ/portable-snippets * Created by Evan Nemerson * * To the extent possible under law, the authors have waived all * copyright and related or neighboring rights to this code. For * details, see the Creative Commons Zero 1.0 Universal license at * https://creativecommons.org/publicdomain/zero/1.0/ */ #if !defined(PSNIP_CLOCK_H) #define PSNIP_CLOCK_H #if !defined(psnip_uint64_t) # include "../exact-int/exact-int.h" #endif #if !defined(PSNIP_CLOCK_STATIC_INLINE) # if defined(__GNUC__) # define PSNIP_CLOCK__COMPILER_ATTRIBUTES __attribute__((__unused__)) # else # define PSNIP_CLOCK__COMPILER_ATTRIBUTES # endif # define PSNIP_CLOCK__FUNCTION PSNIP_CLOCK__COMPILER_ATTRIBUTES static #endif enum PsnipClockType { /* This clock provides the current time, in units since 1970-01-01 * 00:00:00 UTC not including leap seconds. In other words, UNIX * time. Keep in mind that this clock doesn't account for leap * seconds, and can go backwards (think NTP adjustments). */ PSNIP_CLOCK_TYPE_WALL = 1, /* The CPU time is a clock which increases only when the current * process is active (i.e., it doesn't increment while blocking on * I/O). */ PSNIP_CLOCK_TYPE_CPU = 2, /* Monotonic time is always running (unlike CPU time), but it only ever moves forward unless you reboot the system. Things like NTP adjustments have no effect on this clock. */ PSNIP_CLOCK_TYPE_MONOTONIC = 3 }; struct PsnipClockTimespec { psnip_uint64_t seconds; psnip_uint64_t nanoseconds; }; /* Methods we support: */ #define PSNIP_CLOCK_METHOD_CLOCK_GETTIME 1 #define PSNIP_CLOCK_METHOD_TIME 2 #define PSNIP_CLOCK_METHOD_GETTIMEOFDAY 3 #define PSNIP_CLOCK_METHOD_QUERYPERFORMANCECOUNTER 4 #define PSNIP_CLOCK_METHOD_MACH_ABSOLUTE_TIME 5 #define PSNIP_CLOCK_METHOD_CLOCK 6 #define PSNIP_CLOCK_METHOD_GETPROCESSTIMES 7 #define PSNIP_CLOCK_METHOD_GETRUSAGE 8 #define PSNIP_CLOCK_METHOD_GETSYSTEMTIMEPRECISEASFILETIME 9 #define PSNIP_CLOCK_METHOD_GETTICKCOUNT64 10 #include #if defined(HEDLEY_UNREACHABLE) # define PSNIP_CLOCK_UNREACHABLE() HEDLEY_UNREACHABLE() #else # define PSNIP_CLOCK_UNREACHABLE() assert(0) #endif /* Choose an implementation */ /* #undef PSNIP_CLOCK_WALL_METHOD */ /* #undef PSNIP_CLOCK_CPU_METHOD */ /* #undef PSNIP_CLOCK_MONOTONIC_METHOD */ /* We want to be able to detect the libc implementation, so we include ( isn't available everywhere). */ #if defined(__unix__) || defined(__unix) || defined(__linux__) # include # include #endif #if defined(_POSIX_TIMERS) && (_POSIX_TIMERS > 0) /* These are known to work without librt. If you know of others * please let us know so we can add them. */ # if \ (defined(__GLIBC__) && (__GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ >= 17))) || \ (defined(__FreeBSD__)) # define PSNIP_CLOCK_HAVE_CLOCK_GETTIME # elif !defined(PSNIP_CLOCK_NO_LIBRT) # define PSNIP_CLOCK_HAVE_CLOCK_GETTIME # endif #endif #if defined(_WIN32) # if !defined(PSNIP_CLOCK_CPU_METHOD) # define PSNIP_CLOCK_CPU_METHOD PSNIP_CLOCK_METHOD_GETPROCESSTIMES # endif # if !defined(PSNIP_CLOCK_MONOTONIC_METHOD) # define PSNIP_CLOCK_MONOTONIC_METHOD PSNIP_CLOCK_METHOD_QUERYPERFORMANCECOUNTER # endif #endif #if defined(__MACH__) && !defined(__gnu_hurd__) # if !defined(PSNIP_CLOCK_MONOTONIC_METHOD) # define PSNIP_CLOCK_MONOTONIC_METHOD PSNIP_CLOCK_METHOD_MACH_ABSOLUTE_TIME # endif #endif #if defined(PSNIP_CLOCK_HAVE_CLOCK_GETTIME) # include # if !defined(PSNIP_CLOCK_WALL_METHOD) # if defined(CLOCK_REALTIME_PRECISE) # define PSNIP_CLOCK_WALL_METHOD PSNIP_CLOCK_METHOD_CLOCK_GETTIME # define PSNIP_CLOCK_CLOCK_GETTIME_WALL CLOCK_REALTIME_PRECISE # elif !defined(__sun) # define PSNIP_CLOCK_WALL_METHOD PSNIP_CLOCK_METHOD_CLOCK_GETTIME # define PSNIP_CLOCK_CLOCK_GETTIME_WALL CLOCK_REALTIME # endif # endif # if !defined(PSNIP_CLOCK_CPU_METHOD) # if defined(_POSIX_CPUTIME) || defined(CLOCK_PROCESS_CPUTIME_ID) # define PSNIP_CLOCK_CPU_METHOD PSNIP_CLOCK_METHOD_CLOCK_GETTIME # define PSNIP_CLOCK_CLOCK_GETTIME_CPU CLOCK_PROCESS_CPUTIME_ID # elif defined(CLOCK_VIRTUAL) # define PSNIP_CLOCK_CPU_METHOD PSNIP_CLOCK_METHOD_CLOCK_GETTIME # define PSNIP_CLOCK_CLOCK_GETTIME_CPU CLOCK_VIRTUAL # endif # endif # if !defined(PSNIP_CLOCK_MONOTONIC_METHOD) # if defined(CLOCK_MONOTONIC_RAW) # define PSNIP_CLOCK_MONOTONIC_METHOD PSNIP_CLOCK_METHOD_CLOCK_GETTIME # define PSNIP_CLOCK_CLOCK_GETTIME_MONOTONIC CLOCK_MONOTONIC # elif defined(CLOCK_MONOTONIC_PRECISE) # define PSNIP_CLOCK_MONOTONIC_METHOD PSNIP_CLOCK_METHOD_CLOCK_GETTIME # define PSNIP_CLOCK_CLOCK_GETTIME_MONOTONIC CLOCK_MONOTONIC_PRECISE # elif defined(_POSIX_MONOTONIC_CLOCK) || defined(CLOCK_MONOTONIC) # define PSNIP_CLOCK_MONOTONIC_METHOD PSNIP_CLOCK_METHOD_CLOCK_GETTIME # define PSNIP_CLOCK_CLOCK_GETTIME_MONOTONIC CLOCK_MONOTONIC # endif # endif #endif #if defined(_POSIX_VERSION) && (_POSIX_VERSION >= 200112L) # if !defined(PSNIP_CLOCK_WALL_METHOD) # define PSNIP_CLOCK_WALL_METHOD PSNIP_CLOCK_METHOD_GETTIMEOFDAY # endif #endif #if !defined(PSNIP_CLOCK_WALL_METHOD) # define PSNIP_CLOCK_WALL_METHOD PSNIP_CLOCK_METHOD_TIME #endif #if !defined(PSNIP_CLOCK_CPU_METHOD) # define PSNIP_CLOCK_CPU_METHOD PSNIP_CLOCK_METHOD_CLOCK #endif /* Primarily here for testing. */ #if !defined(PSNIP_CLOCK_MONOTONIC_METHOD) && defined(PSNIP_CLOCK_REQUIRE_MONOTONIC) # error No monotonic clock found. #endif /* Implementations */ #if \ (defined(PSNIP_CLOCK_CPU_METHOD) && (PSNIP_CLOCK_CPU_METHOD == PSNIP_CLOCK_METHOD_CLOCK_GETTIME)) || \ (defined(PSNIP_CLOCK_WALL_METHOD) && (PSNIP_CLOCK_WALL_METHOD == PSNIP_CLOCK_METHOD_CLOCK_GETTIME)) || \ (defined(PSNIP_CLOCK_MONOTONIC_METHOD) && (PSNIP_CLOCK_MONOTONIC_METHOD == PSNIP_CLOCK_METHOD_CLOCK_GETTIME)) || \ (defined(PSNIP_CLOCK_CPU_METHOD) && (PSNIP_CLOCK_CPU_METHOD == PSNIP_CLOCK_METHOD_CLOCK)) || \ (defined(PSNIP_CLOCK_WALL_METHOD) && (PSNIP_CLOCK_WALL_METHOD == PSNIP_CLOCK_METHOD_CLOCK)) || \ (defined(PSNIP_CLOCK_MONOTONIC_METHOD) && (PSNIP_CLOCK_MONOTONIC_METHOD == PSNIP_CLOCK_METHOD_CLOCK)) || \ (defined(PSNIP_CLOCK_CPU_METHOD) && (PSNIP_CLOCK_CPU_METHOD == PSNIP_CLOCK_METHOD_TIME)) || \ (defined(PSNIP_CLOCK_WALL_METHOD) && (PSNIP_CLOCK_WALL_METHOD == PSNIP_CLOCK_METHOD_TIME)) || \ (defined(PSNIP_CLOCK_MONOTONIC_METHOD) && (PSNIP_CLOCK_MONOTONIC_METHOD == PSNIP_CLOCK_METHOD_TIME)) # include #endif #if \ (defined(PSNIP_CLOCK_CPU_METHOD) && (PSNIP_CLOCK_CPU_METHOD == PSNIP_CLOCK_METHOD_GETTIMEOFDAY)) || \ (defined(PSNIP_CLOCK_WALL_METHOD) && (PSNIP_CLOCK_WALL_METHOD == PSNIP_CLOCK_METHOD_GETTIMEOFDAY)) || \ (defined(PSNIP_CLOCK_MONOTONIC_METHOD) && (PSNIP_CLOCK_MONOTONIC_METHOD == PSNIP_CLOCK_METHOD_GETTIMEOFDAY)) # include #endif #if \ (defined(PSNIP_CLOCK_CPU_METHOD) && (PSNIP_CLOCK_CPU_METHOD == PSNIP_CLOCK_METHOD_GETPROCESSTIMES)) || \ (defined(PSNIP_CLOCK_WALL_METHOD) && (PSNIP_CLOCK_WALL_METHOD == PSNIP_CLOCK_METHOD_GETPROCESSTIMES)) || \ (defined(PSNIP_CLOCK_MONOTONIC_METHOD) && (PSNIP_CLOCK_MONOTONIC_METHOD == PSNIP_CLOCK_METHOD_GETPROCESSTIMES)) || \ (defined(PSNIP_CLOCK_CPU_METHOD) && (PSNIP_CLOCK_CPU_METHOD == PSNIP_CLOCK_METHOD_GETTICKCOUNT64)) || \ (defined(PSNIP_CLOCK_WALL_METHOD) && (PSNIP_CLOCK_WALL_METHOD == PSNIP_CLOCK_METHOD_GETTICKCOUNT64)) || \ (defined(PSNIP_CLOCK_MONOTONIC_METHOD) && (PSNIP_CLOCK_MONOTONIC_METHOD == PSNIP_CLOCK_METHOD_GETTICKCOUNT64)) # include #endif #if \ (defined(PSNIP_CLOCK_CPU_METHOD) && (PSNIP_CLOCK_CPU_METHOD == PSNIP_CLOCK_METHOD_GETRUSAGE)) || \ (defined(PSNIP_CLOCK_WALL_METHOD) && (PSNIP_CLOCK_WALL_METHOD == PSNIP_CLOCK_METHOD_GETRUSAGE)) || \ (defined(PSNIP_CLOCK_MONOTONIC_METHOD) && (PSNIP_CLOCK_MONOTONIC_METHOD == PSNIP_CLOCK_METHOD_GETRUSAGE)) # include # include #endif #if \ (defined(PSNIP_CLOCK_CPU_METHOD) && (PSNIP_CLOCK_CPU_METHOD == PSNIP_CLOCK_METHOD_MACH_ABSOLUTE_TIME)) || \ (defined(PSNIP_CLOCK_WALL_METHOD) && (PSNIP_CLOCK_WALL_METHOD == PSNIP_CLOCK_METHOD_MACH_ABSOLUTE_TIME)) || \ (defined(PSNIP_CLOCK_MONOTONIC_METHOD) && (PSNIP_CLOCK_MONOTONIC_METHOD == PSNIP_CLOCK_METHOD_MACH_ABSOLUTE_TIME)) # include # include # include #endif /*** Implementations ***/ #define PSNIP_CLOCK_NSEC_PER_SEC ((psnip_uint32_t) (1000000000ULL)) #if \ (defined(PSNIP_CLOCK_CPU_METHOD) && (PSNIP_CLOCK_CPU_METHOD == PSNIP_CLOCK_METHOD_CLOCK_GETTIME)) || \ (defined(PSNIP_CLOCK_WALL_METHOD) && (PSNIP_CLOCK_WALL_METHOD == PSNIP_CLOCK_METHOD_CLOCK_GETTIME)) || \ (defined(PSNIP_CLOCK_MONOTONIC_METHOD) && (PSNIP_CLOCK_MONOTONIC_METHOD == PSNIP_CLOCK_METHOD_CLOCK_GETTIME)) PSNIP_CLOCK__FUNCTION psnip_uint32_t psnip_clock__clock_getres (clockid_t clk_id) { struct timespec res; int r; r = clock_getres(clk_id, &res); if (r != 0) return 0; return (psnip_uint32_t) (PSNIP_CLOCK_NSEC_PER_SEC / res.tv_nsec); } PSNIP_CLOCK__FUNCTION int psnip_clock__clock_gettime (clockid_t clk_id, struct PsnipClockTimespec* res) { struct timespec ts; if (clock_gettime(clk_id, &ts) != 0) return -10; res->seconds = (psnip_uint64_t) (ts.tv_sec); res->nanoseconds = (psnip_uint64_t) (ts.tv_nsec); return 0; } #endif PSNIP_CLOCK__FUNCTION psnip_uint32_t psnip_clock_wall_get_precision (void) { #if !defined(PSNIP_CLOCK_WALL_METHOD) return 0; #elif defined(PSNIP_CLOCK_WALL_METHOD) && PSNIP_CLOCK_WALL_METHOD == PSNIP_CLOCK_METHOD_CLOCK_GETTIME return psnip_clock__clock_getres(PSNIP_CLOCK_CLOCK_GETTIME_WALL); #elif defined(PSNIP_CLOCK_WALL_METHOD) && PSNIP_CLOCK_WALL_METHOD == PSNIP_CLOCK_METHOD_GETTIMEOFDAY return 1000000; #elif defined(PSNIP_CLOCK_WALL_METHOD) && PSNIP_CLOCK_WALL_METHOD == PSNIP_CLOCK_METHOD_TIME return 1; #else return 0; #endif } PSNIP_CLOCK__FUNCTION int psnip_clock_wall_get_time (struct PsnipClockTimespec* res) { (void) res; #if !defined(PSNIP_CLOCK_WALL_METHOD) return -2; #elif defined(PSNIP_CLOCK_WALL_METHOD) && PSNIP_CLOCK_WALL_METHOD == PSNIP_CLOCK_METHOD_CLOCK_GETTIME return psnip_clock__clock_gettime(PSNIP_CLOCK_CLOCK_GETTIME_WALL, res); #elif defined(PSNIP_CLOCK_WALL_METHOD) && PSNIP_CLOCK_WALL_METHOD == PSNIP_CLOCK_METHOD_TIME res->seconds = time(NULL); res->nanoseconds = 0; #elif defined(PSNIP_CLOCK_WALL_METHOD) && PSNIP_CLOCK_WALL_METHOD == PSNIP_CLOCK_METHOD_GETTIMEOFDAY struct timeval tv; if (gettimeofday(&tv, NULL) != 0) return -6; res->seconds = tv.tv_sec; res->nanoseconds = tv.tv_usec * 1000; #else return -2; #endif return 0; } PSNIP_CLOCK__FUNCTION psnip_uint32_t psnip_clock_cpu_get_precision (void) { #if !defined(PSNIP_CLOCK_CPU_METHOD) return 0; #elif defined(PSNIP_CLOCK_CPU_METHOD) && PSNIP_CLOCK_CPU_METHOD == PSNIP_CLOCK_METHOD_CLOCK_GETTIME return psnip_clock__clock_getres(PSNIP_CLOCK_CLOCK_GETTIME_CPU); #elif defined(PSNIP_CLOCK_CPU_METHOD) && PSNIP_CLOCK_CPU_METHOD == PSNIP_CLOCK_METHOD_CLOCK return CLOCKS_PER_SEC; #elif defined(PSNIP_CLOCK_CPU_METHOD) && PSNIP_CLOCK_CPU_METHOD == PSNIP_CLOCK_METHOD_GETPROCESSTIMES return PSNIP_CLOCK_NSEC_PER_SEC / 100; #else return 0; #endif } PSNIP_CLOCK__FUNCTION int psnip_clock_cpu_get_time (struct PsnipClockTimespec* res) { #if !defined(PSNIP_CLOCK_CPU_METHOD) (void) res; return -2; #elif defined(PSNIP_CLOCK_CPU_METHOD) && PSNIP_CLOCK_CPU_METHOD == PSNIP_CLOCK_METHOD_CLOCK_GETTIME return psnip_clock__clock_gettime(PSNIP_CLOCK_CLOCK_GETTIME_CPU, res); #elif defined(PSNIP_CLOCK_CPU_METHOD) && PSNIP_CLOCK_CPU_METHOD == PSNIP_CLOCK_METHOD_CLOCK clock_t t = clock(); if (t == ((clock_t) -1)) return -5; res->seconds = t / CLOCKS_PER_SEC; res->nanoseconds = (t % CLOCKS_PER_SEC) * (PSNIP_CLOCK_NSEC_PER_SEC / CLOCKS_PER_SEC); #elif defined(PSNIP_CLOCK_CPU_METHOD) && PSNIP_CLOCK_CPU_METHOD == PSNIP_CLOCK_METHOD_GETPROCESSTIMES FILETIME CreationTime, ExitTime, KernelTime, UserTime; LARGE_INTEGER date, adjust; if (!GetProcessTimes(GetCurrentProcess(), &CreationTime, &ExitTime, &KernelTime, &UserTime)) return -7; /* http://www.frenk.com/2009/12/convert-filetime-to-unix-timestamp/ */ date.HighPart = UserTime.dwHighDateTime; date.LowPart = UserTime.dwLowDateTime; adjust.QuadPart = 11644473600000 * 10000; date.QuadPart -= adjust.QuadPart; res->seconds = date.QuadPart / 10000000; res->nanoseconds = (date.QuadPart % 10000000) * (PSNIP_CLOCK_NSEC_PER_SEC / 100); #elif PSNIP_CLOCK_CPU_METHOD == PSNIP_CLOCK_METHOD_GETRUSAGE struct rusage usage; if (getrusage(RUSAGE_SELF, &usage) != 0) return -8; res->seconds = usage.ru_utime.tv_sec; res->nanoseconds = tv.tv_usec * 1000; #else (void) res; return -2; #endif return 0; } PSNIP_CLOCK__FUNCTION psnip_uint32_t psnip_clock_monotonic_get_precision (void) { #if !defined(PSNIP_CLOCK_MONOTONIC_METHOD) return 0; #elif defined(PSNIP_CLOCK_MONOTONIC_METHOD) && PSNIP_CLOCK_MONOTONIC_METHOD == PSNIP_CLOCK_METHOD_CLOCK_GETTIME return psnip_clock__clock_getres(PSNIP_CLOCK_CLOCK_GETTIME_MONOTONIC); #elif defined(PSNIP_CLOCK_MONOTONIC_METHOD) && PSNIP_CLOCK_MONOTONIC_METHOD == PSNIP_CLOCK_METHOD_MACH_ABSOLUTE_TIME static mach_timebase_info_data_t tbi = { 0, }; if (tbi.denom == 0) mach_timebase_info(&tbi); return (psnip_uint32_t) (tbi.numer / tbi.denom); #elif defined(PSNIP_CLOCK_MONOTONIC_METHOD) && PSNIP_CLOCK_MONOTONIC_METHOD == PSNIP_CLOCK_METHOD_GETTICKCOUNT64 return 1000; #elif defined(PSNIP_CLOCK_MONOTONIC_METHOD) && PSNIP_CLOCK_MONOTONIC_METHOD == PSNIP_CLOCK_METHOD_QUERYPERFORMANCECOUNTER LARGE_INTEGER Frequency; QueryPerformanceFrequency(&Frequency); return (psnip_uint32_t) ((Frequency.QuadPart > PSNIP_CLOCK_NSEC_PER_SEC) ? PSNIP_CLOCK_NSEC_PER_SEC : Frequency.QuadPart); #else return 0; #endif } PSNIP_CLOCK__FUNCTION int psnip_clock_monotonic_get_time (struct PsnipClockTimespec* res) { #if !defined(PSNIP_CLOCK_MONOTONIC_METHOD) (void) res; return -2; #elif defined(PSNIP_CLOCK_MONOTONIC_METHOD) && PSNIP_CLOCK_MONOTONIC_METHOD == PSNIP_CLOCK_METHOD_CLOCK_GETTIME return psnip_clock__clock_gettime(PSNIP_CLOCK_CLOCK_GETTIME_MONOTONIC, res); #elif defined(PSNIP_CLOCK_MONOTONIC_METHOD) && PSNIP_CLOCK_MONOTONIC_METHOD == PSNIP_CLOCK_METHOD_MACH_ABSOLUTE_TIME psnip_uint64_t nsec = mach_absolute_time(); static mach_timebase_info_data_t tbi = { 0, }; if (tbi.denom == 0) mach_timebase_info(&tbi); nsec *= ((psnip_uint64_t) tbi.numer) / ((psnip_uint64_t) tbi.denom); res->seconds = nsec / PSNIP_CLOCK_NSEC_PER_SEC; res->nanoseconds = nsec % PSNIP_CLOCK_NSEC_PER_SEC; #elif defined(PSNIP_CLOCK_MONOTONIC_METHOD) && PSNIP_CLOCK_MONOTONIC_METHOD == PSNIP_CLOCK_METHOD_QUERYPERFORMANCECOUNTER LARGE_INTEGER t, f; if (QueryPerformanceCounter(&t) == 0) return -12; QueryPerformanceFrequency(&f); res->seconds = t.QuadPart / f.QuadPart; res->nanoseconds = t.QuadPart % f.QuadPart; if (f.QuadPart > PSNIP_CLOCK_NSEC_PER_SEC) res->nanoseconds /= f.QuadPart / PSNIP_CLOCK_NSEC_PER_SEC; else res->nanoseconds *= PSNIP_CLOCK_NSEC_PER_SEC / f.QuadPart; #elif defined(PSNIP_CLOCK_MONOTONIC_METHOD) && PSNIP_CLOCK_MONOTONIC_METHOD == PSNIP_CLOCK_METHOD_GETTICKCOUNT64 const ULONGLONG msec = GetTickCount64(); res->seconds = msec / 1000; res->nanoseconds = sec % 1000; #else return -2; #endif return 0; } /* Returns the number of ticks per second for the specified clock. * For example, a clock with millisecond precision would return 1000, * and a clock with 1 second (such as the time() function) would * return 1. * * If the requested clock isn't available, it will return 0. * Hopefully this will be rare, but if it happens to you please let us * know so we can work on finding a way to support your system. * * Note that different clocks on the same system often have a * different precisions. */ PSNIP_CLOCK__FUNCTION psnip_uint32_t psnip_clock_get_precision (enum PsnipClockType clock_type) { switch (clock_type) { case PSNIP_CLOCK_TYPE_MONOTONIC: return psnip_clock_monotonic_get_precision (); case PSNIP_CLOCK_TYPE_CPU: return psnip_clock_cpu_get_precision (); case PSNIP_CLOCK_TYPE_WALL: return psnip_clock_wall_get_precision (); } PSNIP_CLOCK_UNREACHABLE(); return 0; } /* Set the provided timespec to the requested time. Returns 0 on * success, or a negative value on failure. */ PSNIP_CLOCK__FUNCTION int psnip_clock_get_time (enum PsnipClockType clock_type, struct PsnipClockTimespec* res) { assert(res != NULL); switch (clock_type) { case PSNIP_CLOCK_TYPE_MONOTONIC: return psnip_clock_monotonic_get_time (res); case PSNIP_CLOCK_TYPE_CPU: return psnip_clock_cpu_get_time (res); case PSNIP_CLOCK_TYPE_WALL: return psnip_clock_wall_get_time (res); } return -1; } #endif /* !defined(PSNIP_CLOCK_H) */ static psnip_uint64_t munit_clock_get_elapsed(struct PsnipClockTimespec* start, struct PsnipClockTimespec* end) { psnip_uint64_t r = (end->seconds - start->seconds) * PSNIP_CLOCK_NSEC_PER_SEC; if (end->nanoseconds < start->nanoseconds) { r -= (start->nanoseconds - end->nanoseconds); } else { r += (end->nanoseconds - start->nanoseconds); } return r; } #else # include #endif /* defined(MUNIT_ENABLE_TIMING) */ /*** PRNG stuff ***/ /* This is (unless I screwed up, which is entirely possible) the * version of PCG with 32-bit state. It was chosen because it has a * small enough state that we should reliably be able to use CAS * instead of requiring a lock for thread-safety. * * If I did screw up, I probably will not bother changing it unless * there is a significant bias. It's really not important this be * particularly strong, as long as it is fairly random it's much more * important that it be reproducible, so bug reports have a better * chance of being reproducible. */ #if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201112L) && !defined(__STDC_NO_ATOMICS__) && !defined(__EMSCRIPTEN__) && (!defined(__GNUC_MINOR__) || (__GNUC__ > 4) || (__GNUC__ == 4 && __GNUC_MINOR__ > 8)) # define HAVE_STDATOMIC #elif defined(__clang__) # if __has_extension(c_atomic) # define HAVE_CLANG_ATOMICS # endif #endif /* Workaround for http://llvm.org/bugs/show_bug.cgi?id=26911 */ #if defined(__clang__) && defined(_WIN32) # undef HAVE_STDATOMIC # if defined(__c2__) # undef HAVE_CLANG_ATOMICS # endif #endif #if defined(_OPENMP) # define ATOMIC_UINT32_T uint32_t # define ATOMIC_UINT32_INIT(x) (x) #elif defined(HAVE_STDATOMIC) # include # define ATOMIC_UINT32_T _Atomic uint32_t # define ATOMIC_UINT32_INIT(x) ATOMIC_VAR_INIT(x) #elif defined(HAVE_CLANG_ATOMICS) # define ATOMIC_UINT32_T _Atomic uint32_t # define ATOMIC_UINT32_INIT(x) (x) #elif defined(_WIN32) # define ATOMIC_UINT32_T volatile LONG # define ATOMIC_UINT32_INIT(x) (x) #else # define ATOMIC_UINT32_T volatile uint32_t # define ATOMIC_UINT32_INIT(x) (x) #endif static ATOMIC_UINT32_T munit_rand_state = ATOMIC_UINT32_INIT(42); #if defined(_OPENMP) static inline void munit_atomic_store(ATOMIC_UINT32_T* dest, ATOMIC_UINT32_T value) { #pragma omp critical (munit_atomics) *dest = value; } static inline uint32_t munit_atomic_load(ATOMIC_UINT32_T* src) { int ret; #pragma omp critical (munit_atomics) ret = *src; return ret; } static inline uint32_t munit_atomic_cas(ATOMIC_UINT32_T* dest, ATOMIC_UINT32_T* expected, ATOMIC_UINT32_T desired) { munit_bool ret; #pragma omp critical (munit_atomics) { if (*dest == *expected) { *dest = desired; ret = 1; } else { ret = 0; } } return ret; } #elif defined(HAVE_STDATOMIC) # define munit_atomic_store(dest, value) atomic_store(dest, value) # define munit_atomic_load(src) atomic_load(src) # define munit_atomic_cas(dest, expected, value) atomic_compare_exchange_weak(dest, expected, value) #elif defined(HAVE_CLANG_ATOMICS) # define munit_atomic_store(dest, value) __c11_atomic_store(dest, value, __ATOMIC_SEQ_CST) # define munit_atomic_load(src) __c11_atomic_load(src, __ATOMIC_SEQ_CST) # define munit_atomic_cas(dest, expected, value) __c11_atomic_compare_exchange_weak(dest, expected, value, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST) #elif defined(__GNUC__) && (__GNUC__ > 4) || (__GNUC__ == 4 && __GNUC_MINOR__ >= 7) # define munit_atomic_store(dest, value) __atomic_store_n(dest, value, __ATOMIC_SEQ_CST) # define munit_atomic_load(src) __atomic_load_n(src, __ATOMIC_SEQ_CST) # define munit_atomic_cas(dest, expected, value) __atomic_compare_exchange_n(dest, expected, value, 1, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST) #elif defined(__GNUC__) && (__GNUC__ >= 4) # define munit_atomic_store(dest,value) do { *(dest) = (value); } while (0) # define munit_atomic_load(src) (*(src)) # define munit_atomic_cas(dest, expected, value) __sync_bool_compare_and_swap(dest, *expected, value) #elif defined(_WIN32) /* Untested */ # define munit_atomic_store(dest,value) do { *(dest) = (value); } while (0) # define munit_atomic_load(src) (*(src)) # define munit_atomic_cas(dest, expected, value) InterlockedCompareExchange((dest), (value), *(expected)) #else # warning No atomic implementation, PRNG will not be thread-safe # define munit_atomic_store(dest, value) do { *(dest) = (value); } while (0) # define munit_atomic_load(src) (*(src)) static inline munit_bool munit_atomic_cas(ATOMIC_UINT32_T* dest, ATOMIC_UINT32_T* expected, ATOMIC_UINT32_T desired) { if (*dest == *expected) { *dest = desired; return 1; } else { return 0; } } #endif #define MUNIT_PRNG_MULTIPLIER (747796405U) #define MUNIT_PRNG_INCREMENT (1729U) static munit_uint32_t munit_rand_next_state(munit_uint32_t state) { return state * MUNIT_PRNG_MULTIPLIER + MUNIT_PRNG_INCREMENT; } static munit_uint32_t munit_rand_from_state(munit_uint32_t state) { munit_uint32_t res = ((state >> ((state >> 28) + 4)) ^ state) * (277803737U); res ^= res >> 22; return res; } void munit_rand_seed(munit_uint32_t seed) { munit_uint32_t state = munit_rand_next_state(seed + MUNIT_PRNG_INCREMENT); munit_atomic_store(&munit_rand_state, state); } static munit_uint32_t munit_rand_generate_seed(void) { munit_uint32_t seed, state; #if defined(MUNIT_ENABLE_TIMING) struct PsnipClockTimespec wc = { 0, }; psnip_clock_get_time(PSNIP_CLOCK_TYPE_WALL, &wc); seed = (munit_uint32_t) wc.nanoseconds; #else seed = (munit_uint32_t) time(NULL); #endif state = munit_rand_next_state(seed + MUNIT_PRNG_INCREMENT); return munit_rand_from_state(state); } static munit_uint32_t munit_rand_state_uint32(munit_uint32_t* state) { const munit_uint32_t old = *state; *state = munit_rand_next_state(old); return munit_rand_from_state(old); } munit_uint32_t munit_rand_uint32(void) { munit_uint32_t old, state; do { old = munit_atomic_load(&munit_rand_state); state = munit_rand_next_state(old); } while (!munit_atomic_cas(&munit_rand_state, &old, state)); return munit_rand_from_state(old); } static void munit_rand_state_memory(munit_uint32_t* state, size_t size, munit_uint8_t data[MUNIT_ARRAY_PARAM(size)]) { size_t members_remaining = size / sizeof(munit_uint32_t); size_t bytes_remaining = size % sizeof(munit_uint32_t); munit_uint8_t* b = data; munit_uint32_t rv; while (members_remaining-- > 0) { rv = munit_rand_state_uint32(state); memcpy(b, &rv, sizeof(munit_uint32_t)); b += sizeof(munit_uint32_t); } if (bytes_remaining != 0) { rv = munit_rand_state_uint32(state); memcpy(b, &rv, bytes_remaining); } } void munit_rand_memory(size_t size, munit_uint8_t data[MUNIT_ARRAY_PARAM(size)]) { munit_uint32_t old, state; do { state = old = munit_atomic_load(&munit_rand_state); munit_rand_state_memory(&state, size, data); } while (!munit_atomic_cas(&munit_rand_state, &old, state)); } static munit_uint32_t munit_rand_state_at_most(munit_uint32_t* state, munit_uint32_t salt, munit_uint32_t max) { /* We want (UINT32_MAX + 1) % max, which in unsigned arithmetic is the same * as (UINT32_MAX + 1 - max) % max = -max % max. We compute -max using not * to avoid compiler warnings. */ const munit_uint32_t min = (~max + 1U) % max; munit_uint32_t x; if (max == (~((munit_uint32_t) 0U))) return munit_rand_state_uint32(state) ^ salt; max++; do { x = munit_rand_state_uint32(state) ^ salt; } while (x < min); return x % max; } static munit_uint32_t munit_rand_at_most(munit_uint32_t salt, munit_uint32_t max) { munit_uint32_t old, state; munit_uint32_t retval; do { state = old = munit_atomic_load(&munit_rand_state); retval = munit_rand_state_at_most(&state, salt, max); } while (!munit_atomic_cas(&munit_rand_state, &old, state)); return retval; } int munit_rand_int_range(int min, int max) { munit_uint64_t range = (munit_uint64_t) max - (munit_uint64_t) min; if (min > max) return munit_rand_int_range(max, min); if (range > (~((munit_uint32_t) 0U))) range = (~((munit_uint32_t) 0U)); return min + munit_rand_at_most(0, (munit_uint32_t) range); } double munit_rand_double(void) { munit_uint32_t old, state; double retval = 0.0; do { state = old = munit_atomic_load(&munit_rand_state); /* See http://mumble.net/~campbell/tmp/random_real.c for how to do * this right. Patches welcome if you feel that this is too * biased. */ retval = munit_rand_state_uint32(&state) / ((~((munit_uint32_t) 0U)) + 1.0); } while (!munit_atomic_cas(&munit_rand_state, &old, state)); return retval; } /*** Test suite handling ***/ typedef struct { unsigned int successful; unsigned int skipped; unsigned int failed; unsigned int errored; #if defined(MUNIT_ENABLE_TIMING) munit_uint64_t cpu_clock; munit_uint64_t wall_clock; #endif } MunitReport; typedef struct { const char* prefix; const MunitSuite* suite; const char** tests; munit_uint32_t seed; unsigned int iterations; MunitParameter* parameters; munit_bool single_parameter_mode; void* user_data; MunitReport report; munit_bool colorize; munit_bool fork; munit_bool show_stderr; munit_bool fatal_failures; } MunitTestRunner; const char* munit_parameters_get(const MunitParameter params[], const char* key) { const MunitParameter* param; for (param = params ; param != NULL && param->name != NULL ; param++) if (strcmp(param->name, key) == 0) return param->value; return NULL; } #if defined(MUNIT_ENABLE_TIMING) static void munit_print_time(FILE* fp, munit_uint64_t nanoseconds) { fprintf(fp, "%" MUNIT_TEST_TIME_FORMAT, ((double) nanoseconds) / ((double) PSNIP_CLOCK_NSEC_PER_SEC)); } #endif /* Add a parameter to an array of parameters. */ static MunitResult munit_parameters_add(size_t* params_size, MunitParameter* params[MUNIT_ARRAY_PARAM(*params_size)], char* name, char* value) { *params = realloc(*params, sizeof(MunitParameter) * (*params_size + 2)); if (*params == NULL) return MUNIT_ERROR; (*params)[*params_size].name = name; (*params)[*params_size].value = value; (*params_size)++; (*params)[*params_size].name = NULL; (*params)[*params_size].value = NULL; return MUNIT_OK; } /* Concatenate two strings, but just return one of the components * unaltered if the other is NULL or "". */ static char* munit_maybe_concat(size_t* len, char* prefix, char* suffix) { char* res; size_t res_l; const size_t prefix_l = prefix != NULL ? strlen(prefix) : 0; const size_t suffix_l = suffix != NULL ? strlen(suffix) : 0; if (prefix_l == 0 && suffix_l == 0) { res = NULL; res_l = 0; } else if (prefix_l == 0 && suffix_l != 0) { res = suffix; res_l = suffix_l; } else if (prefix_l != 0 && suffix_l == 0) { res = prefix; res_l = prefix_l; } else { res_l = prefix_l + suffix_l; res = malloc(res_l + 1); memcpy(res, prefix, prefix_l); memcpy(res + prefix_l, suffix, suffix_l); res[res_l] = 0; } if (len != NULL) *len = res_l; return res; } /* Possibly free a string returned by munit_maybe_concat. */ static void munit_maybe_free_concat(char* s, const char* prefix, const char* suffix) { if (prefix != s && suffix != s) free(s); } /* Cheap string hash function, just used to salt the PRNG. */ static munit_uint32_t munit_str_hash(const char* name) { const char *p; munit_uint32_t h = 5381U; for (p = name; *p != '\0'; p++) h = (h << 5) + h + *p; return h; } static void munit_splice(int from, int to) { munit_uint8_t buf[1024]; #if !defined(_WIN32) ssize_t len; ssize_t bytes_written; ssize_t write_res; #else int len; int bytes_written; int write_res; #endif do { len = read(from, buf, sizeof(buf)); if (len > 0) { bytes_written = 0; do { write_res = write(to, buf + bytes_written, len - bytes_written); if (write_res < 0) break; bytes_written += write_res; } while (bytes_written < len); } else break; } while (1); } /* This is the part that should be handled in the child process */ static MunitResult munit_test_runner_exec(MunitTestRunner* runner, const MunitTest* test, const MunitParameter params[], MunitReport* report) { unsigned int iterations = runner->iterations; MunitResult result = MUNIT_FAIL; #if defined(MUNIT_ENABLE_TIMING) struct PsnipClockTimespec wall_clock_begin = { 0, }, wall_clock_end = { 0, }; struct PsnipClockTimespec cpu_clock_begin = { 0, }, cpu_clock_end = { 0, }; #endif unsigned int i = 0; if ((test->options & MUNIT_TEST_OPTION_SINGLE_ITERATION) == MUNIT_TEST_OPTION_SINGLE_ITERATION) iterations = 1; else if (iterations == 0) iterations = runner->suite->iterations; munit_rand_seed(runner->seed); do { void* data = (test->setup == NULL) ? runner->user_data : test->setup(params, runner->user_data); #if defined(MUNIT_ENABLE_TIMING) psnip_clock_get_time(PSNIP_CLOCK_TYPE_WALL, &wall_clock_begin); psnip_clock_get_time(PSNIP_CLOCK_TYPE_CPU, &cpu_clock_begin); #endif result = test->test(params, data); #if defined(MUNIT_ENABLE_TIMING) psnip_clock_get_time(PSNIP_CLOCK_TYPE_WALL, &wall_clock_end); psnip_clock_get_time(PSNIP_CLOCK_TYPE_CPU, &cpu_clock_end); #endif if (test->tear_down != NULL) test->tear_down(data); if (MUNIT_LIKELY(result == MUNIT_OK)) { report->successful++; #if defined(MUNIT_ENABLE_TIMING) report->wall_clock += munit_clock_get_elapsed(&wall_clock_begin, &wall_clock_end); report->cpu_clock += munit_clock_get_elapsed(&cpu_clock_begin, &cpu_clock_end); #endif } else { switch ((int) result) { case MUNIT_SKIP: report->skipped++; break; case MUNIT_FAIL: report->failed++; break; case MUNIT_ERROR: report->errored++; break; default: break; } break; } } while (++i < iterations); return result; } #if defined(MUNIT_EMOTICON) # define MUNIT_RESULT_STRING_OK ":)" # define MUNIT_RESULT_STRING_SKIP ":|" # define MUNIT_RESULT_STRING_FAIL ":(" # define MUNIT_RESULT_STRING_ERROR ":o" # define MUNIT_RESULT_STRING_TODO ":/" #else # define MUNIT_RESULT_STRING_OK "OK " # define MUNIT_RESULT_STRING_SKIP "SKIP " # define MUNIT_RESULT_STRING_FAIL "FAIL " # define MUNIT_RESULT_STRING_ERROR "ERROR" # define MUNIT_RESULT_STRING_TODO "TODO " #endif static void munit_test_runner_print_color(const MunitTestRunner* runner, const char* string, char color) { if (runner->colorize) fprintf(MUNIT_OUTPUT_FILE, "\x1b[3%cm%s\x1b[39m", color, string); else fputs(string, MUNIT_OUTPUT_FILE); } #if !defined(MUNIT_NO_BUFFER) static int munit_replace_stderr(FILE* stderr_buf) { if (stderr_buf != NULL) { const int orig_stderr = dup(STDERR_FILENO); int errfd = fileno(stderr_buf); if (MUNIT_UNLIKELY(errfd == -1)) { exit(EXIT_FAILURE); } dup2(errfd, STDERR_FILENO); return orig_stderr; } return -1; } static void munit_restore_stderr(int orig_stderr) { if (orig_stderr != -1) { dup2(orig_stderr, STDERR_FILENO); close(orig_stderr); } } #endif /* !defined(MUNIT_NO_BUFFER) */ /* Run a test with the specified parameters. */ static void munit_test_runner_run_test_with_params(MunitTestRunner* runner, const MunitTest* test, const MunitParameter params[]) { MunitResult result = MUNIT_OK; MunitReport report = { 0, 0, 0, 0, #if defined(MUNIT_ENABLE_TIMING) 0, 0 #endif }; unsigned int output_l; munit_bool first; const MunitParameter* param; FILE* stderr_buf; #if !defined(MUNIT_NO_FORK) int pipefd[2]; pid_t fork_pid; int orig_stderr; ssize_t bytes_written = 0; ssize_t write_res; ssize_t bytes_read = 0; ssize_t read_res; int status = 0; pid_t changed_pid; #endif if (params != NULL) { output_l = 2; fputs(" ", MUNIT_OUTPUT_FILE); first = 1; for (param = params ; param != NULL && param->name != NULL ; param++) { if (!first) { fputs(", ", MUNIT_OUTPUT_FILE); output_l += 2; } else { first = 0; } output_l += fprintf(MUNIT_OUTPUT_FILE, "%s=%s", param->name, param->value); } while (output_l++ < MUNIT_TEST_NAME_LEN) { fputc(' ', MUNIT_OUTPUT_FILE); } } fflush(MUNIT_OUTPUT_FILE); stderr_buf = NULL; #if !defined(_WIN32) || defined(__MINGW32__) stderr_buf = tmpfile(); #else tmpfile_s(&stderr_buf); #endif if (stderr_buf == NULL) { munit_log_errno(MUNIT_LOG_ERROR, stderr, "unable to create buffer for stderr"); result = MUNIT_ERROR; goto print_result; } #if !defined(MUNIT_NO_FORK) if (runner->fork) { pipefd[0] = -1; pipefd[1] = -1; if (pipe(pipefd) != 0) { munit_log_errno(MUNIT_LOG_ERROR, stderr, "unable to create pipe"); result = MUNIT_ERROR; goto print_result; } fork_pid = fork(); if (fork_pid == 0) { close(pipefd[0]); orig_stderr = munit_replace_stderr(stderr_buf); munit_test_runner_exec(runner, test, params, &report); /* Note that we don't restore stderr. This is so we can buffer * things written to stderr later on (such as by * asan/tsan/ubsan, valgrind, etc.) */ close(orig_stderr); do { write_res = write(pipefd[1], ((munit_uint8_t*) (&report)) + bytes_written, sizeof(report) - bytes_written); if (write_res < 0) { if (stderr_buf != NULL) { munit_log_errno(MUNIT_LOG_ERROR, stderr, "unable to write to pipe"); } exit(EXIT_FAILURE); } bytes_written += write_res; } while ((size_t) bytes_written < sizeof(report)); if (stderr_buf != NULL) fclose(stderr_buf); close(pipefd[1]); exit(EXIT_SUCCESS); } else if (fork_pid == -1) { close(pipefd[0]); close(pipefd[1]); if (stderr_buf != NULL) { munit_log_errno(MUNIT_LOG_ERROR, stderr, "unable to fork"); } report.errored++; result = MUNIT_ERROR; } else { close(pipefd[1]); do { read_res = read(pipefd[0], ((munit_uint8_t*) (&report)) + bytes_read, sizeof(report) - bytes_read); if (read_res < 1) break; bytes_read += read_res; } while (bytes_read < (ssize_t) sizeof(report)); changed_pid = waitpid(fork_pid, &status, 0); if (MUNIT_LIKELY(changed_pid == fork_pid) && MUNIT_LIKELY(WIFEXITED(status))) { if (bytes_read != sizeof(report)) { munit_logf_internal(MUNIT_LOG_ERROR, stderr_buf, "child exited unexpectedly with status %d", WEXITSTATUS(status)); report.errored++; } else if (WEXITSTATUS(status) != EXIT_SUCCESS) { munit_logf_internal(MUNIT_LOG_ERROR, stderr_buf, "child exited with status %d", WEXITSTATUS(status)); report.errored++; } } else { if (WIFSIGNALED(status)) { #if defined(_XOPEN_VERSION) && (_XOPEN_VERSION >= 700) munit_logf_internal(MUNIT_LOG_ERROR, stderr_buf, "child killed by signal %d (%s)", WTERMSIG(status), strsignal(WTERMSIG(status))); #else munit_logf_internal(MUNIT_LOG_ERROR, stderr_buf, "child killed by signal %d", WTERMSIG(status)); #endif } else if (WIFSTOPPED(status)) { munit_logf_internal(MUNIT_LOG_ERROR, stderr_buf, "child stopped by signal %d", WSTOPSIG(status)); } report.errored++; } close(pipefd[0]); waitpid(fork_pid, NULL, 0); } } else #endif { #if !defined(MUNIT_NO_BUFFER) const volatile int orig_stderr = munit_replace_stderr(stderr_buf); #endif #if defined(MUNIT_THREAD_LOCAL) if (MUNIT_UNLIKELY(setjmp(munit_error_jmp_buf) != 0)) { result = MUNIT_FAIL; report.failed++; } else { munit_error_jmp_buf_valid = 1; result = munit_test_runner_exec(runner, test, params, &report); } #else result = munit_test_runner_exec(runner, test, params, &report); #endif #if !defined(MUNIT_NO_BUFFER) munit_restore_stderr(orig_stderr); #endif /* Here just so that the label is used on Windows and we don't get * a warning */ goto print_result; } print_result: fputs("[ ", MUNIT_OUTPUT_FILE); if ((test->options & MUNIT_TEST_OPTION_TODO) == MUNIT_TEST_OPTION_TODO) { if (report.failed != 0 || report.errored != 0 || report.skipped != 0) { munit_test_runner_print_color(runner, MUNIT_RESULT_STRING_TODO, '3'); result = MUNIT_OK; } else { munit_test_runner_print_color(runner, MUNIT_RESULT_STRING_ERROR, '1'); if (MUNIT_LIKELY(stderr_buf != NULL)) munit_log_internal(MUNIT_LOG_ERROR, stderr_buf, "Test marked TODO, but was successful."); runner->report.failed++; result = MUNIT_ERROR; } } else if (report.failed > 0) { munit_test_runner_print_color(runner, MUNIT_RESULT_STRING_FAIL, '1'); runner->report.failed++; result = MUNIT_FAIL; } else if (report.errored > 0) { munit_test_runner_print_color(runner, MUNIT_RESULT_STRING_ERROR, '1'); runner->report.errored++; result = MUNIT_ERROR; } else if (report.skipped > 0) { munit_test_runner_print_color(runner, MUNIT_RESULT_STRING_SKIP, '3'); runner->report.skipped++; result = MUNIT_SKIP; } else if (report.successful > 1) { munit_test_runner_print_color(runner, MUNIT_RESULT_STRING_OK, '2'); #if defined(MUNIT_ENABLE_TIMING) fputs(" ] [ ", MUNIT_OUTPUT_FILE); munit_print_time(MUNIT_OUTPUT_FILE, report.wall_clock / report.successful); fputs(" / ", MUNIT_OUTPUT_FILE); munit_print_time(MUNIT_OUTPUT_FILE, report.cpu_clock / report.successful); fprintf(MUNIT_OUTPUT_FILE, " CPU ]\n %-" MUNIT_XSTRINGIFY(MUNIT_TEST_NAME_LEN) "s Total: [ ", ""); munit_print_time(MUNIT_OUTPUT_FILE, report.wall_clock); fputs(" / ", MUNIT_OUTPUT_FILE); munit_print_time(MUNIT_OUTPUT_FILE, report.cpu_clock); fputs(" CPU", MUNIT_OUTPUT_FILE); #endif runner->report.successful++; result = MUNIT_OK; } else if (report.successful > 0) { munit_test_runner_print_color(runner, MUNIT_RESULT_STRING_OK, '2'); #if defined(MUNIT_ENABLE_TIMING) fputs(" ] [ ", MUNIT_OUTPUT_FILE); munit_print_time(MUNIT_OUTPUT_FILE, report.wall_clock); fputs(" / ", MUNIT_OUTPUT_FILE); munit_print_time(MUNIT_OUTPUT_FILE, report.cpu_clock); fputs(" CPU", MUNIT_OUTPUT_FILE); #endif runner->report.successful++; result = MUNIT_OK; } fputs(" ]\n", MUNIT_OUTPUT_FILE); if (stderr_buf != NULL) { if (result == MUNIT_FAIL || result == MUNIT_ERROR || runner->show_stderr) { fflush(MUNIT_OUTPUT_FILE); rewind(stderr_buf); munit_splice(fileno(stderr_buf), STDERR_FILENO); fflush(stderr); } fclose(stderr_buf); } } static void munit_test_runner_run_test_wild(MunitTestRunner* runner, const MunitTest* test, const char* test_name, MunitParameter* params, MunitParameter* p) { const MunitParameterEnum* pe; char** values; MunitParameter* next; for (pe = test->parameters ; pe != NULL && pe->name != NULL ; pe++) { if (p->name == pe->name) break; } if (pe == NULL) return; for (values = pe->values ; *values != NULL ; values++) { next = p + 1; p->value = *values; if (next->name == NULL) { munit_test_runner_run_test_with_params(runner, test, params); } else { munit_test_runner_run_test_wild(runner, test, test_name, params, next); } if (runner->fatal_failures && (runner->report.failed != 0 || runner->report.errored != 0)) break; } } /* Run a single test, with every combination of parameters * requested. */ static void munit_test_runner_run_test(MunitTestRunner* runner, const MunitTest* test, const char* prefix) { char* test_name = munit_maybe_concat(NULL, (char*) prefix, (char*) test->name); /* The array of parameters to pass to * munit_test_runner_run_test_with_params */ MunitParameter* params = NULL; size_t params_l = 0; /* Wildcard parameters are parameters which have possible values * specified in the test, but no specific value was passed to the * CLI. That means we want to run the test once for every * possible combination of parameter values or, if --single was * passed to the CLI, a single time with a random set of * parameters. */ MunitParameter* wild_params = NULL; size_t wild_params_l = 0; const MunitParameterEnum* pe; const MunitParameter* cli_p; munit_bool filled; unsigned int possible; char** vals; size_t first_wild; const MunitParameter* wp; int pidx; munit_rand_seed(runner->seed); fprintf(MUNIT_OUTPUT_FILE, "%-" MUNIT_XSTRINGIFY(MUNIT_TEST_NAME_LEN) "s", test_name); if (test->parameters == NULL) { /* No parameters. Simple, nice. */ munit_test_runner_run_test_with_params(runner, test, NULL); } else { fputc('\n', MUNIT_OUTPUT_FILE); for (pe = test->parameters ; pe != NULL && pe->name != NULL ; pe++) { /* Did we received a value for this parameter from the CLI? */ filled = 0; for (cli_p = runner->parameters ; cli_p != NULL && cli_p->name != NULL ; cli_p++) { if (strcmp(cli_p->name, pe->name) == 0) { if (MUNIT_UNLIKELY(munit_parameters_add(¶ms_l, ¶ms, pe->name, cli_p->value) != MUNIT_OK)) goto cleanup; filled = 1; break; } } if (filled) continue; /* Nothing from CLI, is the enum NULL/empty? We're not a * fuzzer… */ if (pe->values == NULL || pe->values[0] == NULL) continue; /* If --single was passed to the CLI, choose a value from the * list of possibilities randomly. */ if (runner->single_parameter_mode) { possible = 0; for (vals = pe->values ; *vals != NULL ; vals++) possible++; /* We want the tests to be reproducible, even if you're only * running a single test, but we don't want every test with * the same number of parameters to choose the same parameter * number, so use the test name as a primitive salt. */ pidx = munit_rand_at_most(munit_str_hash(test_name), possible - 1); if (MUNIT_UNLIKELY(munit_parameters_add(¶ms_l, ¶ms, pe->name, pe->values[pidx]) != MUNIT_OK)) goto cleanup; } else { /* We want to try every permutation. Put in a placeholder * entry, we'll iterate through them later. */ if (MUNIT_UNLIKELY(munit_parameters_add(&wild_params_l, &wild_params, pe->name, NULL) != MUNIT_OK)) goto cleanup; } } if (wild_params_l != 0) { first_wild = params_l; for (wp = wild_params ; wp != NULL && wp->name != NULL ; wp++) { for (pe = test->parameters ; pe != NULL && pe->name != NULL && pe->values != NULL ; pe++) { if (strcmp(wp->name, pe->name) == 0) { if (MUNIT_UNLIKELY(munit_parameters_add(¶ms_l, ¶ms, pe->name, pe->values[0]) != MUNIT_OK)) goto cleanup; } } } munit_test_runner_run_test_wild(runner, test, test_name, params, params + first_wild); } else { munit_test_runner_run_test_with_params(runner, test, params); } cleanup: free(params); free(wild_params); } munit_maybe_free_concat(test_name, prefix, test->name); } /* Recurse through the suite and run all the tests. If a list of * tests to run was provied on the command line, run only those * tests. */ static void munit_test_runner_run_suite(MunitTestRunner* runner, const MunitSuite* suite, const char* prefix) { size_t pre_l; char* pre = munit_maybe_concat(&pre_l, (char*) prefix, (char*) suite->prefix); const MunitTest* test; const char** test_name; const MunitSuite* child_suite; /* Run the tests. */ for (test = suite->tests ; test != NULL && test->test != NULL ; test++) { if (runner->tests != NULL) { /* Specific tests were requested on the CLI */ for (test_name = runner->tests ; test_name != NULL && *test_name != NULL ; test_name++) { if ((pre_l == 0 || strncmp(pre, *test_name, pre_l) == 0) && strncmp(test->name, *test_name + pre_l, strlen(*test_name + pre_l)) == 0) { munit_test_runner_run_test(runner, test, pre); if (runner->fatal_failures && (runner->report.failed != 0 || runner->report.errored != 0)) goto cleanup; } } } else { /* Run all tests */ munit_test_runner_run_test(runner, test, pre); } } if (runner->fatal_failures && (runner->report.failed != 0 || runner->report.errored != 0)) goto cleanup; /* Run any child suites. */ for (child_suite = suite->suites ; child_suite != NULL && child_suite->prefix != NULL ; child_suite++) { munit_test_runner_run_suite(runner, child_suite, pre); } cleanup: munit_maybe_free_concat(pre, prefix, suite->prefix); } static void munit_test_runner_run(MunitTestRunner* runner) { munit_test_runner_run_suite(runner, runner->suite, NULL); } static void munit_print_help(int argc, char* const argv[MUNIT_ARRAY_PARAM(argc + 1)], void* user_data, const MunitArgument arguments[]) { const MunitArgument* arg; (void) argc; printf("USAGE: %s [OPTIONS...] [TEST...]\n\n", argv[0]); puts(" --seed SEED\n" " Value used to seed the PRNG. Must be a 32-bit integer in decimal\n" " notation with no separators (commas, decimals, spaces, etc.), or\n" " hexadecimal prefixed by \"0x\".\n" " --iterations N\n" " Run each test N times. 0 means the default number.\n" " --param name value\n" " A parameter key/value pair which will be passed to any test with\n" " takes a parameter of that name. If not provided, the test will be\n" " run once for each possible parameter value.\n" " --list Write a list of all available tests.\n" " --list-params\n" " Write a list of all available tests and their possible parameters.\n" " --single Run each parameterized test in a single configuration instead of\n" " every possible combination\n" " --log-visible debug|info|warning|error\n" " --log-fatal debug|info|warning|error\n" " Set the level at which messages of different severities are visible,\n" " or cause the test to terminate.\n" #if !defined(MUNIT_NO_FORK) " --no-fork Do not execute tests in a child process. If this option is supplied\n" " and a test crashes (including by failing an assertion), no further\n" " tests will be performed.\n" #endif " --fatal-failures\n" " Stop executing tests as soon as a failure is found.\n" " --show-stderr\n" " Show data written to stderr by the tests, even if the test succeeds.\n" " --color auto|always|never\n" " Colorize (or don't) the output.\n" /* 12345678901234567890123456789012345678901234567890123456789012345678901234567890 */ " --help Print this help message and exit.\n"); #if defined(MUNIT_NL_LANGINFO) setlocale(LC_ALL, ""); fputs((strcasecmp("UTF-8", nl_langinfo(CODESET)) == 0) ? "µnit" : "munit", stdout); #else puts("munit"); #endif printf(" %d.%d.%d\n" "Full documentation at: https://nemequ.github.io/munit/\n", (MUNIT_CURRENT_VERSION >> 16) & 0xff, (MUNIT_CURRENT_VERSION >> 8) & 0xff, (MUNIT_CURRENT_VERSION >> 0) & 0xff); for (arg = arguments ; arg != NULL && arg->name != NULL ; arg++) arg->write_help(arg, user_data); } static const MunitArgument* munit_arguments_find(const MunitArgument arguments[], const char* name) { const MunitArgument* arg; for (arg = arguments ; arg != NULL && arg->name != NULL ; arg++) if (strcmp(arg->name, name) == 0) return arg; return NULL; } static void munit_suite_list_tests(const MunitSuite* suite, munit_bool show_params, const char* prefix) { size_t pre_l; char* pre = munit_maybe_concat(&pre_l, (char*) prefix, (char*) suite->prefix); const MunitTest* test; const MunitParameterEnum* params; munit_bool first; char** val; const MunitSuite* child_suite; for (test = suite->tests ; test != NULL && test->name != NULL ; test++) { if (pre != NULL) fputs(pre, stdout); puts(test->name); if (show_params) { for (params = test->parameters ; params != NULL && params->name != NULL ; params++) { fprintf(stdout, " - %s: ", params->name); if (params->values == NULL) { puts("Any"); } else { first = 1; for (val = params->values ; *val != NULL ; val++ ) { if(!first) { fputs(", ", stdout); } else { first = 0; } fputs(*val, stdout); } putc('\n', stdout); } } } } for (child_suite = suite->suites ; child_suite != NULL && child_suite->prefix != NULL ; child_suite++) { munit_suite_list_tests(child_suite, show_params, pre); } munit_maybe_free_concat(pre, prefix, suite->prefix); } static munit_bool munit_stream_supports_ansi(FILE *stream) { #if !defined(_WIN32) return isatty(fileno(stream)); #else #if !defined(__MINGW32__) size_t ansicon_size = 0; #endif if (isatty(fileno(stream))) { #if !defined(__MINGW32__) getenv_s(&ansicon_size, NULL, 0, "ANSICON"); return ansicon_size != 0; #else return getenv("ANSICON") != NULL; #endif } return 0; #endif } int munit_suite_main_custom(const MunitSuite* suite, void* user_data, int argc, char* const argv[MUNIT_ARRAY_PARAM(argc + 1)], const MunitArgument arguments[]) { int result = EXIT_FAILURE; MunitTestRunner runner; size_t parameters_size = 0; size_t tests_size = 0; int arg; char* envptr; unsigned long ts; char* endptr; unsigned long long iterations; MunitLogLevel level; const MunitArgument* argument; const char** runner_tests; unsigned int tests_run; unsigned int tests_total; runner.prefix = NULL; runner.suite = NULL; runner.tests = NULL; runner.seed = 0; runner.iterations = 0; runner.parameters = NULL; runner.single_parameter_mode = 0; runner.user_data = NULL; runner.report.successful = 0; runner.report.skipped = 0; runner.report.failed = 0; runner.report.errored = 0; #if defined(MUNIT_ENABLE_TIMING) runner.report.cpu_clock = 0; runner.report.wall_clock = 0; #endif runner.colorize = 0; #if !defined(_WIN32) runner.fork = 1; #else runner.fork = 0; #endif runner.show_stderr = 0; runner.fatal_failures = 0; runner.suite = suite; runner.user_data = user_data; runner.seed = munit_rand_generate_seed(); runner.colorize = munit_stream_supports_ansi(MUNIT_OUTPUT_FILE); for (arg = 1 ; arg < argc ; arg++) { if (strncmp("--", argv[arg], 2) == 0) { if (strcmp("seed", argv[arg] + 2) == 0) { if (arg + 1 >= argc) { munit_logf_internal(MUNIT_LOG_ERROR, stderr, "%s requires an argument", argv[arg]); goto cleanup; } envptr = argv[arg + 1]; ts = strtoul(argv[arg + 1], &envptr, 0); if (*envptr != '\0' || ts > (~((munit_uint32_t) 0U))) { munit_logf_internal(MUNIT_LOG_ERROR, stderr, "invalid value ('%s') passed to %s", argv[arg + 1], argv[arg]); goto cleanup; } runner.seed = (munit_uint32_t) ts; arg++; } else if (strcmp("iterations", argv[arg] + 2) == 0) { if (arg + 1 >= argc) { munit_logf_internal(MUNIT_LOG_ERROR, stderr, "%s requires an argument", argv[arg]); goto cleanup; } endptr = argv[arg + 1]; iterations = strtoul(argv[arg + 1], &endptr, 0); if (*endptr != '\0' || iterations > UINT_MAX) { munit_logf_internal(MUNIT_LOG_ERROR, stderr, "invalid value ('%s') passed to %s", argv[arg + 1], argv[arg]); goto cleanup; } runner.iterations = (unsigned int) iterations; arg++; } else if (strcmp("param", argv[arg] + 2) == 0) { if (arg + 2 >= argc) { munit_logf_internal(MUNIT_LOG_ERROR, stderr, "%s requires two arguments", argv[arg]); goto cleanup; } runner.parameters = realloc(runner.parameters, sizeof(MunitParameter) * (parameters_size + 2)); if (runner.parameters == NULL) { munit_log_internal(MUNIT_LOG_ERROR, stderr, "failed to allocate memory"); goto cleanup; } runner.parameters[parameters_size].name = (char*) argv[arg + 1]; runner.parameters[parameters_size].value = (char*) argv[arg + 2]; parameters_size++; runner.parameters[parameters_size].name = NULL; runner.parameters[parameters_size].value = NULL; arg += 2; } else if (strcmp("color", argv[arg] + 2) == 0) { if (arg + 1 >= argc) { munit_logf_internal(MUNIT_LOG_ERROR, stderr, "%s requires an argument", argv[arg]); goto cleanup; } if (strcmp(argv[arg + 1], "always") == 0) runner.colorize = 1; else if (strcmp(argv[arg + 1], "never") == 0) runner.colorize = 0; else if (strcmp(argv[arg + 1], "auto") == 0) runner.colorize = munit_stream_supports_ansi(MUNIT_OUTPUT_FILE); else { munit_logf_internal(MUNIT_LOG_ERROR, stderr, "invalid value ('%s') passed to %s", argv[arg + 1], argv[arg]); goto cleanup; } arg++; } else if (strcmp("help", argv[arg] + 2) == 0) { munit_print_help(argc, argv, user_data, arguments); result = EXIT_SUCCESS; goto cleanup; } else if (strcmp("single", argv[arg] + 2) == 0) { runner.single_parameter_mode = 1; } else if (strcmp("show-stderr", argv[arg] + 2) == 0) { runner.show_stderr = 1; #if !defined(_WIN32) } else if (strcmp("no-fork", argv[arg] + 2) == 0) { runner.fork = 0; #endif } else if (strcmp("fatal-failures", argv[arg] + 2) == 0) { runner.fatal_failures = 1; } else if (strcmp("log-visible", argv[arg] + 2) == 0 || strcmp("log-fatal", argv[arg] + 2) == 0) { if (arg + 1 >= argc) { munit_logf_internal(MUNIT_LOG_ERROR, stderr, "%s requires an argument", argv[arg]); goto cleanup; } if (strcmp(argv[arg + 1], "debug") == 0) level = MUNIT_LOG_DEBUG; else if (strcmp(argv[arg + 1], "info") == 0) level = MUNIT_LOG_INFO; else if (strcmp(argv[arg + 1], "warning") == 0) level = MUNIT_LOG_WARNING; else if (strcmp(argv[arg + 1], "error") == 0) level = MUNIT_LOG_ERROR; else { munit_logf_internal(MUNIT_LOG_ERROR, stderr, "invalid value ('%s') passed to %s", argv[arg + 1], argv[arg]); goto cleanup; } if (strcmp("log-visible", argv[arg] + 2) == 0) munit_log_level_visible = level; else munit_log_level_fatal = level; arg++; } else if (strcmp("list", argv[arg] + 2) == 0) { munit_suite_list_tests(suite, 0, NULL); result = EXIT_SUCCESS; goto cleanup; } else if (strcmp("list-params", argv[arg] + 2) == 0) { munit_suite_list_tests(suite, 1, NULL); result = EXIT_SUCCESS; goto cleanup; } else { argument = munit_arguments_find(arguments, argv[arg] + 2); if (argument == NULL) { munit_logf_internal(MUNIT_LOG_ERROR, stderr, "unknown argument ('%s')", argv[arg]); goto cleanup; } if (!argument->parse_argument(suite, user_data, &arg, argc, argv)) goto cleanup; } } else { runner_tests = realloc((void*) runner.tests, sizeof(char*) * (tests_size + 2)); if (runner_tests == NULL) { munit_log_internal(MUNIT_LOG_ERROR, stderr, "failed to allocate memory"); goto cleanup; } runner.tests = runner_tests; runner.tests[tests_size++] = argv[arg]; runner.tests[tests_size] = NULL; } } fflush(stderr); fprintf(MUNIT_OUTPUT_FILE, "Running test suite with seed 0x%08" PRIx32 "...\n", runner.seed); munit_test_runner_run(&runner); tests_run = runner.report.successful + runner.report.failed + runner.report.errored; tests_total = tests_run + runner.report.skipped; if (tests_run == 0) { fprintf(stderr, "No tests run, %d (100%%) skipped.\n", runner.report.skipped); } else { fprintf(MUNIT_OUTPUT_FILE, "%d of %d (%0.0f%%) tests successful, %d (%0.0f%%) test skipped.\n", runner.report.successful, tests_run, (((double) runner.report.successful) / ((double) tests_run)) * 100.0, runner.report.skipped, (((double) runner.report.skipped) / ((double) tests_total)) * 100.0); } if (runner.report.failed == 0 && runner.report.errored == 0) { result = EXIT_SUCCESS; } cleanup: free(runner.parameters); free((void*) runner.tests); return result; } int munit_suite_main(const MunitSuite* suite, void* user_data, int argc, char* const argv[MUNIT_ARRAY_PARAM(argc + 1)]) { return munit_suite_main_custom(suite, user_data, argc, argv, NULL); } ================================================ FILE: common/floats/src/munit.h ================================================ /* µnit Testing Framework * Copyright (c) 2013-2017 Evan Nemerson * * Permission is hereby granted, free of charge, to any person * obtaining a copy of this software and associated documentation * files (the "Software"), to deal in the Software without * restriction, including without limitation the rights to use, copy, * modify, merge, publish, distribute, sublicense, and/or sell copies * of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be * included in all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ #if !defined(MUNIT_H) #define MUNIT_H #include #include #define MUNIT_VERSION(major, minor, revision) \ (((major) << 16) | ((minor) << 8) | (revision)) #define MUNIT_CURRENT_VERSION MUNIT_VERSION(0, 4, 1) #if defined(_MSC_VER) && (_MSC_VER < 1600) # define munit_int8_t __int8 # define munit_uint8_t unsigned __int8 # define munit_int16_t __int16 # define munit_uint16_t unsigned __int16 # define munit_int32_t __int32 # define munit_uint32_t unsigned __int32 # define munit_int64_t __int64 # define munit_uint64_t unsigned __int64 #else # include # define munit_int8_t int8_t # define munit_uint8_t uint8_t # define munit_int16_t int16_t # define munit_uint16_t uint16_t # define munit_int32_t int32_t # define munit_uint32_t uint32_t # define munit_int64_t int64_t # define munit_uint64_t uint64_t #endif #if defined(_MSC_VER) && (_MSC_VER < 1800) # if !defined(PRIi8) # define PRIi8 "i" # endif # if !defined(PRIi16) # define PRIi16 "i" # endif # if !defined(PRIi32) # define PRIi32 "i" # endif # if !defined(PRIi64) # define PRIi64 "I64i" # endif # if !defined(PRId8) # define PRId8 "d" # endif # if !defined(PRId16) # define PRId16 "d" # endif # if !defined(PRId32) # define PRId32 "d" # endif # if !defined(PRId64) # define PRId64 "I64d" # endif # if !defined(PRIx8) # define PRIx8 "x" # endif # if !defined(PRIx16) # define PRIx16 "x" # endif # if !defined(PRIx32) # define PRIx32 "x" # endif # if !defined(PRIx64) # define PRIx64 "I64x" # endif # if !defined(PRIu8) # define PRIu8 "u" # endif # if !defined(PRIu16) # define PRIu16 "u" # endif # if !defined(PRIu32) # define PRIu32 "u" # endif # if !defined(PRIu64) # define PRIu64 "I64u" # endif #else # include #endif #if !defined(munit_bool) # if defined(bool) # define munit_bool bool # elif defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) # define munit_bool _Bool # else # define munit_bool int # endif #endif #if defined(__cplusplus) extern "C" { #endif #if defined(__GNUC__) # define MUNIT_LIKELY(expr) (__builtin_expect ((expr), 1)) # define MUNIT_UNLIKELY(expr) (__builtin_expect ((expr), 0)) # define MUNIT_UNUSED __attribute__((__unused__)) #else # define MUNIT_LIKELY(expr) (expr) # define MUNIT_UNLIKELY(expr) (expr) # define MUNIT_UNUSED #endif #if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) && !defined(__PGI) # define MUNIT_ARRAY_PARAM(name) name #else # define MUNIT_ARRAY_PARAM(name) #endif #if !defined(_WIN32) # define MUNIT_SIZE_MODIFIER "z" # define MUNIT_CHAR_MODIFIER "hh" # define MUNIT_SHORT_MODIFIER "h" #else # if defined(_M_X64) || defined(__amd64__) # define MUNIT_SIZE_MODIFIER "I64" # else # define MUNIT_SIZE_MODIFIER "" # endif # define MUNIT_CHAR_MODIFIER "" # define MUNIT_SHORT_MODIFIER "" #endif #if defined(__STDC_VERSION__) && __STDC_VERSION__ >= 201112L # define MUNIT_NO_RETURN _Noreturn #elif defined(__GNUC__) # define MUNIT_NO_RETURN __attribute__((__noreturn__)) #elif defined(_MSC_VER) # define MUNIT_NO_RETURN __declspec(noreturn) #else # define MUNIT_NO_RETURN #endif #if defined(_MSC_VER) && (_MSC_VER >= 1500) # define MUNIT_PUSH_DISABLE_MSVC_C4127_ __pragma(warning(push)) __pragma(warning(disable:4127)) # define MUNIT_POP_DISABLE_MSVC_C4127_ __pragma(warning(pop)) #else # define MUNIT_PUSH_DISABLE_MSVC_C4127_ # define MUNIT_POP_DISABLE_MSVC_C4127_ #endif typedef enum { MUNIT_LOG_DEBUG, MUNIT_LOG_INFO, MUNIT_LOG_WARNING, MUNIT_LOG_ERROR } MunitLogLevel; #if defined(__GNUC__) && !defined(__MINGW32__) # define MUNIT_PRINTF(string_index, first_to_check) __attribute__((format (printf, string_index, first_to_check))) #else # define MUNIT_PRINTF(string_index, first_to_check) #endif MUNIT_PRINTF(4, 5) void munit_logf_ex(MunitLogLevel level, const char* filename, int line, const char* format, ...); #define munit_logf(level, format, ...) \ munit_logf_ex(level, __FILE__, __LINE__, format, __VA_ARGS__) #define munit_log(level, msg) \ munit_logf(level, "%s", msg) MUNIT_NO_RETURN MUNIT_PRINTF(3, 4) void munit_errorf_ex(const char* filename, int line, const char* format, ...); #define munit_errorf(format, ...) \ munit_errorf_ex(__FILE__, __LINE__, format, __VA_ARGS__) #define munit_error(msg) \ munit_errorf("%s", msg) #define munit_assert(expr) \ do { \ if (!MUNIT_LIKELY(expr)) { \ munit_error("assertion failed: " #expr); \ } \ MUNIT_PUSH_DISABLE_MSVC_C4127_ \ } while (0) \ MUNIT_POP_DISABLE_MSVC_C4127_ #define munit_assert_true(expr) \ do { \ if (!MUNIT_LIKELY(expr)) { \ munit_error("assertion failed: " #expr " is not true"); \ } \ MUNIT_PUSH_DISABLE_MSVC_C4127_ \ } while (0) \ MUNIT_POP_DISABLE_MSVC_C4127_ #define munit_assert_false(expr) \ do { \ if (!MUNIT_LIKELY(!(expr))) { \ munit_error("assertion failed: " #expr " is not false"); \ } \ MUNIT_PUSH_DISABLE_MSVC_C4127_ \ } while (0) \ MUNIT_POP_DISABLE_MSVC_C4127_ #define munit_assert_type_full(prefix, suffix, T, fmt, a, op, b) \ do { \ T munit_tmp_a_ = (a); \ T munit_tmp_b_ = (b); \ if (!(munit_tmp_a_ op munit_tmp_b_)) { \ munit_errorf("assertion failed: %s %s %s (" prefix "%" fmt suffix " %s " prefix "%" fmt suffix ")", \ #a, #op, #b, munit_tmp_a_, #op, munit_tmp_b_); \ } \ MUNIT_PUSH_DISABLE_MSVC_C4127_ \ } while (0) \ MUNIT_POP_DISABLE_MSVC_C4127_ #define munit_assert_type(T, fmt, a, op, b) \ munit_assert_type_full("", "", T, fmt, a, op, b) #define munit_assert_char(a, op, b) \ munit_assert_type_full("'\\x", "'", char, "02" MUNIT_CHAR_MODIFIER "x", a, op, b) #define munit_assert_uchar(a, op, b) \ munit_assert_type_full("'\\x", "'", unsigned char, "02" MUNIT_CHAR_MODIFIER "x", a, op, b) #define munit_assert_short(a, op, b) \ munit_assert_type(short, MUNIT_SHORT_MODIFIER "d", a, op, b) #define munit_assert_ushort(a, op, b) \ munit_assert_type(unsigned short, MUNIT_SHORT_MODIFIER "u", a, op, b) #define munit_assert_int(a, op, b) \ munit_assert_type(int, "d", a, op, b) #define munit_assert_uint(a, op, b) \ munit_assert_type(unsigned int, "u", a, op, b) #define munit_assert_long(a, op, b) \ munit_assert_type(long int, "ld", a, op, b) #define munit_assert_ulong(a, op, b) \ munit_assert_type(unsigned long int, "lu", a, op, b) #define munit_assert_llong(a, op, b) \ munit_assert_type(long long int, "lld", a, op, b) #define munit_assert_ullong(a, op, b) \ munit_assert_type(unsigned long long int, "llu", a, op, b) #define munit_assert_size(a, op, b) \ munit_assert_type(size_t, MUNIT_SIZE_MODIFIER "u", a, op, b) #define munit_assert_float(a, op, b) \ munit_assert_type(float, "f", a, op, b) #define munit_assert_double(a, op, b) \ munit_assert_type(double, "g", a, op, b) #define munit_assert_ptr(a, op, b) \ munit_assert_type(const void*, "p", a, op, b) #define munit_assert_int8(a, op, b) \ munit_assert_type(munit_int8_t, PRIi8, a, op, b) #define munit_assert_uint8(a, op, b) \ munit_assert_type(munit_uint8_t, PRIu8, a, op, b) #define munit_assert_int16(a, op, b) \ munit_assert_type(munit_int16_t, PRIi16, a, op, b) #define munit_assert_uint16(a, op, b) \ munit_assert_type(munit_uint16_t, PRIu16, a, op, b) #define munit_assert_int32(a, op, b) \ munit_assert_type(munit_int32_t, PRIi32, a, op, b) #define munit_assert_uint32(a, op, b) \ munit_assert_type(munit_uint32_t, PRIu32, a, op, b) #define munit_assert_int64(a, op, b) \ munit_assert_type(munit_int64_t, PRIi64, a, op, b) #define munit_assert_uint64(a, op, b) \ munit_assert_type(munit_uint64_t, PRIu64, a, op, b) #define munit_assert_double_equal(a, b, precision) \ do { \ const double munit_tmp_a_ = (a); \ const double munit_tmp_b_ = (b); \ const double munit_tmp_diff_ = ((munit_tmp_a_ - munit_tmp_b_) < 0) ? \ -(munit_tmp_a_ - munit_tmp_b_) : \ (munit_tmp_a_ - munit_tmp_b_); \ if (MUNIT_UNLIKELY(munit_tmp_diff_ > 1e-##precision)) { \ munit_errorf("assertion failed: %s == %s (%0." #precision "g == %0." #precision "g)", \ #a, #b, munit_tmp_a_, munit_tmp_b_); \ } \ MUNIT_PUSH_DISABLE_MSVC_C4127_ \ } while (0) \ MUNIT_POP_DISABLE_MSVC_C4127_ #define munit_assert_float_equal(a, b, precision) \ do { \ const float munit_tmp_a_ = (a); \ const float munit_tmp_b_ = (b); \ const float munit_tmp_diff_ = ((munit_tmp_a_ - munit_tmp_b_) < 0) ? \ -(munit_tmp_a_ - munit_tmp_b_) : \ (munit_tmp_a_ - munit_tmp_b_); \ if (MUNIT_UNLIKELY(munit_tmp_diff_ > 1e-##precision)) { \ munit_errorf("assertion failed: %s == %s (%0." #precision "g == %0." #precision "g)", \ #a, #b, munit_tmp_a_, munit_tmp_b_); \ } \ MUNIT_PUSH_DISABLE_MSVC_C4127_ \ } while (0) \ MUNIT_POP_DISABLE_MSVC_C4127_ #include #define munit_assert_string_equal(a, b) \ do { \ const char* munit_tmp_a_ = a; \ const char* munit_tmp_b_ = b; \ if (MUNIT_UNLIKELY(strcmp(munit_tmp_a_, munit_tmp_b_) != 0)) { \ munit_errorf("assertion failed: string %s == %s (\"%s\" == \"%s\")", \ #a, #b, munit_tmp_a_, munit_tmp_b_); \ } \ MUNIT_PUSH_DISABLE_MSVC_C4127_ \ } while (0) \ MUNIT_POP_DISABLE_MSVC_C4127_ #define munit_assert_string_not_equal(a, b) \ do { \ const char* munit_tmp_a_ = a; \ const char* munit_tmp_b_ = b; \ if (MUNIT_UNLIKELY(strcmp(munit_tmp_a_, munit_tmp_b_) == 0)) { \ munit_errorf("assertion failed: string %s != %s (\"%s\" == \"%s\")", \ #a, #b, munit_tmp_a_, munit_tmp_b_); \ } \ MUNIT_PUSH_DISABLE_MSVC_C4127_ \ } while (0) \ MUNIT_POP_DISABLE_MSVC_C4127_ #define munit_assert_floats_equal(size, a, b) \ do { \ const float* munit_tmp_a_ = (const float*) (a); \ const float* munit_tmp_b_ = (const float*) (b); \ const size_t munit_tmp_size_ = (size); \ if (MUNIT_UNLIKELY(memcmp(munit_tmp_a_, munit_tmp_b_, munit_tmp_size_ * sizeof(float))) != 0) { \ size_t munit_tmp_pos_; \ for (munit_tmp_pos_ = 0 ; munit_tmp_pos_ < munit_tmp_size_ ; munit_tmp_pos_++) { \ if (abs(munit_tmp_a_[munit_tmp_pos_] - munit_tmp_b_[munit_tmp_pos_]) > 1e-6) { \ munit_errorf("assertion failed: floats %s == %s (%f == %f), at offset %" MUNIT_SIZE_MODIFIER "u", \ #a, #b, munit_tmp_a_[munit_tmp_pos_], munit_tmp_b_[munit_tmp_pos_], munit_tmp_pos_); \ break; \ } \ } \ } \ MUNIT_PUSH_DISABLE_MSVC_C4127_ \ } while (0) \ MUNIT_POP_DISABLE_MSVC_C4127_ #define munit_assert_memory_equal(size, a, b) \ do { \ const unsigned char* munit_tmp_a_ = (const unsigned char*) (a); \ const unsigned char* munit_tmp_b_ = (const unsigned char*) (b); \ const size_t munit_tmp_size_ = (size); \ if (MUNIT_UNLIKELY(memcmp(munit_tmp_a_, munit_tmp_b_, munit_tmp_size_)) != 0) { \ size_t munit_tmp_pos_; \ for (munit_tmp_pos_ = 0 ; munit_tmp_pos_ < munit_tmp_size_ ; munit_tmp_pos_++) { \ if (munit_tmp_a_[munit_tmp_pos_] != munit_tmp_b_[munit_tmp_pos_]) { \ munit_errorf("assertion failed: memory %s == %s, at offset %" MUNIT_SIZE_MODIFIER "u", \ #a, #b, munit_tmp_pos_); \ break; \ } \ } \ } \ MUNIT_PUSH_DISABLE_MSVC_C4127_ \ } while (0) \ MUNIT_POP_DISABLE_MSVC_C4127_ #define munit_assert_memory_not_equal(size, a, b) \ do { \ const unsigned char* munit_tmp_a_ = (const unsigned char*) (a); \ const unsigned char* munit_tmp_b_ = (const unsigned char*) (b); \ const size_t munit_tmp_size_ = (size); \ if (MUNIT_UNLIKELY(memcmp(munit_tmp_a_, munit_tmp_b_, munit_tmp_size_)) == 0) { \ munit_errorf("assertion failed: memory %s != %s (%zu bytes)", \ #a, #b, munit_tmp_size_); \ } \ MUNIT_PUSH_DISABLE_MSVC_C4127_ \ } while (0) \ MUNIT_POP_DISABLE_MSVC_C4127_ #define munit_assert_ptr_equal(a, b) \ munit_assert_ptr(a, ==, b) #define munit_assert_ptr_not_equal(a, b) \ munit_assert_ptr(a, !=, b) #define munit_assert_null(ptr) \ munit_assert_ptr(ptr, ==, NULL) #define munit_assert_not_null(ptr) \ munit_assert_ptr(ptr, !=, NULL) #define munit_assert_ptr_null(ptr) \ munit_assert_ptr(ptr, ==, NULL) #define munit_assert_ptr_not_null(ptr) \ munit_assert_ptr(ptr, !=, NULL) /*** Memory allocation ***/ void* munit_malloc_ex(const char* filename, int line, size_t size); #define munit_malloc(size) \ munit_malloc_ex(__FILE__, __LINE__, (size)) #define munit_new(type) \ ((type*) munit_malloc(sizeof(type))) #define munit_calloc(nmemb, size) \ munit_malloc((nmemb) * (size)) #define munit_newa(type, nmemb) \ ((type*) munit_calloc((nmemb), sizeof(type))) /*** Random number generation ***/ void munit_rand_seed(munit_uint32_t seed); munit_uint32_t munit_rand_uint32(void); int munit_rand_int_range(int min, int max); double munit_rand_double(void); void munit_rand_memory(size_t size, munit_uint8_t buffer[MUNIT_ARRAY_PARAM(size)]); /*** Tests and Suites ***/ typedef enum { /* Test successful */ MUNIT_OK, /* Test failed */ MUNIT_FAIL, /* Test was skipped */ MUNIT_SKIP, /* Test failed due to circumstances not intended to be tested * (things like network errors, invalid parameter value, failure to * allocate memory in the test harness, etc.). */ MUNIT_ERROR } MunitResult; typedef struct { char* name; char** values; } MunitParameterEnum; typedef struct { char* name; char* value; } MunitParameter; const char* munit_parameters_get(const MunitParameter params[], const char* key); typedef enum { MUNIT_TEST_OPTION_NONE = 0, MUNIT_TEST_OPTION_SINGLE_ITERATION = 1 << 0, MUNIT_TEST_OPTION_TODO = 1 << 1 } MunitTestOptions; typedef MunitResult (* MunitTestFunc)(const MunitParameter params[], void* user_data_or_fixture); typedef void* (* MunitTestSetup)(const MunitParameter params[], void* user_data); typedef void (* MunitTestTearDown)(void* fixture); typedef struct { char* name; MunitTestFunc test; MunitTestSetup setup; MunitTestTearDown tear_down; MunitTestOptions options; MunitParameterEnum* parameters; } MunitTest; typedef enum { MUNIT_SUITE_OPTION_NONE = 0 } MunitSuiteOptions; typedef struct MunitSuite_ MunitSuite; struct MunitSuite_ { char* prefix; MunitTest* tests; MunitSuite* suites; unsigned int iterations; MunitSuiteOptions options; }; int munit_suite_main(const MunitSuite* suite, void* user_data, int argc, char* const argv[MUNIT_ARRAY_PARAM(argc + 1)]); /* Note: I'm not very happy with this API; it's likely to change if I * figure out something better. Suggestions welcome. */ typedef struct MunitArgument_ MunitArgument; struct MunitArgument_ { char* name; munit_bool (* parse_argument)(const MunitSuite* suite, void* user_data, int* arg, int argc, char* const argv[MUNIT_ARRAY_PARAM(argc + 1)]); void (* write_help)(const MunitArgument* argument, void* user_data); }; int munit_suite_main_custom(const MunitSuite* suite, void* user_data, int argc, char* const argv[MUNIT_ARRAY_PARAM(argc + 1)], const MunitArgument arguments[]); #if defined(MUNIT_ENABLE_ASSERT_ALIASES) #define assert_true(expr) munit_assert_true(expr) #define assert_false(expr) munit_assert_false(expr) #define assert_char(a, op, b) munit_assert_char(a, op, b) #define assert_uchar(a, op, b) munit_assert_uchar(a, op, b) #define assert_short(a, op, b) munit_assert_short(a, op, b) #define assert_ushort(a, op, b) munit_assert_ushort(a, op, b) #define assert_int(a, op, b) munit_assert_int(a, op, b) #define assert_uint(a, op, b) munit_assert_uint(a, op, b) #define assert_long(a, op, b) munit_assert_long(a, op, b) #define assert_ulong(a, op, b) munit_assert_ulong(a, op, b) #define assert_llong(a, op, b) munit_assert_llong(a, op, b) #define assert_ullong(a, op, b) munit_assert_ullong(a, op, b) #define assert_size(a, op, b) munit_assert_size(a, op, b) #define assert_float(a, op, b) munit_assert_float(a, op, b) #define assert_double(a, op, b) munit_assert_double(a, op, b) #define assert_ptr(a, op, b) munit_assert_ptr(a, op, b) #define assert_int8(a, op, b) munit_assert_int8(a, op, b) #define assert_uint8(a, op, b) munit_assert_uint8(a, op, b) #define assert_int16(a, op, b) munit_assert_int16(a, op, b) #define assert_uint16(a, op, b) munit_assert_uint16(a, op, b) #define assert_int32(a, op, b) munit_assert_int32(a, op, b) #define assert_uint32(a, op, b) munit_assert_uint32(a, op, b) #define assert_int64(a, op, b) munit_assert_int64(a, op, b) #define assert_uint64(a, op, b) munit_assert_uint64(a, op, b) #define assert_double_equal(a, b, precision) munit_assert_double_equal(a, b, precision) #define assert_string_equal(a, b) munit_assert_string_equal(a, b) #define assert_string_not_equal(a, b) munit_assert_string_not_equal(a, b) #define assert_memory_equal(size, a, b) munit_assert_memory_equal(size, a, b) #define assert_memory_not_equal(size, a, b) munit_assert_memory_not_equal(size, a, b) #define assert_ptr_equal(a, b) munit_assert_ptr_equal(a, b) #define assert_ptr_not_equal(a, b) munit_assert_ptr_not_equal(a, b) #define assert_ptr_null(ptr) munit_assert_null_equal(ptr) #define assert_ptr_not_null(ptr) munit_assert_not_null(ptr) #define assert_null(ptr) munit_assert_null(ptr) #define assert_not_null(ptr) munit_assert_not_null(ptr) #endif /* defined(MUNIT_ENABLE_ASSERT_ALIASES) */ #if defined(__cplusplus) } #endif #endif /* !defined(MUNIT_H) */ #if defined(MUNIT_ENABLE_ASSERT_ALIASES) # if defined(assert) # undef assert # endif # define assert(expr) munit_assert(expr) #endif ================================================ FILE: common/heap/filter.go ================================================ // Copyright 2022 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package heap import ( "container/heap" "golang.org/x/exp/constraints" ) // TopKFilter filters out top k items with maximum weights. type TopKFilter[T any, W constraints.Ordered] struct { _heap[T, W] k int } // NewTopKFilter creates a top k filter. func NewTopKFilter[T any, W constraints.Ordered](k int) *TopKFilter[T, W] { return &TopKFilter[T, W]{k: k} } // Push pushes the element x onto the heap. // The complexity is O(log n) where n = h.Count(). func (filter *TopKFilter[T, W]) Push(item T, weight W) { heap.Push(&filter._heap, Elem[T, W]{item, weight}) if filter.Len() > filter.k { heap.Pop(&filter._heap) } } // PopAllValues pops all values in the filter with decreasing order. func (filter *TopKFilter[T, W]) PopAllValues() []T { items := make([]T, filter.Len()) for i := len(items) - 1; i >= 0; i-- { elem := heap.Pop(&filter._heap).(Elem[T, W]) items[i] = elem.Value } return items } // PopAll pops all items in the filter with decreasing order. func (filter *TopKFilter[T, W]) PopAll() []Elem[T, W] { results := make([]Elem[T, W], filter.Len()) for i := len(results) - 1; i >= 0; i-- { results[i] = heap.Pop(&filter._heap).(Elem[T, W]) } return results } ================================================ FILE: common/heap/filter_test.go ================================================ // Copyright 2022 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package heap import ( "testing" "github.com/stretchr/testify/assert" ) func TestTopKFilter(t *testing.T) { // Test a adjacent vec a := NewTopKFilter[int32, float32](3) a.Push(10, 2) a.Push(20, 8) a.Push(30, 1) values := a.PopAllValues() assert.Equal(t, []int32{20, 10, 30}, values) // Test a full adjacent vec a = NewTopKFilter[int32, float32](3) a.Push(10, 2) a.Push(20, 8) a.Push(30, 1) a.Push(40, 2) a.Push(50, 5) a.Push(12, 10) a.Push(67, 7) a.Push(32, 9) elems := a.PopAll() assert.Equal(t, []Elem[int32, float32]{ {Value: 12, Weight: 10}, {Value: 32, Weight: 9}, {Value: 20, Weight: 8}, }, elems) } func TestTopKStringFilter(t *testing.T) { // Test a adjacent vec a := NewTopKFilter[string, float64](3) a.Push("10", 2) a.Push("20", 8) a.Push("30", 1) elems := a.PopAll() assert.Equal(t, []Elem[string, float64]{ {Value: "20", Weight: 8}, {Value: "10", Weight: 2}, {Value: "30", Weight: 1}, }, elems) // Test a full adjacent vec a = NewTopKFilter[string, float64](3) a.Push("10", 2) a.Push("20", 8) a.Push("30", 1) a.Push("40", 2) a.Push("50", 5) a.Push("12", 10) a.Push("67", 7) a.Push("32", 9) elems = a.PopAll() assert.Equal(t, []Elem[string, float64]{ {Value: "12", Weight: 10}, {Value: "32", Weight: 9}, {Value: "20", Weight: 8}, }, elems) } ================================================ FILE: common/heap/pq.go ================================================ // Copyright 2022 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package heap import ( "container/heap" "encoding/binary" "io" "github.com/chewxy/math32" mapset "github.com/deckarep/golang-set/v2" "github.com/gorse-io/gorse/common/encoding" "golang.org/x/exp/constraints" ) type Elem[E any, W constraints.Ordered] struct { Value E Weight W } type _heap[T any, W constraints.Ordered] struct { elems []Elem[T, W] desc bool } func (e *_heap[T, W]) Len() int { return len(e.elems) } func (e *_heap[T, W]) Less(i, j int) bool { if e.desc { return e.elems[i].Weight > e.elems[j].Weight } else { return e.elems[i].Weight < e.elems[j].Weight } } func (e *_heap[T, W]) Swap(i, j int) { e.elems[i], e.elems[j] = e.elems[j], e.elems[i] } func (e *_heap[T, W]) Push(x interface{}) { it := x.(Elem[T, W]) e.elems = append(e.elems, it) } func (e *_heap[T, W]) Pop() interface{} { old := e.elems item := e.elems[len(old)-1] e.elems = old[0 : len(old)-1] return item } // PriorityQueue represents the priority queue. type PriorityQueue struct { _heap[int32, float32] lookup mapset.Set[int32] } // NewPriorityQueue initializes an empty priority queue. func NewPriorityQueue(desc bool) *PriorityQueue { return &PriorityQueue{ _heap: _heap[int32, float32]{desc: desc}, lookup: mapset.NewSet[int32](), } } // Push inserts a new element into the queue. No action is performed on duplicate elements. func (p *PriorityQueue) Push(v int32, weight float32) { if math32.IsNaN(weight) { panic("NaN weight is forbidden") } else if !p.lookup.Contains(v) { newItem := Elem[int32, float32]{ Value: v, Weight: weight, } heap.Push(&p._heap, newItem) p.lookup.Add(v) } } // Pop removes the element with the highest priority from the queue and returns it. // In case of an empty queue, an error is returned. func (p *PriorityQueue) Pop() (int32, float32) { item := heap.Pop(&p._heap).(Elem[int32, float32]) return item.Value, item.Weight } func (p *PriorityQueue) Peek() (int32, float32) { return p.elems[0].Value, p.elems[0].Weight } func (p *PriorityQueue) Values() []int32 { values := make([]int32, 0, p.Len()) for _, elem := range p.elems { values = append(values, elem.Value) } return values } func (p *PriorityQueue) Elems() []Elem[int32, float32] { return p.elems } func (p *PriorityQueue) Clone() *PriorityQueue { pq := NewPriorityQueue(p.desc) pq.elems = make([]Elem[int32, float32], p.Len()) copy(pq.elems, p.elems) return pq } func (p *PriorityQueue) Reverse() *PriorityQueue { pq := NewPriorityQueue(!p.desc) pq.elems = make([]Elem[int32, float32], 0, p.Len()) for _, elem := range p.elems { pq.Push(elem.Value, elem.Weight) } return pq } func (p *PriorityQueue) Marshal(w io.Writer) error { if err := binary.Write(w, binary.LittleEndian, p.desc); err != nil { return err } return encoding.WriteSlice(w, p.elems) } func (p *PriorityQueue) Unmarshal(r io.Reader) error { if err := binary.Read(r, binary.LittleEndian, &p.desc); err != nil { return err } if err := encoding.ReadSlice(r, &p.elems); err != nil { return err } p.lookup = mapset.NewSetWithSize[int32](p.Len()) for _, elem := range p.elems { p.lookup.Add(elem.Value) } return nil } ================================================ FILE: common/heap/pq_test.go ================================================ // Copyright 2022 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package heap import ( "os" "path/filepath" "sort" "testing" "github.com/samber/lo" "github.com/stretchr/testify/assert" "modernc.org/sortutil" ) func TestPriorityQueue(t *testing.T) { pq := NewPriorityQueue(false) elements := []int32{5, 3, 7, 8, 6, 2, 9} for _, e := range elements { pq.Push(e, float32(e)) } assert.Equal(t, len(elements), pq.Len()) assert.ElementsMatch(t, elements, pq.Values()) assert.Equal(t, len(elements), len(pq.Elems())) // test clone cp := pq.Clone() assert.Equal(t, len(elements), cp.Len()) // test peek pop sort.Sort(sortutil.Int32Slice(elements)) for _, e := range elements { value, weight := pq.Peek() assert.Equal(t, e, value) assert.Equal(t, e, int32(weight)) value, weight = pq.Pop() assert.Equal(t, e, value) assert.Equal(t, e, int32(weight)) } // test reverse r := cp.Reverse() lo.Reverse(elements) for _, e := range elements { value, weight := r.Pop() assert.Equal(t, e, value) assert.Equal(t, e, int32(weight)) } } func TestMarshalUnmarshal(t *testing.T) { pq := NewPriorityQueue(false) elements := []int32{5, 3, 7, 8, 6, 2, 9} for _, e := range elements { pq.Push(e, float32(e)) } path := filepath.Join(t.TempDir(), "pq.bin") f, err := os.Create(path) assert.NoError(t, err) defer f.Close() err = pq.Marshal(f) assert.NoError(t, err) f, err = os.Open(path) assert.NoError(t, err) defer f.Close() pq2 := NewPriorityQueue(false) err = pq2.Unmarshal(f) assert.NoError(t, err) assert.Equal(t, pq.Len(), pq2.Len()) assert.Equal(t, pq.desc, pq2.desc) assert.Equal(t, pq.elems, pq2.elems) assert.True(t, pq.lookup.Equal(pq2.lookup)) } ================================================ FILE: common/jsonutil/json.go ================================================ // Copyright 2022 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package jsonutil import "encoding/json" // Marshal returns the JSON encoding of v. func Marshal(v interface{}) ([]byte, error) { return json.Marshal(v) } // Unmarshal parses the JSON-encoded data and stores the result // in the value pointed to by v. If data is empty, Unmarshal clears // contents in v. func Unmarshal(data []byte, v interface{}) error { if len(data) == 0 { data = []byte("null") } return json.Unmarshal(data, v) } // MustMarshal returns the JSON encoding of v. Panic if error occurs. func MustMarshal(v interface{}) string { data, err := Marshal(v) if err != nil { panic(err) } return string(data) } ================================================ FILE: common/jsonutil/json_test.go ================================================ // Copyright 2022 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package jsonutil import ( "testing" "github.com/stretchr/testify/assert" ) func TestUnmarshal(t *testing.T) { var a []int err := Unmarshal([]byte("[1,2,3]"), &a) assert.NoError(t, err) assert.Equal(t, []int{1, 2, 3}, a) err = Unmarshal([]byte(""), &a) assert.NoError(t, err) assert.Empty(t, a) } func TestMarshal(t *testing.T) { data, err := Marshal(nil) assert.NoError(t, err) assert.Equal(t, "null", string(data)) } func TestMustMarshal(t *testing.T) { assert.Panics(t, func() { MustMarshal(make(chan int)) }) } ================================================ FILE: common/log/log.go ================================================ // Copyright 2022 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package log import ( "net/url" "os" "runtime" "strings" "testing" "github.com/c-bata/goptuna" "github.com/emicklei/go-restful/v3" "github.com/go-sql-driver/mysql" "github.com/spf13/pflag" "go.opentelemetry.io/otel" "go.uber.org/zap" "go.uber.org/zap/zapcore" "gopkg.in/natefinch/lumberjack.v2" ) var ( logger *zap.Logger openaiLogger *zap.Logger ) func init() { // setup default logger var err error logger, err = zap.NewProduction() if err != nil { panic(err) } // setup OpenAI logger if testing.Testing() { openaiLogger, err = zap.NewDevelopment() if err != nil { panic(err) } } else { openaiLogger = zap.NewNop() } // Windows file sink support: https://github.com/uber-go/zap/issues/621 if runtime.GOOS == "windows" { if err := zap.RegisterSink("windows", func(u *url.URL) (zap.Sink, error) { // Remove leading slash left by url.Parse() return os.OpenFile(u.Path[1:], os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0644) }); err != nil { logger.Fatal("failed to register Windows file sink", zap.Error(err)) } } } // Logger get current logger func Logger() *zap.Logger { return logger } func ResponseLogger(resp *restful.Response) *zap.Logger { return logger.With(zap.String("request_id", resp.Header().Get("X-Request-ID"))) } func CloseLogger() { cfg := zap.NewProductionConfig() cfg.Level = zap.NewAtomicLevelAt(zap.FatalLevel) var err error logger, err = cfg.Build() if err != nil { panic(err) } } func AddFlags(flagSet *pflag.FlagSet) { flagSet.String("log-path", "", "path of log file") flagSet.Int("log-max-size", 100, "maximum size in megabytes of the log file") flagSet.Int("log-max-age", 0, "maximum number of days to retain old log files") flagSet.Int("log-max-backups", 0, "maximum number of old log files to retain") } func SetLogger(flagSet *pflag.FlagSet, debug bool) { // enable or disable debug mode var ( encoder zapcore.Encoder level zapcore.LevelEnabler ) if debug { encoder = zapcore.NewConsoleEncoder(zap.NewDevelopmentEncoderConfig()) level = zap.DebugLevel } else { encoder = zapcore.NewJSONEncoder(zap.NewProductionEncoderConfig()) level = zap.InfoLevel } // create lumberjack logger writers := []zapcore.WriteSyncer{zapcore.AddSync(os.Stdout)} if flagSet.Changed("log-path") { path, _ := flagSet.GetString("log-path") maxSize, _ := flagSet.GetInt("log-max-size") maxAge, _ := flagSet.GetInt("log-max-age") maxBackups, _ := flagSet.GetInt("log-max-backups") writers = append(writers, zapcore.AddSync(&lumberjack.Logger{ Filename: path, MaxSize: maxSize, MaxBackups: maxBackups, MaxAge: maxAge, Compress: false, })) } // create zap logger core := zapcore.NewCore(encoder, zap.CombineWriteSyncers(writers...), level) logger = zap.New(core) } const mysqlPrefix = "mysql://" func RedactDBURL(rawURL string) string { if strings.HasPrefix(rawURL, "sqlite://") { return rawURL } else if strings.HasPrefix(rawURL, mysqlPrefix) { parsed, err := mysql.ParseDSN(rawURL[len(mysqlPrefix):]) if err != nil { return rawURL } parsed.User = strings.Repeat("x", len(parsed.User)) parsed.Passwd = strings.Repeat("x", len(parsed.Passwd)) return mysqlPrefix + parsed.FormatDSN() } else { parsed, err := url.Parse(rawURL) if err != nil { return rawURL } username := parsed.User.Username() password, _ := parsed.User.Password() parsed.User = url.UserPassword(strings.Repeat("x", len(username)), strings.Repeat("x", len(password))) return parsed.String() } } func GetErrorHandler() otel.ErrorHandler { return &errorHandler{} } type errorHandler struct{} func (h *errorHandler) Handle(err error) { Logger().Error("opentelemetry failure", zap.Error(err)) } func OpenAILogger() *zap.Logger { return openaiLogger } func InitOpenAILogger(filename string) { if filename != "" { openaiLogger = zap.New( zapcore.NewCore( zapcore.NewJSONEncoder(zap.NewProductionEncoderConfig()), zapcore.AddSync(&lumberjack.Logger{ Filename: filename, }), zap.InfoLevel)) } } type OptunaLogger struct { logger *zap.Logger } func NewOptunaLogger(logger *zap.Logger) goptuna.Logger { return &OptunaLogger{logger: logger} } func (o OptunaLogger) Debug(msg string, fields ...interface{}) { o.logger.Debug(msg, zap.Any("fields", fields)) } func (o OptunaLogger) Info(msg string, fields ...interface{}) { o.logger.Info(msg, zap.Any("fields", fields)) } func (o OptunaLogger) Warn(msg string, fields ...interface{}) { o.logger.Warn(msg, zap.Any("fields", fields)) } func (o OptunaLogger) Error(msg string, fields ...interface{}) { o.logger.Error(msg, zap.Any("fields", fields)) } ================================================ FILE: common/log/log_test.go ================================================ // Copyright 2022 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package log import ( "os" "testing" "github.com/spf13/pflag" "github.com/stretchr/testify/assert" ) func TestSetDevelopmentLogger(t *testing.T) { temp, err := os.MkdirTemp("", "gorse") assert.NoError(t, err) flagSet := pflag.NewFlagSet("test", pflag.ContinueOnError) AddFlags(flagSet) // set existed path err = flagSet.Set("log-path", temp+"/gorse.log") assert.NoError(t, err) SetLogger(flagSet, true) Logger().Debug("test") assert.FileExists(t, temp+"/gorse.log") // set non-existed path err = flagSet.Set("log-path", temp+"/gorse/gorse.log") assert.NoError(t, err) SetLogger(flagSet, true) Logger().Debug("test") assert.FileExists(t, temp+"/gorse/gorse.log") } func TestSetProductionLogger(t *testing.T) { temp, err := os.MkdirTemp("", "gorse") assert.NoError(t, err) flagSet := pflag.NewFlagSet("test", pflag.ContinueOnError) AddFlags(flagSet) // set existed path err = flagSet.Set("log-path", temp+"/gorse.log") assert.NoError(t, err) SetLogger(flagSet, false) Logger().Info("test") assert.FileExists(t, temp+"/gorse.log") // set non-existed path err = flagSet.Set("log-path", temp+"/gorse/gorse.log") assert.NoError(t, err) SetLogger(flagSet, false) Logger().Info("test") assert.FileExists(t, temp+"/gorse/gorse.log") } func TestRedactDBURL(t *testing.T) { assert.Equal(t, "sqlite://data/data.sqlite", RedactDBURL("sqlite://data/data.sqlite")) assert.Equal(t, "mysql://xxxxx:xxxxxxxxxx@tcp(localhost:3306)/gorse?parseTime=true", RedactDBURL("mysql://gorse:gorse_pass@tcp(localhost:3306)/gorse?parseTime=true")) assert.Equal(t, "postgres://xxx:xxxxxx@1.2.3.4:5432/mydb?sslmode=verify-full", RedactDBURL("postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full")) assert.Equal(t, "mysql://gorse:gorse_pass@tcp(localhost:3306) gorse?parseTime=true", RedactDBURL("mysql://gorse:gorse_pass@tcp(localhost:3306) gorse?parseTime=true")) assert.Equal(t, "postgres://bob:secret@1.2.3.4:5432 mydb?sslmode=verify-full", RedactDBURL("postgres://bob:secret@1.2.3.4:5432 mydb?sslmode=verify-full")) } ================================================ FILE: common/mock/openai.go ================================================ // Copyright 2025 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package mock import ( "bytes" "crypto/md5" "encoding/json" "fmt" "net" "net/http" "github.com/emicklei/go-restful/v3" "github.com/samber/lo" "github.com/sashabaranov/go-openai" ) type OpenAIServer struct { listener net.Listener httpServer *http.Server authToken string ready chan struct{} } func NewOpenAIServer() *OpenAIServer { s := &OpenAIServer{} ws := new(restful.WebService) ws.Path("/v1"). Consumes(restful.MIME_JSON). Produces(restful.MIME_JSON, "text/event-stream") ws.Route(ws.POST("chat/completions"). Reads(openai.ChatCompletionRequest{}). Writes(openai.ChatCompletionResponse{}). To(s.chatCompletion)) ws.Route(ws.POST("embeddings"). Reads(openai.EmbeddingRequest{}). Writes(openai.EmbeddingResponse{}). To(s.embeddings)) container := restful.NewContainer() container.Add(ws) s.httpServer = &http.Server{Handler: container} s.authToken = "ollama" s.ready = make(chan struct{}) return s } func (s *OpenAIServer) Start() error { var err error s.listener, err = net.Listen("tcp", "") if err != nil { return err } close(s.ready) return s.httpServer.Serve(s.listener) } func (s *OpenAIServer) BaseURL() string { return fmt.Sprintf("http://%s/v1", s.listener.Addr().String()) } func (s *OpenAIServer) AuthToken() string { return s.authToken } func (s *OpenAIServer) Ready() { <-s.ready } func (s *OpenAIServer) Close() error { return s.httpServer.Close() } func (s *OpenAIServer) chatCompletion(req *restful.Request, resp *restful.Response) { var r openai.ChatCompletionRequest err := req.ReadEntity(&r) if err != nil { _ = resp.WriteError(http.StatusBadRequest, err) return } content := r.Messages[0].Content if r.Model == "deepseek-r1" { content = "To be or not to be, that is the question." + content } if r.Stream { for i := 0; i < len(content); i += 8 { buf := bytes.NewBuffer(nil) buf.WriteString("data: ") encoder := json.NewEncoder(buf) _ = encoder.Encode(openai.ChatCompletionStreamResponse{ Choices: []openai.ChatCompletionStreamChoice{{ Delta: openai.ChatCompletionStreamChoiceDelta{ Content: content[i:min(i+8, len(content))], }, }}, }) buf.WriteString("\n") _, _ = resp.Write(buf.Bytes()) } } else { _ = resp.WriteEntity(openai.ChatCompletionResponse{ Choices: []openai.ChatCompletionChoice{{ Message: openai.ChatCompletionMessage{ Content: content, }, }}, }) } } func (s *OpenAIServer) embeddings(req *restful.Request, resp *restful.Response) { // parse request var r openai.EmbeddingRequest err := req.ReadEntity(&r) if err != nil { _ = resp.WriteError(http.StatusBadRequest, err) return } input, ok := r.Input.(string) if !ok { _ = resp.WriteError(http.StatusBadRequest, fmt.Errorf("invalid input type")) return } // write response _ = resp.WriteEntity(openai.EmbeddingResponse{ Data: []openai.Embedding{{ Embedding: Hash(input), }}, }) } func Hash(input string) []float32 { hasher := md5.New() _, err := hasher.Write([]byte(input)) if err != nil { panic(err) } h := hasher.Sum(nil) return lo.Map(h, func(b byte, _ int) float32 { return float32(b) }) } ================================================ FILE: common/mock/openai_test.go ================================================ // Copyright 2025 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package mock import ( "io" "strings" "testing" "github.com/juju/errors" "github.com/sashabaranov/go-openai" "github.com/stretchr/testify/suite" ) type OpenAITestSuite struct { suite.Suite server *OpenAIServer client *openai.Client } func (suite *OpenAITestSuite) SetupSuite() { // Start mock server suite.server = NewOpenAIServer() go func() { _ = suite.server.Start() }() suite.server.Ready() // Create client clientConfig := openai.DefaultConfig(suite.server.AuthToken()) clientConfig.BaseURL = suite.server.BaseURL() suite.client = openai.NewClientWithConfig(clientConfig) } func (suite *OpenAITestSuite) TearDownSuite() { suite.NoError(suite.server.Close()) } func (suite *OpenAITestSuite) TestChatCompletion() { resp, err := suite.client.CreateChatCompletion( suite.T().Context(), openai.ChatCompletionRequest{ Model: "qwen2.5", Messages: []openai.ChatCompletionMessage{ { Role: openai.ChatMessageRoleUser, Content: "Hello", }, }, }, ) suite.NoError(err) suite.Equal("Hello", resp.Choices[0].Message.Content) } func (suite *OpenAITestSuite) TestChatCompletionStream() { content := "In my younger and more vulnerable years my father gave me some advice that I've been turning over in" + " my mind ever since. Whenever you feel like criticizing anyone, he told me, just remember that all the " + "people in this world haven't had the advantages that you've had." stream, err := suite.client.CreateChatCompletionStream( suite.T().Context(), openai.ChatCompletionRequest{ Model: "qwen2.5", Messages: []openai.ChatCompletionMessage{ { Role: openai.ChatMessageRoleUser, Content: content, }, }, Stream: true, }, ) suite.NoError(err) defer stream.Close() var buffer strings.Builder for { var resp openai.ChatCompletionStreamResponse resp, err = stream.Recv() if errors.Is(err, io.EOF) { suite.Equal(content, buffer.String()) return } suite.NoError(err) buffer.WriteString(resp.Choices[0].Delta.Content) } } func (suite *OpenAITestSuite) TestEmbeddings() { resp, err := suite.client.CreateEmbeddings( suite.T().Context(), openai.EmbeddingRequest{ Input: "Hello", Model: "mxbai-embed-large", }, ) suite.NoError(err) suite.Equal([]float32{139, 26, 153, 83, 196, 97, 18, 150, 168, 39, 171, 248, 196, 120, 4, 215}, resp.Data[0].Embedding) } func TestOpenAITestSuite(t *testing.T) { suite.Run(t, new(OpenAITestSuite)) } ================================================ FILE: common/monitor/progress.go ================================================ // Copyright 2023 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package monitor import ( "context" "sort" "sync" "time" "github.com/google/uuid" "github.com/gorse-io/gorse/protocol" "modernc.org/mathutil" ) type spanKeyType string var spanKeyName = spanKeyType(uuid.New().String()) type Status string const ( StatusPending Status = "Pending" StatusComplete Status = "Complete" StatusRunning Status = "Running" StatusSuspended Status = "Suspended" StatusFailed Status = "Failed" ) type Monitor struct { name string spans sync.Map } func NewTracer(name string) *Monitor { return &Monitor{name: name} } // Start creates a root span. func (t *Monitor) Start(ctx context.Context, name string, total int) (context.Context, *Span) { span := &Span{ name: name, status: StatusRunning, total: total, start: time.Now(), } t.spans.Store(name, span) return context.WithValue(ctx, spanKeyName, span), span } func (t *Monitor) List() []Progress { var progress []Progress t.spans.Range(func(key, value interface{}) bool { span := value.(*Span) p := span.Progress() p.Tracer = t.name progress = append(progress, p) return true }) // sort by start time sort.Slice(progress, func(i, j int) bool { return progress[i].StartTime.Before(progress[j].StartTime) }) return progress } type Span struct { name string status Status total int count int err string start time.Time finish time.Time children sync.Map } func (s *Span) Add(n int) { s.count = mathutil.Min(s.count+n, s.total) } func (s *Span) End() { if s.status == StatusRunning { s.status = StatusComplete s.count = s.total s.finish = time.Now() } } func (s *Span) Fail(err error) { s.status = StatusFailed s.err = err.Error() } func (s *Span) Count() int { return s.count } func (s *Span) Progress() Progress { // find running children var children []Progress s.children.Range(func(key, value interface{}) bool { child := value.(*Span) progress := child.Progress() if progress.Status == StatusRunning { children = append(children, progress) } if s.err == "" && progress.Error != "" { s.err = progress.Error s.status = StatusFailed } return true }) // leaf node if len(children) == 0 { return Progress{ Name: s.name, Status: s.status, Error: s.err, Count: s.count, Total: s.total, StartTime: s.start, FinishTime: s.finish, } } // non-leaf node childTotal := children[0].Total parentTotal := s.total * childTotal parentCount := s.count * childTotal for _, child := range children { parentCount += childTotal * child.Count / child.Total } return Progress{ Name: s.name, Status: s.status, Error: s.err, Count: parentCount, Total: parentTotal, StartTime: s.start, FinishTime: s.finish, } } func Start(ctx context.Context, name string, total int) (context.Context, *Span) { childSpan := &Span{ name: name, status: StatusRunning, total: total, count: 0, start: time.Now(), } if ctx == nil { return nil, childSpan } span, ok := (ctx).Value(spanKeyName).(*Span) if !ok { return ctx, childSpan } span.children.Store(name, childSpan) return context.WithValue(ctx, spanKeyName, childSpan), childSpan } func Fail(ctx context.Context, err error) { span, ok := (ctx).Value(spanKeyName).(*Span) if !ok { return } span.Fail(err) } type Progress struct { Tracer string Name string Status Status Error string Count int Total int StartTime time.Time FinishTime time.Time } func DecodeProgress(in *protocol.PushProgressRequest) []Progress { var progressList []Progress for _, p := range in.Progress { progressList = append(progressList, Progress{ Tracer: p.GetTracer(), Name: p.GetName(), Status: Status(p.GetStatus()), Count: int(p.GetCount()), Total: int(p.GetTotal()), StartTime: time.UnixMilli(p.GetStartTime()), FinishTime: time.UnixMilli(p.GetFinishTime()), }) } return progressList } func EncodeProgress(progressList []Progress) *protocol.PushProgressRequest { var pbList []*protocol.Progress for _, p := range progressList { pbList = append(pbList, &protocol.Progress{ Tracer: p.Tracer, Name: p.Name, Status: string(p.Status), Count: int64(p.Count), Total: int64(p.Total), StartTime: p.StartTime.UnixMilli(), FinishTime: p.FinishTime.UnixMilli(), }) } return &protocol.PushProgressRequest{ Progress: pbList, } } ================================================ FILE: common/monitor/progress_test.go ================================================ // Copyright 2023 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package monitor import ( "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" ) type ProgressTestSuite struct { suite.Suite tracer Monitor } func (suite *ProgressTestSuite) SetupTest() { suite.tracer = Monitor{} } func TestProgressTestSuite(t *testing.T) { suite.Run(t, new(ProgressTestSuite)) } func TestEncodeDecode(t *testing.T) { progressList := []Progress{ { Tracer: "tracer", Name: "a", Total: 100, Count: 50, Status: StatusRunning, StartTime: time.Date(2018, time.January, 1, 0, 0, 0, 0, time.Local), FinishTime: time.Date(2018, time.January, 2, 0, 0, 0, 0, time.Local), }, { Tracer: "tracer", Name: "b", Total: 100, Count: 50, Status: StatusRunning, StartTime: time.Date(2018, time.January, 1, 0, 0, 0, 0, time.Local), FinishTime: time.Date(2018, time.January, 2, 0, 0, 0, 0, time.Local), }, } pb := EncodeProgress(progressList) assert.Equal(t, progressList, DecodeProgress(pb)) } ================================================ FILE: common/nn/functions.go ================================================ // Copyright 2024 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package nn import ( "fmt" ) func Neg(x *Tensor) *Tensor { return apply(&neg{}, x) } // Add returns the element-wise sum of two tensors. The shape of the second tensor must be a suffix sequence of the shape of the first tensor. func Add(x0 *Tensor, x ...*Tensor) *Tensor { output := x0 for _, x1 := range x { if len(x0.shape) < len(x1.shape) { output, x1 = x1, output } for i := 0; i < len(x1.shape); i++ { if x0.shape[len(x0.shape)-len(x1.shape)+i] != x1.shape[i] { panic(fmt.Sprintf("the shape of one tensor %v must be a suffix sequence of the shape of the other tensor %v", x0.shape, x1.shape)) } } output = apply(&add{}, output, x1) } return output } // Sub returns the element-wise difference of two tensors. The shape of the second tensor must be a suffix sequence of the shape of the first tensor. func Sub(x0, x1 *Tensor) *Tensor { if len(x0.shape) < len(x1.shape) { panic(fmt.Sprintf("the shape of the second tensor %v must be a suffix sequence of the shape of the first tensor %v", x1.shape, x0.shape)) } for i := 0; i < len(x1.shape); i++ { if x0.shape[len(x0.shape)-len(x1.shape)+i] != x1.shape[i] { panic("the shape of the second tensor must be a suffix sequence of the shape of the first tensor") } } return apply(&sub{}, x0, x1) } // Mul returns the element-wise product of two tensors. The shape of the second tensor must be a suffix sequence of the shape of the first tensor. func Mul(x0, x1 *Tensor) *Tensor { if len(x0.shape) < len(x1.shape) { x0, x1 = x1, x0 } for i := 0; i < len(x1.shape); i++ { if x0.shape[len(x0.shape)-len(x1.shape)+i] != x1.shape[i] { panic(fmt.Sprintf("the shape of the second tensor %v must be a suffix sequence of the shape of the first tensor %v", x1.shape, x0.shape)) } } return apply(&mul{}, x0, x1) } // Div returns the element-wise division of two tensors. The shape of the second tensor must be a suffix sequence of the shape of the first tensor. func Div(x0, x1 *Tensor) *Tensor { if len(x0.shape) < len(x1.shape) { panic(fmt.Sprintf("the shape of the second tensor %v must be a suffix sequence of the shape of the first tensor %v", x1.shape, x0.shape)) } for i := 0; i < len(x1.shape); i++ { if x0.shape[len(x0.shape)-len(x1.shape)+i] != x1.shape[i] { panic("the shape of the second tensor must be a suffix sequence of the shape of the first tensor") } } return apply(&div{}, x0, x1) } // Square returns the element-wise square of a tensor. func Square(x *Tensor) *Tensor { return apply(&square{}, x) } // Pow returns the element-wise power of a tensor. The shape of the second tensor must be a suffix sequence of the shape of the first tensor. func Pow(x *Tensor, n *Tensor) *Tensor { if len(x.shape) < len(n.shape) { panic("the shape of the second tensor must be a suffix sequence of the shape of the first tensor") } for i := 0; i < len(n.shape); i++ { if n.shape[len(n.shape)-len(x.shape)+i] != x.shape[i] { panic("the shape of the second tensor must be a suffix sequence of the shape of the first tensor") } } return apply(&pow{}, x, n) } // Exp returns the element-wise exponential of a tensor. func Exp(x *Tensor) *Tensor { return apply(&exp{}, x) } // Log returns the element-wise natural logarithm of a tensor. func Log(x *Tensor) *Tensor { return apply(&log{}, x) } // Sin returns the element-wise sine of a tensor. func Sin(x *Tensor) *Tensor { return apply(&sin{}, x) } func Cos(x *Tensor) *Tensor { return apply(&cos{}, x) } func Abs(x *Tensor) *Tensor { return apply(&abs{}, x) } // Sum returns the sum of all elements in a tensor. func Sum(x *Tensor, along ...int) *Tensor { if len(along) > 1 { panic("only one along is allowed") } else if len(along) == 1 { return apply(&partialSum{along: int64(along[0])}, x) } return apply(&sum{}, x) } // Mean returns the mean of all elements in a tensor. func Mean(x *Tensor) *Tensor { return apply(&mean{}, x) } func MatMul(x, y *Tensor, transpose1, transpose2 bool, jobs int) *Tensor { op := &matMul{ transpose1: transpose1, transpose2: transpose2, jobs: jobs, } return apply(op, x, y) } func BMM(x, y *Tensor, transpose1, transpose2 bool, jobs int) *Tensor { op := &batchMatMul{ transpose1: transpose1, transpose2: transpose2, jobs: jobs, } return apply(op, x, y) } func Broadcast(x *Tensor, shape ...int) *Tensor { return apply(&broadcast{shape: shape}, x) } func Flatten(x *Tensor) *Tensor { return apply(&flatten{}, x) } func Reshape(x *Tensor, shape ...int) *Tensor { size1 := 1 for i := range x.shape { size1 *= x.shape[i] } size2 := 1 for i := range shape { size2 *= shape[i] } if size1 != size2 { panic("the size of the tensor must be equal to the size of the new shape") } return apply(&reshape{shape: shape}, x) } func Embedding(w, x *Tensor) *Tensor { return apply(&embedding{}, w, x) } func Sigmoid(x *Tensor) *Tensor { return apply(&sigmoid{}, x) } func ReLu(x *Tensor) *Tensor { return apply(&relu{}, x) } func Softmax(x *Tensor, axis int) *Tensor { return apply(&softmax{axis: axis}, x) } func MeanSquareError(x, y *Tensor) *Tensor { return Mean(Square(Sub(x, y))) } func SoftmaxCrossEntropy(x, y *Tensor) *Tensor { if len(x.shape) != 2 { panic("the shape of the first tensor must be 2-D") } if len(y.shape) != 1 { panic("the shape of the second tensor must be 1-D") } if x.shape[0] != y.shape[0] { panic("the size of the first tensor must be equal to the size of the second tensor") } return apply(&softmaxCrossEntropy{}, x, y) } // BCEWithLogits calculates the binary cross-entropy loss between target and prediction // with logits. This implementation is numerically stable. // It is equivalent to the formula: // // max(prediction, 0) - prediction*y + log(1 + exp(-|prediction|)) // // where y = (target + 1) / 2, target is -1 or 1. func BCEWithLogits(target, prediction, weights *Tensor) *Tensor { // To prevent overflow, we use the mathematically equivalent and more stable formula. // This avoids calculating exp(x) where x is a large positive number. // term1 = max(prediction, 0) term1 := ReLu(prediction) // y = (target + 1) / 2 y := Div(Add(NewScalar(1), target), NewScalar(2)) // term2 = prediction * y term2 := Mul(prediction, y) // term3 = log(1 + exp(-|prediction|)) absPrediction := Abs(prediction) expTerm := Exp(Neg(absPrediction)) logTerm := Log(Add(NewScalar(1), expTerm)) // loss = max(prediction, 0) - prediction*y + log(1 + exp(-|prediction|)) loss := Add(Sub(term1, term2), logTerm) if weights != nil { loss = Mul(loss, weights) } return Mean(loss) } ================================================ FILE: common/nn/layers.go ================================================ // Copyright 2024 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package nn import ( "io" "reflect" "strconv" "github.com/chewxy/math32" "github.com/gorse-io/gorse/protocol" "github.com/juju/errors" "github.com/matttproud/golang_protobuf_extensions/pbutil" ) type Layer interface { Parameters() []*Tensor Forward(x *Tensor) *Tensor SetJobs(jobs int) } type Model Layer type LinearLayer struct { W *Tensor B *Tensor jobs int } func NewLinear(in, out int) Layer { bound := 1.0 / math32.Sqrt(float32(in)) return &LinearLayer{ W: Uniform(-bound, bound, in, out), B: Zeros(out), } } func (l *LinearLayer) Forward(x *Tensor) *Tensor { return Add(MatMul(x, l.W, false, false, l.jobs), l.B) } func (l *LinearLayer) Parameters() []*Tensor { return []*Tensor{l.W, l.B} } func (l *LinearLayer) SetJobs(jobs int) { l.jobs = max(1, jobs) } type flattenLayer struct{} func NewFlatten() Layer { return &flattenLayer{} } func (f *flattenLayer) Parameters() []*Tensor { return nil } func (f *flattenLayer) Forward(x *Tensor) *Tensor { return Flatten(x) } func (f *flattenLayer) SetJobs(int) {} type EmbeddingLayer struct { W *Tensor } func NewEmbedding(n int, shape ...int) Layer { wShape := append([]int{n}, shape...) return &EmbeddingLayer{ W: Normal(0, 0.01, wShape...), } } func (e *EmbeddingLayer) Parameters() []*Tensor { return []*Tensor{e.W} } func (e *EmbeddingLayer) Forward(x *Tensor) *Tensor { return Embedding(e.W, x) } func (e *EmbeddingLayer) SetJobs(int) {} type sigmoidLayer struct{} func NewSigmoid() Layer { return &sigmoidLayer{} } func (s *sigmoidLayer) Parameters() []*Tensor { return nil } func (s *sigmoidLayer) Forward(x *Tensor) *Tensor { return Sigmoid(x) } func (s *sigmoidLayer) SetJobs(int) {} type reluLayer struct{} func NewReLU() Layer { return &reluLayer{} } func (r *reluLayer) Parameters() []*Tensor { return nil } func (r *reluLayer) Forward(x *Tensor) *Tensor { return ReLu(x) } func (r *reluLayer) SetJobs(int) {} type Sequential struct { Layers []Layer } func NewSequential(layers ...Layer) Model { return &Sequential{Layers: layers} } func (s *Sequential) Parameters() []*Tensor { var params []*Tensor for _, l := range s.Layers { params = append(params, l.Parameters()...) } return params } func (s *Sequential) Forward(x *Tensor) *Tensor { for _, l := range s.Layers { x = l.Forward(x) } return x } func (s *Sequential) SetJobs(jobs int) { for _, l := range s.Layers { l.SetJobs(jobs) } } type Attention struct { W Layer H *Tensor jobs int } func NewAttention(dimensions, k int) *Attention { return &Attention{ W: NewLinear(dimensions, k), H: Normal(0, 0.01, k, dimensions), } } func (a *Attention) Parameters() []*Tensor { var params []*Tensor params = append(params, a.H) params = append(params, a.W.Parameters()...) return params } func (a *Attention) Forward(x *Tensor) *Tensor { return Mul( Softmax(MatMul(ReLu(a.W.Forward(x)), a.H, false, false, a.jobs), 1), x, ) } func (a *Attention) SetJobs(jobs int) { a.W.SetJobs(jobs) a.jobs = max(1, jobs) } func Save(o any, w io.Writer) error { var save func(o any, key []string) error save = func(o any, key []string) error { switch typed := o.(type) { case *Tensor: pb := typed.toPB() pb.Key = key _, err := pbutil.WriteDelimited(w, pb) if err != nil { return err } default: tp := reflect.TypeOf(o) if tp.Kind() == reflect.Ptr { return save(reflect.ValueOf(o).Elem().Interface(), key) } else if tp.Kind() == reflect.Struct { for i := 0; i < tp.NumField(); i++ { field := tp.Field(i) if field.IsExported() { newKey := make([]string, len(key)) copy(newKey, key) newKey = append(newKey, field.Name) if err := save(reflect.ValueOf(o).Field(i).Interface(), newKey); err != nil { return err } } } } else if tp.Kind() == reflect.Slice { for i := 0; i < reflect.ValueOf(o).Len(); i++ { newKey := make([]string, len(key)) copy(newKey, key) newKey = append(newKey, strconv.Itoa(i)) if err := save(reflect.ValueOf(o).Index(i).Interface(), newKey); err != nil { return err } } } else { return errors.New("unexpected type") } } return nil } return save(o, nil) } func Load(o any, r io.Reader) error { var place func(o any, key []string, pb *protocol.Tensor) error place = func(o any, key []string, pb *protocol.Tensor) error { switch typed := o.(type) { case *Tensor: typed.fromPB(pb) default: tp := reflect.TypeOf(o) if tp.Kind() == reflect.Ptr { return place(reflect.ValueOf(o).Elem().Interface(), key, pb) } else if tp.Kind() == reflect.Struct { field := reflect.ValueOf(o).FieldByName(key[0]) if field.IsValid() { if err := place(field.Interface(), key[1:], pb); err != nil { return err } } } else if tp.Kind() == reflect.Slice { index, err := strconv.Atoi(key[0]) if err != nil { return err } elem := reflect.ValueOf(o).Index(index) if elem.IsValid() { if err := place(elem.Interface(), key[1:], pb); err != nil { return err } } } else { return errors.New("unexpected type") } } return nil } // Read data for { pb := new(protocol.Tensor) if _, err := pbutil.ReadDelimited(r, pb); err != nil { if errors.Is(err, io.EOF) { break } return err } if err := place(o, pb.Key, pb); err != nil { return err } } return nil } ================================================ FILE: common/nn/nn_test.go ================================================ // Copyright 2024 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package nn import ( "bufio" "bytes" "encoding/csv" "fmt" "math/rand" "os" "path/filepath" "runtime" "strconv" "strings" "testing" "time" "github.com/chewxy/math32" "github.com/gorse-io/gorse/common/datautil" "github.com/gorse-io/gorse/common/util" "github.com/samber/lo" "github.com/schollz/progressbar/v3" "github.com/stretchr/testify/assert" "golang.org/x/sys/cpu" ) func TestLinearRegression(t *testing.T) { x := Rand(100, 1) y := Add(Rand(100, 1), NewScalar(5), Mul(NewScalar(2), x)) w := Zeros(1, 1) b := Zeros(1) predict := func(x *Tensor) *Tensor { return Add(MatMul(x, w, false, false, 0), b) } lr := float32(0.1) for i := 0; i < 100; i++ { yPred := predict(x) loss := MeanSquareError(y, yPred) w.grad = nil b.grad = nil loss.Backward() w.sub(w.grad.mul(NewScalar(lr))) b.sub(b.grad.mul(NewScalar(lr))) } assert.Equal(t, []int{1, 1}, w.shape) assert.InDelta(t, float64(2), w.data[0], 0.6) assert.Equal(t, []int{1}, b.shape) assert.InDelta(t, float64(5), b.data[0], 0.6) } func TestNeuralNetwork(t *testing.T) { x := Rand(100, 1) y := Add(Rand(100, 1), Sin(Mul(x, NewScalar(2*math32.Pi)))) model := NewSequential( NewLinear(1, 10), NewSigmoid(), NewLinear(10, 1), ) NormalInit(model.(*Sequential).Layers[0].(*LinearLayer).W, 0, 0.01) NormalInit(model.(*Sequential).Layers[2].(*LinearLayer).W, 0, 0.01) optimizer := NewSGD(model.Parameters(), 0.2) var l float32 for i := 0; i < 10000; i++ { yPred := model.Forward(x) loss := MeanSquareError(y, yPred) optimizer.ZeroGrad() loss.Backward() optimizer.Step() l = loss.data[0] } assert.InDelta(t, float64(0), l, 0.2) } func iris() (*Tensor, *Tensor, error) { // Download dataset path, err := datautil.DownloadAndUnzip("iris") if err != nil { return nil, nil, err } dataFile := filepath.Join(path, "iris.data") // Load data f, err := os.Open(dataFile) if err != nil { return nil, nil, err } reader := csv.NewReader(f) rows, err := reader.ReadAll() if err != nil { return nil, nil, err } // Parse data data := make([]float32, len(rows)*4) target := make([]float32, len(rows)) types := make(map[string]int) for i, row := range rows { for j, cell := range row[:4] { data[i*4+j], err = util.ParseFloat[float32](cell) if err != nil { return nil, nil, err } } if _, exist := types[row[4]]; !exist { types[row[4]] = len(types) } target[i] = float32(types[row[4]]) } return NewTensor(data, len(rows), 4), NewTensor(target, len(rows)), nil } func TestIris(t *testing.T) { x, y, err := iris() assert.NoError(t, err) model := NewSequential( NewLinear(4, 100), NewLinear(100, 100), NewLinear(100, 3), ) optimizer := NewAdam(model.Parameters(), 0.01) var l float32 for i := 0; i < 1000; i++ { yPred := model.Forward(x) loss := SoftmaxCrossEntropy(yPred, y) optimizer.ZeroGrad() loss.Backward() optimizer.Step() l = loss.data[0] } assert.InDelta(t, float32(0), l, 0.1) } func mnist() (lo.Tuple2[*Tensor, *Tensor], lo.Tuple2[*Tensor, *Tensor], error) { var train, test lo.Tuple2[*Tensor, *Tensor] // Download and unzip dataset path, err := datautil.DownloadAndUnzip("mnist") if err != nil { return train, test, err } // Open dataset train.A, train.B, err = openMNISTFile(filepath.Join(path, "train.libfm")) if err != nil { return train, test, err } test.A, test.B, err = openMNISTFile(filepath.Join(path, "test.libfm")) if err != nil { return train, test, err } return train, test, nil } func openMNISTFile(path string) (*Tensor, *Tensor, error) { // Open file f, err := os.Open(path) if err != nil { return nil, nil, err } defer f.Close() // Read data line by line var ( images []float32 labels []float32 ) scanner := bufio.NewScanner(f) for scanner.Scan() { line := scanner.Text() splits := strings.Split(line, " ") // Parse label label, err := util.ParseFloat[float32](splits[0]) if err != nil { return nil, nil, err } labels = append(labels, label) // Parse image image := make([]float32, 784) for _, split := range splits[1:] { kv := strings.Split(split, ":") index, err := strconv.Atoi(kv[0]) if err != nil { return nil, nil, err } value, err := util.ParseFloat[float32](kv[1]) if err != nil { return nil, nil, err } image[index] = value } images = append(images, image...) } return NewTensor(images, len(labels), 784), NewTensor(labels, len(labels)), nil } func accuracy(prediction, target *Tensor) float32 { var precision float32 for i, gt := range target.data { if prediction.Slice(i, i+1).argmax()[1] == int(gt) { precision += 1 } } precision /= float32(len(target.data)) return precision } func TestMNIST(t *testing.T) { if (runtime.GOOS != "darwin" || runtime.GOARCH != "arm64") && !cpu.X86.HasAVX512F { // Since the test takes a long time, we run the test only in development environment. // 1. Mac with Apple Silicon. // 2. x86 CPU with AVX512 support. t.Skip("Skip test on non-development environment.") } train, test, err := mnist() assert.NoError(t, err) model := NewSequential( NewLinear(784, 1000), NewReLU(), NewLinear(1000, 10), ) model.SetJobs(runtime.NumCPU()) optimizer := NewAdam(model.Parameters(), 0.001) const ( batchSize = 1000 numEpoch = 5 ) for i := 0; i < numEpoch; i++ { startTime := time.Now() sumLoss, sumAcc := float32(0), float32(0) bar := progressbar.Default(int64(train.A.shape[0]), fmt.Sprintf("Epoch %v/%v", i+1, numEpoch)) for j := 0; j < train.A.shape[0]; j += batchSize { xBatch := train.A.Slice(j, j+batchSize) yBatch := train.B.Slice(j, j+batchSize) yPred := model.Forward(xBatch) loss := SoftmaxCrossEntropy(yPred, yBatch) optimizer.ZeroGrad() loss.Backward() optimizer.Step() sumLoss += loss.data[0] sumAcc += accuracy(yPred, yBatch) assert.NoError(t, bar.Add(batchSize)) } sumLoss /= float32(train.A.shape[0] / batchSize) sumAcc /= float32(train.A.shape[0] / batchSize) assert.NoError(t, bar.Finish()) fmt.Println("Duration:", time.Since(startTime), "Loss:", sumLoss, "Accuracy:", sumAcc) } SetInferenceMode(true) testAcc := accuracy(model.Forward(test.A), test.B) SetInferenceMode(false) fmt.Println("Test Accuracy:", testAcc) assert.Greater(t, float64(testAcc), 0.95) } func spiral() (*Tensor, *Tensor, error) { numData, numClass, inputDim := 100, 3, 2 dataSize := numClass * numData x := Zeros(dataSize, inputDim) t := Zeros(dataSize) for j := 0; j < numClass; j++ { for i := 0; i < numData; i++ { rate := float32(i) / float32(numData) radius := 1.0 * rate theta := float32(j)*4.0 + 4.0*rate + float32(rand.NormFloat64())*0.2 ix := numData*j + i x.data[ix*inputDim] = radius * math32.Sin(theta) x.data[ix*inputDim+1] = radius * math32.Cos(theta) t.data[ix] = float32(j) } } indices := rand.Perm(dataSize) x = x.SliceIndices(indices...) t = t.SliceIndices(indices...) return x, t, nil } func TestSaveAndLoad(t *testing.T) { x, y, err := spiral() assert.NoError(t, err) model := NewSequential( NewLinear(2, 10), NewSigmoid(), NewLinear(10, 3), ) optimizer := NewAdam(model.Parameters(), 0.01) var expected float32 for i := 0; i < 300; i++ { yPred := model.Forward(x) loss := SoftmaxCrossEntropy(yPred, y) optimizer.ZeroGrad() loss.Backward() optimizer.Step() expected = loss.data[0] } buffer := bytes.NewBuffer(nil) err = Save(model, buffer) assert.NoError(t, err) modelLoaded := NewSequential( NewLinear(2, 10), NewSigmoid(), NewLinear(10, 3), ) err = Load(modelLoaded, buffer) assert.NoError(t, err) yPred := modelLoaded.Forward(x) loss := SoftmaxCrossEntropy(yPred, y) assert.InDelta(t, float64(expected), float64(loss.data[0]), 0.01) } ================================================ FILE: common/nn/op.go ================================================ // Copyright 2024 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package nn import ( "weak" "github.com/chewxy/math32" "github.com/gorse-io/gorse/common/floats" ) type op interface { String() string forward(inputs ...*Tensor) *Tensor backward(dy *Tensor) []*Tensor inputsAndOutput() ([]*Tensor, *Tensor) setInputs(inputs ...*Tensor) setOutput(y *Tensor) generation() int setGeneration(gen int) } type base struct { inputs []*Tensor output weak.Pointer[Tensor] gen int } func (b *base) inputsAndOutput() ([]*Tensor, *Tensor) { return b.inputs, b.output.Value() } func (b *base) setInputs(inputs ...*Tensor) { b.inputs = inputs } func (b *base) setOutput(y *Tensor) { b.output = weak.Make(y) } func (b *base) generation() int { return b.gen } func (b *base) setGeneration(gen int) { b.gen = gen } func apply[T op](f T, inputs ...*Tensor) *Tensor { y := f.forward(inputs...) f.setInputs(inputs...) f.setOutput(y) if !inferenceMode.Load() { y.op = f gen := 0 for _, x := range inputs { gen = max(gen, x.generation()) } f.setGeneration(gen + 1) } else { y.op = nil } return y } type neg struct { base } func (n *neg) String() string { return "Neg" } func (n *neg) forward(inputs ...*Tensor) *Tensor { y := inputs[0].clone() y.neg() return y } func (n *neg) backward(dy *Tensor) []*Tensor { dx := dy.clone() dx.neg() return []*Tensor{dx} } type add struct { base } func (a *add) String() string { return "Add" } func (a *add) forward(inputs ...*Tensor) *Tensor { y := inputs[0].clone() y.add(inputs[1]) return y } func (a *add) backward(dy *Tensor) []*Tensor { gx0 := dy.clone() gx1 := Zeros(a.inputs[1].shape...) wSize := 1 for i := range gx1.shape { wSize *= gx1.shape[i] } for i := range dy.data { gx1.data[i%wSize] += dy.data[i] } return []*Tensor{gx0, gx1} } type sub struct { base } func (s *sub) String() string { return "Sub" } func (s *sub) forward(inputs ...*Tensor) *Tensor { y := inputs[0].clone() y.sub(inputs[1]) return y } func (s *sub) backward(dy *Tensor) []*Tensor { gx0 := dy.clone() gx1 := Zeros(s.inputs[1].shape...) wSize := 1 for i := range gx1.shape { wSize *= gx1.shape[i] } for i := range dy.data { gx1.data[i%wSize] -= dy.data[i] } return []*Tensor{gx0, gx1} } type mul struct { base } func (m *mul) String() string { return "Mul" } func (m *mul) forward(inputs ...*Tensor) *Tensor { y := inputs[0].clone() y.mul(inputs[1]) return y } func (m *mul) backward(dy *Tensor) []*Tensor { gx0 := dy.clone() gx0.mul(m.inputs[1]) gx1 := Zeros(m.inputs[1].shape...) wSize := 1 for i := range gx1.shape { wSize *= gx1.shape[i] } for i := range dy.data { gx1.data[i%wSize] += dy.data[i] * m.inputs[0].data[i] } return []*Tensor{gx0, gx1} } type div struct { base } func (d *div) String() string { return "Div" } func (d *div) forward(inputs ...*Tensor) *Tensor { y := inputs[0].clone() y.div(inputs[1]) return y } func (d *div) backward(dy *Tensor) []*Tensor { wSize := 1 for i := range d.inputs[1].shape { wSize *= d.inputs[1].shape[i] } gx0 := Zeros(d.inputs[0].shape...) for i := range dy.data { gx0.data[i] = dy.data[i] / d.inputs[1].data[i%wSize] } gx1 := Zeros(d.inputs[1].shape...) for i := range dy.data { gx1.data[i%wSize] -= dy.data[i] * d.inputs[0].data[i] / d.inputs[1].data[i%wSize] / d.inputs[1].data[i%wSize] } return []*Tensor{gx0, gx1} } type sin struct { base } func (s *sin) String() string { return "Sin" } func (s *sin) forward(inputs ...*Tensor) *Tensor { y := inputs[0].clone() y.sin() return y } func (s *sin) backward(dy *Tensor) []*Tensor { dx := s.inputs[0].clone() dx.cos() dx.mul(dy) return []*Tensor{dx} } type cos struct { base } func (c *cos) String() string { return "Cos" } func (c *cos) forward(inputs ...*Tensor) *Tensor { y := inputs[0].clone() y.cos() return y } func (c *cos) backward(dy *Tensor) []*Tensor { dx := c.inputs[0].clone() dx.sin() dx.neg() dx.mul(dy) return []*Tensor{dx} } type square struct { base } func (s *square) String() string { return "Square" } func (s *square) forward(inputs ...*Tensor) *Tensor { y := inputs[0].clone() y.square() return y } func (s *square) backward(dy *Tensor) []*Tensor { dx := s.inputs[0].clone() floats.MulTo(dx.data, dy.data, dx.data) floats.MulConst(dx.data, 2) return []*Tensor{dx} } type pow struct { base } func (p *pow) String() string { return "Pow" } func (p *pow) forward(inputs ...*Tensor) *Tensor { y := inputs[0].clone() y.pow(inputs[1]) return y } func (p *pow) backward(dy *Tensor) []*Tensor { dx0 := p.inputs[0].clone() dx0.pow(p.inputs[1]) dx0.mul(p.inputs[1]) dx0.div(p.inputs[0]) dx0.mul(dy) wSize := 1 for i := range p.inputs[1].shape { wSize *= p.inputs[1].shape[i] } dx1 := Zeros(p.inputs[1].shape...) for i := range dy.data { dx1.data[i%wSize] += dy.data[i] * p.output.Value().data[i] * math32.Log(p.inputs[0].data[i]) } return []*Tensor{dx0, dx1} } type exp struct { base } func (e *exp) String() string { return "Exp" } func (e *exp) forward(inputs ...*Tensor) *Tensor { y := inputs[0].clone() y.exp() return y } func (e *exp) backward(dy *Tensor) []*Tensor { dx := e.inputs[0].clone() dx.exp() dx.mul(dy) return []*Tensor{dx} } type log struct { base } func (l *log) String() string { return "Log" } func (l *log) forward(inputs ...*Tensor) *Tensor { y := inputs[0].clone() y.log() return y } func (l *log) backward(dy *Tensor) []*Tensor { dx := dy.clone() dx.div(l.inputs[0]) return []*Tensor{dx} } type abs struct { base } func (a *abs) String() string { return "Abs" } func (a *abs) forward(inputs ...*Tensor) *Tensor { y := inputs[0].clone() for i := range y.data { if y.data[i] < 0 { y.data[i] = -y.data[i] } } return y } func (a *abs) backward(dy *Tensor) []*Tensor { dx := dy.clone() for i := range dx.data { if a.inputs[0].data[i] < 0 { dx.data[i] = -dx.data[i] } } return []*Tensor{dx} } type sum struct { base } func (s *sum) String() string { return "Sum" } func (s *sum) forward(inputs ...*Tensor) *Tensor { x := inputs[0] y := NewTensor([]float32{0}) for i := range x.data { y.data[0] += x.data[i] } return y } func (s *sum) backward(dy *Tensor) []*Tensor { dx := Zeros(s.inputs[0].shape...) for i := range dx.data { dx.data[i] = dy.data[0] } return []*Tensor{dx} } type partialSum struct { base along int64 } func (p *partialSum) String() string { return "Sum" } func (p *partialSum) forward(inputs ...*Tensor) *Tensor { x := inputs[0] // Squash the shape. s1, s2, s3 := 1, 1, 1 for i := 0; i < len(x.shape); i++ { if int64(i) == p.along { s2 = x.shape[i] } else if int64(i) < p.along { s1 *= x.shape[i] } else { s3 *= x.shape[i] } } // Calculate the output size and shape. outputSize := s1 * s3 outputShape := make([]int, 0) for i := 0; i < len(x.shape); i++ { if int64(i) != p.along { outputShape = append(outputShape, x.shape[i]) } } // Calculate the output. y := NewTensor(make([]float32, outputSize), outputShape...) for i := 0; i < s1; i++ { for j := 0; j < s2; j++ { for k := 0; k < s3; k++ { y.data[i*s3+k] += x.data[i*s2*s3+j*s3+k] } } } return y } func (p *partialSum) backward(dy *Tensor) []*Tensor { x := p.inputs[0] // Squash the shape. s1, s2, s3 := 1, 1, 1 for i := 0; i < len(x.shape); i++ { if int64(i) == p.along { s2 = x.shape[i] } else if int64(i) < p.along { s1 *= x.shape[i] } else { s3 *= x.shape[i] } } // Calculate the output. dx := Zeros(x.shape...) for i := 0; i < s1; i++ { for j := 0; j < s2; j++ { for k := 0; k < s3; k++ { dx.data[i*s2*s3+j*s3+k] = dy.data[i*s3+k] } } } return []*Tensor{dx} } type mean struct { base } func (m *mean) String() string { return "Mean" } func (m *mean) forward(inputs ...*Tensor) *Tensor { x := inputs[0] y := NewTensor([]float32{0}) for i := range x.data { y.data[0] += x.data[i] } y.data[0] /= float32(len(x.data)) return y } func (m *mean) backward(dy *Tensor) []*Tensor { dx := Zeros(m.inputs[0].shape...) for i := range dx.data { dx.data[i] = dy.data[0] / float32(len(dx.data)) } return []*Tensor{dx} } type matMul struct { base transpose1 bool transpose2 bool jobs int } func (m *matMul) String() string { return "MatMul" } func (m *matMul) forward(inputs ...*Tensor) *Tensor { return inputs[0].matMul(inputs[1], m.transpose1, m.transpose2, m.jobs) } func (m *matMul) backward(dy *Tensor) []*Tensor { var dx0, dx1 *Tensor if !m.transpose1 && !m.transpose2 { // y = x0 * x1 // dx0 = dy * x1^T dx0 = dy.matMul(m.inputs[1], false, true, m.jobs) // dx1 = x0^T * dy dx1 = m.inputs[0].matMul(dy, true, false, m.jobs) } else if m.transpose1 && !m.transpose2 { // y = x0^T * x1 // dx0 = dy * x1^T dx0 = m.inputs[1].matMul(dy, false, true, m.jobs) // dx1 = dy^T * x0 dx1 = m.inputs[0].matMul(dy, false, false, m.jobs) } else if !m.transpose1 && m.transpose2 { // y = x0 * x1^T // dx0 = dy * x1 dx0 = dy.matMul(m.inputs[1], false, false, m.jobs) // dx1 = dy^T * x0 dx1 = dy.matMul(m.inputs[0], true, false, m.jobs) } else { // y = x0^T * x1^T // dx0 = x1 * dy^T dx0 = m.inputs[1].matMul(dy, true, true, m.jobs) // dx1 = dy * x0^T dx1 = dy.matMul(m.inputs[0], true, true, m.jobs) } return []*Tensor{dx0, dx1} } type batchMatMul struct { base transpose1 bool transpose2 bool jobs int } func (b *batchMatMul) String() string { return "BatchMatMul" } func (b *batchMatMul) forward(inputs ...*Tensor) *Tensor { return inputs[0].batchMatMul(inputs[1], b.transpose1, b.transpose2, b.jobs) } func (b *batchMatMul) backward(dy *Tensor) []*Tensor { var dx0, dx1 *Tensor if !b.transpose1 && !b.transpose2 { // y = x0 * x1 // dx0 = dy * x1^T dx0 = dy.batchMatMul(b.inputs[1], false, true, b.jobs) // dx1 = x0^T * dy dx1 = b.inputs[0].batchMatMul(dy, true, false, b.jobs) } else if b.transpose1 && !b.transpose2 { // y = x0^T * x1 // dx0 = dy * x1^T dx0 = b.inputs[1].batchMatMul(dy, false, true, b.jobs) // dx1 = dy^T * x0 dx1 = b.inputs[0].batchMatMul(dy, false, false, b.jobs) } else if !b.transpose1 && b.transpose2 { // y = x0 * x1^T // dx0 = dy * x1 dx0 = dy.batchMatMul(b.inputs[1], false, false, b.jobs) // dx1 = dy^T * x0 dx1 = dy.batchMatMul(b.inputs[0], true, false, b.jobs) } else { // y = x0^T * x1^T // dx0 = x1 * dy^T dx0 = b.inputs[1].batchMatMul(dy, true, true, b.jobs) // dx1 = dy * x0^T dx1 = dy.batchMatMul(b.inputs[0], true, true, b.jobs) } return []*Tensor{dx0, dx1} } type broadcast struct { base shape []int } func (b *broadcast) String() string { return "Broadcast" } func (b *broadcast) forward(inputs ...*Tensor) *Tensor { x := inputs[0] // Concatenate the shape shape := make([]int, len(x.shape)) copy(shape, x.shape) shape = append(shape, b.shape...) size := 1 for i := range shape { size *= shape[i] } // Create a new tensor with the new shape y := NewTensor(make([]float32, size), shape...) wSize := 1 for i := range b.shape { wSize *= b.shape[i] } for i := range x.data { for j := i * wSize; j < (i+1)*wSize; j++ { y.data[j] = x.data[i] } } return y } func (b *broadcast) backward(dy *Tensor) []*Tensor { gx := Zeros(b.inputs[0].shape...) wSize := 1 for i := range b.shape { wSize *= b.shape[i] } for i := range gx.data { for j := i * wSize; j < (i+1)*wSize; j++ { gx.data[i] += dy.data[j] } } return []*Tensor{gx} } type flatten struct { base } func (f *flatten) String() string { return "Flatten" } func (f *flatten) forward(inputs ...*Tensor) *Tensor { return NewTensor(inputs[0].data, len(inputs[0].data)) } func (f *flatten) backward(dy *Tensor) []*Tensor { return []*Tensor{NewTensor(dy.data, f.inputs[0].shape...)} } type reshape struct { base shape []int } func (r *reshape) String() string { return "Reshape" } func (r *reshape) forward(inputs ...*Tensor) *Tensor { return NewTensor(inputs[0].data, r.shape...) } func (r *reshape) backward(dy *Tensor) []*Tensor { return []*Tensor{NewTensor(dy.data, r.inputs[0].shape...)} } type embedding struct { base } func (e *embedding) String() string { return "Embedding" } func (e *embedding) forward(inputs ...*Tensor) *Tensor { w, x := inputs[0], inputs[1] // Calculate embedding size dim := 1 for i := 1; i < len(w.shape); i++ { dim *= w.shape[i] } // Calculate shape shape := make([]int, len(x.shape), len(x.shape)+1) copy(shape, x.shape) shape = append(shape, w.shape[1:]...) // Calculate data size size := 1 for _, s := range shape { size *= s } // Create output tensor data := make([]float32, size) for i := 0; i < len(x.data); i++ { index := int(x.data[i]) copy(data[i*dim:(i+1)*dim], w.data[index*dim:(index+1)*dim]) } return NewTensor(data, shape...) } func (e *embedding) backward(dy *Tensor) []*Tensor { w, x := e.inputs[0], e.inputs[1] dim := 1 for i := 1; i < len(w.shape); i++ { dim *= w.shape[i] } dw := Zeros(w.shape...) for i := 0; i < len(x.data); i++ { index := int(x.data[i]) for j := 0; j < dim; j++ { dw.data[index*dim+j] += dy.data[i*dim+j] } } return []*Tensor{dw} } type sigmoid struct { base } func (s *sigmoid) String() string { return "Sigmoid" } func (s *sigmoid) forward(inputs ...*Tensor) *Tensor { // y = tanh(x * 0.5) * 0.5 + 0.5 y := inputs[0].clone() y.mul(NewScalar(0.5)) y.tanh() y.mul(NewScalar(0.5)) y.add(NewScalar(0.5)) return y } func (s *sigmoid) backward(dy *Tensor) []*Tensor { // dx = dy * y * (1 - y) dx := s.output.Value().clone() dx.neg() dx.add(NewScalar(1)) dx.mul(s.output.Value()) dx.mul(dy) return []*Tensor{dx} } type relu struct { base } func (r *relu) String() string { return "ReLU" } func (r *relu) forward(inputs ...*Tensor) *Tensor { y := inputs[0].clone() y.maximum(NewScalar(0)) return y } func (r *relu) backward(dy *Tensor) []*Tensor { x := r.inputs[0] dx := x.clone().gt(NewScalar(0)).mul(dy) return []*Tensor{dx} } type softmax struct { base axis int } func (s *softmax) String() string { return "Softmax" } func (s *softmax) forward(inputs ...*Tensor) *Tensor { x := inputs[0] y := x.clone() y.sub(x.max(s.axis, true)) y.exp() y.div(y.sum(s.axis, true)) return y } func (s *softmax) backward(dy *Tensor) []*Tensor { y := s.output.Value() gx := y.clone() gx.mul(dy) sumdx := gx.sum(s.axis, true) y.mul(sumdx) gx.sub(y) return []*Tensor{gx} } type softmaxCrossEntropy struct { base } func (c *softmaxCrossEntropy) String() string { return "SoftmaxCrossEntropy" } func (c *softmaxCrossEntropy) forward(inputs ...*Tensor) *Tensor { x, t := inputs[0], inputs[1] m := x.max(1, true) s := x.clone().bSub(m) // x - m s = s.exp() // exp(x - m) s = s.sum(1, true) // sum(exp(x - m)) s.log() // log(sum(exp(x - m))) m.add(s) // m + log(sum(exp(x - m))) logP := x.clone().bSub(m) // x - (m + log(sum(exp(x - m)))) var crossEntropy float32 for i := 0; i < len(t.data); i++ { crossEntropy -= logP.Get(i, int(t.data[i])) } crossEntropy /= float32(len(t.data)) return NewScalar(crossEntropy) } func (c *softmaxCrossEntropy) backward(dy *Tensor) []*Tensor { x, t := c.inputs[0], c.inputs[1] // gy *= 1/N gy := dy.clone().mul(NewScalar(1 / float32(len(t.data)))) // y = softmax(x) y := x.clone() y.bSub(x.max(1, true)) y.exp() y.bDiv(y.sum(1, true)) // convert to one-hot oneHot := Zeros(x.shape...) for i := 0; i < len(t.data); i++ { oneHot.data[i*x.shape[1]+int(t.data[i])] = 1 } // y = (y - t_onehot) * gy y = y.sub(oneHot).mul(gy) return []*Tensor{y, Zeros(t.shape...)} } type opHeap []op func (h opHeap) Len() int { return len(h) } func (h opHeap) Less(i, j int) bool { return h[i].generation() > h[j].generation() } func (h opHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } func (h *opHeap) Push(o any) { *h = append(*h, o.(op)) } func (h *opHeap) Pop() any { old := *h n := len(old) x := old[n-1] *h = old[0 : n-1] return x } ================================================ FILE: common/nn/op_test.go ================================================ // Copyright 2024 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package nn import ( "testing" "github.com/chewxy/math32" "github.com/stretchr/testify/assert" ) const ( eps = 1e-4 rtol = 1e-2 atol = 5e-3 ) func numericalDiff(f func(*Tensor) *Tensor, x *Tensor) *Tensor { x0, x1 := x.clone(), x.clone() dx := make([]float32, len(x.data)) for i, v := range x.data { x0.data[i] = v - eps x1.data[i] = v + eps y0 := f(x0) y1 := f(x1) for j := range y0.data { dx[i] += (y1.data[j] - y0.data[j]) / (2 * eps) } x0.data[i] = v x1.data[i] = v } return NewTensor(dx, x.shape...) } func allClose(t *testing.T, a, b *Tensor) { if !assert.Equal(t, a.shape, b.shape) { return } for i := range a.data { if math32.Abs(a.data[i]-b.data[i]) > atol+rtol*math32.Abs(b.data[i]) { t.Fatalf("a.data[%d] = %f, b.data[%d] = %f\n", i, a.data[i], i, b.data[i]) return } } } func TestAdd(t *testing.T) { // (2,3) + (2,3) -> (2,3) x := NewTensor([]float32{1, 2, 3, 4, 5, 6}, 2, 3) y := NewTensor([]float32{2, 3, 4, 5, 6, 7}, 2, 3) z := Add(x, y) assert.Equal(t, []float32{3, 5, 7, 9, 11, 13}, z.data) // Test gradient x = Rand(2, 3) y = Rand(2, 3) z = Add(x, y) z.Backward() dx := numericalDiff(func(x *Tensor) *Tensor { return Add(x, y) }, x) allClose(t, x.grad, dx) dy := numericalDiff(func(y *Tensor) *Tensor { return Add(x, y) }, y) allClose(t, y.grad, dy) // (2,3) + () -> (2,3) x = NewTensor([]float32{1, 2, 3, 4, 5, 6}, 2, 3) y = NewTensor([]float32{2}) z = Add(x, y) assert.Equal(t, []float32{3, 4, 5, 6, 7, 8}, z.data) // Test gradient z.Backward() assert.Equal(t, []float32{1, 1, 1, 1, 1, 1}, x.grad.data) assert.Equal(t, []float32{6}, y.grad.data) // (2,3) + (3) -> (2,3) x = NewTensor([]float32{1, 2, 3, 4, 5, 6}, 2, 3) y = NewTensor([]float32{2, 3, 4}, 3) z = Add(x, y) assert.Equal(t, []float32{3, 5, 7, 6, 8, 10}, z.data) // Test gradient z.Backward() assert.Equal(t, []float32{1, 1, 1, 1, 1, 1}, x.grad.data) assert.Equal(t, []float32{2, 2, 2}, y.grad.data) } func TestSub(t *testing.T) { // (2,3) - (2,3) -> (2,3) x := NewTensor([]float32{1, 2, 3, 4, 5, 6}, 2, 3) y := NewTensor([]float32{2, 3, 4, 5, 6, 7}, 2, 3) z := Sub(x, y) assert.Equal(t, []float32{-1, -1, -1, -1, -1, -1}, z.data) // Test gradient x = Rand(2, 3) y = Rand(2, 3) z = Sub(x, y) z.Backward() dx := numericalDiff(func(x *Tensor) *Tensor { return Sub(x, y) }, x) allClose(t, x.grad, dx) dy := numericalDiff(func(y *Tensor) *Tensor { return Sub(x, y) }, y) allClose(t, y.grad, dy) // (2,3) - () -> (2,3) x = NewTensor([]float32{1, 2, 3, 4, 5, 6}, 2, 3) y = NewTensor([]float32{2}) z = Sub(x, y) assert.Equal(t, []float32{-1, 0, 1, 2, 3, 4}, z.data) // Test gradient z.Backward() assert.Equal(t, []float32{1, 1, 1, 1, 1, 1}, x.grad.data) assert.Equal(t, []float32{-6}, y.grad.data) // (2,3) - (3) -> (2,3) x = NewTensor([]float32{1, 2, 3, 4, 5, 6}, 2, 3) y = NewTensor([]float32{2, 3, 4}, 3) z = Sub(x, y) assert.Equal(t, []float32{-1, -1, -1, 2, 2, 2}, z.data) // Test gradient z.Backward() assert.Equal(t, []float32{1, 1, 1, 1, 1, 1}, x.grad.data) assert.Equal(t, []float32{-2, -2, -2}, y.grad.data) } func TestMul(t *testing.T) { // (2,3) * (2,3) -> (2,3) x := NewTensor([]float32{1, 2, 3, 4, 5, 6}, 2, 3) y := NewTensor([]float32{2, 3, 4, 5, 6, 7}, 2, 3) z := Mul(x, y) assert.Equal(t, []float32{2, 6, 12, 20, 30, 42}, z.data) // Test gradient x = Rand(2, 3) y = Rand(2, 3) z = Mul(x, y) z.Backward() dx := numericalDiff(func(x *Tensor) *Tensor { return Mul(x, y) }, x) allClose(t, x.grad, dx) dy := numericalDiff(func(y *Tensor) *Tensor { return Mul(x, y) }, y) allClose(t, y.grad, dy) // (2,3) * () -> (2,3) x = NewTensor([]float32{1, 2, 3, 4, 5, 6}, 2, 3) y = NewTensor([]float32{2}) z = Mul(x, y) assert.Equal(t, []float32{2, 4, 6, 8, 10, 12}, z.data) // Test gradient z.Backward() assert.Equal(t, []float32{2, 2, 2, 2, 2, 2}, x.grad.data) assert.Equal(t, []float32{21}, y.grad.data) // (2,3) * (3) -> (2,3) x = NewTensor([]float32{1, 2, 3, 4, 5, 6}, 2, 3) y = NewTensor([]float32{2, 3, 4}, 3) z = Mul(x, y) assert.Equal(t, []float32{2, 6, 12, 8, 15, 24}, z.data) // Test gradient z.Backward() assert.Equal(t, []float32{2, 3, 4, 2, 3, 4}, x.grad.data) assert.Equal(t, []float32{5, 7, 9}, y.grad.data) } func TestDiv(t *testing.T) { // (2,3) / (2,3) -> (2,3) x := NewTensor([]float32{1, 2, 3, 4, 5, 6}, 2, 3) y := NewTensor([]float32{2, 3, 4, 5, 6, 7}, 2, 3) z := Div(x, y) assert.InDeltaSlice(t, []float32{0.5, 2.0 / 3.0, 0.75, 4.0 / 5.0, 5.0 / 6.0, 6.0 / 7.0}, z.data, 1e-6) // Test gradient z.Backward() dx := numericalDiff(func(x *Tensor) *Tensor { return Div(x, y) }, x) allClose(t, x.grad, dx) dy := numericalDiff(func(y *Tensor) *Tensor { return Div(x, y) }, y) allClose(t, y.grad, dy) // (2,3) / () -> (2,3) x = NewTensor([]float32{1, 2, 3, 4, 5, 6}, 2, 3) y = NewTensor([]float32{2}) z = Div(x, y) assert.InDeltaSlice(t, []float32{0.5, 1, 1.5, 2, 2.5, 3}, z.data, 1e-6) // Test gradient z.Backward() assert.InDeltaSlice(t, []float32{0.5, 0.5, 0.5, 0.5, 0.5, 0.5}, x.grad.data, 1e-6) assert.InDeltaSlice(t, []float32{-21.0 / 4.0}, y.grad.data, 1e-6) // (2,3) / (3) -> (2,3) x = NewTensor([]float32{1, 2, 3, 4, 5, 6}, 2, 3) y = NewTensor([]float32{2, 3, 4}, 3) z = Div(x, y) assert.InDeltaSlice(t, []float32{0.5, 2.0 / 3.0, 3.0 / 4.0, 2, 5.0 / 3.0, 1.5}, z.data, 1e-6) // Test gradient z.Backward() assert.InDeltaSlice(t, []float32{1.0 / 2, 1.0 / 3, 1.0 / 4, 1.0 / 2, 1.0 / 3, 1.0 / 4}, x.grad.data, 1e-6) assert.InDeltaSlice(t, []float32{-5.0 / 4.0, -7.0 / 9.0, -9.0 / 16.0}, y.grad.data, 1e-6) } func TestSquare(t *testing.T) { // (2,3) -> (2,3) x := NewTensor([]float32{1, 2, 3, 4, 5, 6}, 2, 3) y := Square(x) assert.Equal(t, []float32{1, 4, 9, 16, 25, 36}, y.data) // Test gradient x = Rand(2, 3) y = Square(x) y.Backward() dx := numericalDiff(Square, x) allClose(t, x.grad, dx) } func TestPow(t *testing.T) { // (2,3) ** (2,3) -> (2,3) x := NewTensor([]float32{1, 2, 3, 4, 5, 6}, 2, 3) y := NewTensor([]float32{2, 3, 4, 5, 6, 7}, 2, 3) z := Pow(x, y) assert.InDeltaSlice(t, []float32{1, 8, 81, 1024, 15625, 279936}, z.data, 1e-6) // Test gradient x = Uniform(0.5, 1, 2, 3) y = Uniform(0.5, 1, 2, 3) z = Pow(x, y) z.Backward() dx := numericalDiff(func(x *Tensor) *Tensor { return Pow(x, y) }, x) allClose(t, x.grad, dx) dy := numericalDiff(func(y *Tensor) *Tensor { return Pow(x, y) }, y) allClose(t, y.grad, dy) // (2,3) ** () -> (2,3) x = NewTensor([]float32{1, 2, 3, 4, 5, 6}, 2, 3) y = NewTensor([]float32{2}) z = Pow(x, y) assert.InDeltaSlice(t, []float32{1, 4, 9, 16, 25, 36}, z.data, 1e-6) // Test gradient z.Backward() assert.InDeltaSlice(t, []float32{2, 4, 6, 8, 10, 12}, x.grad.data, 1e-6) assert.InDeltaSlice(t, []float32{ math32.Pow(1, 2)*math32.Log(1) + math32.Pow(2, 2)*math32.Log(2) + math32.Pow(3, 2)*math32.Log(3) + math32.Pow(4, 2)*math32.Log(4) + math32.Pow(5, 2)*math32.Log(5) + math32.Pow(6, 2)*math32.Log(6), }, y.grad.data, 1e-6) } func TestExp(t *testing.T) { // (2,3) -> (2,3) x := NewTensor([]float32{0, 1, 2, 3, 4, 5}, 2, 3) y := Exp(x) assert.InDeltaSlice(t, []float32{1, math32.Exp(1), math32.Exp(2), math32.Exp(3), math32.Exp(4), math32.Exp(5)}, y.data, 1e-5) // Test gradient x = Rand(2, 3) y = Exp(x) y.Backward() dx := numericalDiff(Exp, x) allClose(t, x.grad, dx) } func TestLog(t *testing.T) { // (2,3) -> (2,3) x := NewTensor([]float32{1, 2, 3, 4, 5, 6}, 2, 3) y := Log(x) assert.InDeltaSlice(t, []float32{0, math32.Log(2), math32.Log(3), math32.Log(4), math32.Log(5), math32.Log(6)}, y.data, 1e-6) // Test gradient x = Uniform(1, 2, 2, 3) y = Log(x) y.Backward() dx := numericalDiff(Log, x) allClose(t, x.grad, dx) } func TestAbs(t *testing.T) { // (2,3) -> (2,3) x := NewTensor([]float32{1, -2, 3, -4, 5, -6}, 2, 3) y := Abs(x) assert.Equal(t, []float32{1, 2, 3, 4, 5, 6}, y.data) // Test gradient x = Rand(2, 3) y = Abs(x) y.Backward() dx := numericalDiff(Abs, x) allClose(t, x.grad, dx) } func TestSum(t *testing.T) { // (2,3) -> () x := NewTensor([]float32{1, 2, 3, 4, 5, 6}, 2, 3) y := Sum(x) assert.Equal(t, []float32{21}, y.data) // Test gradient x = Rand(2, 3) y = Sum(x) y.Backward() assert.Equal(t, []float32{1, 1, 1, 1, 1, 1}, x.grad.data) // (2,3,2) -> (2,2) x = NewTensor([]float32{1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6}, 2, 3, 2) y = Sum(x, 1) assert.Equal(t, []int{2, 2}, y.shape) assert.Equal(t, []float32{9, 12, 9, 12}, y.data) // Test gradient x = Rand(2, 3, 2) y = Sum(x, 1) y.Backward() assert.Equal(t, []int{2, 3, 2}, x.grad.shape) assert.Equal(t, []float32{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, x.grad.data) } func TestMean(t *testing.T) { // (2,3) -> () x := NewTensor([]float32{1, 2, 3, 4, 5, 6}, 2, 3) y := Mean(x) assert.Equal(t, []float32{3.5}, y.data) // Test gradient x = Rand(2, 3) y = Mean(x) y.Backward() assert.Equal(t, []float32{1.0 / 6, 1.0 / 6, 1.0 / 6, 1.0 / 6, 1.0 / 6, 1.0 / 6}, x.grad.data) } func TestCos(t *testing.T) { // (2,3) -> (2,3) x := NewTensor([]float32{0, 0.1, 0.2, 0.3, 0.4, 0.5}, 2, 3) y := Cos(x) assert.InDeltaSlice(t, []float32{1, 0.9950041652780258, 0.9800665778412416, 0.955336489125606, 0.9210609940028851, 0.8775825618903728}, y.data, 1e-6) // Test gradient x = Rand(2, 3) y = Cos(x) y.Backward() dx := numericalDiff(Cos, x) allClose(t, x.grad, dx) } func TestSin(t *testing.T) { // (2,3) -> (2,3) x := NewTensor([]float32{0, 1, 2, 3, 4, 5}, 2, 3) y := Sin(x) assert.InDeltaSlice(t, []float32{0, 0.8414709848078965, 0.9092974268256817, 0.1411200080598672, -0.7568024953079282, -0.9589242746631385}, y.data, 1e-6) // Test gradient x = Rand(2, 3) y = Sin(x) y.Backward() dx := numericalDiff(Sin, x) allClose(t, x.grad, dx) } func TestMatMul(t *testing.T) { // (2,3) * (3,4) -> (2,4) x := NewTensor([]float32{1, 2, 3, 4, 5, 6}, 2, 3) y := NewTensor([]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, 3, 4) z := MatMul(x, y, false, false, 0) assert.Equal(t, []int{2, 4}, z.shape) assert.Equal(t, []float32{38, 44, 50, 56, 83, 98, 113, 128}, z.data) // Test gradient z.Backward() assert.Equal(t, []int{2, 3}, x.grad.shape) assert.Equal(t, []float32{10, 26, 42, 10, 26, 42}, x.grad.data) assert.Equal(t, []int{3, 4}, y.grad.shape) assert.Equal(t, []float32{5, 5, 5, 5, 7, 7, 7, 7, 9, 9, 9, 9}, y.grad.data) // (3,2).T * (3,4) -> (2,4) x = Rand(3, 2) y = Rand(3, 4) z = MatMul(x, y, true, false, 0) assert.Equal(t, []int{2, 4}, z.shape) z.Backward() assert.Equal(t, []int{3, 2}, x.grad.shape) assert.Equal(t, []int{3, 4}, y.grad.shape) // (2,3) * (4,3).T -> (2,4) x = Rand(2, 3) y = Rand(4, 3) z = MatMul(x, y, false, true, 0) assert.Equal(t, []int{2, 4}, z.shape) z.Backward() assert.Equal(t, []int{2, 3}, x.grad.shape) assert.Equal(t, []int{4, 3}, y.grad.shape) // (3,2).T * (4,3).T -> (2,4) x = Rand(3, 2) y = Rand(4, 3) z = MatMul(x, y, true, true, 0) assert.Equal(t, []int{2, 4}, z.shape) z.Backward() assert.Equal(t, []int{3, 2}, x.grad.shape) } func TestBMM(t *testing.T) { // (2,2,3) * (2,3,4) -> (2,2,4) x := NewTensor([]float32{1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6}, 2, 2, 3) y := NewTensor([]float32{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, }, 2, 3, 4) z := BMM(x, y, false, false, 0) assert.Equal(t, []int{2, 2, 4}, z.shape) assert.Equal(t, []float32{ 38, 44, 50, 56, 83, 98, 113, 128, 38, 44, 50, 56, 83, 98, 113, 128, }, z.data) // Test gradient z.Backward() assert.Equal(t, []int{2, 2, 3}, x.grad.shape) assert.Equal(t, []float32{ 10, 26, 42, 10, 26, 42, 10, 26, 42, 10, 26, 42, }, x.grad.data) assert.Equal(t, []int{2, 3, 4}, y.grad.shape) assert.Equal(t, []float32{ 5, 5, 5, 5, 7, 7, 7, 7, 9, 9, 9, 9, 5, 5, 5, 5, 7, 7, 7, 7, 9, 9, 9, 9, }, y.grad.data) // (2,3,2).T * (2,3,4) -> (2,2,4) x = Rand(2, 3, 2) y = Rand(2, 3, 4) z = BMM(x, y, true, false, 0) assert.Equal(t, []int{2, 2, 4}, z.shape) z.Backward() assert.Equal(t, []int{2, 3, 2}, x.grad.shape) // (2,2,3) * (2,4,3).T -> (2,2,4) x = Rand(2, 2, 3) y = Rand(2, 4, 3) z = BMM(x, y, false, true, 0) assert.Equal(t, []int{2, 2, 4}, z.shape) z.Backward() assert.Equal(t, []int{2, 2, 3}, x.grad.shape) // (2,3,2).T * (2,43).T -> (2,2,4) x = Rand(2, 3, 2) y = Rand(2, 4, 3) z = BMM(x, y, true, true, 0) assert.Equal(t, []int{2, 2, 4}, z.shape) z.Backward() assert.Equal(t, []int{2, 3, 2}, x.grad.shape) } func TestBroadcast(t *testing.T) { // (2) -> (2,3) x := NewTensor([]float32{1, 2}, 2) y := Broadcast(x, 3) assert.Equal(t, []float32{1, 1, 1, 2, 2, 2}, y.data) // Test gradient y.Backward() assert.Equal(t, []float32{3, 3}, x.grad.data) } func TestEmbedding(t *testing.T) { // (2,3) -> (2,3,2) x := NewTensor([]float32{0, 1, 0, 3, 0, 5}, 2, 3) w := NewTensor([]float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, 6, 2) y := Embedding(w, x) assert.Equal(t, []int{2, 3, 2}, y.shape) assert.Equal(t, []float32{0, 1, 2, 3, 0, 1, 6, 7, 0, 1, 10, 11}, y.data) // Test gradient y.Backward() assert.Nil(t, x.grad) assert.Equal(t, []float32{3, 3, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1}, w.grad.data) // (2,3) -> (2,3,1,2) x = NewTensor([]float32{0, 1, 0, 3, 0, 5}, 2, 3) w = NewTensor([]float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, 6, 1, 2) y = Embedding(w, x) assert.Equal(t, []int{2, 3, 1, 2}, y.shape) assert.Equal(t, []float32{0, 1, 2, 3, 0, 1, 6, 7, 0, 1, 10, 11}, y.data) // Test gradient y.Backward() assert.Nil(t, x.grad) assert.Equal(t, []float32{3, 3, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1}, w.grad.data) } func TestSigmoid(t *testing.T) { // (2,3) -> (2,3) x := NewTensor([]float32{0, 1, 2, 3, 4, 5}, 2, 3) y := Sigmoid(x) assert.InDeltaSlice(t, []float32{0.5, 0.7310585786300049, 0.8807970779778823, 0.9525741268224334, 0.9820137900379085, 0.9933071490757153}, y.data, 1e-6) // Test gradient x = Rand(2, 3) y = Sigmoid(x) y.Backward() dx := numericalDiff(Sigmoid, x) allClose(t, x.grad, dx) } func TestReLu(t *testing.T) { // (2,3) -> (2,3) x := NewTensor([]float32{-1, 0, 1, 2, 3, 4}, 2, 3) y := ReLu(x) assert.Equal(t, []float32{0, 0, 1, 2, 3, 4}, y.data) // Test gradient x = Rand(2, 3) y = ReLu(x) y.Backward() dx := numericalDiff(ReLu, x) allClose(t, x.grad, dx) } func TestSoftmax(t *testing.T) { // (1,3) -> (1,3) x := NewTensor([]float32{3.0, 1.0, 0.2}, 1, 3) y := Softmax(x, 1) assert.Equal(t, []int{1, 3}, y.shape) assert.InDeltaSlice(t, []float32{0.8360188027814407, 0.11314284146556013, 0.05083835575299916}, y.data, 1e-6) // Test gradient y.Backward() dx := numericalDiff(func(x *Tensor) *Tensor { return Softmax(x, 1) }, x) allClose(t, x.grad, dx) } func TestFlatten(t *testing.T) { // (2,3) -> (6) x := NewTensor([]float32{1, 2, 3, 4, 5, 6}, 2, 3) y := Flatten(x) assert.Equal(t, []float32{1, 2, 3, 4, 5, 6}, y.data) // Test gradient y.Backward() assert.Equal(t, []float32{1, 1, 1, 1, 1, 1}, x.grad.data) } func TestReshape(t *testing.T) { // (2,3) -> (3,2) x := NewTensor([]float32{1, 2, 3, 4, 5, 6}, 2, 3) y := Reshape(x, 3, 2) assert.Equal(t, []float32{1, 2, 3, 4, 5, 6}, y.data) // Test gradient y.Backward() assert.Equal(t, []float32{1, 1, 1, 1, 1, 1}, x.grad.data) } func TestSoftmaxCrossEntropy(t *testing.T) { // (2,3) -> (2,3) x := NewTensor([]float32{0.3, 2.9, 4.0, 0.2, 1.0, 3.0}, 3, 2) y := NewTensor([]float32{1, 0, 1}, 3) z := SoftmaxCrossEntropy(x, y) assert.Empty(t, z.shape) assert.InDelta(t, float32(0.07356563982184072), z.data[0], 1e-4) // Test gradient z.Backward() dx := numericalDiff(func(x *Tensor) *Tensor { return SoftmaxCrossEntropy(x, y) }, x) allClose(t, x.grad, dx) } func TestReuseLeaf(t *testing.T) { // x + x x := NewTensor([]float32{1, 2, 3, 4, 5, 6}, 2, 3) y := Add(x, x) assert.Equal(t, []float32{2, 4, 6, 8, 10, 12}, y.data) // Test gradient y.Backward() dx := numericalDiff(func(x *Tensor) *Tensor { return Add(x, x) }, x) allClose(t, x.grad, dx) } func TestReuseNode(t *testing.T) { // x^2 + x^2 x := NewTensor([]float32{1, 2, 3, 4, 5, 6}, 2, 3) temp := Pow(x, NewTensor([]float32{2})) y := Add(temp, temp) assert.Equal(t, []float32{2, 8, 18, 32, 50, 72}, y.data) // Test gradient y.Backward() dx := numericalDiff(func(x *Tensor) *Tensor { temp := Pow(x, NewTensor([]float32{2})) return Add(temp, temp) }, x) allClose(t, x.grad, dx) } func TestDependency(t *testing.T) { // x^2 + 2x^2 x := NewTensor([]float32{1, 2, 3, 4, 5, 6}, 2, 3) temp := Pow(x, NewTensor([]float32{2})) y := Add(temp, Mul(NewTensor([]float32{2}), temp)) assert.Equal(t, []float32{3, 12, 27, 48, 75, 108}, y.data) // Test gradient y.Backward() dx := numericalDiff(func(x *Tensor) *Tensor { temp := Pow(x, NewTensor([]float32{2})) return Add(temp, Mul(NewTensor([]float32{2}), temp)) }, x) allClose(t, x.grad, dx) } func TestSphere(t *testing.T) { // x^2 + y^2 x := NewScalar(1) y := NewScalar(1) z := Add(Mul(x, x), Mul(y, y)) assert.Equal(t, []float32{2}, z.data) // Test gradient z.Backward() dx := numericalDiff(func(x *Tensor) *Tensor { return Add(Mul(x, x), Mul(y, y)) }, x) dy := numericalDiff(func(y *Tensor) *Tensor { return Add(Mul(x, x), Mul(y, y)) }, y) allClose(t, x.grad, dx) allClose(t, y.grad, dy) } func TestMatyas(t *testing.T) { // 0.26 * (x^2 + y^2) - 0.48 * x * y x := NewScalar(1) y := NewScalar(1) z := Sub(Mul(NewScalar(0.26), Add(Mul(x, x), Mul(y, y))), Mul(NewScalar(0.48), Mul(x, y))) assert.InDeltaSlice(t, []float32{0.04}, z.data, 1e-6) // Test gradient z.Backward() dx := numericalDiff(func(x *Tensor) *Tensor { return Sub(Mul(NewScalar(0.26), Add(Mul(x, x), Mul(y, y))), Mul(NewScalar(0.48), Mul(x, y))) }, x) dy := numericalDiff(func(y *Tensor) *Tensor { return Sub(Mul(NewScalar(0.26), Add(Mul(x, x), Mul(y, y))), Mul(NewScalar(0.48), Mul(x, y))) }, y) allClose(t, x.grad, dx) allClose(t, y.grad, dy) } func TestGoldsteinPrice(t *testing.T) { // (1 + (x + y + 1)^2 * (19 - 14x + 3x^2 - 14y + 6xy + 3y^2)) * (30 + (2x - 3y)^2 * (18 - 32x + 12x^2 + 48y - 36xy + 27y^2)) x := NewScalar(1) y := NewScalar(1) z := Mul( Add(NewScalar(1), Mul( Pow(Add(x, y, NewScalar(1)), NewScalar(2)), // (x + y + 1)^2 Add( NewScalar(19), // 19 Mul(NewScalar(-14), x), // -14x Mul(NewScalar(3), Pow(x, NewScalar(2))), // 3x^2 Mul(NewScalar(-14), y), // -14y Mul(NewScalar(6), Mul(x, y)), // 6xy Mul(NewScalar(3), Pow(y, NewScalar(2)))))), // 3y^2 Add(NewScalar(30), Mul( Pow(Sub(Mul(NewScalar(2), x), Mul(NewScalar(3), y)), NewScalar(2)), // (2x - 3y)^2 Add( NewScalar(18), // 18 Mul(NewScalar(-32), x), // -32x Mul(NewScalar(12), Pow(x, NewScalar(2))), // 12x^2 Mul(NewScalar(48), y), // 48y Mul(NewScalar(-36), Mul(x, y)), // -36xy Mul(NewScalar(27), Pow(y, NewScalar(2))))))) // 27y^2 assert.InDeltaSlice(t, []float32{1876}, z.data, 1e-6) // Test gradient z.Backward() dx := numericalDiff(func(x *Tensor) *Tensor { return Mul( Add(NewScalar(1), Mul( Pow(Add(x, y, NewScalar(1)), NewScalar(2)), // (x + y + 1)^2 Add( NewScalar(19), // 19 Mul(NewScalar(-14), x), // -14x Mul(NewScalar(3), Pow(x, NewScalar(2))), // 3x^2 Mul(NewScalar(-14), y), // -14y Mul(NewScalar(6), Mul(x, y)), // 6xy Mul(NewScalar(3), Pow(y, NewScalar(2)))))), // 3y^2 Add(NewScalar(30), Mul( Pow(Sub(Mul(NewScalar(2), x), Mul(NewScalar(3), y)), NewScalar(2)), // (2x - 3y)^2 Add( NewScalar(18), // 18 Mul(NewScalar(-32), x), // -32x Mul(NewScalar(12), Pow(x, NewScalar(2))), // 12x^2 Mul(NewScalar(48), y), // 48y Mul(NewScalar(-36), Mul(x, y)), // -36xy Mul(NewScalar(27), Pow(y, NewScalar(2))))))) // 27y^2 }, x) dy := numericalDiff(func(y *Tensor) *Tensor { return Mul( Add(NewScalar(1), Mul( Pow(Add(x, y, NewScalar(1)), NewScalar(2)), // (x + y + 1)^2 Add( NewScalar(19), // 19 Mul(NewScalar(-14), x), // -14x Mul(NewScalar(3), Pow(x, NewScalar(2))), // 3x^2 Mul(NewScalar(-14), y), // -14y Mul(NewScalar(6), Mul(x, y)), // 6xy Mul(NewScalar(3), Pow(y, NewScalar(2)))))), // 3y^2 Add(NewScalar(30), Mul( Pow(Sub(Mul(NewScalar(2), x), Mul(NewScalar(3), y)), NewScalar(2)), // (2x - 3y)^2 Add( NewScalar(18), // 18 Mul(NewScalar(-32), x), // -32x Mul(NewScalar(12), Pow(x, NewScalar(2))), // 12x^2 Mul(NewScalar(48), y), // 48y Mul(NewScalar(-36), Mul(x, y)), // -36xy Mul(NewScalar(27), Pow(y, NewScalar(2))))))) // 27y^2 }, y) allClose(t, x.grad, dx) allClose(t, y.grad, dy) } ================================================ FILE: common/nn/optimizers.go ================================================ // Copyright 2024 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package nn import ( "sync" "github.com/chewxy/math32" "github.com/gorse-io/gorse/common/floats" "github.com/samber/lo" ) type Optimizer interface { SetWeightDecay(rate float32) SetJobs(jobs int) ZeroGrad() Step() } type baseOptimizer struct { params []*Tensor wd float32 jobs int } func (o *baseOptimizer) ZeroGrad() { for _, p := range o.params { p.grad = nil } } func (o *baseOptimizer) SetWeightDecay(wd float32) { o.wd = wd } func (o *baseOptimizer) SetJobs(jobs int) { o.jobs = jobs } type SGD struct { baseOptimizer lr float32 b []float32 } func NewSGD(params []*Tensor, lr float32) Optimizer { bufSize := 0 for _, p := range params { bufSize = max(bufSize, len(p.data)) } return &SGD{ baseOptimizer: baseOptimizer{params: params}, lr: lr, b: make([]float32, bufSize), } } func (s *SGD) Step() { for _, p := range s.params { b := s.b[:len(p.data)] parts := partitionAligned(len(p.data), s.jobs, 32) var wg sync.WaitGroup for _, part := range parts { i, j := part.A, part.B wg.Go(func() { floats.MulConstAddTo(p.data[i:j], s.wd, p.grad.data[i:j], b[i:j]) floats.MulConstAdd(b[i:j], -s.lr, p.data[i:j]) }) } wg.Wait() } } type Adam struct { baseOptimizer alpha float32 beta1 float32 beta2 float32 eps float32 ms map[*Tensor]*Tensor vs map[*Tensor]*Tensor t float32 b1 []float32 b2 []float32 } func NewAdam(params []*Tensor, alpha float32) Optimizer { bufSize := 0 for _, p := range params { bufSize = max(bufSize, len(p.data)) } return &Adam{ baseOptimizer: baseOptimizer{params: params}, alpha: alpha, beta1: 0.9, beta2: 0.999, eps: 1e-8, ms: make(map[*Tensor]*Tensor), vs: make(map[*Tensor]*Tensor), b1: make([]float32, bufSize), b2: make([]float32, bufSize), } } func (a *Adam) Step() { a.t++ fix1 := 1 - math32.Pow(a.beta1, a.t) fix2 := 1 - math32.Pow(a.beta2, a.t) lr := a.alpha * math32.Sqrt(fix2) / fix1 for _, p := range a.params { if _, ok := a.ms[p]; !ok { a.ms[p] = Zeros(p.shape...) a.vs[p] = Zeros(p.shape...) } m, v := a.ms[p], a.vs[p] b1, b2 := a.b1[:len(p.data)], a.b2[:len(p.data)] parts := partitionAligned(len(p.data), a.jobs, 32) var wg sync.WaitGroup for _, part := range parts { i, j := part.A, part.B wg.Go(func() { // grad = grad + wd * param.data floats.MulConstAddTo(p.data[i:j], a.wd, p.grad.data[i:j], b1[i:j]) // m += (1 - beta1) * (grad - m) floats.SubTo(b1[i:j], m.data[i:j], b2[i:j]) floats.MulConstAdd(b2[i:j], 1-a.beta1, m.data[i:j]) // v += (1 - beta2) * (grad * grad - v) floats.MulTo(b1[i:j], b1[i:j], b2[i:j]) floats.Sub(b2[i:j], v.data[i:j]) floats.MulConstAdd(b2[i:j], 1-a.beta2, v.data[i:j]) // param.data -= self.lr * m / (xp.sqrt(v) + eps) floats.SqrtTo(v.data[i:j], b2[i:j]) floats.AddConst(b2[i:j], a.eps) floats.DivTo(m.data[i:j], b2[i:j], b1[i:j]) floats.MulConstAdd(b1[i:j], -lr, p.data[i:j]) }) } wg.Wait() } } // partitionAligned partitions n-size slice into m parts. Each part is aligned to k elements except the last part. For example: // split(10, 3, 2) = [(0, 4), (4, 8), (8, 10)] func partitionAligned(n, m, k int) []lo.Tuple2[int, int] { if n <= 0 { return nil } if m <= 0 { return []lo.Tuple2[int, int]{{0, n}} } if k <= 0 { k = 1 } // calculate the size of each part partSize := n / m if partSize%k != 0 { partSize += k - partSize%k } if partSize == 0 { partSize = k } // split the slice into m parts parts := make([]lo.Tuple2[int, int], 0, m) start := 0 for start < n { end := start + partSize if end > n { end = n } parts = append(parts, lo.Tuple2[int, int]{A: start, B: end}) start = end } return parts } ================================================ FILE: common/nn/tensor.go ================================================ // Copyright 2024 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package nn import ( "container/heap" "context" "fmt" "math" "math/rand" "strings" "sync" "sync/atomic" "github.com/chewxy/math32" mapset "github.com/deckarep/golang-set/v2" "github.com/gorse-io/gorse/common/floats" "github.com/gorse-io/gorse/common/parallel" "github.com/gorse-io/gorse/protocol" "github.com/samber/lo" "golang.org/x/exp/slices" ) // inferenceMode disables gradient computation when enabled, improving performance during model evaluation. var inferenceMode = atomic.Bool{} // SetInferenceMode enables or disables inference mode, which disables gradient computation to improve performance during model evaluation. func SetInferenceMode(enabled bool) { inferenceMode.Store(enabled) } type Tensor struct { data []float32 shape []int grad *Tensor op op } func NewTensor(data []float32, shape ...int) *Tensor { size := 1 for i := range shape { size *= shape[i] } if len(data) != size { panic(fmt.Sprintf("shape %v does not match data size %v", shape, len(data))) } return &Tensor{ data: data, shape: shape, } } func NewScalar(data float32) *Tensor { return &Tensor{ data: []float32{data}, shape: []int{}, } } func LinSpace(start, end float32, shape ...int) *Tensor { n := 1 for _, s := range shape { n *= s } data := make([]float32, n) delta := (end - start) / float32(n-1) for i := range data { data[i] = start + delta*float32(i) } return &Tensor{ data: data, shape: shape, } } func Rand(shape ...int) *Tensor { n := 1 for _, s := range shape { n *= s } data := make([]float32, n) for i := range data { data[i] = rand.Float32() } return &Tensor{ data: data, shape: shape, } } func Uniform(low, high float32, shape ...int) *Tensor { n := 1 for _, s := range shape { n *= s } data := make([]float32, n) for i := range data { data[i] = rand.Float32()*(high-low) + low } return &Tensor{ data: data, shape: shape, } } func Normal(mean, std float32, shape ...int) *Tensor { n := 1 for _, s := range shape { n *= s } data := make([]float32, n) for i := range data { data[i] = float32(rand.NormFloat64())*std + mean } return &Tensor{ data: data, shape: shape, } } // Ones creates a tensor filled with ones. func Ones(shape ...int) *Tensor { n := 1 for _, s := range shape { n *= s } data := make([]float32, n) for i := range data { data[i] = 1 } return &Tensor{ data: data, shape: shape, } } // Zeros creates a tensor filled with zeros. func Zeros(shape ...int) *Tensor { n := 1 for _, s := range shape { n *= s } data := make([]float32, n) return &Tensor{ data: data, shape: shape, } } func (t *Tensor) generation() int { if t.op != nil { return t.op.generation() } return 0 } func (t *Tensor) IsScalar() bool { return len(t.shape) == 0 } // NoGrad convert a node tensor to a leaf tensor. func (t *Tensor) NoGrad() *Tensor { if t.op != nil { t.op = nil } return t } func (t *Tensor) Shape() []int { return t.shape } // Slice returns a slice of the tensor. func (t *Tensor) Slice(start, end int) *Tensor { if len(t.shape) < 1 { panic("slice requires at least 1-D tensor") } if start < 0 || end > t.shape[0] { panic("slice out of range") } subSize := 1 for i := 1; i < len(t.shape); i++ { subSize *= t.shape[i] } return &Tensor{ data: t.data[start*subSize : end*subSize], shape: append([]int{end - start}, t.shape[1:]...), } } func (t *Tensor) SliceIndices(indices ...int) *Tensor { shape := []int{len(indices)} subSize := 1 for i := range t.shape[1:] { shape = append(shape, t.shape[i+1]) subSize *= t.shape[i+1] } data := make([]float32, len(indices)*subSize) for i, index := range indices { copy(data[i*subSize:(i+1)*subSize], t.data[index*subSize:(index+1)*subSize]) } return &Tensor{ data: data, shape: shape, } } // Get returns the value of the tensor at the given indices. func (t *Tensor) Get(indices ...int) float32 { if len(indices) != len(t.shape) { panic("the number of indices does not match the shape of the tensor") } index := 0 for i := range indices { if indices[i] < 0 || indices[i] >= t.shape[i] { panic("index out of range") } index = index*t.shape[i] + indices[i] } return t.data[index] } func (t *Tensor) String() string { // Print scalar value if len(t.shape) == 0 { return fmt.Sprint(t.data[0]) } builder := strings.Builder{} builder.WriteString("[") if len(t.data) <= 10 { for i := 0; i < len(t.data); i++ { fmt.Fprint(&builder, t.data[i]) if i != len(t.data)-1 { builder.WriteString(", ") } } } else { for i := 0; i < 5; i++ { fmt.Fprint(&builder, t.data[i]) builder.WriteString(", ") } builder.WriteString("..., ") for i := len(t.data) - 5; i < len(t.data); i++ { fmt.Fprint(&builder, t.data[i]) if i != len(t.data)-1 { builder.WriteString(", ") } } } builder.WriteString("]") return builder.String() } func (t *Tensor) Backward() { t.grad = Ones(t.shape...) ops := &opHeap{t.op} seen := mapset.NewSet[op](t.op) for ops.Len() > 0 { op := heap.Pop(ops).(op) inputs, output := op.inputsAndOutput() grads := op.backward(output.grad) for i := range grads { if !slices.Equal(inputs[i].shape, grads[i].shape) { panic(fmt.Sprintf("%s: shape %v does not match shape %v", op.String(), inputs[i].shape, grads[i].shape)) } if inputs[i].grad == nil { inputs[i].grad = grads[i] } else { inputs[i].grad.add(grads[i]) } if inputs[i].op != nil && !seen.Contains(inputs[i].op) { heap.Push(ops, inputs[i].op) seen.Add(inputs[i].op) } } output.grad = nil } } func (t *Tensor) Grad() *Tensor { return t.grad } func (t *Tensor) Data() []float32 { return t.data } func (t *Tensor) clone() *Tensor { newData := make([]float32, len(t.data)) copy(newData, t.data) return &Tensor{ data: newData, shape: t.shape, } } func (t *Tensor) add(other *Tensor) *Tensor { wSize := 1 for i := range other.shape { wSize *= other.shape[i] } if wSize == 1 { floats.AddConst(t.data, other.data[0]) } else { for i := 0; i < len(t.data); i += wSize { floats.Add(t.data[i:i+wSize], other.data) } } return t } // sub returns the element-wise addition of two tensors. The shape // of the second tensor must be a suffix sequence of the shape of // the first tensor: (...,m,n) - (m,n) = (...,m,n). func (t *Tensor) sub(other *Tensor) *Tensor { wSize := 1 for i := range other.shape { wSize *= other.shape[i] } for i := range t.data { t.data[i] -= other.data[i%wSize] } return t } // bSub returns the element-wise addition of two tensors. The shape // of the second tensor must be a prefix sequence of the shape of // the first tensor: (m,n,...) - (m,n) = (m,n,...). func (t *Tensor) bSub(other *Tensor) *Tensor { bSize := 1 for i := range t.shape { bSize *= t.shape[i] } for i := range other.shape { bSize /= other.shape[i] } for i := range t.data { t.data[i] -= other.data[i/bSize] } return t } func (t *Tensor) mul(other *Tensor) *Tensor { wSize := 1 for i := range other.shape { wSize *= other.shape[i] } for i := range t.data { t.data[i] *= other.data[i%wSize] } return t } // div returns the element-wise division of two tensors. The shape // of the second tensor must be a suffix sequence of the shape of // the first tensor: (...,m,n) / (m,n) = (...,m,n). func (t *Tensor) div(other *Tensor) *Tensor { wSize := 1 for i := range other.shape { wSize *= other.shape[i] } for i := range t.data { t.data[i] /= other.data[i%wSize] } return t } // bDiv returns the element-wise division of two tensors. The shape // of the second tensor must be a prefix sequence of the shape of // the first tensor: (m,n,...) / (m,n) = (m,n,...). func (t *Tensor) bDiv(other *Tensor) *Tensor { bSize := 1 for i := range t.shape { bSize *= t.shape[i] } for i := range other.shape { bSize /= other.shape[i] } for i := range t.data { t.data[i] /= other.data[i/bSize] } return t } func (t *Tensor) square() *Tensor { floats.MulTo(t.data, t.data, t.data) return t } func (t *Tensor) pow(other *Tensor) *Tensor { wSize := 1 for i := range other.shape { wSize *= other.shape[i] } for i := range t.data { t.data[i] = math32.Pow(t.data[i], other.data[i%wSize]) } return t } func (t *Tensor) exp() *Tensor { for i := range t.data { t.data[i] = float32(math.Exp(float64(t.data[i]))) } return t } func (t *Tensor) log() *Tensor { for i := range t.data { t.data[i] = math32.Log(t.data[i]) } return t } func (t *Tensor) sin() *Tensor { for i := range t.data { t.data[i] = math32.Sin(t.data[i]) } return t } func (t *Tensor) cos() *Tensor { for i := range t.data { t.data[i] = math32.Cos(t.data[i]) } return t } func (t *Tensor) tanh() *Tensor { for i := range t.data { t.data[i] = math32.Tanh(t.data[i]) } return t } func (t *Tensor) neg() *Tensor { for i := range t.data { t.data[i] = -t.data[i] } return t } func partition(n, p int) []lo.Tuple2[int, int] { // If n is less than or equal to 0, return nil. if n <= 0 { return nil } // If p is less than or equal to 1, return a single part covering the whole range. if p <= 1 { return []lo.Tuple2[int, int]{{A: 0, B: n}} } // If n is less than or equal to p, return each index as a separate part. if n <= p { return lo.Map(lo.Range(n), func(i int, _ int) lo.Tuple2[int, int] { return lo.Tuple2[int, int]{A: i, B: i + 1} }) } // Otherwise, split n into p parts as evenly as possible. minPartSize := n / p maxPartCount := n % p parts := make([]lo.Tuple2[int, int], 0, p) for i := 0; i < n; { partSize := minPartSize if maxPartCount > 0 { partSize++ maxPartCount-- } end := i + partSize if end > n { end = n } parts = append(parts, lo.Tuple2[int, int]{A: i, B: end}) i = end } return parts } func (t *Tensor) matMul(other *Tensor, transpose1, transpose2 bool, jobs int) *Tensor { if len(t.shape) != 2 || len(other.shape) != 2 { panic("matMul requires 2-D tensors") } var m, n, k int var result []float32 var wg sync.WaitGroup if !transpose1 && !transpose2 { if t.shape[1] != other.shape[0] { panic(fmt.Sprintf("matMul requires the shapes of tensors are compatible, but got %v and %v", t.shape, other.shape)) } m, n, k = t.shape[0], other.shape[1], t.shape[1] result = make([]float32, m*n) for _, p := range partition(m, jobs) { wg.Go(func() { floats.MM(transpose1, transpose2, p.B-p.A, n, k, t.data[p.A*k:], k, other.data, n, result[p.A*n:], n) }) } } else if transpose1 && !transpose2 { if t.shape[0] != other.shape[0] { panic(fmt.Sprintf("matMul requires the shapes of tensors are compatible, but got %v and %v", t.shape, other.shape)) } m, n, k = t.shape[1], other.shape[1], t.shape[0] result = make([]float32, m*n) for _, p := range partition(m, jobs) { wg.Go(func() { floats.MM(transpose1, transpose2, p.B-p.A, n, k, t.data[p.A:], m, other.data, n, result[p.A*n:], n) }) } } else if !transpose1 && transpose2 { if t.shape[1] != other.shape[1] { panic(fmt.Sprintf("matMul requires the shapes of tensors are compatible, but got %v and %v", t.shape, other.shape)) } m, n, k = t.shape[0], other.shape[0], t.shape[1] result = make([]float32, m*n) for _, p := range partition(m, jobs) { wg.Go(func() { floats.MM(transpose1, transpose2, p.B-p.A, n, k, t.data[p.A*k:], k, other.data, k, result[p.A*n:], n) }) } } else { if t.shape[0] != other.shape[1] { panic(fmt.Sprintf("matMul requires the shapes of tensors are compatible, but got %v and %v", t.shape, other.shape)) } m, n, k = t.shape[1], other.shape[0], t.shape[0] result = make([]float32, m*n) for _, p := range partition(m, jobs) { wg.Go(func() { floats.MM(transpose1, transpose2, p.B-p.A, n, k, t.data[p.A:], m, other.data, k, result[p.A*n:], n) }) } } wg.Wait() return &Tensor{ data: result, shape: []int{m, n}, } } func (t *Tensor) batchMatMul(other *Tensor, transpose1, transpose2 bool, jobs int) *Tensor { if len(t.shape) != 3 || len(other.shape) != 3 { panic("BatchMatMul requires 3-D tensors") } var b, m, n, k int var result []float32 if !transpose1 && !transpose2 { if t.shape[0] != other.shape[0] || t.shape[2] != other.shape[1] { panic("BatchMatMul requires the shapes of tensors are compatible") } b, m, n, k = t.shape[0], t.shape[1], other.shape[2], t.shape[2] result = make([]float32, b*m*n) _ = parallel.For(context.Background(), b, jobs, func(i int) { floats.MM(transpose1, transpose2, m, n, k, t.data[i*m*k:(i+1)*m*k], k, other.data[i*n*k:(i+1)*n*k], n, result[i*m*n:(i+1)*m*n], n) }) } else if transpose1 && !transpose2 { if t.shape[0] != other.shape[0] || t.shape[1] != other.shape[1] { panic("batchMatMul requires the shapes of tensors are compatible") } b, m, n, k = t.shape[0], t.shape[2], other.shape[2], t.shape[1] result = make([]float32, b*m*n) _ = parallel.For(context.Background(), b, jobs, func(i int) { floats.MM(transpose1, transpose2, m, n, k, t.data[i*m*k:(i+1)*m*k], m, other.data[i*n*k:(i+1)*n*k], n, result[i*m*n:(i+1)*m*n], n) }) } else if !transpose1 && transpose2 { if t.shape[0] != other.shape[0] || t.shape[2] != other.shape[2] { panic("batchMatMul requires the shapes of tensors are compatible") } b, m, n, k = t.shape[0], t.shape[1], other.shape[1], t.shape[2] result = make([]float32, b*m*n) _ = parallel.For(context.Background(), b, jobs, func(i int) { floats.MM(transpose1, transpose2, m, n, k, t.data[i*m*k:(i+1)*m*k], k, other.data[i*n*k:(i+1)*n*k], k, result[i*m*n:(i+1)*m*n], n) }) } else { if t.shape[0] != other.shape[0] || t.shape[1] != other.shape[2] { panic("batchMatMul requires the shapes of tensors are compatible") } b, m, n, k = t.shape[0], t.shape[2], other.shape[1], t.shape[1] result = make([]float32, b*m*n) _ = parallel.For(context.Background(), b, jobs, func(i int) { floats.MM(transpose1, transpose2, m, n, k, t.data[i*m*k:(i+1)*m*k], m, other.data[i*n*k:(i+1)*n*k], k, result[i*m*n:(i+1)*m*n], n) }) } return &Tensor{ data: result, shape: []int{b, m, n}, } } func (t *Tensor) maximum(other *Tensor) { if other.IsScalar() { for i := range t.data { t.data[i] = max(t.data[i], other.data[0]) } } else { for i := range t.data { t.data[i] = max(t.data[i], other.data[i]) } } } func (t *Tensor) gt(other *Tensor) *Tensor { if other.IsScalar() { for i := range t.data { if t.data[i] > other.data[0] { t.data[i] = 1 } else { t.data[i] = 0 } } } else { for i := range t.data { if t.data[i] > other.data[i] { t.data[i] = 1 } else { t.data[i] = 0 } } } return t } func (t *Tensor) transpose() *Tensor { if len(t.shape) < 2 { panic("transpose requires at least 2-D tensor") } shape := make([]int, 0, len(t.shape)) batchSize := 1 for i := 0; i < len(t.shape)-2; i++ { batchSize *= t.shape[i] shape = append(shape, t.shape[i]) } m, n := t.shape[len(t.shape)-2], t.shape[len(t.shape)-1] shape = append(shape, n, m) data := make([]float32, batchSize*m*n) for b := 0; b < batchSize; b++ { for i := 0; i < m; i++ { for j := 0; j < n; j++ { data[b*m*n+j*m+i] = t.data[b*m*n+i*n+j] } } } return &Tensor{ data: data, shape: shape, } } func (t *Tensor) max(axis int, keepDim bool) *Tensor { if axis < 0 || axis >= len(t.shape) { panic("axis out of range") } if len(t.shape) == 1 { return NewScalar(lo.Max(t.data)) } shape := make([]int, 0, len(t.shape)-1) a, b, c := 1, 1, 1 for i := 0; i < len(t.shape); i++ { if i < axis { shape = append(shape, t.shape[i]) a *= t.shape[i] } else if i == axis { if keepDim { shape = append(shape, 1) } b = t.shape[i] } else { shape = append(shape, t.shape[i]) c *= t.shape[i] } } data := make([]float32, a*c) for i := 0; i < a; i++ { for j := 0; j < c; j++ { maxValue := t.data[i*b*c+j] for k := 1; k < b; k++ { maxValue = max(maxValue, t.data[i*b*c+j+k*c]) } data[i*c+j] = maxValue } } return &Tensor{ data: data, shape: shape, } } func (t *Tensor) sum(axis int, keepDim bool) *Tensor { if axis < 0 || axis >= len(t.shape) { panic("axis out of range") } if len(t.shape) == 1 { return NewScalar(lo.Sum(t.data)) } shape := make([]int, 0, len(t.shape)-1) a, b, c := 1, 1, 1 for i := 0; i < len(t.shape); i++ { if i < axis { shape = append(shape, t.shape[i]) a *= t.shape[i] } else if i == axis { if keepDim { shape = append(shape, 1) } b = t.shape[i] } else { shape = append(shape, t.shape[i]) c *= t.shape[i] } } data := make([]float32, a*c) for i := 0; i < a; i++ { for j := 0; j < c; j++ { sumValue := t.data[i*b*c+j] for k := 1; k < b; k++ { sumValue += t.data[i*b*c+j+k*c] } data[i*c+j] = sumValue } } return &Tensor{ data: data, shape: shape, } } func (t *Tensor) argmax() []int { if len(t.data) == 0 { return nil } maxValue := t.data[0] maxIndex := 0 for i := 1; i < len(t.data); i++ { if t.data[i] > maxValue { maxValue = t.data[i] maxIndex = i } } indices := make([]int, len(t.shape)) for i := len(t.shape) - 1; i >= 0; i-- { indices[i] = maxIndex % t.shape[i] maxIndex /= t.shape[i] } return indices } func (t *Tensor) toPB() *protocol.Tensor { return &protocol.Tensor{ Shape: lo.Map(t.shape, func(i, _ int) int32 { return int32(i) }), Data: t.data, } } func (t *Tensor) fromPB(pb *protocol.Tensor) { t.shape = make([]int, len(pb.Shape)) for i := range t.shape { t.shape[i] = int(pb.Shape[i]) } t.data = pb.Data } func NormalInit(t *Tensor, mean, std float32) { for i := range t.data { t.data[i] = float32(rand.NormFloat64())*(std) + (mean) } } ================================================ FILE: common/nn/tensor_test.go ================================================ // Copyright 2024 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package nn import ( "fmt" "github.com/stretchr/testify/assert" "testing" ) func TestTensor_Slice(t *testing.T) { x := Rand(3, 4, 5) y := x.Slice(1, 3) assert.Equal(t, []int{2, 4, 5}, y.Shape()) for i := 0; i < 2; i++ { for j := 0; j < 4; j++ { for k := 0; k < 5; k++ { assert.Equal(t, x.Get(i+1, j, k), y.Get(i, j, k)) } } } } func TestTensor_SliceIndices(t *testing.T) { x := NewTensor([]float32{1, 2, 3, 4, 5, 6}, 3, 2) y := x.SliceIndices(2, 0) assert.Equal(t, []int{2, 2}, y.Shape()) assert.Equal(t, []float32{5, 6, 1, 2}, y.Data()) } func TestTensor_Max(t *testing.T) { x := NewTensor([]float32{3, 2, 5, 6, 0, 0}, 6) y := x.max(0, false) assert.Len(t, y.shape, 0) assert.Equal(t, []float32{6}, y.data) assert.Panics(t, func() { x.max(-1, false) }) assert.Panics(t, func() { x.max(2, false) }) x = NewTensor([]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, 3, 2, 2) y = x.max(1, false) assert.Equal(t, []int{3, 2}, y.shape) assert.Equal(t, []float32{3, 4, 7, 8, 11, 12}, y.data) y = x.max(1, true) assert.Equal(t, []int{3, 1, 2}, y.shape) assert.Equal(t, []float32{3, 4, 7, 8, 11, 12}, y.data) } func TestTensor_Sum(t *testing.T) { x := NewTensor([]float32{1, 2, 3, 4, 5, 6}, 6) y := x.sum(0, false) assert.Len(t, y.shape, 0) assert.Equal(t, []float32{21}, y.data) assert.Panics(t, func() { x.sum(-1, false) }) assert.Panics(t, func() { x.sum(2, false) }) x = NewTensor([]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, 3, 2, 2) y = x.sum(1, false) assert.Equal(t, []int{3, 2}, y.shape) assert.Equal(t, []float32{4, 6, 12, 14, 20, 22}, y.data) y = x.sum(1, true) assert.Equal(t, []int{3, 1, 2}, y.shape) assert.Equal(t, []float32{4, 6, 12, 14, 20, 22}, y.data) } func TestTensor_Transpose(t *testing.T) { x := NewTensor([]float32{1, 2, 3, 4, 5, 6}, 3, 2) y := x.transpose() assert.Equal(t, []int{2, 3}, y.Shape()) assert.Equal(t, []float32{1, 3, 5, 2, 4, 6}, y.Data()) } func (t *Tensor) matMulLegacy(other *Tensor, transpose1, transpose2 bool) *Tensor { if !transpose1 && !transpose2 { if len(t.shape) != 2 || len(other.shape) != 2 { panic("matMul requires 2-D tensors") } if t.shape[1] != other.shape[0] { panic("matMul requires the shapes of tensors are compatible") } m, n, p := t.shape[0], t.shape[1], other.shape[1] result := make([]float32, m*p) for i := 0; i < m; i++ { for j := 0; j < p; j++ { for k := 0; k < n; k++ { result[i*p+j] += t.data[i*n+k] * other.data[k*p+j] } } } return &Tensor{ data: result, shape: []int{m, p}, } } else if transpose1 && !transpose2 { if len(t.shape) != 2 || len(other.shape) != 2 { panic("matMul requires 2-D tensors") } if t.shape[0] != other.shape[0] { panic("matMul requires the shapes of tensors are compatible") } m, n, p := t.shape[1], t.shape[0], other.shape[1] result := make([]float32, m*p) for i := 0; i < m; i++ { for j := 0; j < p; j++ { for k := 0; k < n; k++ { result[i*p+j] += t.data[k*m+i] * other.data[k*p+j] } } } return &Tensor{ data: result, shape: []int{m, p}, } } else if !transpose1 && transpose2 { if len(t.shape) != 2 || len(other.shape) != 2 { panic("matMul requires 2-D tensors") } if t.shape[1] != other.shape[1] { panic("matMul requires the shapes of tensors are compatible") } m, n, p := t.shape[0], t.shape[1], other.shape[0] result := make([]float32, m*p) for i := 0; i < m; i++ { for j := 0; j < p; j++ { for k := 0; k < n; k++ { result[i*p+j] += t.data[i*n+k] * other.data[j*n+k] } } } return &Tensor{ data: result, shape: []int{m, p}, } } else { if len(t.shape) != 2 || len(other.shape) != 2 { panic("matMul requires 2-D tensors") } if t.shape[0] != other.shape[0] { panic("matMul requires the shapes of tensors are compatible") } m, n, p := t.shape[1], t.shape[0], other.shape[1] result := make([]float32, m*p) for i := 0; i < m; i++ { for j := 0; j < p; j++ { for k := 0; k < n; k++ { result[i*p+j] += t.data[k*m+i] * other.data[j*n+k] } } } return &Tensor{ data: result, shape: []int{m, p}, } } } func (t *Tensor) batchMatMulLegacy(other *Tensor, transpose1, transpose2 bool) *Tensor { if !transpose1 && !transpose2 { if len(t.shape) != 3 || len(other.shape) != 3 { panic("BatchMatMul requires 3-D tensors") } if t.shape[0] != other.shape[0] || t.shape[2] != other.shape[1] { panic("BatchMatMul requires the shapes of tensors are compatible") } m, n, p := t.shape[0], t.shape[1], other.shape[2] result := make([]float32, m*n*p) for i := 0; i < m; i++ { for j := 0; j < n; j++ { for k := 0; k < p; k++ { for l := 0; l < t.shape[2]; l++ { result[i*n*p+j*p+k] += t.data[i*n*t.shape[2]+j*t.shape[2]+l] * other.data[i*other.shape[1]*other.shape[2]+l*other.shape[2]+k] } } } } return &Tensor{ data: result, shape: []int{m, n, p}, } } else if transpose1 && !transpose2 { if len(t.shape) != 3 || len(other.shape) != 3 { panic("batchMatMul requires 3-D tensors") } if t.shape[0] != other.shape[0] || t.shape[1] != other.shape[1] { panic("batchMatMul requires the shapes of tensors are compatible") } m, n, p := t.shape[0], t.shape[2], other.shape[2] result := make([]float32, m*n*p) for i := 0; i < m; i++ { for j := 0; j < n; j++ { for k := 0; k < p; k++ { for l := 0; l < t.shape[1]; l++ { result[i*n*p+j*p+k] += t.data[i*t.shape[1]*t.shape[2]+l*t.shape[2]+j] * other.data[i*other.shape[1]*other.shape[2]+l*other.shape[2]+k] } } } } return &Tensor{ data: result, shape: []int{m, n, p}, } } else if !transpose1 && transpose2 { if len(t.shape) != 3 || len(other.shape) != 3 { panic("batchMatMul requires 3-D tensors") } if t.shape[0] != other.shape[0] || t.shape[2] != other.shape[2] { panic("batchMatMul requires the shapes of tensors are compatible") } m, n, p := t.shape[0], t.shape[1], other.shape[1] result := make([]float32, m*n*p) for i := 0; i < m; i++ { for j := 0; j < n; j++ { for k := 0; k < p; k++ { for l := 0; l < t.shape[2]; l++ { result[i*n*p+j*p+k] += t.data[i*n*t.shape[2]+j*t.shape[2]+l] * other.data[i*other.shape[1]*other.shape[2]+k*other.shape[2]+l] } } } } return &Tensor{ data: result, shape: []int{m, n, p}, } } else { if len(t.shape) != 3 || len(other.shape) != 3 { panic("batchMatMul requires 3-D tensors") } if t.shape[0] != other.shape[0] || t.shape[2] != other.shape[2] { panic("batchMatMul requires the shapes of tensors are compatible") } m, n, p := t.shape[1], t.shape[2], other.shape[2] result := make([]float32, m*n*p) for i := 0; i < m; i++ { for j := 0; j < n; j++ { for k := 0; k < p; k++ { for l := 0; l < t.shape[0]; l++ { result[i*n*p+j*p+k] += t.data[l*t.shape[1]*t.shape[2]+i*t.shape[2]+j] * other.data[l*other.shape[1]*other.shape[2]+j*other.shape[2]+k] } } } } return &Tensor{ data: result, shape: []int{m, n, p}, } } } func BenchmarkMatMulLegacy64(b *testing.B) { x := Rand(64, 64) y := Rand(64, 64) for t1 := 0; t1 < 2; t1++ { for t2 := 0; t2 < 2; t2++ { b.Run(fmt.Sprintf("(%d,%d)", t1, t2), func(b *testing.B) { for i := 0; i < b.N; i++ { x.matMulLegacy(y, t1 == 1, t2 == 1) } }) } } } func BenchmarkMatMul64(b *testing.B) { x := Rand(64, 64) y := Rand(64, 64) for t1 := 0; t1 < 2; t1++ { for t2 := 0; t2 < 2; t2++ { b.Run(fmt.Sprintf("(%d,%d)", t1, t2), func(b *testing.B) { for i := 0; i < b.N; i++ { x.matMul(y, t1 == 1, t2 == 1, 0) } }) } } } func BenchmarkBatchMatMulLegacy64(b *testing.B) { x := Rand(64, 64, 64) y := Rand(64, 64, 64) for t1 := 0; t1 < 2; t1++ { for t2 := 0; t2 < 2; t2++ { b.Run(fmt.Sprintf("(%d,%d)", t1, t2), func(b *testing.B) { for i := 0; i < b.N; i++ { x.batchMatMulLegacy(y, t1 == 1, t2 == 1) } }) } } } func BenchmarkBatchMatMul64(b *testing.B) { x := Rand(64, 64, 64) y := Rand(64, 64, 64) for t1 := 0; t1 < 2; t1++ { for t2 := 0; t2 < 2; t2++ { b.Run(fmt.Sprintf("(%d,%d)", t1, t2), func(b *testing.B) { for i := 0; i < b.N; i++ { x.batchMatMul(y, t1 == 1, t2 == 1, 0) } }) } } } ================================================ FILE: common/parallel/parallel.go ================================================ // Copyright 2020 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package parallel import ( "context" "sync" "github.com/gorse-io/gorse/common/util" "github.com/juju/errors" "github.com/samber/lo" ) const chanSize = 1024 /* Parallel Schedulers */ // Parallel schedules and runs tasks in parallel. nTask is the number of tasks. nJob is // the number of executors. worker is the executed function which passed a range of task // Names (begin, end). The ctx argument allows callers to cancel outstanding work. func Parallel(ctx context.Context, nJobs, nWorkers int, worker func(workerId, jobId int) error) error { if nWorkers <= 1 { for i := 0; i < nJobs; i++ { if err := ctx.Err(); err != nil { return errors.Trace(err) } if err := worker(0, i); err != nil { return errors.Trace(err) } } } else { c := make(chan int, chanSize) // producer go func() { defer close(c) for i := 0; i < nJobs; i++ { select { case <-ctx.Done(): return case c <- i: } } }() // consumer var wg sync.WaitGroup errs := make([]error, nJobs) for j := 0; j < nWorkers; j++ { // start workers workerId := j wg.Go(func() { defer util.CheckPanic() for { select { case <-ctx.Done(): return case jobId, ok := <-c: if !ok { return } if err := ctx.Err(); err != nil { errs[jobId] = err return } // run job if err := worker(workerId, jobId); err != nil { errs[jobId] = err return } } } }) } wg.Wait() // check errors for _, err := range errs { if err != nil { return errors.Trace(err) } } } return ctx.Err() } func For(ctx context.Context, nJobs, nWorkers int, worker func(int)) error { if nWorkers <= 1 { for i := 0; i < nJobs; i++ { if err := ctx.Err(); err != nil { return errors.Trace(err) } worker(i) } } else { c := make(chan int, chanSize) // producer go func() { defer close(c) for i := 0; i < nJobs; i++ { select { case <-ctx.Done(): return case c <- i: } } }() // consumer var wg sync.WaitGroup for j := 0; j < nWorkers; j++ { // start workers wg.Go(func() { for { select { case <-ctx.Done(): return case jobId, ok := <-c: if !ok { return } if err := ctx.Err(); err != nil { return } worker(jobId) } } }) } wg.Wait() } return ctx.Err() } func ForEach[T any](ctx context.Context, a []T, nWorkers int, worker func(int, T)) error { if nWorkers <= 1 { for i, v := range a { if err := ctx.Err(); err != nil { return errors.Trace(err) } worker(i, v) } } else { c := make(chan lo.Tuple2[int, T], chanSize) // producer go func() { defer close(c) for i, v := range a { select { case <-ctx.Done(): return case c <- lo.Tuple2[int, T]{A: i, B: v}: } } }() // consumer var wg sync.WaitGroup for j := 0; j < nWorkers; j++ { // start workers wg.Go(func() { for { select { case <-ctx.Done(): return case job, ok := <-c: if !ok { return } if err := ctx.Err(); err != nil { return } worker(job.A, job.B) } } }) } wg.Wait() } return ctx.Err() } // Split a slice into n slices and keep the order of elements. func Split[T any](a []T, n int) [][]T { if len(a) == 0 { return nil } if n > len(a) { n = len(a) } minChunkSize := len(a) / n maxChunkNum := len(a) % n chunks := make([][]T, n) for i, j := 0, 0; i < n; i++ { chunkSize := minChunkSize if i < maxChunkNum { chunkSize++ } chunks[i] = a[j : j+chunkSize] j += chunkSize } return chunks } type Context struct { sem chan struct{} detachedSem chan struct{} detached bool } func (ctx *Context) Detach() { if ctx == nil || ctx.detached { return } ctx.detachedSem <- struct{}{} ctx.detached = true <-ctx.sem } func (ctx *Context) Attach() { if ctx == nil || !ctx.detached { return } ctx.detached = false <-ctx.detachedSem ctx.sem <- struct{}{} } func Detachable(ctx context.Context, nJobs, nWorkers, nMaxDetached int, worker func(*Context, int)) error { sem := make(chan struct{}, nWorkers) detachedSem := make(chan struct{}, nMaxDetached) var wg sync.WaitGroup for i := 0; i < nJobs; i++ { select { case <-ctx.Done(): wg.Wait() return ctx.Err() case sem <- struct{}{}: } wg.Go(func() { if ctx.Err() != nil { <-sem return } c := &Context{sem: sem, detachedSem: detachedSem} worker(c, i) if c.detached { <-c.detachedSem } else { <-sem } }) } wg.Wait() return ctx.Err() } ================================================ FILE: common/parallel/parallel_test.go ================================================ // Copyright 2020 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package parallel import ( "context" "fmt" "sync/atomic" "testing" "testing/synctest" "time" mapset "github.com/deckarep/golang-set/v2" "github.com/gorse-io/gorse/common/util" "github.com/stretchr/testify/assert" ) func TestParallel(t *testing.T) { synctest.Test(t, func(t *testing.T) { a := util.RangeInt(10000) b := make([]int, len(a)) workerIds := make([]int, len(a)) // multiple threads _ = Parallel(t.Context(), len(a), 4, func(workerId, jobId int) error { b[jobId] = a[jobId] workerIds[jobId] = workerId time.Sleep(time.Microsecond) return nil }) workersSet := mapset.NewSet(workerIds...) assert.Equal(t, a, b) assert.GreaterOrEqual(t, 4, workersSet.Cardinality()) assert.Less(t, 1, workersSet.Cardinality()) // single thread _ = Parallel(t.Context(), len(a), 1, func(workerId, jobId int) error { b[jobId] = a[jobId] workerIds[jobId] = workerId return nil }) workersSet = mapset.NewSet(workerIds...) assert.Equal(t, a, b) assert.Equal(t, 1, workersSet.Cardinality()) }) } func TestFor(t *testing.T) { synctest.Test(t, func(t *testing.T) { // multiple threads a := util.RangeInt(10000) b := make([]int, len(a)) err := For(t.Context(), len(a), 4, func(jobId int) { b[jobId] = a[jobId] time.Sleep(time.Microsecond) }) assert.NoError(t, err) assert.Equal(t, a, b) // single thread err = For(t.Context(), len(a), 1, func(jobId int) { b[jobId] = a[jobId] time.Sleep(time.Microsecond) }) assert.NoError(t, err) assert.Equal(t, a, b) }) } func TestForCancel(t *testing.T) { synctest.Test(t, func(t *testing.T) { ctx, cancel := context.WithCancel(t.Context()) var count atomic.Int32 err := For(ctx, 1000, 4, func(jobId int) { if jobId == 0 { cancel() } count.Add(1) time.Sleep(100 * time.Microsecond) }) assert.ErrorIs(t, err, context.Canceled) assert.Less(t, int(count.Load()), 1000) }) } func TestForEach(t *testing.T) { synctest.Test(t, func(t *testing.T) { a := util.RangeInt(10000) b := make([]int, len(a)) // multiple threads err := ForEach(t.Context(), a, 4, func(i, v int) { assert.Equal(t, i, v) b[i] = v time.Sleep(time.Microsecond) }) assert.NoError(t, err) assert.Equal(t, a, b) // single thread err = ForEach(t.Context(), a, 1, func(i, v int) { assert.Equal(t, i, v) b[i] = v time.Sleep(time.Microsecond) }) assert.NoError(t, err) assert.Equal(t, a, b) }) } func TestForEachCancel(t *testing.T) { synctest.Test(t, func(t *testing.T) { ctx, cancel := context.WithCancel(t.Context()) var count atomic.Int32 err := ForEach(ctx, util.RangeInt(1000), 4, func(i, v int) { if i == 0 { cancel() } count.Add(1) time.Sleep(100 * time.Microsecond) }) assert.ErrorIs(t, err, context.Canceled) assert.Less(t, int(count.Load()), 1000) }) } func TestParallelFail(t *testing.T) { // multiple threads err := Parallel(t.Context(), 10000, 4, func(workerId, jobId int) error { if jobId%2 == 1 { return fmt.Errorf("error from %d", jobId) } return nil }) assert.Error(t, err) // single thread err = Parallel(t.Context(), 10000, 1, func(workerId, jobId int) error { if jobId%2 == 1 { return fmt.Errorf("error from %d", jobId) } return nil }) assert.Error(t, err) } func TestParallelCancel(t *testing.T) { synctest.Test(t, func(t *testing.T) { ctx, cancel := context.WithCancel(t.Context()) var count atomic.Int32 err := Parallel(ctx, 10000, 4, func(_, jobId int) error { if jobId == 0 { cancel() } count.Add(1) time.Sleep(time.Second) return nil }) assert.ErrorIs(t, err, context.Canceled) assert.Less(t, int(count.Load()), 10000) }) } func TestSplit(t *testing.T) { a := []int{1, 2, 3, 4, 5, 6} b := Split(a, 3) assert.Equal(t, [][]int{{1, 2}, {3, 4}, {5, 6}}, b) a = []int{1, 2, 3, 4, 5, 6, 7} b = Split(a, 3) assert.Equal(t, [][]int{{1, 2, 3}, {4, 5}, {6, 7}}, b) } func TestDetachable(t *testing.T) { synctest.Test(t, func(t *testing.T) { start := time.Now() err := Detachable(t.Context(), 100, 1, 100, func(ctx *Context, jobId int) { ctx.Detach() time.Sleep(time.Second) ctx.Attach() }) assert.NoError(t, err) assert.Less(t, time.Since(start), time.Second*2) }) synctest.Test(t, func(t *testing.T) { start := time.Now() err := Detachable(t.Context(), 100, 1, 10, func(ctx *Context, jobId int) { ctx.Detach() time.Sleep(time.Second) ctx.Attach() }) assert.NoError(t, err) assert.Less(t, time.Since(start), time.Second*11) }) } func TestDetachableCancel(t *testing.T) { synctest.Test(t, func(t *testing.T) { ctx, cancel := context.WithCancel(t.Context()) var count atomic.Int32 err := Detachable(ctx, 100, 4, 10, func(c *Context, jobId int) { if jobId == 0 { cancel() } count.Add(1) time.Sleep(10 * time.Millisecond) }) assert.ErrorIs(t, err, context.Canceled) assert.Less(t, int(count.Load()), 20) }) } ================================================ FILE: common/parallel/ratelimit.go ================================================ package parallel import ( "time" "github.com/juju/ratelimit" ) var ( ChatCompletionBackoff = time.Duration(0) ChatCompletionRequestsLimiter RateLimiter = &Unlimited{} ChatCompletionTokensLimiter RateLimiter = &Unlimited{} EmbeddingBackoff = time.Duration(0) EmbeddingRequestsLimiter RateLimiter = &Unlimited{} EmbeddingTokensLimiter RateLimiter = &Unlimited{} ) func InitChatCompletionLimiters(rpm, tpm int) { if rpm > 0 { ChatCompletionBackoff = time.Minute / time.Duration(rpm) ChatCompletionRequestsLimiter = ratelimit.NewBucketWithQuantum(time.Second, int64(rpm/60), int64(rpm/60)) } if tpm > 0 { ChatCompletionTokensLimiter = ratelimit.NewBucketWithQuantum(time.Second, int64(tpm/60), int64(tpm/60)) } } func InitEmbeddingLimiters(rpm, tpm int) { if rpm > 0 { EmbeddingBackoff = time.Minute / time.Duration(rpm) EmbeddingRequestsLimiter = ratelimit.NewBucketWithQuantum(time.Second, int64(rpm/60), int64(rpm/60)) } if tpm > 0 { EmbeddingTokensLimiter = ratelimit.NewBucketWithQuantum(time.Second, int64(tpm/60), int64(tpm/60)) } } type RateLimiter interface { Take(count int64) time.Duration } type Unlimited struct{} func (n *Unlimited) Take(count int64) time.Duration { return 0 } ================================================ FILE: common/parallel/ratelimit_test.go ================================================ package parallel import ( "testing" "time" "github.com/stretchr/testify/assert" ) func TestUnlimited(t *testing.T) { rateLimiter := &Unlimited{} assert.Zero(t, rateLimiter.Take(1)) } func TestInitEmbeddingLimiters(t *testing.T) { InitEmbeddingLimiters(120, 180) assert.Equal(t, time.Duration(0), EmbeddingRequestsLimiter.Take(1)) assert.InDelta(t, time.Second, EmbeddingRequestsLimiter.Take(2), float64(time.Millisecond)) assert.Equal(t, time.Duration(0), EmbeddingTokensLimiter.Take(2)) assert.InDelta(t, 2*time.Second, EmbeddingTokensLimiter.Take(5), float64(time.Millisecond)) } func TestInitChatCompletionLimiters(t *testing.T) { InitChatCompletionLimiters(120, 180) assert.Equal(t, time.Duration(0), ChatCompletionRequestsLimiter.Take(1)) assert.InDelta(t, time.Second, ChatCompletionRequestsLimiter.Take(2), float64(time.Millisecond)) assert.Equal(t, time.Duration(0), ChatCompletionTokensLimiter.Take(2)) assert.InDelta(t, 2*time.Second, ChatCompletionTokensLimiter.Take(5), float64(time.Millisecond)) } ================================================ FILE: common/rc/rc.go ================================================ // Copyright 2025 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package rc import ( "io" "sync" "sync/atomic" "github.com/modern-go/reflect2" ) type DropFunc func() type rcPointer[T io.Closer] struct { pointer T reference atomic.Int32 } type Rc[T io.Closer] struct { pointer *rcPointer[T] drop DropFunc mu sync.RWMutex } func (r *Rc[T]) Reset(pointer T) { r.mu.Lock() defer r.mu.Unlock() if r.pointer != nil { r.drop() } if reflect2.IsNil(pointer) { r.pointer = nil } else { r.pointer = &rcPointer[T]{pointer: pointer} _, r.drop = r.get() } } func (r *Rc[T]) Get() (T, DropFunc) { r.mu.RLock() defer r.mu.RUnlock() return r.get() } func (r *Rc[T]) get() (T, DropFunc) { if r.pointer == nil { var niL T return niL, func() {} } p := r.pointer p.reference.Add(1) return p.pointer, func() { // Atomically decrement and check if we are the last to drop if p.reference.Add(-1) == 0 { // Use CAS to ensure only one goroutine calls Close() if p.reference.CompareAndSwap(0, -1) { _ = p.pointer.Close() } } } } func (r *Rc[T]) Apply(f func(T)) { p, drop := r.Get() defer drop() f(p) } ================================================ FILE: common/rc/rc_test.go ================================================ // Copyright 2025 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package rc import ( "math/rand/v2" "sync" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" ) type Foo struct { mock.Mock } func NewFoo() *Foo { f := &Foo{} f.On("Close").Return(nil) return f } func (f *Foo) Close() error { f.Called() return nil } func TestRc_Get(t *testing.T) { f := NewFoo() var rc Rc[*Foo] rc.Reset(f) _, drop := rc.Get() drop() rc.Reset(nil) f.AssertNumberOfCalls(t, "Close", 1) } func TestRc_Apply(t *testing.T) { f := NewFoo() var rc Rc[*Foo] rc.Reset(f) rc.Apply(func(f *Foo) { assert.NotNil(t, f) }) rc.Reset(nil) f.AssertNumberOfCalls(t, "Close", 1) } func TestRc_Concurrent(t *testing.T) { f := make([]*Foo, 1000) for i := range f { f[i] = NewFoo() } var rc Rc[*Foo] drops := make([]DropFunc, 0) for i := range f { rc.Reset(f[i]) cnt := rand.IntN(10) for j := 0; j < cnt; j++ { _, drop := rc.Get() drops = append(drops, drop) } } rc.Reset(nil) var wg sync.WaitGroup for _, drop := range drops { wg.Go(func() { drop() }) } wg.Wait() for i := range f { f[i].AssertNumberOfCalls(t, "Close", 1) } } ================================================ FILE: common/reranker/client.go ================================================ // Copyright 2026 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package reranker import ( "bytes" "context" "encoding/json" "fmt" "io" "net" "net/http" "github.com/emicklei/go-restful/v3" ) type Client struct { authToken string url string httpClient *http.Client } func NewClient(apiKey, endpoint string) *Client { return &Client{ authToken: apiKey, url: endpoint, httpClient: &http.Client{}, } } type RerankRequest struct { Model string `json:"model"` Query string `json:"query"` TopN int `json:"top_n,omitempty"` Documents []string `json:"documents"` } type RerankResponse struct { Model string `json:"model"` Usage Usage `json:"usage"` Results []RerankResult `json:"results"` } type RerankResult struct { Index int `json:"index"` RelevanceScore float64 `json:"relevance_score"` } type Usage struct { TotalTokens int `json:"total_tokens"` } func (c *Client) Rerank(ctx context.Context, req RerankRequest) (*RerankResponse, error) { var body []byte var err error body, err = json.Marshal(req) if err != nil { return nil, err } httpReq, err := http.NewRequestWithContext(ctx, "POST", c.url, bytes.NewBuffer(body)) if err != nil { return nil, err } httpReq.Header.Set("Authorization", "Bearer "+c.authToken) httpReq.Header.Set("Content-Type", "application/json") resp, err := c.httpClient.Do(httpReq) if err != nil { return nil, err } defer resp.Body.Close() respBody, err := io.ReadAll(resp.Body) if err != nil { return nil, err } if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("rerank request failed with status: %d, body: %s", resp.StatusCode, string(respBody)) } var rerankResp RerankResponse if err := json.Unmarshal(respBody, &rerankResp); err != nil { return nil, err } return &rerankResp, nil } type MockServer struct { listener net.Listener httpServer *http.Server apiKey string ready chan struct{} } func NewMockServer() *MockServer { s := &MockServer{} ws := new(restful.WebService) ws.Path("/"). Consumes(restful.MIME_JSON). Produces(restful.MIME_JSON) ws.Route(ws.POST("/v1/rerank"). Reads(RerankRequest{}). Writes(RerankResponse{}). To(s.rerank)) container := restful.NewContainer() container.Add(ws) s.httpServer = &http.Server{Handler: container} s.apiKey = "dashscope" s.ready = make(chan struct{}) return s } func (s *MockServer) Start() error { var err error s.listener, err = net.Listen("tcp", "127.0.0.1:0") if err != nil { return err } close(s.ready) return s.httpServer.Serve(s.listener) } func (s *MockServer) URL() string { return fmt.Sprintf("http://%s/v1/rerank", s.listener.Addr().String()) } func (s *MockServer) AuthToken() string { return s.apiKey } func (s *MockServer) Ready() { <-s.ready } func (s *MockServer) Close() error { return s.httpServer.Close() } func (s *MockServer) rerank(req *restful.Request, resp *restful.Response) { var r RerankRequest err := req.ReadEntity(&r) if err != nil { _ = resp.WriteError(http.StatusBadRequest, err) return } results := make([]RerankResult, len(r.Documents)) for i := range r.Documents { results[i] = RerankResult{ Index: i, RelevanceScore: 1.0 / float64(i+1), } } _ = resp.WriteEntity(RerankResponse{ Model: r.Model, Usage: Usage{ TotalTokens: 100, }, Results: results, }) } ================================================ FILE: common/reranker/client_test.go ================================================ // Copyright 2026 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package reranker import ( "net/http" "testing" "github.com/stretchr/testify/suite" ) type ClientTestSuite struct { suite.Suite s *MockServer } func (suite *ClientTestSuite) SetupSuite() { suite.s = NewMockServer() go func() { err := suite.s.Start() suite.ErrorIs(err, http.ErrServerClosed) }() suite.s.Ready() } func (suite *ClientTestSuite) TearDownSuite() { suite.s.Close() } func (suite *ClientTestSuite) TestRerank() { client := NewClient(suite.s.AuthToken(), suite.s.URL()) req := RerankRequest{ Model: "jina-reranker-v2-base-multilingual", Query: "What is the capital of France?", Documents: []string{ "Paris is the capital of France.", "Lyon is a city in France.", }, } resp, err := client.Rerank(suite.T().Context(), req) suite.NoError(err) suite.Equal(2, len(resp.Results)) suite.Equal(0, resp.Results[0].Index) suite.Equal(1.0, resp.Results[0].RelevanceScore) suite.Equal(1, resp.Results[1].Index) suite.Equal(0.5, resp.Results[1].RelevanceScore) } func TestClient(t *testing.T) { suite.Run(t, new(ClientTestSuite)) } ================================================ FILE: common/sizeof/size.go ================================================ // Copyright 2020 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package sizeof import ( "math" "reflect" ) // DeepSize reports the size of v in bytes, as reflect.Size, but also including // all recursive substructures of v via maps, pointers, and slices. If v // contains any cycles, the size of each pointer (re)introducing the cycle is // included but the acyclic substructure is counted only once. // // Only values whose size and structure can be obtained by the reflect package // are counted. Some values have components that are not visible by // reflection, and so are not counted or may be undercounted. In particular: // // The space occupied by code and data, reachable through variables captured in // the closure of a function pointer, are not counted. A value of function type // is counted only as a pointer. // // The unused buckets of a map cannot be inspected by the reflect package. // Their size is estimated by assuming unfilled slots contain zeroes of their // type. // // The unused capacity of the array under a slice is estimated by assuming the // unused slots contain zeroes of their type. It is possible they contain non // zero values from sharing or reslicing, but without explicitly reslicing the // reflect package cannot touch them. func DeepSize(v any) int { return int(valueSize(reflect.ValueOf(v), make(map[uintptr]bool))) } func valueSize(v reflect.Value, seen map[uintptr]bool) uintptr { base := v.Type().Size() switch v.Kind() { case reflect.Ptr: p := v.Pointer() if !seen[p] && !v.IsNil() { seen[p] = true return base + valueSize(v.Elem(), seen) } case reflect.Slice: n := v.Len() for i := 0; i < n; i++ { base += valueSize(v.Index(i), seen) } // Account for the parts of the array not covered by this slice. Since // we can't get the values directly, assume they're zeroes. That may be // incorrect, in which case we may underestimate. if cap := v.Cap(); cap > n { base += v.Type().Size() * uintptr(cap-n) } case reflect.Map: // A map m has len(m) / 6.5 buckets, rounded up to a power of two, and // a minimum of one bucket. Each bucket is 16 bytes + 8*(keysize + valsize). // // We can't tell which keys are in which bucket by reflection, however, // so here we count the 16-byte header for each bucket, and then just add // in the computed key and value sizes. nb := uintptr(math.Pow(2, math.Ceil(math.Log(float64(v.Len())/6.5)/math.Log(2)))) if nb == 0 { nb = 1 } base = 16 * nb for _, key := range v.MapKeys() { base += valueSize(key, seen) base += valueSize(v.MapIndex(key), seen) } // We have nb buckets of 8 slots each, and v.Len() slots are filled. // The remaining slots we will assume contain zero key/value pairs. zk := v.Type().Key().Size() // a zero key zv := v.Type().Elem().Size() // a zero value base += (8*nb - uintptr(v.Len())) * (zk + zv) case reflect.Struct: // Chase pointer and slice fields and add the size of their members. for i := 0; i < v.NumField(); i++ { f := v.Field(i) switch f.Kind() { case reflect.Ptr: p := f.Pointer() if !seen[p] && !f.IsNil() { seen[p] = true base += valueSize(f.Elem(), seen) } case reflect.Slice: base += valueSize(f, seen) } } case reflect.String: return base + uintptr(v.Len()) } return base } ================================================ FILE: common/sizeof/size_test.go ================================================ // Copyright 2025 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package sizeof import ( "testing" "github.com/stretchr/testify/assert" ) func TestCyclic(t *testing.T) { type V struct { Z int E *V } v := &V{Z: 25} want := DeepSize(v) v.E = v // induce a cycle got := DeepSize(v) if got != want { t.Errorf("Cyclic size: got %d, want %d", got, want) } } func TestDeepSize(t *testing.T) { // matrix a := [][]int64{{1}, {2}, {3}, {4}} assert.Equal(t, 5*24+4*8, DeepSize(a)) b := [][]int32{{1}, {2}, {3}, {4}} assert.Equal(t, 5*24+4*4, DeepSize(b)) c := [][]int16{{1}, {2}, {3}, {4}} assert.Equal(t, 5*24+4*2, DeepSize(c)) d := [][]int8{{1}, {2}, {3}, {4}} assert.Equal(t, 5*24+4, DeepSize(d)) // strings e := []string{"abc", "de", "f"} assert.Equal(t, 24+16*3+6, DeepSize(e)) f := []string{"♥♥♥", "♥♥", "♥"} assert.Equal(t, 24+16*3+18, DeepSize(f)) // slice g := []int64{1, 2, 3, 4} assert.Equal(t, 7*8, DeepSize(g)) h := []int32{1, 2, 3, 4} assert.Equal(t, 3*8+4*4, DeepSize(h)) i := []int16{1, 2, 3, 4} assert.Equal(t, 3*8+2*4, DeepSize(i)) j := []int8{1, 2, 3, 4} assert.Equal(t, 3*8+4, DeepSize(j)) } ================================================ FILE: common/util/random.go ================================================ // Copyright 2020 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package util import ( "math/rand" "sync" mapset "github.com/deckarep/golang-set/v2" ) // RandomGenerator is the random generator for gorse. type RandomGenerator struct { *rand.Rand } // NewRandomGenerator creates a RandomGenerator. func NewRandomGenerator(seed int64) RandomGenerator { return RandomGenerator{rand.New(rand.NewSource(int64(seed)))} } // UniformVector makes a vec filled with uniform random floats, func (rng RandomGenerator) UniformVector(size int, low, high float32) []float32 { ret := make([]float32, size) scale := high - low for i := 0; i < len(ret); i++ { ret[i] = rng.Float32()*scale + low } return ret } // NewNormalVector makes a vec filled with normal random floats. func (rng RandomGenerator) NewNormalVector(size int, mean, stdDev float32) []float32 { ret := make([]float32, size) for i := 0; i < len(ret); i++ { ret[i] = float32(rng.NormFloat64())*stdDev + mean } return ret } // NormalMatrix makes a matrix filled with normal random floats. func (rng RandomGenerator) NormalMatrix(row, col int, mean, stdDev float32) [][]float32 { ret := make([][]float32, row) for i := range ret { ret[i] = rng.NewNormalVector(col, mean, stdDev) } return ret } func (rng RandomGenerator) NormalVector(size int, mean, stdDev float32) []float32 { ret := make([]float32, size) for i := 0; i < len(ret); i++ { ret[i] = float32(rng.NormFloat64())*stdDev + mean } return ret } // UniformMatrix makes a matrix filled with uniform random floats. func (rng RandomGenerator) UniformMatrix(row, col int, low, high float32) [][]float32 { ret := make([][]float32, row) for i := range ret { ret[i] = rng.UniformVector(col, low, high) } return ret } // NormalVector64 makes a vec filled with normal random floats. func (rng RandomGenerator) NormalVector64(size int, mean, stdDev float64) []float64 { ret := make([]float64, size) for i := 0; i < len(ret); i++ { ret[i] = rng.NormFloat64()*stdDev + mean } return ret } // Sample n values between low and high, but not in exclude. func (rng RandomGenerator) Sample(low, high, n int, exclude ...mapset.Set[int]) []int { intervalLength := high - low excludeSet := mapset.NewSet[int]() for _, set := range exclude { excludeSet = excludeSet.Union(set) } sampled := make([]int, 0, n) if n >= intervalLength-excludeSet.Cardinality() { for i := low; i < high; i++ { if !excludeSet.Contains(i) { sampled = append(sampled, i) excludeSet.Add(i) } } } else { for len(sampled) < n { v := rng.Intn(intervalLength) + low if !excludeSet.Contains(v) { sampled = append(sampled, v) excludeSet.Add(v) } } } return sampled } // SampleInt32 n 32bit values between low and high, but not in exclude. func (rng RandomGenerator) SampleInt32(low, high int32, n int, exclude ...mapset.Set[int32]) []int32 { intervalLength := high - low excludeSet := mapset.NewSet[int32]() for _, set := range exclude { excludeSet = excludeSet.Union(set) } sampled := make([]int32, 0, n) if n >= int(intervalLength)-excludeSet.Cardinality() { for i := low; i < high; i++ { if !excludeSet.Contains(i) { sampled = append(sampled, i) excludeSet.Add(i) } } } else { for len(sampled) < n { v := rng.Int31n(intervalLength) + low if !excludeSet.Contains(v) { sampled = append(sampled, v) excludeSet.Add(v) } } } return sampled } // lockedSource allows a random number generator to be used by multiple goroutines concurrently. // The code is very similar to math/rand.lockedSource, which is unfortunately not exposed. type lockedSource struct { mut sync.Mutex src rand.Source } // NewRand returns a rand.Rand that is threadsafe. func NewRand(seed int64) *rand.Rand { return rand.New(&lockedSource{src: rand.NewSource(seed)}) } func (r *lockedSource) Int63() (n int64) { r.mut.Lock() n = r.src.Int63() r.mut.Unlock() return } func (r *lockedSource) Seed(seed int64) { r.mut.Lock() r.src.Seed(seed) r.mut.Unlock() } ================================================ FILE: common/util/random_test.go ================================================ // Copyright 2020 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package util import ( "testing" "github.com/chewxy/math32" mapset "github.com/deckarep/golang-set/v2" "github.com/samber/lo" "github.com/stretchr/testify/assert" ) const randomEpsilon = 0.1 func TestRandomGenerator_MakeNormalMatrix(t *testing.T) { rng := NewRandomGenerator(0) vec := rng.NormalMatrix(1, 1000, 1, 2)[0] assert.False(t, math32.Abs(mean(vec)-1) > randomEpsilon) assert.False(t, math32.Abs(stdDev(vec)-2) > randomEpsilon) } func TestRandomGenerator_MakeUniformMatrix(t *testing.T) { rng := NewRandomGenerator(0) vec := rng.UniformMatrix(1, 1000, 1, 2)[0] assert.False(t, lo.Min(vec) < 1) assert.False(t, lo.Max(vec) > 2) } func TestRandomGenerator_Sample(t *testing.T) { excludeSet := mapset.NewSet(0, 1, 2, 3, 4) rng := NewRandomGenerator(0) for i := 1; i <= 10; i++ { sampled := rng.Sample(0, 10, i, excludeSet) for j := range sampled { assert.False(t, excludeSet.Contains(sampled[j])) } } } func TestRandomGenerator_SampleInt32(t *testing.T) { excludeSet := mapset.NewSet[int32](0, 1, 2, 3, 4) rng := NewRandomGenerator(0) for i := 1; i <= 10; i++ { sampled := rng.SampleInt32(0, 10, i, excludeSet) for j := range sampled { assert.False(t, excludeSet.Contains(sampled[j])) } } } // mean of a slice of 32-bit floats. func mean(x []float32) float32 { return lo.Sum(x) / float32(len(x)) } // stdDev returns the sample standard deviation. func stdDev(x []float32) float32 { _, variance := meanVariance(x) return math32.Sqrt(variance) } // meanVariance computes the sample mean and unbiased variance, where the mean and variance are // // \sum_i w_i * x_i / (sum_i w_i) // \sum_i w_i (x_i - mean)^2 / (sum_i w_i - 1) // // respectively. // If weights is nil then all of the weights are 1. If weights is not nil, then // len(x) must equal len(weights). // When weights sum to 1 or less, a biased variance estimator should be used. func meanVariance(x []float32) (m, variance float32) { // This uses the corrected two-pass algorithm (1.7), from "Algorithms for computing // the sample variance: Analysis and recommendations" by Chan, Tony F., Gene H. Golub, // and Randall J. LeVeque. // note that this will panic if the slice lengths do not match m = mean(x) var ( ss float32 compensation float32 ) for _, v := range x { d := v - m ss += d * d compensation += d } variance = (ss - compensation*compensation/float32(len(x))) / float32(len(x)-1) return } ================================================ FILE: common/util/strconv.go ================================================ // Copyright 2025 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package util import ( "reflect" "strconv" "golang.org/x/exp/constraints" ) func ParseFloat[T constraints.Float](s string) (T, error) { v, err := strconv.ParseFloat(s, reflect.TypeOf(T(0)).Bits()) return T(v), err } func ParseUInt[T constraints.Unsigned](s string) (T, error) { v, err := strconv.ParseUint(s, 10, reflect.TypeOf(T(0)).Bits()) return T(v), err } func ParseInt[T constraints.Signed](s string) (T, error) { v, err := strconv.ParseInt(s, 10, reflect.TypeOf(T(0)).Bits()) return T(v), err } func FormatInt[T constraints.Signed](i T) string { return strconv.FormatInt(int64(i), 10) } ================================================ FILE: common/util/tls.go ================================================ // Copyright 2024 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package util import ( "crypto/tls" "crypto/x509" "github.com/juju/errors" "google.golang.org/grpc/credentials" "google.golang.org/grpc/security/advancedtls" "os" ) type TLSConfig struct { SSLCA string SSLCert string SSLKey string } func NewServerCreds(o *TLSConfig) (credentials.TransportCredentials, error) { // Load certification authority ca := x509.NewCertPool() pem, err := os.ReadFile(o.SSLCA) if err != nil { return nil, errors.Trace(err) } if !ca.AppendCertsFromPEM(pem) { return nil, errors.New("failed to append certificate") } // Load certification certificate, err := tls.LoadX509KeyPair(o.SSLCert, o.SSLKey) if err != nil { return nil, errors.Trace(err) } // Create server credentials return advancedtls.NewServerCreds(&advancedtls.Options{ IdentityOptions: advancedtls.IdentityCertificateOptions{ Certificates: []tls.Certificate{certificate}, }, RootOptions: advancedtls.RootCertificateOptions{ RootCertificates: ca, }, RequireClientCert: true, VerificationType: advancedtls.CertVerification, }) } func NewClientCreds(o *TLSConfig) (credentials.TransportCredentials, error) { // Load certification authority ca := x509.NewCertPool() pem, err := os.ReadFile(o.SSLCA) if err != nil { return nil, errors.Trace(err) } if !ca.AppendCertsFromPEM(pem) { return nil, errors.New("failed to append certificate") } // Load certification certificate, err := tls.LoadX509KeyPair(o.SSLCert, o.SSLKey) if err != nil { return nil, errors.Trace(err) } // Create client credentials return advancedtls.NewClientCreds(&advancedtls.Options{ IdentityOptions: advancedtls.IdentityCertificateOptions{ Certificates: []tls.Certificate{certificate}, }, RootOptions: advancedtls.RootCertificateOptions{ RootCertificates: ca, }, VerificationType: advancedtls.CertVerification, }) } ================================================ FILE: common/util/util.go ================================================ // Copyright 2020 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package util import ( "crypto/md5" "encoding/hex" "fmt" "sort" "strings" "github.com/gorse-io/gorse/common/log" "go.uber.org/zap" ) // RangeInt generate a slice [0, ..., n-1]. func RangeInt(n int) []int { a := make([]int, n) for i := range a { a[i] = i } return a } // RepeatFloat32s repeats value n times. func RepeatFloat32s(n int, value float32) []float32 { a := make([]float32, n) for i := range a { a[i] = value } return a } // NewMatrix32 creates a 2D matrix of 32-bit floats. func NewMatrix32(row, col int) [][]float32 { ret := make([][]float32, row) for i := range ret { ret[i] = make([]float32, col) } return ret } // NewTensor32 creates a 3D tensor of 32-bit floats. func NewTensor32(a, b, c int) [][][]float32 { ret := make([][][]float32, a) for i := range ret { ret[i] = NewMatrix32(b, c) } return ret } // NewMatrixInt creates a 2D matrix of integers. func NewMatrixInt(row, col int) [][]int { ret := make([][]int, row) for i := range ret { ret[i] = make([]int, col) } return ret } // CheckPanic catches panic. func CheckPanic() { if r := recover(); r != nil { log.Logger().Error("panic recovered", zap.Any("panic", r)) } } // ValidateId validates user/item id. Id cannot be empty and contain [/,]. func ValidateId(text string) error { text = strings.TrimSpace(text) if text == "" { return fmt.Errorf("id cannot be empty") } else if strings.Contains(text, "/") { return fmt.Errorf("id cannot contain `/`") } return nil } // MD5 computes the MD5 hash of unordered strings. func MD5(s ...string) string { hash := md5.New() sort.Strings(s) for _, str := range s { hash.Write([]byte(str)) } return hex.EncodeToString(hash.Sum(nil)[:]) } ================================================ FILE: common/util/util_test.go ================================================ // Copyright 2020 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package util import ( "testing" "github.com/stretchr/testify/assert" ) func TestNewMatrix32(t *testing.T) { a := NewMatrix32(3, 4) assert.Equal(t, 3, len(a)) assert.Equal(t, 4, len(a[0])) assert.Equal(t, 4, len(a[0])) assert.Equal(t, 4, len(a[0])) } func TestRangeInt(t *testing.T) { a := RangeInt(7) assert.Equal(t, 7, len(a)) for i := range a { assert.Equal(t, i, a[i]) } } func TestRepeatFloat32s(t *testing.T) { a := RepeatFloat32s(3, 0.1) assert.Equal(t, []float32{0.1, 0.1, 0.1}, a) } func TestNewMatrixInt(t *testing.T) { m := NewMatrixInt(4, 3) assert.Equal(t, 4, len(m)) for _, v := range m { assert.Equal(t, 3, len(v)) } } func TestNewTensor32(t *testing.T) { a := NewTensor32(3, 4, 5) assert.Equal(t, 3, len(a)) assert.Equal(t, 4, len(a[0])) assert.Equal(t, 5, len(a[0][0])) } func TestValidateId(t *testing.T) { assert.NotNil(t, ValidateId("")) assert.NotNil(t, ValidateId("/")) assert.Nil(t, ValidateId("abc")) } func TestMD5(t *testing.T) { hash1 := MD5("b", "a", "c") hash2 := MD5("c", "b", "a") assert.Equal(t, hash1, hash2) } ================================================ FILE: config/config.go ================================================ // Copyright 2021 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package config import ( "context" "crypto/md5" _ "embed" "encoding/hex" "os" "path/filepath" "reflect" "strings" "time" mapset "github.com/deckarep/golang-set/v2" "github.com/expr-lang/expr/parser" "github.com/go-playground/locales/en" ut "github.com/go-playground/universal-translator" "github.com/go-playground/validator/v10" en_translations "github.com/go-playground/validator/v10/translations/en" "github.com/go-viper/mapstructure/v2" "github.com/gorse-io/gorse/common/expression" "github.com/gorse-io/gorse/common/log" "github.com/gorse-io/gorse/common/util" "github.com/gorse-io/gorse/storage" "github.com/juju/errors" "github.com/samber/lo" "github.com/spf13/viper" "go.opentelemetry.io/otel/exporters/otlp/otlptrace" "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc" "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp" "go.opentelemetry.io/otel/exporters/zipkin" "go.opentelemetry.io/otel/sdk/resource" tracesdk "go.opentelemetry.io/otel/sdk/trace" semconv "go.opentelemetry.io/otel/semconv/v1.8.0" "go.opentelemetry.io/otel/trace" "go.uber.org/zap" ) func init() { viper.SetOptions(viper.WithDecodeHook(mapstructure.ComposeDecodeHookFunc( mapstructure.StringToTimeDurationHookFunc(), StringToFeedbackTypeHookFunc(), ))) } // Config is the configuration for the engine. type Config struct { Database DatabaseConfig `mapstructure:"database"` Master MasterConfig `mapstructure:"master"` Server ServerConfig `mapstructure:"server"` Recommend RecommendConfig `mapstructure:"recommend"` Tracing TracingConfig `mapstructure:"tracing"` OIDC OIDCConfig `mapstructure:"oidc"` OpenAI OpenAIConfig `mapstructure:"openai"` Blob BlobConfig `mapstructure:"blob"` } // DatabaseConfig is the configuration for the database. type DatabaseConfig struct { DataStore string `mapstructure:"data_store" validate:"required,data_store"` // database for data store CacheStore string `mapstructure:"cache_store" validate:"required,cache_store"` // database for cache store TablePrefix string `mapstructure:"table_prefix"` DataTablePrefix string `mapstructure:"data_table_prefix"` CacheTablePrefix string `mapstructure:"cache_table_prefix"` MySQL MySQLConfig `mapstructure:"mysql"` Postgres SQLConfig `mapstructure:"postgres"` Redis RedisConfig `mapstructure:"redis"` } type MySQLConfig struct { IsolationLevel string `mapstructure:"isolation_level" validate:"oneof=READ-UNCOMMITTED READ-COMMITTED REPEATABLE-READ SERIALIZABLE"` MaxOpenConns int `mapstructure:"max_open_conns" validate:"gte=0"` MaxIdleConns int `mapstructure:"max_idle_conns" validate:"gte=0"` ConnMaxLifetime time.Duration `mapstructure:"conn_max_lifetime" validate:"gte=0"` } type SQLConfig struct { MaxOpenConns int `mapstructure:"max_open_conns" validate:"gte=0"` MaxIdleConns int `mapstructure:"max_idle_conns" validate:"gte=0"` ConnMaxLifetime time.Duration `mapstructure:"conn_max_lifetime" validate:"gte=0"` } type RedisConfig struct { MaxSearchResults int `mapstructure:"max_search_results" validate:"gt=0"` } func (db *DatabaseConfig) StorageOptions(path string) []storage.Option { if strings.HasPrefix(path, storage.MySQLPrefix) { return []storage.Option{ storage.WithIsolationLevel(db.MySQL.IsolationLevel), storage.WithMaxOpenConns(db.MySQL.MaxOpenConns), storage.WithMaxIdleConns(db.MySQL.MaxIdleConns), storage.WithConnMaxLifetime(db.MySQL.ConnMaxLifetime), } } if strings.HasPrefix(path, storage.PostgresPrefix) || strings.HasPrefix(path, storage.PostgreSQLPrefix) { return []storage.Option{ storage.WithMaxOpenConns(db.Postgres.MaxOpenConns), storage.WithMaxIdleConns(db.Postgres.MaxIdleConns), storage.WithConnMaxLifetime(db.Postgres.ConnMaxLifetime), } } if strings.HasPrefix(path, storage.RedisPrefix) || strings.HasPrefix(path, storage.RedissPrefix) || strings.HasPrefix(path, storage.RedisClusterPrefix) || strings.HasPrefix(path, storage.RedissClusterPrefix) { return []storage.Option{ storage.WithMaxSearchResults(db.Redis.MaxSearchResults), } } return nil } // MasterConfig is the configuration for the master. type MasterConfig struct { Port int `mapstructure:"port" validate:"gte=0"` // master port Host string `mapstructure:"host"` // master host SSLMode bool `mapstructure:"ssl_mode"` // enable SSL mode SSLCA string `mapstructure:"ssl_ca"` // SSL CA file SSLCert string `mapstructure:"ssl_cert"` // SSL certificate file SSLKey string `mapstructure:"ssl_key"` // SSL key file HttpPort int `mapstructure:"http_port" validate:"gte=0"` // HTTP port HttpHost string `mapstructure:"http_host"` // HTTP host HttpCorsDomains []string `mapstructure:"http_cors_domains"` // add allowed cors domains HttpCorsMethods []string `mapstructure:"http_cors_methods"` // add allowed cors methods NumJobs int `mapstructure:"n_jobs" validate:"gt=0"` // number of working jobs MetaTimeout time.Duration `mapstructure:"meta_timeout" validate:"gt=0"` // cluster meta timeout (second) DashboardUserName string `mapstructure:"dashboard_user_name"` // dashboard user name DashboardPassword string `mapstructure:"dashboard_password"` // dashboard password DashboardRedacted bool `mapstructure:"dashboard_redacted"` AdminAPIKey string `mapstructure:"admin_api_key"` } // ServerConfig is the configuration for the server. type ServerConfig struct { APIKey string `mapstructure:"api_key"` // default number of returned items DefaultN int `mapstructure:"default_n" validate:"gt=0"` // secret key for RESTful APIs (SSL required) ClockError time.Duration `mapstructure:"clock_error" validate:"gte=0"` // clock error in the cluster in seconds AutoInsertUser bool `mapstructure:"auto_insert_user"` // insert new users while inserting feedback AutoInsertItem bool `mapstructure:"auto_insert_item"` // insert new items while inserting feedback CacheExpire time.Duration `mapstructure:"cache_expire" validate:"gt=0"` // server-side cache expire time } // RecommendConfig is the configuration of recommendation setup. type RecommendConfig struct { CacheSize int `mapstructure:"cache_size" validate:"gt=0"` CacheExpire time.Duration `mapstructure:"cache_expire" validate:"gt=0"` ContextSize int `mapstructure:"context_size" validate:"gt=0"` ActiveUserTTL int `mapstructure:"active_user_ttl" validate:"gte=0"` DataSource DataSourceConfig `mapstructure:"data_source"` NonPersonalized []NonPersonalizedConfig `mapstructure:"non-personalized" validate:"dive"` ItemToItem []ItemToItemConfig `mapstructure:"item-to-item" validate:"dive"` UserToUser []UserToUserConfig `mapstructure:"user-to-user" validate:"dive"` Collaborative CollaborativeConfig `mapstructure:"collaborative"` External []ExternalConfig `mapstructure:"external" validate:"dive"` Replacement ReplacementConfig `mapstructure:"replacement"` Ranker RankerConfig `mapstructure:"ranker"` Fallback FallbackConfig `mapstructure:"fallback"` } func (r *RecommendConfig) ListRecommenders() []string { recommenders := make([]string, 0) for _, rec := range r.NonPersonalized { recommenders = append(recommenders, rec.FullName()) } for _, rec := range r.ItemToItem { recommenders = append(recommenders, rec.FullName()) } for _, rec := range r.UserToUser { recommenders = append(recommenders, rec.FullName()) } for _, rec := range r.External { recommenders = append(recommenders, rec.FullName()) } recommenders = append(recommenders, r.Collaborative.FullName()) recommenders = append(recommenders, "latest") return recommenders } func (r *RecommendConfig) Hash() string { recommenders := mapset.NewSet(r.Ranker.Recommenders...) if recommenders.IsEmpty() { recommenders.Append((r.ListRecommenders())...) } var digests []string for _, rec := range r.NonPersonalized { if recommenders.Contains(rec.FullName()) { digests = append(digests, rec.Hash()) } } for _, rec := range r.ItemToItem { if recommenders.Contains(rec.FullName()) { digests = append(digests, rec.Hash(r)) } } for _, rec := range r.UserToUser { if recommenders.Contains(rec.FullName()) { digests = append(digests, rec.Hash(r)) } } for _, rec := range r.External { if recommenders.Contains(rec.FullName()) { digests = append(digests, rec.Hash()) } } if recommenders.Contains(r.Collaborative.FullName()) { digests = append(digests, r.Collaborative.Hash(r)) } if recommenders.Contains("latest") { digests = append(digests, "latest") } return util.MD5(digests...) } func StringToFeedbackTypeHookFunc() mapstructure.DecodeHookFunc { return func( f reflect.Type, t reflect.Type, data interface{}, ) (interface{}, error) { if f.Kind() == reflect.String && t == reflect.TypeOf(expression.FeedbackTypeExpression{}) { var expr expression.FeedbackTypeExpression if err := expr.FromString(data.(string)); err != nil { return nil, errors.Trace(err) } return expr, nil // only convert string to FeedbackType } return data, nil } } type DataSourceConfig struct { PositiveFeedbackTypes []expression.FeedbackTypeExpression `mapstructure:"positive_feedback_types"` // positive feedback type ReadFeedbackTypes []expression.FeedbackTypeExpression `mapstructure:"read_feedback_types"` // feedback type for read event PositiveFeedbackTTL uint `mapstructure:"positive_feedback_ttl" validate:"gte=0"` // time-to-live of positive feedbacks ItemTTL uint `mapstructure:"item_ttl" validate:"gte=0"` // item-to-live of items } type NonPersonalizedConfig struct { Name string `mapstructure:"name" json:"name"` Score string `mapstructure:"score" json:"score" validate:"required,item_expr"` Filter string `mapstructure:"filter" json:"filter" validate:"item_expr"` } func (config *NonPersonalizedConfig) FullName() string { return "non-personalized/" + config.Name } func (config *NonPersonalizedConfig) Hash() string { hash := md5.New() hash.Write([]byte(config.Name)) hash.Write([]byte(config.Score)) hash.Write([]byte(config.Filter)) return hex.EncodeToString(hash.Sum(nil)[:]) } type ItemToItemConfig struct { Name string `mapstructure:"name" json:"name"` Type string `mapstructure:"type" json:"type" validate:"oneof=embedding tags users chat auto"` Column string `mapstructure:"column" json:"column" validate:"item_expr"` Prompt string `mapstructure:"prompt" json:"prompt"` } func (config *ItemToItemConfig) FullName() string { return "item-to-item/" + config.Name } func (config *ItemToItemConfig) Hash(cfg *RecommendConfig) string { hash := md5.New() hash.Write([]byte(config.Name)) hash.Write([]byte(config.Type)) hash.Write([]byte(config.Column)) if config.Type == "users" { for _, expr := range cfg.DataSource.PositiveFeedbackTypes { hash.Write([]byte(expr.String())) } } return hex.EncodeToString(hash.Sum(nil)[:]) } type UserToUserConfig struct { Name string `mapstructure:"name" json:"name"` Type string `mapstructure:"type" json:"type" validate:"oneof=embedding tags items auto"` Column string `mapstructure:"column" json:"column" validate:"item_expr"` } func (config *UserToUserConfig) FullName() string { return "user-to-user/" + config.Name } func (config *UserToUserConfig) Hash(cfg *RecommendConfig) string { hash := md5.New() hash.Write([]byte(config.Name)) hash.Write([]byte(config.Type)) hash.Write([]byte(config.Column)) if config.Type == "items" { for _, expr := range cfg.DataSource.PositiveFeedbackTypes { hash.Write([]byte(expr.String())) } } return hex.EncodeToString(hash.Sum(nil)[:]) } type CollaborativeConfig struct { Type string `mapstructure:"type" validate:"oneof=none mf"` FitPeriod time.Duration `mapstructure:"fit_period" validate:"gt=0"` FitEpoch int `mapstructure:"fit_epoch" validate:"gt=0"` OptimizePeriod time.Duration `mapstructure:"optimize_period" validate:"gte=0"` OptimizeTrials int `mapstructure:"optimize_trials" validate:"gt=0"` EarlyStopping EarlyStoppingConfig `mapstructure:"early_stopping"` } func (config *CollaborativeConfig) FullName() string { return "collaborative" } func (config *CollaborativeConfig) Hash(cfg *RecommendConfig) string { hash := md5.New() for _, expr := range cfg.DataSource.PositiveFeedbackTypes { hash.Write([]byte(expr.String())) } return hex.EncodeToString(hash.Sum(nil)[:]) } type EarlyStoppingConfig struct { Patience int `mapstructure:"patience"` } type ExternalConfig struct { Name string `mapstructure:"name" json:"name"` Script string `mapstructure:"script" json:"script"` } func (config *ExternalConfig) FullName() string { return "external/" + config.Name } func (config *ExternalConfig) Hash() string { hash := md5.New() hash.Write([]byte(config.Name)) hash.Write([]byte(config.Script)) return hex.EncodeToString(hash.Sum(nil)[:]) } type ReplacementConfig struct { EnableReplacement bool `mapstructure:"enable_replacement"` PositiveReplacementDecay float64 `mapstructure:"positive_replacement_decay" validate:"gt=0"` ReadReplacementDecay float64 `mapstructure:"read_replacement_decay" validate:"gt=0"` } type RankerConfig struct { Type string `mapstructure:"type" validate:"oneof=none fm llm"` Recommenders []string `mapstructure:"recommenders"` CacheExpire time.Duration `mapstructure:"cache_expire" validate:"gt=0"` FitPeriod time.Duration `mapstructure:"fit_period" validate:"gt=0"` FitEpoch int `mapstructure:"fit_epoch" validate:"gt=0"` OptimizePeriod time.Duration `mapstructure:"optimize_period" validate:"gte=0"` OptimizeTrials int `mapstructure:"optimize_trials" validate:"gt=0"` QueryTemplate string `mapstructure:"query_template"` DocumentTemplate string `mapstructure:"document_template"` EarlyStopping EarlyStoppingConfig `mapstructure:"early_stopping"` RerankerAPI RerankerAPIConfig `mapstructure:"reranker_api"` } type FallbackConfig struct { Recommenders []string `mapstructure:"recommenders"` } type TracingConfig struct { EnableTracing bool `mapstructure:"enable_tracing"` Exporter string `mapstructure:"exporter" validate:"oneof=zipkin otlp otlphttp"` CollectorEndpoint string `mapstructure:"collector_endpoint"` Sampler string `mapstructure:"sampler"` Ratio float64 `mapstructure:"ratio"` } type OIDCConfig struct { Enable bool `mapstructure:"enable"` Issuer string `mapstructure:"issuer"` ClientID string `mapstructure:"client_id"` ClientSecret string `mapstructure:"client_secret"` RedirectURL string `mapstructure:"redirect_url" validate:"omitempty,endswith=/callback/oauth2"` } type RerankerAPIConfig struct { AuthToken string `mapstructure:"auth_token"` Model string `mapstructure:"model"` URL string `mapstructure:"url"` } type OpenAIConfig struct { BaseURL string `mapstructure:"base_url"` AuthToken string `mapstructure:"auth_token"` ChatCompletionModel string `mapstructure:"chat_completion_model"` ChatCompletionRPM int `mapstructure:"chat_completion_rpm"` ChatCompletionTPM int `mapstructure:"chat_completion_tpm"` EmbeddingModel string `mapstructure:"embedding_model"` EmbeddingDimensions int `mapstructure:"embedding_dimensions"` EmbeddingRPM int `mapstructure:"embedding_rpm"` EmbeddingTPM int `mapstructure:"embedding_tpm"` LogFile string `mapstructure:"log_file"` } type BlobConfig struct { URI string `mapstructure:"uri" validate:"required"` S3 S3Config `mapstructure:"s3"` GCS GCSConfig `mapstructure:"gcs"` Azure AzureBlobConfig `mapstructure:"azure"` } type S3Config struct { Endpoint string `mapstructure:"endpoint"` AccessKeyID string `mapstructure:"access_key_id"` SecretAccessKey string `mapstructure:"secret_access_key"` } type GCSConfig struct { CredentialsFile string `mapstructure:"credentials_file"` } type AzureBlobConfig struct { Endpoint string `mapstructure:"endpoint"` AccountName string `mapstructure:"account_name"` AccountKey string `mapstructure:"account_key"` ConnectionString string `mapstructure:"connection_string"` } func GetDefaultConfig() *Config { return &Config{ Database: DatabaseConfig{ DataStore: "sqlite://" + filepath.Join(MkDir(), "data.sqlite"), CacheStore: "sqlite://" + filepath.Join(MkDir(), "cache.sqlite"), MySQL: MySQLConfig{ IsolationLevel: "READ-UNCOMMITTED", MaxOpenConns: 0, MaxIdleConns: 0, ConnMaxLifetime: 0, }, Postgres: SQLConfig{ MaxOpenConns: 64, MaxIdleConns: 64, ConnMaxLifetime: time.Minute, }, Redis: RedisConfig{ MaxSearchResults: 10000, }, }, Master: MasterConfig{ Port: 8086, Host: "0.0.0.0", HttpPort: 8088, HttpHost: "0.0.0.0", HttpCorsDomains: []string{".*"}, HttpCorsMethods: []string{"GET", "POST", "PUT", "DELETE", "PATCH"}, NumJobs: 1, MetaTimeout: 10 * time.Second, }, Server: ServerConfig{ DefaultN: 10, ClockError: 5 * time.Second, AutoInsertUser: true, AutoInsertItem: true, CacheExpire: 10 * time.Second, }, Recommend: RecommendConfig{ CacheSize: 100, CacheExpire: 72 * time.Hour, ContextSize: 100, Collaborative: CollaborativeConfig{ Type: "none", FitPeriod: 60 * time.Minute, FitEpoch: 100, OptimizePeriod: 0, OptimizeTrials: 10, }, Replacement: ReplacementConfig{ EnableReplacement: false, PositiveReplacementDecay: 0.8, ReadReplacementDecay: 0.6, }, Ranker: RankerConfig{ Type: "none", CacheExpire: 120 * time.Hour, FitPeriod: 60 * time.Minute, FitEpoch: 100, OptimizePeriod: 0, OptimizeTrials: 10, Recommenders: []string{"latest"}, }, }, Tracing: TracingConfig{ Exporter: "otlp", Sampler: "always", }, Blob: BlobConfig{ URI: MkDir("blob"), }, } } //go:embed config.toml var ConfigTOML string func (config *Config) Now() *time.Time { return lo.ToPtr(time.Now().Add(config.Server.ClockError)) } func (config *TracingConfig) NewTracerProvider() (trace.TracerProvider, error) { if !config.EnableTracing { return trace.NewNoopTracerProvider(), nil } var exporter tracesdk.SpanExporter var err error switch config.Exporter { case "zipkin": exporter, err = zipkin.New(config.CollectorEndpoint) if err != nil { return nil, errors.Trace(err) } case "otlp": client := otlptracegrpc.NewClient(otlptracegrpc.WithInsecure(), otlptracegrpc.WithEndpoint(config.CollectorEndpoint)) exporter, err = otlptrace.New(context.TODO(), client) if err != nil { return nil, errors.Trace(err) } case "otlphttp": client := otlptracehttp.NewClient(otlptracehttp.WithInsecure(), otlptracehttp.WithEndpoint(config.CollectorEndpoint)) exporter, err = otlptrace.New(context.TODO(), client) if err != nil { return nil, errors.Trace(err) } default: return nil, errors.NotSupportedf("exporter %s", config.Exporter) } var sampler tracesdk.Sampler switch config.Sampler { case "always": sampler = tracesdk.AlwaysSample() case "never": sampler = tracesdk.NeverSample() case "ratio": sampler = tracesdk.TraceIDRatioBased(config.Ratio) default: return nil, errors.NotSupportedf("sampler %s", config.Sampler) } return tracesdk.NewTracerProvider( tracesdk.WithSampler(sampler), tracesdk.WithBatcher(exporter), tracesdk.WithResource(resource.NewWithAttributes( semconv.SchemaURL, semconv.ServiceNameKey.String("gorse"), )), ), nil } func (config *TracingConfig) Equal(other TracingConfig) bool { if config == nil { return false } return config.EnableTracing == other.EnableTracing && config.Exporter == other.Exporter && config.CollectorEndpoint == other.CollectorEndpoint && config.Sampler == other.Sampler && config.Ratio == other.Ratio } func setDefault() { defaultConfig := GetDefaultConfig() // [database] viper.SetDefault("database.data_store", defaultConfig.Database.DataStore) viper.SetDefault("database.cache_store", defaultConfig.Database.CacheStore) // [database.mysql] viper.SetDefault("database.mysql.isolation_level", defaultConfig.Database.MySQL.IsolationLevel) viper.SetDefault("database.mysql.max_open_conns", defaultConfig.Database.MySQL.MaxOpenConns) viper.SetDefault("database.mysql.max_idle_conns", defaultConfig.Database.MySQL.MaxIdleConns) viper.SetDefault("database.mysql.conn_max_lifetime", defaultConfig.Database.MySQL.ConnMaxLifetime) // [database.postgres] viper.SetDefault("database.postgres.max_open_conns", defaultConfig.Database.Postgres.MaxOpenConns) viper.SetDefault("database.postgres.max_idle_conns", defaultConfig.Database.Postgres.MaxIdleConns) viper.SetDefault("database.postgres.conn_max_lifetime", defaultConfig.Database.Postgres.ConnMaxLifetime) // [database.redis] viper.SetDefault("database.redis.max_search_results", defaultConfig.Database.Redis.MaxSearchResults) // [master] viper.SetDefault("master.port", defaultConfig.Master.Port) viper.SetDefault("master.host", defaultConfig.Master.Host) viper.SetDefault("master.http_port", defaultConfig.Master.HttpPort) viper.SetDefault("master.http_host", defaultConfig.Master.HttpHost) viper.SetDefault("master.http_cors_domains", defaultConfig.Master.HttpCorsDomains) viper.SetDefault("master.http_cors_methods", defaultConfig.Master.HttpCorsMethods) viper.SetDefault("master.n_jobs", defaultConfig.Master.NumJobs) viper.SetDefault("master.meta_timeout", defaultConfig.Master.MetaTimeout) // [server] viper.SetDefault("server.api_key", defaultConfig.Server.APIKey) viper.SetDefault("server.default_n", defaultConfig.Server.DefaultN) viper.SetDefault("server.clock_error", defaultConfig.Server.ClockError) viper.SetDefault("server.auto_insert_user", defaultConfig.Server.AutoInsertUser) viper.SetDefault("server.auto_insert_item", defaultConfig.Server.AutoInsertItem) viper.SetDefault("server.cache_expire", defaultConfig.Server.CacheExpire) // [recommend] viper.SetDefault("recommend.cache_size", defaultConfig.Recommend.CacheSize) viper.SetDefault("recommend.cache_expire", defaultConfig.Recommend.CacheExpire) viper.SetDefault("recommend.context_size", defaultConfig.Recommend.ContextSize) // [recommend.collaborative] viper.SetDefault("recommend.collaborative.type", defaultConfig.Recommend.Collaborative.Type) viper.SetDefault("recommend.collaborative.fit_period", defaultConfig.Recommend.Collaborative.FitPeriod) viper.SetDefault("recommend.collaborative.fit_epoch", defaultConfig.Recommend.Collaborative.FitEpoch) viper.SetDefault("recommend.collaborative.optimize_period", defaultConfig.Recommend.Collaborative.OptimizePeriod) viper.SetDefault("recommend.collaborative.optimize_trials", defaultConfig.Recommend.Collaborative.OptimizeTrials) // [recommend.replacement] viper.SetDefault("recommend.replacement.enable_replacement", defaultConfig.Recommend.Replacement.EnableReplacement) viper.SetDefault("recommend.replacement.positive_replacement_decay", defaultConfig.Recommend.Replacement.PositiveReplacementDecay) viper.SetDefault("recommend.replacement.read_replacement_decay", defaultConfig.Recommend.Replacement.ReadReplacementDecay) // [recommend.ranker] viper.SetDefault("recommend.ranker.type", defaultConfig.Recommend.Ranker.Type) viper.SetDefault("recommend.ranker.cache_expire", defaultConfig.Recommend.Ranker.CacheExpire) viper.SetDefault("recommend.ranker.fit_period", defaultConfig.Recommend.Ranker.FitPeriod) viper.SetDefault("recommend.ranker.fit_epoch", defaultConfig.Recommend.Ranker.FitEpoch) viper.SetDefault("recommend.ranker.optimize_period", defaultConfig.Recommend.Ranker.OptimizePeriod) viper.SetDefault("recommend.ranker.optimize_trials", defaultConfig.Recommend.Ranker.OptimizeTrials) viper.SetDefault("recommend.ranker.recommenders", defaultConfig.Recommend.Ranker.Recommenders) // [recommend.fallback] viper.SetDefault("recommend.fallback", defaultConfig.Recommend.Fallback) // [tracing] viper.SetDefault("tracing.exporter", defaultConfig.Tracing.Exporter) viper.SetDefault("tracing.sampler", defaultConfig.Tracing.Sampler) // [blob] viper.SetDefault("blob.uri", defaultConfig.Blob.URI) } type configBinding struct { key string env string } var bindings = []configBinding{ {"database.cache_store", "GORSE_CACHE_STORE"}, {"database.data_store", "GORSE_DATA_STORE"}, {"database.table_prefix", "GORSE_TABLE_PREFIX"}, {"database.cache_table_prefix", "GORSE_CACHE_TABLE_PREFIX"}, {"database.data_table_prefix", "GORSE_DATA_TABLE_PREFIX"}, {"master.port", "GORSE_MASTER_PORT"}, {"master.host", "GORSE_MASTER_HOST"}, {"master.ssl_mode", "GORSE_MASTER_SSL_MODE"}, {"master.ssl_ca", "GORSE_MASTER_SSL_CA"}, {"master.ssl_cert", "GORSE_MASTER_SSL_CERT"}, {"master.ssl_key", "GORSE_MASTER_SSL_KEY"}, {"master.http_port", "GORSE_MASTER_HTTP_PORT"}, {"master.http_host", "GORSE_MASTER_HTTP_HOST"}, {"master.n_jobs", "GORSE_MASTER_JOBS"}, {"master.dashboard_user_name", "GORSE_DASHBOARD_USER_NAME"}, {"master.dashboard_password", "GORSE_DASHBOARD_PASSWORD"}, {"master.dashboard_auth_server", "GORSE_DASHBOARD_AUTH_SERVER"}, {"master.dashboard_redacted", "GORSE_DASHBOARD_REDACTED"}, {"master.admin_api_key", "GORSE_ADMIN_API_KEY"}, {"server.api_key", "GORSE_SERVER_API_KEY"}, {"oidc.enable", "GORSE_OIDC_ENABLE"}, {"oidc.issuer", "GORSE_OIDC_ISSUER"}, {"oidc.client_id", "GORSE_OIDC_CLIENT_ID"}, {"oidc.client_secret", "GORSE_OIDC_CLIENT_SECRET"}, {"oidc.redirect_url", "GORSE_OIDC_REDIRECT_URL"}, {"blob.uri", "GORSE_BLOB_URI"}, {"blob.s3.endpoint", "S3_ENDPOINT"}, {"blob.s3.access_key_id", "S3_ACCESS_KEY_ID"}, {"blob.s3.secret_access_key", "S3_SECRET_ACCESS_KEY"}, {"blob.gcs.credentials_file", "GCS_CREDENTIALS_FILE"}, {"blob.azure.endpoint", "AZURE_STORAGE_ENDPOINT"}, {"blob.azure.account_name", "AZURE_STORAGE_ACCOUNT_NAME"}, {"blob.azure.account_key", "AZURE_STORAGE_ACCOUNT_KEY"}, {"blob.azure.connection_string", "AZURE_STORAGE_CONNECTION_STRING"}, {"openai.base_url", "OPENAI_BASE_URL"}, {"openai.auth_token", "OPENAI_AUTH_TOKEN"}, {"openai.chat_completion_model", "OPENAI_CHAT_COMPLETION_MODEL"}, {"recommend.ranker.reranker_api.url", "RERANKER_URL"}, {"recommend.ranker.reranker_api.model", "RERANKER_MODEL"}, {"recommend.ranker.reranker_api.auth_token", "RERANKER_AUTH_TOKEN"}, } // LoadConfig loads configuration from toml file. func LoadConfig(path string) (*Config, error) { // set default config setDefault() // bind environment bindings for _, binding := range bindings { err := viper.BindEnv(binding.key, binding.env) if err != nil { log.Logger().Fatal("failed to bind a Viper key to a ENV variable", zap.Error(err)) } } // load config file if provided if path != "" { // check if file exist if _, err := os.Stat(path); err != nil { if os.IsNotExist(err) { log.Logger().Warn("config file not found, use default config", zap.String("path", path)) } else { return nil, errors.Trace(err) } } else { // load config file viper.SetConfigFile(path) if err := viper.ReadInConfig(); err != nil { return nil, errors.Trace(err) } } } else { log.Logger().Info("no config file provided, use defaults and environment variables") } // unmarshal config file var conf Config if err := viper.Unmarshal(&conf); err != nil { return nil, errors.Trace(err) } // validate config file if err := conf.Validate(); err != nil { return nil, errors.Trace(err) } // apply table prefix if conf.Database.CacheTablePrefix == "" { conf.Database.CacheTablePrefix = conf.Database.TablePrefix } if conf.Database.DataTablePrefix == "" { conf.Database.DataTablePrefix = conf.Database.TablePrefix } return &conf, nil } func (config *Config) Validate() error { // Check non-personalized recommenders nonPersonalizedNames := mapset.NewSet[string]() for _, nonPersonalized := range config.Recommend.NonPersonalized { if nonPersonalizedNames.Contains(nonPersonalized.Name) { return errors.Errorf("non-personalized recommender %v is duplicated", nonPersonalized.Name) } nonPersonalizedNames.Add(nonPersonalized.Name) } // Check item-to-item recommenders itemToItemNames := mapset.NewSet[string]() for _, itemToItem := range config.Recommend.ItemToItem { if itemToItemNames.Contains(itemToItem.Name) { return errors.Errorf("item-to-item recommender %v is duplicated", itemToItem.Name) } itemToItemNames.Add(itemToItem.Name) } // Check recommender existence and collaborative enabled availableRecommenders := mapset.NewSet[string]() for _, rec := range config.Recommend.NonPersonalized { availableRecommenders.Add(rec.FullName()) } for _, rec := range config.Recommend.ItemToItem { availableRecommenders.Add(rec.FullName()) } for _, rec := range config.Recommend.UserToUser { availableRecommenders.Add(rec.FullName()) } for _, rec := range config.Recommend.External { availableRecommenders.Add(rec.FullName()) } availableRecommenders.Add("latest") if config.Recommend.Collaborative.Type != "none" { availableRecommenders.Add(config.Recommend.Collaborative.FullName()) } checkRecommenders := func(recommenders []string) error { for _, recommender := range recommenders { if recommender == config.Recommend.Collaborative.FullName() && config.Recommend.Collaborative.Type == "none" { return errors.New("collaborative recommender is disabled") } if !availableRecommenders.Contains(recommender) { return errors.Errorf("recommender %v doesn't exist", recommender) } } return nil } validate := validator.New() if err := validate.RegisterValidation("data_store", func(fl validator.FieldLevel) bool { prefixes := []string{ storage.MongoPrefix, storage.MongoSrvPrefix, storage.MySQLPrefix, storage.PostgresPrefix, storage.PostgreSQLPrefix, storage.ClickhousePrefix, storage.CHHTTPPrefix, storage.CHHTTPSPrefix, storage.SQLitePrefix, } for _, prefix := range prefixes { if strings.HasPrefix(fl.Field().String(), prefix) { return true } } return false }); err != nil { return errors.Trace(err) } if err := validate.RegisterValidation("cache_store", func(fl validator.FieldLevel) bool { prefixes := []string{ storage.RedisPrefix, storage.RedissPrefix, storage.MongoPrefix, storage.MongoSrvPrefix, storage.MySQLPrefix, storage.PostgresPrefix, storage.PostgreSQLPrefix, storage.SQLitePrefix, } for _, prefix := range prefixes { if strings.HasPrefix(fl.Field().String(), prefix) { return true } } return false }); err != nil { return errors.Trace(err) } if err := validate.RegisterValidation("item_expr", func(fl validator.FieldLevel) bool { if fl.Field().String() == "" { // Empty expression is legal. return true } _, err := parser.Parse(fl.Field().String()) return err == nil }); err != nil { return errors.Trace(err) } validate.RegisterTagNameFunc(func(fld reflect.StructField) string { return strings.SplitN(fld.Tag.Get("mapstructure"), ",", 2)[0] }) err := validate.Struct(config) if err != nil { // translate errors trans := ut.New(en.New()).GetFallback() if err := en_translations.RegisterDefaultTranslations(validate, trans); err != nil { return errors.Trace(err) } if err := validate.RegisterTranslation("data_store", trans, func(ut ut.Translator) error { return ut.Add("data_store", "unsupported data storage backend", true) // see universal-translator for details }, func(ut ut.Translator, fe validator.FieldError) string { t, _ := ut.T("data_store", fe.Field()) return t }); err != nil { return errors.Trace(err) } if err := validate.RegisterTranslation("cache_store", trans, func(ut ut.Translator) error { return ut.Add("cache_store", "unsupported cache storage backend", true) // see universal-translator for details }, func(ut ut.Translator, fe validator.FieldError) string { t, _ := ut.T("cache_store", fe.Field()) return t }); err != nil { return errors.Trace(err) } if err := validate.RegisterTranslation("item_expr", trans, func(ut ut.Translator) error { return ut.Add("item_expr", "invalid item expression", true) }, func(ut ut.Translator, fe validator.FieldError) string { t, _ := ut.T("item_expr", fe.Field()) return t }); err != nil { return errors.Trace(err) } errs := err.(validator.ValidationErrors) for _, e := range errs { return errors.New(e.Translate(trans)) } } if len(config.Recommend.Ranker.Recommenders) == 0 { return errors.New("ranker.recommenders must not be empty") } if config.Recommend.Ranker.Type == "none" && len(config.Recommend.Ranker.Recommenders) > 1 { return errors.New("ranker.recommenders must contain at most one recommender when ranker.type is none") } if err := checkRecommenders(config.Recommend.Ranker.Recommenders); err != nil { return err } if err := checkRecommenders(config.Recommend.Fallback.Recommenders); err != nil { return err } return nil } var RootDir string // MkDir creates a directory under Gorse home directory. func MkDir(elem ...string) string { if RootDir == "" { RootDir = filepath.Join(lo.Must(os.UserHomeDir()), ".gorse", "var", "lib") } path := filepath.Join(RootDir, filepath.Join(elem...)) lo.Must0(os.MkdirAll(path, 0755)) return path } ================================================ FILE: config/config.toml ================================================ [database] # The database for caching, support Redis, MySQL, Postgres and MongoDB: # redis://:@:/ # rediss://:@:/ # redis+cluster://:@:[?addr=:&addr=:] # rediss+cluster://:@:[?addr=:&addr=:] # mysql://[username[:password]@][protocol[(address)]]/dbname[?param1=value1&...¶mN=valueN] # postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full # postgresql://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full # mongodb://[username:password@]host1[:port1][,...hostN[:portN]][/[defaultauthdb][?options]] # mongodb+srv://[username:password@]host1[:port1][,...hostN[:portN]][/[defaultauthdb][?options]] # sqlite:// cache_store = "redis://localhost:6379/0" # The database for persist data, support MySQL, Postgres, ClickHouse and MongoDB: # mysql://[username[:password]@][protocol[(address)]]/dbname[?param1=value1&...¶mN=valueN] # postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full # postgresql://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full # clickhouse://user:password@host[:port]/database?param1=value1&...¶mN=valueN # chhttp://user:password@host[:port]/database?param1=value1&...¶mN=valueN # chhttps://user:password@host[:port]/database?param1=value1&...¶mN=valueN # mongodb://[username:password@]host1[:port1][,...hostN[:portN]][/[defaultauthdb][?options]] # mongodb+srv://[username:password@]host1[:port1][,...hostN[:portN]][/[defaultauthdb][?options]] # sqlite:// data_store = "mysql://gorse:gorse_pass@tcp(localhost:3306)/gorse" # The naming prefix for tables (collections, keys) in databases. The default value is empty. table_prefix = "" # The naming prefix for tables (collections, keys) in cache storage databases. The default value is `table_prefix`. cache_table_prefix = "" # The naming prefix for tables (collections, keys) in data storage databases. The default value is `table_prefix`. data_table_prefix = "" [database.mysql] # Transaction isolation level. The default value is "READ-UNCOMMITTED". isolation_level = "READ-UNCOMMITTED" # Maximum number of open connections to the database. Set to 0 to keep the driver default. max_open_conns = 0 # Maximum number of idle connections to the database. Set to 0 to keep the driver default. max_idle_conns = 0 # Maximum amount of time a connection may be reused. Set to "0s" to keep the driver default. conn_max_lifetime = "0s" [database.postgres] # Maximum number of open connections to the database. The default value is 64. max_open_conns = 64 # Maximum number of idle connections to the database. The default value is 64. max_idle_conns = 64 # Maximum amount of time a connection may be reused. The default value is "1m". conn_max_lifetime = "1m" [database.redis] # The maximum number of results to be returned by the FT.SEARCH command if LIMIT is used max_search_results = 10000 [master] # GRPC port of the master node. The default value is 8086. port = 8086 # gRPC host of the master node. The default values is "0.0.0.0". host = "0.0.0.0" # Enable SSL for the gRPC communication. The default value is false. ssl_mode = false # SSL certification authority for the gRPC communication. ssl_ca = "" # SSL certification for the gRPC communication. ssl_cert = "" # SSL certification key for the gRPC communication. ssl_key = "" # HTTP port of the master node. The default values is 8088. http_port = 8088 # HTTP host of the master node. The default values is "0.0.0.0". http_host = "0.0.0.0" # AllowedDomains is a list of allowed values for Http Origin. # The list may contain the special wildcard string ".*" ; all is allowed # If empty all are allowed. http_cors_domains = [] # AllowedMethods is either empty or has a list of http methods names. Checking is case-insensitive. http_cors_methods = [] # Number of working jobs in the master node. The default value is 1. n_jobs = 1 # Meta information timeout. The default value is 10s. meta_timeout = "10s" # Username for the master node dashboard. dashboard_user_name = "" # Password for the master node dashboard. dashboard_password = "" # Secret key for admin APIs (SSL required). admin_api_key = "" [server] # Default number of returned items. The default value is 10. default_n = 10 # Secret key for RESTful APIs (SSL required). api_key = "" # Clock error in the cluster. The default value is 5s. clock_error = "5s" # Insert new users while inserting feedback. The default value is true. auto_insert_user = true # Insert new items while inserting feedback. The default value is true. auto_insert_item = true # Server-side cache expire time. The default value is 10s. cache_expire = "10s" [recommend] # The cache size for recommended/popular/latest items. The default value is 10. cache_size = 100 # Recommended cache expire time. The default value is 72h. cache_expire = "72h" # The context size for online recommendations. Online recommendations can't use all user feedbacks to generate # recommendations for efficiency consideration. Instead, only the latest `context_size` feedbacks are used. # The default value is 100. context_size = 100 # The time-to-live (days) of active users, 0 means disabled. Recommendation won't be cached for inactive users. The default value is 0. active_user_ttl = 0 [recommend.data_source] # The feedback types for positive events. positive_feedback_types = ["star","like","read>=3"] # The feedback types for read events. read_feedback_types = ["read"] # The time-to-live (days) of positive feedback, 0 means disabled. The default value is 0. positive_feedback_ttl = 0 # The time-to-live (days) of items, 0 means disabled. The default value is 0. item_ttl = 0 [[recommend.non-personalized]] # The name of the leaderboard. name = "most_starred_weekly" # The score function for items in the leaderboard. score = "count(feedback, .FeedbackType == 'star')" # The filter for items in the leaderboard. filter = "(now() - item.Timestamp).Hours() < 168" [[recommend.item-to-item]] # The name of the item-to-item recommender. name = "neighbors" # The type of the item-to-item recommender. There are three types: # embedding: recommend by Euclidean distance of embeddings. # tags: recommend by number of common tags. # users: recommend by number of common users. type = "embedding" # The column of the item embeddings. Leave blank if type is "users". column = "item.Labels.embedding" [[recommend.user-to-user]] # The name of the user-to-user recommender. name = "neighbors" # The type of the user-to-user recommender. There are three types: # embedding: recommend by Euclidean distance of embeddings. # tags: recommend by number of common tags. # items: recommend by number of common items. type = "items" [[recommend.external]] # The name of the external recommender. name = "trending" # The script to fetch external recommended items. The script should return a list of item IDs. script = """ const response = fetch("https://cdn.jsdelivr.net/gh/isboyjc/github-trending-api/data/daily/all.json"); if (!response.ok) { throw new Error(`${response.status} ${response.body}`); } const data = JSON.parse(response.body); data["items"].map((item) => { return item["title"].toLowerCase().replace("/", ":"); }) """ [recommend.collaborative] # The type of collaborative filtering. Supported values: # none: disable collaborative filtering. # mf: matrix factorization. type = "mf" # The time period for model fitting. The default value is "60m". fit_period = "60m" # The number of epochs for model fitting. The default value is 100. fit_epoch = 100 # The time period for hyperparameter optimization, set to 0 to disable. The default value is "0". optimize_period = "360m" # The number of trials for hyperparameter optimization. The default value is 10. optimize_trials = 10 [recommend.collaborative.early_stopping] # Number of epochs to wait if no improvement and then stop the training. The default value is 10. patience = 10 [recommend.replacement] # Replace historical items back to recommendations. The default value is false. enable_replacement = false # Decay the weights of replaced items from positive feedbacks. The default value is 0.8. positive_replacement_decay = 0.8 # Decay the weights of replaced items from read feedbacks. The default value is 0.6. read_replacement_decay = 0.6 [recommend.ranker] # The type of the ranker. There are two types: # none: no ranking (default). # fm: factorization machines. # llm: LLM-based reranker. type = "fm" # The time period to refresh recommendation for inactive users. The default values is 120h. cache_expire = "120h" # The recommenders used to fetch candidate items before ranking. The default values is all recommenders. recommenders = ["latest", "collaborative", "non-personalized/most_starred_weekly", "item-to-item/neighbors", "user-to-user/neighbors"] # The time period for model fitting. The default value is "60m". fit_period = "60m" # The number of epochs for model fitting. The default value is 100. fit_epoch = 100 # The time period for hyperparameter optimization, set to 0 to disable. The default value is "0". optimize_period = "360m" # The number of trials for hyperparameter optimization. The default value is 10. optimize_trials = 10 # The prompt template for query. query_template = """ You are a GitHub repository recommender system. Given a user is interested in the following repositories: {% for repo in feedback -%} - {{ repo.Comment }} {% endfor -%} Please sort repositories by the user's interests. """ # The prompt template for documents. document_template = '{{ item.Comment | replace(",", " ") | replace("\n", " ") }}' [recommend.ranker.early_stopping] # Number of epochs to wait if no improvement and then stop the training. The default value is 10. patience = 10 [recommend.ranker.reranker_api] # Auth token for the reranker API. auth_token = "" # The reranker model. model = "qwen3-rerank" # URL for the reranker API, supports Jina style. url = "https://dashscope.aliyuncs.com/compatible-api/v1/reranks" [recommend.fallback] # The fallback recommenders are used when cached recommendation drained out. The default values is ["latest"]. recommenders = ["item-to-item/neighbors", "latest"] [tracing] # Enable tracing for REST APIs. The default value is false. enable_tracing = false # The type of tracing exporters should be one of "zipkin", "otlp" and "otlphttp". The default value is "otlp". exporter = "otlp" # The endpoint of tracing collector. collector_endpoint = "http://localhost:4317" # The type of tracing sampler should be one of "always", "never" and "ratio". The default value is "always". sampler = "always" # The ratio of ratio based sampler. The default value is 1. ratio = 1 [oidc] # Enable OpenID Connect (OIDC) authentication. The default value is false. enable = false # The issuer of the OAuth provider. issuer = "" # Public identifier of the OAuth application. client_id = "" # Token access to the OAuth application. client_secret = "" # URL used by the OAuth provider to redirect users after they are successfully authenticated # (also referred to as the callback URL). You should set this to the concatenation of the # Gorse dashboard URL and "/callback/oauth2". For example, if the Gorse dashboard URL is # http://localhost:8088, the redirect URL should be: http://localhost:8088/callback/oauth2 redirect_url = "" [blob] # Blob storage URI. # - S3: s3://bucket/path # - GCS: gs://my-bucket/my-database # - Azure Blob: az://container/path # - Local: /path/without/prefix uri = "/var/lib/gorse/blob" [blob.s3] # S3 endpoint. endpoint = "" # S3 access key ID. access_key_id = "" # S3 secret access key. secret_access_key = "" [blob.gcs] # GCS credentials file. credentials_file = "" [blob.azure] # Azure Blob endpoint. Leave empty to use https://.blob.core.windows.net/. endpoint = "" # Azure storage account name. account_name = "" # Azure storage account key. account_key = "" # Azure storage connection string (takes precedence over account_name/account_key). connection_string = "" [openai] # Base URL of OpenAI API. base_url = "http://localhost:11434/v1" # API key of OpenAI API. auth_token = "ollama" # Name of chat completion model. chat_completion_model = "qwen2.5" # Maximum requests per minute for chat completion. chat_completion_rpm = 15000 # Maximum tokens per minute for chat completion. chat_completion_tpm = 1200000 # Name of embedding model. embedding_model = "mxbai-embed-large" # Dimensions of embedding vectors. embedding_dimensions = 1024 # Maximum requests per minute for embedding. embedding_rpm = 1800 # Maximum tokens per minute for embedding. embedding_tpm = 1200000 # Log file for OpenAI API. log_file = "openai.log" ================================================ FILE: config/config_test.go ================================================ // Copyright 2020 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package config import ( "bytes" "fmt" "os" "path/filepath" "strings" "testing" "time" "github.com/gorse-io/gorse/common/expression" "github.com/sclevine/yj/convert" "github.com/spf13/viper" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" ) func TestUnmarshal(t *testing.T) { text := ConfigTOML text = strings.Replace(text, "ssl_mode = false", "ssl_mode = true", -1) text = strings.Replace(text, "ssl_ca = \"\"", "ssl_ca = \"ca.pem\"", -1) text = strings.Replace(text, "ssl_cert = \"\"", "ssl_cert = \"cert.pem\"", -1) text = strings.Replace(text, "ssl_key = \"\"", "ssl_key = \"key.pem\"", -1) text = strings.Replace(text, "dashboard_user_name = \"\"", "dashboard_user_name = \"admin\"", -1) text = strings.Replace(text, "dashboard_password = \"\"", "dashboard_password = \"password\"", -1) text = strings.Replace(text, "admin_api_key = \"\"", "admin_api_key = \"super_api_key\"", -1) text = strings.Replace(text, "api_key = \"\"", "api_key = \"19260817\"", -1) text = strings.Replace(text, "table_prefix = \"\"", "table_prefix = \"gorse_\"", -1) text = strings.Replace(text, "cache_table_prefix = \"gorse_\"", "cache_table_prefix = \"gorse_cache_\"", -1) text = strings.Replace(text, "data_table_prefix = \"gorse_\"", "data_table_prefix = \"gorse_data_\"", -1) text = strings.Replace(text, "http_cors_domains = []", "http_cors_domains = [\".*\"]", -1) text = strings.Replace(text, "http_cors_methods = []", "http_cors_methods = [\"GET\",\"PATCH\",\"POST\"]", -1) text = strings.Replace(text, "issuer = \"\"", "issuer = \"https://accounts.google.com\"", -1) text = strings.Replace(text, "client_id = \"\"", "client_id = \"client_id\"", -1) text = strings.Replace(text, "client_secret = \"\"", "client_secret = \"client_secret\"", -1) text = strings.Replace(text, "redirect_url = \"\"", "redirect_url = \"http://localhost:8088/callback/oauth2\"", -1) text = strings.Replace(text, "auth_token = \"\"", "auth_token = \"\"", -1) text = strings.Replace(text, "url = \"https://dashscope.aliyuncs.com/compatible-api/v1/reranks\"", "url = \"\"", -1) r, err := convert.TOML{}.Decode(bytes.NewBufferString(text)) assert.NoError(t, err) encodings := []convert.Encoding{convert.TOML{}, convert.YAML{}, convert.JSON{}} for _, encoding := range encodings { t.Run(encoding.String(), func(t *testing.T) { filePath := filepath.Join(os.TempDir(), fmt.Sprintf("config.%s", strings.ToLower(encoding.String()))) fp, err := os.Create(filePath) assert.NoError(t, err) err = encoding.Encode(fp, r) assert.NoError(t, err) config, err := LoadConfig(filePath) assert.NoError(t, err) // [database] assert.Equal(t, "redis://localhost:6379/0", config.Database.CacheStore) assert.Equal(t, "mysql://gorse:gorse_pass@tcp(localhost:3306)/gorse", config.Database.DataStore) assert.Equal(t, "gorse_", config.Database.TablePrefix) assert.Equal(t, "gorse_cache_", config.Database.CacheTablePrefix) assert.Equal(t, "gorse_data_", config.Database.DataTablePrefix) assert.Equal(t, "READ-UNCOMMITTED", config.Database.MySQL.IsolationLevel) assert.Equal(t, 0, config.Database.MySQL.MaxOpenConns) assert.Equal(t, 0, config.Database.MySQL.MaxIdleConns) assert.Equal(t, time.Duration(0), config.Database.MySQL.ConnMaxLifetime) assert.Equal(t, 64, config.Database.Postgres.MaxOpenConns) assert.Equal(t, 64, config.Database.Postgres.MaxIdleConns) assert.Equal(t, time.Minute, config.Database.Postgres.ConnMaxLifetime) assert.Equal(t, 10000, config.Database.Redis.MaxSearchResults) // [master] assert.Equal(t, 8086, config.Master.Port) assert.Equal(t, "0.0.0.0", config.Master.Host) assert.Equal(t, true, config.Master.SSLMode) assert.Equal(t, "ca.pem", config.Master.SSLCA) assert.Equal(t, "cert.pem", config.Master.SSLCert) assert.Equal(t, "key.pem", config.Master.SSLKey) assert.Equal(t, 8088, config.Master.HttpPort) assert.Equal(t, "0.0.0.0", config.Master.HttpHost) assert.Equal(t, []string{".*"}, config.Master.HttpCorsDomains) assert.Equal(t, []string{"GET", "PATCH", "POST"}, config.Master.HttpCorsMethods) assert.Equal(t, 1, config.Master.NumJobs) assert.Equal(t, 10*time.Second, config.Master.MetaTimeout) assert.Equal(t, "admin", config.Master.DashboardUserName) assert.Equal(t, "password", config.Master.DashboardPassword) assert.Equal(t, "super_api_key", config.Master.AdminAPIKey) // [server] assert.Equal(t, 10, config.Server.DefaultN) assert.Equal(t, "19260817", config.Server.APIKey) assert.Equal(t, 5*time.Second, config.Server.ClockError) assert.True(t, config.Server.AutoInsertUser) assert.True(t, config.Server.AutoInsertItem) assert.Equal(t, 10*time.Second, config.Server.CacheExpire) // [recommend] assert.Equal(t, 100, config.Recommend.CacheSize) assert.Equal(t, 72*time.Hour, config.Recommend.CacheExpire) assert.Equal(t, 100, config.Recommend.ContextSize) // [recommend.data_source] assert.Equal(t, []expression.FeedbackTypeExpression{ expression.MustParseFeedbackTypeExpression("star"), expression.MustParseFeedbackTypeExpression("like"), expression.MustParseFeedbackTypeExpression("read>=3"), }, config.Recommend.DataSource.PositiveFeedbackTypes) assert.Equal(t, []expression.FeedbackTypeExpression{ expression.MustParseFeedbackTypeExpression("read"), }, config.Recommend.DataSource.ReadFeedbackTypes) assert.Equal(t, uint(0), config.Recommend.DataSource.PositiveFeedbackTTL) assert.Equal(t, uint(0), config.Recommend.DataSource.ItemTTL) // [recommend.non-personalized] assert.Len(t, config.Recommend.NonPersonalized, 1) assert.Equal(t, "most_starred_weekly", config.Recommend.NonPersonalized[0].Name) assert.Equal(t, "count(feedback, .FeedbackType == 'star')", config.Recommend.NonPersonalized[0].Score) assert.Equal(t, "(now() - item.Timestamp).Hours() < 168", config.Recommend.NonPersonalized[0].Filter) // [recommend.collaborative] assert.Equal(t, "mf", config.Recommend.Collaborative.Type) assert.Equal(t, 60*time.Minute, config.Recommend.Collaborative.FitPeriod) assert.Equal(t, 100, config.Recommend.Collaborative.FitEpoch) assert.Equal(t, 360*time.Minute, config.Recommend.Collaborative.OptimizePeriod) assert.Equal(t, 10, config.Recommend.Collaborative.OptimizeTrials) // [recommend.replacement] assert.False(t, config.Recommend.Replacement.EnableReplacement) assert.Equal(t, 0.8, config.Recommend.Replacement.PositiveReplacementDecay) assert.Equal(t, 0.6, config.Recommend.Replacement.ReadReplacementDecay) // [recommend.ranker] assert.Equal(t, "fm", config.Recommend.Ranker.Type) assert.Equal(t, 120*time.Hour, config.Recommend.Ranker.CacheExpire) assert.Equal(t, []string{"latest", "collaborative", "non-personalized/most_starred_weekly", "item-to-item/neighbors", "user-to-user/neighbors"}, config.Recommend.Ranker.Recommenders) assert.Equal(t, 60*time.Minute, config.Recommend.Ranker.FitPeriod) assert.Equal(t, 100, config.Recommend.Ranker.FitEpoch) assert.Equal(t, 360*time.Minute, config.Recommend.Ranker.OptimizePeriod) assert.Equal(t, 10, config.Recommend.Ranker.OptimizeTrials) assert.Equal(t, "", config.Recommend.Ranker.RerankerAPI.AuthToken) assert.Equal(t, "qwen3-rerank", config.Recommend.Ranker.RerankerAPI.Model) assert.Equal(t, "", config.Recommend.Ranker.RerankerAPI.URL) // [recommend.fallback] assert.Equal(t, []string{"item-to-item/neighbors", "latest"}, config.Recommend.Fallback.Recommenders) // [tracing] assert.False(t, config.Tracing.EnableTracing) assert.Equal(t, "otlp", config.Tracing.Exporter) assert.Equal(t, "http://localhost:4317", config.Tracing.CollectorEndpoint) assert.Equal(t, "always", config.Tracing.Sampler) assert.Equal(t, 1.0, config.Tracing.Ratio) // [oauth2] assert.Equal(t, "https://accounts.google.com", config.OIDC.Issuer) assert.Equal(t, "client_id", config.OIDC.ClientID) assert.Equal(t, "client_secret", config.OIDC.ClientSecret) assert.Equal(t, "http://localhost:8088/callback/oauth2", config.OIDC.RedirectURL) // [openai] assert.Equal(t, "http://localhost:11434/v1", config.OpenAI.BaseURL) assert.Equal(t, "ollama", config.OpenAI.AuthToken) assert.Equal(t, "qwen2.5", config.OpenAI.ChatCompletionModel) assert.Equal(t, 15000, config.OpenAI.ChatCompletionRPM) assert.Equal(t, 1200000, config.OpenAI.ChatCompletionTPM) assert.Equal(t, "mxbai-embed-large", config.OpenAI.EmbeddingModel) assert.Equal(t, 1024, config.OpenAI.EmbeddingDimensions) assert.Equal(t, 1800, config.OpenAI.EmbeddingRPM) assert.Equal(t, 1200000, config.OpenAI.EmbeddingTPM) }) } } func TestSetDefault(t *testing.T) { for _, binding := range bindings { t.Setenv(binding.env, "") } setDefault() viper.SetConfigType("toml") err := viper.ReadConfig(strings.NewReader("")) assert.NoError(t, err) var config Config err = viper.Unmarshal(&config) assert.NoError(t, err) assert.Equal(t, GetDefaultConfig(), &config) } type environmentVariable struct { key string value string } func TestBindEnv(t *testing.T) { variables := []environmentVariable{ {"GORSE_CACHE_STORE", "redis://"}, {"GORSE_DATA_STORE", "mysql://"}, {"GORSE_TABLE_PREFIX", "gorse_"}, {"GORSE_DATA_TABLE_PREFIX", "gorse_data_"}, {"GORSE_CACHE_TABLE_PREFIX", "gorse_cache_"}, {"GORSE_MASTER_PORT", "123"}, {"GORSE_MASTER_HOST", ""}, {"GORSE_MASTER_SSL_MODE", "true"}, {"GORSE_MASTER_SSL_CA", "ca.pem"}, {"GORSE_MASTER_SSL_CERT", "cert.pem"}, {"GORSE_MASTER_SSL_KEY", "key.pem"}, {"GORSE_MASTER_HTTP_PORT", "456"}, {"GORSE_MASTER_HTTP_HOST", ""}, {"GORSE_MASTER_JOBS", "789"}, {"GORSE_DASHBOARD_USER_NAME", "user_name"}, {"GORSE_DASHBOARD_PASSWORD", "password"}, {"GORSE_DASHBOARD_AUTH_SERVER", "http://127.0.0.1:8888"}, {"GORSE_DASHBOARD_REDACTED", "true"}, {"GORSE_ADMIN_API_KEY", ""}, {"GORSE_SERVER_API_KEY", ""}, {"GORSE_OIDC_ENABLE", "true"}, {"GORSE_OIDC_ISSUER", "https://accounts.google.com"}, {"GORSE_OIDC_CLIENT_ID", "client_id"}, {"GORSE_OIDC_CLIENT_SECRET", "client_secret"}, {"GORSE_OIDC_REDIRECT_URL", "http://localhost:8088/callback/oauth2"}, {"GORSE_BLOB_URI", "s3:///path"}, {"S3_ENDPOINT", "https://s3.example.com"}, {"S3_ACCESS_KEY_ID", ""}, {"S3_SECRET_ACCESS_KEY", ""}, {"GCS_CREDENTIALS_FILE", "/path/to/credentials.json"}, {"AZURE_STORAGE_ENDPOINT", "https://.blob.core.windows.net"}, {"AZURE_STORAGE_ACCOUNT_NAME", ""}, {"AZURE_STORAGE_ACCOUNT_KEY", ""}, {"AZURE_STORAGE_CONNECTION_STRING", "DefaultEndpointsProtocol=https;AccountName=;AccountKey="}, {"OPENAI_BASE_URL", "https://api.openai.com/v1"}, {"OPENAI_AUTH_TOKEN", ""}, {"OPENAI_CHAT_COMPLETION_MODEL", "gpt-4"}, {"RERANKER_AUTH_TOKEN", ""}, {"RERANKER_URL", ""}, {"RERANKER_MODEL", ""}, } for _, variable := range variables { t.Setenv(variable.key, variable.value) } config, err := LoadConfig("config.toml") assert.NoError(t, err) assert.Equal(t, "redis://", config.Database.CacheStore) assert.Equal(t, "mysql://", config.Database.DataStore) assert.Equal(t, "gorse_", config.Database.TablePrefix) assert.Equal(t, "gorse_cache_", config.Database.CacheTablePrefix) assert.Equal(t, "gorse_data_", config.Database.DataTablePrefix) assert.Equal(t, 123, config.Master.Port) assert.Equal(t, "", config.Master.Host) assert.Equal(t, true, config.Master.SSLMode) assert.Equal(t, "ca.pem", config.Master.SSLCA) assert.Equal(t, "cert.pem", config.Master.SSLCert) assert.Equal(t, "key.pem", config.Master.SSLKey) assert.Equal(t, 456, config.Master.HttpPort) assert.Equal(t, "", config.Master.HttpHost) assert.Equal(t, 789, config.Master.NumJobs) assert.Equal(t, "user_name", config.Master.DashboardUserName) assert.Equal(t, "password", config.Master.DashboardPassword) assert.Equal(t, true, config.Master.DashboardRedacted) assert.Equal(t, "", config.Master.AdminAPIKey) assert.Equal(t, "", config.Server.APIKey) assert.Equal(t, true, config.OIDC.Enable) assert.Equal(t, "https://accounts.google.com", config.OIDC.Issuer) assert.Equal(t, "client_id", config.OIDC.ClientID) assert.Equal(t, "client_secret", config.OIDC.ClientSecret) assert.Equal(t, "http://localhost:8088/callback/oauth2", config.OIDC.RedirectURL) assert.Equal(t, "s3:///path", config.Blob.URI) assert.Equal(t, "https://s3.example.com", config.Blob.S3.Endpoint) assert.Equal(t, "", config.Blob.S3.AccessKeyID) assert.Equal(t, "", config.Blob.S3.SecretAccessKey) assert.Equal(t, "/path/to/credentials.json", config.Blob.GCS.CredentialsFile) assert.Equal(t, "https://.blob.core.windows.net", config.Blob.Azure.Endpoint) assert.Equal(t, "", config.Blob.Azure.AccountName) assert.Equal(t, "", config.Blob.Azure.AccountKey) assert.Equal(t, "DefaultEndpointsProtocol=https;AccountName=;AccountKey=", config.Blob.Azure.ConnectionString) assert.Equal(t, "https://api.openai.com/v1", config.OpenAI.BaseURL) assert.Equal(t, "", config.OpenAI.AuthToken) assert.Equal(t, "gpt-4", config.OpenAI.ChatCompletionModel) assert.Equal(t, "", config.Recommend.Ranker.RerankerAPI.AuthToken) assert.Equal(t, "", config.Recommend.Ranker.RerankerAPI.URL) assert.Equal(t, "", config.Recommend.Ranker.RerankerAPI.Model) // check default values assert.Equal(t, 100, config.Recommend.CacheSize) } func TestTablePrefixCompat(t *testing.T) { data, err := os.ReadFile("config.toml") assert.NoError(t, err) text := string(data) text = strings.Replace(text, "cache_table_prefix = \"\"", "", -1) text = strings.Replace(text, "data_table_prefix = \"\"", "", -1) text = strings.Replace(text, "table_prefix = \"\"", "table_prefix = \"gorse_\"", -1) path := filepath.Join(t.TempDir(), "config.toml") err = os.WriteFile(path, []byte(text), os.ModePerm) assert.NoError(t, err) config, err := LoadConfig(path) assert.NoError(t, err) assert.Equal(t, "gorse_", config.Database.TablePrefix) assert.Equal(t, "gorse_", config.Database.CacheTablePrefix) assert.Equal(t, "gorse_", config.Database.DataTablePrefix) } func TestNonPersonalizedConfig(t *testing.T) { a := NonPersonalizedConfig{} b := NonPersonalizedConfig{} assert.Equal(t, a.Hash(), b.Hash()) a = NonPersonalizedConfig{Name: "a"} b = NonPersonalizedConfig{Name: "b"} assert.NotEqual(t, a.Hash(), b.Hash()) assert.Equal(t, "non-personalized/a", a.FullName()) assert.Equal(t, "non-personalized/b", b.FullName()) a = NonPersonalizedConfig{Score: "a"} b = NonPersonalizedConfig{Score: "b"} assert.NotEqual(t, a.Hash(), b.Hash()) a = NonPersonalizedConfig{Filter: "a"} b = NonPersonalizedConfig{Filter: "b"} assert.NotEqual(t, a.Hash(), b.Hash()) } func TestItemToItemConfig(t *testing.T) { a := ItemToItemConfig{} b := ItemToItemConfig{} assert.Equal(t, a.Hash(nil), b.Hash(nil)) a = ItemToItemConfig{Name: "a"} b = ItemToItemConfig{Name: "b"} assert.NotEqual(t, a.Hash(nil), b.Hash(nil)) assert.Equal(t, "item-to-item/a", a.FullName()) assert.Equal(t, "item-to-item/b", b.FullName()) a = ItemToItemConfig{Type: "a"} b = ItemToItemConfig{Type: "b"} assert.NotEqual(t, a.Hash(nil), b.Hash(nil)) a = ItemToItemConfig{Column: "a"} b = ItemToItemConfig{Column: "b"} assert.NotEqual(t, a.Hash(nil), b.Hash(nil)) c := ItemToItemConfig{Type: "users"} d := RecommendConfig{} e := RecommendConfig{} assert.Equal(t, c.Hash(&d), c.Hash(&e)) d.DataSource.PositiveFeedbackTypes = []expression.FeedbackTypeExpression{expression.MustParseFeedbackTypeExpression("like")} e.DataSource.PositiveFeedbackTypes = []expression.FeedbackTypeExpression{expression.MustParseFeedbackTypeExpression("star")} assert.NotEqual(t, c.Hash(&d), c.Hash(&e)) } func TestUserToUserConfig(t *testing.T) { a := UserToUserConfig{} b := UserToUserConfig{} assert.Equal(t, a.Hash(nil), b.Hash(nil)) a = UserToUserConfig{Name: "a"} b = UserToUserConfig{Name: "b"} assert.NotEqual(t, a.Hash(nil), b.Hash(nil)) assert.Equal(t, "user-to-user/a", a.FullName()) assert.Equal(t, "user-to-user/b", b.FullName()) a = UserToUserConfig{Type: "a"} b = UserToUserConfig{Type: "b"} assert.NotEqual(t, a.Hash(nil), b.Hash(nil)) a = UserToUserConfig{Column: "a"} b = UserToUserConfig{Column: "b"} assert.NotEqual(t, a.Hash(nil), b.Hash(nil)) c := UserToUserConfig{Type: "items"} d := RecommendConfig{} e := RecommendConfig{} assert.Equal(t, c.Hash(&d), c.Hash(&e)) d.DataSource.PositiveFeedbackTypes = []expression.FeedbackTypeExpression{expression.MustParseFeedbackTypeExpression("like")} e.DataSource.PositiveFeedbackTypes = []expression.FeedbackTypeExpression{expression.MustParseFeedbackTypeExpression("star")} assert.NotEqual(t, c.Hash(&d), c.Hash(&e)) } func TestCollaborativeConfig(t *testing.T) { a := RecommendConfig{} b := RecommendConfig{} c := CollaborativeConfig{} assert.Equal(t, c.Hash(&a), c.Hash(&b)) a.DataSource.PositiveFeedbackTypes = []expression.FeedbackTypeExpression{expression.MustParseFeedbackTypeExpression("like")} b.DataSource.PositiveFeedbackTypes = []expression.FeedbackTypeExpression{expression.MustParseFeedbackTypeExpression("star")} assert.NotEqual(t, c.Hash(&a), c.Hash(&b)) } func TestExternalConfig(t *testing.T) { a := ExternalConfig{} b := ExternalConfig{} assert.Equal(t, a.Hash(), b.Hash()) a = ExternalConfig{Name: "a"} b = ExternalConfig{Name: "b"} assert.NotEqual(t, a.Hash(), b.Hash()) assert.Equal(t, "external/a", a.FullName()) assert.Equal(t, "external/b", b.FullName()) a = ExternalConfig{Script: "a"} b = ExternalConfig{Script: "b"} assert.NotEqual(t, a.Hash(), b.Hash()) } func TestRecommendConfig(t *testing.T) { a := RecommendConfig{} b := RecommendConfig{} assert.Equal(t, a.Hash(), b.Hash()) a.NonPersonalized = []NonPersonalizedConfig{{Name: "a"}} b.NonPersonalized = []NonPersonalizedConfig{{Name: "b"}} assert.NotEqual(t, a.Hash(), b.Hash()) a.NonPersonalized = []NonPersonalizedConfig{} b.NonPersonalized = []NonPersonalizedConfig{} a.ItemToItem = []ItemToItemConfig{{Name: "a"}} b.ItemToItem = []ItemToItemConfig{{Name: "b"}} assert.NotEqual(t, a.Hash(), b.Hash()) a.ItemToItem = []ItemToItemConfig{} b.ItemToItem = []ItemToItemConfig{} a.UserToUser = []UserToUserConfig{{Name: "a"}} b.UserToUser = []UserToUserConfig{{Name: "b"}} assert.NotEqual(t, a.Hash(), b.Hash()) a.UserToUser = []UserToUserConfig{} b.UserToUser = []UserToUserConfig{} a.External = []ExternalConfig{{Name: "a"}} b.External = []ExternalConfig{{Name: "b"}} assert.NotEqual(t, a.Hash(), b.Hash()) a.External = []ExternalConfig{} b.External = []ExternalConfig{} a.DataSource.PositiveFeedbackTypes = []expression.FeedbackTypeExpression{expression.MustParseFeedbackTypeExpression("like")} b.DataSource.PositiveFeedbackTypes = []expression.FeedbackTypeExpression{expression.MustParseFeedbackTypeExpression("star")} assert.NotEqual(t, a.Hash(), b.Hash()) a.Ranker.Recommenders = []string{"latest"} b.Ranker.Recommenders = []string{"collaborative"} assert.NotEqual(t, a.Hash(), b.Hash()) a.UserToUser = []UserToUserConfig{{Name: "a"}, {Name: "b"}} b.UserToUser = []UserToUserConfig{{Name: "b"}, {Name: "a"}} a.Ranker.Recommenders = []string{"user-to-user/a", "user-to-user/b"} b.Ranker.Recommenders = []string{"user-to-user/b", "user-to-user/a"} assert.Equal(t, a.Hash(), b.Hash()) } type ValidateTestSuite struct { suite.Suite *Config } func (s *ValidateTestSuite) SetupTest() { s.Config = GetDefaultConfig() s.Database.CacheStore = "redis://localhost:6379/0" s.Database.DataStore = "mysql://gorse:gorse_pass@tcp(localhost:3306)/gorse" } func (s *ValidateTestSuite) TestDuplicateNonPersonalized() { s.Recommend.NonPersonalized = []NonPersonalizedConfig{{ Name: "most_starred_weekly", Score: "count(feedback, .FeedbackType == 'star')", }, { Name: "most_starred_weekly", Score: "count(feedback, .FeedbackType == 'star')", }} s.Error(s.Validate()) } func (s *ValidateTestSuite) TestDuplicateItemToItem() { s.Recommend.ItemToItem = []ItemToItemConfig{{ Name: "item_to_item", Type: "users", }, { Name: "item_to_item", Type: "users", }} s.Error(s.Validate()) } func (s *ValidateTestSuite) TestRecommendersExistence() { s.Recommend.Ranker.Recommenders = []string{"not_exist"} s.Error(s.Validate()) s.Recommend.Collaborative.Type = "none" s.Recommend.Ranker.Recommenders = []string{"collaborative"} s.Error(s.Validate()) } func TestValidate(t *testing.T) { suite.Run(t, new(ValidateTestSuite)) } ================================================ FILE: dataset/dataset.go ================================================ // Copyright 2025 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package dataset import ( "bufio" "fmt" "os" "sort" "strconv" "strings" "time" "github.com/chewxy/math32" mapset "github.com/deckarep/golang-set/v2" "github.com/gorse-io/gorse/common/util" "github.com/gorse-io/gorse/model" "github.com/gorse-io/gorse/storage/data" "github.com/juju/errors" "github.com/samber/lo" "modernc.org/strutil" ) type ID int32 // CFSplit is the dataset split for collaborative filtering. type CFSplit interface { // CountUsers returns the number of users. CountUsers() int // CountItems returns the number of items. CountItems() int // CountFeedback returns the number of (positive) feedback. CountFeedback() int // GetItems returns the items. GetItems() []data.Item // GetUserDict returns the frequency dictionary of users. GetUserDict() *FreqDict // GetItemDict returns the frequency dictionary of items. GetItemDict() *FreqDict // GetUserFeedback returns the (positive) feedback of users. GetUserFeedback() [][]int32 // GetItemFeedback returns the (positive) feedback of items. GetItemFeedback() [][]int32 // SampleUserNegatives samples negative (feedback) for users. SampleUserNegatives(excludeSet CFSplit, numCandidates int) [][]int32 } // CTRSplit is the dataset split for click-through rate prediction. type CTRSplit interface { Count() int CountUsers() int CountItems() int CountUserLabels() int CountItemLabels() int CountContextLabels() int CountPositive() int CountNegative() int GetIndex() UnifiedIndex GetTarget(i int) float32 Get(i int) ([]int32, []float32, [][]float32, float32) GetItemEmbeddingDim() []int GetItemEmbeddingIndex() *Index } type Dataset struct { timestamp time.Time users []data.User items []data.Item userLabels *Labels itemLabels *Labels userFeedback [][]int32 itemFeedback [][]int32 timestamps [][]time.Time negatives [][]int32 userDict *FreqDict itemDict *FreqDict numFeedback int categories map[string]int } func NewDataset(timestamp time.Time, userCount, itemCount int) *Dataset { return &Dataset{ timestamp: timestamp, users: make([]data.User, 0, userCount), items: make([]data.Item, 0, itemCount), userLabels: NewLabels(), itemLabels: NewLabels(), userFeedback: make([][]int32, userCount), itemFeedback: make([][]int32, itemCount), timestamps: make([][]time.Time, userCount), userDict: NewFreqDict(), itemDict: NewFreqDict(), categories: make(map[string]int), } } func (d *Dataset) GetTimestamp() time.Time { return d.timestamp } func (d *Dataset) CountFeedback() int { return d.numFeedback } func (d *Dataset) GetUsers() []data.User { return d.users } func (d *Dataset) GetUserDict() *FreqDict { return d.userDict } func (d *Dataset) CountUsers() int { return len(d.users) } func (d *Dataset) GetItems() []data.Item { return d.items } func (d *Dataset) GetItemDict() *FreqDict { return d.itemDict } func (d *Dataset) CountItems() int { return len(d.items) } func (d *Dataset) GetUserFeedback() [][]int32 { return d.userFeedback } func (d *Dataset) GetItemFeedback() [][]int32 { return d.itemFeedback } func (d *Dataset) GetCategories() map[string]int { return d.categories } // GetUserIDF returns the IDF of users. // // IDF(u) = log(I/freq(u)) // // I is the number of items. // freq(u) is the frequency of user u in all feedback. func (d *Dataset) GetUserIDF() []float32 { idf := make([]float32, d.userDict.Count()) for i := int32(0); i < d.userDict.Count(); i++ { // Since zero IDF will cause NaN in the future, we set the minimum value to 1e-3. idf[i] = max(math32.Log(float32(len(d.items))/float32(d.userDict.Freq(i))), 1e-3) } return idf } // GetItemIDF returns the IDF of items. // // IDF(i) = log(U/freq(i)) // // U is the number of users. // freq(i) is the frequency of item i in all feedback. func (d *Dataset) GetItemIDF() []float32 { idf := make([]float32, d.itemDict.Count()) for i := int32(0); i < d.itemDict.Count(); i++ { // Since zero IDF will cause NaN in the future, we set the minimum value to 1e-3. idf[i] = max(math32.Log(float32(len(d.users))/float32(d.itemDict.Freq(i))), 1e-3) } return idf } func (d *Dataset) GetUserColumnValuesIDF() []float32 { idf := make([]float32, d.userLabels.values.Count()) for i := int32(0); i < d.userLabels.values.Count(); i++ { // Since zero IDF will cause NaN in the future, we set the minimum value to 1e-3. idf[i] = max(math32.Log(float32(len(d.users))/float32(d.userLabels.values.Freq(i))), 1e-3) } return idf } func (d *Dataset) GetItemColumnValuesIDF() []float32 { idf := make([]float32, d.itemLabels.values.Count()) for i := int32(0); i < d.itemLabels.values.Count(); i++ { // Since zero IDF will cause NaN in the future, we set the minimum value to 1e-3. idf[i] = max(math32.Log(float32(len(d.items))/float32(d.itemLabels.values.Freq(i))), 1e-3) } return idf } func (d *Dataset) AddUser(user data.User) { d.users = append(d.users, data.User{ UserId: user.UserId, Labels: d.userLabels.processLabels(user.Labels, ""), Comment: user.Comment, }) d.userDict.AddNoCount(user.UserId) if len(d.userFeedback) < len(d.users) { d.userFeedback = append(d.userFeedback, nil) } if len(d.timestamps) < len(d.users) { d.timestamps = append(d.timestamps, nil) } } func (d *Dataset) AddItem(item data.Item) { d.items = append(d.items, data.Item{ ItemId: item.ItemId, IsHidden: item.IsHidden, Categories: item.Categories, Timestamp: item.Timestamp, Labels: d.itemLabels.processLabels(item.Labels, ""), Comment: item.Comment, }) d.itemDict.AddNoCount(item.ItemId) if len(d.itemFeedback) < len(d.items) { d.itemFeedback = append(d.itemFeedback, nil) } for _, category := range item.Categories { d.categories[category]++ } } func (d *Dataset) AddFeedback(userId, itemId string, timestamp time.Time) { userIndex := d.userDict.Add(userId) itemIndex := d.itemDict.Add(itemId) d.userFeedback[userIndex] = append(d.userFeedback[userIndex], itemIndex) d.itemFeedback[itemIndex] = append(d.itemFeedback[itemIndex], userIndex) d.timestamps[userIndex] = append(d.timestamps[userIndex], timestamp) d.numFeedback++ } func (d *Dataset) SampleUserNegatives(excludeSet CFSplit, numCandidates int) [][]int32 { if len(d.negatives) == 0 { rng := util.NewRandomGenerator(0) d.negatives = make([][]int32, d.CountUsers()) for userIndex := 0; userIndex < d.CountUsers(); userIndex++ { s1 := mapset.NewSet(d.GetUserFeedback()[userIndex]...) s2 := mapset.NewSet(excludeSet.GetUserFeedback()[userIndex]...) d.negatives[userIndex] = rng.SampleInt32(0, int32(d.CountItems()), numCandidates, s1, s2) } } return d.negatives } // SplitCF splits dataset by user-leave-one-out method. The argument `numTestUsers` determines the number of users in the test // set. If numTestUsers is equal or greater than the number of total users or numTestUsers <= 0, all users are presented // in the test set. func (d *Dataset) SplitCF(numTestUsers int, seed int64) (CFSplit, CFSplit) { trainSet, testSet := new(Dataset), new(Dataset) trainSet.users, testSet.users = d.users, d.users trainSet.items, testSet.items = d.items, d.items trainSet.userFeedback, testSet.userFeedback = make([][]int32, d.CountUsers()), make([][]int32, d.CountUsers()) trainSet.itemFeedback, testSet.itemFeedback = make([][]int32, d.CountItems()), make([][]int32, d.CountItems()) trainSet.timestamps, testSet.timestamps = make([][]time.Time, d.CountUsers()), make([][]time.Time, d.CountUsers()) trainSet.userDict, testSet.userDict = d.userDict, d.userDict trainSet.itemDict, testSet.itemDict = d.itemDict, d.itemDict rng := util.NewRandomGenerator(seed) if numTestUsers >= d.CountUsers() || numTestUsers <= 0 { for userIndex := int32(0); userIndex < int32(d.CountUsers()); userIndex++ { if len(d.userFeedback[userIndex]) > 0 { k := rng.Intn(len(d.userFeedback[userIndex])) testSet.userFeedback[userIndex] = append(testSet.userFeedback[userIndex], d.userFeedback[userIndex][k]) testSet.itemFeedback[d.userFeedback[userIndex][k]] = append(testSet.itemFeedback[d.userFeedback[userIndex][k]], userIndex) testSet.timestamps[userIndex] = append(testSet.timestamps[userIndex], d.timestamps[userIndex][k]) testSet.numFeedback++ for i, itemIndex := range d.userFeedback[userIndex] { if i != k { trainSet.userFeedback[userIndex] = append(trainSet.userFeedback[userIndex], itemIndex) trainSet.itemFeedback[itemIndex] = append(trainSet.itemFeedback[itemIndex], userIndex) trainSet.timestamps[userIndex] = append(trainSet.timestamps[userIndex], d.timestamps[userIndex][i]) trainSet.numFeedback++ } } } } } else { testUsers := rng.SampleInt32(0, int32(d.CountUsers()), numTestUsers) for _, userIndex := range testUsers { if len(d.userFeedback[userIndex]) > 0 { k := rng.Intn(len(d.userFeedback[userIndex])) testSet.userFeedback[userIndex] = append(testSet.userFeedback[userIndex], d.userFeedback[userIndex][k]) testSet.itemFeedback[d.userFeedback[userIndex][k]] = append(testSet.itemFeedback[d.userFeedback[userIndex][k]], userIndex) testSet.timestamps[userIndex] = append(testSet.timestamps[userIndex], d.timestamps[userIndex][k]) testSet.numFeedback++ for i, itemIndex := range d.userFeedback[userIndex] { if i != k { trainSet.userFeedback[userIndex] = append(trainSet.userFeedback[userIndex], itemIndex) trainSet.itemFeedback[itemIndex] = append(trainSet.itemFeedback[itemIndex], userIndex) trainSet.timestamps[userIndex] = append(trainSet.timestamps[userIndex], d.timestamps[userIndex][i]) trainSet.numFeedback++ } } } } testUserSet := mapset.NewSet(testUsers...) for userIndex := int32(0); userIndex < int32(d.CountUsers()); userIndex++ { if !testUserSet.Contains(userIndex) { for idx, itemIndex := range d.userFeedback[userIndex] { trainSet.userFeedback[userIndex] = append(trainSet.userFeedback[userIndex], itemIndex) trainSet.itemFeedback[itemIndex] = append(trainSet.itemFeedback[itemIndex], userIndex) trainSet.timestamps[userIndex] = append(trainSet.timestamps[userIndex], d.timestamps[userIndex][idx]) trainSet.numFeedback++ } } } } return trainSet, testSet } // SplitLatest splits dataset by moving the most recent feedback of all users into the test set to avoid leakage. func (d *Dataset) SplitLatest(shots int) (CFSplit, CFSplit) { trainSet, testSet := new(Dataset), new(Dataset) trainSet.users, testSet.users = d.users, d.users trainSet.items, testSet.items = d.items, d.items trainSet.userFeedback, testSet.userFeedback = make([][]int32, d.CountUsers()), make([][]int32, d.CountUsers()) trainSet.itemFeedback, testSet.itemFeedback = make([][]int32, d.CountItems()), make([][]int32, d.CountItems()) trainSet.timestamps, testSet.timestamps = make([][]time.Time, d.CountUsers()), make([][]time.Time, d.CountUsers()) trainSet.userDict, testSet.userDict = d.userDict, d.userDict trainSet.itemDict, testSet.itemDict = d.itemDict, d.itemDict for userIndex := int32(0); userIndex < int32(d.CountUsers()); userIndex++ { if len(d.userFeedback[userIndex]) == 0 { continue } idxs := lo.Range(len(d.userFeedback[userIndex])) sort.Slice(idxs, func(i, j int) bool { return d.timestamps[userIndex][idxs[i]].After(d.timestamps[userIndex][idxs[j]]) }) testSet.timestamps[userIndex] = append(testSet.timestamps[userIndex], d.timestamps[userIndex][idxs[0]]) testSet.itemFeedback[d.userFeedback[userIndex][idxs[0]]] = append(testSet.itemFeedback[d.userFeedback[userIndex][idxs[0]]], userIndex) testSet.userFeedback[userIndex] = append(testSet.userFeedback[userIndex], d.userFeedback[userIndex][idxs[0]]) testSet.numFeedback++ for i := 1; i < len(d.userFeedback[userIndex]) && i <= shots; i++ { itemIndex := d.userFeedback[userIndex][idxs[i]] trainSet.userFeedback[userIndex] = append(trainSet.userFeedback[userIndex], itemIndex) trainSet.itemFeedback[itemIndex] = append(trainSet.itemFeedback[itemIndex], userIndex) trainSet.timestamps[userIndex] = append(trainSet.timestamps[userIndex], d.timestamps[userIndex][idxs[i]]) trainSet.numFeedback++ } } return trainSet, testSet } type Labels struct { fields *strutil.Pool values *FreqDict } func NewLabels() *Labels { return &Labels{ fields: strutil.NewPool(), values: NewFreqDict(), } } func (l *Labels) processLabels(labels any, parent string) any { switch typed := labels.(type) { case map[string]any: o := make(map[string]any) for k, v := range typed { o[l.fields.Align(k)] = l.processLabels(v, parent+"."+k) } return o case []any: if isSliceOf[float64](typed) { return lo.Map(typed, func(e any, _ int) float32 { return float32(e.(float64)) }) } else if isSliceOf[string](typed) { return lo.Map(typed, func(e any, _ int) ID { return ID(l.values.Add(parent + ":" + e.(string))) }) } return typed case string: return ID(l.values.Add(parent + ":" + typed)) default: return labels } } func isSliceOf[T any](v []any) bool { for _, e := range v { if _, ok := e.(T); !ok { return false } } return true } func LoadDataFromBuiltIn(dataSetName string) (*Dataset, *Dataset, error) { // Extract Data set information trainFilePath, testFilePath, err := model.LocateBuiltInDataset(dataSetName, model.FormatNCF) if err != nil { return nil, nil, err } // Load dataset train, err := loadTrain(trainFilePath) if err != nil { return nil, nil, err } test := NewDataset(train.GetTimestamp(), 0, 0) test.users, test.items = train.users, train.items test.userDict, test.itemDict = train.userDict, train.itemDict test.userFeedback = make([][]int32, len(train.userFeedback)) test.itemFeedback = make([][]int32, len(train.itemFeedback)) test.timestamps = make([][]time.Time, len(train.userFeedback)) test.negatives = make([][]int32, len(train.userFeedback)) err = loadTest(test, testFilePath) if err != nil { return nil, nil, err } return train, test, nil } func loadTrain(path string) (*Dataset, error) { dataset := NewDataset(time.Now(), 0, 0) // Open file, err := os.Open(path) if err != nil { return nil, err } defer file.Close() // Read lines scanner := bufio.NewScanner(file) for scanner.Scan() { line := scanner.Text() fields := strings.Split(line, "\t") // add users userId, err := util.ParseInt[int32](fields[0]) if err != nil { return nil, err } for i := dataset.userDict.Count(); i <= userId; i++ { dataset.AddUser(data.User{UserId: util.FormatInt(i)}) } // add items itemId, err := util.ParseInt[int32](fields[1]) if err != nil { return nil, err } for i := dataset.itemDict.Count(); i <= itemId; i++ { dataset.AddItem(data.Item{ItemId: util.FormatInt(i)}) } // add feedback dataset.AddFeedback(fields[0], fields[1], time.Time{}) } return dataset, scanner.Err() } func loadTest(dataset *Dataset, path string) error { // Open file, err := os.Open(path) if err != nil { return errors.Trace(err) } defer file.Close() // Read lines scanner := bufio.NewScanner(file) for scanner.Scan() { line := scanner.Text() // parse line fields := strings.Split(line, "\t") positive, negatives := fields[0], fields[1:] if positive[0] != '(' || positive[len(positive)-1] != ')' { return fmt.Errorf("wrong foramt: %v", line) } positive = positive[1 : len(positive)-1] fields = strings.Split(positive, ",") // add feedback dataset.AddFeedback(fields[0], fields[1], time.Time{}) // add negatives userId, err := strconv.Atoi(fields[0]) if err != nil { return err } dataset.negatives[userId] = make([]int32, len(negatives)) for i, negative := range negatives { dataset.negatives[userId][i] = dataset.itemDict.Add(negative) } } return scanner.Err() } ================================================ FILE: dataset/dataset_test.go ================================================ // Copyright 2025 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package dataset import ( "fmt" "math" "strconv" "testing" "time" "github.com/chewxy/math32" "github.com/gorse-io/gorse/storage/data" "github.com/stretchr/testify/assert" ) func TestDataset_AddItem(t *testing.T) { dataSet := NewDataset(time.Now(), 0, 1) dataSet.AddItem(data.Item{ ItemId: "1", IsHidden: false, Categories: []string{"a", "b"}, Timestamp: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), Labels: map[string]any{ "a": 1, "embedded": []any{1.1, 2.2, 3.3}, "tags": []any{"a", "b", "c"}, }, Comment: "comment", }) dataSet.AddItem(data.Item{ ItemId: "2", IsHidden: true, Categories: []string{"a", "b"}, Timestamp: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), Labels: map[string]any{ "a": 1, "embedded": []any{1.1, 2.2, 3.3}, "tags": []any{"b", "c", "a"}, "topics": []any{"a", "b", "c"}, }, Comment: "comment", }) assert.Len(t, dataSet.GetItems(), 2) assert.Equal(t, data.Item{ ItemId: "1", IsHidden: false, Categories: []string{"a", "b"}, Timestamp: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), Labels: map[string]any{ "a": 1, "embedded": []float32{1.1, 2.2, 3.3}, "tags": []ID{0, 1, 2}, }, Comment: "comment", }, dataSet.GetItems()[0]) assert.Equal(t, data.Item{ ItemId: "2", IsHidden: true, Categories: []string{"a", "b"}, Timestamp: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), Labels: map[string]any{ "a": 1, "embedded": []float32{1.1, 2.2, 3.3}, "tags": []ID{1, 2, 0}, "topics": []ID{3, 4, 5}, }, Comment: "comment", }, dataSet.GetItems()[1]) } func TestDataset_GetItemColumnValuesIDF(t *testing.T) { dataSet := NewDataset(time.Now(), 0, 1) dataSet.AddItem(data.Item{ ItemId: "1", IsHidden: false, Categories: []string{"a", "b"}, Timestamp: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), Labels: map[string]any{ "tags": []any{"a", "b", "c"}, }, Comment: "comment", }) dataSet.AddItem(data.Item{ ItemId: "2", IsHidden: false, Categories: []string{"a", "b"}, Timestamp: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), Labels: map[string]any{ "tags": []any{"a", "e"}, }, Comment: "comment", }) idf := dataSet.GetItemColumnValuesIDF() assert.Len(t, idf, 4) assert.InDelta(t, 1e-3, idf[0], 1e-6) assert.InDelta(t, math32.Log(2), idf[1], 1e-6) } func TestDataset_AddUser(t *testing.T) { dataSet := NewDataset(time.Now(), 1, 0) dataSet.AddUser(data.User{ UserId: "1", Labels: map[string]any{"a": 1, "b": "a"}, Comment: "comment", }) assert.Len(t, dataSet.users, 1) assert.Equal(t, data.User{ UserId: "1", Labels: map[string]any{"a": 1, "b": ID(0)}, Comment: "comment", }, dataSet.users[0]) } func TestDataset_GetUserColumnValuesIDF(t *testing.T) { dataSet := NewDataset(time.Now(), 1, 0) dataSet.AddUser(data.User{ UserId: "1", Labels: map[string]any{ "tags": []any{"a", "b", "c"}, }, Comment: "comment", }) dataSet.AddUser(data.User{ UserId: "2", Labels: map[string]any{ "tags": []any{"a", "e"}, }, Comment: "comment", }) idf := dataSet.GetUserColumnValuesIDF() assert.Len(t, idf, 4) assert.InDelta(t, 1e-3, idf[0], 1e-6) assert.InDelta(t, math32.Log(2), idf[1], 1e-6) } func TestDataset_AddFeedback(t *testing.T) { dataSet := NewDataset(time.Now(), 10, 10) for i := 0; i < 10; i++ { dataSet.AddUser(data.User{ UserId: strconv.Itoa(i), }) } for i := 0; i < 10; i++ { dataSet.AddItem(data.Item{ ItemId: strconv.Itoa(i), }) } for i := 0; i < 10; i++ { for j := i; j < 10; j++ { dataSet.AddFeedback(strconv.Itoa(i), strconv.Itoa(j), time.Unix(int64(i*10+j), 0)) } } userIDF := dataSet.GetUserIDF() itemIDF := dataSet.GetItemIDF() for i := 0; i < 10; i++ { assert.Len(t, dataSet.GetUserFeedback()[i], 10-i) assert.Len(t, dataSet.GetItemFeedback()[i], i+1) assert.Len(t, dataSet.timestamps[i], 10-i) assert.InDelta(t, math32.Log(float32(10)/float32(10-i)), userIDF[i], 1e-2) assert.InDelta(t, math32.Log(float32(10)/float32(i+1)), itemIDF[i], 1e-2) } } func TestDataset_Split(t *testing.T) { const numUsers, numItems = 3, 5 // create dataset dataset := NewDataset(time.Now(), numUsers, numItems) for i := 0; i < numUsers; i++ { dataset.AddUser(data.User{UserId: fmt.Sprintf("user%v", i)}) } for i := 0; i < numItems; i++ { dataset.AddItem(data.Item{ItemId: fmt.Sprintf("item%v", i)}) } for i := 0; i < numUsers; i++ { for j := i + 1; j < numItems; j++ { dataset.AddFeedback(fmt.Sprintf("user%v", i), fmt.Sprintf("item%v", j), time.Time{}) } } assert.Equal(t, 9, dataset.CountFeedback()) // split train, test := dataset.SplitCF(0, 0) assert.Equal(t, numUsers, train.CountUsers()) assert.Equal(t, numItems, train.CountItems()) assert.Equal(t, 9-numUsers, train.CountFeedback()) assert.Equal(t, numUsers, test.CountUsers()) assert.Equal(t, numItems, test.CountItems()) assert.Equal(t, numUsers, test.CountFeedback()) // part split train2, test2 := dataset.SplitCF(2, 0) assert.Equal(t, numUsers, train2.CountUsers()) assert.Equal(t, numItems, train2.CountItems()) assert.Equal(t, 7, train2.CountFeedback()) assert.Equal(t, numUsers, test2.CountUsers()) assert.Equal(t, numItems, test2.CountItems()) assert.Equal(t, 2, test2.CountFeedback()) } func TestDataset_SplitLatest(t *testing.T) { const numUsers, numItems = 3, 5 // create dataset dataset := NewDataset(time.Now(), numUsers, numItems) for i := 0; i < numUsers; i++ { dataset.AddUser(data.User{UserId: fmt.Sprintf("user%v", i)}) } for i := 0; i < numItems; i++ { dataset.AddItem(data.Item{ItemId: fmt.Sprintf("item%v", i)}) } for i := 0; i < numUsers; i++ { for j := i + 1; j < numItems; j++ { dataset.AddFeedback(fmt.Sprintf("user%v", i), fmt.Sprintf("item%v", j), time.Unix(int64(j), 0)) } } assert.Equal(t, 9, dataset.CountFeedback()) // split train, test := dataset.SplitLatest(math.MaxInt) assert.Equal(t, numUsers, train.CountUsers()) assert.Equal(t, numItems, train.CountItems()) assert.Equal(t, numUsers, test.CountUsers()) assert.Equal(t, numItems, test.CountItems()) assert.Equal(t, 6, train.CountFeedback()) assert.Equal(t, 3, test.CountFeedback()) for i := 0; i < numUsers; i++ { assert.Len(t, train.GetUserFeedback()[i], numItems-i-2) assert.Len(t, test.GetUserFeedback()[i], 1) assert.Equal(t, 4, int(test.GetUserFeedback()[i][0])) } } func TestDataset_LoadMovieLens1M(t *testing.T) { train, test, err := LoadDataFromBuiltIn("ml-1m") assert.NoError(t, err) assert.Len(t, train.GetUsers(), 6040) assert.Len(t, train.GetItems(), 3706) assert.Equal(t, train.CountFeedback(), 994169) assert.Len(t, test.GetUsers(), 6040) assert.Len(t, test.GetItems(), 3706) assert.Equal(t, test.CountFeedback(), 6040) } ================================================ FILE: dataset/dict.go ================================================ // Copyright 2025 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package dataset type FreqDict struct { si map[string]int32 is []string cnt []int32 } func NewFreqDict() (d *FreqDict) { d = &FreqDict{map[string]int32{}, []string{}, []int32{}} return } func (d *FreqDict) Count() int32 { return int32(len(d.is)) } func (d *FreqDict) Add(s string) (y int32) { if y, ok := d.si[s]; ok { d.cnt[y]++ return y } y = int32(len(d.is)) d.si[s] = y d.is = append(d.is, s) d.cnt = append(d.cnt, 1) return } func (d *FreqDict) AddNoCount(s string) (y int32) { if y, ok := d.si[s]; ok { return y } y = int32(len(d.is)) d.si[s] = y d.is = append(d.is, s) d.cnt = append(d.cnt, 0) return } func (d *FreqDict) Id(s string) int32 { if y, ok := d.si[s]; ok { return y } return -1 } func (d *FreqDict) String(id int32) (s string, ok bool) { if id >= int32(len(d.is)) { return "", false } return d.is[id], true } func (d *FreqDict) Freq(id int32) int32 { if id >= int32(len(d.cnt)) { return 0 } return d.cnt[id] } func (d *FreqDict) ToIndex() *Index { return &Index{ Numbers: d.si, Names: d.is, } } ================================================ FILE: dataset/dict_test.go ================================================ // Copyright 2025 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package dataset import ( "testing" "github.com/stretchr/testify/assert" ) func TestFreqDict(t *testing.T) { dict := NewFreqDict() assert.Equal(t, int32(0), dict.Add("a")) assert.Equal(t, int32(1), dict.Add("b")) assert.Equal(t, int32(1), dict.Add("b")) assert.Equal(t, int32(2), dict.Add("c")) assert.Equal(t, int32(2), dict.Add("c")) assert.Equal(t, int32(2), dict.Add("c")) assert.Equal(t, int32(3), dict.Count()) assert.Equal(t, int32(1), dict.Freq(0)) assert.Equal(t, int32(2), dict.Freq(1)) assert.Equal(t, int32(3), dict.Freq(2)) assert.Equal(t, int32(0), dict.Id("a")) assert.Equal(t, int32(-1), dict.Id("e")) } ================================================ FILE: dataset/index.go ================================================ // Copyright 2020 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package dataset import ( "encoding/binary" "io" "github.com/gorse-io/gorse/common/encoding" "github.com/juju/errors" ) // MarshalIndex marshal index into byte stream. func MarshalIndex(w io.Writer, index *Index) error { return index.Marshal(w) } // UnmarshalIndex unmarshal index from byte stream. func UnmarshalIndex(r io.Reader) (*Index, error) { index := &Index{} err := index.Unmarshal(r) if err != nil { return nil, errors.Trace(err) } return index, nil } // Index manages the map between sparse Names and dense indices. A sparse ID is // a user ID or item ID. The dense index is the internal user index or item index // optimized for faster parameter access and less memory usage. type Index struct { Numbers map[string]int32 // sparse ID -> dense index Names []string // dense index -> sparse ID } // NotId represents an ID doesn't exist. const NotId = int32(-1) // NewMapIndex creates a Index. func NewMapIndex() *Index { set := new(Index) set.Numbers = make(map[string]int32) set.Names = make([]string, 0) return set } // Len returns the number of indexed Names. func (idx *Index) Len() int32 { if idx == nil { return 0 } return int32(len(idx.Names)) } // Add adds a new ID to the indexer. func (idx *Index) Add(name string) { if _, exist := idx.Numbers[name]; !exist { idx.Numbers[name] = int32(len(idx.Names)) idx.Names = append(idx.Names, name) } } // ToNumber converts a sparse ID to a dense index. func (idx *Index) ToNumber(name string) int32 { if denseId, exist := idx.Numbers[name]; exist { return denseId } return NotId } // ToName converts a dense index to a sparse ID. func (idx *Index) ToName(index int32) string { return idx.Names[index] } // GetNames returns all names in current index. func (idx *Index) GetNames() []string { return idx.Names } // Marshal map index into byte stream. func (idx *Index) Marshal(w io.Writer) error { // write length err := binary.Write(w, binary.LittleEndian, int32(len(idx.Names))) if err != nil { return errors.Trace(err) } // write names for _, s := range idx.Names { err = encoding.WriteString(w, s) if err != nil { return errors.Trace(err) } } return nil } // Unmarshal map index from byte stream. func (idx *Index) Unmarshal(r io.Reader) error { // read length var n int32 err := binary.Read(r, binary.LittleEndian, &n) if err != nil { return errors.Trace(err) } // write names idx.Names = make([]string, 0, n) idx.Numbers = make(map[string]int32, n) for i := 0; i < int(n); i++ { name, err := encoding.ReadString(r) if err != nil { return errors.Trace(err) } idx.Add(name) } return nil } ================================================ FILE: dataset/index_test.go ================================================ package dataset import ( "bytes" "testing" "github.com/stretchr/testify/assert" ) func TestIndex(t *testing.T) { // Null indexer var index *Index assert.Zero(t, index.Len()) // Create a indexer index = NewMapIndex() assert.Zero(t, index.Len()) // Add Names index.Add("1") index.Add("2") index.Add("4") index.Add("8") assert.Equal(t, int32(4), index.Len()) assert.Equal(t, int32(0), index.ToNumber("1")) assert.Equal(t, int32(1), index.ToNumber("2")) assert.Equal(t, int32(2), index.ToNumber("4")) assert.Equal(t, int32(3), index.ToNumber("8")) assert.Equal(t, NotId, index.ToNumber("1000")) assert.Equal(t, "1", index.ToName(0)) assert.Equal(t, "2", index.ToName(1)) assert.Equal(t, "4", index.ToName(2)) assert.Equal(t, "8", index.ToName(3)) // Get names assert.Equal(t, []string{"1", "2", "4", "8"}, index.GetNames()) // Encode and decode buf := bytes.NewBuffer(nil) err := MarshalIndex(buf, index) assert.NoError(t, err) indexCopy, err := UnmarshalIndex(buf) assert.NoError(t, err) assert.Equal(t, index, indexCopy) } ================================================ FILE: dataset/unified_index.go ================================================ // Copyright 2020 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package dataset import ( "encoding/binary" "fmt" "io" "strconv" "github.com/gorse-io/gorse/common/log" "github.com/juju/errors" ) // UnifiedIndex maps users, items and labels into a unified encoding space. type UnifiedIndex interface { Len() int32 EncodeUser(userId string) int32 EncodeItem(itemId string) int32 EncodeUserLabel(userLabel string) int32 EncodeItemLabel(itemLabel string) int32 EncodeContextLabel(ctxLabel string) int32 GetUsers() []string GetItems() []string GetUserLabels() []string GetItemLabels() []string GetContextLabels() []string CountUsers() int32 CountItems() int32 CountUserLabels() int32 CountItemLabels() int32 CountContextLabels() int32 Marshal(w io.Writer) error Unmarshal(r io.Reader) error } const ( mapIndex uint8 = iota directIndex nilIndex ) // MarshalIndex marshal index into byte stream. func MarshalUnifiedIndex(w io.Writer, index UnifiedIndex) error { // if index is nil if index == nil { return binary.Write(w, binary.LittleEndian, nilIndex) } // write index type var indexType uint8 switch index.(type) { case *UnifiedMapIndex: indexType = mapIndex case *UnifiedDirectIndex: indexType = directIndex default: return errors.New("unknown index type") } err := binary.Write(w, binary.LittleEndian, indexType) if err != nil { return errors.Trace(err) } // write index return index.Marshal(w) } // UnmarshalIndex unmarshal index from byte stream. func UnmarshalUnifiedIndex(r io.Reader) (UnifiedIndex, error) { // read type index var indexType uint8 err := binary.Read(r, binary.LittleEndian, &indexType) if err != nil { return nil, errors.Trace(err) } var index UnifiedIndex switch indexType { case mapIndex: index = &UnifiedMapIndex{} case directIndex: index = &UnifiedDirectIndex{} case nilIndex: return nil, nil default: return nil, fmt.Errorf("unknown index type (%v)", indexType) } // read index err = index.Unmarshal(r) if err != nil { return nil, errors.Trace(err) } return index, nil } // UnifiedMapIndexBuilder is the builder for UnifiedMapIndex. type UnifiedMapIndexBuilder struct { UserIndex *Index ItemIndex *Index UserLabelIndex *Index ItemLabelIndex *Index CtxLabelIndex *Index } // NewUnifiedMapIndexBuilder creates a UnifiedMapIndexBuilder. func NewUnifiedMapIndexBuilder() *UnifiedMapIndexBuilder { return &UnifiedMapIndexBuilder{ UserIndex: NewMapIndex(), ItemIndex: NewMapIndex(), UserLabelIndex: NewMapIndex(), ItemLabelIndex: NewMapIndex(), CtxLabelIndex: NewMapIndex(), } } // AddUser adds a user into the unified index. func (builder *UnifiedMapIndexBuilder) AddUser(userId string) { builder.UserIndex.Add(userId) } // AddItem adds a item into the unified index. func (builder *UnifiedMapIndexBuilder) AddItem(itemId string) { builder.ItemIndex.Add(itemId) } // AddUserLabel adds a user label into the unified index. func (builder *UnifiedMapIndexBuilder) AddUserLabel(userLabel string) { builder.UserLabelIndex.Add(userLabel) } // AddItemLabel adds a item label into the unified index. func (builder *UnifiedMapIndexBuilder) AddItemLabel(itemLabel string) { builder.ItemLabelIndex.Add(itemLabel) } // AddCtxLabel adds a context label the unified index. func (builder *UnifiedMapIndexBuilder) AddCtxLabel(ctxLabel string) { builder.CtxLabelIndex.Add(ctxLabel) } // Build UnifiedMapIndex from UnifiedMapIndexBuilder. func (builder *UnifiedMapIndexBuilder) Build() UnifiedIndex { return &UnifiedMapIndex{ UserIndex: builder.UserIndex, ItemIndex: builder.ItemIndex, UserLabelIndex: builder.UserLabelIndex, ItemLabelIndex: builder.ItemLabelIndex, CtxLabelIndex: builder.CtxLabelIndex, } } // UnifiedMapIndex is the id -> index mapper for factorization machines. // The division of id is: | user | item | user label | item label | context label | type UnifiedMapIndex struct { UserIndex *Index ItemIndex *Index UserLabelIndex *Index ItemLabelIndex *Index CtxLabelIndex *Index } // GetUserLabels returns all user labels. func (unified *UnifiedMapIndex) GetUserLabels() []string { return unified.UserLabelIndex.GetNames() } // GetItemLabels returns all item labels. func (unified *UnifiedMapIndex) GetItemLabels() []string { return unified.ItemLabelIndex.GetNames() } // GetContextLabels returns all context labels. func (unified *UnifiedMapIndex) GetContextLabels() []string { return unified.CtxLabelIndex.GetNames() } // CountUserLabels returns the number of user labels. func (unified *UnifiedMapIndex) CountUserLabels() int32 { return unified.UserLabelIndex.Len() } // CountItemLabels returns the number of item labels. func (unified *UnifiedMapIndex) CountItemLabels() int32 { return unified.ItemLabelIndex.Len() } // CountContextLabels returns the number of context labels. func (unified *UnifiedMapIndex) CountContextLabels() int32 { return unified.CtxLabelIndex.Len() } // Len returns the size of unified index. func (unified *UnifiedMapIndex) Len() int32 { return unified.UserIndex.Len() + unified.ItemIndex.Len() + unified.UserLabelIndex.Len() + unified.ItemLabelIndex.Len() + unified.CtxLabelIndex.Len() } // EncodeUser converts a user id to a integer in the encoding space. func (unified *UnifiedMapIndex) EncodeUser(userId string) int32 { return unified.UserIndex.ToNumber(userId) } // EncodeItem converts a item id to a integer in the encoding space. func (unified *UnifiedMapIndex) EncodeItem(itemId string) int32 { itemIndex := unified.ItemIndex.ToNumber(itemId) if itemIndex != NotId { itemIndex += unified.UserIndex.Len() } return itemIndex } // EncodeUserLabel converts a user label to a integer in the encoding space. func (unified *UnifiedMapIndex) EncodeUserLabel(userLabel string) int32 { userLabelIndex := unified.UserLabelIndex.ToNumber(userLabel) if userLabelIndex != NotId { userLabelIndex += unified.UserIndex.Len() + unified.ItemIndex.Len() } return userLabelIndex } // EncodeItemLabel converts a item label to a integer in the encoding space. func (unified *UnifiedMapIndex) EncodeItemLabel(itemLabel string) int32 { itemLabelIndex := unified.ItemLabelIndex.ToNumber(itemLabel) if itemLabelIndex != NotId { itemLabelIndex += unified.UserIndex.Len() + unified.ItemIndex.Len() + unified.UserLabelIndex.Len() } return itemLabelIndex } // EncodeContextLabel converts a context label to a integer in the encoding space. func (unified *UnifiedMapIndex) EncodeContextLabel(label string) int32 { ctxLabelIndex := unified.CtxLabelIndex.ToNumber(label) if ctxLabelIndex != NotId { ctxLabelIndex += unified.UserIndex.Len() + unified.ItemIndex.Len() + unified.UserLabelIndex.Len() + unified.ItemLabelIndex.Len() } return ctxLabelIndex } // GetUsers returns all users. func (unified *UnifiedMapIndex) GetUsers() []string { return unified.UserIndex.GetNames() } // GetItems returns all items. func (unified *UnifiedMapIndex) GetItems() []string { return unified.ItemIndex.GetNames() } // CountUsers returns the number of users. func (unified *UnifiedMapIndex) CountUsers() int32 { return unified.UserIndex.Len() } // CountItems returns the number of items. func (unified *UnifiedMapIndex) CountItems() int32 { return unified.ItemIndex.Len() } // Marshal map index into byte stream. func (unified *UnifiedMapIndex) Marshal(w io.Writer) error { indices := []*Index{unified.UserIndex, unified.ItemIndex, unified.UserLabelIndex, unified.ItemLabelIndex, unified.CtxLabelIndex} for _, index := range indices { err := MarshalIndex(w, index) if err != nil { return errors.Trace(err) } } return nil } // Unmarshal map index from byte stream. func (unified *UnifiedMapIndex) Unmarshal(r io.Reader) error { indices := []**Index{&unified.UserIndex, &unified.ItemIndex, &unified.UserLabelIndex, &unified.ItemLabelIndex, &unified.CtxLabelIndex} for i := range indices { var err error *indices[i], err = UnmarshalIndex(r) if err != nil { return errors.Trace(err) } } return nil } // UnifiedDirectIndex maps string to integer in literal. type UnifiedDirectIndex struct { N int32 } // EncodeUserLabel should be used by unit testing only. func (unified *UnifiedDirectIndex) EncodeUserLabel(userLabel string) int32 { if val, err := strconv.Atoi(userLabel); err != nil { panic(err) } else { return int32(val) } } // EncodeItemLabel should be used by unit testing only. func (unified *UnifiedDirectIndex) EncodeItemLabel(itemLabel string) int32 { if val, err := strconv.Atoi(itemLabel); err != nil { panic(err) } else { return int32(val) } } // GetUserLabels should be used by unit testing only. func (unified *UnifiedDirectIndex) GetUserLabels() []string { var names []string begin, end := unified.N/5*3, unified.N/5*4 for i := begin; i < end; i++ { names = append(names, strconv.Itoa(int(i))) } return names } // GetItemLabels should be used by unit testing only. func (unified *UnifiedDirectIndex) GetItemLabels() []string { var names []string begin, end := unified.N/5*2, unified.N/5*3 for i := begin; i < end; i++ { names = append(names, strconv.Itoa(int(i))) } return names } // GetContextLabels should be used by unit testing only. func (unified *UnifiedDirectIndex) GetContextLabels() []string { var names []string begin, end := unified.N/5*4, unified.N for i := begin; i < end; i++ { names = append(names, strconv.Itoa(int(i))) } return names } // CountUserLabels should be used by unit testing only. func (unified *UnifiedDirectIndex) CountUserLabels() int32 { return unified.N / 5 } // CountItemLabels should be used by unit testing only. func (unified *UnifiedDirectIndex) CountItemLabels() int32 { return unified.N / 5 } // CountContextLabels should be used by unit testing only. func (unified *UnifiedDirectIndex) CountContextLabels() int32 { return unified.N - unified.N/5*4 } // NewUnifiedDirectIndex creates a UnifiedDirectIndex. func NewUnifiedDirectIndex(n int32) UnifiedIndex { return &UnifiedDirectIndex{N: n} } // Len should be used by unit testing only. func (unified *UnifiedDirectIndex) Len() int32 { return unified.N } // EncodeUser should be used by unit testing only. func (unified *UnifiedDirectIndex) EncodeUser(userId string) int32 { if val, err := strconv.Atoi(userId); err != nil { panic(err) } else { return int32(val) } } // EncodeItem should be used by unit testing only. func (unified *UnifiedDirectIndex) EncodeItem(itemId string) int32 { if val, err := strconv.Atoi(itemId); err != nil { panic(err) } else { return int32(val) } } // EncodeContextLabel should be used by unit testing only. func (unified *UnifiedDirectIndex) EncodeContextLabel(label string) int32 { if val, err := strconv.Atoi(label); err != nil { panic(err) } else { return int32(val) } } // GetUsers should be used by unit testing only. func (unified *UnifiedDirectIndex) GetUsers() []string { var names []string begin, end := unified.N/5, unified.N/5*2 for i := begin; i < end; i++ { names = append(names, strconv.Itoa(int(i))) } return names } // GetItems should be used by unit testing only. func (unified *UnifiedDirectIndex) GetItems() []string { log.Logger().Warn("") var names []string begin, end := int32(0), unified.N/5 for i := begin; i < end; i++ { names = append(names, strconv.Itoa(int(i))) } return names } // CountUsers should be used by unit testing only. func (unified *UnifiedDirectIndex) CountUsers() int32 { return unified.N / 5 } // CountItems should be used by unit testing only. func (unified *UnifiedDirectIndex) CountItems() int32 { return unified.N / 5 } // Marshal direct index into byte stream. func (unified *UnifiedDirectIndex) Marshal(w io.Writer) error { return binary.Write(w, binary.LittleEndian, unified.N) } // Unmarshal direct index from byte stream. func (unified *UnifiedDirectIndex) Unmarshal(r io.Reader) error { return binary.Read(r, binary.LittleEndian, &unified.N) } ================================================ FILE: dataset/unified_index_test.go ================================================ // Copyright 2021 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package dataset import ( "bytes" "fmt" "testing" "github.com/stretchr/testify/assert" ) func TestUnifiedMapIndex(t *testing.T) { // create unified map index builder := NewUnifiedMapIndexBuilder() var numUsers, numItems, numUserLabels, numItemLabels, numCtxLabels int32 = 3, 4, 5, 6, 7 for i := int32(0); i < numUsers; i++ { builder.AddUser(fmt.Sprintf("user%v", i)) } for i := int32(0); i < numItems; i++ { builder.AddItem(fmt.Sprintf("item%v", i)) } for i := int32(0); i < numUserLabels; i++ { builder.AddUserLabel(fmt.Sprintf("user_label%v", i)) } for i := int32(0); i < numItemLabels; i++ { builder.AddItemLabel(fmt.Sprintf("item_label%v", i)) } for i := int32(0); i < numCtxLabels; i++ { builder.AddCtxLabel(fmt.Sprintf("ctx_label%v", i)) } index := builder.Build() // check count assert.Equal(t, numUsers+numItems+numUserLabels+numItemLabels+numCtxLabels, index.Len()) assert.Equal(t, numUsers, index.CountUsers()) assert.Equal(t, numItems, index.CountItems()) assert.Equal(t, numUserLabels, index.CountUserLabels()) assert.Equal(t, numItemLabels, index.CountItemLabels()) assert.Equal(t, numCtxLabels, index.CountContextLabels()) // check encode users := index.GetUsers() for i := int32(0); i < numUsers; i++ { userIndex := index.EncodeUser(fmt.Sprintf("user%v", i)) assert.Equal(t, i, userIndex) assert.Equal(t, fmt.Sprintf("user%v", i), users[i]) } items := index.GetItems() for i := int32(0); i < numItems; i++ { itemIndex := index.EncodeItem(fmt.Sprintf("item%v", i)) assert.Equal(t, numUsers+i, itemIndex) assert.Equal(t, fmt.Sprintf("item%v", i), items[i]) } userLabels := index.GetUserLabels() for i := int32(0); i < numUserLabels; i++ { userLabelIndex := index.EncodeUserLabel(fmt.Sprintf("user_label%v", i)) assert.Equal(t, numUsers+numItems+i, userLabelIndex) assert.Equal(t, fmt.Sprintf("user_label%v", i), userLabels[i]) } itemLabels := index.GetItemLabels() for i := int32(0); i < numItemLabels; i++ { itemLabelIndex := index.EncodeItemLabel(fmt.Sprintf("item_label%v", i)) assert.Equal(t, numUsers+numItems+numUserLabels+i, itemLabelIndex) assert.Equal(t, fmt.Sprintf("item_label%v", i), itemLabels[i]) } ctxLabels := index.GetContextLabels() for i := int32(0); i < numCtxLabels; i++ { ctxLabelIndex := index.EncodeContextLabel(fmt.Sprintf("ctx_label%v", i)) assert.Equal(t, numUsers+numItems+numUserLabels+numItemLabels+i, ctxLabelIndex) assert.Equal(t, fmt.Sprintf("ctx_label%v", i), ctxLabels[i]) } // Encode and decode buf := bytes.NewBuffer(nil) err := MarshalUnifiedIndex(buf, index) assert.NoError(t, err) indexCopy, err := UnmarshalUnifiedIndex(buf) assert.NoError(t, err) assert.Equal(t, index, indexCopy) } func TestUnifiedDirectIndex(t *testing.T) { index := NewUnifiedDirectIndex(10) assert.Equal(t, int32(10), index.Len()) assert.Equal(t, []string{"0", "1"}, index.GetItems()) assert.Equal(t, []string{"2", "3"}, index.GetUsers()) assert.Equal(t, []string{"4", "5"}, index.GetItemLabels()) assert.Equal(t, []string{"6", "7"}, index.GetUserLabels()) assert.Equal(t, []string{"8", "9"}, index.GetContextLabels()) assert.Equal(t, int32(2), index.CountItems()) assert.Equal(t, int32(2), index.CountUsers()) assert.Equal(t, int32(2), index.CountItemLabels()) assert.Equal(t, int32(2), index.CountUserLabels()) assert.Equal(t, int32(2), index.CountContextLabels()) assert.Panics(t, func() { index.EncodeItem("abc") }) assert.Panics(t, func() { index.EncodeUser("abc") }) assert.Panics(t, func() { index.EncodeItemLabel("abc") }) assert.Panics(t, func() { index.EncodeUserLabel("abc") }) assert.Panics(t, func() { index.EncodeContextLabel("abc") }) assert.Equal(t, int32(1), index.EncodeItem("1")) assert.Equal(t, int32(2), index.EncodeUser("2")) assert.Equal(t, int32(3), index.EncodeItemLabel("3")) assert.Equal(t, int32(4), index.EncodeUserLabel("4")) assert.Equal(t, int32(5), index.EncodeContextLabel("5")) // Encode and decode buf := bytes.NewBuffer(nil) err := MarshalUnifiedIndex(buf, index) assert.NoError(t, err) indexCopy, err := UnmarshalUnifiedIndex(buf) assert.NoError(t, err) assert.Equal(t, index, indexCopy) } ================================================ FILE: docker-bake.hcl ================================================ variable "VERSIONS" { default = "nightly" } variable versions { default = split(",", VERSIONS) } variable components { default = ["gorse-master", "gorse-server", "gorse-worker", "gorse-in-one"] } group "default" { targets = ["gorse-master", "gorse-server", "gorse-worker", "gorse-in-one"] } target "openblas" { matrix = { component = components } name = component context = "." dockerfile = "cmd/${component}/Dockerfile.openblas" platforms = ["linux/amd64", "linux/arm64", "linux/riscv64"] tags = [for v in versions : "zhenghaoz/${component}:${v}"] cache-from = ["type=gha"] cache-to = ["type=gha,mode=max"] } target "cuda" { matrix = { component = components } name = "${component}-cuda" context = "." dockerfile = "cmd/${component}/Dockerfile.cuda" platforms = ["linux/amd64"] tags = [for v in versions : "zhenghaoz/${component}:${v}-cuda12.8"] cache-from = ["type=s3,endpoint_url=https://b172f19b7e057975835d8d311a7b0dbd.r2.cloudflarestorage.com,bucket=github,region=auto"] cache-to = ["type=s3,endpoint_url=https://b172f19b7e057975835d8d311a7b0dbd.r2.cloudflarestorage.com,bucket=github,region=auto,mode=max"] } target "mkl" { matrix = { component = components } name = "${component}-mkl" context = "." dockerfile = "cmd/${component}/Dockerfile.mkl" platforms = ["linux/amd64"] tags = [for v in versions : "zhenghaoz/${component}:${v}-mkl"] cache-from = ["type=s3,endpoint_url=https://b172f19b7e057975835d8d311a7b0dbd.r2.cloudflarestorage.com,bucket=github,region=auto"] cache-to = ["type=s3,endpoint_url=https://b172f19b7e057975835d8d311a7b0dbd.r2.cloudflarestorage.com,bucket=github,region=auto,mode=max"] } ================================================ FILE: docker-compose.yml ================================================ version: "3" services: mysql: image: mysql/mysql-server restart: unless-stopped ports: - 3306:3306 environment: MYSQL_ROOT_PASSWORD: root_pass MYSQL_DATABASE: gorse MYSQL_USER: gorse MYSQL_PASSWORD: gorse_pass volumes: - mysql_data:/var/lib/mysql # postgres: # image: postgres:10.0 # ports: # - 5432:5432 # environment: # POSTGRES_DB: gorse # POSTGRES_USER: gorse # POSTGRES_PASSWORD: gorse_pass # volumes: # - postgres_data:/var/lib/postgresql/data # mongo: # image: mongo:4.0 # ports: # - 27017:27017 # environment: # MONGO_INITDB_DATABASE: gorse # MONGO_INITDB_ROOT_USERNAME: root # MONGO_INITDB_ROOT_PASSWORD: password # volumes: # - mongo_data:/data/db # clickhouse: # image: clickhouse/clickhouse-server:22 # ports: # - 8123:8123 # environment: # CLICKHOUSE_DB: gorse # CLICKHOUSE_USER: gorse # CLICKHOUSE_PASSWORD: gorse_pass # volumes: # - clickhouse_data:/var/lib/clickhouse # redis: # image: redis/redis-stack # restart: unless-stopped # ports: # - 6379:6379 worker: image: zhenghaoz/gorse-worker restart: unless-stopped ports: - 8089:8089 command: > --master-host master --master-port 8086 --http-host 0.0.0.0 --http-port 8089 --log-path /var/log/gorse/worker.log --cache-path /var/lib/gorse/worker_cache.data volumes: - gorse_log:/var/log/gorse - worker_data:/var/lib/gorse depends_on: - master server: image: zhenghaoz/gorse-server restart: unless-stopped ports: - 8087:8087 command: > --master-host master --master-port 8086 --http-host 0.0.0.0 --http-port 8087 --log-path /var/log/gorse/server.log --cache-path /var/lib/gorse/server_cache.data volumes: - gorse_log:/var/log/gorse - server_data:/var/lib/gorse depends_on: - master master: image: zhenghaoz/gorse-master restart: unless-stopped ports: - 8086:8086 - 8088:8088 environment: GORSE_CACHE_STORE: mysql://gorse:gorse_pass@tcp(mysql:3306)/gorse # GORSE_CACHE_STORE: postgres://gorse:gorse_pass@postgres/gorse?sslmode=disable # GORSE_CACHE_STORE: mongodb://root:password@mongo:27017/gorse?authSource=admin&connect=direct # GORSE_CACHE_STORE: redis://redis:6379 GORSE_DATA_STORE: mysql://gorse:gorse_pass@tcp(mysql:3306)/gorse # GORSE_DATA_STORE: postgres://gorse:gorse_pass@postgres/gorse?sslmode=disable # GORSE_DATA_STORE: mongodb://root:password@mongo:27017/gorse?authSource=admin&connect=direct # GORSE_DATA_STORE: clickhouse://gorse:gorse_pass@clickhouse:8123/gorse GORSE_BLOB_URI: /var/lib/gorse/blob # GORSE_BLOB_URI: s3://my-bucket/path # GORSE_BLOB_URI: gs://my-bucket/path # GORSE_BLOB_URI: az://container/path command: > -c /etc/gorse/config.toml --log-path /var/log/gorse/master.log --cache-path /var/lib/gorse/master volumes: - ./config/config.toml:/etc/gorse/config.toml - gorse_log:/var/log/gorse - master_data:/var/lib/gorse/master - blob_data:/var/lib/gorse/blob depends_on: - mysql # - postgres # - mongo # - clickhouse # - redis volumes: worker_data: server_data: master_data: gorse_log: mysql_data: # postgres_data: # mongo_data: # clickhouse_data: blob_data: ================================================ FILE: go.mod ================================================ module github.com/gorse-io/gorse go 1.26 require ( cloud.google.com/go/storage v1.61.3 github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0 github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.6.3 github.com/XSAM/otelsql v0.41.0 github.com/araddon/dateparse v0.0.0-20210429162001-6b43995a97de github.com/bits-and-blooms/bitset v1.24.4 github.com/c-bata/goptuna v0.8.1 github.com/cenkalti/backoff/v5 v5.0.3 github.com/chewxy/math32 v1.11.1 github.com/coreos/go-oidc/v3 v3.17.0 github.com/deckarep/golang-set/v2 v2.8.0 github.com/emicklei/go-restful-openapi/v2 v2.12.0 github.com/emicklei/go-restful/v3 v3.13.0 github.com/expr-lang/expr v1.17.8 github.com/fsouza/fake-gcs-server v1.54.0 github.com/fxtlabs/primes v0.0.0-20150821004651-dad82d10a449 github.com/go-openapi/strfmt v0.26.1 github.com/go-playground/locales v0.14.1 github.com/go-playground/universal-translator v0.18.1 github.com/go-playground/validator/v10 v10.30.1 github.com/go-sql-driver/mysql v1.9.3 github.com/go-viper/mapstructure/v2 v2.5.0 github.com/gomlx/gomlx v0.27.1 github.com/google/uuid v1.6.0 github.com/gorilla/securecookie v1.1.2 github.com/gorse-io/dashboard v0.0.0-20260223101641-33715448ded8 github.com/gorse-io/gorse-go v0.5.0-alpha.3 github.com/invopop/jsonschema v0.13.0 github.com/jaswdr/faker v1.19.1 github.com/jellydator/ttlcache/v3 v3.4.0 github.com/juju/errors v1.0.0 github.com/juju/ratelimit v1.0.2 github.com/klauspost/cpuid/v2 v2.3.0 github.com/lafikl/consistent v0.0.0-20220512074542-bdd3606bfc3e github.com/lib/pq v1.11.2 github.com/madflojo/testcerts v1.5.0 github.com/mailru/go-clickhouse/v2 v2.5.1 github.com/matttproud/golang_protobuf_extensions v1.0.4 github.com/milvus-io/milvus-sdk-go/v2 v2.4.2 github.com/minio/minio-go/v7 v7.0.99 github.com/modern-go/reflect2 v1.0.2 github.com/nikolalohinski/gonja/v2 v2.7.0 github.com/olekukonko/tablewriter v1.1.4 github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.23.2 github.com/qdrant/go-client v1.17.1 github.com/rakyll/statik v0.1.8 github.com/redis/go-redis/extra/redisotel/v9 v9.18.0 github.com/redis/go-redis/v9 v9.18.0 github.com/samber/lo v1.53.0 github.com/sashabaranov/go-openai v1.41.2 github.com/schollz/progressbar/v3 v3.19.0 github.com/sclevine/yj v0.0.0-20210612025309-737bdf40a5d1 github.com/spf13/cobra v1.10.2 github.com/spf13/pflag v1.0.10 github.com/spf13/viper v1.21.0 github.com/steinfletcher/apitest v1.6.0 github.com/stretchr/testify v1.11.1 github.com/swaggest/swgui v1.8.5 github.com/tiktoken-go/tokenizer v0.7.0 github.com/weaviate/weaviate v1.27.0 github.com/weaviate/weaviate-go-client/v4 v4.16.1 github.com/yuin/goldmark v1.7.16 go.mongodb.org/mongo-driver v1.17.9 go.opentelemetry.io/contrib/instrumentation/github.com/emicklei/go-restful/otelrestful v0.67.0 go.opentelemetry.io/contrib/instrumentation/go.mongodb.org/mongo-driver/mongo/otelmongo v0.67.0 go.opentelemetry.io/otel v1.42.0 go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.42.0 go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.42.0 go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.42.0 go.opentelemetry.io/otel/exporters/zipkin v1.42.0 go.opentelemetry.io/otel/sdk v1.42.0 go.opentelemetry.io/otel/trace v1.42.0 go.uber.org/atomic v1.11.0 go.uber.org/zap v1.27.1 golang.org/x/exp v0.0.0-20260312153236-7ab1446f8b90 golang.org/x/oauth2 v0.36.0 golang.org/x/sys v0.42.0 google.golang.org/api v0.272.0 google.golang.org/grpc v1.79.2 google.golang.org/grpc/security/advancedtls v1.0.0 google.golang.org/protobuf v1.36.11 gopkg.in/natefinch/lumberjack.v2 v2.2.1 gorm.io/driver/clickhouse v0.4.2 gorm.io/driver/mysql v1.3.4 gorm.io/driver/postgres v1.3.5 gorm.io/driver/sqlite v1.3.4 gorm.io/gorm v1.23.6 modernc.org/mathutil v1.7.1 modernc.org/quickjs v0.17.1 modernc.org/sortutil v1.2.1 modernc.org/sqlite v1.47.0 modernc.org/strutil v1.2.1 ) require ( cel.dev/expr v0.25.1 // indirect cloud.google.com/go v0.123.0 // indirect cloud.google.com/go/auth v0.18.2 // indirect cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect cloud.google.com/go/compute/metadata v0.9.0 // indirect cloud.google.com/go/iam v1.5.3 // indirect cloud.google.com/go/monitoring v1.24.3 // indirect cloud.google.com/go/pubsub/v2 v2.4.0 // indirect filippo.io/edwards25519 v1.1.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 // indirect github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.31.0 // indirect github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.55.0 // indirect github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.55.0 // indirect github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/buger/jsonparser v1.1.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/clipperhouse/displaywidth v0.11.0 // indirect github.com/clipperhouse/uax29/v2 v2.7.0 // indirect github.com/cncf/xds/go v0.0.0-20251210132809-ee656c7534f5 // indirect github.com/cockroachdb/errors v1.9.1 // indirect github.com/cockroachdb/logtags v0.0.0-20211118104740-dabe8e521a4f // indirect github.com/cockroachdb/redact v1.1.3 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dlclark/regexp2 v1.11.5 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/envoyproxy/go-control-plane/envoy v1.36.0 // indirect github.com/envoyproxy/protoc-gen-validate v1.3.0 // indirect github.com/erraggy/oastools v1.36.1 // indirect github.com/fatih/color v1.18.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/gabriel-vasile/mimetype v1.4.12 // indirect github.com/getsentry/sentry-go v0.30.0 // indirect github.com/go-ini/ini v1.67.0 // indirect github.com/go-jose/go-jose/v4 v4.1.3 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-openapi/analysis v0.24.1 // indirect github.com/go-openapi/errors v0.22.7 // indirect github.com/go-openapi/jsonpointer v0.22.4 // indirect github.com/go-openapi/jsonreference v0.21.4 // indirect github.com/go-openapi/loads v0.23.2 // indirect github.com/go-openapi/spec v0.22.2 // indirect github.com/go-openapi/swag v0.23.0 // indirect github.com/go-openapi/swag/conv v0.25.5 // indirect github.com/go-openapi/swag/fileutils v0.25.1 // indirect github.com/go-openapi/swag/jsonname v0.25.4 // indirect github.com/go-openapi/swag/jsonutils v0.25.5 // indirect github.com/go-openapi/swag/loading v0.25.5 // indirect github.com/go-openapi/swag/mangling v0.25.1 // indirect github.com/go-openapi/swag/stringutils v0.25.4 // indirect github.com/go-openapi/swag/typeutils v0.25.5 // indirect github.com/go-openapi/swag/yamlutils v0.25.5 // indirect github.com/go-openapi/validate v0.25.1 // indirect github.com/gofrs/flock v0.13.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/golang/snappy v1.0.0 // indirect github.com/gomlx/go-xla v0.2.0 // indirect github.com/google/renameio/v2 v2.0.0 // indirect github.com/google/s2a-go v0.1.9 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.14 // indirect github.com/googleapis/gax-go/v2 v2.18.0 // indirect github.com/gorilla/handlers v1.5.2 // indirect github.com/gorilla/mux v1.8.1 // indirect github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 // indirect github.com/hashicorp/go-version v1.6.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackc/chunkreader/v2 v2.0.1 // indirect github.com/jackc/pgconn v1.12.0 // indirect github.com/jackc/pgio v1.0.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgproto3/v2 v2.3.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/pgtype v1.11.0 // indirect github.com/jackc/pgx/v4 v4.16.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/compress v1.18.4 // indirect github.com/klauspost/crc32 v1.3.0 // indirect github.com/kr/pretty v0.3.1 // indirect github.com/kr/text v0.2.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect github.com/mailru/easyjson v0.7.7 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-runewidth v0.0.21 // indirect github.com/milvus-io/milvus-proto/go-api/v2 v2.4.10-0.20240819025435-512e3b98866a // indirect github.com/minio/blake2b-simd v0.0.0-20160723061019-3f5f724cb5b1 // indirect github.com/minio/crc64nvme v1.1.1 // indirect github.com/minio/md5-simd v1.1.2 // indirect github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/montanaflynn/stats v0.7.1 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/ncruces/go-strftime v1.0.0 // indirect github.com/oklog/ulid/v2 v2.1.1 // indirect github.com/olekukonko/cat v0.0.0-20250911104152-50322a0618f6 // indirect github.com/olekukonko/errors v1.2.0 // indirect github.com/olekukonko/ll v0.1.6 // indirect github.com/openzipkin/zipkin-go v0.4.3 // indirect github.com/pelletier/go-toml v1.9.5 // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect github.com/philhofer/fwd v1.2.0 // indirect github.com/pkg/xattr v0.4.12 // indirect github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/prometheus/client_model v0.6.2 // indirect github.com/prometheus/common v0.66.1 // indirect github.com/prometheus/procfs v0.16.1 // indirect github.com/redis/go-redis/extra/rediscmd/v9 v9.18.0 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/rogpeppe/go-internal v1.14.1 // indirect github.com/rs/xid v1.6.0 // indirect github.com/sagikazarmark/locafero v0.12.0 // indirect github.com/sirupsen/logrus v1.9.4 // indirect github.com/spf13/afero v1.15.0 // indirect github.com/spf13/cast v1.10.0 // indirect github.com/spiffe/go-spiffe/v2 v2.6.0 // indirect github.com/stretchr/objx v0.5.2 // indirect github.com/subosito/gotenv v1.6.0 // indirect github.com/tidwall/gjson v1.14.4 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect github.com/tinylib/msgp v1.6.1 // indirect github.com/vearutop/statigz v1.4.0 // indirect github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect github.com/x448/float16 v0.8.4 // indirect github.com/xdg-go/pbkdf2 v1.0.0 // indirect github.com/xdg-go/scram v1.2.0 // indirect github.com/xdg-go/stringprep v1.0.4 // indirect github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect go.opencensus.io v0.24.0 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/contrib/detectors/gcp v1.40.0 // indirect go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.63.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.65.0 // indirect go.opentelemetry.io/otel/metric v1.42.0 // indirect go.opentelemetry.io/otel/sdk/metric v1.42.0 // indirect go.opentelemetry.io/proto/otlp v1.9.0 // indirect go.uber.org/multierr v1.10.0 // indirect go.yaml.in/yaml/v2 v2.4.2 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect go.yaml.in/yaml/v4 v4.0.0-rc.3 // indirect golang.org/x/crypto v0.49.0 // indirect golang.org/x/net v0.52.0 // indirect golang.org/x/sync v0.20.0 // indirect golang.org/x/term v0.41.0 // indirect golang.org/x/text v0.35.0 // indirect golang.org/x/time v0.15.0 // indirect gonum.org/v1/gonum v0.16.0 // indirect google.golang.org/genproto v0.0.0-20260217215200-42d3e9bedb6d // indirect google.golang.org/genproto/googleapis/api v0.0.0-20260217215200-42d3e9bedb6d // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20260311181403-84a4fc48630c // indirect gopkg.in/yaml.v3 v3.0.1 // indirect k8s.io/klog/v2 v2.140.0 // indirect modernc.org/libc v1.70.0 // indirect modernc.org/libquickjs v0.12.3 // indirect modernc.org/memory v1.11.0 // indirect ) replace ( gorm.io/driver/clickhouse v0.4.2 => github.com/gorse-io/clickhouse v0.3.3-0.20251121080503-d578f146896d gorm.io/driver/sqlite v1.3.4 => github.com/gorse-io/sqlite v1.3.3-0.20220713123255-c322aec4e59e ) ================================================ FILE: go.sum ================================================ cel.dev/expr v0.25.1 h1:1KrZg61W6TWSxuNZ37Xy49ps13NUovb66QLprthtwi4= cel.dev/expr v0.25.1/go.mod h1:hrXvqGP6G6gyx8UAHSHJ5RGk//1Oj5nXQ2NI02Nrsg4= cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.38.0/go.mod h1:990N+gfupTy94rShfmMCWGDn0LpTmnzTp2qbd1dvSRU= cloud.google.com/go v0.44.1/go.mod h1:iSa0KzasP4Uvy3f1mN/7PiObzGgflwredwwASm/v6AU= cloud.google.com/go v0.44.2/go.mod h1:60680Gw3Yr4ikxnPRS/oxxkBccT6SA1yMk63TGekxKY= cloud.google.com/go v0.45.1/go.mod h1:RpBamKRgapWJb87xiFSdk4g1CME7QZg3uwTez+TSTjc= cloud.google.com/go v0.46.3/go.mod h1:a6bKKbmY7er1mI7TEI4lsAkts/mkhTSZK8w33B4RAg0= cloud.google.com/go v0.123.0 h1:2NAUJwPR47q+E35uaJeYoNhuNEM9kM8SjgRgdeOJUSE= cloud.google.com/go v0.123.0/go.mod h1:xBoMV08QcqUGuPW65Qfm1o9Y4zKZBpGS+7bImXLTAZU= cloud.google.com/go/auth v0.18.2 h1:+Nbt5Ev0xEqxlNjd6c+yYUeosQ5TtEUaNcN/3FozlaM= cloud.google.com/go/auth v0.18.2/go.mod h1:xD+oY7gcahcu7G2SG2DsBerfFxgPAJz17zz2joOFF3M= cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc= cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c= cloud.google.com/go/bigquery v1.0.1/go.mod h1:i/xbL2UlR5RvWAURpBYZTtm/cXjCha9lbfbpx4poX+o= cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs= cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10= cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7TKLgZqpHYE= cloud.google.com/go/firestore v1.1.0/go.mod h1:ulACoGHTpvq5r8rxGJ4ddJZBZqakUQqClKRT5SZwBmk= cloud.google.com/go/iam v1.5.3 h1:+vMINPiDF2ognBJ97ABAYYwRgsaqxPbQDlMnbHMjolc= cloud.google.com/go/iam v1.5.3/go.mod h1:MR3v9oLkZCTlaqljW6Eb2d3HGDGK5/bDv93jhfISFvU= cloud.google.com/go/logging v1.13.2 h1:qqlHCBvieJT9Cdq4QqYx1KPadCQ2noD4FK02eNqHAjA= cloud.google.com/go/logging v1.13.2/go.mod h1:zaybliM3yun1J8mU2dVQ1/qDzjbOqEijZCn6hSBtKak= cloud.google.com/go/longrunning v0.8.0 h1:LiKK77J3bx5gDLi4SMViHixjD2ohlkwBi+mKA7EhfW8= cloud.google.com/go/longrunning v0.8.0/go.mod h1:UmErU2Onzi+fKDg2gR7dusz11Pe26aknR4kHmJJqIfk= cloud.google.com/go/monitoring v1.24.3 h1:dde+gMNc0UhPZD1Azu6at2e79bfdztVDS5lvhOdsgaE= cloud.google.com/go/monitoring v1.24.3/go.mod h1:nYP6W0tm3N9H/bOw8am7t62YTzZY+zUeQ+Bi6+2eonI= cloud.google.com/go/pubsub v1.0.1/go.mod h1:R0Gpsv3s54REJCy4fxDixWD93lHJMoZTyQ2kNxGRt3I= cloud.google.com/go/pubsub/v2 v2.4.0 h1:oMKNiBQpXImRWnHYla9uSU66ZzByZwBSCJOEs/pTKVg= cloud.google.com/go/pubsub/v2 v2.4.0/go.mod h1:2lS/XQKq5qtOMs6kHBK+WX1ytUC36kLl2ig3zqsGUx8= cloud.google.com/go/storage v1.0.0/go.mod h1:IhtSnM/ZTZV8YYJWCY8RULGVqBDmpoyjwiyrjsg+URw= cloud.google.com/go/storage v1.61.3 h1:VS//ZfBuPGDvakfD9xyPW1RGF1Vy3BWUoVZXgW1KMOg= cloud.google.com/go/storage v1.61.3/go.mod h1:JtqK8BBB7TWv0HVGHubtUdzYYrakOQIsMLffZ2Z/HWk= cloud.google.com/go/trace v1.11.7 h1:kDNDX8JkaAG3R2nq1lIdkb7FCSi1rCmsEtKVsty7p+U= cloud.google.com/go/trace v1.11.7/go.mod h1:TNn9d5V3fQVf6s4SCveVMIBS2LJUqo73GACmq/Tky0s= codeberg.org/go-fonts/liberation v0.5.0 h1:SsKoMO1v1OZmzkG2DY+7ZkCL9U+rrWI09niOLfQ5Bo0= codeberg.org/go-fonts/liberation v0.5.0/go.mod h1:zS/2e1354/mJ4pGzIIaEtm/59VFCFnYC7YV6YdGl5GU= codeberg.org/go-latex/latex v0.1.0 h1:hoGO86rIbWVyjtlDLzCqZPjNykpWQ9YuTZqAzPcfL3c= codeberg.org/go-latex/latex v0.1.0/go.mod h1:LA0q/AyWIYrqVd+A9Upkgsb+IqPcmSTKc9Dny04MHMw= codeberg.org/go-pdf/fpdf v0.10.0 h1:u+w669foDDx5Ds43mpiiayp40Ov6sZalgcPMDBcZRd4= codeberg.org/go-pdf/fpdf v0.10.0/go.mod h1:Y0DGRAdZ0OmnZPvjbMp/1bYxmIPxm0ws4tfoPOc4LjU= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= git.sr.ht/~sbinet/gg v0.6.0 h1:RIzgkizAk+9r7uPzf/VfbJHBMKUr0F5hRFxTUGMnt38= git.sr.ht/~sbinet/gg v0.6.0/go.mod h1:uucygbfC9wVPQIfrmwM2et0imr8L7KQWywX0xpFMm94= github.com/AndreasBriese/bbloom v0.0.0-20190306092124-e2d15f34fcf9/go.mod h1:bOvUY6CB00SOBii9/FifXqc0awNKxLFCL/+pkDPuyl8= github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0 h1:JXg2dwJUmPB9JmtVmdEB16APJ7jurfbY5jnfXpJoRMc= github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0/go.mod h1:YD5h/ldMsG0XiIw7PdyNhLxaM317eFh5yNLccNfGdyw= github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.0 h1:KpMC6LFL7mqpExyMC9jVOYRiVhLmamjeZfRsUpB7l4s= github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.0/go.mod h1:J7MUC/wtRpfGVbQ5sIItY5/FuVWmvzlY21WAOfQnq/I= github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 h1:9iefClla7iYpfYWdzPCRDozdmndjTm8DXdpCzPajMgA= github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2/go.mod h1:XtLgD3ZD34DAaVIIAyG3objl5DynM3CQ/vMcbBNJZGI= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage v1.8.1 h1:/Zt+cDPnpC3OVDm/JKLOs7M2DKmLRIIp3XIx9pHHiig= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage v1.8.1/go.mod h1:Ng3urmn6dYe8gnbCMoHHVl5APYz2txho3koEkV2o2HA= github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.6.3 h1:ZJJNFaQ86GVKQ9ehwqyAFE6pIfyicpuJ8IkVaPBc6/4= github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.6.3/go.mod h1:URuDvhmATVKqHBH9/0nOiNKk0+YcwfQ3WkK5PqHKxc8= github.com/AzureAD/microsoft-authentication-library-for-go v1.5.0 h1:XkkQbfMyuH2jTSjQjSoihryI8GINRcs4xp8lNawg0FI= github.com/AzureAD/microsoft-authentication-library-for-go v1.5.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/CloudyKit/fastprinter v0.0.0-20200109182630-33d98a066a53/go.mod h1:+3IMCy2vIlbG1XG/0ggNQv0SvxCAIpPM5b1nCz56Xno= github.com/CloudyKit/jet/v3 v3.0.0/go.mod h1:HKQPgSJmdK8hdoAbKUUWajkHyHo4RaU5rMdUywE7VMo= github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.31.0 h1:DHa2U07rk8syqvCge0QIGMCE1WxGj9njT44GH7zNJLQ= github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.31.0/go.mod h1:P4WPRUkOhJC13W//jWpyfJNDAIpvRbAUIYLX/4jtlE0= github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.55.0 h1:UnDZ/zFfG1JhH/DqxIZYU/1CUAlTUScoXD/LcM2Ykk8= github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.55.0/go.mod h1:IA1C1U7jO/ENqm/vhi7V9YYpBsp+IMyqNrEN94N7tVc= github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/cloudmock v0.55.0 h1:7t/qx5Ost0s0wbA/VDrByOooURhp+ikYwv20i9Y07TQ= github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/cloudmock v0.55.0/go.mod h1:vB2GH9GAYYJTO3mEn8oYwzEdhlayZIdQz6zdzgUIRvA= github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.55.0 h1:0s6TxfCu2KHkkZPnBfsQ2y5qia0jl3MMrmBhu3nCOYk= github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.55.0/go.mod h1:Mf6O40IAyB9zR/1J8nGDDPirZQQPbYJni8Yisy7NTMc= github.com/Joker/hpp v1.0.0/go.mod h1:8x5n+M1Hp5hC0g8okX3sR3vFQwynaX/UgSOM9MeBKzY= github.com/MakeNowJust/heredoc v1.0.0 h1:cXCdzVdstXyiTqTvfqk9SDHpKNjxuom+DOlyEeQ4pzQ= github.com/MakeNowJust/heredoc v1.0.0/go.mod h1:mG5amYoWBHf8vpLOuehzbGGw0EHxpZZ6lCpQ4fNJ8LE= github.com/Masterminds/semver/v3 v3.1.1 h1:hLg3sBzpNErnxhQtUy/mmLR2I9foDujNK030IGemrRc= github.com/Masterminds/semver/v3 v3.1.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/PuerkitoBio/goquery v1.5.1/go.mod h1:GsLWisAFVj4WgDibEWF4pvYnkVQBpKBKeU+7zCJoLcc= github.com/Shopify/goreferrer v0.0.0-20181106222321-ec9c9a553398/go.mod h1:a1uqRtAwp2Xwc6WNPJEufxJ7fx3npB4UV/JOLmbu5I0= github.com/XSAM/otelsql v0.41.0 h1:uZifjQhZhv5EDYJh+IVk1DiYxQZJBlNSen0MBFnfxB8= github.com/XSAM/otelsql v0.41.0/go.mod h1:NMQT0PiKoFILp9QgjQz+D5mvW+9mT0suR7OejqrtMaM= github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY= github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw= github.com/ajstarks/svgo v0.0.0-20211024235047-1546f124cd8b h1:slYM766cy2nI3BwyRiyQj/Ud48djTMtMebDqepE95rw= github.com/ajstarks/svgo v0.0.0-20211024235047-1546f124cd8b/go.mod h1:1KcenG0jGWcpt8ov532z81sp/kMMUG485J2InIOyADM= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M= github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY= github.com/andybalholm/cascadia v1.1.0/go.mod h1:GsXiBklL0woXo1j/WYWtSYYC4ouU9PqHO0sqidkEA4Y= github.com/apache/arrow/go/arrow v0.0.0-20201229220542-30ce2eb5d4dc/go.mod h1:c9sxoIT3YgLxH4UhLOCKaBlEojuMhVYpk4Ntv3opUTQ= github.com/araddon/dateparse v0.0.0-20210429162001-6b43995a97de h1:FxWPpzIjnTlhPwqqXc4/vE0f7GvRjuAsbW+HOIe8KnA= github.com/araddon/dateparse v0.0.0-20210429162001-6b43995a97de/go.mod h1:DCaWoUhZrYW9p1lxo/cm8EmUOOzAPSEZNGF2DK1dJgw= github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o= github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8= github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY= github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= github.com/awalterschulze/gographviz v0.0.0-20190221210632-1e9ccb565bca/go.mod h1:GEV5wmg4YquNw7v1kkyoX9etIk8yVmXj+AkDHuuETHs= github.com/awalterschulze/gographviz v2.0.3+incompatible/go.mod h1:GEV5wmg4YquNw7v1kkyoX9etIk8yVmXj+AkDHuuETHs= github.com/aymerick/raymond v2.0.3-0.20180322193309-b565731e1464+incompatible/go.mod h1:osfaiScAUVup+UC9Nfq76eWqDhXlp+4UYaA8uhTBO6g= github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs= github.com/bits-and-blooms/bitset v1.24.4 h1:95H15Og1clikBrKr/DuzMXkQzECs1M6hhoGXLwLQOZE= github.com/bits-and-blooms/bitset v1.24.4/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= github.com/bketelsen/crypt v0.0.3-0.20200106085610-5cbc8cc4026c/go.mod h1:MKsuJmJgSg28kpZDP6UIiPt0e0Oz0kqKNGyRaWEPv84= github.com/bool64/dev v0.2.43 h1:yQ7qiZVef6WtCl2vDYU0Y+qSq+0aBrQzY8KXkklk9cQ= github.com/bool64/dev v0.2.43/go.mod h1:iJbh1y/HkunEPhgebWRNcs8wfGq7sjvJ6W5iabL8ACg= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= github.com/c-bata/goptuna v0.8.1 h1:25+n1MLv0yvCsD56xv4nqIus3oLHL9GuPAZDLIqmX1U= github.com/c-bata/goptuna v0.8.1/go.mod h1:knmS8+Iyq5PPy1YUeIEq0pMFR4Y6x7z/CySc9HlZTCY= github.com/campoy/embedmd v1.0.0 h1:V4kI2qTJJLf4J29RzI/MAt2c3Bl4dQSYPuflzwFH2hY= github.com/campoy/embedmd v1.0.0/go.mod h1:oxyr9RCiSXg0M3VJ3ks0UGfp98BpSSGr0kpiX3MzVl8= github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM= github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/chengxilo/virtualterm v1.0.4 h1:Z6IpERbRVlfB8WkOmtbHiDbBANU7cimRIof7mk9/PwM= github.com/chengxilo/virtualterm v1.0.4/go.mod h1:DyxxBZz/x1iqJjFxTFcr6/x+jSpqN0iwWCOK1q10rlY= github.com/chewxy/hm v1.0.0/go.mod h1:qg9YI4q6Fkj/whwHR1D+bOGeF7SniIP40VweVepLjg0= github.com/chewxy/math32 v1.0.0/go.mod h1:Miac6hA1ohdDUTagnvJy/q+aNnEk16qWUdb8ZVhvCN0= github.com/chewxy/math32 v1.0.6/go.mod h1:dOB2rcuFrCn6UHrze36WSLVPKtzPMRAQvBvUwkSsLqs= github.com/chewxy/math32 v1.11.1 h1:b7PGHlp8KjylDoU8RrcEsRuGZhJuz8haxnKfuMMRqy8= github.com/chewxy/math32 v1.11.1/go.mod h1:dOB2rcuFrCn6UHrze36WSLVPKtzPMRAQvBvUwkSsLqs= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/clipperhouse/displaywidth v0.11.0 h1:lBc6kY44VFw+TDx4I8opi/EtL9m20WSEFgwIwO+UVM8= github.com/clipperhouse/displaywidth v0.11.0/go.mod h1:bkrFNkf81G8HyVqmKGxsPufD3JhNl3dSqnGhOoSD/o0= github.com/clipperhouse/uax29/v2 v2.7.0 h1:+gs4oBZ2gPfVrKPthwbMzWZDaAFPGYK72F0NJv2v7Vk= github.com/clipperhouse/uax29/v2 v2.7.0/go.mod h1:EFJ2TJMRUaplDxHKj1qAEhCtQPW2tJSwu5BF98AuoVM= github.com/cloudflare/cfssl v0.0.0-20190808011637-b1ec8c586c2a/go.mod h1:yMWuSON2oQp+43nFtAV/uvKQIFpSPerB57DCt9t8sSA= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/cncf/xds/go v0.0.0-20251210132809-ee656c7534f5 h1:6xNmx7iTtyBRev0+D/Tv1FZd4SCg8axKApyNyRsAt/w= github.com/cncf/xds/go v0.0.0-20251210132809-ee656c7534f5/go.mod h1:KdCmV+x/BuvyMxRnYBlmVaq4OLiKW6iRQfvC62cvdkI= github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I= github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= github.com/cockroachdb/datadriven v1.0.2/go.mod h1:a9RdTaap04u637JoCzcUoIcDmvwSUtcUFtT/C3kJlTU= github.com/cockroachdb/errors v1.9.1 h1:yFVvsI0VxmRShfawbt/laCIDy/mtTqqnvoNgiy5bEV8= github.com/cockroachdb/errors v1.9.1/go.mod h1:2sxOtL2WIc096WSZqZ5h8fa17rdDq9HZOZLBCor4mBk= github.com/cockroachdb/logtags v0.0.0-20211118104740-dabe8e521a4f h1:6jduT9Hfc0njg5jJ1DdKCFPdMBrp/mdZfCpa5h+WM74= github.com/cockroachdb/logtags v0.0.0-20211118104740-dabe8e521a4f/go.mod h1:Vz9DsVWQQhf3vs21MhPMZpMGSht7O/2vFW2xusFUVOs= github.com/cockroachdb/redact v1.1.3 h1:AKZds10rFSIj7qADf0g46UixK8NNLwWTNdCIGS5wfSQ= github.com/cockroachdb/redact v1.1.3/go.mod h1:BVNblN9mBWFyMyqK1k3AAiSxhvhfK2oOZZ2lK+dpvRg= github.com/codegangsta/inject v0.0.0-20150114235600-33e0aa1cb7c0/go.mod h1:4Zcjuz89kmFXt9morQgcfYZAYZ5n8WHjt81YYWIwtTM= github.com/coreos/bbolt v1.3.2/go.mod h1:iRUV2dpdMOn7Bo10OQBFzIJO9kkE559Wcmn+qkEiiKk= github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= github.com/coreos/etcd v3.3.13+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= github.com/coreos/go-etcd v2.0.0+incompatible/go.mod h1:Jez6KQU2B/sWsbdaef3ED8NzMklzPG4d5KIOhIy30Tk= github.com/coreos/go-oidc/v3 v3.17.0 h1:hWBGaQfbi0iVviX4ibC7bk8OKT5qNr4klBaCHVNvehc= github.com/coreos/go-oidc/v3 v3.17.0/go.mod h1:wqPbKFrVnE90vty060SB40FCJ8fTHTxSwyXJqZH+sI8= github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA= github.com/cpuguy83/go-md2man v1.0.10/go.mod h1:SmD6nW6nTyfqj6ABTjUi3V3JVMnlJmwcJI5acqYI6dE= github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/cznic/cc v0.0.0-20181122101902-d673e9b70d4d/go.mod h1:m3fD/V+XTB35Kh9zw6dzjMY+We0Q7PMf6LLIC4vuG9k= github.com/cznic/golex v0.0.0-20181122101858-9c343928389c/go.mod h1:+bmmJDNmKlhWNG+gwWCkaBoTy39Fs+bzRxVBzoTQbIc= github.com/cznic/mathutil v0.0.0-20181122101859-297441e03548/go.mod h1:e6NPNENfs9mPDVNRekM7lKScauxd5kXTr1Mfyig6TDM= github.com/cznic/strutil v0.0.0-20181122101858-275e90344537/go.mod h1:AHHPPPXTw0h6pVabbcbyGRK1DckRn7r/STdZEeIDzZc= github.com/cznic/xc v0.0.0-20181122101856-45b06973881e/go.mod h1:3oFoiOvCDBYH+swwf5+k/woVmWy7h1Fcyu8Qig/jjX0= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/deckarep/golang-set/v2 v2.8.0 h1:swm0rlPCmdWn9mESxKOjWk8hXSqoxOp+ZlfuyaAdFlQ= github.com/deckarep/golang-set/v2 v2.8.0/go.mod h1:VAky9rY/yGXJOLEDv3OMci+7wtDpOF4IN+y82NBOac4= github.com/denisenkom/go-mssqldb v0.0.0-20191124224453-732737034ffd/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU= github.com/dgraph-io/badger v1.6.0/go.mod h1:zwt7syl517jmP8s94KqSxTlM6IMsdhYy6psNgSztDR4= github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/dgryski/go-sip13 v0.0.0-20181026042036-e10d5fee7954/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no= github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ= github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/eknkc/amber v0.0.0-20171010120322-cdade1c07385/go.mod h1:0vRUJqYpeSZifjYj7uP3BG/gKcuzL9xWVV/Y+cK33KM= github.com/emicklei/go-restful-openapi/v2 v2.12.0 h1:6wE0/+4tD6CpV5RI7x3ZdH2toUrB74oQmyEXBi47tHc= github.com/emicklei/go-restful-openapi/v2 v2.12.0/go.mod h1:I/b/Q1A/wpKWJGZJeO4WPaIw0ME4jXp5Yrh5hdB1bBA= github.com/emicklei/go-restful/v3 v3.13.0 h1:C4Bl2xDndpU6nJ4bc1jXd+uTmYPVUwkD6bFY/oTyCes= github.com/emicklei/go-restful/v3 v3.13.0/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= github.com/envoyproxy/go-control-plane v0.14.0 h1:hbG2kr4RuFj222B6+7T83thSPqLjwBIfQawTkC++2HA= github.com/envoyproxy/go-control-plane v0.14.0/go.mod h1:NcS5X47pLl/hfqxU70yPwL9ZMkUlwlKxtAohpi2wBEU= github.com/envoyproxy/go-control-plane/envoy v1.36.0 h1:yg/JjO5E7ubRyKX3m07GF3reDNEnfOboJ0QySbH736g= github.com/envoyproxy/go-control-plane/envoy v1.36.0/go.mod h1:ty89S1YCCVruQAm9OtKeEkQLTb+Lkz0k8v9W0Oxsv98= github.com/envoyproxy/go-control-plane/ratelimit v0.1.0 h1:/G9QYbddjL25KvtKTv3an9lx6VBE2cnb8wp1vEGNYGI= github.com/envoyproxy/go-control-plane/ratelimit v0.1.0/go.mod h1:Wk+tMFAFbCXaJPzVVHnPgRKdUdwW/KdbRt94AzgRee4= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/envoyproxy/protoc-gen-validate v1.3.0 h1:TvGH1wof4H33rezVKWSpqKz5NXWg5VPuZ0uONDT6eb4= github.com/envoyproxy/protoc-gen-validate v1.3.0/go.mod h1:HvYl7zwPa5mffgyeTUHA9zHIH36nmrm7oCbo4YKoSWA= github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5/go.mod h1:a2zkGnVExMxdzMo3M0Hi/3sEU+cWnZpSni0O6/Yb/P0= github.com/erraggy/oastools v1.36.1 h1:mNxGZO1w7LCFmZhar9QPOhLFpchE0T9TzNwqa+rENN4= github.com/erraggy/oastools v1.36.1/go.mod h1:KVEs42aJN9Z+H9YpxqjGrXJRQdrp1W0aJ4w2R+UId+I= github.com/etcd-io/bbolt v1.3.3/go.mod h1:ZF2nL25h33cCyBtcyWeZ2/I3HQOfTP+0PIEvHjkjCrw= github.com/expr-lang/expr v1.17.8 h1:W1loDTT+0PQf5YteHSTpju2qfUfNoBt4yw9+wOEU9VM= github.com/expr-lang/expr v1.17.8/go.mod h1:8/vRC7+7HBzESEqt5kKpYXxrxkr31SaO8r40VO/1IT4= github.com/fasthttp-contrib/websocket v0.0.0-20160511215533-1f3b11f56072/go.mod h1:duJ4Jxv5lDcvg4QuQr0oowTf7dz4/CR8NtyCooz9HL8= github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= github.com/fsouza/fake-gcs-server v1.54.0 h1:DGO4EkFVbtP/A5Ha+CAHHx+Xa6O6LeskMB4hQ1wBE48= github.com/fsouza/fake-gcs-server v1.54.0/go.mod h1:ryXYE4debQs8GjOxwaOAwFRwM4Cvs6S+NKPPgdVJe6g= github.com/fxtlabs/primes v0.0.0-20150821004651-dad82d10a449 h1:HOYnhuVrhAVGKdg3rZapII640so7QfXQmkLkefUN/uM= github.com/fxtlabs/primes v0.0.0-20150821004651-dad82d10a449/go.mod h1:i+vbdOOivRRh2j+WwBkjZXloGN/+KAqfKDwNfUJeugc= github.com/gabriel-vasile/mimetype v1.4.12 h1:e9hWvmLYvtp846tLHam2o++qitpguFiYCKbn0w9jyqw= github.com/gabriel-vasile/mimetype v1.4.12/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s= github.com/gavv/httpexpect v2.0.0+incompatible/go.mod h1:x+9tiU1YnrOvnB725RkpoLv1M62hOWzwo5OXotisrKc= github.com/getsentry/sentry-go v0.12.0/go.mod h1:NSap0JBYWzHND8oMbyi0+XZhUalc1TBdRL1M71JZW2c= github.com/getsentry/sentry-go v0.30.0 h1:lWUwDnY7sKHaVIoZ9wYqRHJ5iEmoc0pqcRqFkosKzBo= github.com/getsentry/sentry-go v0.30.0/go.mod h1:WU9B9/1/sHDqeV8T+3VwwbjeR5MSXs/6aqG3mqZrezA= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/gin-contrib/sse v0.0.0-20190301062529-5545eab6dad3/go.mod h1:VJ0WA2NBN22VlZ2dKZQPAPnyWw5XTlK1KymzLKsr59s= github.com/gin-gonic/gin v1.4.0/go.mod h1:OW2EZn3DO8Ln9oIKOvM++LBO+5UPHJJDH72/q/3rZdM= github.com/go-check/check v0.0.0-20180628173108-788fd7840127/go.mod h1:9ES+weclKsC9YodN5RgxqK/VD9HM9JsCSh7rNhMZE98= github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q= github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA= github.com/go-errors/errors v1.4.2/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og= github.com/go-faker/faker/v4 v4.1.0 h1:ffuWmpDrducIUOO0QSKSF5Q2dxAht+dhsT9FvVHhPEI= github.com/go-faker/faker/v4 v4.1.0/go.mod h1:uuNc0PSRxF8nMgjGrrrU4Nw5cF30Jc6Kd0/FUTTYbhg= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gota/gota v0.10.1/go.mod h1:NZLQccXn0rABmkXjsaugRY6l+UH2dDZSgIgF8E2ipmA= github.com/go-ini/ini v1.67.0 h1:z6ZrTEZqSWOTyH2FlglNbNgARyHG8oLW9gMELqKr06A= github.com/go-ini/ini v1.67.0/go.mod h1:ByCAeIL28uOIIG0E3PJtZPDL8WnHpFKFOtgjp+3Ies8= github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs= github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-martini/martini v0.0.0-20170121215854-22fa46961aab/go.mod h1:/P9AEU963A2AYjv4d1V5eVL1CQbEJq6aCNHDDjibzu8= github.com/go-openapi/analysis v0.24.1 h1:Xp+7Yn/KOnVWYG8d+hPksOYnCYImE3TieBa7rBOesYM= github.com/go-openapi/analysis v0.24.1/go.mod h1:dU+qxX7QGU1rl7IYhBC8bIfmWQdX4Buoea4TGtxXY84= github.com/go-openapi/errors v0.22.7 h1:JLFBGC0Apwdzw3484MmBqspjPbwa2SHvpDm0u5aGhUA= github.com/go-openapi/errors v0.22.7/go.mod h1://QW6SD9OsWtH6gHllUCddOXDL0tk0ZGNYHwsw4sW3w= github.com/go-openapi/jsonpointer v0.22.4 h1:dZtK82WlNpVLDW2jlA1YCiVJFVqkED1MegOUy9kR5T4= github.com/go-openapi/jsonpointer v0.22.4/go.mod h1:elX9+UgznpFhgBuaMQ7iu4lvvX1nvNsesQ3oxmYTw80= github.com/go-openapi/jsonreference v0.21.4 h1:24qaE2y9bx/q3uRK/qN+TDwbok1NhbSmGjjySRCHtC8= github.com/go-openapi/jsonreference v0.21.4/go.mod h1:rIENPTjDbLpzQmQWCj5kKj3ZlmEh+EFVbz3RTUh30/4= github.com/go-openapi/loads v0.23.2 h1:rJXAcP7g1+lWyBHC7iTY+WAF0rprtM+pm8Jxv1uQJp4= github.com/go-openapi/loads v0.23.2/go.mod h1:IEVw1GfRt/P2Pplkelxzj9BYFajiWOtY2nHZNj4UnWY= github.com/go-openapi/spec v0.22.2 h1:KEU4Fb+Lp1qg0V4MxrSCPv403ZjBl8Lx1a83gIPU8Qc= github.com/go-openapi/spec v0.22.2/go.mod h1:iIImLODL2loCh3Vnox8TY2YWYJZjMAKYyLH2Mu8lOZs= github.com/go-openapi/strfmt v0.26.1 h1:7zGCHji7zSYDC2tCXIusoxYQz/48jAf2q+sF6wXTG+c= github.com/go-openapi/strfmt v0.26.1/go.mod h1:Zslk5VZPOISLwmWTMBIS7oiVFem1o1EI6zULY8Uer7Y= github.com/go-openapi/swag v0.23.0 h1:vsEVJDUo2hPJ2tu0/Xc+4noaxyEffXNIs3cOULZ+GrE= github.com/go-openapi/swag v0.23.0/go.mod h1:esZ8ITTYEsH1V2trKHjAN8Ai7xHb8RV+YSZ577vPjgQ= github.com/go-openapi/swag/conv v0.25.5 h1:wAXBYEXJjoKwE5+vc9YHhpQOFj2JYBMF2DUi+tGu97g= github.com/go-openapi/swag/conv v0.25.5/go.mod h1:CuJ1eWvh1c4ORKx7unQnFGyvBbNlRKbnRyAvDvzWA4k= github.com/go-openapi/swag/fileutils v0.25.1 h1:rSRXapjQequt7kqalKXdcpIegIShhTPXx7yw0kek2uU= github.com/go-openapi/swag/fileutils v0.25.1/go.mod h1:+NXtt5xNZZqmpIpjqcujqojGFek9/w55b3ecmOdtg8M= github.com/go-openapi/swag/jsonname v0.25.4 h1:bZH0+MsS03MbnwBXYhuTttMOqk+5KcQ9869Vye1bNHI= github.com/go-openapi/swag/jsonname v0.25.4/go.mod h1:GPVEk9CWVhNvWhZgrnvRA6utbAltopbKwDu8mXNUMag= github.com/go-openapi/swag/jsonutils v0.25.5 h1:XUZF8awQr75MXeC+/iaw5usY/iM7nXPDwdG3Jbl9vYo= github.com/go-openapi/swag/jsonutils v0.25.5/go.mod h1:48FXUaz8YsDAA9s5AnaUvAmry1UcLcNVWUjY42XkrN4= github.com/go-openapi/swag/jsonutils/fixtures_test v0.25.5 h1:SX6sE4FrGb4sEnnxbFL/25yZBb5Hcg1inLeErd86Y1U= github.com/go-openapi/swag/jsonutils/fixtures_test v0.25.5/go.mod h1:/2KvOTrKWjVA5Xli3DZWdMCZDzz3uV/T7bXwrKWPquo= github.com/go-openapi/swag/loading v0.25.5 h1:odQ/umlIZ1ZVRteI6ckSrvP6e2w9UTF5qgNdemJHjuU= github.com/go-openapi/swag/loading v0.25.5/go.mod h1:I8A8RaaQ4DApxhPSWLNYWh9NvmX2YKMoB9nwvv6oW6g= github.com/go-openapi/swag/mangling v0.25.1 h1:XzILnLzhZPZNtmxKaz/2xIGPQsBsvmCjrJOWGNz/ync= github.com/go-openapi/swag/mangling v0.25.1/go.mod h1:CdiMQ6pnfAgyQGSOIYnZkXvqhnnwOn997uXZMAd/7mQ= github.com/go-openapi/swag/stringutils v0.25.4 h1:O6dU1Rd8bej4HPA3/CLPciNBBDwZj9HiEpdVsb8B5A8= github.com/go-openapi/swag/stringutils v0.25.4/go.mod h1:GTsRvhJW5xM5gkgiFe0fV3PUlFm0dr8vki6/VSRaZK0= github.com/go-openapi/swag/typeutils v0.25.5 h1:EFJ+PCga2HfHGdo8s8VJXEVbeXRCYwzzr9u4rJk7L7E= github.com/go-openapi/swag/typeutils v0.25.5/go.mod h1:itmFmScAYE1bSD8C4rS0W+0InZUBrB2xSPbWt6DLGuc= github.com/go-openapi/swag/yamlutils v0.25.5 h1:kASCIS+oIeoc55j28T4o8KwlV2S4ZLPT6G0iq2SSbVQ= github.com/go-openapi/swag/yamlutils v0.25.5/go.mod h1:Gek1/SjjfbYvM+Iq4QGwa/2lEXde9n2j4a3wI3pNuOQ= github.com/go-openapi/testify/enable/yaml/v2 v2.4.0 h1:7SgOMTvJkM8yWrQlU8Jm18VeDPuAvB/xWrdxFJkoFag= github.com/go-openapi/testify/enable/yaml/v2 v2.4.0/go.mod h1:14iV8jyyQlinc9StD7w1xVPW3CO3q1Gj04Jy//Kw4VM= github.com/go-openapi/testify/v2 v2.4.1 h1:zB34HDKj4tHwyUQHrUkpV0Q0iXQ6dUCOQtIqn8hE6Iw= github.com/go-openapi/testify/v2 v2.4.1/go.mod h1:HCPmvFFnheKK2BuwSA0TbbdxJ3I16pjwMkYkP4Ywn54= github.com/go-openapi/validate v0.25.1 h1:sSACUI6Jcnbo5IWqbYHgjibrhhmt3vR6lCzKZnmAgBw= github.com/go-openapi/validate v0.25.1/go.mod h1:RMVyVFYte0gbSTaZ0N4KmTn6u/kClvAFp+mAVfS/DQc= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= github.com/go-playground/validator/v10 v10.30.1 h1:f3zDSN/zOma+w6+1Wswgd9fLkdwy06ntQJp0BBvFG0w= github.com/go-playground/validator/v10 v10.30.1/go.mod h1:oSuBIQzuJxL//3MelwSLD5hc2Tu889bF0Idm9Dg26cM= github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo= github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI= github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= github.com/go-viper/mapstructure/v2 v2.5.0 h1:vM5IJoUAy3d7zRSVtIwQgBj7BiWtMPfmPEgAXnvj1Ro= github.com/go-viper/mapstructure/v2 v2.5.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo= github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM= github.com/gofrs/flock v0.13.0 h1:95JolYOvGMqeH31+FC7D2+uULf6mG61mEZ/A8dRYMzw= github.com/gofrs/flock v0.13.0/go.mod h1:jxeyy9R1auM5S6JYDBhDt+E2TCo7DkratH4Pgi8P+Z0= github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/gofrs/uuid v4.4.0+incompatible h1:3qXRTX8/NbyulANqlc0lchS1gqAVxRgsuW1YrTJupqA= github.com/gofrs/uuid v4.4.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/gogo/googleapis v0.0.0-20180223154316-0cd9801be74a/go.mod h1:gf4bu3Q80BeJ6H1S1vYPm8/ELATdvryBaNFGgqEef3s= github.com/gogo/googleapis v1.4.1/go.mod h1:2lpHqI5OcWCtVElxXnPt+s8oJvMpySlOyM6xDCrzib4= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.2.1/go.mod h1:hp+jE20tsWTFYpLwKvXlhS1hjn+gTNwPg2I6zVXpSg4= github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/gogo/status v1.1.0/go.mod h1:BFv9nrluPLmrS0EmGVvLaPNmRosr9KapBYd5/hpY1WM= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g= github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20190129154638-5b532d6fd5ef/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.0/go.mod h1:Qd/q+1AKNOZr9uGQzbzCmRO6sUih6GTPZv6a1/R87v0= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/golang/snappy v1.0.0 h1:Oy607GVXHs7RtbggtPBnr2RmDArIsAefDwvrdWvRhGs= github.com/golang/snappy v1.0.0/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/gomlx/go-xla v0.2.0 h1:vPRgGjKUaN4Pq58ZswWtiKNGUkqcbUj95YmvELZNvTA= github.com/gomlx/go-xla v0.2.0/go.mod h1:T2CsL/E90te3k4qpuzlXv2uQU2FmLMLfUsRlAGqKSuI= github.com/gomlx/gomlx v0.27.1 h1:WhWop7VzKWOgl7C0yDfvJ0kXiaNAWjw5XvWME5cP6Zg= github.com/gomlx/gomlx v0.27.1/go.mod h1:9fOMnTb7YMs/6zYVR6diliq25x9oSjMUPMlHAbcJWd4= github.com/gomodule/redigo v1.7.1-0.20190724094224-574c33c3df38/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4= github.com/gonum/blas v0.0.0-20181208220705-f22b278b28ac/go.mod h1:P32wAyui1PQ58Oce/KYkOqQv8cVw1zAapXOl+dRFGbc= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/flatbuffers v1.10.0/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= github.com/google/flatbuffers v1.11.0/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= github.com/google/flatbuffers v1.12.0/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/martian v2.1.0+incompatible h1:/CP5g8u/VJHijgedC/Legn3BAbAaWPgecwXBIDzw5no= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/martian/v3 v3.3.3 h1:DIhPTQrbPkgs2yJYdXU/eNACCG5DVQjySNRNlflZ9Fc= github.com/google/martian/v3 v3.3.3/go.mod h1:iEPrYcgCF7jA9OtScMFQyAlZZ4YXTKEtJ1E6RWzmBA0= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20250403155104-27863c87afa6 h1:BHT72Gu3keYf3ZEu2J0b1vyeLSOYI8bm5wbJM/8yDe8= github.com/google/pprof v0.0.0-20250403155104-27863c87afa6/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/renameio/v2 v2.0.0 h1:UifI23ZTGY8Tt29JbYFiuyIU3eX+RNFtUwefq9qAhxg= github.com/google/renameio/v2 v2.0.0/go.mod h1:BtmJXm5YlszgC+TD4HOEEUFgkJP3nLxehU6hfe7jRt4= github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0= github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.1.5/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/enterprise-certificate-proxy v0.3.14 h1:yh8ncqsbUY4shRD5dA6RlzjJaT4hi3kII+zYw8wmLb8= github.com/googleapis/enterprise-certificate-proxy v0.3.14/go.mod h1:vqVt9yG9480NtzREnTlmGSBmFrA+bzb0yl0TxoBQXOg= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= github.com/googleapis/gax-go/v2 v2.18.0 h1:jxP5Uuo3bxm3M6gGtV94P4lliVetoCB4Wk2x8QA86LI= github.com/googleapis/gax-go/v2 v2.18.0/go.mod h1:uSzZN4a356eRG985CzJ3WfbFSpqkLTjsnhWGJR6EwrE= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gorgonia/bindgen v0.0.0-20180812032444-09626750019e/go.mod h1:YzKk63P9jQHkwAo2rXHBv02yPxDzoQT2cBV0x5bGV/8= github.com/gorilla/handlers v1.5.2 h1:cLTUSsNkgcwhgRqvCNmdbRWG0A3N4F+M2nWKdScwyEE= github.com/gorilla/handlers v1.5.2/go.mod h1:dX+xVpaxdSw+q0Qek8SSsl3dfMk3jNddUkMzo0GtH0w= github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA= github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo= github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gorse-io/clickhouse v0.3.3-0.20251121080503-d578f146896d h1:s6U3e1W7NdviEq12zOisra1cAaG9+fkL6jnMK8sok/Y= github.com/gorse-io/clickhouse v0.3.3-0.20251121080503-d578f146896d/go.mod h1:iILWzbul8U+gsf4kqbheF2QzBmdvVp63mloGGK8emDI= github.com/gorse-io/dashboard v0.0.0-20260223101641-33715448ded8 h1:IjZK+VF93JvEAZEFsOa8qwVKBfb4dXuCF+m7Qb2Yr5c= github.com/gorse-io/dashboard v0.0.0-20260223101641-33715448ded8/go.mod h1:Ako5pxQlNyCiJpLpcJAwYWMF5nXdYP0mgpGcEOaw0vQ= github.com/gorse-io/gorse-go v0.5.0-alpha.3 h1:GR/OWzq016VyFyTTxgQWeayGahRVzB1cGFIW/AaShC4= github.com/gorse-io/gorse-go v0.5.0-alpha.3/go.mod h1:ZxmVHzZPKm5pmEIlqaRDwK0rkfTRHlfziO033XZ+RW0= github.com/gorse-io/sqlite v1.3.3-0.20220713123255-c322aec4e59e h1:uPQtYQzG1QcC3Qbv+tuEe8Q2l++V4KEcqYSSwB9qobg= github.com/gorse-io/sqlite v1.3.3-0.20220713123255-c322aec4e59e/go.mod h1:PmIOwYnI+F1lRKd6F/PdLXGgI8GZ5H8x8z1yx0+0bmQ= github.com/grpc-ecosystem/go-grpc-middleware v1.0.0/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs= github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 h1:UH//fgunKIs4JdUbpDl1VZCDaL56wXCB/5+wF6uHfaI= github.com/grpc-ecosystem/go-grpc-middleware v1.4.0/go.mod h1:g5qyo/la0ALbONm6Vbp88Yd8NsDy6rZz+RcrMPxvld8= github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk= github.com/grpc-ecosystem/grpc-gateway v1.9.0/go.mod h1:vNeuVxBJEsws4ogUvrchl83t/GYV9WGTSLVdBhOQFDY= github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 h1:HWRh5R2+9EifMyIHV7ZV+MIZqgz+PMpZ14Jynv3O2Zs= github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0/go.mod h1:JfhWUomR1baixubs02l85lZYYOm7LV6om4ceouMv45c= github.com/hashicorp/consul/api v1.1.0/go.mod h1:VmuI/Lkw1nC05EYQWNKwWGbkg+FbDBtguAZLlVdkD9Q= github.com/hashicorp/consul/sdk v0.1.1/go.mod h1:VKf9jXwCTEY1QZP2MOLRhb5i/I/ssyNV1vwHyQBF0x8= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-cleanhttp v0.5.1/go.mod h1:JpRdi6/HCYpAwUzNwuwqhbovhLtngrth3wmdIIUrZ80= github.com/hashicorp/go-immutable-radix v1.0.0/go.mod h1:0y9vanUI8NX6FsYoO3zeMjhV/C5i9g4Q3DwcSNZ4P60= github.com/hashicorp/go-msgpack v0.5.3/go.mod h1:ahLV/dePpqEmjfWmKiqvPkv/twdG7iPBM1vqhUKIvfM= github.com/hashicorp/go-multierror v1.0.0/go.mod h1:dHtQlpGsu+cZNNAkkCN/P3hoUDHhCYQXV3UM06sGGrk= github.com/hashicorp/go-rootcerts v1.0.0/go.mod h1:K6zTfqpRlCUIjkwsN4Z+hiSfzSTQa6eBIzfwKfwNnHU= github.com/hashicorp/go-sockaddr v1.0.0/go.mod h1:7Xibr9yA9JjQq1JpNB2Vw7kxv8xerXegt+ozgdvDeDU= github.com/hashicorp/go-syslog v1.0.0/go.mod h1:qPfqrKkXGihmCqbJM2mZgkZGvKG1dFdvsLplgctolz4= github.com/hashicorp/go-uuid v1.0.0/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/go-uuid v1.0.1/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/go-version v1.2.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= github.com/hashicorp/go-version v1.5.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= github.com/hashicorp/go-version v1.6.0 h1:feTTfFNnjP967rlCxM/I9g701jU+RN74YKx2mOkIeek= github.com/hashicorp/go-version v1.6.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= github.com/hashicorp/go.net v0.0.1/go.mod h1:hjKkEWcCURg++eb33jQU7oqQcI9XDCnUzHA0oac0k90= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v1.0.2 h1:dV3g9Z/unq5DpblPpw+Oqcv4dU/1omnb4Ok8iPY6p1c= github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/hashicorp/logutils v1.0.0/go.mod h1:QIAnNjmIWmVIIkWDTG1z5v++HQmx9WQRO+LraFDTW64= github.com/hashicorp/mdns v1.0.0/go.mod h1:tL+uN++7HEJ6SQLQ2/p+z2pH24WQKWjBPkE0mNTz8vQ= github.com/hashicorp/memberlist v0.1.3/go.mod h1:ajVTdAv/9Im8oMAAj5G31PhhMCZJV2pPBoIllUwCN7I= github.com/hashicorp/serf v0.8.2/go.mod h1:6hOLApaqBFA1NXqRQAsxw9QxuDEvNxSQRwA/JwenrHc= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/hydrogen18/memlistener v0.0.0-20200120041712-dcc25e7acd91/go.mod h1:qEIFzExnS6016fRpRfxrExeVn2gbClQA99gQhnIcdhE= github.com/imkira/go-interpol v1.1.0/go.mod h1:z0h2/2T3XF8kyEPpRgJ3kmNv+C43p+I/CoI+jC3w2iA= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= github.com/iris-contrib/blackfriday v2.0.0+incompatible/go.mod h1:UzZ2bDEoaSGPbkg6SAB4att1aAwTmVIx/5gCVqeyUdI= github.com/iris-contrib/go.uuid v2.0.0+incompatible/go.mod h1:iz2lgM/1UnEf1kP0L/+fafWORmlnuysV2EMP8MW+qe0= github.com/iris-contrib/jade v1.1.3/go.mod h1:H/geBymxJhShH5kecoiOCSssPX7QWYH7UaeZTSWddIk= github.com/iris-contrib/pongo2 v0.0.1/go.mod h1:Ssh+00+3GAZqSQb30AvBRNxBx7rf0GqwkjqxNd0u65g= github.com/iris-contrib/schema v0.0.1/go.mod h1:urYA3uvUNG1TIIjOSCzHr9/LmbQo8LrOcOqfqxa4hXw= github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8= github.com/jackc/chunkreader/v2 v2.0.1/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA= github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE= github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s= github.com/jackc/pgconn v1.4.0/go.mod h1:Y2O3ZDF0q4mMacyWV3AstPJpeHXWGEetiFttmq5lahk= github.com/jackc/pgconn v1.5.0/go.mod h1:QeD3lBfpTFe8WUnPZWN5KY/mB8FGMIYRdd8P8Jr0fAI= github.com/jackc/pgconn v1.5.1-0.20200601181101-fa742c524853/go.mod h1:QeD3lBfpTFe8WUnPZWN5KY/mB8FGMIYRdd8P8Jr0fAI= github.com/jackc/pgconn v1.8.0/go.mod h1:1C2Pb36bGIP9QHGBYCjnyhqu7Rv3sGshaQUvmfGIB/o= github.com/jackc/pgconn v1.9.0/go.mod h1:YctiPyvzfU11JFxoXokUOOKQXQmDMoJL9vJzHH8/2JY= github.com/jackc/pgconn v1.9.1-0.20210724152538-d89c8390a530/go.mod h1:4z2w8XhRbP1hYxkpTuBjTS3ne3J48K83+u0zoyvg2pI= github.com/jackc/pgconn v1.12.0 h1:/RvQ24k3TnNdfBSW0ou9EOi5jx2cX7zfE8n2nLKuiP0= github.com/jackc/pgconn v1.12.0/go.mod h1:ZkhRC59Llhrq3oSfrikvwQ5NaxYExr6twkdkMLaKono= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd/go.mod h1:hrBW0Enj2AZTNpt/7Y5rr2xe/9Mn757Wtb2xeBzPv2c= github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 h1:DadwsjnMwFjfWc9y5Wi/+Zz7xoE5ALHsRQlOctkOiHc= github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65/go.mod h1:5R2h2EEX+qri8jOWMbJCtaPWkrrNc7OHwsp2TCqp7ak= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/jackc/pgproto3/v2 v2.0.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgproto3/v2 v2.3.0 h1:brH0pCGBDkBW07HWlN/oSBXrmo3WB0UvZd1pIuDcL8Y= github.com/jackc/pgproto3/v2 v2.3.0/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= github.com/jackc/pgtype v1.2.0/go.mod h1:5m2OfMh1wTK7x+Fk952IDmI4nw3nPrvtQdM0ZT4WpC0= github.com/jackc/pgtype v1.3.1-0.20200510190516-8cd94a14c75a/go.mod h1:vaogEUkALtxZMCH411K+tKzNpwzCKU+AnPzBKZ+I+Po= github.com/jackc/pgtype v1.3.1-0.20200606141011-f6355165a91c/go.mod h1:cvk9Bgu/VzJ9/lxTO5R5sf80p0DiucVtN7ZxvaC4GmQ= github.com/jackc/pgtype v1.6.2/go.mod h1:JCULISAZBFGrHaOXIIFiyfzW5VY0GRitRr8NeJsrdig= github.com/jackc/pgtype v1.8.1-0.20210724151600-32e20a603178/go.mod h1:C516IlIV9NKqfsMCXTdChteoXmwgUceqaLfjg2e3NlM= github.com/jackc/pgtype v1.11.0 h1:u4uiGPz/1hryuXzyaBhSk6dnIyyG2683olG2OV+UUgs= github.com/jackc/pgtype v1.11.0/go.mod h1:LUMuVrfsFfdKGLw+AFFVv6KtHOFMwRgDDzBt76IqCA4= github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM= github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc= github.com/jackc/pgx/v4 v4.5.0/go.mod h1:EpAKPLdnTorwmPUUsqrPxy5fphV18j9q3wrfRXgo+kA= github.com/jackc/pgx/v4 v4.6.1-0.20200510190926-94ba730bb1e9/go.mod h1:t3/cdRQl6fOLDxqtlyhe9UWgfIi9R8+8v8GKV5TRA/o= github.com/jackc/pgx/v4 v4.6.1-0.20200606145419-4e5062306904/go.mod h1:ZDaNWkt9sW1JMiNn0kdYBaLelIhw7Pg4qd+Vk6tw7Hg= github.com/jackc/pgx/v4 v4.10.1/go.mod h1:QlrWebbs3kqEZPHCTGyxecvzG6tvIsYu+A5b1raylkA= github.com/jackc/pgx/v4 v4.12.1-0.20210724153913-640aa07df17c/go.mod h1:1QD0+tgSXP7iUjYm9C1NxKhny7lq6ee99u/z+IHFcgs= github.com/jackc/pgx/v4 v4.16.0 h1:4k1tROTJctHotannFYzu77dY3bgtMRymQP7tXQjqpPk= github.com/jackc/pgx/v4 v4.16.0/go.mod h1:N0A9sFdWzkw/Jy1lwoiB64F2+ugFZi987zRxcPez/wI= github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.1.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.1.1/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.2.1/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/janpfeifer/go-benchmarks v0.1.1 h1:gLLy07/JrOKSnMWeUxSnjTdhkglgmrNR2IBDnR4kRqw= github.com/janpfeifer/go-benchmarks v0.1.1/go.mod h1:5AagXCOUzevvmYFQalcgoa4oWPyH1IkZNckolGWfiSM= github.com/jaswdr/faker v1.19.1 h1:xBoz8/O6r0QAR8eEvKJZMdofxiRH+F0M/7MU9eNKhsM= github.com/jaswdr/faker v1.19.1/go.mod h1:x7ZlyB1AZqwqKZgyQlnqEG8FDptmHlncA5u2zY/yi6w= github.com/jellydator/ttlcache/v3 v3.4.0 h1:YS4P125qQS0tNhtL6aeYkheEaB/m8HCqdMMP4mnWdTY= github.com/jellydator/ttlcache/v3 v3.4.0/go.mod h1:Hw9EgjymziQD3yGsQdf1FqFdpp7YjFMd4Srg5EJlgD4= github.com/jinzhu/gorm v1.9.16/go.mod h1:G3LB3wezTOWM2ITLzPxEXgSkOXAntiLHS7UdBefADcs= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.0.1/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jinzhu/now v1.1.1/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/juju/errors v1.0.0 h1:yiq7kjCLll1BiaRuNY53MGI0+EQ3rF6GB+wvboZDefM= github.com/juju/errors v1.0.0/go.mod h1:B5x9thDqx0wIMH3+aLIMP9HjItInYWObRovoCFM5Qe8= github.com/juju/ratelimit v1.0.2 h1:sRxmtRiajbvrcLQT7S+JbqU0ntsb9W2yhSdNN8tWfaI= github.com/juju/ratelimit v1.0.2/go.mod h1:qapgC/Gy+xNh9UxzV13HGGl/6UXNN+ct+vwSgWNm/qk= github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes= github.com/k0kubun/colorstring v0.0.0-20150214042306-9440f1994b88/go.mod h1:3w7q1U84EfirKl04SVQ/s7nPm1ZPhiXd34z40TNz36k= github.com/kataras/golog v0.0.10/go.mod h1:yJ8YKCmyL+nWjERB90Qwn+bdyBZsaQwU3bTVFgkFIp8= github.com/kataras/iris/v12 v12.1.8/go.mod h1:LMYy4VlP67TQ3Zgriz8RE2h2kMZV2SgMYbq3UhfoFmE= github.com/kataras/neffos v0.0.14/go.mod h1:8lqADm8PnbeFfL7CLXh1WHw53dG27MC3pgi2R1rmoTE= github.com/kataras/pio v0.0.2/go.mod h1:hAoW0t9UmXi4R5Oyq5Z4irTbaTsOemSrDGUtaTl7Dro= github.com/kataras/sitemap v0.0.5/go.mod h1:KY2eugMKiPwsJgx7+U103YZehfvNGOXURubcGyk0Bz8= github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= github.com/kisielk/errcheck v1.1.0/go.mod h1:EZBBE59ingxPouuu3KfxchcWSUPOHkagtvWXihfKN4Q= github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/compress v1.8.2/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A= github.com/klauspost/compress v1.9.7/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A= github.com/klauspost/compress v1.18.4 h1:RPhnKRAQ4Fh8zU2FY/6ZFDwTVTxgJ/EMydqSTzE9a2c= github.com/klauspost/compress v1.18.4/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= github.com/klauspost/cpuid v1.2.1/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek= github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= github.com/klauspost/crc32 v1.3.0 h1:sSmTt3gUt81RP655XGZPElI0PelVTZ6YwCRnPSupoFM= github.com/klauspost/crc32 v1.3.0/go.mod h1:D7kQaZhnkX/Y0tstFGf8VUzv2UofNGqCjnC3zdHB0Hw= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/labstack/echo/v4 v4.5.0/go.mod h1:czIriw4a0C1dFun+ObrXp7ok03xON0N1awStJ6ArI7Y= github.com/labstack/gommon v0.3.0/go.mod h1:MULnywXg0yavhxWKc+lOruYdAhDwPK9wf0OL7NoOu+k= github.com/lafikl/consistent v0.0.0-20220512074542-bdd3606bfc3e h1:DuhzIzxOx3aJ0j4enY7SQ9bvulrT/XjkGAqiychfavc= github.com/lafikl/consistent v0.0.0-20220512074542-bdd3606bfc3e/go.mod h1:JmowInJuqa6EpSut8NSMAZtlvK9uL+8Q1P7tyew5rQY= github.com/leesper/go_rng v0.0.0-20171009123644-5344a9259b21/go.mod h1:N0SVk0uhy+E1PZ3C9ctsPRlvOPAFPkCNlcPBDkt0N3U= github.com/leesper/go_rng v0.0.0-20190531154944-a612b043e353/go.mod h1:N0SVk0uhy+E1PZ3C9ctsPRlvOPAFPkCNlcPBDkt0N3U= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.1.1/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.3.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.10.2/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lib/pq v1.11.2 h1:x6gxUeu39V0BHZiugWe8LXZYZ+Utk7hSJGThs8sdzfs= github.com/lib/pq v1.11.2/go.mod h1:/p+8NSbOcwzAEI7wiMXFlgydTwcgTr3OSKMsD2BitpA= github.com/madflojo/testcerts v1.5.0 h1:GhQllyAiGzXVZU+i8O/cQkPTHzN59RxMGtm3uETgXnU= github.com/madflojo/testcerts v1.5.0/go.mod h1:MW8sh39gLnkKh4K0Nc55AyHEDl9l/FBLDUsQhpmkuo0= github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= github.com/magiconair/properties v1.8.1/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/mailru/go-clickhouse/v2 v2.5.1 h1:k+YfKvUrTOHngWNBEmsTs0KAaS1L4paEe6c8IYOVqa8= github.com/mailru/go-clickhouse/v2 v2.5.1/go.mod h1:mJ/E4F05qQolb98/uFHWwFwgiO9NWss2DzZkhjV+jgo= github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-colorable v0.1.8/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-colorable v0.1.11/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= github.com/mattn/go-isatty v0.0.3/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-runewidth v0.0.4/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU= github.com/mattn/go-runewidth v0.0.10/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk= github.com/mattn/go-runewidth v0.0.21 h1:jJKAZiQH+2mIinzCJIaIG9Be1+0NR+5sz/lYEEjdM8w= github.com/mattn/go-runewidth v0.0.21/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs= github.com/mattn/go-sqlite3 v1.14.0/go.mod h1:JIl7NbARA7phWnGvh0LKTyg7S9BA+6gx71ShQilpsus= github.com/mattn/go-sqlite3 v1.14.5/go.mod h1:WVKg1VTActs4Qso6iwGbiFih2UIHo0ENGwNd0Lj+XmI= github.com/mattn/go-sqlite3 v1.14.12/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/mattn/goveralls v0.0.2/go.mod h1:8d1ZMHsd7fW6IRPKQh46F2WRpyib5/X4FOpevwGNQEw= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo= github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4= github.com/mediocregopher/radix/v3 v3.4.2/go.mod h1:8FL3F6UQRXHXIBSPUs5h0RybMF8i4n7wVopoX3x7Bv8= github.com/microcosm-cc/bluemonday v1.0.2/go.mod h1:iVP4YcDBq+n/5fb23BhYFvIMq/leAFZyRl6bYmGDlGc= github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= github.com/milvus-io/milvus-proto/go-api/v2 v2.4.10-0.20240819025435-512e3b98866a h1:0B/8Fo66D8Aa23Il0yrQvg1KKz92tE/BJ5BvkUxxAAk= github.com/milvus-io/milvus-proto/go-api/v2 v2.4.10-0.20240819025435-512e3b98866a/go.mod h1:1OIl0v5PQeNxIJhCvY+K55CBUOYDZevw9g9380u1Wek= github.com/milvus-io/milvus-sdk-go/v2 v2.4.2 h1:Xqf+S7iicElwYoS2Zly8Nf/zKHuZsNy1xQajfdtygVY= github.com/milvus-io/milvus-sdk-go/v2 v2.4.2/go.mod h1:ulO1YUXKH0PGg50q27grw048GDY9ayB4FPmh7D+FFTA= github.com/minio/blake2b-simd v0.0.0-20160723061019-3f5f724cb5b1 h1:lYpkrQH5ajf0OXOcUbGjvZxxijuBwbbmlSxLiuofa+g= github.com/minio/blake2b-simd v0.0.0-20160723061019-3f5f724cb5b1/go.mod h1:pD8RvIylQ358TN4wwqatJ8rNavkEINozVn9DtGI3dfQ= github.com/minio/crc64nvme v1.1.1 h1:8dwx/Pz49suywbO+auHCBpCtlW1OfpcLN7wYgVR6wAI= github.com/minio/crc64nvme v1.1.1/go.mod h1:eVfm2fAzLlxMdUGc0EEBGSMmPwmXD5XiNRpnu9J3bvg= github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34= github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM= github.com/minio/minio-go/v7 v7.0.99 h1:2vH/byrwUkIpFQFOilvTfaUpvAX3fEFhEzO+DR3DlCE= github.com/minio/minio-go/v7 v7.0.99/go.mod h1:EtGNKtlX20iL2yaYnxEigaIvj0G0GwSDnifnG8ClIdw= github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc= github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db h1:62I3jR2EmQ4l5rM/4FEfDWcRD+abF5XlKShorW5LRoQ= github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db/go.mod h1:l0dey0ia/Uv7NcFFVbCLtqEBQbrT4OCwCSKTEv6enCw= github.com/mitchellh/go-homedir v1.0.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/mitchellh/go-testing-interface v1.0.0/go.mod h1:kRemZodwjscx+RGhAo8eIhFbs2+BFgRtFPeD/KE+zxI= github.com/mitchellh/gox v0.4.0/go.mod h1:Sd9lOJ0+aimLBi73mGofS1ycjY8lL3uZM3JPS42BGNg= github.com/mitchellh/iochan v1.0.0/go.mod h1:JwYml1nuB7xOzsp52dPpHFffvOCDupsG0QubkSMEySY= github.com/mitchellh/mapstructure v0.0.0-20160808181253-ca63d7c062ee/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/montanaflynn/stats v0.7.1 h1:etflOAAHORrCC44V+aR6Ftzort912ZU+YLiSTuV8eaE= github.com/montanaflynn/stats v0.7.1/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow= github.com/moul/http2curl v1.0.0/go.mod h1:8UbvGypXm98wA/IqH45anm5Y2Z6ep6O31QGOAZ3H0fQ= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/nats-io/jwt v0.3.0/go.mod h1:fRYCDE99xlTsqUzISS1Bi75UBJ6ljOJQOAAu5VglpSg= github.com/nats-io/nats.go v1.9.1/go.mod h1:ZjDU1L/7fJ09jvUSRVBR2e7+RnLiiIQyqyzEE/Zbp4w= github.com/nats-io/nkeys v0.1.0/go.mod h1:xpnFELMwJABBLVhffcfd1MZx6VsNRFpEugbxziKVo7w= github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/nikolalohinski/gonja/v2 v2.7.0 h1:XuwnulQVPwzGaM0J/9AaQv0AFPBAxKI1GILifQ1r9pk= github.com/nikolalohinski/gonja/v2 v2.7.0/go.mod h1:UIzXPVuOsr5h7dZ5DUbqk3/Z7oFA/NLGQGMjqT4L2aU= github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= github.com/oklog/ulid/v2 v2.1.1 h1:suPZ4ARWLOJLegGFiZZ1dFAkqzhMjL3J1TzI+5wHz8s= github.com/oklog/ulid/v2 v2.1.1/go.mod h1:rcEKHmBBKfef9DhnvX7y1HZBYxjXb0cP5ExxNsTT1QQ= github.com/olekukonko/cat v0.0.0-20250911104152-50322a0618f6 h1:zrbMGy9YXpIeTnGj4EljqMiZsIcE09mmF8XsD5AYOJc= github.com/olekukonko/cat v0.0.0-20250911104152-50322a0618f6/go.mod h1:rEKTHC9roVVicUIfZK7DYrdIoM0EOr8mK1Hj5s3JjH0= github.com/olekukonko/errors v1.2.0 h1:10Zcn4GeV59t/EGqJc8fUjtFT/FuUh5bTMzZ1XwmCRo= github.com/olekukonko/errors v1.2.0/go.mod h1:ppzxA5jBKcO1vIpCXQ9ZqgDh8iwODz6OXIGKU8r5m4Y= github.com/olekukonko/ll v0.1.6 h1:lGVTHO+Qc4Qm+fce/2h2m5y9LvqaW+DCN7xW9hsU3uA= github.com/olekukonko/ll v0.1.6/go.mod h1:NVUmjBb/aCtUpjKk75BhWrOlARz3dqsM+OtszpY4o88= github.com/olekukonko/tablewriter v1.1.4 h1:ORUMI3dXbMnRlRggJX3+q7OzQFDdvgbN9nVWj1drm6I= github.com/olekukonko/tablewriter v1.1.4/go.mod h1:+kedxuyTtgoZLwif3P1Em4hARJs+mVnzKxmsCL/C5RY= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.10.3 h1:OoxbjfXVZyod1fmWYhI7SEyaD8B00ynP3T+D5GiyHOY= github.com/onsi/ginkgo v1.10.3/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo/v2 v2.23.4 h1:ktYTpKJAVZnDT4VjxSbiBenUjmlL/5QkBEocaWXiQus= github.com/onsi/ginkgo/v2 v2.23.4/go.mod h1:Bt66ApGPBFzHyR+JO10Zbt0Gsp4uWxu5mIOTusL46e8= github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= github.com/onsi/gomega v1.37.0 h1:CdEG8g0S133B4OswTDC/5XPSzE1OeP29QOioj2PID2Y= github.com/onsi/gomega v1.37.0/go.mod h1:8D9+Txp43QWKhM24yyOBEdpkzN8FvJyAwecBgsU4KU0= github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= github.com/openzipkin/zipkin-go v0.4.3 h1:9EGwpqkgnwdEIJ+Od7QVSEIH+ocmm5nPat0G7sjsSdg= github.com/openzipkin/zipkin-go v0.4.3/go.mod h1:M9wCJZFWCo2RiY+o1eBCEMe0Dp2S5LDHcMZmk3RmK7c= github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= github.com/pborman/getopt v0.0.0-20170112200414-7148bc3a4c30/go.mod h1:85jBQOZwpVEaDAr341tbn15RS4fCAsIst0qp7i8ex1o= github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= github.com/pelletier/go-toml v1.9.2/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c= github.com/pelletier/go-toml v1.9.5 h1:4yBQzkHv+7BHq2PQUZF3Mx0IYxG7LsP222s7Agd3ve8= github.com/pelletier/go-toml v1.9.5/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c= github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= github.com/philhofer/fwd v1.2.0 h1:e6DnBTl7vGY+Gz322/ASL4Gyp1FspeMvx1RNDoToZuM= github.com/philhofer/fwd v1.2.0/go.mod h1:RqIHx9QI14HlwKwm98g9Re5prTQ6LdeRQn+gXJFxsJM= github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4= github.com/pingcap/errors v0.11.4/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/xattr v0.4.12 h1:rRTkSyFNTRElv6pkA3zpjHpQ90p/OdHQC1GmGh1aTjM= github.com/pkg/xattr v0.4.12/go.mod h1:di8WF84zAKk8jzR1UBTEWh9AUlIZZ7M/JNt8e9B6ktU= github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgmp0tecUJ0sJuv4pzYCqS9+RGSn52M3FUwPs+uo= github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/posener/complete v1.1.1/go.mod h1:em0nMJCgc9GFtwrmVmEMR/ZL6WyhyjMBndrE9hABlRI= github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= github.com/prometheus/client_golang v0.9.3/go.mod h1:/TN21ttK/J9q6uSwhBd54HahCDft0ttaMvbicHlPoso= github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o= github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk= github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE= github.com/prometheus/common v0.0.0-20181113130724-41aa239b4cce/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro= github.com/prometheus/common v0.4.0/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs= github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA= github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/prometheus/procfs v0.0.0-20190507164030-5867b95ac084/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg= github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is= github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU= github.com/qdrant/go-client v1.17.1 h1:7QmPwDddrHL3hC4NfycwtQlraVKRLcRi++BX6TTm+3g= github.com/qdrant/go-client v1.17.1/go.mod h1:n1h6GhkdAzcohoXt/5Z19I2yxbCkMA6Jejob3S6NZT8= github.com/rakyll/statik v0.1.7/go.mod h1:AlZONWzMtEnMs7W4e/1LURLiI49pIMmp6V9Unghqrcc= github.com/rakyll/statik v0.1.8 h1:Fe7egWVZbW/2vlPUY8P/aL9o6qbtrBn71uIztkdafMU= github.com/rakyll/statik v0.1.8/go.mod h1:AlZONWzMtEnMs7W4e/1LURLiI49pIMmp6V9Unghqrcc= github.com/redis/go-redis/extra/rediscmd/v9 v9.18.0 h1:QY4nmPHLFAJjtT5O4OMUEOxP8WVaRNOFpcbmxT2NLZU= github.com/redis/go-redis/extra/rediscmd/v9 v9.18.0/go.mod h1:WH8cY/0fT41Bsf341qzo8v4nx0GCE8FykAA23IVbVmo= github.com/redis/go-redis/extra/redisotel/v9 v9.18.0 h1:2dKdoEYBJ0CZCLPiCdvvc7luz3DPwY6hKdzjL6m1eHE= github.com/redis/go-redis/extra/redisotel/v9 v9.18.0/go.mod h1:WzkrVG9ro9BwCQD0eJOWn6AGL4Z1CleGflM45w1hu10= github.com/redis/go-redis/v9 v9.18.0 h1:pMkxYPkEbMPwRdenAzUNyFNrDgHx9U+DrBabWNfSRQs= github.com/redis/go-redis/v9 v9.18.0/go.mod h1:k3ufPphLU5YXwNTUcCRXGxUoF1fqxnhFQmscfkCoDA0= github.com/remyoudompheng/bigfft v0.0.0-20190728182440-6a916e37a237/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/rogpeppe/go-internal v1.8.1/go.mod h1:JeRgkft04UBgHMgCIwADu4Pn6Mtm5d4nPKWu0nJ5d+o= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU= github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU= github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc= github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/ryanuber/columnize v0.0.0-20160712163229-9b3edd62028f/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= github.com/ryanuber/columnize v2.1.0+incompatible/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= github.com/sagikazarmark/locafero v0.12.0 h1:/NQhBAkUb4+fH1jivKHWusDYFjMOOKU88eegjfxfHb4= github.com/sagikazarmark/locafero v0.12.0/go.mod h1:sZh36u/YSZ918v0Io+U9ogLYQJ9tLLBmM4eneO6WwsI= github.com/samber/lo v1.53.0 h1:t975lj2py4kJPQ6haz1QMgtId2gtmfktACxIXArw3HM= github.com/samber/lo v1.53.0/go.mod h1:4+MXEGsJzbKGaUEQFKBq2xtfuznW9oz/WrgyzMzRoM0= github.com/sashabaranov/go-openai v1.41.2 h1:vfPRBZNMpnqu8ELsclWcAvF19lDNgh1t6TVfFFOPiSM= github.com/sashabaranov/go-openai v1.41.2/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= github.com/schollz/closestmatch v2.1.0+incompatible/go.mod h1:RtP1ddjLong6gTkbtmuhtR2uUrrJOpYzYRvbcPAid+g= github.com/schollz/progressbar/v3 v3.19.0 h1:Ea18xuIRQXLAUidVDox3AbwfUhD0/1IvohyTutOIFoc= github.com/schollz/progressbar/v3 v3.19.0/go.mod h1:IsO3lpbaGuzh8zIMzgY3+J8l4C8GjO0Y9S69eFvNsec= github.com/sclevine/yj v0.0.0-20210612025309-737bdf40a5d1 h1:2cp8mZ+gRxIx/5GoLcQmmXzJuKdM1cSE2dmvFn3udJE= github.com/sclevine/yj v0.0.0-20210612025309-737bdf40a5d1/go.mod h1:PkEqqiiBYB87KgvpQj2r0wtRjDKEhhLRarGCubegp7E= github.com/scylladb/termtables v0.0.0-20191203121021-c4c0b6d42ff4/go.mod h1:C1a7PQSMz9NShzorzCiG2fk9+xuCgLkPeCvMHYR2OWg= github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= github.com/shopspring/decimal v0.0.0-20200227202807-02e2044944cc/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ= github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.9.4 h1:TsZE7l11zFCLZnZ+teH4Umoq5BhEIfIzfRDZ1Uzql2w= github.com/sirupsen/logrus v1.9.4/go.mod h1:ftWc9WdOfJ0a92nsE2jF5u5ZwH8Bv2zdeOC42RjbV2g= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM= github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ= github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I= github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg= github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY= github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= github.com/spf13/cobra v0.0.5/go.mod h1:3K3wKZymM7VvHMDS9+Akkh4K60UwM26emMESw8tLCHU= github.com/spf13/cobra v1.1.1/go.mod h1:WnodtKOvamDL/PwE2M4iKs8aMDBZ5Q5klgD3qfVJQMI= github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU= github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4= github.com/spf13/jwalterweatherman v1.0.0/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb68N+wFjFa4jdeBTo= github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/viper v1.3.2/go.mod h1:ZiWeW+zYFKm7srdB9IoDzzZXaJaI5eL9QjNiN/DMA2s= github.com/spf13/viper v1.7.0/go.mod h1:8WkrPz2fc9jxqZNCJI/76HCieCp4Q8HaLFoCha5qpdg= github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU= github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY= github.com/spiffe/go-spiffe/v2 v2.6.0 h1:l+DolpxNWYgruGQVV0xsfeya3CsC7m8iBzDnMpsbLuo= github.com/spiffe/go-spiffe/v2 v2.6.0/go.mod h1:gm2SeUoMZEtpnzPNs2Csc0D/gX33k1xIx7lEzqblHEs= github.com/steinfletcher/apitest v1.6.0 h1:BvZpQh0DECrDq7nFzDjwQqwXAEc+cykuVD4aUdXvQfA= github.com/steinfletcher/apitest v1.6.0/go.mod h1:mF+KnYaIkuHM0C4JgGzkIIOJAEjo+EA5tTjJ+bHXnQc= github.com/streadway/quantile v0.0.0-20220407130108-4246515d968d h1:X4+kt6zM/OVO6gbJdAfJR60MGPsqCzbtXNnjoGqdfAs= github.com/streadway/quantile v0.0.0-20220407130108-4246515d968d/go.mod h1:lbP8tGiBjZ5YWIc2fzuRpTaz0b/53vT6PEs3QuAWzuU= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.1.4/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.0/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= github.com/swaggest/swgui v1.8.5 h1:nceK5OJcpXpkfjmPNH6wtubbd8ZYwxy043xmx0SK18g= github.com/swaggest/swgui v1.8.5/go.mod h1:kvSzLC7+wK4l9n/YcQlb2AMeQtkno9i3C6imADv/fLQ= github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM= github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tiktoken-go/tokenizer v0.7.0 h1:VMu6MPT0bXFDHr7UPh9uii7CNItVt3X9K90omxL54vw= github.com/tiktoken-go/tokenizer v0.7.0/go.mod h1:6UCYI/DtOallbmL7sSy30p6YQv60qNyU/4aVigPOx6w= github.com/tinylib/msgp v1.6.1 h1:ESRv8eL3u+DNHUoSAAQRE50Hm162zqAnBoGv9PzScPY= github.com/tinylib/msgp v1.6.1/go.mod h1:RSp0LW9oSxFut3KzESt5Voq4GVWyS+PSulT77roAqEA= github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= github.com/ugorji/go v1.1.4/go.mod h1:uQMGLiO92mf5W77hV/PUCpI3pbzQx3CRekS0kk+RGrc= github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0= github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= github.com/urfave/negroni v1.0.0/go.mod h1:Meg73S6kFm/4PpbYdq35yYWoCZ9mS/YSx+lKnmiohz4= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasthttp v1.6.0/go.mod h1:FstJa9V+Pj9vQ7OJie2qMHdwemEDaDiSdBnvPM1Su9w= github.com/valyala/fasttemplate v1.0.1/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPUpymEIMZ47gx8= github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= github.com/valyala/tcplisten v0.0.0-20161114210144-ceec8f93295a/go.mod h1:v3UYOV9WzVtRmSR+PDvWpU/qWl4Wa5LApYYX4ZtKbio= github.com/vearutop/statigz v1.4.0 h1:RQL0KG3j/uyA/PFpHeZ/L6l2ta920/MxlOAIGEOuwmU= github.com/vearutop/statigz v1.4.0/go.mod h1:LYTolBLiz9oJISwiVKnOQoIwhO1LWX1A7OECawGS8XE= github.com/weaviate/weaviate v1.27.0 h1:ovFnKER+HRpT5PPuR1ysbKgit0NSpHbBLcsjWR1UyWI= github.com/weaviate/weaviate v1.27.0/go.mod h1:ppTWDzt/atYk1KhyYzxVD8XckmaCaOYnnmelD5M4LK4= github.com/weaviate/weaviate-go-client/v4 v4.16.1 h1:jkDYuRCYly6zG2ngqTpv6z8azzbqiMUXcmaJHJmAV0Q= github.com/weaviate/weaviate-go-client/v4 v4.16.1/go.mod h1:XmoRpzNpWrTW5/TE07dUtxy5kMZbG3uAG/3b69nuwFk= github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= github.com/xdg-go/scram v1.2.0 h1:bYKF2AEwG5rqd1BumT4gAnvwU/M9nBp2pTSxeZw7Wvs= github.com/xdg-go/scram v1.2.0/go.mod h1:3dlrS0iBaWKYVt2ZfA4cj48umJZ+cAEbR6/SjLA88I8= github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8= github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM= github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ= github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y= github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= github.com/xtgo/set v1.0.0/go.mod h1:d3NHzGzSa0NmB2NhFyECA+QdRp29oEn2xbT+TpeFoM8= github.com/yalp/jsonpath v0.0.0-20180802001716-5cc68e5049a0/go.mod h1:/LWChgwKmvncFJFHJ7Gvn9wZArjbV5/FppcK2fKk/tI= github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 h1:ilQV1hzziu+LLM3zUTJ0trRztfwgjqKnBWNtSRkbmwM= github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78/go.mod h1:aL8wCCfTfSfmXjznFBSZNN13rSJjlIOI1fUNAtF7rmI= github.com/yudai/gojsondiff v1.0.0/go.mod h1:AY32+k2cwILAkW1fbgxQ5mUmMiZFgLIV+FBNExI05xg= github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82/go.mod h1:lgjkn3NuSvDfVJdfcVVdX+jpBxNmX4rDAzaS45IcYoM= github.com/yudai/pp v2.0.1+incompatible/go.mod h1:PuxR/8QJ7cyCkFp/aUDS+JY727OFEZkTdatxwunjIkc= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/yuin/goldmark v1.7.16 h1:n+CJdUxaFMiDUNnWC3dMWCIQJSkxH4uz3ZwQBkAlVNE= github.com/yuin/goldmark v1.7.16/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA= github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= go.einride.tech/aip v0.79.0 h1:19zdPlZzlUvxOA8syAFw4LkdJdXepzyTl6gt9XEeqdU= go.einride.tech/aip v0.79.0/go.mod h1:E8+wdTApA70odnpFzJgsGogHozC2JCIhFJBKPr8bVig= go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= go.mongodb.org/mongo-driver v1.17.9 h1:IexDdCuuNJ3BHrELgBlyaH9p60JXAvdzWR128q+U5tU= go.mongodb.org/mongo-driver v1.17.9/go.mod h1:LlOhpH5NUEfhxcAwG0UEkMqwYcc4JU18gtCdGudk/tQ= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= go.opentelemetry.io/contrib/detectors/gcp v1.40.0 h1:Awaf8gmW99tZTOWqkLCOl6aw1/rxAWVlHsHIZ3fT2sA= go.opentelemetry.io/contrib/detectors/gcp v1.40.0/go.mod h1:99OY9ZCqyLkzJLTh5XhECpLRSxcZl+ZDKBEO+jMBFR4= go.opentelemetry.io/contrib/instrumentation/github.com/emicklei/go-restful/otelrestful v0.67.0 h1:0ONi/nCM00kjql6HSU2kVs6nYFx3l0efqDrR1qZ03/A= go.opentelemetry.io/contrib/instrumentation/github.com/emicklei/go-restful/otelrestful v0.67.0/go.mod h1:2vCOOy1Om03w7th7aDKpgjJ3z0/K80RgIJwkORo5ab4= go.opentelemetry.io/contrib/instrumentation/go.mongodb.org/mongo-driver/mongo/otelmongo v0.67.0 h1:X0H+vhyjOczVijlJqIz2kqq0H3O/y3iwKgAKYTII7yU= go.opentelemetry.io/contrib/instrumentation/go.mongodb.org/mongo-driver/mongo/otelmongo v0.67.0/go.mod h1:hg41UE3tzwcEqMZD7xT4EjISCFFS5fgPbBsW37twGNs= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.63.0 h1:YH4g8lQroajqUwWbq/tr2QX1JFmEXaDLgG+ew9bLMWo= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.63.0/go.mod h1:fvPi2qXDqFs8M4B4fmJhE92TyQs9Ydjlg3RvfUp+NbQ= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.65.0 h1:7iP2uCb7sGddAr30RRS6xjKy7AZ2JtTOPA3oolgVSw8= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.65.0/go.mod h1:c7hN3ddxs/z6q9xwvfLPk+UHlWRQyaeR1LdgfL/66l0= go.opentelemetry.io/contrib/propagators/b3 v1.42.0 h1:B2Pew5ufEtgkjLF+tSkXjgYZXQr9m7aCm1wLKB0URbU= go.opentelemetry.io/contrib/propagators/b3 v1.42.0/go.mod h1:iPgUcSEF5DORW6+yNbdw/YevUy+QqJ508ncjhrRSCjc= go.opentelemetry.io/otel v1.42.0 h1:lSQGzTgVR3+sgJDAU/7/ZMjN9Z+vUip7leaqBKy4sho= go.opentelemetry.io/otel v1.42.0/go.mod h1:lJNsdRMxCUIWuMlVJWzecSMuNjE7dOYyWlqOXWkdqCc= go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.42.0 h1:THuZiwpQZuHPul65w4WcwEnkX2QIuMT+UFoOrygtoJw= go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.42.0/go.mod h1:J2pvYM5NGHofZ2/Ru6zw/TNWnEQp5crgyDeSrYpXkAw= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.42.0 h1:zWWrB1U6nqhS/k6zYB74CjRpuiitRtLLi68VcgmOEto= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.42.0/go.mod h1:2qXPNBX1OVRC0IwOnfo1ljoid+RD0QK3443EaqVlsOU= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.42.0 h1:uLXP+3mghfMf7XmV4PkGfFhFKuNWoCvvx5wP/wOXo0o= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.42.0/go.mod h1:v0Tj04armyT59mnURNUJf7RCKcKzq+lgJs6QSjHjaTc= go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.40.0 h1:ZrPRak/kS4xI3AVXy8F7pipuDXmDsrO8Lg+yQjBLjw0= go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.40.0/go.mod h1:3y6kQCWztq6hyW8Z9YxQDDm0Je9AJoFar2G0yDcmhRk= go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.42.0 h1:s/1iRkCKDfhlh1JF26knRneorus8aOwVIDhvYx9WoDw= go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.42.0/go.mod h1:UI3wi0FXg1Pofb8ZBiBLhtMzgoTm1TYkMvn71fAqDzs= go.opentelemetry.io/otel/exporters/zipkin v1.42.0 h1:Z7ARHF7193vyVltPYcmuhSKPLf8dP5rtJZLtTQnbMH4= go.opentelemetry.io/otel/exporters/zipkin v1.42.0/go.mod h1:DW09+gaEg5kydlb9g8kp4Nos3yqo9YSA1uHXkeJihXc= go.opentelemetry.io/otel/metric v1.42.0 h1:2jXG+3oZLNXEPfNmnpxKDeZsFI5o4J+nz6xUlaFdF/4= go.opentelemetry.io/otel/metric v1.42.0/go.mod h1:RlUN/7vTU7Ao/diDkEpQpnz3/92J9ko05BIwxYa2SSI= go.opentelemetry.io/otel/sdk v1.42.0 h1:LyC8+jqk6UJwdrI/8VydAq/hvkFKNHZVIWuslJXYsDo= go.opentelemetry.io/otel/sdk v1.42.0/go.mod h1:rGHCAxd9DAph0joO4W6OPwxjNTYWghRWmkHuGbayMts= go.opentelemetry.io/otel/sdk/metric v1.42.0 h1:D/1QR46Clz6ajyZ3G8SgNlTJKBdGp84q9RKCAZ3YGuA= go.opentelemetry.io/otel/sdk/metric v1.42.0/go.mod h1:Ua6AAlDKdZ7tdvaQKfSmnFTdHx37+J4ba8MwVCYM5hc= go.opentelemetry.io/otel/trace v1.42.0 h1:OUCgIPt+mzOnaUTpOQcBiM/PLQ/Op7oq6g4LenLmOYY= go.opentelemetry.io/otel/trace v1.42.0/go.mod h1:f3K9S+IFqnumBkKhRJMeaZeNk9epyhnCmQh/EysQCdc= go.opentelemetry.io/proto/otlp v1.9.0 h1:l706jCMITVouPOqEnii2fIAuO3IVGBRPV5ICjceRb/A= go.opentelemetry.io/proto/otlp v1.9.0/go.mod h1:xE+Cx5E/eEHw+ISFkwPLwCZefwVjY+pqKg1qcK03+/4= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= go.uber.org/automaxprocs v1.6.0 h1:O3y2/QNTOdbF+e/dpXNNW7Rx2hZ4sTIPyybbxyNqTUs= go.uber.org/automaxprocs v1.6.0/go.mod h1:ifeIMSnPZuznNm6jmdzmU3/bfk01Fe2fotchwEFJ8r8= go.uber.org/goleak v1.1.10/go.mod h1:8a7PlsEVH3e/a/GLqe5IIrQx6GzcnRmZEufDUTk4A7A= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4= go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU= go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ= go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA= go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= go.uber.org/zap v1.13.0/go.mod h1:zwrFLgMcdUuIBviXEYEH1YKNaOBnKXsx2IPda5bBwHM= go.uber.org/zap v1.18.1/go.mod h1:xg/QME4nWcxGxrpdeYfq7UvYrLh66cuVKdrbD1XF/NI= go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc= go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI= go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU= go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= go.yaml.in/yaml/v4 v4.0.0-rc.3 h1:3h1fjsh1CTAPjW7q/EMe+C8shx5d8ctzZTrLcs/j8Go= go.yaml.in/yaml/v4 v4.0.0-rc.3/go.mod h1:aZqd9kCMsGL7AuUv/m/PvWLdg5sjJsZ4oHDEnfPPfY0= go4.org/unsafe/assume-no-moving-gc v0.0.0-20201222180813-1025295fd063/go.mod h1:FftLjUGFEDu5k8lt0ddY+HcrH/qU/0qk+H8j9/nTl3E= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20181029021203-45a5f77698d3/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191205180655-e7c4368fe9dd/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20191227163750-53104e6ec876/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190312203227-4b39c73a6495/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= golang.org/x/exp v0.0.0-20190829153037-c13cbed26979/go.mod h1:86+5VVa7VpoJ4kLfm080zCjGlMRFzhUhsZKEZO7MGek= golang.org/x/exp v0.0.0-20191030013958-a1ab85dbe136/go.mod h1:JXzH8nQsPlswgeRAPE3MuO9GYsAcnJvJ4vnMwN/5qkY= golang.org/x/exp v0.0.0-20260312153236-7ab1446f8b90 h1:jiDhWWeC7jfWqR9c/uplMOqJ0sbNlNWv0UkzE0vX1MA= golang.org/x/exp v0.0.0-20260312153236-7ab1446f8b90/go.mod h1:xE1HEv6b+1SCZ5/uscMRjUBKtIxworgEcEi+/n9NQDQ= golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/image v0.26.0 h1:4XjIFEZWQmCZi6Wv8BoxsDhRU3RVnLX04dToTDAEPlY= golang.org/x/image v0.26.0/go.mod h1:lcxbMFAovzpnJxzXS3nyL83K27tmqtKzIJpctK8YO5c= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/lint v0.0.0-20190409202823-959b441ac422/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/lint v0.0.0-20190909230951-414d861bb4ac/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/lint v0.0.0-20210508222113-6edffad5e616/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/mobile v0.0.0-20190312151609-d3739f865fa6/go.mod h1:z+o9i4GpDbdi3rU15maQ/Ox0txvL9dWGYEHz965HBQE= golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o= golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI= golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY= golang.org/x/net v0.0.0-20180218175443-cbe0f9307d01/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181023162649-9b4f9f5ad519/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181201002055-351d144fa1fc/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181220203305-927f97764cc3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190327091125-710a502c58a2/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190501004415-9ce7a6920f09/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200904194848-62affa334b73/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20211008194852-3b03d305991f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs= golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181026203630-95b1ffbd15a5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190226215855-775f8194d0f9/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190626221950-04f50cda93cb/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190826190057-c7b8b68b1456/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200909081042-eff7692f9009/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211025201205-69cdffdb9359/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220209214540-3681064d5158/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220408201424-a24fb2fb8a0f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU= golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U= golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno= golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20181221001348-537d06c36207/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190206041539-40960b6deb8e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190312151545-0bb0c0a6e846/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190327201419-c70d86f8b7cf/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190506145303-2d16b83fe98c/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190606124116-d0a3d012864b/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= golang.org/x/tools v0.0.0-20190628153133-6cdbf07be9d0/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= golang.org/x/tools v0.0.0-20190816200558-6889da9d5479/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20190911174233-4f2ddba30aff/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191012152004-8de300cfc20a/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191108193012-7d206e10da11/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191112195655-aa38f8e97acc/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200103221440-774c71fcf114/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20201124115921-2c860bdd6e78/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.3/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.43.0 h1:12BdW9CeB3Z+J/I/wj34VMl8X+fEXBxVR90JeMX5E7s= golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0= golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gonum.org/v1/gonum v0.0.0-20180816165407-929014505bf4/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo= gonum.org/v1/gonum v0.0.0-20190226202314-149afe6ec0b6/go.mod h1:jevfED4GnIEnJrWW55YmY9DMhajHcnkqVnEXmEtMyNI= gonum.org/v1/gonum v0.0.0-20190902003836-43865b531bee/go.mod h1:9mxDZsDKxgMAuccQkewq682L+0eCu4dCN2yonUJTCLU= gonum.org/v1/gonum v0.8.1-0.20200930085651-eea0b5cb5cc9/go.mod h1:oe/vMfY3deqTw+1EZJhuvEW2iwGF1bW9wwu7XCu0+v0= gonum.org/v1/gonum v0.8.2/go.mod h1:oe/vMfY3deqTw+1EZJhuvEW2iwGF1bW9wwu7XCu0+v0= gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= gonum.org/v1/netlib v0.0.0-20190221094214-0632e2ebbd2d/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw= gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw= gonum.org/v1/netlib v0.0.0-20201012070519-2390d26c3658/go.mod h1:zQa7n16lh3Z6FbSTYgjG+KNhz1bA/b9t3plFEaGMp+A= gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc= gonum.org/v1/plot v0.15.2 h1:Tlfh/jBk2tqjLZ4/P8ZIwGrLEWQSPDLRm/SNWKNXiGI= gonum.org/v1/plot v0.15.2/go.mod h1:DX+x+DWso3LTha+AdkJEv5Txvi+Tql3KAGkehP0/Ubg= google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE= google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M= google.golang.org/api v0.8.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg= google.golang.org/api v0.9.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg= google.golang.org/api v0.13.0/go.mod h1:iLdEw5Ide6rF15KTC1Kkl0iskquN2gFfn9o9XIsbkAI= google.golang.org/api v0.272.0 h1:eLUQZGnAS3OHn31URRf9sAmRk3w2JjMx37d2k8AjJmA= google.golang.org/api v0.272.0/go.mod h1:wKjowi5LNJc5qarNvDCvNQBn3rVK8nSy6jg2SwRwzIA= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww/cMBSeb0= google.golang.org/genproto v0.0.0-20180518175338-11a468237815/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20180831171423-11092d34479b/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= google.golang.org/genproto v0.0.0-20190418145605-e7d98fc518a7/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= google.golang.org/genproto v0.0.0-20190425155659-357c62f0e4bb/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= google.golang.org/genproto v0.0.0-20190502173448-54afdca5d873/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= google.golang.org/genproto v0.0.0-20190801165951-fa694d86fc64/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= google.golang.org/genproto v0.0.0-20190911173649-1774047e7e51/go.mod h1:IbNlFCBrqXvoKpeg0TB2l7cyZUmoaFKYIwrEpbDKLA8= google.golang.org/genproto v0.0.0-20191108220845-16a3f7862a1a/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= google.golang.org/genproto v0.0.0-20200423170343-7949de9c1215/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= google.golang.org/genproto v0.0.0-20200911024640-645f7a48b24f/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20210624195500-8bfb893ecb84/go.mod h1:SzzZ/N+nwJDaO1kznhnlzqS8ocJICar6hYhVyhi++24= google.golang.org/genproto v0.0.0-20260217215200-42d3e9bedb6d h1:vsOm753cOAMkt76efriTCDKjpCbK18XGHMJHo0JUKhc= google.golang.org/genproto v0.0.0-20260217215200-42d3e9bedb6d/go.mod h1:0oz9d7g9QLSdv9/lgbIjowW1JoxMbxmBVNe8i6tORJI= google.golang.org/genproto/googleapis/api v0.0.0-20260217215200-42d3e9bedb6d h1:EocjzKLywydp5uZ5tJ79iP6Q0UjDnyiHkGRWxuPBP8s= google.golang.org/genproto/googleapis/api v0.0.0-20260217215200-42d3e9bedb6d/go.mod h1:48U2I+QQUYhsFrg2SY6r+nJzeOtjey7j//WBESw+qyQ= google.golang.org/genproto/googleapis/rpc v0.0.0-20260311181403-84a4fc48630c h1:xgCzyF2LFIO/0X2UAoVRiXKU5Xg6VjToG4i2/ecSswk= google.golang.org/genproto/googleapis/rpc v0.0.0-20260311181403-84a4fc48630c/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8= google.golang.org/grpc v1.12.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= google.golang.org/grpc v1.29.1/go.mod h1:itym6AZVZYACWQqET3MqgPpjcuV5QH3BxFS3IjizoKk= google.golang.org/grpc v1.32.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc= google.golang.org/grpc v1.38.0/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQdJfM= google.golang.org/grpc v1.79.2 h1:fRMD94s2tITpyJGtBBn7MkMseNpOZU8ZxgC3MMBaXRU= google.golang.org/grpc v1.79.2/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ= google.golang.org/grpc/cmd/protoc-gen-go-grpc v0.0.0-20200910201057-6591123024b3/go.mod h1:6Kw0yEErY5E/yWrBtf03jp27GLLJujG4z/JK95pnjjw= google.golang.org/grpc/examples v0.0.0-20250407062114-b368379ef8f6 h1:ExN12ndbJ608cboPYflpTny6mXSzPrDLh0iTaVrRrds= google.golang.org/grpc/examples v0.0.0-20250407062114-b368379ef8f6/go.mod h1:6ytKWczdvnpnO+m+JiG9NjEDzR1FJfsnmJdG7B8QVZ8= google.golang.org/grpc/security/advancedtls v1.0.0 h1:/KQ7VP/1bs53/aopk9QhuPyFAp9Dm9Ejix3lzYkCrDA= google.golang.org/grpc/security/advancedtls v1.0.0/go.mod h1:o+s4go+e1PJ2AjuQMY5hU82W7lDlefjJA6FqEHRVHWk= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGjtUeSXeh4= google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/cheggaaa/pb.v1 v1.0.27/go.mod h1:V/YB90LKu/1FcN3WVnfiiE5oMCibMjukxqG/qStrOgw= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= gopkg.in/go-playground/assert.v1 v1.2.1/go.mod h1:9RXL0bg/zibRAgZUYszZSwO/z8Y/a8bDuhia5mkpMnE= gopkg.in/go-playground/validator.v8 v8.18.2/go.mod h1:RX2a/7Ha8BgOhfk7j780h4/u/RRjR0eouCJSH80/M2Y= gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s= gopkg.in/ini.v1 v1.51.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/ini.v1 v1.51.1/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/mgo.v2 v2.0.0-20180705113604-9856a29383ce/go.mod h1:yeKp02qBN3iKW1OzL3MGk2IdtZzaj7SFntXj72NppTA= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc= gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/yaml.v2 v2.0.0-20170812160011-eb3733d160e7/go.mod h1:JAlM8MvJe8wmxCU4Bli9HhUf9+ttbYbLASfIpnQbh74= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20191120175047-4206685974f2/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gorgonia.org/cu v0.9.0-beta/go.mod h1:RPEPIfaxxqUmeRe7T1T8a0NER+KxBI2McoLEXhP1Vd8= gorgonia.org/cu v0.9.3/go.mod h1:LgyAYDkN7HWhh8orGnCY2R8pP9PYbO44ivEbLMatkVU= gorgonia.org/dawson v1.1.0/go.mod h1:Px1mcziba8YUBIDsbzGwbKJ11uIblv/zkln4jNrZ9Ws= gorgonia.org/dawson v1.2.0/go.mod h1:Px1mcziba8YUBIDsbzGwbKJ11uIblv/zkln4jNrZ9Ws= gorgonia.org/gorgonia v0.9.2/go.mod h1:ZtOb9f/wM2OMta1ISGspQ4roGDgz9d9dKOaPNvGR+ec= gorgonia.org/gorgonia v0.9.16/go.mod h1:EnZtUbxgbqMx8eCTGPq8C0RfBlr/WllVtMyAFUYG+b4= gorgonia.org/tensor v0.9.0-beta/go.mod h1:05Y4laKuVlj4qFoZIZW1q/9n1jZkgDBOLmKXZdBLG1w= gorgonia.org/tensor v0.9.16/go.mod h1:75SMdLLhZ+2oB0/EE8lFEIt1Caoykdd4bz1mAe59deg= gorgonia.org/tensor v0.9.19/go.mod h1:75SMdLLhZ+2oB0/EE8lFEIt1Caoykdd4bz1mAe59deg= gorgonia.org/vecf32 v0.7.0/go.mod h1:iHG+kvTMqGYA0SgahfO2k62WRnxmHsqAREGbayRDzy8= gorgonia.org/vecf32 v0.9.0/go.mod h1:NCc+5D2oxddRL11hd+pCB1PEyXWOyiQxfZ/1wwhOXCA= gorgonia.org/vecf64 v0.7.0/go.mod h1:1y4pmcSd+wh3phG+InwWQjYrqwyrtN9h27WLFVQfV1Q= gorgonia.org/vecf64 v0.9.0/go.mod h1:hp7IOWCnRiVQKON73kkC/AUMtEXyf9kGlVrtPQ9ccVA= gorm.io/driver/mysql v1.0.3/go.mod h1:twGxftLBlFgNVNakL7F+P/x9oYqoymG3YYT8cAfI9oI= gorm.io/driver/mysql v1.3.4 h1:/KoBMgsUHC3bExsekDcmNYaBnfH2WNeFuXqqrqMc98Q= gorm.io/driver/mysql v1.3.4/go.mod h1:s4Tq0KmD0yhPGHbZEwg1VPlH0vT/GBHJZorPzhcxBUE= gorm.io/driver/postgres v1.0.8/go.mod h1:4eOzrI1MUfm6ObJU/UcmbXyiHSs8jSwH95G5P5dxcAg= gorm.io/driver/postgres v1.3.5 h1:oVLmefGqBTlgeEVG6LKnH6krOlo4TZ3Q/jIK21KUMlw= gorm.io/driver/postgres v1.3.5/go.mod h1:EGCWefLFQSVFrHGy4J8EtiHCWX5Q8t0yz2Jt9aKkGzU= gorm.io/driver/sqlite v1.1.4/go.mod h1:mJCeTFr7+crvS+TRnWc5Z3UvwxUN1BGBLMrf5LA9DYw= gorm.io/gorm v1.20.4/go.mod h1:0HFTzE/SqkGTzK6TlDPPQbAYCluiVvhzoA1+aVyzenw= gorm.io/gorm v1.20.7/go.mod h1:0HFTzE/SqkGTzK6TlDPPQbAYCluiVvhzoA1+aVyzenw= gorm.io/gorm v1.20.12/go.mod h1:0HFTzE/SqkGTzK6TlDPPQbAYCluiVvhzoA1+aVyzenw= gorm.io/gorm v1.23.4/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk= gorm.io/gorm v1.23.6 h1:KFLdNgri4ExFFGTRGGFWON2P1ZN28+9SJRN8voOoYe0= gorm.io/gorm v1.23.6/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= k8s.io/klog/v2 v2.140.0 h1:Tf+J3AH7xnUzZyVVXhTgGhEKnFqye14aadWv7bzXdzc= k8s.io/klog/v2 v2.140.0/go.mod h1:o+/RWfJ6PwpnFn7OyAG3QnO47BFsymfEfrz6XyYSSp0= lukechampine.com/uint128 v1.1.1/go.mod h1:c4eWIwlEGaxC/+H1VguhU4PHXNWDCDMUlWdIWl2j1gk= modernc.org/cc v1.0.0 h1:nPibNuDEx6tvYrUAtvDTTw98rx5juGsa5zuDnKwEEQQ= modernc.org/cc v1.0.0/go.mod h1:1Sk4//wdnYJiUIxnW8ddKpaOJCF37yAdqYnkxUpaYxw= modernc.org/cc/v3 v3.36.0/go.mod h1:NFUHyPn4ekoC/JHeZFfZurN6ixxawE1BnVonP/oahEI= modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis= modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0= modernc.org/ccgo/v3 v3.0.0-20220428102840-41399a37e894/go.mod h1:eI31LL8EwEBKPpNpA4bU1/i+sKOwOrQy8D87zWUcRZc= modernc.org/ccgo/v3 v3.0.0-20220430103911-bc99d88307be/go.mod h1:bwdAnOoaIt8Ax9YdWGjxWsdkPcZyRPHqrOvJxaKAKGw= modernc.org/ccgo/v3 v3.16.4/go.mod h1:tGtX0gE9Jn7hdZFeU88slbTh1UtCYKusWOoCJuvkWsQ= modernc.org/ccgo/v3 v3.16.6/go.mod h1:tGtX0gE9Jn7hdZFeU88slbTh1UtCYKusWOoCJuvkWsQ= modernc.org/ccgo/v4 v4.32.0 h1:hjG66bI/kqIPX1b2yT6fr/jt+QedtP2fqojG2VrFuVw= modernc.org/ccgo/v4 v4.32.0/go.mod h1:6F08EBCx5uQc38kMGl+0Nm0oWczoo1c7cgpzEry7Uc0= modernc.org/ccorpus v1.11.6/go.mod h1:2gEUTrWqdpH2pXsmTM1ZkjeSrUWDpjMu2T6m29L/ErQ= modernc.org/fileutil v1.4.0 h1:j6ZzNTftVS054gi281TyLjHPp6CPHr2KCxEXjEbD6SM= modernc.org/fileutil v1.4.0/go.mod h1:EqdKFDxiByqxLk8ozOxObDSfcVOv/54xDs/DUHdvCUU= modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI= modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito= modernc.org/gc/v3 v3.1.2 h1:ZtDCnhonXSZexk/AYsegNRV1lJGgaNZJuKjJSWKyEqo= modernc.org/gc/v3 v3.1.2/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY= modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks= modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI= modernc.org/golex v1.0.0/go.mod h1:b/QX9oBD/LhixY6NDh+IdGv17hgB+51fET1i2kPSmvk= modernc.org/httpfs v1.0.6/go.mod h1:7dosgurJGp0sPaRanU53W4xZYKh14wfzX420oZADeHM= modernc.org/libc v0.0.0-20220428101251-2d5f3daf273b/go.mod h1:p7Mg4+koNjc8jkqwcoFBJx7tXkpj00G77X7A72jXPXA= modernc.org/libc v1.16.0/go.mod h1:N4LD6DBE9cf+Dzf9buBlzVJndKr/iJHG97vGLHYnb5A= modernc.org/libc v1.16.1/go.mod h1:JjJE0eu4yeK7tab2n4S1w8tlWd9MxXLRzheaRnAKymU= modernc.org/libc v1.16.7/go.mod h1:hYIV5VZczAmGZAnG15Vdngn5HSF5cSkbvfz2B7GRuVU= modernc.org/libc v1.70.0 h1:U58NawXqXbgpZ/dcdS9kMshu08aiA6b7gusEusqzNkw= modernc.org/libc v1.70.0/go.mod h1:OVmxFGP1CI/Z4L3E0Q3Mf1PDE0BucwMkcXjjLntvHJo= modernc.org/libquickjs v0.12.3 h1:2IU9B6njBmce2PuYttJDkXeoLRV9WnvgP+eU5HAC8YI= modernc.org/libquickjs v0.12.3/go.mod h1:iCsgVxnHTX3i0YPxxHBmJk0GLA5sVUHXWI/090UXgeE= modernc.org/mathutil v1.0.0/go.mod h1:wU0vUrJsVWBZ4P6e7xtFJEhFSNsfRLJ8H458uRjg03k= modernc.org/mathutil v1.2.2/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E= modernc.org/mathutil v1.4.1/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E= modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= modernc.org/memory v1.1.1/go.mod h1:/0wo5ibyrQiaoUoH7f9D8dnglAmILJ5/cxZlRECf+Nw= modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= modernc.org/opt v0.1.1/go.mod h1:WdSiB5evDcignE70guQKxYUl14mgWtbClRi5wmkkTX0= modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8= modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns= modernc.org/quickjs v0.17.1 h1:CbYnbTf7ksZk9YZ1rRM2Ab1Zfi+X6s50kXiOhpd2NIg= modernc.org/quickjs v0.17.1/go.mod h1:hATT7DIJc33I5Q/Fjffhm0tpUHNSqdKHma/ossibTA0= modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w= modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE= modernc.org/sqlite v1.17.3/go.mod h1:10hPVYar9C0kfXuTWGz8s0XtB8uAGymUy51ZzStYe3k= modernc.org/sqlite v1.47.0 h1:R1XyaNpoW4Et9yly+I2EeX7pBza/w+pmYee/0HJDyKk= modernc.org/sqlite v1.47.0/go.mod h1:hWjRO6Tj/5Ik8ieqxQybiEOUXy0NJFNp2tpvVpKlvig= modernc.org/strutil v1.1.0/go.mod h1:lstksw84oURvj9y3tn8lGvRxyRC1S2+g5uuIzNfIOBs= modernc.org/strutil v1.1.1/go.mod h1:DE+MQQ/hjKBZS2zNInV5hhcipt5rLPWkmpbGeW5mmdw= modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= modernc.org/tcl v1.13.1/go.mod h1:XOLfOwzhkljL4itZkK6T72ckMgvj0BDsnKNdZVUOecw= modernc.org/token v1.0.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= modernc.org/xc v1.0.0/go.mod h1:mRNCo0bvLjGhHO9WsyuKVU4q0ceiDDDoEeWDJHrNx8I= modernc.org/z v1.5.1/go.mod h1:eWFB510QWW5Th9YGZT81s+LwvaAs3Q2yr4sP0rmLkv8= rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= ================================================ FILE: logics/cf.go ================================================ // Copyright 2025 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package logics import ( "io" "sync" "time" "github.com/gorse-io/gorse/common/ann" "github.com/gorse-io/gorse/common/encoding" "github.com/gorse-io/gorse/common/floats" "github.com/gorse-io/gorse/common/log" "github.com/gorse-io/gorse/storage/cache" "github.com/pkg/errors" "github.com/samber/lo" "go.uber.org/zap" ) func distance(a, b []float32) float32 { return -floats.Dot(a, b) } type MatrixFactorizationItems struct { timestamp time.Time items []string itemsLock sync.Mutex index *ann.HNSW[[]float32] dimension int } func NewMatrixFactorizationItems(timestamp time.Time) *MatrixFactorizationItems { return &MatrixFactorizationItems{ timestamp: timestamp, items: make([]string, 0), index: ann.NewHNSW(distance), } } func (items *MatrixFactorizationItems) Add(itemId string, v []float32) { // Check dimension items.itemsLock.Lock() if items.dimension == 0 { items.dimension = len(v) } else if items.dimension != len(v) { log.Logger().Error("dimension mismatch", zap.Int("dimension", len(v))) return } // Push item items.items = append(items.items, "") items.itemsLock.Unlock() j := items.index.Add(v) items.itemsLock.Lock() items.items[j] = itemId items.itemsLock.Unlock() } func (items *MatrixFactorizationItems) Search(v []float32, n int) []cache.Score { scores := items.index.SearchVector(v, n, false) return lo.Map(scores, func(v lo.Tuple2[int, float32], _ int) cache.Score { return cache.Score{ Id: items.items[v.A], Score: -float64(v.B), Timestamp: items.timestamp, } }) } func (items *MatrixFactorizationItems) Marshal(w io.Writer) error { if err := encoding.WriteGob(w, items.timestamp); err != nil { return errors.WithStack(err) } if err := encoding.WriteGob(w, items.dimension); err != nil { return errors.WithStack(err) } if err := items.index.Marshal(w); err != nil { return errors.WithStack(err) } numItems := int64(len(items.items)) if err := encoding.WriteGob(w, numItems); err != nil { return errors.WithStack(err) } for _, item := range items.items { if err := encoding.WriteGob(w, item); err != nil { return errors.WithStack(err) } } return nil } func (items *MatrixFactorizationItems) Unmarshal(r io.Reader) error { if err := encoding.ReadGob(r, &items.timestamp); err != nil { return errors.WithStack(err) } if err := encoding.ReadGob(r, &items.dimension); err != nil { return errors.WithStack(err) } if err := items.index.Unmarshal(r); err != nil { return errors.WithStack(err) } var numItems int64 if err := encoding.ReadGob(r, &numItems); err != nil { return errors.WithStack(err) } items.items = make([]string, numItems) for i := int64(0); i < numItems; i++ { if err := encoding.ReadGob(r, &items.items[i]); err != nil { return errors.WithStack(err) } } return nil } type MatrixFactorizationUsers struct { embeddings map[string][]float32 } func NewMatrixFactorizationUsers() *MatrixFactorizationUsers { return &MatrixFactorizationUsers{ embeddings: make(map[string][]float32), } } func (users *MatrixFactorizationUsers) Add(userId string, v []float32) { users.embeddings[userId] = v } func (users *MatrixFactorizationUsers) Get(userId string) ([]float32, bool) { v, ok := users.embeddings[userId] return v, ok } func (users *MatrixFactorizationUsers) Marshal(w io.Writer) error { numUsers := int64(len(users.embeddings)) if err := encoding.WriteGob(w, numUsers); err != nil { return errors.WithStack(err) } for userId, embedding := range users.embeddings { if err := encoding.WriteString(w, userId); err != nil { return errors.WithStack(err) } if err := encoding.WriteSlice(w, embedding); err != nil { return errors.WithStack(err) } } return nil } func (users *MatrixFactorizationUsers) Unmarshal(r io.Reader) error { var numUsers int64 if err := encoding.ReadGob(r, &numUsers); err != nil { return errors.WithStack(err) } users.embeddings = make(map[string][]float32, numUsers) for i := int64(0); i < numUsers; i++ { userId, err := encoding.ReadString(r) if err != nil { return errors.WithStack(err) } var embedding []float32 if err = encoding.ReadSlice(r, &embedding); err != nil { return errors.WithStack(err) } users.embeddings[userId] = embedding } return nil } ================================================ FILE: logics/cf_test.go ================================================ // Copyright 2025 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package logics import ( "github.com/gorse-io/gorse/storage/cache" "github.com/stretchr/testify/assert" "os" "path/filepath" "testing" "time" ) func TestMatrixFactorizationItems(t *testing.T) { ts := time.Now() items := NewMatrixFactorizationItems(ts) items.Add("1", []float32{1, 1, 1}) items.Add("2", []float32{2, 2, 2}) items.Add("3", []float32{3, 3, 3}) items.Add("4", []float32{4, 4, 4}) items.Add("5", []float32{5, 5, 5}) path := filepath.Join(t.TempDir(), "items") f, err := os.Create(path) assert.NoError(t, err) defer f.Close() err = items.Marshal(f) assert.NoError(t, err) f, err = os.Open(path) assert.NoError(t, err) defer f.Close() items2 := NewMatrixFactorizationItems(time.Time{}) err = items2.Unmarshal(f) assert.NoError(t, err) assert.Equal(t, items.timestamp.UnixNano(), items2.timestamp.UnixNano()) assert.Equal(t, items.dimension, items2.dimension) assert.Equal(t, items.items, items2.items) scores := items2.Search([]float32{1, 1, 1}, 3) assert.Equal(t, []cache.Score{ {Id: "5", Score: 15, Timestamp: items2.timestamp}, {Id: "4", Score: 12, Timestamp: items2.timestamp}, {Id: "3", Score: 9, Timestamp: items2.timestamp}, }, scores) } func TestNewMatrixFactorizationUsers(t *testing.T) { users := NewMatrixFactorizationUsers() users.Add("1", []float32{1, 1, 1}) users.Add("2", []float32{2, 2, 2}) users.Add("3", []float32{3, 3, 3}) path := filepath.Join(t.TempDir(), "users") f, err := os.Create(path) assert.NoError(t, err) defer f.Close() err = users.Marshal(f) assert.NoError(t, err) f, err = os.Open(path) assert.NoError(t, err) defer f.Close() users2 := NewMatrixFactorizationUsers() err = users2.Unmarshal(f) assert.NoError(t, err) assert.Equal(t, users.embeddings, users2.embeddings) } ================================================ FILE: logics/chat.go ================================================ // Copyright 2025 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package logics import ( "context" "encoding/json" "strings" "github.com/gorse-io/gorse/common/reranker" "github.com/gorse-io/gorse/config" "github.com/gorse-io/gorse/storage/cache" "github.com/gorse-io/gorse/storage/data" "github.com/nikolalohinski/gonja/v2" "github.com/nikolalohinski/gonja/v2/exec" "github.com/sashabaranov/go-openai" "github.com/yuin/goldmark" "github.com/yuin/goldmark/ast" "github.com/yuin/goldmark/text" ) type FeedbackItem struct { FeedbackType string data.Item } type ChatReranker struct { queryTemplate *exec.Template docTemplate *exec.Template client *reranker.Client model string } func NewChatReranker(cfg config.RerankerAPIConfig, queryTemplate, docTemplate string) (*ChatReranker, error) { // create reranker client client := reranker.NewClient(cfg.AuthToken, cfg.URL) // create templates qTpl, err := gonja.FromString(queryTemplate) if err != nil { return nil, err } dTpl, err := gonja.FromString(docTemplate) if err != nil { return nil, err } return &ChatReranker{ queryTemplate: qTpl, docTemplate: dTpl, client: client, model: cfg.Model, }, nil } func (r *ChatReranker) Rank(ctx context.Context, user *data.User, feedback []*FeedbackItem, items []*data.Item) ([]cache.Score, error) { // render query var queryBuf strings.Builder queryCtx := exec.NewContext(map[string]any{ "user": user, "feedback": feedback, }) if err := r.queryTemplate.Execute(&queryBuf, queryCtx); err != nil { return nil, err } // render documents documents := make([]string, len(items)) for i, item := range items { var docBuf strings.Builder docCtx := exec.NewContext(map[string]any{ "item": item, }) if err := r.docTemplate.Execute(&docBuf, docCtx); err != nil { return nil, err } documents[i] = docBuf.String() } // rerank resp, err := r.client.Rerank(ctx, reranker.RerankRequest{ Model: r.model, Query: queryBuf.String(), Documents: documents, }) if err != nil { return nil, err } // sort items result := make([]cache.Score, len(resp.Results)) for i, rerankResult := range resp.Results { result[i].Id = items[rerankResult.Index].ItemId result[i].Score = rerankResult.RelevanceScore } return result, nil } // parseArrayFromCompletion parse JSON array from completion. // If the completion contains a JSON array, it will return each element in the array. // If the completion contains a JSON object, it will return the object as a string. // Otherwise, it will return the completion as a string. func parseArrayFromCompletion(completion string) []string { source := []byte(stripThinkInCompletion(completion)) root := goldmark.DefaultParser().Parse(text.NewReader(source)) for n := root.FirstChild(); n != nil; n = n.NextSibling() { if n.Kind() != ast.KindFencedCodeBlock { continue } if codeBlock, ok := n.(*ast.FencedCodeBlock); ok { if string(codeBlock.Language(source)) == "json" { bytes := codeBlock.Text(source) if bytes[0] == '[' { var temp []any err := json.Unmarshal(bytes, &temp) if err != nil { return []string{string(bytes)} } var result []string for _, v := range temp { var bytes []byte switch typed := v.(type) { case string: bytes = []byte(typed) default: bytes, err = json.Marshal(v) if err != nil { return []string{string(bytes)} } } result = append(result, string(bytes)) } return result } return []string{string(bytes)} } else if string(codeBlock.Language(source)) == "csv" { // If the code block is CSV, retrieve 1st column as IDs. bytes := codeBlock.Text(source) lines := strings.Split(string(bytes), "\n") var result []string for _, line := range lines { fields := strings.Split(line, ",") if len(fields) > 0 && strings.TrimSpace(fields[0]) != "" { result = append(result, strings.TrimSpace(fields[0])) } } return result } } } var result []string for _, line := range strings.Split(string(source), "\n") { line = strings.TrimSpace(line) if line != "" { result = append(result, line) } } return result } func isThrottled(err error) bool { switch e := err.(type) { case *openai.APIError: if e.HTTPStatusCode == 429 { return true } case *openai.RequestError: return e.HTTPStatusCode == 504 || e.HTTPStatusCode == 520 } return false } ================================================ FILE: logics/chat_test.go ================================================ // Copyright 2025 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package logics import ( "net/http" "testing" "github.com/gorse-io/gorse/common/reranker" "github.com/gorse-io/gorse/config" "github.com/gorse-io/gorse/storage/cache" "github.com/gorse-io/gorse/storage/data" "github.com/stretchr/testify/assert" ) func TestChatReranker(t *testing.T) { s := reranker.NewMockServer() go func() { err := s.Start() if err != nil && err != http.ErrServerClosed { panic(err) } }() s.Ready() defer s.Close() reranker, err := NewChatReranker(config.RerankerAPIConfig{ AuthToken: s.AuthToken(), URL: s.URL(), Model: "gte-rerank", }, "{{ user.UserId }} is a {{ user.Comment }} watched the following movies recently: {% for item in feedback %}{{ item.Comment }}, {% endfor %}", "{{ item.Comment }}") assert.NoError(t, err) items, err := reranker.Rank(t.Context(), &data.User{ UserId: "Tom", Comment: "horror movie enthusiast", }, []*FeedbackItem{ {Item: data.Item{ItemId: "tt0387564", Comment: "Saw"}}, {Item: data.Item{ItemId: "tt0432348", Comment: "Saw II"}}, {Item: data.Item{ItemId: "tt0435761", Comment: "Saw III"}}, }, []*data.Item{ {ItemId: "tt1233227", Comment: "Harry Potter and the Half-Blood Prince"}, {ItemId: "tt0926084", Comment: "Harry Potter and the Deathly Hallows: Part 1"}, {ItemId: "tt0890870", Comment: "Saw IV"}, {ItemId: "tt1132626", Comment: "Saw VI"}, {ItemId: "tt0435761", Comment: "Saw V"}, }) assert.NoError(t, err) assert.Equal(t, []cache.Score{ {Id: "tt1233227", Score: 1}, {Id: "tt0926084", Score: 0.5}, {Id: "tt0890870", Score: 0.3333333333333333}, {Id: "tt1132626", Score: 0.25}, {Id: "tt0435761", Score: 0.2}, }, items) } func TestParseArrayFromCompletion(t *testing.T) { // parse JSON object completion := "```json\n{\"a\": 1, \"b\": 2}\n```" parsed := parseArrayFromCompletion(completion) assert.Equal(t, []string{"{\"a\": 1, \"b\": 2}\n"}, parsed) // parse JSON array completion = "```json\n[1, 2]\n```" parsed = parseArrayFromCompletion(completion) assert.Equal(t, []string{"1", "2"}, parsed) // parse CSV completion = "```csv\n1\n2\n3\n```" parsed = parseArrayFromCompletion(completion) assert.Equal(t, []string{"1", "2", "3"}, parsed) // parse text completion = "Hello, world!\nThis is a test." parsed = parseArrayFromCompletion(completion) assert.Equal(t, []string{"Hello, world!", "This is a test."}, parsed) // strip think completion = "helloWorld!" assert.Equal(t, "World!", stripThinkInCompletion(completion)) } ================================================ FILE: logics/external.go ================================================ // Copyright 2025 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package logics import ( "encoding/json" "io" "net/http" "os" "strings" "github.com/gorse-io/gorse/config" "github.com/pkg/errors" "github.com/samber/lo" "modernc.org/quickjs" ) type External struct { vm *quickjs.VM client *http.Client script string name string } func NewExternal(cfg config.ExternalConfig) (*External, error) { vm, err := quickjs.NewVM() if err != nil { return nil, errors.WithStack(err) } // Add environment variables env, err := vm.NewObjectValue() if err != nil { return nil, errors.WithStack(err) } for _, e := range os.Environ() { parts := strings.SplitN(e, "=", 2) if len(parts) != 2 { continue } key, err := vm.NewAtom(parts[0]) if err != nil { return nil, errors.WithStack(err) } value := parts[1] if err := env.SetProperty(key, value); err != nil { return nil, errors.WithStack(err) } } envKey, err := vm.NewAtom("env") if err != nil { return nil, errors.WithStack(err) } if err := vm.GlobalObject().SetProperty(envKey, env); err != nil { return nil, errors.WithStack(err) } // Register fetch function external := &External{ vm: vm, client: &http.Client{}, script: cfg.Script, name: cfg.Name, } if err = vm.RegisterFunc("fetch", external.fetch, false); err != nil { return nil, errors.WithStack(err) } return external, nil } func (e *External) Close() error { return e.vm.Close() } func (e *External) Pull(userId string) (res []string, err error) { defer func() { if r := recover(); r != nil { var ok bool err, ok = r.(error) if !ok { err = errors.Errorf("%v", r) } } }() userIdKey, err := e.vm.NewAtom("user_id") if err != nil { return nil, errors.WithStack(err) } err = e.vm.GlobalObject().SetProperty(userIdKey, userId) if err != nil { return nil, errors.WithStack(err) } result, err := e.vm.Eval(e.script, quickjs.EvalGlobal) if err != nil { return nil, errors.WithStack(err) } switch v := result.(type) { case string: var items []string if err := json.Unmarshal([]byte(v), &items); err != nil { return nil, errors.WithStack(err) } return items, nil case *quickjs.Object: var items []string if err := json.Unmarshal([]byte(v.String()), &items); err != nil { return nil, errors.WithStack(err) } return items, nil default: return nil, errors.New("script must return string or object") } } func (e *External) fetch(args ...quickjs.Value) quickjs.Value { var ( url string req = quickjs.UndefinedValue ) if len(args) == 1 { switch v := lo.Must(args[0].Any()).(type) { case string: url = v case *quickjs.Object: req = args[0] default: panic("fetch requires first argument to be string or object") } } else if len(args) == 2 { if _, ok := lo.Must(args[0].Any()).(string); !ok { panic("fetch requires first argument to be string") } if _, ok := lo.Must(args[1].Any()).(*quickjs.Object); !ok { panic("fetch requires second argument to be object") } url = lo.Must(args[0].Any()).(string) req = args[1] } else { panic("fetch requires 1 or 2 arguments") } r := e.parseRequest(url, req) resp, err := e.client.Do(r) if err != nil { panic(err) } return e.newResponse(resp) } // parseRequest parse Fetch API Request. func (e *External) parseRequest(url string, req quickjs.Value) *http.Request { method := "GET" headers := make(map[string]string) body := "" if !req.IsUndefined() { // Request.method methodKey := lo.Must(e.vm.NewAtom("method")) methodValue := lo.Must(req.GetPropertyValue(methodKey)) if !methodValue.IsUndefined() { method = lo.Must(methodValue.Any()).(string) } // Request.headers headersKey := lo.Must(e.vm.NewAtom("headers")) headersValue := lo.Must(req.GetPropertyValue(headersKey)) if !headersValue.IsUndefined() { if headersObj, ok := lo.Must(headersValue.Any()).(*quickjs.Object); ok { if err := json.Unmarshal([]byte(headersObj.String()), &headers); err != nil { panic(err) } } } // Request.url urlKey := lo.Must(e.vm.NewAtom("url")) urlValue := lo.Must(req.GetPropertyValue(urlKey)) if !urlValue.IsUndefined() { url = lo.Must(urlValue.Any()).(string) } // Request.body bodyKey := lo.Must(e.vm.NewAtom("body")) bodyValue := lo.Must(req.GetPropertyValue(bodyKey)) if !bodyValue.IsUndefined() { body = lo.Must(bodyValue.Any()).(string) } } r := lo.Must(http.NewRequest(method, url, strings.NewReader(body))) for key, value := range headers { r.Header.Add(key, value) } return r } // newResponse convert http.Response to Fetch API Response. func (e *External) newResponse(resp *http.Response) quickjs.Value { if resp == nil { return quickjs.UndefinedValue } response := lo.Must(e.vm.NewObjectValue()) // Response.ok okKey := lo.Must(e.vm.NewAtom("ok")) lo.Must0(response.SetProperty(okKey, resp.StatusCode >= 200 && resp.StatusCode < 300)) // Response.status statusKey := lo.Must(e.vm.NewAtom("status")) lo.Must0(response.SetProperty(statusKey, resp.StatusCode)) // Response.statusText statusTextKey := lo.Must(e.vm.NewAtom("statusText")) lo.Must0(response.SetProperty(statusTextKey, resp.Status)) // Response.body bodyKey := lo.Must(e.vm.NewAtom("body")) body := lo.Must(io.ReadAll(resp.Body)) lo.Must0(response.SetProperty(bodyKey, string(body))) // Response.headers headersKey := lo.Must(e.vm.NewAtom("headers")) headers := lo.Must(e.vm.NewObjectValue()) for key, values := range resp.Header { headerKey := lo.Must(e.vm.NewAtom(key)) lo.Must0(headers.SetProperty(headerKey, strings.Join(values, ", "))) } lo.Must0(response.SetProperty(headersKey, headers)) return response } ================================================ FILE: logics/external_test.go ================================================ // Copyright 2025 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package logics import ( "encoding/json" "fmt" "io" "net/http" "net/http/httptest" "testing" "github.com/gorse-io/gorse/config" "github.com/stretchr/testify/assert" "modernc.org/quickjs" ) func TestEnv(t *testing.T) { t.Setenv("TEST_ENV", "test_value") external, err := NewExternal(config.ExternalConfig{}) assert.NoError(t, err) defer external.Close() value, err := external.vm.Eval(`env.TEST_ENV`, quickjs.EvalGlobal) assert.NoError(t, err) assert.Equal(t, "test_value", value) } func TestFetch(t *testing.T) { var ( req *http.Request body string ) ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { req = r b, err := io.ReadAll(r.Body) assert.NoError(t, err) body = string(b) fmt.Fprintln(w, "Hello, client") })) defer ts.Close() external, err := NewExternal(config.ExternalConfig{}) assert.NoError(t, err) defer external.Close() response, err := external.vm.Eval(`fetch("`+ts.URL+`")`, quickjs.EvalGlobal) assert.NoError(t, err) if assert.NotNil(t, req) { assert.Equal(t, "GET", req.Method) } if assert.IsType(t, &quickjs.Object{}, response) { var resp map[string]any err = json.Unmarshal([]byte(response.(*quickjs.Object).String()), &resp) assert.NoError(t, err) assert.Equal(t, true, resp["ok"]) assert.Equal(t, float64(200), resp["status"]) assert.Equal(t, "200 OK", resp["statusText"]) assert.Equal(t, "Hello, client\n", resp["body"]) headers := resp["headers"].(map[string]any) assert.Contains(t, headers, "Content-Length") assert.Contains(t, headers, "Date") } _, err = external.vm.Eval(`fetch({method: "POST", url: "`+ts.URL+`"})`, quickjs.EvalGlobal) assert.NoError(t, err) if assert.NotNil(t, req) { assert.Equal(t, "POST", req.Method) } _, err = external.vm.Eval(`fetch("`+ts.URL+`", { method: "PUT", headers: { "Content-Type": "application/json" }, body: JSON.stringify({message: "Hello, server"}) })`, quickjs.EvalGlobal) assert.NoError(t, err) if assert.NotNil(t, req) { assert.Equal(t, "PUT", req.Method) assert.Equal(t, "application/json", req.Header.Get("Content-Type")) assert.Equal(t, `{"message":"Hello, server"}`, body) } } func TestExternal(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { userId := r.URL.Query().Get("user_id") if userId == "1" { fmt.Fprintln(w, `["item_1", "item_2", "item_3"]`) } else { http.NotFound(w, r) } })) defer ts.Close() external, err := NewExternal(config.ExternalConfig{ Script: fmt.Sprintf(`fetch("%s?user_id=1").body`, ts.URL), Name: "test", }) assert.NoError(t, err) defer external.Close() items, err := external.Pull("1") assert.NoError(t, err) assert.Equal(t, []string{"item_1", "item_2", "item_3"}, items) } func TestException(t *testing.T) { external, err := NewExternal(config.ExternalConfig{ Script: `throw new Error("test error")`, Name: "test", }) assert.NoError(t, err) defer external.Close() _, err = external.Pull("1") assert.Error(t, err) } func TestPanic(t *testing.T) { external, err := NewExternal(config.ExternalConfig{ Script: `fetch({}, {})`, Name: "test", }) assert.NoError(t, err) defer external.Close() _, err = external.Pull("1") assert.Error(t, err) } ================================================ FILE: logics/item_to_item.go ================================================ // Copyright 2024 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package logics import ( "context" "errors" "sort" "strings" "sync" "time" "github.com/cenkalti/backoff/v5" "github.com/chewxy/math32" mapset "github.com/deckarep/golang-set/v2" "github.com/expr-lang/expr" "github.com/expr-lang/expr/vm" "github.com/gorse-io/gorse/common/ann" "github.com/gorse-io/gorse/common/floats" "github.com/gorse-io/gorse/common/heap" "github.com/gorse-io/gorse/common/log" "github.com/gorse-io/gorse/common/parallel" "github.com/gorse-io/gorse/config" "github.com/gorse-io/gorse/dataset" "github.com/gorse-io/gorse/storage/cache" "github.com/gorse-io/gorse/storage/data" "github.com/nikolalohinski/gonja/v2" "github.com/nikolalohinski/gonja/v2/exec" "github.com/samber/lo" "github.com/sashabaranov/go-openai" "github.com/tiktoken-go/tokenizer" "go.uber.org/zap" ) var cl100kBaseTokenizer tokenizer.Codec func init() { var err error cl100kBaseTokenizer, err = tokenizer.Get(tokenizer.Cl100kBase) if err != nil { panic(err) } } type ItemToItemOptions struct { TagsIDF []float32 UsersIDF []float32 OpenAIConfig config.OpenAIConfig } type ItemToItem interface { Timestamp() time.Time Count() int Get(i int) *data.Item Push(item *data.Item, feedback []int32) PopAll(i int) []cache.Score } func NewItemToItem(cfg config.ItemToItemConfig, n int, timestamp time.Time, opts *ItemToItemOptions) (ItemToItem, error) { switch cfg.Type { case "embedding": return newEmbeddingItemToItem(cfg, n, timestamp) case "tags": if opts == nil || opts.TagsIDF == nil { return nil, errors.New("tags IDF is required for tags item-to-item") } return newTagsItemToItem(cfg, n, timestamp, opts.TagsIDF) case "users": if opts == nil || opts.UsersIDF == nil { return nil, errors.New("users IDF is required for users item-to-item") } return newUsersItemToItem(cfg, n, timestamp, opts.UsersIDF) case "auto": if opts == nil || opts.TagsIDF == nil || opts.UsersIDF == nil { return nil, errors.New("tags and users IDF are required for auto item-to-item") } return newAutoItemToItem(cfg, n, timestamp, opts.TagsIDF, opts.UsersIDF) case "chat": if opts == nil || opts.OpenAIConfig.BaseURL == "" || opts.OpenAIConfig.AuthToken == "" { return nil, errors.New("OpenAI config is required for chat item-to-item") } return newChatItemToItem(cfg, n, timestamp, opts.OpenAIConfig) default: return nil, errors.New("invalid item-to-item type") } } type baseItemToItem[T any] struct { name string n int timestamp time.Time columnFunc *vm.Program index *ann.HNSW[T] items []*data.Item itemsLock sync.Mutex // Hidden items are stored separately without adding to the index, // and they have neighbors but are not neighbors of other items. hiddenItems []*data.Item hiddenVectors []T } func (b *baseItemToItem[T]) Timestamp() time.Time { return b.timestamp } func (b *baseItemToItem[T]) Count() int { return len(b.items) + len(b.hiddenItems) } func (b *baseItemToItem[T]) Get(i int) *data.Item { if i < len(b.items) { return b.items[i] } return b.hiddenItems[i-len(b.items)] } func (b *baseItemToItem[T]) pushItem(item *data.Item, v T) { if item.IsHidden { b.itemsLock.Lock() b.hiddenItems = append(b.hiddenItems, item) b.hiddenVectors = append(b.hiddenVectors, v) b.itemsLock.Unlock() } else { b.itemsLock.Lock() b.items = append(b.items, nil) b.itemsLock.Unlock() j := b.index.Add(v) b.itemsLock.Lock() b.items[j] = item b.itemsLock.Unlock() } } func (b *baseItemToItem[T]) PopAll(i int) []cache.Score { var results []lo.Tuple2[int, float32] if i < len(b.items) { // Non-hidden item: search by index var err error results, err = b.index.SearchIndex(i, b.n+1, true) if err != nil { log.Logger().Error("failed to search index", zap.Error(err)) return nil } } else { // Hidden item: search by vector results = b.index.SearchVector(b.hiddenVectors[i-len(b.items)], b.n, true) } return lo.Map(results, func(v lo.Tuple2[int, float32], _ int) cache.Score { return cache.Score{ Id: b.items[v.A].ItemId, Categories: b.items[v.A].Categories, Score: 1.0 / (1.0 + float64(v.B)), Timestamp: b.timestamp, } }) } type embeddingItemToItem struct { baseItemToItem[[]float32] dimension int } func newEmbeddingItemToItem(cfg config.ItemToItemConfig, n int, timestamp time.Time) (*embeddingItemToItem, error) { // Compile column expression columnFunc, err := expr.Compile(cfg.Column, expr.Env(map[string]any{ "item": data.Item{}, })) if err != nil { return nil, err } return &embeddingItemToItem{baseItemToItem: baseItemToItem[[]float32]{ name: cfg.Name, n: n, timestamp: timestamp, columnFunc: columnFunc, index: ann.NewHNSW(floats.Euclidean), }}, nil } func (e *embeddingItemToItem) Push(item *data.Item, _ []int32) { // Evaluate filter function result, err := expr.Run(e.columnFunc, map[string]any{ "item": item, }) if err != nil { log.Logger().Error("failed to evaluate column expression", zap.Any("item", item), zap.Error(err)) return } // Check column type v, ok := result.([]float32) if !ok { log.Logger().Error("invalid column type", zap.Any("column", result)) return } // Check dimension e.itemsLock.Lock() if e.dimension == 0 && len(v) > 0 { e.dimension = len(v) } else if e.dimension != len(v) { log.Logger().Error("invalid column dimension", zap.Int("dimension", len(v))) e.itemsLock.Unlock() return } e.itemsLock.Unlock() e.pushItem(item, v) } type tagsItemToItem struct { baseItemToItem[[]dataset.ID] IDF[dataset.ID] } func newTagsItemToItem(cfg config.ItemToItemConfig, n int, timestamp time.Time, idf []float32) (ItemToItem, error) { // Compile column expression columnFunc, err := expr.Compile(cfg.Column, expr.Env(map[string]any{ "item": data.Item{}, })) if err != nil { return nil, err } t := &tagsItemToItem{IDF: idf} t.baseItemToItem = baseItemToItem[[]dataset.ID]{ name: cfg.Name, n: n, timestamp: timestamp, columnFunc: columnFunc, index: ann.NewHNSW(t.distance), } return t, nil } func (t *tagsItemToItem) Push(item *data.Item, _ []int32) { // Evaluate filter function result, err := expr.Run(t.columnFunc, map[string]any{ "item": item, }) if err != nil { log.Logger().Error("failed to evaluate column expression", zap.Any("item", item), zap.Error(err)) return } // Extract tags tSet := mapset.NewSet[dataset.ID]() flatten(result, tSet) v := tSet.ToSlice() sort.Slice(v, func(i, j int) bool { return v[i] < v[j] }) t.pushItem(item, v) } type usersItemToItem struct { baseItemToItem[[]int32] IDF[int32] } func newUsersItemToItem(cfg config.ItemToItemConfig, n int, timestamp time.Time, idf []float32) (ItemToItem, error) { if cfg.Column != "" { return nil, errors.New("column is not supported in users item-to-item") } u := &usersItemToItem{IDF: idf} u.baseItemToItem = baseItemToItem[[]int32]{ name: cfg.Name, n: n, timestamp: timestamp, index: ann.NewHNSW(u.distance), } return u, nil } func (u *usersItemToItem) Push(item *data.Item, feedback []int32) { // Sort feedback sort.Slice(feedback, func(i, j int) bool { return feedback[i] < feedback[j] }) u.pushItem(item, feedback) } type autoItemToItem struct { baseItemToItem[lo.Tuple2[[]dataset.ID, []int32]] tIDF IDF[dataset.ID] uIDF IDF[int32] } func newAutoItemToItem(cfg config.ItemToItemConfig, n int, timestamp time.Time, tIDF, uIDF []float32) (ItemToItem, error) { a := &autoItemToItem{ tIDF: tIDF, uIDF: uIDF, } a.baseItemToItem = baseItemToItem[lo.Tuple2[[]dataset.ID, []int32]]{ name: cfg.Name, n: n, timestamp: timestamp, index: ann.NewHNSW[lo.Tuple2[[]dataset.ID, []int32]](a.distance), } return a, nil } func (a *autoItemToItem) Push(item *data.Item, feedback []int32) { // Extract tags tSet := mapset.NewSet[dataset.ID]() flatten(item.Labels, tSet) v := tSet.ToSlice() sort.Slice(v, func(i, j int) bool { return v[i] < v[j] }) // Sort feedback sort.Slice(feedback, func(i, j int) bool { return feedback[i] < feedback[j] }) a.pushItem(item, lo.Tuple2[[]dataset.ID, []int32]{A: v, B: feedback}) } func (a *autoItemToItem) distance(u, v lo.Tuple2[[]dataset.ID, []int32]) float32 { return (a.tIDF.distance(u.A, v.A) + a.uIDF.distance(u.B, v.B)) / 2 } type IDF[T dataset.ID | int32] []float32 func (idf IDF[T]) distance(a, b []T) float32 { commonSum, commonCount := idf.weightedSumCommonElements(a, b) if len(a) == len(b) && commonCount == float32(len(a)) { // If two items have the same tags, its distance is zero. return 0 } else if commonCount > 0 && len(a) > 0 && len(b) > 0 { // Add shrinkage to avoid division by zero return 1 - commonSum*commonCount/ math32.Sqrt(idf.weightedSum(a))/ math32.Sqrt(idf.weightedSum(b))/ (commonCount+100) } else { // If two items have no common tags, its distance is one. return 1 } } func (idf IDF[T]) weightedSumCommonElements(a, b []T) (float32, float32) { i, j, sum, count := 0, 0, float32(0), float32(0) for i < len(a) && j < len(b) { if a[i] == b[j] { sum += idf[a[i]] count++ i++ j++ } else if a[i] < b[j] { i++ } else if a[i] > b[j] { j++ } } return sum, count } func (idf IDF[T]) weightedSum(a []T) float32 { var sum float32 for _, i := range a { sum += idf[i] } return sum } func flatten(o any, tSet mapset.Set[dataset.ID]) { switch typed := o.(type) { case dataset.ID: tSet.Add(typed) return case []dataset.ID: tSet.Append(typed...) return case map[string]any: for _, v := range typed { flatten(v, tSet) } } } type chatItemToItem struct { *embeddingItemToItem template *exec.Template client *openai.Client chatCompletionModel string embeddingModel string embeddingDimensions int poolSize int } func newChatItemToItem(cfg config.ItemToItemConfig, n int, timestamp time.Time, openaiConfig config.OpenAIConfig) (*chatItemToItem, error) { // create embedding item-to-item recommender embedding, err := newEmbeddingItemToItem(cfg, n, timestamp) if err != nil { return nil, err } // parse template template, err := gonja.FromString(cfg.Prompt) if err != nil { return nil, err } // create openai client clientConfig := openai.DefaultConfig(openaiConfig.AuthToken) clientConfig.BaseURL = openaiConfig.BaseURL return &chatItemToItem{ embeddingItemToItem: embedding, template: template, client: openai.NewClientWithConfig(clientConfig), chatCompletionModel: openaiConfig.ChatCompletionModel, embeddingModel: openaiConfig.EmbeddingModel, embeddingDimensions: openaiConfig.EmbeddingDimensions, poolSize: min(openaiConfig.ChatCompletionRPM, openaiConfig.EmbeddingRPM), }, nil } func (g *chatItemToItem) PopAll(i int) []cache.Score { item := g.Get(i) // evaluate column expression and get embedding vector result, err := expr.Run(g.columnFunc, map[string]any{ "item": item, }) if err != nil { log.Logger().Error("failed to evaluate column expression", zap.Any("item", item), zap.Error(err)) return nil } embedding0, ok := result.([]float32) if !ok { log.Logger().Error("invalid column type", zap.Any("column", result)) return nil } // render template var buf strings.Builder ctx := exec.NewContext(map[string]any{ "item": item, }) if err := g.template.Execute(&buf, ctx); err != nil { log.Logger().Error("failed to execute template", zap.Error(err)) return nil } // chat completion start := time.Now() ids, _, _ := cl100kBaseTokenizer.Encode(buf.String()) resp, err := backoff.Retry(context.Background(), func() (openai.ChatCompletionResponse, error) { time.Sleep(parallel.ChatCompletionRequestsLimiter.Take(1)) time.Sleep(parallel.ChatCompletionTokensLimiter.Take(int64(len(ids)))) resp, err := g.client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ Model: g.chatCompletionModel, Messages: []openai.ChatCompletionMessage{{ Role: openai.ChatMessageRoleUser, Content: buf.String(), }}, }) if err == nil { return resp, nil } if isThrottled(err) { return openai.ChatCompletionResponse{}, err } return openai.ChatCompletionResponse{}, backoff.Permanent(err) }, backoff.WithBackOff(backoff.NewExponentialBackOff())) if err != nil { log.Logger().Error("failed to chat completion", zap.String("item_id", item.ItemId), zap.Error(err)) return nil } duration := time.Since(start) parsed := parseArrayFromCompletion(resp.Choices[0].Message.Content) log.OpenAILogger().Info("chat completion", zap.String("prompt", buf.String()), zap.String("completion", resp.Choices[0].Message.Content), zap.Strings("parsed", parsed), zap.Int("prompt_tokens", resp.Usage.PromptTokens), zap.Int("completion_tokens", resp.Usage.CompletionTokens), zap.Int("total_tokens", resp.Usage.TotalTokens), zap.Duration("duration", duration)) // message embedding embeddings := make([][]float32, len(parsed)) for i, message := range parsed { ids, _, _ := cl100kBaseTokenizer.Encode(message) resp, err := backoff.Retry(context.Background(), func() (openai.EmbeddingResponse, error) { time.Sleep(parallel.EmbeddingRequestsLimiter.Take(1)) time.Sleep(parallel.EmbeddingTokensLimiter.Take(int64(len(ids)))) resp, err := g.client.CreateEmbeddings(context.Background(), openai.EmbeddingRequest{ Input: message, Model: openai.EmbeddingModel(g.embeddingModel), Dimensions: g.embeddingDimensions, }) if err == nil { return resp, nil } if isThrottled(err) { return openai.EmbeddingResponse{}, err } return openai.EmbeddingResponse{}, backoff.Permanent(err) }, backoff.WithBackOff(backoff.NewExponentialBackOff())) if err != nil { log.Logger().Error("failed to create embeddings", zap.String("item_id", g.items[i].ItemId), zap.Error(err)) return nil } embeddings[i] = resp.Data[0].Embedding } // search index pq := heap.NewPriorityQueue(true) for _, embedding := range embeddings { score0 := floats.Euclidean(embedding, embedding0) scores := g.index.SearchVector(embedding, g.n+1, true) for _, score := range scores { if score.A != i { pq.Push(int32(score.A), score.B*score0) if pq.Len() > g.n { pq.Pop() } } } } scores := make([]cache.Score, pq.Len()) for i := pq.Len() - 1; i >= 0; i-- { id, score := pq.Pop() scores[i] = cache.Score{ Id: g.items[id].ItemId, Categories: g.items[id].Categories, Score: 1.0 / (1.0 + float64(score)), Timestamp: g.timestamp, } } return scores } func stripThinkInCompletion(s string) string { if len(s) < 7 || s[:7] != "" { return s } end := strings.Index(s, "") if end == -1 { return s } return s[end+8:] } ================================================ FILE: logics/item_to_item_test.go ================================================ // Copyright 2024 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package logics import ( "strconv" "testing" "time" "github.com/gorse-io/gorse/common/floats" "github.com/gorse-io/gorse/common/mock" "github.com/gorse-io/gorse/config" "github.com/gorse-io/gorse/dataset" "github.com/gorse-io/gorse/storage/data" "github.com/stretchr/testify/suite" ) type ItemToItemTestSuite struct { suite.Suite } func (suite *ItemToItemTestSuite) TestColumnFunc() { item2item, err := newEmbeddingItemToItem(config.ItemToItemConfig{ Column: "item.Labels.description", }, 10, time.Now()) suite.NoError(err) // Push success item2item.Push(&data.Item{ ItemId: "1", Labels: map[string]any{ "description": []float32{0.1, 0.2, 0.3}, }, }, nil) suite.Equal(1, item2item.Count()) // Hidden item2item.Push(&data.Item{ ItemId: "2", IsHidden: true, Labels: map[string]any{ "description": []float32{0.1, 0.2, 0.3}, }, }, nil) suite.Equal(2, item2item.Count()) // Dimension does not match item2item.Push(&data.Item{ ItemId: "1", Labels: map[string]any{ "description": []float32{0.1, 0.2}, }, }, nil) suite.Equal(2, item2item.Count()) // Type does not match item2item.Push(&data.Item{ ItemId: "1", Labels: map[string]any{ "description": "hello", }, }, nil) suite.Equal(2, item2item.Count()) // Column does not exist item2item.Push(&data.Item{ ItemId: "2", Labels: []float32{0.1, 0.2, 0.3}, }, nil) suite.Equal(2, item2item.Count()) } func (suite *ItemToItemTestSuite) TestEmbedding() { timestamp := time.Now() item2item, err := newEmbeddingItemToItem(config.ItemToItemConfig{ Column: "item.Labels.description", }, 10, timestamp) suite.NoError(err) for i := 0; i < 100; i++ { item2item.Push(&data.Item{ ItemId: strconv.Itoa(i), Labels: map[string]any{ "description": []float32{0.1 * float32(i), 0.2 * float32(i), 0.3 * float32(i)}, }, }, nil) } scores := item2item.PopAll(0) suite.Len(scores, 10) for i := 1; i <= 10; i++ { suite.Equal(strconv.Itoa(i), scores[i-1].Id) } } func (suite *ItemToItemTestSuite) TestHidden() { timestamp := time.Now() item2item, err := newEmbeddingItemToItem(config.ItemToItemConfig{ Column: "item.Labels.description", }, 2, timestamp) suite.NoError(err) item2item.Push(&data.Item{ ItemId: "visible_1", Labels: map[string]any{ "description": []float32{0.0, 0.0, 0.0}, }, }, nil) item2item.Push(&data.Item{ ItemId: "visible_2", Labels: map[string]any{ "description": []float32{0.1, 0.0, 0.0}, }, }, nil) item2item.Push(&data.Item{ ItemId: "hidden_1", IsHidden: true, Labels: map[string]any{ "description": []float32{0.05, 0.0, 0.0}, }, }, nil) suite.Equal(3, item2item.Count()) // hidden item should have similar items generated from non-hidden index hiddenScores := item2item.PopAll(2) suite.Len(hiddenScores, 2) for _, score := range hiddenScores { suite.NotEqual("hidden_1", score.Id) } // non-hidden item should never get hidden item in similarity results visibleScores := item2item.PopAll(0) suite.Len(visibleScores, 1) for _, score := range visibleScores { suite.NotEqual("hidden_1", score.Id) } } func (suite *ItemToItemTestSuite) TestTags() { timestamp := time.Now() idf := make([]float32, 101) for i := range idf { idf[i] = 1 } item2item, err := newTagsItemToItem(config.ItemToItemConfig{ Column: "item.Labels", }, 10, timestamp, idf) suite.NoError(err) for i := 0; i < 100; i++ { labels := make(map[string]any) for j := 1; j <= 100-i; j++ { labels[strconv.Itoa(j)] = []dataset.ID{dataset.ID(j)} } item2item.Push(&data.Item{ ItemId: strconv.Itoa(i), Labels: labels, }, nil) } scores := item2item.PopAll(0) suite.Len(scores, 10) for i := 1; i <= 10; i++ { suite.Equal(strconv.Itoa(i), scores[i-1].Id) } } func (suite *ItemToItemTestSuite) TestUsers() { timestamp := time.Now() idf := make([]float32, 101) for i := range idf { idf[i] = 1 } item2item, err := newUsersItemToItem(config.ItemToItemConfig{}, 10, timestamp, idf) suite.NoError(err) for i := 0; i < 100; i++ { feedback := make([]int32, 0, 100-i) for j := 1; j <= 100-i; j++ { feedback = append(feedback, int32(j)) } item2item.Push(&data.Item{ItemId: strconv.Itoa(i)}, feedback) } scores := item2item.PopAll(0) suite.Len(scores, 10) for i := 1; i <= 10; i++ { suite.Equal(strconv.Itoa(i), scores[i-1].Id) } } func (suite *ItemToItemTestSuite) TestAuto() { timestamp := time.Now() idf := make([]float32, 101) for i := range idf { idf[i] = 1 } item2item, err := newAutoItemToItem(config.ItemToItemConfig{}, 10, timestamp, idf, idf) suite.NoError(err) for i := 0; i < 100; i++ { item := &data.Item{ItemId: strconv.Itoa(i)} feedback := make([]int32, 0, 100-i) if i%2 == 0 { labels := make(map[string]any) for j := 1; j <= 100-i; j++ { labels[strconv.Itoa(j)] = []dataset.ID{dataset.ID(j)} } item.Labels = labels } else { for j := 1; j <= 100-i; j++ { feedback = append(feedback, int32(j)) } } item2item.Push(item, feedback) } scores0 := item2item.PopAll(0) suite.Len(scores0, 10) for i := 1; i <= 10; i++ { suite.Equal(strconv.Itoa(i*2), scores0[i-1].Id) } scores1 := item2item.PopAll(1) suite.Len(scores1, 10) for i := 1; i <= 10; i++ { suite.Equal(strconv.Itoa(i*2+1), scores1[i-1].Id) } } func (suite *ItemToItemTestSuite) TestChat() { mockAI := mock.NewOpenAIServer() go func() { _ = mockAI.Start() }() mockAI.Ready() defer mockAI.Close() timestamp := time.Now() item2item, err := newChatItemToItem(config.ItemToItemConfig{ Column: "item.Labels.embeddings", Prompt: "Please generate similar items for {{ item.Labels.title }}.", }, 10, timestamp, config.OpenAIConfig{ BaseURL: mockAI.BaseURL(), AuthToken: mockAI.AuthToken(), ChatCompletionModel: "deepseek-r1", EmbeddingModel: "text-similarity-ada-001", }) suite.NoError(err) for i := 0; i < 100; i++ { embedding := mock.Hash("Please generate similar items for item_0.") floats.AddConst(embedding, float32(i+1)) item2item.Push(&data.Item{ ItemId: strconv.Itoa(i), Labels: map[string]any{ "title": "item_" + strconv.Itoa(i), "embeddings": embedding, }, }, nil) } scores := item2item.PopAll(0) suite.Len(scores, 10) for i := 1; i <= 10; i++ { suite.Equal(strconv.Itoa(i), scores[i-1].Id) } } func TestItemToItem(t *testing.T) { suite.Run(t, new(ItemToItemTestSuite)) } ================================================ FILE: logics/non_personalized.go ================================================ // Copyright 2024 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package logics import ( "reflect" "sort" "sync" "time" "github.com/expr-lang/expr" "github.com/expr-lang/expr/vm" "github.com/gorse-io/gorse/common/heap" "github.com/gorse-io/gorse/common/log" "github.com/gorse-io/gorse/config" "github.com/gorse-io/gorse/storage/cache" "github.com/gorse-io/gorse/storage/data" "github.com/juju/errors" "github.com/samber/lo" "go.uber.org/zap" ) type NonPersonalized struct { sync.Mutex name string timestamp time.Time scoreFunc *vm.Program filterFunc *vm.Program heapSize int heaps map[string]*heap.TopKFilter[string, float64] } func NewNonPersonalized(cfg config.NonPersonalizedConfig, n int, timestamp time.Time) (*NonPersonalized, error) { // Compile score expression scoreFunc, err := expr.Compile(cfg.Score, expr.Env(map[string]any{ "item": data.Item{}, "feedback": []data.Feedback{}, })) if err != nil { return nil, err } switch scoreFunc.Node().Type().Kind() { case reflect.Float64, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: default: return nil, errors.New("score function must return float64") } // Compile filter expression var filterFunc *vm.Program if cfg.Filter != "" { filterFunc, err = expr.Compile(cfg.Filter, expr.Env(map[string]any{ "item": data.Item{}, "feedback": []data.Feedback{}, })) if err != nil { return nil, err } if filterFunc.Node().Type().Kind() != reflect.Bool { return nil, errors.New("filter function must return bool") } } // Initialize heap heaps := make(map[string]*heap.TopKFilter[string, float64]) heaps[""] = heap.NewTopKFilter[string, float64](n) return &NonPersonalized{ name: cfg.Name, timestamp: timestamp, scoreFunc: scoreFunc, filterFunc: filterFunc, heapSize: n, heaps: heaps, }, nil } func (l *NonPersonalized) Push(item data.Item, feedback []data.Feedback) { // Skip hidden items if item.IsHidden { return } // Evaluate filter function if l.filterFunc != nil { result, err := expr.Run(l.filterFunc, map[string]any{ "item": item, "feedback": feedback, }) if err != nil { log.Logger().Error("evaluate filter function", zap.Error(err)) return } if !result.(bool) { return } } // Evaluate score function result, err := expr.Run(l.scoreFunc, map[string]any{ "item": item, "feedback": feedback, }) if err != nil { log.Logger().Error("evaluate score function", zap.Error(err)) return } var score float64 switch typed := result.(type) { case float64: score = typed case int: score = float64(typed) case int8: score = float64(typed) case int16: score = float64(typed) case int32: score = float64(typed) case int64: score = float64(typed) default: log.Logger().Error("score function must return float64", zap.Any("result", result)) return } // Add to heap l.Lock() defer l.Unlock() l.heaps[""].Push(item.ItemId, score) for _, group := range item.Categories { if _, exist := l.heaps[group]; !exist { l.heaps[group] = heap.NewTopKFilter[string, float64](l.heapSize) } l.heaps[group].Push(item.ItemId, score) } } func (l *NonPersonalized) PopAll() []cache.Score { scores := make(map[string]*cache.Score) l.Lock() defer l.Unlock() for category, h := range l.heaps { elems := h.PopAll() for _, elem := range elems { if _, exist := scores[elem.Value]; !exist { scores[elem.Value] = &cache.Score{ Id: elem.Value, Score: elem.Weight, Categories: []string{category}, Timestamp: l.timestamp, } } else { scores[elem.Value].Categories = append(scores[elem.Value].Categories, category) } } } result := lo.MapToSlice(scores, func(_ string, v *cache.Score) cache.Score { return *v }) sort.Slice(result, func(i, j int) bool { return result[i].Score > result[j].Score }) return result } func (l *NonPersonalized) Name() string { return l.name } func (l *NonPersonalized) Timestamp() time.Time { return l.timestamp } ================================================ FILE: logics/non_personalized_test.go ================================================ // Copyright 2024 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package logics import ( "fmt" "strconv" "testing" "time" "github.com/gorse-io/gorse/config" "github.com/gorse-io/gorse/storage/data" "github.com/stretchr/testify/assert" ) func TestLatest(t *testing.T) { timestamp := time.Now() latest, err := NewNonPersonalized(config.NonPersonalizedConfig{ Name: "latest", Score: "item.Timestamp.Unix()", }, 10, timestamp) assert.NoError(t, err) for i := 0; i < 100; i++ { item := data.Item{ItemId: strconv.Itoa(i), Timestamp: timestamp.Add(time.Duration(-i) * time.Second)} latest.Push(item, nil) } scores := latest.PopAll() assert.Len(t, scores, 10) for i := 0; i < 10; i++ { assert.Equal(t, strconv.Itoa(i), scores[i].Id) assert.Equal(t, float64(timestamp.Add(time.Duration(-i)*time.Second).Unix()), scores[i].Score) assert.Equal(t, timestamp, scores[i].Timestamp) } } func TestPopular(t *testing.T) { timestamp := time.Now() popular, err := NewNonPersonalized(config.NonPersonalizedConfig{ Name: "popular", Score: "len(feedback)", }, 10, timestamp) assert.NoError(t, err) for i := 0; i < 100; i++ { item := data.Item{ItemId: strconv.Itoa(i)} feedback := make([]data.Feedback, i) popular.Push(item, feedback) } scores := popular.PopAll() assert.Len(t, scores, 10) for i := 0; i < 10; i++ { assert.Equal(t, strconv.Itoa(99-i), scores[i].Id) assert.Equal(t, float64(99-i), scores[i].Score) } } func TestPopularWindow(t *testing.T) { // Create popular recommender timestamp := time.Now() popular, err := NewNonPersonalized(config.NonPersonalizedConfig{ Name: "popular", Score: "len(feedback)", Filter: fmt.Sprintf("(now() - item.Timestamp).Nanoseconds() < %d", time.Hour.Nanoseconds()), }, 10, timestamp) assert.NoError(t, err) // Add items for i := 0; i < 100; i++ { item := data.Item{ItemId: strconv.Itoa(i), Timestamp: timestamp.Add(time.Second - time.Hour)} feedback := make([]data.Feedback, i) popular.Push(item, feedback) } // Add outdated items for i := 100; i < 110; i++ { item := data.Item{ItemId: strconv.Itoa(i), Timestamp: timestamp.Add(-time.Hour)} feedback := make([]data.Feedback, i) popular.Push(item, feedback) } // Check result scores := popular.PopAll() assert.Len(t, scores, 10) for i := 0; i < 10; i++ { assert.Equal(t, strconv.Itoa(99-i), scores[i].Id) assert.Equal(t, float64(99-i), scores[i].Score) } } func TestFilter(t *testing.T) { timestamp := time.Now() latest, err := NewNonPersonalized(config.NonPersonalizedConfig{ Name: "latest", Score: "item.Timestamp.Unix()", Filter: "!item.IsHidden", }, 10, timestamp) assert.NoError(t, err) for i := 0; i < 100; i++ { item := data.Item{ItemId: strconv.Itoa(i), Timestamp: timestamp.Add(time.Duration(-i) * time.Second)} item.IsHidden = i < 10 latest.Push(item, nil) } scores := latest.PopAll() assert.Len(t, scores, 10) for i := 0; i < 10; i++ { assert.Equal(t, strconv.Itoa(i+10), scores[i].Id) assert.Equal(t, float64(timestamp.Add(time.Duration(-i-10)*time.Second).Unix()), scores[i].Score) assert.Equal(t, timestamp, scores[i].Timestamp) } } func TestHidden(t *testing.T) { timestamp := time.Now() latest, err := NewNonPersonalized(config.NonPersonalizedConfig{ Name: "latest", Score: "item.Timestamp.Unix()", }, 10, timestamp) assert.NoError(t, err) for i := 0; i < 100; i++ { item := data.Item{ItemId: strconv.Itoa(i), Timestamp: timestamp.Add(time.Duration(-i) * time.Second)} item.IsHidden = i < 10 latest.Push(item, nil) } scores := latest.PopAll() assert.Len(t, scores, 10) for i := 0; i < 10; i++ { assert.Equal(t, strconv.Itoa(i+10), scores[i].Id) assert.Equal(t, float64(timestamp.Add(time.Duration(-i-10)*time.Second).Unix()), scores[i].Score) assert.Equal(t, timestamp, scores[i].Timestamp) } } func TestMostStarredWeekly(t *testing.T) { // Create non-personalized recommender timestamp := time.Now() mostStarredWeekly, err := NewNonPersonalized(config.NonPersonalizedConfig{ Name: "most_starred_weekly", Score: "count(feedback, .FeedbackType == 'star')", Filter: "(now() - item.Timestamp).Hours() < 168", }, 10, timestamp) assert.NoError(t, err) // Add items for i := 0; i < 100; i++ { item := data.Item{ItemId: strconv.Itoa(i), Timestamp: timestamp.Add(-167 * time.Hour)} var feedback []data.Feedback for j := 0; j < i; j++ { feedback = append(feedback, data.Feedback{ FeedbackKey: data.FeedbackKey{ FeedbackType: "star", UserId: strconv.Itoa(j), ItemId: strconv.Itoa(i), }, Timestamp: timestamp, }) feedback = append(feedback, data.Feedback{ FeedbackKey: data.FeedbackKey{ FeedbackType: "like", UserId: strconv.Itoa(j), ItemId: strconv.Itoa(i), }, Timestamp: timestamp, }) } mostStarredWeekly.Push(item, feedback) } // Add outdated items for i := 100; i < 110; i++ { item := data.Item{ItemId: strconv.Itoa(i), Timestamp: timestamp.Add(-168 * time.Hour)} var feedback []data.Feedback for j := 0; j < i; j++ { feedback = append(feedback, data.Feedback{ FeedbackKey: data.FeedbackKey{ FeedbackType: "star", UserId: strconv.Itoa(j), ItemId: strconv.Itoa(i), }, Timestamp: timestamp.Add(-time.Hour * 169), }) } mostStarredWeekly.Push(item, feedback) } // Check result scores := mostStarredWeekly.PopAll() assert.Len(t, scores, 10) for i := 0; i < 10; i++ { assert.Equal(t, strconv.Itoa(99-i), scores[i].Id) assert.Equal(t, float64(99-i), scores[i].Score) } } ================================================ FILE: logics/recommend.go ================================================ // Copyright 2025 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package logics import ( "context" "strings" "time" mapset "github.com/deckarep/golang-set/v2" "github.com/gorse-io/gorse/common/expression" "github.com/gorse-io/gorse/common/heap" "github.com/gorse-io/gorse/common/util" "github.com/gorse-io/gorse/config" "github.com/gorse-io/gorse/storage/cache" "github.com/gorse-io/gorse/storage/data" "github.com/juju/errors" "github.com/samber/lo" ) const ( LatestRecommender = "latest" NonPersonalizedRecommender = "non-personalized/" ItemToItemRecommender = "item-to-item/" UserToUserRecommender = "user-to-user/" ExternalRecommender = "external/" CollaborativeRecommender = "collaborative" ) type Recommender struct { config config.RecommendConfig cacheClient cache.Database dataClient data.Database online bool coldstart bool userId string userFeedback []data.Feedback categories []string excludeSet mapset.Set[string] } type RecommenderFunc func(ctx context.Context) ([]cache.Score, string, error) func NewRecommender(config config.RecommendConfig, cacheClient cache.Database, dataClient data.Database, online bool, userId string, categories []string) (*Recommender, error) { // Load user feedback userFeedback, err := dataClient.GetUserFeedback(context.Background(), userId, lo.ToPtr(time.Now())) if err != nil { return nil, errors.Trace(err) } excludeSet := mapset.NewSet[string]() coldstart := true for _, feedback := range userFeedback { if !config.Replacement.EnableReplacement || !online { excludeSet.Add(feedback.ItemId) } if expression.MatchFeedbackTypeExpressions(config.DataSource.PositiveFeedbackTypes, feedback.FeedbackType, feedback.Value) { coldstart = false } } return &Recommender{ config: config, cacheClient: cacheClient, dataClient: dataClient, userId: userId, userFeedback: userFeedback, online: online, coldstart: coldstart, categories: categories, excludeSet: excludeSet, }, nil } func (r *Recommender) ExcludeSet() mapset.Set[string] { return r.excludeSet } func (r *Recommender) UserFeedback() []data.Feedback { return r.userFeedback } func (r *Recommender) IsColdStart() bool { return r.coldstart } func (r *Recommender) Recommend(ctx context.Context, limit int) (result []cache.Score, err error) { if !strings.EqualFold(r.config.Ranker.Type, "none") { scores, err := r.cacheClient.SearchScores(ctx, cache.Recommend, r.userId, r.categories, 0, r.config.CacheSize) if err != nil { return nil, errors.Trace(err) } result = make([]cache.Score, 0, len(scores)) for _, score := range scores { if !r.excludeSet.Contains(score.Id) { r.excludeSet.Add(score.Id) result = append(result, score) } } } else { result, _, err = r.RecommendSequential(ctx, result, r.config.CacheSize, r.config.Ranker.Recommenders...) if err != nil { return nil, errors.Trace(err) } } if len(result) >= limit && limit > 0 { return result[:limit], nil } result, _, err = r.RecommendSequential(ctx, result, limit, r.config.Fallback.Recommenders...) return result, errors.Trace(err) } // RecommendSequential recommend items from multiple recommenders sequentially util reaching the limit. // If limit <= 0, all recommendations are returned. func (r *Recommender) RecommendSequential(ctx context.Context, result []cache.Score, limit int, names ...string) ([]cache.Score, string, error) { var digests []string for _, name := range names { recommenderFunc, err := r.parse(name) if err != nil { return nil, "", errors.Trace(err) } scores, digest, err := recommenderFunc(ctx) if err != nil { return nil, "", errors.Trace(err) } for _, score := range scores { r.excludeSet.Add(score.Id) } result = append(result, scores...) digests = append(digests, digest) if limit > 0 && len(result) >= limit { return result[:limit], util.MD5(digests...), nil } } return result, util.MD5(digests...), nil } func (r *Recommender) parse(fullname string) (RecommenderFunc, error) { if fullname == CollaborativeRecommender { return r.recommendCollaborative, nil } else if fullname == LatestRecommender { return r.recommendLatest, nil } else if strings.HasPrefix(fullname, NonPersonalizedRecommender) { name := strings.TrimPrefix(fullname, NonPersonalizedRecommender) return r.recommendNonPersonalized(name), nil } else if strings.HasPrefix(fullname, ItemToItemRecommender) { name := strings.TrimPrefix(fullname, ItemToItemRecommender) return r.recommendItemToItem(name), nil } else if strings.HasPrefix(fullname, UserToUserRecommender) { name := strings.TrimPrefix(fullname, UserToUserRecommender) return r.recommendUserToUser(name), nil } else if strings.HasPrefix(fullname, ExternalRecommender) { name := strings.TrimPrefix(fullname, ExternalRecommender) return r.recommendExternal(name), nil } else { return nil, errors.Errorf("unknown recommender: %s", fullname) } } func (r *Recommender) recommendLatest(ctx context.Context) ([]cache.Score, string, error) { items, err := r.dataClient.GetLatestItems(ctx, r.config.CacheSize, r.categories) if err != nil { return nil, "", errors.Trace(err) } scores := make([]cache.Score, 0, len(items)) for _, item := range items { if !r.excludeSet.Contains(item.ItemId) { scores = append(scores, cache.Score{ Id: item.ItemId, Score: float64(item.Timestamp.Unix()), Categories: item.Categories, }) } } return scores, "latest", nil } func (r *Recommender) recommendNonPersonalized(name string) RecommenderFunc { return func(ctx context.Context) ([]cache.Score, string, error) { var categories []string if len(r.categories) == 0 { categories = []string{""} } else { categories = r.categories } // fetch items from cache items, err := r.cacheClient.SearchScores(ctx, cache.NonPersonalized, name, categories, 0, r.config.CacheSize) if err != nil { return nil, "", errors.Trace(err) } // read digest digest, err := r.cacheClient.Get(ctx, cache.Key(cache.NonPersonalizedDigest, name)).String() if err != nil { return nil, "", errors.Trace(err) } // remove excluded items return lo.Filter(items, func(item cache.Score, index int) bool { return !r.excludeSet.Contains(item.Id) }), digest, nil } } func (r *Recommender) recommendCollaborative(ctx context.Context) ([]cache.Score, string, error) { // fetch items from cache items, err := r.cacheClient.SearchScores(ctx, cache.CollaborativeFiltering, r.userId, r.categories, 0, r.config.CacheSize) if err != nil { return nil, "", errors.Trace(err) } // read digest digest, err := r.cacheClient.Get(ctx, cache.Key(cache.CollaborativeFilteringDigest, r.userId)).String() if err != nil { return nil, "", errors.Trace(err) } // remove excluded items return lo.Filter(items, func(item cache.Score, index int) bool { return !r.excludeSet.Contains(item.Id) }), digest, nil } func (r *Recommender) recommendItemToItem(name string) RecommenderFunc { return func(ctx context.Context) ([]cache.Score, string, error) { // filter positive feedbacks data.SortFeedbacks(r.userFeedback) userFeedback := make([]data.Feedback, 0, r.config.CacheSize) for _, feedback := range r.userFeedback { if r.online && r.config.ContextSize <= len(userFeedback) { break } if expression.MatchFeedbackTypeExpressions(r.config.DataSource.PositiveFeedbackTypes, feedback.FeedbackType, feedback.Value) { userFeedback = append(userFeedback, feedback) } } // collect scores scores := make(map[string]float64) categories := make(map[string][]string) digests := mapset.NewSet[string]() for _, feedback := range userFeedback { similarItems, err := r.cacheClient.SearchScores(ctx, cache.ItemToItem, cache.Key(name, feedback.ItemId), r.categories, 0, r.config.CacheSize) if err != nil { return nil, "", errors.Trace(err) } digest, err := r.cacheClient.Get(ctx, cache.Key(cache.ItemToItemDigest, name, feedback.ItemId)).String() if err != nil { return nil, "", errors.Trace(err) } for _, item := range similarItems { if !r.excludeSet.Contains(item.Id) { scores[item.Id] += item.Score categories[item.Id] = item.Categories digests.Add(digest) } } } // collect top scores filter := heap.NewTopKFilter[string, float64](r.config.CacheSize) for id, score := range scores { filter.Push(id, score) } elems := filter.PopAll() return lo.Map(elems, func(elem heap.Elem[string, float64], _ int) cache.Score { return cache.Score{ Id: elem.Value, Score: elem.Weight, Categories: categories[elem.Value], } }), strings.Join(digests.ToSlice(), ""), nil } } func (r *Recommender) recommendUserToUser(name string) RecommenderFunc { return func(ctx context.Context) ([]cache.Score, string, error) { scores := make(map[string]float64) // load similar users similarUsers, err := r.cacheClient.SearchScores(ctx, cache.UserToUser, cache.Key(name, r.userId), nil, 0, r.config.CacheSize) if err != nil { return nil, "", errors.Trace(err) } // read digest digest, err := r.cacheClient.Get(ctx, cache.Key(cache.UserToUserDigest, name, r.userId)).String() if err != nil { return nil, "", errors.Trace(err) } // aggregate scores for _, user := range similarUsers { // load historical feedback feedbacks, err := r.dataClient.GetUserFeedback(ctx, user.Id, lo.ToPtr(time.Now()), r.config.DataSource.PositiveFeedbackTypes...) if err != nil { return nil, "", errors.Trace(err) } // add unseen items for _, feedback := range feedbacks { if !r.excludeSet.Contains(feedback.ItemId) { scores[feedback.ItemId] += user.Score } } } // collect top k filter := heap.NewTopKFilter[string, float64](r.config.CacheSize) for id, score := range scores { filter.Push(id, score) } elems := filter.PopAll() // filter by categories results := make([]cache.Score, 0, len(elems)) ids := lo.Map(elems, func(elem heap.Elem[string, float64], _ int) string { return elem.Value }) items, err := r.dataClient.BatchGetItems(ctx, ids) if err != nil { return nil, "", errors.Trace(err) } itemsMap := make(map[string]data.Item) for _, item := range items { itemsMap[item.ItemId] = item } for _, elem := range elems { if item, ok := itemsMap[elem.Value]; ok && lo.Every(item.Categories, r.categories) { results = append(results, cache.Score{ Id: item.ItemId, Score: elem.Weight, Categories: item.Categories, }) } } return results, digest, nil } } func (r *Recommender) recommendExternal(name string) RecommenderFunc { return func(ctx context.Context) ([]cache.Score, string, error) { var externalConfig config.ExternalConfig for _, extConfig := range r.config.External { if extConfig.Name == name { externalConfig = extConfig break } } if len(r.categories) > 0 { // external recommenders do not support categories return nil, externalConfig.Hash(), nil } external, err := NewExternal(externalConfig) if err != nil { return nil, "", errors.Trace(err) } defer external.Close() items, err := external.Pull(r.userId) if err != nil { return nil, "", errors.Trace(err) } scores := make([]cache.Score, 0, len(items)) for _, itemId := range items { if !r.excludeSet.Contains(itemId) { scores = append(scores, cache.Score{ Id: itemId, }) } } return scores, externalConfig.Hash(), nil } } ================================================ FILE: logics/recommend_test.go ================================================ // Copyright 2025 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package logics import ( "fmt" "net/http" "net/http/httptest" "testing" "time" "github.com/gorse-io/gorse/config" "github.com/gorse-io/gorse/storage/cache" "github.com/gorse-io/gorse/storage/data" "github.com/stretchr/testify/suite" ) type RecommenderTestSuite struct { suite.Suite dataClient data.Database cacheClient cache.Database } func (suite *RecommenderTestSuite) SetupSuite() { var err error // open database suite.dataClient, err = data.Open(fmt.Sprintf("sqlite://%s/data.db", suite.T().TempDir()), "") suite.NoError(err) suite.cacheClient, err = cache.Open(fmt.Sprintf("sqlite://%s/cache.db", suite.T().TempDir()), "") suite.NoError(err) // init database err = suite.dataClient.Init() suite.NoError(err) err = suite.cacheClient.Init() suite.NoError(err) } func (suite *RecommenderTestSuite) TearDownSuite() { err := suite.dataClient.Close() suite.NoError(err) err = suite.cacheClient.Close() suite.NoError(err) } func (suite *RecommenderTestSuite) TestLatest() { items := make([]data.Item, 20) for i := 0; i < 20; i++ { items[i] = data.Item{ ItemId: fmt.Sprintf("item_%d", i), Timestamp: time.Unix(int64(i), 0), } if i%2 == 0 { items[i].Categories = []string{"cat_1"} } } err := suite.dataClient.BatchInsertItems(suite.T().Context(), items) suite.NoError(err) feedback := make([]data.Feedback, 10) for i := 0; i < 10; i++ { feedback[i] = data.Feedback{ FeedbackKey: data.FeedbackKey{ FeedbackType: "click", UserId: "user_1", ItemId: fmt.Sprintf("item_%d", i), }, } } err = suite.dataClient.BatchInsertFeedback(suite.T().Context(), feedback, true, true, false) suite.NoError(err) recommender, err := NewRecommender(config.RecommendConfig{}, suite.cacheClient, suite.dataClient, true, "user_1", nil) suite.NoError(err) scores, digest, err := recommender.recommendLatest(suite.T().Context()) suite.NoError(err) suite.Equal("latest", digest) if suite.Equal(10, len(scores)) { for i := 0; i < 10; i++ { suite.Equal(fmt.Sprintf("item_%d", 19-i), scores[i].Id) suite.Equal(float64(19-i), scores[i].Score) } } recommender, err = NewRecommender(config.RecommendConfig{}, suite.cacheClient, suite.dataClient, true, "user_1", []string{"cat_1"}) suite.NoError(err) scores, digest, err = recommender.recommendLatest(suite.T().Context()) suite.NoError(err) suite.Equal("latest", digest) if suite.Equal(5, len(scores)) { for i := 0; i < 5; i++ { suite.Equal(fmt.Sprintf("item_%d", 18-2*i), scores[i].Id) suite.Equal(float64(18-2*i), scores[i].Score) } } } func (suite *RecommenderTestSuite) TestCollaborative() { recommends := make([]cache.Score, 20) for i := 0; i < 20; i++ { recommends[i] = cache.Score{ Id: fmt.Sprintf("item_%d", i), Score: float64(i), } if i%2 == 0 { recommends[i].Categories = []string{"cat_1"} } } err := suite.cacheClient.AddScores(suite.T().Context(), cache.CollaborativeFiltering, "user_1", recommends) suite.NoError(err) err = suite.cacheClient.Set(suite.T().Context(), cache.String(cache.Key(cache.CollaborativeFilteringDigest, "user_1"), "digest")) suite.NoError(err) feedback := make([]data.Feedback, 10) for i := 0; i < 10; i++ { feedback[i] = data.Feedback{ FeedbackKey: data.FeedbackKey{ FeedbackType: "click", UserId: "user_1", ItemId: fmt.Sprintf("item_%d", i), }, } } err = suite.dataClient.BatchInsertFeedback(suite.T().Context(), feedback, true, true, false) suite.NoError(err) recommender, err := NewRecommender(config.RecommendConfig{}, suite.cacheClient, suite.dataClient, true, "user_1", nil) suite.NoError(err) scores, digest, err := recommender.recommendCollaborative(suite.T().Context()) suite.NoError(err) suite.Equal("digest", digest) if suite.Equal(10, len(scores)) { for i := 0; i < 10; i++ { suite.Equal(fmt.Sprintf("item_%d", 19-i), scores[i].Id) suite.Equal(float64(19-i), scores[i].Score) } } recommender, err = NewRecommender(config.RecommendConfig{}, suite.cacheClient, suite.dataClient, true, "user_1", []string{"cat_1"}) suite.NoError(err) scores, digest, err = recommender.recommendCollaborative(suite.T().Context()) suite.NoError(err) suite.Equal("digest", digest) if suite.Equal(5, len(scores)) { for i := 0; i < 5; i++ { suite.Equal(fmt.Sprintf("item_%d", 18-2*i), scores[i].Id) suite.Equal(float64(18-2*i), scores[i].Score) } } } func (suite *RecommenderTestSuite) TestNonPersonalized() { recommends := make([]cache.Score, 20) for i := 0; i < 20; i++ { recommends[i] = cache.Score{ Id: fmt.Sprintf("item_%d", i), Score: float64(i), } if i%2 == 0 { recommends[i].Categories = []string{"", "cat_1"} } else { recommends[i].Categories = []string{""} } } err := suite.cacheClient.AddScores(suite.T().Context(), cache.NonPersonalized, "a", recommends) suite.NoError(err) err = suite.cacheClient.Set(suite.T().Context(), cache.String(cache.Key(cache.NonPersonalizedDigest, "a"), "digest")) suite.NoError(err) feedback := make([]data.Feedback, 10) for i := 0; i < 10; i++ { feedback[i] = data.Feedback{ FeedbackKey: data.FeedbackKey{ FeedbackType: "click", UserId: "user_1", ItemId: fmt.Sprintf("item_%d", i), }, } } err = suite.dataClient.BatchInsertFeedback(suite.T().Context(), feedback, true, true, false) suite.NoError(err) recommender, err := NewRecommender(config.RecommendConfig{}, suite.cacheClient, suite.dataClient, true, "user_1", nil) suite.NoError(err) recommendFunc := recommender.recommendNonPersonalized("a") scores, digest, err := recommendFunc(suite.T().Context()) suite.NoError(err) suite.Equal("digest", digest) if suite.Equal(10, len(scores)) { for i := 0; i < 10; i++ { suite.Equal(fmt.Sprintf("item_%d", 19-i), scores[i].Id) suite.Equal(float64(19-i), scores[i].Score) } } recommender, err = NewRecommender(config.RecommendConfig{}, suite.cacheClient, suite.dataClient, true, "user_1", []string{"cat_1"}) suite.NoError(err) recommendFunc = recommender.recommendNonPersonalized("a") scores, digest, err = recommendFunc(suite.T().Context()) suite.NoError(err) suite.Equal("digest", digest) if suite.Equal(5, len(scores)) { for i := 0; i < 5; i++ { suite.Equal(fmt.Sprintf("item_%d", 18-2*i), scores[i].Id) suite.Equal(float64(18-2*i), scores[i].Score) } } } func (suite *RecommenderTestSuite) TestExternal() { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { userId := r.URL.Query().Get("user_id") if userId == "user_1" { fmt.Fprintln(w, `["item_1", "item_2", "item_3", "item_100", "item_200", "item_300"]`) } else { http.NotFound(w, r) } })) defer ts.Close() feedback := make([]data.Feedback, 10) for i := 0; i < 10; i++ { feedback[i] = data.Feedback{ FeedbackKey: data.FeedbackKey{ FeedbackType: "click", UserId: "user_1", ItemId: fmt.Sprintf("item_%d", i), }, } } err := suite.dataClient.BatchInsertFeedback(suite.T().Context(), feedback, true, true, false) suite.NoError(err) cfg := config.RecommendConfig{ External: []config.ExternalConfig{{ Script: fmt.Sprintf(`fetch("%s?user_id=user_1").body`, ts.URL), Name: "test", }}, } recommender, err := NewRecommender(cfg, suite.cacheClient, suite.dataClient, true, "user_1", nil) suite.NoError(err) recommendFunc := recommender.recommendExternal("test") scores, digest, err := recommendFunc(suite.T().Context()) suite.NoError(err) suite.Equal(cfg.External[0].Hash(), digest) suite.Equal([]cache.Score{ {Id: "item_100", Score: 0}, {Id: "item_200", Score: 0}, {Id: "item_300", Score: 0}, }, scores) } func TestRecommenderTestSuite(t *testing.T) { suite.Run(t, new(RecommenderTestSuite)) } ================================================ FILE: logics/user_to_user.go ================================================ // Copyright 2025 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package logics import ( "sort" "sync" "time" mapset "github.com/deckarep/golang-set/v2" "github.com/expr-lang/expr" "github.com/expr-lang/expr/vm" "github.com/gorse-io/gorse/common/ann" "github.com/gorse-io/gorse/common/floats" "github.com/gorse-io/gorse/common/log" "github.com/gorse-io/gorse/config" "github.com/gorse-io/gorse/dataset" "github.com/gorse-io/gorse/storage/cache" "github.com/gorse-io/gorse/storage/data" "github.com/juju/errors" "github.com/samber/lo" "go.uber.org/zap" ) type UserToUserOptions struct { TagsIDF []float32 ItemsIDF []float32 } type UserToUser interface { Users() []*data.User Push(user *data.User, feedback []int32) PopAll(i int) []cache.Score Timestamp() time.Time } func NewUserToUser(cfg config.UserToUserConfig, n int, timestamp time.Time, opts *UserToUserOptions) (UserToUser, error) { switch cfg.Type { case "embedding": return newEmbeddingUserToUser(cfg, n, timestamp) case "tags": if opts == nil || opts.TagsIDF == nil { return nil, errors.New("tags IDF is required for tags user-to-user") } return newTagsUserToUser(cfg, n, timestamp, opts.TagsIDF) case "items": if opts == nil || opts.ItemsIDF == nil { return nil, errors.New("items IDF is required for items user-to-user") } return newItemsUserToUser(cfg, n, timestamp, opts.ItemsIDF) case "auto": if opts == nil || opts.TagsIDF == nil || opts.ItemsIDF == nil { return nil, errors.New("tags IDF and items IDF are required for auto user-to-user") } return newAutoUserToUser(cfg, n, timestamp, opts.TagsIDF, opts.ItemsIDF) } return nil, errors.New("unknown user-to-user method") } type baseUserToUser[T any] struct { name string n int timestamp time.Time columnFunc *vm.Program index *ann.HNSW[T] users []*data.User usersLock sync.Mutex } func (b *baseUserToUser[T]) Users() []*data.User { return b.users } func (b *baseUserToUser[T]) Timestamp() time.Time { return b.timestamp } func (b *baseUserToUser[T]) PopAll(i int) []cache.Score { scores, err := b.index.SearchIndex(i, b.n+1, true) if err != nil { log.Logger().Error("failed to search index", zap.Error(err)) return nil } return lo.Map(scores, func(v lo.Tuple2[int, float32], _ int) cache.Score { return cache.Score{ Id: b.users[v.A].UserId, Score: 1.0 / (1.0 + float64(v.B)), Timestamp: b.timestamp, } }) } type embeddingUserToUser struct { baseUserToUser[[]float32] dimension int } func newEmbeddingUserToUser(cfg config.UserToUserConfig, n int, timestamp time.Time) (UserToUser, error) { // Compile column expression columnFunc, err := expr.Compile(cfg.Column, expr.Env(map[string]any{ "user": data.User{}, })) if err != nil { return nil, err } return &embeddingUserToUser{baseUserToUser: baseUserToUser[[]float32]{ name: cfg.Name, n: n, timestamp: timestamp, columnFunc: columnFunc, index: ann.NewHNSW[[]float32](floats.Euclidean), users: []*data.User{}, }}, nil } func (e *embeddingUserToUser) Push(user *data.User, _ []int32) { // Evaluate filter function result, err := expr.Run(e.columnFunc, map[string]any{ "user": user, }) if err != nil { log.Logger().Error("failed to evaluate column expression", zap.Error(err)) return } // Check column type v, ok := result.([]float32) if !ok { log.Logger().Error("invalid column type", zap.Any("column", result)) return } // Check dimension e.usersLock.Lock() if e.dimension == 0 && len(v) > 0 { e.dimension = len(v) } else if e.dimension != len(v) { log.Logger().Error("invalid dimension", zap.Int("expected", e.dimension), zap.Int("actual", len(v))) return } // Push user e.users = append(e.users, nil) e.usersLock.Unlock() j := e.index.Add(v) e.usersLock.Lock() e.users[j] = user e.usersLock.Unlock() } type tagsUserToUser struct { baseUserToUser[[]dataset.ID] IDF[dataset.ID] } func newTagsUserToUser(cfg config.UserToUserConfig, n int, timestamp time.Time, idf []float32) (UserToUser, error) { // Compile column expression columnFunc, err := expr.Compile(cfg.Column, expr.Env(map[string]any{ "user": data.User{}, })) if err != nil { return nil, err } t := &tagsUserToUser{IDF: idf} t.baseUserToUser = baseUserToUser[[]dataset.ID]{ name: cfg.Name, n: n, timestamp: timestamp, columnFunc: columnFunc, index: ann.NewHNSW[[]dataset.ID](t.distance), } return t, nil } func (t *tagsUserToUser) Push(user *data.User, _ []int32) { // Evaluate filter function result, err := expr.Run(t.columnFunc, map[string]any{ "user": user, }) if err != nil { log.Logger().Error("failed to evaluate column expression", zap.Error(err)) return } // Extract tags tSet := mapset.NewSet[dataset.ID]() flatten(result, tSet) v := tSet.ToSlice() sort.Slice(v, func(i, j int) bool { return v[i] < v[j] }) // Push user t.usersLock.Lock() t.users = append(t.users, nil) t.usersLock.Unlock() j := t.index.Add(v) t.usersLock.Lock() t.users[j] = user t.usersLock.Unlock() } type itemsUserToUser struct { baseUserToUser[[]int32] IDF[int32] } func newItemsUserToUser(cfg config.UserToUserConfig, n int, timestamp time.Time, idf []float32) (UserToUser, error) { if cfg.Column != "" { return nil, errors.New("column is not supported in items user-to-user") } i := &itemsUserToUser{IDF: idf} i.baseUserToUser = baseUserToUser[[]int32]{ name: cfg.Name, n: n, timestamp: timestamp, index: ann.NewHNSW[[]int32](i.distance), } return i, nil } func (i *itemsUserToUser) Push(user *data.User, feedback []int32) { // Sort feedback sort.Slice(feedback, func(i, j int) bool { return feedback[i] < feedback[j] }) // Push user i.usersLock.Lock() i.users = append(i.users, nil) i.usersLock.Unlock() j := i.index.Add(feedback) i.usersLock.Lock() i.users[j] = user i.usersLock.Unlock() } type autoUserToUser struct { baseUserToUser[lo.Tuple2[[]dataset.ID, []int32]] tIDF IDF[dataset.ID] iIDF IDF[int32] } func newAutoUserToUser(cfg config.UserToUserConfig, n int, timestamp time.Time, tIDF, iIDF []float32) (UserToUser, error) { a := &autoUserToUser{ tIDF: tIDF, iIDF: iIDF, } a.baseUserToUser = baseUserToUser[lo.Tuple2[[]dataset.ID, []int32]]{ name: cfg.Name, n: n, timestamp: timestamp, index: ann.NewHNSW[lo.Tuple2[[]dataset.ID, []int32]](a.distance), } return a, nil } func (a *autoUserToUser) Push(user *data.User, feedback []int32) { // Extract tags tSet := mapset.NewSet[dataset.ID]() flatten(user.Labels, tSet) t := tSet.ToSlice() sort.Slice(t, func(i, j int) bool { return t[i] < t[j] }) // Sort feedback sort.Slice(feedback, func(i, j int) bool { return feedback[i] < feedback[j] }) // Push user a.usersLock.Lock() a.users = append(a.users, nil) a.usersLock.Unlock() j := a.index.Add(lo.Tuple2[[]dataset.ID, []int32]{A: t, B: feedback}) a.usersLock.Lock() a.users[j] = user a.usersLock.Unlock() } func (a *autoUserToUser) distance(u, v lo.Tuple2[[]dataset.ID, []int32]) float32 { return (a.tIDF.distance(u.A, v.A) + a.iIDF.distance(u.B, v.B)) / 2 } ================================================ FILE: logics/user_to_user_test.go ================================================ // Copyright 2025 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package logics import ( "strconv" "testing" "time" "github.com/gorse-io/gorse/config" "github.com/gorse-io/gorse/dataset" "github.com/gorse-io/gorse/storage/data" "github.com/stretchr/testify/suite" ) type UserToUserTestSuite struct { suite.Suite } func (suite *UserToUserTestSuite) TestEmbedding() { timestamp := time.Now() user2user, err := newEmbeddingUserToUser(config.UserToUserConfig{ Column: "user.Labels.description", }, 10, timestamp) suite.NoError(err) for i := 0; i < 100; i++ { user2user.Push(&data.User{ UserId: strconv.Itoa(i), Labels: map[string]any{ "description": []float32{0.1 * float32(i), 0.2 * float32(i), 0.3 * float32(i)}, }, }, nil) } scores := user2user.PopAll(0) suite.Len(scores, 10) for i := 1; i <= 10; i++ { suite.Equal(strconv.Itoa(i), scores[i-1].Id) } } func (suite *UserToUserTestSuite) TestTags() { timestamp := time.Now() idf := make([]float32, 101) for i := range idf { idf[i] = 1 } user2user, err := newTagsUserToUser(config.UserToUserConfig{ Column: "user.Labels", }, 10, timestamp, idf) suite.NoError(err) for i := 0; i < 100; i++ { labels := make(map[string]any) for j := 1; j <= 100-i; j++ { labels[strconv.Itoa(j)] = []dataset.ID{dataset.ID(j)} } user2user.Push(&data.User{ UserId: strconv.Itoa(i), Labels: labels, }, nil) } scores := user2user.PopAll(0) suite.Len(scores, 10) for i := 1; i <= 10; i++ { suite.Equal(strconv.Itoa(i), scores[i-1].Id) } } func (suite *UserToUserTestSuite) TestItems() { timestamp := time.Now() idf := make([]float32, 101) for i := range idf { idf[i] = 1 } user2user, err := newItemsUserToUser(config.UserToUserConfig{}, 10, timestamp, idf) suite.NoError(err) for i := 0; i < 100; i++ { feedback := make([]int32, 0, 100-i) for j := 1; j <= 100-i; j++ { feedback = append(feedback, int32(j)) } user2user.Push(&data.User{UserId: strconv.Itoa(i)}, feedback) } scores := user2user.PopAll(0) suite.Len(scores, 10) for i := 1; i <= 10; i++ { suite.Equal(strconv.Itoa(i), scores[i-1].Id) } } func (suite *UserToUserTestSuite) TestAuto() { timestamp := time.Now() idf := make([]float32, 101) for i := range idf { idf[i] = 1 } user2user, err := newAutoUserToUser(config.UserToUserConfig{}, 10, timestamp, idf, idf) suite.NoError(err) for i := 0; i < 100; i++ { user := &data.User{UserId: strconv.Itoa(i)} feedback := make([]int32, 0, 100-i) if i%2 == 0 { labels := make(map[string]any) for j := 1; j <= 100-i; j++ { labels[strconv.Itoa(j)] = []dataset.ID{dataset.ID(j)} } user.Labels = labels } else { for j := 1; j <= 100-i; j++ { feedback = append(feedback, int32(j)) } } user2user.Push(user, feedback) } scores0 := user2user.PopAll(0) suite.Len(scores0, 10) scores1 := user2user.PopAll(1) suite.Len(scores1, 10) } func TestUserToUser(t *testing.T) { suite.Run(t, new(UserToUserTestSuite)) } ================================================ FILE: master/master.go ================================================ // Copyright 2020 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package master import ( "context" "encoding/json" "fmt" "math" "math/rand" "net" "sync" "time" "github.com/coreos/go-oidc/v3/oidc" "github.com/emicklei/go-restful/v3" "github.com/gorse-io/gorse/common/log" "github.com/gorse-io/gorse/common/monitor" "github.com/gorse-io/gorse/common/parallel" "github.com/gorse-io/gorse/common/util" "github.com/gorse-io/gorse/config" "github.com/gorse-io/gorse/dataset" "github.com/gorse-io/gorse/model/cf" "github.com/gorse-io/gorse/model/ctr" "github.com/gorse-io/gorse/protocol" "github.com/gorse-io/gorse/server" "github.com/gorse-io/gorse/storage/blob" "github.com/gorse-io/gorse/storage/cache" "github.com/gorse-io/gorse/storage/data" "github.com/gorse-io/gorse/storage/meta" "github.com/jellydator/ttlcache/v3" "github.com/juju/errors" "github.com/sashabaranov/go-openai" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/propagation" "go.uber.org/zap" "golang.org/x/oauth2" "google.golang.org/grpc" ) type Datasets struct { rankingDataset *dataset.Dataset rankingTrainSet dataset.CFSplit rankingTestSet dataset.CFSplit clickDataset *ctr.Dataset clickTrainSet *ctr.Dataset clickTestSet *ctr.Dataset } // Master is the master node. type Master struct { protocol.UnimplementedMasterServer server.RestServer grpcServer *grpc.Server tracer *monitor.Monitor remoteProgress sync.Map cachePath string configPath string standalone bool openAIClient *openai.Client // cluster meta cache metaStore meta.Database blobStore blob.Store blobServer *blob.MasterStoreServer // collaborative filtering collaborativeFilteringModelMutex sync.RWMutex collaborativeFilteringTrainSetSize int collaborativeFilteringMeta meta.Model[cf.Score] collaborativeFilteringTarget meta.Model[cf.Score] // click model clickThroughRateModelMutex sync.RWMutex clickThroughRateTrainSetSize int clickThroughRateMeta meta.Model[ctr.Score] clickThroughRateTarget meta.Model[ctr.Score] // oauth2 oauth2Config oauth2.Config verifier *oidc.IDTokenVerifier tokenCache *ttlcache.Cache[string, UserInfo] // events ticker *time.Ticker scheduled chan struct{} cancel context.CancelFunc } // NewMaster creates a master node. func NewMaster(cfg *config.Config, cacheFolder string, standalone bool, configPath string) *Master { rand.Seed(time.Now().UnixNano()) // setup trace provider tp, err := cfg.Tracing.NewTracerProvider() if err != nil { log.Logger().Fatal("failed to create trace provider", zap.Error(err)) } otel.SetTracerProvider(tp) otel.SetErrorHandler(log.GetErrorHandler()) otel.SetTextMapPropagator(propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{})) // setup OpenAI client clientConfig := openai.DefaultConfig(cfg.OpenAI.AuthToken) clientConfig.BaseURL = cfg.OpenAI.BaseURL // setup OpenAI logger log.InitOpenAILogger(cfg.OpenAI.LogFile) // setup OpenAI rate limiter parallel.InitChatCompletionLimiters(cfg.OpenAI.ChatCompletionRPM, cfg.OpenAI.ChatCompletionTPM) parallel.InitEmbeddingLimiters(cfg.OpenAI.EmbeddingRPM, cfg.OpenAI.EmbeddingTPM) duration := min(cfg.Recommend.Collaborative.FitPeriod, cfg.Recommend.Ranker.FitPeriod) m := &Master{ // create task monitor cachePath: cacheFolder, configPath: configPath, standalone: standalone, tracer: monitor.NewTracer("master"), openAIClient: openai.NewClientWithConfig(clientConfig), RestServer: server.RestServer{ Config: cfg, CacheClient: cache.NoDatabase{}, DataClient: data.NoDatabase{}, HttpHost: cfg.Master.HttpHost, HttpPort: cfg.Master.HttpPort, WebService: new(restful.WebService), }, ticker: time.NewTicker(duration), scheduled: make(chan struct{}, 1), cancel: func() {}, } return m } // Serve starts the master node. func (m *Master) Serve() { // connect blob store var err error m.blobServer = blob.NewMasterStoreServer(m.Config.Blob.URI) m.blobStore, err = blob.NewStore(m.Config.Blob, nil) if err != nil { log.Logger().Fatal("failed to create blob store", zap.Error(err)) } // connect meta database m.metaStore, err = meta.Open(fmt.Sprintf("sqlite://%s/meta.sqlite3", m.cachePath), m.Config.Master.MetaTimeout) if err != nil { log.Logger().Fatal("failed to connect meta database", zap.Error(err)) } if err = m.metaStore.Init(); err != nil { log.Logger().Fatal("failed to init meta database", zap.Error(err)) } // connect data database dataOpts := m.Config.Database.StorageOptions(m.Config.Database.DataStore) m.DataClient, err = data.Open(m.Config.Database.DataStore, m.Config.Database.DataTablePrefix, dataOpts...) if err != nil { log.Logger().Fatal("failed to connect data database", zap.Error(err), zap.String("database", log.RedactDBURL(m.Config.Database.DataStore))) } if err = m.DataClient.Init(); err != nil { log.Logger().Fatal("failed to init database", zap.Error(err)) } // connect cache database cacheOpts := m.Config.Database.StorageOptions(m.Config.Database.CacheStore) m.CacheClient, err = cache.Open(m.Config.Database.CacheStore, m.Config.Database.CacheTablePrefix, cacheOpts...) if err != nil { log.Logger().Fatal("failed to connect cache database", zap.Error(err), zap.String("database", log.RedactDBURL(m.Config.Database.CacheStore))) } if err = m.CacheClient.Init(); err != nil { log.Logger().Fatal("failed to init database", zap.Error(err)) } // load recommend config metaStr, err := m.metaStore.Get(meta.RECOMMEND_CONFIG) if err != nil && !errors.Is(err, errors.NotFound) { log.Logger().Error("failed to load recommend config", zap.Error(err)) } else if metaStr != nil { err = json.Unmarshal([]byte(*metaStr), &m.Config.Recommend) if err != nil { log.Logger().Error("failed to unmarshal recommend config", zap.Error(err)) } } // load collective filtering model meta metaStr, err = m.metaStore.Get(meta.COLLABORATIVE_FILTERING_MODEL) if err != nil && !errors.Is(err, errors.NotFound) { log.Logger().Error("failed to load collaborative filtering meta", zap.Error(err)) } else if metaStr != nil { if err = m.collaborativeFilteringMeta.FromJSON(*metaStr); err != nil { log.Logger().Error("failed to unmarshal collaborative filtering meta", zap.Error(err)) } else { log.Logger().Info("loaded collaborative filtering model", zap.String("type", m.collaborativeFilteringMeta.Type), zap.Any("params", m.collaborativeFilteringMeta.Params), zap.Any("score", m.collaborativeFilteringMeta.Score)) } } // load click-through rate model metaStr, err = m.metaStore.Get(meta.CLICK_THROUGH_RATE_MODEL) if err != nil && !errors.Is(err, errors.NotFound) { log.Logger().Error("failed to load click-through rate meta", zap.Error(err)) } else if metaStr != nil { if err = m.clickThroughRateMeta.FromJSON(*metaStr); err != nil { log.Logger().Error("failed to unmarshal click-through rate meta", zap.Error(err)) } else { log.Logger().Info("loaded click-through rate model", zap.String("type", m.clickThroughRateMeta.Type), zap.Any("params", m.clickThroughRateMeta.Params), zap.Any("score", m.clickThroughRateMeta.Score)) } } go m.RunTasksLoop() // start rpc server go func() { log.Logger().Info("start rpc server", zap.String("host", m.Config.Master.Host), zap.Int("port", m.Config.Master.Port), zap.Bool("ssl_mode", m.Config.Master.SSLMode), zap.String("ssl_ca", m.Config.Master.SSLCA), zap.String("ssl_cert", m.Config.Master.SSLCert), zap.String("ssl_key", m.Config.Master.SSLKey)) opts := []grpc.ServerOption{grpc.MaxSendMsgSize(math.MaxInt)} if m.Config.Master.SSLMode { c, err := util.NewServerCreds(&util.TLSConfig{ SSLCA: m.Config.Master.SSLCA, SSLCert: m.Config.Master.SSLCert, SSLKey: m.Config.Master.SSLKey, }) if err != nil { log.Logger().Fatal("failed to load server TLS", zap.Error(err)) } opts = append(opts, grpc.Creds(c)) } lis, err := net.Listen("tcp", fmt.Sprintf("%s:%d", m.Config.Master.Host, m.Config.Master.Port)) if err != nil { log.Logger().Fatal("failed to listen", zap.Error(err)) } m.grpcServer = grpc.NewServer(opts...) protocol.RegisterMasterServer(m.grpcServer, m) protocol.RegisterCacheStoreServer(m.grpcServer, cache.NewProxyServer(m.CacheClient)) protocol.RegisterDataStoreServer(m.grpcServer, data.NewProxyServer(m.DataClient)) protocol.RegisterBlobStoreServer(m.grpcServer, m.blobServer) if err = m.grpcServer.Serve(lis); err != nil { log.Logger().Fatal("failed to start rpc server", zap.Error(err)) } }() if m.Config.OIDC.Enable { provider, err := oidc.NewProvider(context.Background(), m.Config.OIDC.Issuer) if err != nil { log.Logger().Error("failed to create oidc provider", zap.Error(err)) } else { m.verifier = provider.Verifier(&oidc.Config{ClientID: m.Config.OIDC.ClientID}) m.oauth2Config = oauth2.Config{ ClientID: m.Config.OIDC.ClientID, ClientSecret: m.Config.OIDC.ClientSecret, RedirectURL: m.Config.OIDC.RedirectURL, Endpoint: provider.Endpoint(), Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, } m.tokenCache = ttlcache.New(ttlcache.WithTTL[string, UserInfo](time.Hour)) go m.tokenCache.Start() } } // start http server m.StartHttpServer() } func (m *Master) Shutdown() { // stop http server err := m.HttpServer.Shutdown(context.TODO()) if err != nil { log.Logger().Error("failed to shutdown http server", zap.Error(err)) } // stop grpc server m.grpcServer.GracefulStop() } func (m *Master) RunTasksLoop() { defer util.CheckPanic() select { case m.scheduled <- struct{}{}: default: } for { select { case <-m.ticker.C: case <-m.scheduled: } // download dataset var ctx context.Context ctx, m.cancel = context.WithCancel(context.Background()) err := m.runLoadDatasetTask(ctx) if err != nil { log.Logger().Error("failed to load ranking dataset", zap.Error(err)) continue } } } ================================================ FILE: master/master_test.go ================================================ // Copyright 2020 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package master import ( "fmt" "testing" "github.com/gorse-io/gorse/common/monitor" "github.com/gorse-io/gorse/config" "github.com/gorse-io/gorse/storage/cache" "github.com/gorse-io/gorse/storage/data" "github.com/stretchr/testify/suite" ) type MasterTestSuite struct { suite.Suite Master } func (s *MasterTestSuite) SetupTest() { // open database var err error s.tracer = monitor.NewTracer("test") s.Config = config.GetDefaultConfig() s.DataClient, err = data.Open(fmt.Sprintf("sqlite://%s/data.db", s.T().TempDir()), "") s.NoError(err) s.CacheClient, err = cache.Open(fmt.Sprintf("sqlite://%s/cache.db", s.T().TempDir()), "") s.NoError(err) // init database err = s.DataClient.Init() s.NoError(err) err = s.CacheClient.Init() s.NoError(err) } func (s *MasterTestSuite) TearDownTest() { s.NoError(s.DataClient.Close()) s.NoError(s.CacheClient.Close()) } func TestMaster(t *testing.T) { suite.Run(t, new(MasterTestSuite)) } ================================================ FILE: master/metrics.go ================================================ // Copyright 2021 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package master import ( "time" mapset "github.com/deckarep/golang-set/v2" "github.com/gorse-io/gorse/common/expression" "github.com/gorse-io/gorse/storage/cache" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" ) const ( LabelFeedbackType = "feedback_type" LabelStep = "step" LabelData = "data" ) var ( LoadDatasetStepSecondsVec = promauto.NewGaugeVec(prometheus.GaugeOpts{ Namespace: "gorse", Subsystem: "master", Name: "load_dataset_step_seconds", }, []string{LabelStep}) LoadDatasetTotalSeconds = promauto.NewGauge(prometheus.GaugeOpts{ Namespace: "gorse", Subsystem: "master", Name: "load_dataset_total_seconds", }) FindUserNeighborsSecondsVec = promauto.NewGaugeVec(prometheus.GaugeOpts{ Namespace: "gorse", Subsystem: "master", Name: "find_user_neighbors_seconds", }, []string{LabelStep}) FindUserNeighborsTotalSeconds = promauto.NewGauge(prometheus.GaugeOpts{ Namespace: "gorse", Subsystem: "master", Name: "find_user_neighbors_total_seconds", }) FindItemNeighborsSecondsVec = promauto.NewGaugeVec(prometheus.GaugeOpts{ Namespace: "gorse", Subsystem: "master", Name: "find_item_neighbors_seconds", }, []string{"step"}) FindItemNeighborsTotalSeconds = promauto.NewGauge(prometheus.GaugeOpts{ Namespace: "gorse", Subsystem: "master", Name: "find_item_neighbors_total_seconds", }) UpdateUserNeighborsTotal = promauto.NewGauge(prometheus.GaugeOpts{ Namespace: "gorse", Subsystem: "master", Name: "update_user_neighbors_total", }) UpdateItemNeighborsTotal = promauto.NewGauge(prometheus.GaugeOpts{ Namespace: "gorse", Subsystem: "master", Name: "update_item_neighbors_total", }) CacheScannedTotal = promauto.NewGauge(prometheus.GaugeOpts{ Namespace: "gorse", Subsystem: "master", Name: "cache_scanned_total", }) CacheReclaimedTotal = promauto.NewGauge(prometheus.GaugeOpts{ Namespace: "gorse", Subsystem: "master", Name: "cache_reclaimed_total", }) CacheScannedSeconds = promauto.NewGauge(prometheus.GaugeOpts{ Namespace: "gorse", Subsystem: "master", Name: "cache_scanned_seconds", }) CollaborativeFilteringFitSeconds = promauto.NewGauge(prometheus.GaugeOpts{ Namespace: "gorse", Subsystem: "master", Name: "collaborative_filtering_fit_seconds", }) CollaborativeFilteringSearchSeconds = promauto.NewGauge(prometheus.GaugeOpts{ Namespace: "gorse", Subsystem: "master", Name: "collaborative_filtering_search_seconds", }) CollaborativeFilteringNDCG10 = promauto.NewGauge(prometheus.GaugeOpts{ Namespace: "gorse", Subsystem: "master", Name: "collaborative_filtering_ndcg_10", }) CollaborativeFilteringPrecision10 = promauto.NewGauge(prometheus.GaugeOpts{ Namespace: "gorse", Subsystem: "master", Name: "collaborative_filtering_precision_10", }) CollaborativeFilteringRecall10 = promauto.NewGauge(prometheus.GaugeOpts{ Namespace: "gorse", Subsystem: "master", Name: "collaborative_filtering_recall_10", }) CollaborativeFilteringSearchPrecision10 = promauto.NewGauge(prometheus.GaugeOpts{ Namespace: "gorse", Subsystem: "master", Name: "collaborative_filtering_search_precision_10", }) RankingFitSeconds = promauto.NewGauge(prometheus.GaugeOpts{ Namespace: "gorse", Subsystem: "master", Name: "ranking_fit_seconds", }) RankingSearchSeconds = promauto.NewGauge(prometheus.GaugeOpts{ Namespace: "gorse", Subsystem: "master", Name: "ranking_search_seconds", }) RankingPrecision = promauto.NewGauge(prometheus.GaugeOpts{ Namespace: "gorse", Subsystem: "master", Name: "ranking_model_precision", }) RankingRecall = promauto.NewGauge(prometheus.GaugeOpts{ Namespace: "gorse", Subsystem: "master", Name: "ranking_model_recall", }) RankingAUC = promauto.NewGauge(prometheus.GaugeOpts{ Namespace: "gorse", Subsystem: "master", Name: "ranking_model_auc", }) RankingSearchPrecision = promauto.NewGauge(prometheus.GaugeOpts{ Namespace: "gorse", Subsystem: "master", Name: "ranking_search_precision", }) UsersTotal = promauto.NewGauge(prometheus.GaugeOpts{ Namespace: "gorse", Subsystem: "master", Name: "users_total", }) ActiveUsersTotal = promauto.NewGauge(prometheus.GaugeOpts{ Namespace: "gorse", Subsystem: "master", Name: "active_users_total", }) InactiveUsersTotal = promauto.NewGauge(prometheus.GaugeOpts{ Namespace: "gorse", Subsystem: "master", Name: "inactive_users_total", }) ItemsTotal = promauto.NewGauge(prometheus.GaugeOpts{ Namespace: "gorse", Subsystem: "master", Name: "items_total", }) ActiveItemsTotal = promauto.NewGauge(prometheus.GaugeOpts{ Namespace: "gorse", Subsystem: "master", Name: "active_items_total", }) InactiveItemsTotal = promauto.NewGauge(prometheus.GaugeOpts{ Namespace: "gorse", Subsystem: "master", Name: "inactive_items_total", }) UserLabelsTotal = promauto.NewGauge(prometheus.GaugeOpts{ Namespace: "gorse", Subsystem: "master", Name: "user_labels_total", }) ItemLabelsTotal = promauto.NewGauge(prometheus.GaugeOpts{ Namespace: "gorse", Subsystem: "master", Name: "item_labels_total", }) FeedbacksTotal = promauto.NewGauge(prometheus.GaugeOpts{ Namespace: "gorse", Subsystem: "master", Name: "feedbacks_total", }) ImplicitFeedbacksTotal = promauto.NewGauge(prometheus.GaugeOpts{ Namespace: "gorse", Subsystem: "master", Name: "implicit_feedbacks_total", }) PositiveFeedbacksTotal = promauto.NewGauge(prometheus.GaugeOpts{ Namespace: "gorse", Subsystem: "master", Name: "positive_feedbacks_total", }) NegativeFeedbackTotal = promauto.NewGauge(prometheus.GaugeOpts{ Namespace: "gorse", Subsystem: "master", Name: "negative_feedbacks_total", }) MemoryInUseBytesVec = promauto.NewGaugeVec(prometheus.GaugeOpts{ Namespace: "gorse", Subsystem: "master", Name: "memory_inuse_bytes", }, []string{LabelData}) ) type OnlineEvaluator struct { ReadTypes []expression.FeedbackTypeExpression ReadFeedback []map[int32]mapset.Set[int32] PositiveTypes []expression.FeedbackTypeExpression PositiveFeedback map[string]map[int32]mapset.Set[int32] WindowSize int WindowEnd time.Time } func NewOnlineEvaluator(positiveTypes, readTypes []expression.FeedbackTypeExpression) *OnlineEvaluator { evaluator := new(OnlineEvaluator) evaluator.WindowSize = 30 evaluator.WindowEnd = time.Now().Truncate(time.Hour * 24) evaluator.ReadTypes = readTypes evaluator.ReadFeedback = make([]map[int32]mapset.Set[int32], evaluator.WindowSize) for i := 0; i < evaluator.WindowSize; i++ { evaluator.ReadFeedback[i] = make(map[int32]mapset.Set[int32]) } evaluator.PositiveTypes = positiveTypes evaluator.PositiveFeedback = make(map[string]map[int32]mapset.Set[int32]) evaluator.PositiveFeedback[""] = make(map[int32]mapset.Set[int32]) return evaluator } func (evaluator *OnlineEvaluator) Add(feedbackType string, value float64, userIndex int32, itemIndex int32, timestamp time.Time) { if expression.MatchFeedbackTypeExpressions(evaluator.ReadTypes, feedbackType, value) { truncated := timestamp.Truncate(time.Hour * 24) windowIndex := int(evaluator.WindowEnd.Sub(truncated) / time.Hour / 24) if windowIndex < 0 || windowIndex >= evaluator.WindowSize { return } if evaluator.ReadFeedback[windowIndex][userIndex] == nil { evaluator.ReadFeedback[windowIndex][userIndex] = mapset.NewSet[int32]() } evaluator.ReadFeedback[windowIndex][userIndex].Add(itemIndex) } if expression.MatchFeedbackTypeExpressions(evaluator.PositiveTypes, feedbackType, value) { if evaluator.PositiveFeedback[feedbackType] == nil { evaluator.PositiveFeedback[feedbackType] = make(map[int32]mapset.Set[int32]) } if evaluator.PositiveFeedback[feedbackType][userIndex] == nil { evaluator.PositiveFeedback[feedbackType][userIndex] = mapset.NewSet[int32]() } evaluator.PositiveFeedback[feedbackType][userIndex].Add(itemIndex) if evaluator.PositiveFeedback[""][userIndex] == nil { evaluator.PositiveFeedback[""][userIndex] = mapset.NewSet[int32]() } evaluator.PositiveFeedback[""][userIndex].Add(itemIndex) } } func (evaluator *OnlineEvaluator) Evaluate() []cache.TimeSeriesPoint { var points []cache.TimeSeriesPoint for feedbackType := range evaluator.PositiveFeedback { for i := 0; i < evaluator.WindowSize; i++ { date := evaluator.WindowEnd.AddDate(0, 0, -i) var ratioSum float64 var userCount int for userIndex, readItems := range evaluator.ReadFeedback[i] { positiveItems, ok := evaluator.PositiveFeedback[feedbackType][userIndex] if !ok { continue } positiveCount := float64(readItems.Intersect(positiveItems).Cardinality()) if readItems.Cardinality() > 0 { ratioSum += positiveCount / float64(readItems.Cardinality()) userCount++ } } if userCount > 0 { name := cache.PositiveFeedbackRatio if feedbackType != "" { name += "_" + feedbackType } points = append(points, cache.TimeSeriesPoint{ Name: name, Timestamp: date, Value: ratioSum / float64(userCount), }) } } } return points } ================================================ FILE: master/metrics_test.go ================================================ // Copyright 2022 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package master import ( "testing" "time" "github.com/gorse-io/gorse/common/expression" "github.com/gorse-io/gorse/storage/cache" "github.com/stretchr/testify/assert" ) func TestOnlineEvaluator(t *testing.T) { positiveTypes := []expression.FeedbackTypeExpression{expression.MustParseFeedbackTypeExpression("read>=100")} readTypes := []expression.FeedbackTypeExpression{expression.MustParseFeedbackTypeExpression("read")} evaluator := NewOnlineEvaluator(positiveTypes, readTypes) result := evaluator.Evaluate() assert.Empty(t, result) evaluator = NewOnlineEvaluator(positiveTypes, readTypes) evaluator.WindowEnd = time.Date(2005, 6, 16, 0, 0, 0, 0, time.UTC) evaluator.WindowSize = 2 evaluator.Add("read", 0, 1, 1, time.Date(2005, 6, 15, 0, 0, 0, 0, time.UTC)) evaluator.Add("read", 0, 1, 2, time.Date(2005, 6, 15, 0, 0, 0, 0, time.UTC)) evaluator.Add("read", 0, 1, 3, time.Date(2005, 6, 15, 0, 0, 0, 0, time.UTC)) evaluator.Add("read", 100, 1, 4, time.Date(2005, 6, 15, 0, 0, 0, 0, time.UTC)) evaluator.Add("read", 100, 2, 1, time.Date(2005, 6, 15, 0, 0, 0, 0, time.UTC)) evaluator.Add("read", 0, 2, 2, time.Date(2005, 6, 15, 0, 0, 0, 0, time.UTC)) evaluator.Add("read", 0, 2, 3, time.Date(2005, 6, 15, 0, 0, 0, 0, time.UTC)) evaluator.Add("read", 0, 2, 4, time.Date(2005, 6, 15, 0, 0, 0, 0, time.UTC)) evaluator.Add("read", 0, 2, 3, time.Date(2005, 6, 16, 0, 0, 0, 0, time.UTC)) evaluator.Add("read", 100, 3, 3, time.Date(2005, 6, 16, 0, 0, 0, 0, time.UTC)) result = evaluator.Evaluate() assert.ElementsMatch(t, []cache.TimeSeriesPoint{ {Name: "positive_feedback_ratio_read", Timestamp: time.Date(2005, 6, 16, 0, 0, 0, 0, time.UTC), Value: 0.5}, {Name: "positive_feedback_ratio_read", Timestamp: time.Date(2005, 6, 15, 0, 0, 0, 0, time.UTC), Value: 0.25}, {Name: "positive_feedback_ratio", Timestamp: time.Date(2005, 6, 16, 0, 0, 0, 0, time.UTC), Value: 0.5}, {Name: "positive_feedback_ratio", Timestamp: time.Date(2005, 6, 15, 0, 0, 0, 0, time.UTC), Value: 0.25}, }, result) } ================================================ FILE: master/rest.go ================================================ // Copyright 2021 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package master import ( "context" "encoding/base64" "encoding/binary" "encoding/json" "fmt" "io" "net/http" "os" "sort" "strings" "time" "github.com/araddon/dateparse" mapset "github.com/deckarep/golang-set/v2" restfulspec "github.com/emicklei/go-restful-openapi/v2" "github.com/emicklei/go-restful/v3" "github.com/go-viper/mapstructure/v2" "github.com/gorilla/securecookie" _ "github.com/gorse-io/dashboard" "github.com/gorse-io/gorse/cmd/version" "github.com/gorse-io/gorse/common/expression" "github.com/gorse-io/gorse/common/log" "github.com/gorse-io/gorse/common/monitor" "github.com/gorse-io/gorse/common/util" "github.com/gorse-io/gorse/config" "github.com/gorse-io/gorse/logics" "github.com/gorse-io/gorse/model/cf" "github.com/gorse-io/gorse/model/ctr" "github.com/gorse-io/gorse/protocol" "github.com/gorse-io/gorse/server" "github.com/gorse-io/gorse/storage/cache" "github.com/gorse-io/gorse/storage/data" "github.com/gorse-io/gorse/storage/meta" "github.com/invopop/jsonschema" "github.com/juju/errors" "github.com/nikolalohinski/gonja/v2" "github.com/nikolalohinski/gonja/v2/exec" "github.com/rakyll/statik/fs" "github.com/samber/lo" "github.com/sashabaranov/go-openai" "go.uber.org/zap" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/timestamppb" ) type UserInfo struct { Name string `json:"name"` FamilyName string `json:"family_name"` GivenName string `json:"given_name"` MiddleName string `json:"middle_name"` NickName string `json:"nickname"` Picture string `json:"picture"` UpdatedAt string `json:"updated_at"` Email string `json:"email"` Verified bool `json:"email_verified"` AuthType string `json:"auth_type"` } type RerankerPrompt struct { Query string `json:"query"` Documents []string `json:"documents"` } func (m *Master) CreateWebService() { ws := m.WebService ws.Consumes(restful.MIME_JSON).Produces(restful.MIME_JSON) ws.Path("/api/") ws.Filter(m.LoginFilter) ws.Route(ws.GET("/dashboard/userinfo").To(m.handleUserInfo). Doc("Get login user information."). Metadata(restfulspec.KeyOpenAPITags, []string{"dashboard"}). Returns(http.StatusOK, "OK", UserInfo{}). Writes(UserInfo{})) ws.Route(ws.GET("/dashboard/cluster").To(m.getCluster). Doc("Get nodes in the cluster."). Metadata(restfulspec.KeyOpenAPITags, []string{"dashboard"}). Returns(http.StatusOK, "OK", []meta.Node{}). Writes([]meta.Node{})) ws.Route(ws.GET("/dashboard/categories").To(m.getCategories). Doc("Get categories of items."). Metadata(restfulspec.KeyOpenAPITags, []string{"dashboard"}). Returns(http.StatusOK, "OK", []string{}). Writes([]string{})) ws.Route(ws.POST("/dashboard/config").To(m.postConfig). Doc("Update config."). Metadata(restfulspec.KeyOpenAPITags, []string{"dashboard"}). Returns(http.StatusOK, "OK", config.Config{}). Writes(config.Config{})) ws.Route(ws.GET("/dashboard/config").To(m.getConfig). Doc("Get config."). Metadata(restfulspec.KeyOpenAPITags, []string{"dashboard"}). Returns(http.StatusOK, "OK", config.Config{}). Writes(config.Config{})) ws.Route(ws.DELETE("/dashboard/config").To(m.deleteConfig). Doc("Delete config."). Metadata(restfulspec.KeyOpenAPITags, []string{"dashboard"}). Returns(http.StatusOK, "OK", struct{}{})) ws.Route(ws.GET("/dashboard/config/schema").To(m.getConfigSchema). Doc("Get config schema."). Metadata(restfulspec.KeyOpenAPITags, []string{"dashboard"}). Returns(http.StatusOK, "OK", config.Config{}). Writes(jsonschema.Schema{})) ws.Route(ws.GET("/dashboard/stats").To(m.getStats). Doc("Get global status."). Metadata(restfulspec.KeyOpenAPITags, []string{"dashboard"}). Param(ws.HeaderParameter("X-API-Key", "secret key for RESTful API")). Returns(http.StatusOK, "OK", Status{}). Writes(Status{})) ws.Route(ws.GET("/dashboard/tasks").To(m.getTasks). Doc("Get tasks."). Metadata(restfulspec.KeyOpenAPITags, []string{"dashboard"}). Param(ws.HeaderParameter("X-API-Key", "secret key for RESTful API")). Returns(http.StatusOK, "OK", []monitor.Progress{}). Writes([]monitor.Progress{})) ws.Route(ws.GET("/dashboard/timeseries/{name}").To(m.getTimeseries). Doc("Get time series data."). Metadata(restfulspec.KeyOpenAPITags, []string{"dashboard"}). Param(ws.HeaderParameter("X-API-Key", "secret key for RESTful API")). Param(ws.PathParameter("name", "name of the time series").DataType("string")). Returns(http.StatusOK, "OK", map[string][]cache.TimeSeriesPoint{}). Writes(map[string][]cache.TimeSeriesPoint{})) // Get a user ws.Route(ws.GET("/dashboard/user/{user-id}").To(m.getUser). Doc("Get a user."). Metadata(restfulspec.KeyOpenAPITags, []string{"dashboard"}). Param(ws.PathParameter("user-id", "identifier of the user").DataType("string")). Returns(http.StatusOK, "OK", User{}). Writes(User{})) // Get a user feedback ws.Route(ws.GET("/dashboard/user/{user-id}/feedback/").To(m.getUserFeedback). Doc("Get feedback by user id."). Metadata(restfulspec.KeyOpenAPITags, []string{"feedback"}). Param(ws.HeaderParameter("X-API-Key", "secret key for RESTful API")). Param(ws.PathParameter("user-id", "identifier of the user").DataType("string")). Param(ws.QueryParameter("n", "number of returned feedback").DataType("int")). Param(ws.QueryParameter("offset", "offset of returned feedback").DataType("int")). Returns(http.StatusOK, "OK", []DetailedFeedback{}). Writes([]DetailedFeedback{})) ws.Route(ws.GET("/dashboard/user/{user-id}/feedback/{feedback-type}").To(m.getUserFeedback). Doc("Get feedback by user id with feedback type."). Metadata(restfulspec.KeyOpenAPITags, []string{"feedback"}). Param(ws.HeaderParameter("X-API-Key", "secret key for RESTful API")). Param(ws.PathParameter("user-id", "identifier of the user").DataType("string")). Param(ws.PathParameter("feedback-type", "feedback type").DataType("string")). Param(ws.QueryParameter("n", "number of returned feedback").DataType("int")). Param(ws.QueryParameter("offset", "offset of returned feedback").DataType("int")). Returns(http.StatusOK, "OK", []DetailedFeedback{}). Writes([]DetailedFeedback{})) // Get users ws.Route(ws.GET("/dashboard/users").To(m.getUsers). Doc("Get users."). Metadata(restfulspec.KeyOpenAPITags, []string{"dashboard"}). Param(ws.QueryParameter("n", "number of returned users").DataType("int")). Param(ws.QueryParameter("cursor", "cursor for next page").DataType("string")). Returns(http.StatusOK, "OK", UserIterator{}). Writes(UserIterator{})) // Get non-personalized recommendation ws.Route(ws.GET("/dashboard/latest").To(m.getLatest). Doc("Get latest items."). Metadata(restfulspec.KeyOpenAPITags, []string{"dashboard"}). Param(ws.QueryParameter("category", "Category of returned items.").DataType("string")). Param(ws.QueryParameter("n", "Number of returned items").DataType("integer")). Param(ws.QueryParameter("offset", "Offset of returned items").DataType("integer")). Returns(http.StatusOK, "OK", []data.Item{}). Writes([]ScoredItem{})) ws.Route(ws.GET("/dashboard/non-personalized/{name}").To(m.getNonPersonalized). Doc("Get non-personalized recommendations."). Metadata(restfulspec.KeyOpenAPITags, []string{"dashboard"}). Param(ws.QueryParameter("category", "Category of returned items.").DataType("string")). Param(ws.QueryParameter("n", "Number of returned items").DataType("integer")). Param(ws.QueryParameter("offset", "Offset of returned items").DataType("integer")). Param(ws.QueryParameter("user-id", "Remove read items of a user").DataType("string")). Returns(http.StatusOK, "OK", []cache.Score{}). Writes([]cache.Score{})) ws.Route(ws.GET("/dashboard/recommend/{user-id}").To(m.getRecommend). Doc("Get recommendation for user."). Metadata(restfulspec.KeyOpenAPITags, []string{"dashboard"}). Param(ws.PathParameter("user-id", "identifier of the user").DataType("string")). Param(ws.QueryParameter("category", "category of items").DataType("string")). Param(ws.QueryParameter("n", "number of returned items").DataType("int")). Returns(http.StatusOK, "OK", []data.Item{}). Writes([]data.Item{})) ws.Route(ws.GET("/dashboard/recommend/{user-id}/{recommender}").To(m.getRecommend). Doc("Get recommendation for user."). Metadata(restfulspec.KeyOpenAPITags, []string{"dashboard"}). Param(ws.PathParameter("user-id", "identifier of the user").DataType("string")). Param(ws.PathParameter("recommender", "one of `final`, `collaborative`, `user_based` and `item_based`").DataType("string")). Param(ws.QueryParameter("category", "category of items").DataType("string")). Param(ws.QueryParameter("n", "number of returned items").DataType("int")). Returns(http.StatusOK, "OK", []data.Item{}). Writes([]data.Item{})) ws.Route(ws.GET("/dashboard/recommend/{user-id}/{recommender}/{name}").To(m.getRecommend). Doc("Get recommendation for user."). Metadata(restfulspec.KeyOpenAPITags, []string{"dashboard"}). Param(ws.PathParameter("user-id", "identifier of the user").DataType("string")). Param(ws.PathParameter("recommender", "one of `final`, `collaborative`, `user_based` and `item_based`").DataType("string")). Param(ws.PathParameter("name", "name of the recommender").DataType("string")). Param(ws.QueryParameter("category", "category of items").DataType("string")). Param(ws.QueryParameter("n", "number of returned items").DataType("int")). Returns(http.StatusOK, "OK", []data.Item{}). Writes([]data.Item{})) ws.Route(ws.GET("/dashboard/item-to-item/{name}/{item-id}").To(m.getItemToItem). Doc("get neighbors of a item"). Metadata(restfulspec.KeyOpenAPITags, []string{"recommendation"}). Param(ws.PathParameter("item-id", "identifier of the item").DataType("string")). Param(ws.QueryParameter("n", "number of returned items").DataType("int")). Param(ws.QueryParameter("offset", "offset of the list").DataType("int")). Returns(http.StatusOK, "OK", []ScoredItem{}). Writes([]ScoredItem{})) ws.Route(ws.GET("/dashboard/user-to-user/{name}/{user-id}").To(m.getUserToUser). Doc("get neighbors of a user"). Metadata(restfulspec.KeyOpenAPITags, []string{"dashboard"}). Param(ws.PathParameter("user-id", "identifier of the user").DataType("string")). Param(ws.QueryParameter("n", "number of returned users").DataType("int")). Param(ws.QueryParameter("offset", "offset of the list").DataType("int")). Returns(http.StatusOK, "OK", []ScoreUser{}). Writes([]ScoreUser{})) ws.Route(ws.GET("/dashboard/external").To(m.getExternal). Doc("get external recommendations preview"). Metadata(restfulspec.KeyOpenAPITags, []string{"dashboard"}). Param(ws.QueryParameter("script", "external script").DataType("string")). Param(ws.QueryParameter("user-id", "identifier of the user").DataType("string")). Returns(http.StatusOK, "OK", []string{}). Writes([]string{})) ws.Route(ws.GET("/dashboard/ranker/prompt").To(m.getRankerPrompt). Doc("Get ranker prompt."). Metadata(restfulspec.KeyOpenAPITags, []string{"dashboard"}). Param(ws.QueryParameter("query-template", "query template (base64)").DataType("string")). Param(ws.QueryParameter("document-template", "document template (base64)").DataType("string")). Param(ws.QueryParameter("user-id", "identifier of the user").DataType("string")). Returns(http.StatusOK, "OK", RerankerPrompt{}). Writes(RerankerPrompt{})) } // SinglePageAppFileSystem is the file system for single page app. type SinglePageAppFileSystem struct { root http.FileSystem } // Open index.html if required file not exists. func (fs *SinglePageAppFileSystem) Open(name string) (http.File, error) { f, err := fs.root.Open(name) if os.IsNotExist(err) { return fs.root.Open("/index.html") } return f, err } func (m *Master) StartHttpServer() { m.CreateWebService() container := restful.NewContainer() container.Handle("/", http.HandlerFunc(m.dashboard)) container.Handle("/login", http.HandlerFunc(m.login)) container.Handle("/logout", http.HandlerFunc(m.logout)) container.Handle("/callback/oauth2", http.HandlerFunc(m.handleOAuth2Callback)) container.Handle("/api/purge", http.HandlerFunc(m.purge)) container.Handle("/api/bulk/users", http.HandlerFunc(m.importExportUsers)) container.Handle("/api/bulk/items", http.HandlerFunc(m.importExportItems)) container.Handle("/api/bulk/feedback", http.HandlerFunc(m.importExportFeedback)) container.Handle("/api/dump", http.HandlerFunc(m.dump)) container.Handle("/api/restore", http.HandlerFunc(m.restore)) container.Handle("/api/chat", http.HandlerFunc(m.chat)) m.RestServer.StartHttpServer(container) } var ( cookieHandler = securecookie.New( securecookie.GenerateRandomKey(64), securecookie.GenerateRandomKey(32)) staticFileSystem http.FileSystem staticFileServer http.Handler ) func init() { var err error staticFileSystem, err = fs.New() if err != nil { log.Logger().Fatal("failed to load statik files", zap.Error(err)) } staticFileServer = http.FileServer(&SinglePageAppFileSystem{staticFileSystem}) // Create temporary directory if not exist tempDir := os.TempDir() if err = os.MkdirAll(tempDir, 1777); err != nil { log.Logger().Fatal("failed to create temporary directory", zap.String("directory", tempDir), zap.Error(err)) } } // Taken from https://github.com/mytrile/nocache var noCacheHeaders = map[string]string{ "Expires": time.Unix(0, 0).Format(time.RFC1123), "Cache-Control": "no-cache, no-store, no-transform, must-revalidate, private, max-age=0", "Pragma": "no-cache", "X-Accel-Expires": "0", } var etagHeaders = []string{ "ETag", "If-Modified-Since", "If-Match", "If-None-Match", "If-Range", "If-Unmodified-Since", } // noCache is a simple piece of middleware that sets a number of HTTP headers to prevent // a router (or subrouter) from being cached by an upstream proxy and/or client. // // As per http://wiki.nginx.org/HttpProxyModule - noCache sets: // // Expires: Thu, 01 Jan 1970 00:00:00 UTC // Cache-Control: no-cache, private, max-age=0 // X-Accel-Expires: 0 // Pragma: no-cache (for HTTP/1.0 proxies/clients) func noCache(h http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { // Delete any ETag headers that may have been set for _, v := range etagHeaders { if r.Header.Get(v) != "" { r.Header.Del(v) } } // Set our noCache headers for k, v := range noCacheHeaders { w.Header().Set(k, v) } h.ServeHTTP(w, r) } return http.HandlerFunc(fn) } func (m *Master) dashboard(response http.ResponseWriter, request *http.Request) { _, err := staticFileSystem.Open(request.RequestURI) if request.RequestURI == "/" || os.IsNotExist(err) { if !m.checkLogin(request) { if m.Config.OIDC.Enable { // Redirect to OIDC login http.Redirect(response, request, m.oauth2Config.AuthCodeURL(""), http.StatusFound) } else { http.Redirect(response, request, "/login", http.StatusFound) log.Logger().Info(fmt.Sprintf("%s %s", request.Method, request.URL), zap.Int("status_code", http.StatusFound)) } return } noCache(staticFileServer).ServeHTTP(response, request) return } staticFileServer.ServeHTTP(response, request) } func (m *Master) login(response http.ResponseWriter, request *http.Request) { switch request.Method { case http.MethodGet: log.Logger().Info("GET /login", zap.Int("status_code", http.StatusOK)) staticFileServer.ServeHTTP(response, request) case http.MethodPost: name := request.FormValue("user_name") pass := request.FormValue("password") if m.Config.Master.DashboardUserName != "" || m.Config.Master.DashboardPassword != "" { if name != m.Config.Master.DashboardUserName || pass != m.Config.Master.DashboardPassword { http.Redirect(response, request, "login?msg=incorrect", http.StatusFound) log.Logger().Info("POST /login", zap.Int("status_code", http.StatusUnauthorized)) return } value := map[string]string{ "user_name": name, "password": pass, } if encoded, err := cookieHandler.Encode("session", value); err != nil { server.InternalServerError(restful.NewResponse(response), err) return } else { cookie := &http.Cookie{ Name: "session", Value: encoded, Path: "/", } http.SetCookie(response, cookie) http.Redirect(response, request, "/", http.StatusFound) log.Logger().Info("POST /login", zap.Int("status_code", http.StatusFound)) return } } else { http.Redirect(response, request, "/", http.StatusFound) log.Logger().Info("POST /login", zap.Int("status_code", http.StatusFound)) } default: server.BadRequest(restful.NewResponse(response), errors.New("unsupported method")) } } func (m *Master) logout(response http.ResponseWriter, request *http.Request) { cookie := &http.Cookie{ Name: "session", Value: "", Path: "/", MaxAge: -1, } http.SetCookie(response, cookie) http.Redirect(response, request, "/login", http.StatusFound) log.Logger().Info(fmt.Sprintf("%s %s", request.Method, request.RequestURI), zap.Int("status_code", http.StatusFound)) } func (m *Master) LoginFilter(req *restful.Request, resp *restful.Response, chain *restful.FilterChain) { if m.checkLogin(req.Request) { req.Request.Header.Set("X-API-Key", m.Config.Server.APIKey) chain.ProcessFilter(req, resp) } else if !strings.HasPrefix(req.SelectedRoutePath(), "/api/dashboard") { chain.ProcessFilter(req, resp) } else { if err := resp.WriteError(http.StatusUnauthorized, fmt.Errorf("unauthorized")); err != nil { log.ResponseLogger(resp).Error("failed to write error", zap.Error(err)) } } } func (m *Master) checkLogin(request *http.Request) bool { if m.Config.Master.AdminAPIKey != "" && m.Config.Master.AdminAPIKey == request.Header.Get("X-Api-Key") { return true } if m.Config.OIDC.Enable { if tokenCookie, err := request.Cookie("id_token"); err == nil { var token string if err = cookieHandler.Decode("id_token", tokenCookie.Value, &token); err == nil { if m.tokenCache.Get(token) != nil { return true } } } return false } else if m.Config.Master.DashboardUserName != "" || m.Config.Master.DashboardPassword != "" { if sessionCookie, err := request.Cookie("session"); err == nil { cookieValue := make(map[string]string) if err = cookieHandler.Decode("session", sessionCookie.Value, &cookieValue); err == nil { userName := cookieValue["user_name"] password := cookieValue["password"] if userName == m.Config.Master.DashboardUserName && password == m.Config.Master.DashboardPassword { return true } } } return false } return true } func (m *Master) handleUserInfo(request *restful.Request, response *restful.Response) { if m.Config.OIDC.Enable { if tokenCookie, err := request.Request.Cookie("id_token"); err == nil { var token string if err = cookieHandler.Decode("id_token", tokenCookie.Value, &token); err == nil { if item := m.tokenCache.Get(token); item != nil { userInfo := item.Value() userInfo.AuthType = "OIDC" server.Ok(response, userInfo) return } } } } else if m.Config.Master.DashboardUserName != "" { server.Ok(response, UserInfo{ Name: m.Config.Master.DashboardUserName, }) } else { response.Header().Set("Content-Type", "application/json") if _, err := response.Write([]byte("null")); err != nil { log.ResponseLogger(response).Error("failed to write response", zap.Error(err)) } } } func (m *Master) getCategories(request *restful.Request, response *restful.Response) { ctx := context.Background() if request != nil && request.Request != nil { ctx = request.Request.Context() } categoryScores, err := m.CacheClient.SearchScores(ctx, cache.ItemCategories, "", nil, 0, -1) if err != nil { server.InternalServerError(response, err) return } categories := make([]string, len(categoryScores)) for i, score := range categoryScores { categories[i] = score.Id } server.Ok(response, categories) } func (m *Master) getCluster(_ *restful.Request, response *restful.Response) { nodes, err := m.metaStore.ListNodes() if err != nil { server.InternalServerError(response, err) return } sort.Slice(nodes, func(i, j int) bool { return nodes[i].Type < nodes[j].Type }) server.Ok(response, nodes) } func formatConfig(configMap map[string]interface{}) map[string]interface{} { return lo.MapValues(configMap, func(v interface{}, _ string) interface{} { switch value := v.(type) { case time.Duration: s := value.String() if strings.HasSuffix(s, "m0s") { s = s[:len(s)-2] } if strings.HasSuffix(s, "h0m") { s = s[:len(s)-2] } return s case map[string]interface{}: return formatConfig(value) default: return v } }) } func (m *Master) postConfig(request *restful.Request, response *restful.Response) { var newConfigMap map[string]any if err := request.ReadEntity(&newConfigMap); err != nil { server.BadRequest(response, err) return } var newConfig config.Config decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ DecodeHook: mapstructure.ComposeDecodeHookFunc( mapstructure.StringToTimeDurationHookFunc(), config.StringToFeedbackTypeHookFunc(), ), Result: &newConfig, }) if err != nil { server.InternalServerError(response, err) return } if err = decoder.Decode(newConfigMap); err != nil { server.BadRequest(response, err) return } configForValidation := *m.Config configForValidation.Recommend = newConfig.Recommend if err = configForValidation.Validate(); err != nil { server.BadRequest(response, err) return } recommendConfigBytes, err := json.Marshal(newConfig.Recommend) if err != nil { server.InternalServerError(response, err) return } if err = m.metaStore.Put(meta.RECOMMEND_CONFIG, string(recommendConfigBytes)); err != nil { server.InternalServerError(response, err) return } m.Config.Recommend = newConfig.Recommend m.cancel() select { case m.scheduled <- struct{}{}: default: } server.Ok(response, newConfig) } func (m *Master) getConfig(_ *restful.Request, response *restful.Response) { var configMap map[string]interface{} err := mapstructure.Decode(m.Config, &configMap) if err != nil { server.InternalServerError(response, err) return } if m.Config.Master.DashboardRedacted { delete(configMap, "database") } server.Ok(response, formatConfig(configMap)) } func (m *Master) deleteConfig(_request *restful.Request, response *restful.Response) { if err := m.metaStore.Delete(meta.RECOMMEND_CONFIG); err != nil { server.InternalServerError(response, err) return } newConfig, err := config.LoadConfig(m.configPath) if err != nil { server.InternalServerError(response, err) return } m.Config.Recommend = newConfig.Recommend m.cancel() select { case m.scheduled <- struct{}{}: default: } server.Ok(response, struct{}{}) } func (m *Master) getConfigSchema(_ *restful.Request, response *restful.Response) { server.Ok(response, jsonschema.Reflect(m.Config)) } type Status struct { BinaryVersion string NumServers int NumWorkers int NumUsers int NumItems int NumUserLabels int NumItemLabels int NumTotalPosFeedback int NumValidPosFeedback int NumValidNegFeedback int PopularItemsUpdateTime time.Time LatestItemsUpdateTime time.Time MatchingModelFitTime time.Time MatchingModelScore cf.Score RankingModelFitTime time.Time RankingModelScore ctr.Score } func (m *Master) getStats(request *restful.Request, response *restful.Response) { ctx := context.Background() if request != nil && request.Request != nil { ctx = request.Request.Context() } status := Status{BinaryVersion: version.Version} var err error // read number of users if status.NumUsers, err = m.CacheClient.Get(ctx, cache.Key(cache.GlobalMeta, cache.NumUsers)).Integer(); err != nil { log.ResponseLogger(response).Warn("failed to get number of users", zap.Error(err)) } // read number of items if status.NumItems, err = m.CacheClient.Get(ctx, cache.Key(cache.GlobalMeta, cache.NumItems)).Integer(); err != nil { log.ResponseLogger(response).Warn("failed to get number of items", zap.Error(err)) } // read number of user labels if status.NumUserLabels, err = m.CacheClient.Get(ctx, cache.Key(cache.GlobalMeta, cache.NumUserLabels)).Integer(); err != nil { log.ResponseLogger(response).Warn("failed to get number of user labels", zap.Error(err)) } // read number of item labels if status.NumItemLabels, err = m.CacheClient.Get(ctx, cache.Key(cache.GlobalMeta, cache.NumItemLabels)).Integer(); err != nil { log.ResponseLogger(response).Warn("failed to get number of item labels", zap.Error(err)) } // read number of total positive feedback if status.NumTotalPosFeedback, err = m.CacheClient.Get(ctx, cache.Key(cache.GlobalMeta, cache.NumTotalPosFeedbacks)).Integer(); err != nil { log.ResponseLogger(response).Warn("failed to get number of total positive feedbacks", zap.Error(err)) } // read number of valid positive feedback if status.NumValidPosFeedback, err = m.CacheClient.Get(ctx, cache.Key(cache.GlobalMeta, cache.NumValidPosFeedbacks)).Integer(); err != nil { log.ResponseLogger(response).Warn("failed to get number of valid positive feedbacks", zap.Error(err)) } // read number of valid negative feedback if status.NumValidNegFeedback, err = m.CacheClient.Get(ctx, cache.Key(cache.GlobalMeta, cache.NumValidNegFeedbacks)).Integer(); err != nil { log.ResponseLogger(response).Warn("failed to get number of valid negative feedbacks", zap.Error(err)) } // count the number of workers and servers nodes, err := m.metaStore.ListNodes() if err != nil { server.InternalServerError(response, err) return } for _, node := range nodes { switch node.Type { case protocol.NodeType_Server.String(): status.NumServers++ case protocol.NodeType_Worker.String(): status.NumWorkers++ } } // read popular items update time if status.PopularItemsUpdateTime, err = m.CacheClient.Get(ctx, cache.Key(cache.GlobalMeta, cache.LastUpdatePopularItemsTime)).Time(); err != nil { log.ResponseLogger(response).Warn("failed to get popular items update time", zap.Error(err)) } // read the latest items update time if status.LatestItemsUpdateTime, err = m.CacheClient.Get(ctx, cache.Key(cache.GlobalMeta, cache.LastUpdateLatestItemsTime)).Time(); err != nil { log.ResponseLogger(response).Warn("failed to get latest items update time", zap.Error(err)) } // read last fit matching model time if status.MatchingModelFitTime, err = m.CacheClient.Get(ctx, cache.Key(cache.GlobalMeta, cache.LastFitMatchingModelTime)).Time(); err != nil { log.ResponseLogger(response).Warn("failed to get last fit matching model time", zap.Error(err)) } // read last fit ranking model time if status.RankingModelFitTime, err = m.CacheClient.Get(ctx, cache.Key(cache.GlobalMeta, cache.LastFitRankingModelTime)).Time(); err != nil { log.ResponseLogger(response).Warn("failed to get last fit ranking model time", zap.Error(err)) } server.Ok(response, status) } func (m *Master) getTasks(_ *restful.Request, response *restful.Response) { // List workers workers := mapset.NewSet[string]() nodes, err := m.metaStore.ListNodes() if err != nil { server.InternalServerError(response, err) return } for _, node := range nodes { if node.Type == protocol.NodeType_Worker.String() { workers.Add(node.UUID) } } // List local progress progressList := m.tracer.List() // list remote progress m.remoteProgress.Range(func(key, value interface{}) bool { if workers.Contains(key.(string)) { progressList = append(progressList, value.([]monitor.Progress)...) } return true }) server.Ok(response, progressList) } func (m *Master) getTimeseries(request *restful.Request, response *restful.Response) { ctx := context.Background() if request != nil && request.Request != nil { ctx = request.Request.Context() } // get time series name name := request.PathParameter("name") // get begin time beginStr := request.QueryParameter("begin") if beginStr == "" { beginStr = time.Now().Add(-7 * 24 * time.Hour).Format(time.RFC3339) } begin, err := dateparse.ParseAny(beginStr) if err != nil { server.BadRequest(response, err) return } // get end time endStr := request.QueryParameter("end") if endStr == "" { endStr = time.Now().Format(time.RFC3339) } end, err := dateparse.ParseAny(endStr) if err != nil { server.BadRequest(response, err) return } // get duration durationStr := request.QueryParameter("duration") if durationStr == "" { durationStr = "24h" } duration, err := time.ParseDuration(durationStr) if err != nil { server.BadRequest(response, err) return } // get time series data data, err := m.CacheClient.GetTimeSeriesPoints(ctx, cache.Key(name), begin, end, duration) if err != nil { server.InternalServerError(response, err) return } server.Ok(response, data) } type UserIterator struct { Cursor string Users []User } type User struct { data.User LastActiveTime time.Time LastUpdateTime time.Time } func (m *Master) getUser(request *restful.Request, response *restful.Response) { ctx := context.Background() if request != nil && request.Request != nil { ctx = request.Request.Context() } // get user id userId := request.PathParameter("user-id") // get user user, err := m.DataClient.GetUser(ctx, userId) if err != nil { if errors.Is(err, errors.NotFound) { server.PageNotFound(response, err) } else { server.InternalServerError(response, err) } return } detail := User{User: user} if detail.LastActiveTime, err = m.CacheClient.Get(ctx, cache.Key(cache.LastModifyUserTime, user.UserId)).Time(); err != nil && !errors.Is(err, errors.NotFound) { server.InternalServerError(response, err) return } if detail.LastUpdateTime, err = m.CacheClient.Get(ctx, cache.Key(cache.RecommendUpdateTime, user.UserId)).Time(); err != nil && !errors.Is(err, errors.NotFound) { server.InternalServerError(response, err) return } server.Ok(response, detail) } func (m *Master) getUsers(request *restful.Request, response *restful.Response) { ctx := context.Background() if request != nil && request.Request != nil { ctx = request.Request.Context() } // Authorize cursor := request.QueryParameter("cursor") n, err := server.ParseInt(request, "n", m.Config.Server.DefaultN) if err != nil { server.BadRequest(response, err) return } // get all users cursor, users, err := m.DataClient.GetUsers(ctx, cursor, n) if err != nil { server.InternalServerError(response, err) return } details := make([]User, len(users)) for i, user := range users { details[i].User = user if details[i].LastActiveTime, err = m.CacheClient.Get(ctx, cache.Key(cache.LastModifyUserTime, user.UserId)).Time(); err != nil && !errors.Is(err, errors.NotFound) { server.InternalServerError(response, err) return } if details[i].LastUpdateTime, err = m.CacheClient.Get(ctx, cache.Key(cache.RecommendUpdateTime, user.UserId)).Time(); err != nil && !errors.Is(err, errors.NotFound) { server.InternalServerError(response, err) return } } server.Ok(response, UserIterator{Cursor: cursor, Users: details}) } func (m *Master) getRecommend(request *restful.Request, response *restful.Response) { ctx := context.Background() if request != nil && request.Request != nil { ctx = request.Request.Context() } // parse arguments recommenderType := request.PathParameter("recommender") recommenderName := request.PathParameter("name") userId := request.PathParameter("user-id") categories := server.ReadCategories(request, nil) n, err := server.ParseInt(request, "n", m.Config.Server.DefaultN) if err != nil { server.BadRequest(response, err) return } recommender, err := logics.NewRecommender(m.Config.Recommend, m.CacheClient, m.DataClient, true, userId, categories) if err != nil { server.InternalServerError(response, err) return } var scores []cache.Score if recommenderType != "" { var name string if recommenderName != "" { name = recommenderType + "/" + recommenderName } else { name = recommenderType } scores, _, err = recommender.RecommendSequential(ctx, scores, n, name) } else { scores, err = recommender.Recommend(ctx, n) } if err != nil { server.InternalServerError(response, err) return } results := lo.Map(scores, func(item cache.Score, index int) string { return item.Id }) // Get item details items, err := m.DataClient.BatchGetItems(ctx, lo.Map(results, func(id string, _ int) string { return id })) if err != nil { server.InternalServerError(response, err) return } itemsMap := make(map[string]data.Item, len(items)) for _, item := range items { itemsMap[item.ItemId] = item } // Send result details := make([]ScoredItem, 0, len(results)) for i := range results { detail, exist := itemsMap[results[i]] if exist { details = append(details, ScoredItem{Item: detail, Score: scores[i].Score}) } else { log.Logger().Warn("recommended item doesn't exist", zap.String("item_id", results[i])) } } server.Ok(response, details) } type DetailedFeedback struct { FeedbackType string UserId string Item data.Item Value float64 Timestamp time.Time Comment string } // get feedback by user-id with feedback type func (m *Master) getUserFeedback(request *restful.Request, response *restful.Response) { ctx := context.Background() if request != nil && request.Request != nil { ctx = request.Request.Context() } // parse feedback type feedbackType := request.PathParameter("feedback-type") var feedbackTypeExpressions []expression.FeedbackTypeExpression if feedbackType != "" { feedbackTypeExpressions = make([]expression.FeedbackTypeExpression, 1) if err := feedbackTypeExpressions[0].FromString(feedbackType); err != nil { server.BadRequest(response, fmt.Errorf("invalid feedback type `%s`: %w", feedbackType, err)) return } } else { feedbackTypeExpressions = m.Config.Recommend.DataSource.PositiveFeedbackTypes } // parse n nStr := request.QueryParameter("n") n := m.Config.Server.DefaultN if nStr != "" { var err error n, err = server.ParseInt(request, "n", m.Config.Server.DefaultN) if err != nil { server.BadRequest(response, err) return } } // parse offset offsetStr := request.QueryParameter("offset") offset := 0 if offsetStr != "" { var err error offset, err = server.ParseInt(request, "offset", 0) if err != nil { server.BadRequest(response, err) return } } userId := request.PathParameter("user-id") feedback, err := m.DataClient.GetUserFeedback(ctx, userId, m.Config.Now(), feedbackTypeExpressions...) if err != nil { server.InternalServerError(response, err) return } sort.Slice(feedback, func(i, j int) bool { return feedback[i].Timestamp.After(feedback[j].Timestamp) }) if offset <= len(feedback) { feedback = feedback[offset:] } if n < len(feedback) { feedback = feedback[:n] } // Get item details items, err := m.DataClient.BatchGetItems(ctx, lo.Map(feedback, func(f data.Feedback, _ int) string { return f.ItemId })) if err != nil { server.InternalServerError(response, err) return } itemsMap := make(map[string]data.Item, len(items)) for _, item := range items { itemsMap[item.ItemId] = item } details := make([]DetailedFeedback, len(feedback)) for i := range feedback { details[i].FeedbackType = feedback[i].FeedbackType details[i].UserId = feedback[i].UserId details[i].Value = feedback[i].Value details[i].Timestamp = feedback[i].Timestamp details[i].Comment = feedback[i].Comment var exist bool details[i].Item, exist = itemsMap[feedback[i].ItemId] if !exist { details[i].Item = data.Item{ItemId: feedback[i].ItemId, Comment: "** This item doesn't exist in Gorse **"} } } server.Ok(response, details) } type ScoredItem struct { data.Item Score float64 } type ScoreUser struct { data.User Score float64 } func (m *Master) GetItem(score cache.Score) (any, error) { var item ScoredItem var err error item.Score = score.Score item.Item, err = m.DataClient.GetItem(context.Background(), score.Id) if err != nil { return nil, err } return item, nil } func (m *Master) GetUser(score cache.Score) (any, error) { var user ScoreUser var err error user.Score = score.Score user.User, err = m.DataClient.GetUser(context.Background(), score.Id) if err != nil { return nil, err } return user, nil } func (m *Master) getLatest(request *restful.Request, response *restful.Response) { var ( offset int n int err error ) categories := server.ReadCategories(request, nil) if offset, err = server.ParseInt(request, "offset", 0); err != nil { server.BadRequest(response, err) return } if n, err = server.ParseInt(request, "n", m.Config.Server.DefaultN); err != nil { server.BadRequest(response, err) return } items, err := m.DataClient.GetLatestItems(context.Background(), offset+n, categories) if err != nil { server.InternalServerError(response, err) return } if offset < len(items) { items = items[offset:] } if n < len(items) { items = items[:n] } scores := make([]ScoredItem, len(items)) for i := range items { scores[i] = ScoredItem{Item: items[i], Score: float64(items[i].Timestamp.Unix())} } m.SetLastModified(request, response, time.Now().String()) server.Ok(response, scores) } func (m *Master) getNonPersonalized(request *restful.Request, response *restful.Response) { name := request.PathParameter("name") categories := server.ReadCategories(request, []string{""}) m.SetLastModified(request, response, cache.Key(cache.NonPersonalizedUpdateTime, name)) m.SearchDocuments(cache.NonPersonalized, name, categories, m.GetItem, request, response) } func (m *Master) getItemToItem(request *restful.Request, response *restful.Response) { name := request.PathParameter("name") itemId := request.PathParameter("item-id") categories := request.QueryParameters("category") m.SetLastModified(request, response, cache.Key(cache.ItemToItemUpdateTime, name, itemId)) m.SearchDocuments(cache.ItemToItem, cache.Key(name, itemId), categories, m.GetItem, request, response) } func (m *Master) getUserToUser(request *restful.Request, response *restful.Response) { userId := request.PathParameter("user-id") name := request.PathParameter("name") m.SetLastModified(request, response, cache.Key(cache.UserToUserUpdateTime, name, userId)) m.SearchDocuments(cache.UserToUser, cache.Key(name, userId), nil, m.GetUser, request, response) } func (m *Master) getExternal(request *restful.Request, response *restful.Response) { scriptBase64 := request.QueryParameter("script") if scriptBase64 == "" { server.BadRequest(response, fmt.Errorf("script is required")) return } userId := request.QueryParameter("user-id") scriptBytes, err := base64.StdEncoding.DecodeString(scriptBase64) if err != nil { server.BadRequest(response, fmt.Errorf("invalid script encoding: %w", err)) return } external, err := logics.NewExternal(config.ExternalConfig{ Name: "preview", Script: string(scriptBytes), }) if err != nil { server.InternalServerError(response, err) return } defer external.Close() items, err := external.Pull(userId) if err != nil { server.InternalServerError(response, err) return } m.SetLastModified(request, response, time.Now().String()) server.Ok(response, items) } func (m *Master) getRankerPrompt(request *restful.Request, response *restful.Response) { ctx := context.Background() if request != nil && request.Request != nil { ctx = request.Request.Context() } queryTplBase64 := request.QueryParameter("query-template") documentTplBase64 := request.QueryParameter("document-template") if queryTplBase64 == "" || documentTplBase64 == "" { server.BadRequest(response, fmt.Errorf("query-template and document-template are required")) return } userId := request.QueryParameter("user-id") if userId == "" { server.BadRequest(response, fmt.Errorf("user-id is required")) return } queryTplBytes, err := base64.StdEncoding.DecodeString(queryTplBase64) if err != nil { server.BadRequest(response, fmt.Errorf("invalid query-template encoding: %w", err)) return } queryTpl, err := gonja.FromString(string(queryTplBytes)) if err != nil { server.BadRequest(response, err) return } documentTplBytes, err := base64.StdEncoding.DecodeString(documentTplBase64) if err != nil { server.BadRequest(response, fmt.Errorf("invalid document-template encoding: %w", err)) return } documentTpl, err := gonja.FromString(string(documentTplBytes)) if err != nil { server.BadRequest(response, err) return } user, err := m.DataClient.GetUser(ctx, userId) if err != nil { if errors.Is(err, errors.NotFound) { server.PageNotFound(response, err) } else { server.InternalServerError(response, err) } return } feedbacks, err := m.DataClient.GetUserFeedback(ctx, userId, m.Config.Now(), m.Config.Recommend.DataSource.PositiveFeedbackTypes...) if err != nil { server.InternalServerError(response, err) return } data.SortFeedbacks(feedbacks) if len(feedbacks) > 10 { feedbacks = feedbacks[:10] } itemsById := map[string]data.Item{} if len(feedbacks) > 0 { feedbackItemIds := lo.Map(feedbacks, func(fb data.Feedback, _ int) string { return fb.ItemId }) feedbackItems, err := m.DataClient.BatchGetItems(ctx, feedbackItemIds) if err != nil { server.InternalServerError(response, err) return } itemsById = make(map[string]data.Item, len(feedbackItems)) for _, item := range feedbackItems { itemsById[item.ItemId] = item } } feedbackItems := make([]*logics.FeedbackItem, 0, len(feedbacks)) for _, fb := range feedbacks { if item, exist := itemsById[fb.ItemId]; exist { feedbackItems = append(feedbackItems, &logics.FeedbackItem{ FeedbackType: fb.FeedbackType, Item: item, }) } } latestItems, err := m.DataClient.GetLatestItems(ctx, 100, nil) if err != nil { server.InternalServerError(response, err) return } // render query var queryBuf strings.Builder queryCtx := exec.NewContext(map[string]any{ "user": &user, "feedback": feedbackItems, }) if err := queryTpl.Execute(&queryBuf, queryCtx); err != nil { server.InternalServerError(response, err) return } // render documents documents := make([]string, len(latestItems)) for i, item := range latestItems { var docBuf strings.Builder docCtx := exec.NewContext(map[string]any{ "item": item, }) if err := documentTpl.Execute(&docBuf, docCtx); err != nil { server.InternalServerError(response, err) return } documents[i] = docBuf.String() } server.Ok(response, RerankerPrompt{ Query: queryBuf.String(), Documents: documents, }) } func (m *Master) importExportUsers(response http.ResponseWriter, request *http.Request) { ctx := context.Background() if request != nil { ctx = request.Context() } if !m.checkLogin(request) { resp := restful.NewResponse(response) err := resp.WriteErrorString(http.StatusUnauthorized, "unauthorized") if err != nil { server.InternalServerError(resp, err) return } return } switch request.Method { case http.MethodGet: var err error response.Header().Set("Content-Type", "application/jsonl") response.Header().Set("Content-Disposition", "attachment;filename=users.jsonl") encoder := json.NewEncoder(response) userStream, errChan := m.DataClient.GetUserStream(ctx, batchSize) for users := range userStream { for _, user := range users { if err = encoder.Encode(user); err != nil { server.InternalServerError(restful.NewResponse(response), err) return } } } if err = <-errChan; err != nil { server.InternalServerError(restful.NewResponse(response), errors.Trace(err)) return } case http.MethodPost: // open file file, _, err := request.FormFile("file") if err != nil { server.BadRequest(restful.NewResponse(response), err) return } defer file.Close() // parse and import users decoder := json.NewDecoder(file) lineCount := 0 timeStart := time.Now() users := make([]data.User, 0, batchSize) for { // parse line var user data.User if err = decoder.Decode(&user); err != nil { if errors.Is(err, io.EOF) { break } server.BadRequest(restful.NewResponse(response), err) return } // validate user id if err = util.ValidateId(user.UserId); err != nil { server.BadRequest(restful.NewResponse(response), fmt.Errorf("invalid user id `%v` at line %d (%s)", user.UserId, lineCount, err.Error())) return } users = append(users, user) // batch insert if len(users) == batchSize { err = m.DataClient.BatchInsertUsers(ctx, users) if err != nil { server.InternalServerError(restful.NewResponse(response), err) return } users = make([]data.User, 0, batchSize) } lineCount++ } if len(users) > 0 { err = m.DataClient.BatchInsertUsers(ctx, users) if err != nil { server.InternalServerError(restful.NewResponse(response), err) return } } m.cancel() select { case m.scheduled <- struct{}{}: default: } timeUsed := time.Since(timeStart) log.Logger().Info("complete import users", zap.Duration("time_used", timeUsed), zap.Int("num_users", lineCount)) server.Ok(restful.NewResponse(response), server.Success{RowAffected: lineCount}) default: writeError(response, http.StatusMethodNotAllowed, "method not allowed") } } func (m *Master) importExportItems(response http.ResponseWriter, request *http.Request) { ctx := context.Background() if request != nil { ctx = request.Context() } if !m.checkLogin(request) { resp := restful.NewResponse(response) err := resp.WriteErrorString(http.StatusUnauthorized, "unauthorized") if err != nil { server.InternalServerError(resp, err) return } return } switch request.Method { case http.MethodGet: var err error response.Header().Set("Content-Type", "application/jsonl") response.Header().Set("Content-Disposition", "attachment;filename=items.jsonl") encoder := json.NewEncoder(response) itemStream, errChan := m.DataClient.GetItemStream(ctx, batchSize, nil) for items := range itemStream { for _, item := range items { if err = encoder.Encode(item); err != nil { server.InternalServerError(restful.NewResponse(response), err) return } } } if err = <-errChan; err != nil { server.InternalServerError(restful.NewResponse(response), errors.Trace(err)) return } case http.MethodPost: // open file file, _, err := request.FormFile("file") if err != nil { server.BadRequest(restful.NewResponse(response), err) return } defer file.Close() // parse and import items decoder := json.NewDecoder(file) lineCount := 0 timeStart := time.Now() items := make([]data.Item, 0, batchSize) for { // parse line var item server.Item if err = decoder.Decode(&item); err != nil { if errors.Is(err, io.EOF) { break } server.BadRequest(restful.NewResponse(response), err) return } // validate item id if err = util.ValidateId(item.ItemId); err != nil { server.BadRequest(restful.NewResponse(response), fmt.Errorf("invalid item id `%v` at line %d (%s)", item.ItemId, lineCount, err.Error())) return } // validate categories for _, category := range item.Categories { if err = util.ValidateId(category); err != nil { server.BadRequest(restful.NewResponse(response), fmt.Errorf("invalid category `%v` at line %d (%s)", category, lineCount, err.Error())) return } } // parse timestamp var timestamp time.Time if item.Timestamp != "" { timestamp, err = dateparse.ParseAny(item.Timestamp) if err != nil { server.BadRequest(restful.NewResponse(response), fmt.Errorf("failed to parse datetime `%v` at line %v", item.Timestamp, lineCount)) return } } items = append(items, data.Item{ ItemId: item.ItemId, IsHidden: item.IsHidden, Categories: item.Categories, Timestamp: timestamp, Labels: item.Labels, Comment: item.Comment, }) // batch insert if len(items) == batchSize { err = m.DataClient.BatchInsertItems(ctx, items) if err != nil { server.InternalServerError(restful.NewResponse(response), err) return } items = make([]data.Item, 0, batchSize) } lineCount++ } if len(items) > 0 { err = m.DataClient.BatchInsertItems(ctx, items) if err != nil { server.InternalServerError(restful.NewResponse(response), err) return } } m.cancel() select { case m.scheduled <- struct{}{}: default: } timeUsed := time.Since(timeStart) log.Logger().Info("complete import items", zap.Duration("time_used", timeUsed), zap.Int("num_items", lineCount)) server.Ok(restful.NewResponse(response), server.Success{RowAffected: lineCount}) default: writeError(response, http.StatusMethodNotAllowed, "method not allowed") } } func (m *Master) importExportFeedback(response http.ResponseWriter, request *http.Request) { ctx := context.Background() if request != nil { ctx = request.Context() } if !m.checkLogin(request) { writeError(response, http.StatusUnauthorized, "unauthorized") return } switch request.Method { case http.MethodGet: var err error response.Header().Set("Content-Type", "application/jsonl") response.Header().Set("Content-Disposition", "attachment;filename=feedback.jsonl") encoder := json.NewEncoder(response) feedbackStream, errChan := m.DataClient.GetFeedbackStream(ctx, batchSize, data.WithEndTime(*m.Config.Now())) for feedback := range feedbackStream { for _, v := range feedback { if err = encoder.Encode(v); err != nil { server.InternalServerError(restful.NewResponse(response), err) return } } } if err = <-errChan; err != nil { server.InternalServerError(restful.NewResponse(response), errors.Trace(err)) return } case http.MethodPost: // open file file, _, err := request.FormFile("file") if err != nil { server.BadRequest(restful.NewResponse(response), err) return } defer file.Close() // parse and import feedback decoder := json.NewDecoder(file) lineCount := 0 timeStart := time.Now() feedbacks := make([]data.Feedback, 0, batchSize) for { // parse line var feedback server.Feedback if err = decoder.Decode(&feedback); err != nil { if errors.Is(err, io.EOF) { break } server.BadRequest(restful.NewResponse(response), err) return } // validate feedback type if err = util.ValidateId(feedback.FeedbackType); err != nil { server.BadRequest(restful.NewResponse(response), fmt.Errorf("invalid feedback type `%v` at line %d (%s)", feedback.FeedbackType, lineCount, err.Error())) return } // validate user id if err = util.ValidateId(feedback.UserId); err != nil { server.BadRequest(restful.NewResponse(response), fmt.Errorf("invalid user id `%v` at line %d (%s)", feedback.UserId, lineCount, err.Error())) return } // validate item id if err = util.ValidateId(feedback.ItemId); err != nil { server.BadRequest(restful.NewResponse(response), fmt.Errorf("invalid item id `%v` at line %d (%s)", feedback.ItemId, lineCount, err.Error())) return } // parse timestamp var timestamp time.Time if feedback.Timestamp != "" { timestamp, err = dateparse.ParseAny(feedback.Timestamp) if err != nil { server.BadRequest(restful.NewResponse(response), fmt.Errorf("failed to parse datetime `%v` at line %d", feedback.Timestamp, lineCount)) return } } feedbacks = append(feedbacks, data.Feedback{ FeedbackKey: feedback.FeedbackKey, Value: feedback.Value, Timestamp: timestamp, Comment: feedback.Comment, }) // batch insert if len(feedbacks) == batchSize { // batch insert to data store err = m.DataClient.BatchInsertFeedback(ctx, feedbacks, m.Config.Server.AutoInsertUser, m.Config.Server.AutoInsertItem, true) if err != nil { server.InternalServerError(restful.NewResponse(response), err) return } feedbacks = make([]data.Feedback, 0, batchSize) } lineCount++ } // insert to cache store if len(feedbacks) > 0 { // insert to data store err = m.DataClient.BatchInsertFeedback(ctx, feedbacks, m.Config.Server.AutoInsertUser, m.Config.Server.AutoInsertItem, true) if err != nil { server.InternalServerError(restful.NewResponse(response), err) return } } m.cancel() select { case m.scheduled <- struct{}{}: default: } timeUsed := time.Since(timeStart) log.Logger().Info("complete import feedback", zap.Duration("time_used", timeUsed), zap.Int("num_items", lineCount)) server.Ok(restful.NewResponse(response), server.Success{RowAffected: lineCount}) default: writeError(response, http.StatusMethodNotAllowed, "method not allowed") } } var checkList = mapset.NewSet("delete_users", "delete_items", "delete_feedback", "delete_cache") func (m *Master) purge(response http.ResponseWriter, request *http.Request) { // check method if request.Method != http.MethodPost { writeError(response, http.StatusMethodNotAllowed, "method not allowed") return } // check login if !m.checkLogin(request) { resp := restful.NewResponse(response) err := resp.WriteErrorString(http.StatusUnauthorized, "unauthorized") if err != nil { server.InternalServerError(resp, err) return } return } // check password if m.Config.Master.DashboardPassword == "" { writeError(response, http.StatusUnauthorized, "purge is not allowed without dashboard password") return } // check list if err := request.ParseForm(); err != nil { server.BadRequest(restful.NewResponse(response), err) return } checkedList := strings.Split(request.Form.Get("check_list"), ",") if !checkList.Equal(mapset.NewSet(checkedList...)) { writeError(response, http.StatusUnauthorized, "please confirm by checking all") return } // purge data if err := m.DataClient.Purge(); err != nil { writeError(response, http.StatusInternalServerError, err.Error()) return } if err := m.CacheClient.Purge(); err != nil { writeError(response, http.StatusInternalServerError, err.Error()) return } } func writeError(response http.ResponseWriter, httpStatus int, message string) { log.Logger().Error(strings.ToLower(http.StatusText(httpStatus)), zap.String("error", message)) response.Header().Set("Access-Control-Allow-Origin", "*") response.WriteHeader(httpStatus) if _, err := response.Write([]byte(message)); err != nil { log.Logger().Error("failed to write error", zap.Error(err)) } } func (m *Master) checkAdmin(request *http.Request) bool { if m.Config.Master.AdminAPIKey == "" { return true } if request.FormValue("X-API-Key") == m.Config.Master.AdminAPIKey { return true } return false } const ( EOF = int64(0) UserStream = int64(-1) ItemStream = int64(-2) FeedbackStream = int64(-3) ) type DumpStats struct { Users int Items int Feedback int Duration time.Duration } func writeDump[T proto.Message](w io.Writer, data T) error { bytes, err := proto.Marshal(data) if err != nil { return err } if err = binary.Write(w, binary.LittleEndian, int64(len(bytes))); err != nil { return err } if _, err = w.Write(bytes); err != nil { return err } return nil } func readDump[T proto.Message](r io.Reader, data T) (int64, error) { var size int64 if err := binary.Read(r, binary.LittleEndian, &size); err != nil { return 0, err } if size <= 0 { return size, nil } bytes := make([]byte, size) if _, err := io.ReadFull(r, bytes); err != nil { return 0, err } return size, proto.Unmarshal(bytes, data) } func (m *Master) dump(response http.ResponseWriter, request *http.Request) { if !m.checkAdmin(request) { writeError(response, http.StatusUnauthorized, "unauthorized") return } if request.Method != http.MethodGet { writeError(response, http.StatusMethodNotAllowed, "method not allowed") return } response.Header().Set("Content-Type", "application/octet-stream") var stats DumpStats start := time.Now() // dump users if err := binary.Write(response, binary.LittleEndian, UserStream); err != nil { writeError(response, http.StatusInternalServerError, err.Error()) return } userStream, errChan := m.DataClient.GetUserStream(context.Background(), batchSize) for users := range userStream { for _, user := range users { labels, err := json.Marshal(user.Labels) if err != nil { writeError(response, http.StatusInternalServerError, err.Error()) return } if err := writeDump(response, &protocol.User{ UserId: user.UserId, Labels: labels, Comment: user.Comment, }); err != nil { writeError(response, http.StatusInternalServerError, err.Error()) return } stats.Users++ } } if err := <-errChan; err != nil { writeError(response, http.StatusInternalServerError, err.Error()) return } // dump items if err := binary.Write(response, binary.LittleEndian, ItemStream); err != nil { writeError(response, http.StatusInternalServerError, err.Error()) return } itemStream, errChan := m.DataClient.GetItemStream(context.Background(), batchSize, nil) for items := range itemStream { for _, item := range items { labels, err := json.Marshal(item.Labels) if err != nil { writeError(response, http.StatusInternalServerError, err.Error()) return } if err := writeDump(response, &protocol.Item{ ItemId: item.ItemId, IsHidden: item.IsHidden, Categories: item.Categories, Timestamp: timestamppb.New(item.Timestamp), Labels: labels, Comment: item.Comment, }); err != nil { writeError(response, http.StatusInternalServerError, err.Error()) return } stats.Items++ } } if err := <-errChan; err != nil { writeError(response, http.StatusInternalServerError, err.Error()) return } // dump feedback if err := binary.Write(response, binary.LittleEndian, FeedbackStream); err != nil { writeError(response, http.StatusInternalServerError, err.Error()) return } feedbackStream, errChan := m.DataClient.GetFeedbackStream(context.Background(), batchSize, data.WithEndTime(*m.Config.Now())) for feedbacks := range feedbackStream { for _, feedback := range feedbacks { if err := writeDump(response, &protocol.Feedback{ FeedbackType: feedback.FeedbackType, UserId: feedback.UserId, ItemId: feedback.ItemId, Value: feedback.Value, Timestamp: timestamppb.New(feedback.Timestamp), Comment: feedback.Comment, }); err != nil { writeError(response, http.StatusInternalServerError, err.Error()) return } stats.Feedback++ } } if err := <-errChan; err != nil { writeError(response, http.StatusInternalServerError, err.Error()) return } // dump EOF if err := binary.Write(response, binary.LittleEndian, EOF); err != nil { writeError(response, http.StatusInternalServerError, err.Error()) return } stats.Duration = time.Since(start) log.Logger().Info("complete dump", zap.Int("users", stats.Users), zap.Int("items", stats.Items), zap.Int("feedback", stats.Feedback), zap.Duration("duration", stats.Duration)) server.Ok(restful.NewResponse(response), stats) } func (m *Master) Restore(r io.ReadCloser) (stats DumpStats, err error) { flag := EOF if err = binary.Read(r, binary.LittleEndian, &flag); err != nil { return } for flag != EOF { switch flag { case UserStream: users := make([]data.User, 0, batchSize) for { var user protocol.User if flag, err = readDump(r, &user); err != nil { return } if flag <= 0 { break } var labels any if err = json.Unmarshal(user.Labels, &labels); err != nil { return } users = append(users, data.User{ UserId: user.UserId, Labels: labels, Comment: user.Comment, }) stats.Users++ if len(users) == batchSize { if err = m.DataClient.BatchInsertUsers(context.Background(), users); err != nil { return } users = users[:0] } } if len(users) > 0 { if err = m.DataClient.BatchInsertUsers(context.Background(), users); err != nil { return } } case ItemStream: items := make([]data.Item, 0, batchSize) for { var item protocol.Item if flag, err = readDump(r, &item); err != nil { return } if flag <= 0 { break } var labels any if err = json.Unmarshal(item.Labels, &labels); err != nil { return } items = append(items, data.Item{ ItemId: item.ItemId, IsHidden: item.IsHidden, Categories: item.Categories, Timestamp: item.Timestamp.AsTime(), Labels: labels, Comment: item.Comment, }) stats.Items++ if len(items) == batchSize { if err = m.DataClient.BatchInsertItems(context.Background(), items); err != nil { return } items = items[:0] } } if len(items) > 0 { if err = m.DataClient.BatchInsertItems(context.Background(), items); err != nil { return } } case FeedbackStream: feedbacks := make([]data.Feedback, 0, batchSize) for { var feedback protocol.Feedback if flag, err = readDump(r, &feedback); err != nil { return } if flag <= 0 { break } feedbacks = append(feedbacks, data.Feedback{ FeedbackKey: data.FeedbackKey{ FeedbackType: feedback.FeedbackType, UserId: feedback.UserId, ItemId: feedback.ItemId, }, Value: feedback.Value, Timestamp: feedback.Timestamp.AsTime(), Comment: feedback.Comment, }) stats.Feedback++ if len(feedbacks) == batchSize { if err = m.DataClient.BatchInsertFeedback(context.Background(), feedbacks, false, false, true); err != nil { return } feedbacks = feedbacks[:0] } } if len(feedbacks) > 0 { if err = m.DataClient.BatchInsertFeedback(context.Background(), feedbacks, false, false, true); err != nil { return } } default: err = fmt.Errorf("unknown flag %v", flag) return } } return } func (m *Master) restore(response http.ResponseWriter, request *http.Request) { if !m.checkAdmin(request) { writeError(response, http.StatusUnauthorized, "unauthorized") return } if request.Method != http.MethodPost { writeError(response, http.StatusMethodNotAllowed, "method not allowed") return } start := time.Now() stats, err := m.Restore(request.Body) if err != nil { writeError(response, http.StatusInternalServerError, err.Error()) return } m.cancel() select { case m.scheduled <- struct{}{}: default: } stats.Duration = time.Since(start) log.Logger().Info("complete restore", zap.Int("users", stats.Users), zap.Int("items", stats.Items), zap.Int("feedback", stats.Feedback), zap.Duration("duration", stats.Duration)) server.Ok(restful.NewResponse(response), stats) } func (m *Master) handleOAuth2Callback(w http.ResponseWriter, r *http.Request) { // Verify state and errors. oauth2Token, err := m.oauth2Config.Exchange(r.Context(), r.URL.Query().Get("code")) if err != nil { server.InternalServerError(restful.NewResponse(w), err) return } // Extract the ID Token from OAuth2 token. rawIDToken, ok := oauth2Token.Extra("id_token").(string) if !ok { server.InternalServerError(restful.NewResponse(w), errors.New("missing id_token")) return } // Parse and verify ID Token payload. idToken, err := m.verifier.Verify(r.Context(), rawIDToken) if err != nil { server.InternalServerError(restful.NewResponse(w), err) return } // Extract custom claims var claims UserInfo if err := idToken.Claims(&claims); err != nil { server.InternalServerError(restful.NewResponse(w), err) return } // Set token cache and cookie m.tokenCache.Set(rawIDToken, claims, time.Until(idToken.Expiry)) if encoded, err := cookieHandler.Encode("id_token", rawIDToken); err != nil { server.InternalServerError(restful.NewResponse(w), err) return } else { http.SetCookie(w, &http.Cookie{ Name: "id_token", Value: encoded, Path: "/", Expires: idToken.Expiry, }) http.Redirect(w, r, "/", http.StatusFound) log.Logger().Info("login success via OIDC", zap.String("name", claims.Name), zap.String("email", claims.Email)) } } func (m *Master) chat(response http.ResponseWriter, request *http.Request) { if !m.checkAdmin(request) { writeError(response, http.StatusUnauthorized, "unauthorized") return } content, err := io.ReadAll(request.Body) if err != nil { writeError(response, http.StatusInternalServerError, err.Error()) return } stream, err := m.openAIClient.CreateChatCompletionStream( request.Context(), openai.ChatCompletionRequest{ Model: m.Config.OpenAI.ChatCompletionModel, Messages: []openai.ChatCompletionMessage{ { Role: openai.ChatMessageRoleUser, Content: string(content), }, }, Stream: true, }, ) if err != nil { writeError(response, http.StatusInternalServerError, err.Error()) return } // read response defer stream.Close() for { var resp openai.ChatCompletionStreamResponse resp, err = stream.Recv() if errors.Is(err, io.EOF) { return } if err != nil { writeError(response, http.StatusInternalServerError, err.Error()) return } if len(resp.Choices) == 0 { continue } if _, err = response.Write([]byte(resp.Choices[0].Delta.Content)); err != nil { log.Logger().Error("failed to write response", zap.Error(err)) return } // flush response if f, ok := response.(http.Flusher); ok { f.Flush() } } } ================================================ FILE: master/rest_test.go ================================================ // Copyright 2021 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package master import ( "bytes" "encoding/base64" "encoding/json" "fmt" "mime/multipart" "net/http" "net/http/httptest" "strconv" "strings" "testing" "time" "github.com/emicklei/go-restful/v3" "github.com/go-viper/mapstructure/v2" "github.com/gorse-io/gorse/common/expression" "github.com/gorse-io/gorse/common/mock" "github.com/gorse-io/gorse/config" "github.com/gorse-io/gorse/protocol" "github.com/gorse-io/gorse/server" "github.com/gorse-io/gorse/storage/cache" "github.com/gorse-io/gorse/storage/data" "github.com/gorse-io/gorse/storage/meta" "github.com/invopop/jsonschema" "github.com/samber/lo" "github.com/sashabaranov/go-openai" "github.com/steinfletcher/apitest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" ) const ( mockMasterUsername = "admin" mockMasterPassword = "pass" ) func marshal(t *testing.T, v interface{}) string { s, err := json.Marshal(v) assert.NoError(t, err) return string(s) } func marshalJSONLines[T any](t *testing.T, v []T) string { var buf bytes.Buffer encoder := json.NewEncoder(&buf) for _, item := range v { err := encoder.Encode(item) assert.NoError(t, err) } return buf.String() } func convertToMapStructure(t *testing.T, v interface{}) map[string]interface{} { var m map[string]interface{} err := mapstructure.Decode(v, &m) assert.NoError(t, err) return m } type MasterAPITestSuite struct { suite.Suite Master handler *restful.Container openAIServer *mock.OpenAIServer cookie string } func (suite *MasterAPITestSuite) SetupTest() { // open database var err error suite.Config = config.GetDefaultConfig() suite.Config.Recommend.Ranker.Type = "fm" suite.Config.Database.DataStore = fmt.Sprintf("sqlite://%s/data.db", suite.T().TempDir()) suite.Config.Database.CacheStore = fmt.Sprintf("sqlite://%s/cache.db", suite.T().TempDir()) suite.metaStore, err = meta.Open(fmt.Sprintf("sqlite://%s/meta.db", suite.T().TempDir()), suite.Config.Master.MetaTimeout) suite.NoError(err) suite.DataClient, err = data.Open(suite.Config.Database.DataStore, "") suite.NoError(err) suite.CacheClient, err = cache.Open(suite.Config.Database.CacheStore, "") suite.NoError(err) // init database err = suite.metaStore.Init() suite.NoError(err) err = suite.DataClient.Init() suite.NoError(err) err = suite.CacheClient.Init() suite.NoError(err) // create server suite.Config.Master.DashboardUserName = mockMasterUsername suite.Config.Master.DashboardPassword = mockMasterPassword suite.WebService = new(restful.WebService) suite.CreateWebService() suite.RestServer.CreateWebService() suite.cancel = func() {} suite.scheduled = make(chan struct{}, 1) // create handler suite.handler = restful.NewContainer() suite.handler.Add(suite.WebService) // creat mock AI server suite.openAIServer = mock.NewOpenAIServer() go func() { _ = suite.openAIServer.Start() }() suite.openAIServer.Ready() clientConfig := openai.DefaultConfig(suite.openAIServer.AuthToken()) clientConfig.BaseURL = suite.openAIServer.BaseURL() suite.openAIClient = openai.NewClientWithConfig(clientConfig) // login req, err := http.NewRequest("POST", "/login", strings.NewReader(fmt.Sprintf("user_name=%s&password=%s", mockMasterUsername, mockMasterPassword))) suite.NoError(err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") resp := httptest.NewRecorder() suite.login(resp, req) suite.Equal(http.StatusFound, resp.Code) suite.cookie = resp.Header().Get("Set-Cookie") } func (suite *MasterAPITestSuite) TearDownTest() { err := suite.metaStore.Close() suite.NoError(err) err = suite.DataClient.Close() suite.NoError(err) err = suite.CacheClient.Close() suite.NoError(err) err = suite.openAIServer.Close() suite.NoError(err) } func (suite *MasterAPITestSuite) TestExportUsers() { ctx := suite.T().Context() // insert users users := []data.User{ {UserId: "1", Labels: map[string]any{"gender": "male", "job": "engineer"}}, {UserId: "2", Labels: map[string]any{"gender": "male", "job": "lawyer"}}, {UserId: "3", Labels: map[string]any{"gender": "female", "job": "teacher"}}, } err := suite.DataClient.BatchInsertUsers(ctx, users) suite.NoError(err) // send request req := httptest.NewRequest("GET", "https://example.com/", nil) req.Header.Set("Cookie", suite.cookie) w := httptest.NewRecorder() suite.importExportUsers(w, req) suite.Equal(http.StatusOK, w.Result().StatusCode) suite.Equal("application/jsonl", w.Header().Get("Content-Type")) suite.Equal("attachment;filename=users.jsonl", w.Header().Get("Content-Disposition")) suite.Equal(marshalJSONLines(suite.T(), users), w.Body.String()) } func (suite *MasterAPITestSuite) TestExportItems() { ctx := suite.T().Context() // insert items items := []data.Item{ { ItemId: "1", IsHidden: false, Categories: []string{"x"}, Timestamp: time.Date(2020, 1, 1, 1, 1, 1, 1, time.UTC), Labels: map[string]any{"genre": []string{"comedy", "sci-fi"}}, Comment: "o,n,e", }, { ItemId: "2", IsHidden: false, Categories: []string{"x", "y"}, Timestamp: time.Date(2021, 1, 1, 1, 1, 1, 1, time.UTC), Labels: map[string]any{"genre": []string{"documentary", "sci-fi"}}, Comment: "t\r\nw\r\no", }, { ItemId: "3", IsHidden: true, Categories: nil, Timestamp: time.Date(2022, 1, 1, 1, 1, 1, 1, time.UTC), Labels: nil, Comment: "\"three\"", }, } err := suite.DataClient.BatchInsertItems(ctx, items) suite.NoError(err) // send request req := httptest.NewRequest("GET", "https://example.com/", nil) req.Header.Set("Cookie", suite.cookie) w := httptest.NewRecorder() suite.importExportItems(w, req) suite.Equal(http.StatusOK, w.Result().StatusCode) suite.Equal("application/jsonl", w.Header().Get("Content-Type")) suite.Equal("attachment;filename=items.jsonl", w.Header().Get("Content-Disposition")) suite.Equal(marshalJSONLines(suite.T(), items), w.Body.String()) } func (suite *MasterAPITestSuite) TestExportFeedback() { ctx := suite.T().Context() // insert feedback feedbacks := []data.Feedback{ {FeedbackKey: data.FeedbackKey{FeedbackType: "click", UserId: "0", ItemId: "2"}}, {FeedbackKey: data.FeedbackKey{FeedbackType: "read", UserId: "2", ItemId: "6"}}, {FeedbackKey: data.FeedbackKey{FeedbackType: "share", UserId: "1", ItemId: "4"}}, } err := suite.DataClient.BatchInsertFeedback(ctx, feedbacks, true, true, true) suite.NoError(err) // send request req := httptest.NewRequest("GET", "https://example.com/", nil) req.Header.Set("Cookie", suite.cookie) w := httptest.NewRecorder() suite.importExportFeedback(w, req) suite.Equal(http.StatusOK, w.Result().StatusCode) suite.Equal("application/jsonl", w.Header().Get("Content-Type")) suite.Equal("attachment;filename=feedback.jsonl", w.Header().Get("Content-Disposition")) suite.Equal(marshalJSONLines(suite.T(), feedbacks), w.Body.String()) } func (suite *MasterAPITestSuite) TestImportUsers() { ctx := suite.T().Context() // send request buf := bytes.NewBuffer(nil) writer := multipart.NewWriter(buf) file, err := writer.CreateFormFile("file", "users.jsonl") suite.NoError(err) _, err = file.Write([]byte(`{"UserId":"1","Labels":{"性别":"男","职业":"工程师"}} {"UserId":"2","Labels":{"性别":"男","职业":"律师"}} {"UserId":"3","Labels":{"性别":"女","职业":"教师"}}`)) suite.NoError(err) err = writer.Close() suite.NoError(err) req := httptest.NewRequest("POST", "https://example.com/", buf) req.Header.Set("Cookie", suite.cookie) req.Header.Set("Content-Type", writer.FormDataContentType()) w := httptest.NewRecorder() suite.importExportUsers(w, req) // check suite.Equal(http.StatusOK, w.Result().StatusCode) suite.JSONEq(marshal(suite.T(), server.Success{RowAffected: 3}), w.Body.String()) _, items, err := suite.DataClient.GetUsers(ctx, "", 100) suite.NoError(err) suite.Equal([]data.User{ {UserId: "1", Labels: map[string]any{"性别": "男", "职业": "工程师"}}, {UserId: "2", Labels: map[string]any{"性别": "男", "职业": "律师"}}, {UserId: "3", Labels: map[string]any{"性别": "女", "职业": "教师"}}, }, items) suite.NotEmpty(suite.scheduled) } func (suite *MasterAPITestSuite) TestImportItems() { ctx := suite.T().Context() // send request buf := bytes.NewBuffer(nil) writer := multipart.NewWriter(buf) file, err := writer.CreateFormFile("file", "items.jsonl") suite.NoError(err) _, err = file.Write([]byte(`{"ItemId":"1","IsHidden":false,"Categories":["x"],"Timestamp":"2020-01-01 01:01:01.000000001 +0000 UTC","Labels":{"类型":["喜剧","科幻"]},"Comment":"one"} {"ItemId":"2","IsHidden":false,"Categories":["x","y"],"Timestamp":"2021-01-01 01:01:01.000000001 +0000 UTC","Labels":{"类型":["卡通","科幻"]},"Comment":"two"} {"ItemId":"3","IsHidden":true,"Timestamp":"2022-01-01 01:01:01.000000001 +0000 UTC","Comment":"three"}`)) suite.NoError(err) err = writer.Close() suite.NoError(err) req := httptest.NewRequest("POST", "https://example.com/", buf) req.Header.Set("Cookie", suite.cookie) req.Header.Set("Content-Type", writer.FormDataContentType()) w := httptest.NewRecorder() suite.importExportItems(w, req) // check suite.Equal(http.StatusOK, w.Result().StatusCode) suite.JSONEq(marshal(suite.T(), server.Success{RowAffected: 3}), w.Body.String()) _, items, err := suite.DataClient.GetItems(ctx, "", 100, nil) suite.NoError(err) suite.Equal([]data.Item{ { ItemId: "1", IsHidden: false, Categories: []string{"x"}, Timestamp: time.Date(2020, 1, 1, 1, 1, 1, 1, time.UTC), Labels: map[string]any{"类型": []any{"喜剧", "科幻"}}, Comment: "one"}, { ItemId: "2", IsHidden: false, Categories: []string{"x", "y"}, Timestamp: time.Date(2021, 1, 1, 1, 1, 1, 1, time.UTC), Labels: map[string]any{"类型": []any{"卡通", "科幻"}}, Comment: "two", }, { ItemId: "3", IsHidden: true, Categories: nil, Timestamp: time.Date(2022, 1, 1, 1, 1, 1, 1, time.UTC), Labels: nil, Comment: "three", }, }, items) suite.NotEmpty(suite.scheduled) } func (suite *MasterAPITestSuite) TestImportFeedback() { // send request ctx := suite.T().Context() buf := bytes.NewBuffer(nil) writer := multipart.NewWriter(buf) file, err := writer.CreateFormFile("file", "feedback.jsonl") suite.NoError(err) _, err = file.Write([]byte(`{"FeedbackType":"click","UserId":"0","ItemId":"2","Timestamp":"0001-01-01 00:00:00 +0000 UTC"} {"FeedbackType":"read","UserId":"2","ItemId":"6","Timestamp":"0001-01-01 00:00:00 +0000 UTC"} {"FeedbackType":"share","UserId":"1","ItemId":"4","Timestamp":"0001-01-01 00:00:00 +0000 UTC"}`)) suite.NoError(err) err = writer.Close() suite.NoError(err) req := httptest.NewRequest("POST", "https://example.com/", buf) req.Header.Set("Cookie", suite.cookie) req.Header.Set("Content-Type", writer.FormDataContentType()) w := httptest.NewRecorder() suite.importExportFeedback(w, req) // check suite.Equal(http.StatusOK, w.Result().StatusCode) suite.JSONEq(marshal(suite.T(), server.Success{RowAffected: 3}), w.Body.String()) _, feedback, err := suite.DataClient.GetFeedback(ctx, "", 100, nil, lo.ToPtr(time.Now())) suite.NoError(err) suite.Equal([]data.Feedback{ {FeedbackKey: data.FeedbackKey{FeedbackType: "click", UserId: "0", ItemId: "2"}}, {FeedbackKey: data.FeedbackKey{FeedbackType: "read", UserId: "2", ItemId: "6"}}, {FeedbackKey: data.FeedbackKey{FeedbackType: "share", UserId: "1", ItemId: "4"}}, }, feedback) suite.NotEmpty(suite.scheduled) } func (suite *MasterAPITestSuite) TestGetCluster() { // add nodes serverNode := &meta.Node{ UUID: "alan turnin", Hostname: "192.168.1.100", Type: protocol.NodeType_Server.String(), Version: "server_version", UpdateTime: time.Now().UTC(), } workerNode := &meta.Node{ UUID: "dennis ritchie", Hostname: "192.168.1.101", Type: protocol.NodeType_Worker.String(), Version: "worker_version", UpdateTime: time.Now().UTC(), } err := suite.metaStore.UpdateNode(serverNode) suite.NoError(err) err = suite.metaStore.UpdateNode(workerNode) suite.NoError(err) // get nodes apitest.New(). Handler(suite.handler). Get("/api/dashboard/cluster"). Header("Cookie", suite.cookie). Expect(suite.T()). Status(http.StatusOK). Body(marshal(suite.T(), []*meta.Node{serverNode, workerNode})). End() } func (suite *MasterAPITestSuite) TestGetStats() { ctx := suite.T().Context() // set stats err := suite.CacheClient.Set(ctx, cache.Integer(cache.Key(cache.GlobalMeta, cache.NumUsers), 123)) suite.NoError(err) err = suite.CacheClient.Set(ctx, cache.Integer(cache.Key(cache.GlobalMeta, cache.NumItems), 234)) suite.NoError(err) err = suite.CacheClient.Set(ctx, cache.Integer(cache.Key(cache.GlobalMeta, cache.NumValidPosFeedbacks), 345)) suite.NoError(err) err = suite.CacheClient.Set(ctx, cache.Integer(cache.Key(cache.GlobalMeta, cache.NumValidNegFeedbacks), 456)) suite.NoError(err) // get stats apitest.New(). Handler(suite.handler). Get("/api/dashboard/stats"). Header("Cookie", suite.cookie). Expect(suite.T()). Status(http.StatusOK). Body(marshal(suite.T(), Status{ NumUsers: 123, NumItems: 234, NumValidPosFeedback: 345, NumValidNegFeedback: 456, BinaryVersion: "unknown-version", })). End() } func (suite *MasterAPITestSuite) TestGetCategories() { ctx := suite.T().Context() // insert categories categoryScores := []cache.Score{ {Id: "a", Score: 3}, {Id: "b", Score: 2}, {Id: "c", Score: 1}, } err := suite.CacheClient.AddScores(ctx, cache.ItemCategories, "", categoryScores) suite.NoError(err) // get categories apitest.New(). Handler(suite.handler). Get("/api/dashboard/categories"). Header("Cookie", suite.cookie). Expect(suite.T()). Status(http.StatusOK). Body(marshal(suite.T(), []string{"a", "b", "c"})). End() } func (suite *MasterAPITestSuite) TestGetUsers() { ctx := suite.T().Context() // add users users := []User{ {data.User{UserId: "0"}, time.Date(2000, 1, 1, 1, 1, 1, 1, time.UTC), time.Date(2020, 1, 1, 1, 1, 1, 1, time.UTC)}, {data.User{UserId: "1"}, time.Date(2001, 1, 1, 1, 1, 1, 1, time.UTC), time.Date(2021, 1, 1, 1, 1, 1, 1, time.UTC)}, {data.User{UserId: "2"}, time.Date(2002, 1, 1, 1, 1, 1, 1, time.UTC), time.Date(2022, 1, 1, 1, 1, 1, 1, time.UTC)}, } for _, user := range users { err := suite.DataClient.BatchInsertUsers(ctx, []data.User{user.User}) suite.NoError(err) err = suite.CacheClient.Set(ctx, cache.Time(cache.Key(cache.LastModifyUserTime, user.UserId), user.LastActiveTime)) suite.NoError(err) err = suite.CacheClient.Set(ctx, cache.Time(cache.Key(cache.RecommendUpdateTime, user.UserId), user.LastUpdateTime)) suite.NoError(err) } // get users apitest.New(). Handler(suite.handler). Get("/api/dashboard/users"). Header("Cookie", suite.cookie). Expect(suite.T()). Status(http.StatusOK). Body(marshal(suite.T(), UserIterator{ Cursor: "", Users: users, })). End() // get a user apitest.New(). Handler(suite.handler). Get("/api/dashboard/user/1"). Header("Cookie", suite.cookie). Expect(suite.T()). Status(http.StatusOK). Body(marshal(suite.T(), users[1])). End() } func (suite *MasterAPITestSuite) TestGetLatestItems() { ctx := suite.T().Context() // add items items := []data.Item{ {ItemId: "0", Timestamp: time.Date(2025, 1, 1, 1, 1, 1, 1, time.UTC)}, {ItemId: "1", Timestamp: time.Date(2024, 1, 1, 1, 1, 1, 1, time.UTC)}, {ItemId: "2", Timestamp: time.Date(2023, 1, 1, 1, 1, 1, 1, time.UTC)}, } err := suite.DataClient.BatchInsertItems(ctx, items) suite.NoError(err) // get latest items scores := lo.Map(items, func(item data.Item, _ int) ScoredItem { return ScoredItem{Item: item, Score: float64(item.Timestamp.Unix())} }) apitest.New(). Handler(suite.handler). Get("/api/dashboard/latest"). Header("Cookie", suite.cookie). Expect(suite.T()). Status(http.StatusOK). Body(marshal(suite.T(), scores)). End() } func (suite *MasterAPITestSuite) TestSearchDocumentsOfItems() { type ListOperator struct { Name string Collection string Subset string Category string Get string } ctx := suite.T().Context() operators := []ListOperator{ {"ItemToItem", cache.ItemToItem, cache.Key("neighbors", "0"), "", "/api/dashboard/item-to-item/neighbors/0"}, {"ItemToItemCategory", cache.ItemToItem, cache.Key("neighbors", "0"), "*", "/api/dashboard/item-to-item/neighbors/0"}, {"LatestItems", cache.NonPersonalized, "latest", "", "/api/dashboard/non-personalized/latest/"}, {"PopularItems", cache.NonPersonalized, "popular", "", "/api/dashboard/non-personalized/popular/"}, {"LatestItemsCategory", cache.NonPersonalized, "latest", "*", "/api/dashboard/non-personalized/latest/"}, {"PopularItemsCategory", cache.NonPersonalized, "popular", "*", "/api/dashboard/non-personalized/popular/"}, } lastModified := time.Now() for i, operator := range operators { suite.T().Run(operator.Name, func(t *testing.T) { // Put scores scores := []cache.Score{ {Id: strconv.Itoa(i) + "0", Score: 100, Categories: []string{operator.Category}}, {Id: strconv.Itoa(i) + "1", Score: 99, Categories: []string{operator.Category}}, {Id: strconv.Itoa(i) + "2", Score: 98, Categories: []string{operator.Category}}, {Id: strconv.Itoa(i) + "3", Score: 97, Categories: []string{operator.Category}}, {Id: strconv.Itoa(i) + "4", Score: 96, Categories: []string{operator.Category}}, } err := suite.CacheClient.AddScores(ctx, operator.Collection, operator.Subset, scores) suite.NoError(err) err = suite.CacheClient.Set(ctx, cache.Time(cache.Key(operator.Collection+"_update_time", operator.Subset), lastModified)) assert.NoError(t, err) items := make([]ScoredItem, 0) for _, score := range scores { items = append(items, ScoredItem{Item: data.Item{ItemId: score.Id}, Score: score.Score}) err = suite.DataClient.BatchInsertItems(ctx, []data.Item{{ItemId: score.Id}}) suite.NoError(err) } // hide item apitest.New(). Handler(suite.handler). Patch("/api/item/"+strconv.Itoa(i)+"3"). Header("Cookie", suite.cookie). JSON(data.ItemPatch{IsHidden: new(true)}). Expect(t). Status(http.StatusOK). End() apitest.New(). Handler(suite.handler). Get(operator.Get). Header("Cookie", suite.cookie). Query("category", operator.Category). Expect(t). Status(http.StatusOK). HeaderPresent("Last-Modified"). Body(marshal(t, []ScoredItem{items[0], items[1], items[2], items[4]})). End() }) } } func (suite *MasterAPITestSuite) TestSearchDocumentsOfUsers() { type ListOperator struct { Prefix string Label string Get string } ctx := suite.T().Context() operators := []ListOperator{ {cache.UserToUser, cache.Key("neighbors", "0"), "/api/dashboard/user-to-user/neighbors/0/"}, } lastModified := time.Now() for _, operator := range operators { suite.T().Logf("test RESTful API: %v", operator.Get) // Put scores scores := []cache.Score{ {Id: "0", Score: 100, Categories: []string{""}}, {Id: "1", Score: 99, Categories: []string{""}}, {Id: "2", Score: 98, Categories: []string{""}}, {Id: "3", Score: 97, Categories: []string{""}}, {Id: "4", Score: 96, Categories: []string{""}}, } err := suite.CacheClient.AddScores(ctx, operator.Prefix, operator.Label, scores) suite.NoError(err) err = suite.CacheClient.Set(ctx, cache.Time(cache.Key(operator.Prefix+"_update_time", operator.Label), lastModified)) suite.NoError(err) users := make([]ScoreUser, 0) for _, score := range scores { users = append(users, ScoreUser{User: data.User{UserId: score.Id}, Score: score.Score}) err = suite.DataClient.BatchInsertUsers(ctx, []data.User{{UserId: score.Id}}) suite.NoError(err) } apitest.New(). Handler(suite.handler). Get(operator.Get). Header("Cookie", suite.cookie). Expect(suite.T()). Status(http.StatusOK). HeaderPresent("Last-Modified"). Body(marshal(suite.T(), users)). End() } } func (suite *MasterAPITestSuite) TestFeedback() { ctx := suite.T().Context() // insert feedback feedback := []DetailedFeedback{ {FeedbackType: "click", UserId: "0", Item: data.Item{ItemId: "0"}}, {FeedbackType: "click", UserId: "0", Item: data.Item{ItemId: "2"}}, {FeedbackType: "click", UserId: "0", Item: data.Item{ItemId: "4"}}, {FeedbackType: "click", UserId: "0", Item: data.Item{ItemId: "6"}}, {FeedbackType: "click", UserId: "0", Item: data.Item{ItemId: "8"}}, } for _, v := range feedback { err := suite.DataClient.BatchInsertFeedback(ctx, []data.Feedback{{ FeedbackKey: data.FeedbackKey{FeedbackType: v.FeedbackType, UserId: v.UserId, ItemId: v.Item.ItemId}, }}, true, true, true) suite.NoError(err) } // get feedback apitest.New(). Handler(suite.handler). Get("/api/dashboard/user/0/feedback/click"). Header("Cookie", suite.cookie). Expect(suite.T()). Status(http.StatusOK). Body(marshal(suite.T(), feedback)). End() } func (suite *MasterAPITestSuite) TestGetRecommends() { // inset recommendation itemIds := []cache.Score{ {Id: "1", Score: 99, Categories: []string{""}}, {Id: "2", Score: 98, Categories: []string{""}}, {Id: "3", Score: 97, Categories: []string{""}}, {Id: "4", Score: 96, Categories: []string{""}}, {Id: "5", Score: 95, Categories: []string{""}}, {Id: "6", Score: 94, Categories: []string{""}}, {Id: "7", Score: 93, Categories: []string{""}}, {Id: "8", Score: 92, Categories: []string{""}}, } ctx := suite.T().Context() err := suite.CacheClient.AddScores(ctx, cache.Recommend, "0", itemIds) suite.NoError(err) // insert feedback feedback := []data.Feedback{ {FeedbackKey: data.FeedbackKey{FeedbackType: "a", UserId: "0", ItemId: "2"}}, {FeedbackKey: data.FeedbackKey{FeedbackType: "a", UserId: "0", ItemId: "4"}}, } err = suite.DataClient.BatchInsertFeedback(ctx, feedback, true, true, true) suite.NoError(err) // insert items for _, item := range itemIds { err = suite.DataClient.BatchInsertItems(ctx, []data.Item{{ItemId: item.Id}}) suite.NoError(err) } apitest.New(). Handler(suite.handler). Get("/api/dashboard/recommend/0"). Header("Cookie", suite.cookie). Expect(suite.T()). Status(http.StatusOK). Body(marshal(suite.T(), []ScoredItem{ {data.Item{ItemId: "1"}, 99}, {data.Item{ItemId: "3"}, 97}, {data.Item{ItemId: "5"}, 95}, {data.Item{ItemId: "6"}, 94}, {data.Item{ItemId: "7"}, 93}, {data.Item{ItemId: "8"}, 92}, })). End() } func (suite *MasterAPITestSuite) TestGetNonPersonalizedRecommends() { ctx := suite.T().Context() // insert offline recommendation err := suite.CacheClient.AddScores(ctx, cache.Recommend, "0", []cache.Score{ {Id: "1", Score: 99, Categories: []string{""}}, {Id: "2", Score: 98, Categories: []string{""}}, {Id: "3", Score: 97, Categories: []string{""}}, }) suite.NoError(err) // insert non-personalized latest recommendation err = suite.CacheClient.AddScores(ctx, cache.NonPersonalized, "popular", []cache.Score{ {Id: "10", Score: 100, Categories: []string{""}}, {Id: "20", Score: 99, Categories: []string{""}}, {Id: "30", Score: 98, Categories: []string{""}}, }) suite.NoError(err) // insert items err = suite.DataClient.BatchInsertItems(ctx, []data.Item{ {ItemId: "1", Timestamp: time.Date(2020, 1, 1, 1, 1, 1, 1, time.UTC)}, {ItemId: "2", Timestamp: time.Date(2021, 1, 1, 1, 1, 1, 1, time.UTC)}, {ItemId: "3", Timestamp: time.Date(2022, 1, 1, 1, 1, 1, 1, time.UTC)}, {ItemId: "10", Timestamp: time.Date(2020, 1, 1, 1, 1, 1, 1, time.UTC)}, {ItemId: "20", Timestamp: time.Date(2021, 1, 1, 1, 1, 1, 1, time.UTC)}, {ItemId: "30", Timestamp: time.Date(2022, 1, 1, 1, 1, 1, 1, time.UTC)}, }) suite.NoError(err) apitest.New(). Handler(suite.handler). Get("/api/dashboard/recommend/0/non-personalized/popular"). Header("Cookie", suite.cookie). Expect(suite.T()). Status(http.StatusOK). Body(marshal(suite.T(), []ScoredItem{ {data.Item{ItemId: "10", Timestamp: time.Date(2020, 1, 1, 1, 1, 1, 1, time.UTC)}, 100}, {data.Item{ItemId: "20", Timestamp: time.Date(2021, 1, 1, 1, 1, 1, 1, time.UTC)}, 99}, {data.Item{ItemId: "30", Timestamp: time.Date(2022, 1, 1, 1, 1, 1, 1, time.UTC)}, 98}, })). End() } func (suite *MasterAPITestSuite) TestGetExternal() { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Query().Get("user_id") == "1" { fmt.Fprintln(w, `["x","y"]`) return } http.NotFound(w, r) })) defer ts.Close() script := fmt.Sprintf(`fetch("%s?user_id=" + user_id).body`, ts.URL) scriptBase64 := base64.StdEncoding.EncodeToString([]byte(script)) apitest.New(). Handler(suite.handler). Get("/api/dashboard/external"). Header("Cookie", suite.cookie). Query("script", scriptBase64). Query("user-id", "1"). Expect(suite.T()). Status(http.StatusOK). Body(marshal(suite.T(), []string{"x", "y"})). End() } func (suite *MasterAPITestSuite) TestPurge() { ctx := suite.T().Context() // insert data err := suite.CacheClient.Set(ctx, cache.String("key", "value")) suite.NoError(err) ret, err := suite.CacheClient.Get(ctx, "key").String() suite.NoError(err) suite.Equal("value", ret) err = suite.CacheClient.AddScores(ctx, "sorted", "", []cache.Score{ {Id: "a", Score: 1, Categories: []string{""}}, {Id: "b", Score: 2, Categories: []string{""}}, {Id: "c", Score: 3, Categories: []string{""}}}) suite.NoError(err) z, err := suite.CacheClient.SearchScores(ctx, "sorted", "", []string{""}, 0, -1) suite.NoError(err) suite.ElementsMatch([]cache.Score{ {Id: "a", Score: 1, Categories: []string{""}}, {Id: "b", Score: 2, Categories: []string{""}}, {Id: "c", Score: 3, Categories: []string{""}}}, z) err = suite.DataClient.BatchInsertFeedback(ctx, lo.Map(lo.Range(100), func(t int, i int) data.Feedback { return data.Feedback{FeedbackKey: data.FeedbackKey{ FeedbackType: "click", UserId: strconv.Itoa(t), ItemId: strconv.Itoa(t), }} }), true, true, true) suite.NoError(err) _, users, err := suite.DataClient.GetUsers(ctx, "", 100) suite.NoError(err) suite.Equal(100, len(users)) _, items, err := suite.DataClient.GetItems(ctx, "", 100, nil) suite.NoError(err) suite.Equal(100, len(items)) _, feedbacks, err := suite.DataClient.GetFeedback(ctx, "", 100, nil, lo.ToPtr(time.Now())) suite.NoError(err) suite.Equal(100, len(feedbacks)) // purge data req := httptest.NewRequest("POST", "https://example.com/", strings.NewReader("check_list=delete_users,delete_items,delete_feedback,delete_cache")) req.Header.Set("Cookie", suite.cookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") w := httptest.NewRecorder() suite.purge(w, req) suite.Equal(http.StatusOK, w.Code) v := suite.CacheClient.Get(ctx, "key") suite.False(v.Exists()) z, err = suite.CacheClient.SearchScores(ctx, "sorted", "", []string{""}, 0, -1) suite.NoError(err) suite.Empty(z) _, users, err = suite.DataClient.GetUsers(ctx, "", 100) suite.NoError(err) suite.Empty(users) _, items, err = suite.DataClient.GetItems(ctx, "", 100, nil) suite.NoError(err) suite.Empty(items) _, feedbacks, err = suite.DataClient.GetFeedback(ctx, "", 100, nil, lo.ToPtr(time.Now())) suite.NoError(err) suite.Empty(feedbacks) } func (suite *MasterAPITestSuite) TestConfig() { suite.Config.Recommend.DataSource.PositiveFeedbackTypes = []expression.FeedbackTypeExpression{ expression.MustParseFeedbackTypeExpression("a")} suite.Config.Recommend.DataSource.ReadFeedbackTypes = []expression.FeedbackTypeExpression{ expression.MustParseFeedbackTypeExpression("b")} suite.Config.Recommend.Ranker.Recommenders = []string{"latest"} apitest.New(). Handler(suite.handler). Get("/api/dashboard/config"). Header("Cookie", suite.cookie). Expect(suite.T()). Status(http.StatusOK). Body(marshal(suite.T(), formatConfig(convertToMapStructure(suite.T(), suite.Config)))). End() suite.Config.Master.DashboardRedacted = true redactedConfig := formatConfig(convertToMapStructure(suite.T(), suite.Config)) delete(redactedConfig, "database") apitest.New(). Handler(suite.handler). Get("/api/dashboard/config"). Header("Cookie", suite.cookie). Expect(suite.T()). Status(http.StatusOK). Body(marshal(suite.T(), redactedConfig)). End() suite.Config.Master.DashboardRedacted = false newConfig := *suite.Config newConfig.Recommend.Ranker.Type = "llm" apitest.New(). Handler(suite.handler). Post("/api/dashboard/config"). Header("Cookie", suite.cookie). Header("Content-Type", "application/json"). Body(marshal(suite.T(), formatConfig(convertToMapStructure(suite.T(), newConfig)))). Expect(suite.T()). Status(http.StatusOK). End() suite.Equal("llm", suite.Config.Recommend.Ranker.Type) suite.NotEmpty(suite.scheduled) newConfig.Recommend.Ranker.Type = "xxx" apitest.New(). Handler(suite.handler). Post("/api/dashboard/config"). Header("Cookie", suite.cookie). JSON(newConfig). Expect(suite.T()). Status(http.StatusBadRequest). End() suite.Equal("llm", suite.Config.Recommend.Ranker.Type) configBytes, err := json.Marshal(suite.Config.Recommend) suite.NoError(err) err = suite.metaStore.Put(meta.RECOMMEND_CONFIG, string(configBytes)) suite.NoError(err) apitest.New(). Handler(suite.handler). Delete("/api/dashboard/config"). Header("Cookie", suite.cookie). Expect(suite.T()). Status(http.StatusOK). End() value, err := suite.metaStore.Get(meta.RECOMMEND_CONFIG) suite.NoError(err) suite.Nil(value) } func (suite *MasterAPITestSuite) TestGetConfigSchema() { apitest.New(). Handler(suite.handler). Get("/api/dashboard/config/schema"). Header("Cookie", suite.cookie). Expect(suite.T()). Status(http.StatusOK). Body(marshal(suite.T(), jsonschema.Reflect(suite.Config))). End() } func (suite *MasterAPITestSuite) TestGetTimeseries() { ctx := suite.T().Context() err := suite.CacheClient.AddTimeSeriesPoints(ctx, []cache.TimeSeriesPoint{ {Name: "test_timeseries", Timestamp: time.Now().Add(-24 * time.Hour), Value: 1}, {Name: "test_timeseries", Timestamp: time.Now().Add(-48 * time.Hour), Value: 2}, {Name: "test_timeseries", Timestamp: time.Now().Add(-72 * time.Hour), Value: 3}, }) suite.NoError(err) req := httptest.NewRequest("GET", "/api/dashboard/timeseries/test_timeseries", nil) req.Header.Set("Cookie", suite.cookie) w := httptest.NewRecorder() suite.handler.ServeHTTP(w, req) suite.Equal(http.StatusOK, w.Code, w.Body.String()) var got []cache.TimeSeriesPoint err = json.Unmarshal(w.Body.Bytes(), &got) suite.NoError(err) suite.Len(got, 3) } func (suite *MasterAPITestSuite) TestGetRankerPrompt() { ctx := suite.T().Context() // insert user user := data.User{UserId: "u1"} err := suite.DataClient.BatchInsertUsers(ctx, []data.User{user}) suite.NoError(err) // insert items for feedback (hidden items) feedbackItems := make([]data.Item, 12) feedbacks := make([]data.Feedback, 12) for i := 0; i < 12; i++ { itemId := fmt.Sprintf("fb-%02d", i) feedbackItems[i] = data.Item{ ItemId: itemId, IsHidden: true, Timestamp: time.Date(2020, 1, 1, 0, 0, i, 0, time.UTC), } feedbacks[i] = data.Feedback{ FeedbackKey: data.FeedbackKey{ FeedbackType: "click", UserId: user.UserId, ItemId: itemId, }, Timestamp: time.Date(2021, 1, 1, 0, 0, i, 0, time.UTC), } } err = suite.DataClient.BatchInsertItems(ctx, feedbackItems) suite.NoError(err) err = suite.DataClient.BatchInsertFeedback(ctx, feedbacks, true, true, true) suite.NoError(err) // insert latest items (visible) latestItems := []data.Item{ {ItemId: "lt-1", IsHidden: false, Timestamp: time.Date(2024, 1, 1, 0, 0, 1, 0, time.UTC)}, {ItemId: "lt-2", IsHidden: false, Timestamp: time.Date(2024, 1, 1, 0, 0, 2, 0, time.UTC)}, {ItemId: "lt-3", IsHidden: false, Timestamp: time.Date(2024, 1, 1, 0, 0, 3, 0, time.UTC)}, } err = suite.DataClient.BatchInsertItems(ctx, latestItems) suite.NoError(err) // render template queryTpl := "user={{ user.UserId }}\n" + "feedback={% for f in feedback %}{{ f.Item.ItemId }}{% if not loop.last %}, {% endif %}{% endfor %}" documentTpl := "item={{ item.ItemId }}" queryTplBase64 := base64.StdEncoding.EncodeToString([]byte(queryTpl)) documentTplBase64 := base64.StdEncoding.EncodeToString([]byte(documentTpl)) // latest 10 feedback items: fb-11 to fb-02 feedbackList := []string{} for i := 11; i >= 2; i-- { feedbackList = append(feedbackList, fmt.Sprintf("fb-%02d", i)) } expectedQuery := fmt.Sprintf( "user=%s\nfeedback=%s", user.UserId, strings.Join(feedbackList, ", "), ) expectedDocuments := []string{"item=lt-3", "item=lt-2", "item=lt-1"} apitest.New(). Handler(suite.handler). Get("/api/dashboard/ranker/prompt"). Header("Cookie", suite.cookie). Query("query-template", queryTplBase64). Query("document-template", documentTplBase64). Query("user-id", user.UserId). Expect(suite.T()). Status(http.StatusOK). Body(marshal(suite.T(), RerankerPrompt{ Query: expectedQuery, Documents: expectedDocuments, })). End() } func (suite *MasterAPITestSuite) TestDumpAndRestore() { ctx := suite.T().Context() // insert users users := make([]data.User, batchSize+1) for i := range users { users[i] = data.User{ UserId: fmt.Sprintf("%05d", i), Labels: map[string]any{"a": fmt.Sprintf("%d", 2*i+1), "b": fmt.Sprintf("%d", 2*i+2)}, } } err := suite.DataClient.BatchInsertUsers(ctx, users) suite.NoError(err) // insert items items := make([]data.Item, batchSize+1) for i := range items { items[i] = data.Item{ ItemId: fmt.Sprintf("%05d", i), Labels: map[string]any{"a": fmt.Sprintf("%d", 2*i+1), "b": fmt.Sprintf("%d", 2*i+2)}, } } err = suite.DataClient.BatchInsertItems(ctx, items) suite.NoError(err) // insert feedback feedback := make([]data.Feedback, batchSize+1) for i := range feedback { feedback[i] = data.Feedback{ FeedbackKey: data.FeedbackKey{ FeedbackType: "click", UserId: fmt.Sprintf("%05d", i), ItemId: fmt.Sprintf("%05d", i), }, } } err = suite.DataClient.BatchInsertFeedback(ctx, feedback, true, true, true) suite.NoError(err) // dump data req := httptest.NewRequest("GET", "https://example.com/", nil) req.Header.Set("Cookie", suite.cookie) w := httptest.NewRecorder() suite.dump(w, req) suite.Equal(http.StatusOK, w.Code) // restore data err = suite.DataClient.Purge() suite.NoError(err) req = httptest.NewRequest("POST", "https://example.com/", bytes.NewReader(w.Body.Bytes())) req.Header.Set("Cookie", suite.cookie) req.Header.Set("Content-Type", "application/octet-stream") w = httptest.NewRecorder() suite.restore(w, req) suite.Equal(http.StatusOK, w.Code) // check data _, returnUsers, err := suite.DataClient.GetUsers(ctx, "", len(users)) suite.NoError(err) if suite.Equal(len(users), len(returnUsers)) { suite.Equal(users, returnUsers) } _, returnItems, err := suite.DataClient.GetItems(ctx, "", len(items), nil) suite.NoError(err) if suite.Equal(len(items), len(returnItems)) { suite.Equal(items, returnItems) } _, returnFeedback, err := suite.DataClient.GetFeedback(ctx, "", len(feedback), nil, lo.ToPtr(time.Now())) suite.NoError(err) if suite.Equal(len(feedback), len(returnFeedback)) { suite.Equal(feedback, returnFeedback) } suite.NotEmpty(suite.scheduled) } func (suite *MasterAPITestSuite) TestExportAndImport() { ctx := suite.T().Context() // insert users users := make([]data.User, batchSize+1) for i := range users { users[i] = data.User{ UserId: fmt.Sprintf("%05d", i), Labels: map[string]any{"a": fmt.Sprintf("%d", 2*i+1), "b": fmt.Sprintf("%d", 2*i+2)}, } } err := suite.DataClient.BatchInsertUsers(ctx, users) suite.NoError(err) // insert items items := make([]data.Item, batchSize+1) for i := range items { items[i] = data.Item{ ItemId: fmt.Sprintf("%05d", i), Labels: map[string]any{"a": fmt.Sprintf("%d", 2*i+1), "b": fmt.Sprintf("%d", 2*i+2)}, } } err = suite.DataClient.BatchInsertItems(ctx, items) suite.NoError(err) // insert feedback feedback := make([]data.Feedback, batchSize+1) for i := range feedback { feedback[i] = data.Feedback{ FeedbackKey: data.FeedbackKey{ FeedbackType: "click", UserId: fmt.Sprintf("%05d", i), ItemId: fmt.Sprintf("%05d", i), }, Value: 1.0, } } err = suite.DataClient.BatchInsertFeedback(ctx, feedback, true, true, true) suite.NoError(err) // export users req := httptest.NewRequest("GET", "https://example.com/", nil) req.Header.Set("Cookie", suite.cookie) w := httptest.NewRecorder() suite.importExportUsers(w, req) suite.Equal(http.StatusOK, w.Code) usersData := w.Body.Bytes() // export items req = httptest.NewRequest("GET", "https://example.com/", nil) req.Header.Set("Cookie", suite.cookie) w = httptest.NewRecorder() suite.importExportItems(w, req) suite.Equal(http.StatusOK, w.Code) itemsData := w.Body.Bytes() // export feedback req = httptest.NewRequest("GET", "https://example.com/", nil) req.Header.Set("Cookie", suite.cookie) w = httptest.NewRecorder() suite.importExportFeedback(w, req) suite.Equal(http.StatusOK, w.Code) feedbackData := w.Body.Bytes() err = suite.DataClient.Purge() suite.NoError(err) // import users buf := bytes.NewBuffer(nil) writer := multipart.NewWriter(buf) file, err := writer.CreateFormFile("file", "users.jsonl") suite.NoError(err) _, err = file.Write(usersData) suite.NoError(err) err = writer.Close() suite.NoError(err) req = httptest.NewRequest("POST", "https://example.com/", buf) req.Header.Set("Cookie", suite.cookie) req.Header.Set("Content-Type", writer.FormDataContentType()) w = httptest.NewRecorder() suite.importExportUsers(w, req) suite.Equal(http.StatusOK, w.Code) // import items buf = bytes.NewBuffer(nil) writer = multipart.NewWriter(buf) file, err = writer.CreateFormFile("file", "items.jsonl") suite.NoError(err) _, err = file.Write(itemsData) suite.NoError(err) err = writer.Close() suite.NoError(err) req = httptest.NewRequest("POST", "https://example.com/", buf) req.Header.Set("Cookie", suite.cookie) req.Header.Set("Content-Type", writer.FormDataContentType()) w = httptest.NewRecorder() suite.importExportItems(w, req) suite.Equal(http.StatusOK, w.Code) // import feedback buf = bytes.NewBuffer(nil) writer = multipart.NewWriter(buf) file, err = writer.CreateFormFile("file", "feedback.jsonl") suite.NoError(err) _, err = file.Write(feedbackData) suite.NoError(err) err = writer.Close() suite.NoError(err) req = httptest.NewRequest("POST", "https://example.com/", buf) req.Header.Set("Cookie", suite.cookie) req.Header.Set("Content-Type", writer.FormDataContentType()) w = httptest.NewRecorder() suite.importExportFeedback(w, req) suite.Equal(http.StatusOK, w.Code) // check data _, returnUsers, err := suite.DataClient.GetUsers(ctx, "", len(users)) suite.NoError(err) if suite.Equal(len(users), len(returnUsers)) { suite.Equal(users, returnUsers) } _, returnItems, err := suite.DataClient.GetItems(ctx, "", len(items), nil) suite.NoError(err) if suite.Equal(len(items), len(returnItems)) { suite.Equal(items, returnItems) } _, returnFeedback, err := suite.DataClient.GetFeedback(ctx, "", len(feedback), nil, lo.ToPtr(time.Now())) suite.NoError(err) if suite.Equal(len(feedback), len(returnFeedback)) { suite.Equal(feedback, returnFeedback) } } func (suite *MasterAPITestSuite) TestChat() { content := "In my younger and more vulnerable years my father gave me some advice that I've been turning over in" + " my mind ever since. \"Whenever you feel like criticizing any one,\" he told me, \" just remember that all " + "the people in this world haven't had the advantages that you've had.\"" buf := strings.NewReader(content) req := httptest.NewRequest("POST", "https://example.com/", buf) req.Header.Set("Cookie", suite.cookie) w := httptest.NewRecorder() suite.chat(w, req) suite.Equal(http.StatusOK, w.Code, w.Body.String()) suite.Equal(content, w.Body.String()) } func TestMasterAPI(t *testing.T) { suite.Run(t, new(MasterAPITestSuite)) } ================================================ FILE: master/rpc.go ================================================ // Copyright 2021 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package master import ( "context" "encoding/json" "time" "github.com/gorse-io/gorse/common/monitor" "github.com/gorse-io/gorse/protocol" "github.com/gorse-io/gorse/storage/meta" "github.com/juju/errors" ) // GetMeta returns latest configuration. func (m *Master) GetMeta(ctx context.Context, nodeInfo *protocol.NodeInfo) (*protocol.Meta, error) { // register node node := &meta.Node{ UUID: nodeInfo.Uuid, Hostname: nodeInfo.Hostname, Type: nodeInfo.NodeType.String(), Version: nodeInfo.BinaryVersion, UpdateTime: time.Now().UTC(), } if err := m.metaStore.UpdateNode(node); err != nil { return nil, err } // marshall config s, err := json.Marshal(m.Config) if err != nil { return nil, err } // save ranking model version m.collaborativeFilteringModelMutex.RLock() collaborativeFilteringModelId := m.collaborativeFilteringMeta.ID m.collaborativeFilteringModelMutex.RUnlock() // save click model version m.clickThroughRateModelMutex.RLock() clickThroughRateModelId := m.clickThroughRateMeta.ID m.clickThroughRateModelMutex.RUnlock() // collect nodes workers := make([]string, 0) servers := make([]string, 0) nodes, err := m.metaStore.ListNodes() if err != nil { return nil, err } for _, n := range nodes { switch n.Type { case protocol.NodeType_Worker.String(): workers = append(workers, n.UUID) case protocol.NodeType_Server.String(): servers = append(servers, n.UUID) } } return &protocol.Meta{ Config: string(s), CollaborativeFilteringModelId: collaborativeFilteringModelId, ClickThroughRateModelId: clickThroughRateModelId, Me: nodeInfo.Uuid, Workers: workers, Servers: servers, }, nil } func (m *Master) PushProgress( _ context.Context, in *protocol.PushProgressRequest) (*protocol.PushProgressResponse, error) { // check empty progress if len(in.Progress) == 0 { return &protocol.PushProgressResponse{}, nil } // check tracers tracer := in.Progress[0].Tracer for _, p := range in.Progress { if p.Tracer != tracer { return nil, errors.Errorf("tracers must be the same, expect %v, got %v", tracer, p.Tracer) } } // store progress m.remoteProgress.Store(tracer, monitor.DecodeProgress(in)) return &protocol.PushProgressResponse{}, nil } ================================================ FILE: master/rpc_test.go ================================================ // Copyright 2021 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package master import ( "encoding/json" "fmt" "net" "os" "path/filepath" "testing" "time" "github.com/gorse-io/gorse/common/monitor" "github.com/gorse-io/gorse/common/util" "github.com/gorse-io/gorse/config" "github.com/gorse-io/gorse/dataset" "github.com/gorse-io/gorse/model" "github.com/gorse-io/gorse/model/cf" "github.com/gorse-io/gorse/model/ctr" "github.com/gorse-io/gorse/protocol" "github.com/gorse-io/gorse/server" "github.com/gorse-io/gorse/storage/cache" "github.com/gorse-io/gorse/storage/data" "github.com/gorse-io/gorse/storage/meta" "github.com/madflojo/testcerts" "github.com/stretchr/testify/assert" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" ) func newRankingDataset() (*dataset.Dataset, *dataset.Dataset) { return dataset.NewDataset(time.Now(), 0, 0), dataset.NewDataset(time.Now(), 0, 0) } func newClickDataset() (*ctr.Dataset, *ctr.Dataset) { dataSet := &ctr.Dataset{ Index: dataset.NewUnifiedMapIndexBuilder().Build(), } return dataSet, dataSet } type mockMasterRPC struct { Master addr chan string grpcServer *grpc.Server } func newMockMasterRPC(t *testing.T) *mockMasterRPC { // create meta store metaStore, err := meta.Open(fmt.Sprintf("sqlite://%s/meta.db", t.TempDir()), time.Second) assert.NoError(t, err) err = metaStore.Init() assert.NoError(t, err) // create click model train, test := newClickDataset() fm := ctr.NewAFM(model.Params{model.NEpochs: 0}) fm.Fit(t.Context(), train, test, &ctr.FitConfig{}) // create ranking model trainSet, testSet := newRankingDataset() bpr := cf.NewBPR(model.Params{model.NEpochs: 0}) bpr.Fit(t.Context(), trainSet, testSet, cf.NewFitConfig()) return &mockMasterRPC{ Master: Master{ RestServer: server.RestServer{ Config: config.GetDefaultConfig(), CacheClient: cache.NoDatabase{}, DataClient: data.NoDatabase{}, }, metaStore: metaStore, collaborativeFilteringMeta: meta.Model[cf.Score]{ID: 123}, clickThroughRateMeta: meta.Model[ctr.Score]{ID: 456}, }, addr: make(chan string), } } func (m *mockMasterRPC) Start(t *testing.T) { listen, err := net.Listen("tcp", ":0") assert.NoError(t, err) m.addr <- listen.Addr().String() var opts []grpc.ServerOption m.grpcServer = grpc.NewServer(opts...) protocol.RegisterMasterServer(m.grpcServer, m) err = m.grpcServer.Serve(listen) assert.NoError(t, err) } func (m *mockMasterRPC) StartTLS(t *testing.T, o *util.TLSConfig) { listen, err := net.Listen("tcp", ":0") assert.NoError(t, err) m.addr <- listen.Addr().String() creds, err := util.NewServerCreds(&util.TLSConfig{ SSLCA: o.SSLCA, SSLCert: o.SSLCert, SSLKey: o.SSLKey, }) assert.NoError(t, err) m.grpcServer = grpc.NewServer(grpc.Creds(creds)) protocol.RegisterMasterServer(m.grpcServer, m) err = m.grpcServer.Serve(listen) assert.NoError(t, err) } func (m *mockMasterRPC) Stop() { _ = m.metaStore.Close() m.grpcServer.Stop() } func TestRPC(t *testing.T) { rpcServer := newMockMasterRPC(t) go rpcServer.Start(t) defer rpcServer.Stop() address := <-rpcServer.addr conn, err := grpc.Dial(address, grpc.WithTransportCredentials(insecure.NewCredentials())) assert.NoError(t, err) client := protocol.NewMasterClient(conn) ctx := t.Context() progressList := []monitor.Progress{{ Tracer: "tracer", Name: "a", Status: monitor.StatusRunning, Total: 12, Count: 6, StartTime: time.Date(2018, time.January, 1, 0, 0, 0, 0, time.Local), FinishTime: time.Date(2018, time.January, 2, 0, 0, 0, 0, time.Local), }} _, err = client.PushProgress(ctx, monitor.EncodeProgress(progressList)) assert.NoError(t, err) i, ok := rpcServer.remoteProgress.Load("tracer") assert.True(t, ok) remoteProgressList := i.([]monitor.Progress) assert.Equal(t, progressList, remoteProgressList) // test get meta _, err = client.GetMeta(ctx, &protocol.NodeInfo{NodeType: protocol.NodeType_Server, Uuid: "server1", Hostname: "yoga"}) assert.NoError(t, err) metaResp, err := client.GetMeta(ctx, &protocol.NodeInfo{NodeType: protocol.NodeType_Worker, Uuid: "worker1", Hostname: "yoga"}) assert.NoError(t, err) assert.Equal(t, int64(123), metaResp.CollaborativeFilteringModelId) assert.Equal(t, int64(456), metaResp.ClickThroughRateModelId) assert.Equal(t, "worker1", metaResp.Me) assert.Equal(t, []string{"server1"}, metaResp.Servers) assert.Equal(t, []string{"worker1"}, metaResp.Workers) var cfg config.Config err = json.Unmarshal([]byte(metaResp.Config), &cfg) assert.NoError(t, err) assert.Equal(t, rpcServer.Config, &cfg) time.Sleep(time.Second * 2) metaResp, err = client.GetMeta(ctx, &protocol.NodeInfo{NodeType: protocol.NodeType_Worker, Uuid: "worker2", Hostname: "yoga"}) assert.NoError(t, err) assert.Equal(t, []string{"worker2"}, metaResp.Workers) rpcServer.Stop() } func generateToTempFile(t *testing.T) (string, string, string) { // Generate Certificate Authority ca := testcerts.NewCA() // Create a signed Certificate and Key certs, err := ca.NewKeyPair() assert.NoError(t, err) // Write certificates to a file caFile := filepath.Join(t.TempDir(), "ca.pem") certFile := filepath.Join(t.TempDir(), "cert.pem") keyFile := filepath.Join(t.TempDir(), "key.pem") pem := ca.PublicKey() err = os.WriteFile(caFile, pem, 0640) assert.NoError(t, err) err = certs.ToFile(certFile, keyFile) assert.NoError(t, err) return caFile, certFile, keyFile } func TestSSL(t *testing.T) { caFile, certFile, keyFile := generateToTempFile(t) o := &util.TLSConfig{ SSLCA: caFile, SSLCert: certFile, SSLKey: keyFile, } rpcServer := newMockMasterRPC(t) go rpcServer.StartTLS(t, o) defer rpcServer.Stop() address := <-rpcServer.addr // success c, err := util.NewClientCreds(o) assert.NoError(t, err) conn, err := grpc.Dial(address, grpc.WithTransportCredentials(c)) assert.NoError(t, err) client := protocol.NewMasterClient(conn) _, err = client.GetMeta(t.Context(), &protocol.NodeInfo{NodeType: protocol.NodeType_Server, Uuid: "server1", Hostname: "yoga"}) assert.NoError(t, err) // insecure conn, err = grpc.Dial(address, grpc.WithInsecure()) assert.NoError(t, err) client = protocol.NewMasterClient(conn) _, err = client.GetMeta(t.Context(), &protocol.NodeInfo{NodeType: protocol.NodeType_Server, Uuid: "server1", Hostname: "yoga"}) assert.Error(t, err) // certificate mismatch caFile2, certFile2, keyFile2 := generateToTempFile(t) o2 := &util.TLSConfig{ SSLCA: caFile2, SSLCert: certFile2, SSLKey: keyFile2, } c, err = util.NewClientCreds(o2) assert.NoError(t, err) conn, err = grpc.Dial(address, grpc.WithTransportCredentials(c)) assert.NoError(t, err) client = protocol.NewMasterClient(conn) _, err = client.GetMeta(t.Context(), &protocol.NodeInfo{NodeType: protocol.NodeType_Server, Uuid: "server1", Hostname: "yoga"}) assert.Error(t, err) } ================================================ FILE: master/tasks.go ================================================ // Copyright 2020 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package master import ( "context" "sort" "strconv" "strings" "sync" "time" "github.com/c-bata/goptuna" "github.com/c-bata/goptuna/tpe" mapset "github.com/deckarep/golang-set/v2" "github.com/gorse-io/gorse/common/expression" "github.com/gorse-io/gorse/common/log" "github.com/gorse-io/gorse/common/monitor" "github.com/gorse-io/gorse/common/parallel" "github.com/gorse-io/gorse/common/sizeof" "github.com/gorse-io/gorse/config" "github.com/gorse-io/gorse/dataset" "github.com/gorse-io/gorse/logics" "github.com/gorse-io/gorse/model" "github.com/gorse-io/gorse/model/cf" "github.com/gorse-io/gorse/model/ctr" "github.com/gorse-io/gorse/storage/cache" "github.com/gorse-io/gorse/storage/data" "github.com/gorse-io/gorse/storage/meta" "github.com/gorse-io/gorse/worker" "github.com/juju/errors" "github.com/samber/lo" "go.uber.org/zap" ) const batchSize = 10000 func (m *Master) loadDataset(parent context.Context) (datasets Datasets, err error) { ctx, span := m.tracer.Start(parent, "Load Dataset", 1) defer span.End() // Build non-personalized recommenders initialStartTime := time.Now() nonPersonalizedRecommenders := make([]*logics.NonPersonalized, 0, len(m.Config.Recommend.NonPersonalized)) for _, cfg := range m.Config.Recommend.NonPersonalized { recommender, err := logics.NewNonPersonalized(cfg, m.Config.Recommend.CacheSize, initialStartTime) if err != nil { return Datasets{}, errors.Trace(err) } nonPersonalizedRecommenders = append(nonPersonalizedRecommenders, recommender) } log.Logger().Info("load dataset", zap.Any("positive_feedback_types", m.Config.Recommend.DataSource.PositiveFeedbackTypes), zap.Any("read_feedback_types", m.Config.Recommend.DataSource.ReadFeedbackTypes), zap.Uint("item_ttl", m.Config.Recommend.DataSource.ItemTTL), zap.Uint("feedback_ttl", m.Config.Recommend.DataSource.PositiveFeedbackTTL)) evaluator := NewOnlineEvaluator( m.Config.Recommend.DataSource.PositiveFeedbackTypes, m.Config.Recommend.DataSource.ReadFeedbackTypes) datasets.clickDataset, datasets.rankingDataset, err = m.LoadDataFromDatabase(ctx, m.DataClient, m.Config.Recommend.DataSource.PositiveFeedbackTypes, m.Config.Recommend.DataSource.ReadFeedbackTypes, m.Config.Recommend.DataSource.ItemTTL, m.Config.Recommend.DataSource.PositiveFeedbackTTL, evaluator, nonPersonalizedRecommenders) if err != nil { return Datasets{}, errors.Trace(err) } // save non-personalized recommenders to cache for i, recommender := range nonPersonalizedRecommenders { scores := recommender.PopAll() if err = m.CacheClient.AddScores(ctx, cache.NonPersonalized, recommender.Name(), scores); err != nil { log.Logger().Error("failed to cache non-personalized recommenders", zap.Error(err)) } if err = m.CacheClient.DeleteScores(ctx, []string{cache.NonPersonalized}, cache.ScoreCondition{ Subset: new(recommender.Name()), Before: lo.ToPtr(recommender.Timestamp()), }); err != nil { log.Logger().Error("failed to reclaim outdated items", zap.Error(err)) } if err = m.CacheClient.Set(ctx, cache.Time(cache.Key(cache.NonPersonalizedUpdateTime, recommender.Name()), recommender.Timestamp()), cache.String(cache.Key(cache.NonPersonalizedDigest, recommender.Name()), m.Config.Recommend.NonPersonalized[i].Hash()), ); err != nil { log.Logger().Error("failed to write meta", zap.Error(err)) } } // write statistics to database if err = m.CacheClient.AddTimeSeriesPoints(ctx, []cache.TimeSeriesPoint{ {Name: cache.NumUsers, Value: float64(datasets.rankingDataset.CountUsers()), Timestamp: datasets.rankingDataset.GetTimestamp()}, {Name: cache.NumItems, Value: float64(datasets.rankingDataset.CountItems()), Timestamp: datasets.rankingDataset.GetTimestamp()}, {Name: cache.NumFeedback, Value: float64(len(datasets.clickDataset.Target)), Timestamp: datasets.rankingDataset.GetTimestamp()}, {Name: cache.NumPosFeedbacks, Value: float64(datasets.clickDataset.PositiveCount), Timestamp: datasets.rankingDataset.GetTimestamp()}, {Name: cache.NumNegFeedbacks, Value: float64(datasets.clickDataset.NegativeCount), Timestamp: datasets.rankingDataset.GetTimestamp()}, }); err != nil { log.Logger().Error("failed to write timeseries points", zap.Error(err)) } UsersTotal.Set(float64(datasets.rankingDataset.CountUsers())) if err = m.CacheClient.Set(ctx, cache.Integer(cache.Key(cache.GlobalMeta, cache.NumUsers), datasets.rankingDataset.CountUsers())); err != nil { log.Logger().Error("failed to write number of users", zap.Error(err)) } ItemsTotal.Set(float64(datasets.rankingDataset.CountItems())) if err = m.CacheClient.Set(ctx, cache.Integer(cache.Key(cache.GlobalMeta, cache.NumItems), datasets.rankingDataset.CountItems())); err != nil { log.Logger().Error("failed to write number of items", zap.Error(err)) } ImplicitFeedbacksTotal.Set(float64(datasets.rankingDataset.CountFeedback())) if err = m.CacheClient.Set(ctx, cache.Integer(cache.Key(cache.GlobalMeta, cache.NumTotalPosFeedbacks), datasets.rankingDataset.CountFeedback())); err != nil { log.Logger().Error("failed to write number of positive feedbacks", zap.Error(err)) } UserLabelsTotal.Set(float64(datasets.clickDataset.Index.CountUserLabels())) if err = m.CacheClient.Set(ctx, cache.Integer(cache.Key(cache.GlobalMeta, cache.NumUserLabels), int(datasets.clickDataset.Index.CountUserLabels()))); err != nil { log.Logger().Error("failed to write number of user labels", zap.Error(err)) } ItemLabelsTotal.Set(float64(datasets.clickDataset.Index.CountItemLabels())) if err = m.CacheClient.Set(ctx, cache.Integer(cache.Key(cache.GlobalMeta, cache.NumItemLabels), int(datasets.clickDataset.Index.CountItemLabels()))); err != nil { log.Logger().Error("failed to write number of item labels", zap.Error(err)) } ImplicitFeedbacksTotal.Set(float64(datasets.rankingDataset.CountFeedback())) PositiveFeedbacksTotal.Set(float64(datasets.clickDataset.PositiveCount)) if err = m.CacheClient.Set(ctx, cache.Integer(cache.Key(cache.GlobalMeta, cache.NumValidPosFeedbacks), datasets.clickDataset.PositiveCount)); err != nil { log.Logger().Error("failed to write number of positive feedbacks", zap.Error(err)) } NegativeFeedbackTotal.Set(float64(datasets.clickDataset.NegativeCount)) if err = m.CacheClient.Set(ctx, cache.Integer(cache.Key(cache.GlobalMeta, cache.NumValidNegFeedbacks), datasets.clickDataset.NegativeCount)); err != nil { log.Logger().Error("failed to write number of negative feedbacks", zap.Error(err)) } // evaluate positive feedback rate points := evaluator.Evaluate() if err = m.CacheClient.AddTimeSeriesPoints(ctx, points); err != nil { log.Logger().Error("failed to insert measurement", zap.Error(err)) } // collect active users and items activeUsers, activeItems, inactiveUsers, inactiveItems := 0, 0, 0, 0 for _, userFeedback := range datasets.rankingDataset.GetUserFeedback() { if len(userFeedback) > 0 { activeUsers++ } else { inactiveUsers++ } } for _, itemFeedback := range datasets.rankingDataset.GetItemFeedback() { if len(itemFeedback) > 0 { activeItems++ } else { inactiveItems++ } } ActiveUsersTotal.Set(float64(activeUsers)) ActiveItemsTotal.Set(float64(activeItems)) InactiveUsersTotal.Set(float64(inactiveUsers)) InactiveItemsTotal.Set(float64(inactiveItems)) // write categories to cache categories := datasets.rankingDataset.GetCategories() categoryScores := make([]cache.Score, 0, len(categories)) for category, count := range categories { categoryScores = append(categoryScores, cache.Score{ Id: category, Score: float64(count), Timestamp: initialStartTime, }) } if err = m.CacheClient.AddScores(ctx, cache.ItemCategories, "", categoryScores); err != nil { log.Logger().Error("failed to write categories to cache", zap.Error(err)) } // split ranking dataset startTime := time.Now() datasets.rankingTrainSet, datasets.rankingTestSet = datasets.rankingDataset.SplitCF(0, 0) LoadDatasetStepSecondsVec.WithLabelValues("split_ranking_dataset").Set(time.Since(startTime).Seconds()) MemoryInUseBytesVec.WithLabelValues("collaborative_filtering_train_set").Set(float64(sizeof.DeepSize(datasets.rankingTrainSet))) MemoryInUseBytesVec.WithLabelValues("collaborative_filtering_test_set").Set(float64(sizeof.DeepSize(datasets.rankingTestSet))) // split click dataset startTime = time.Now() datasets.clickTrainSet, datasets.clickTestSet = datasets.clickDataset.Split(0.2, 0) // After splitting, clickDataset is set to nil to free memory. // WARNING: Do not access datasets.clickDataset after this point; use clickTrainSet and clickTestSet instead. datasets.clickDataset = nil LoadDatasetStepSecondsVec.WithLabelValues("split_click_dataset").Set(time.Since(startTime).Seconds()) MemoryInUseBytesVec.WithLabelValues("ranking_train_set").Set(float64(sizeof.DeepSize(datasets.clickTrainSet))) MemoryInUseBytesVec.WithLabelValues("ranking_test_set").Set(float64(sizeof.DeepSize(datasets.clickTestSet))) LoadDatasetTotalSeconds.Set(time.Since(initialStartTime).Seconds()) return } // runLoadDatasetTask loads dataset. func (m *Master) runLoadDatasetTask(ctx context.Context) error { datasets, err := m.loadDataset(ctx) if err != nil { return errors.Trace(err) } useCollaborativeFilteringTasks := !strings.EqualFold(m.Config.Recommend.Collaborative.Type, "none") useClickThroughRateTasks := strings.EqualFold(m.Config.Recommend.Ranker.Type, "fm") if err = m.updateUserToUser(ctx, datasets.rankingDataset); err != nil { log.Logger().Error("failed to update user-to-user recommendation", zap.Error(err)) } if err = m.updateItemToItem(ctx, datasets.rankingDataset); err != nil { log.Logger().Error("failed to update item-to-item recommendation", zap.Error(err)) } if useCollaborativeFilteringTasks { if err = m.trainCollaborativeFiltering(ctx, datasets.rankingTrainSet, datasets.rankingTestSet); err != nil { log.Logger().Error("failed to train collaborative filtering model", zap.Error(err)) } } if useClickThroughRateTasks { if err = m.trainClickThroughRatePrediction(ctx, datasets.clickTrainSet, datasets.clickTestSet); err != nil { log.Logger().Error("failed to train click-through rate prediction model", zap.Error(err)) } } if m.standalone { if err = m.updateRecommend(ctx); err != nil { log.Logger().Error("failed to update recommendation", zap.Error(err)) } } if err = m.collectGarbage(ctx, datasets.rankingDataset); err != nil { log.Logger().Error("failed to collect garbage in cache", zap.Error(err)) } if useCollaborativeFilteringTasks && m.Config.Recommend.Collaborative.OptimizePeriod > 0 { if err = m.optimizeCollaborativeFiltering(ctx, datasets.rankingTrainSet, datasets.rankingTestSet); err != nil { log.Logger().Error("failed to optimize collaborative filtering model", zap.Error(err)) } } if useClickThroughRateTasks && m.Config.Recommend.Ranker.OptimizePeriod > 0 { if err = m.optimizeClickThroughRatePrediction(ctx, datasets.clickTrainSet, datasets.clickTestSet); err != nil { log.Logger().Error("failed to optimize click-through rate prediction model", zap.Error(err)) } } return nil } // LoadDataFromDatabase loads dataset from data store. func (m *Master) LoadDataFromDatabase( ctx context.Context, database data.Database, posFeedbackTypes, readTypes []expression.FeedbackTypeExpression, itemTTL, positiveFeedbackTTL uint, evaluator *OnlineEvaluator, nonPersonalizedRecommenders []*logics.NonPersonalized, ) (ctrDataset *ctr.Dataset, dataSet *dataset.Dataset, err error) { // Estimate the number of users, items, and feedbacks estimatedNumUsers, err := m.DataClient.CountUsers(ctx) if err != nil { return nil, nil, errors.Trace(err) } estimatedNumItems, err := m.DataClient.CountItems(ctx) if err != nil { return nil, nil, errors.Trace(err) } estimatedNumFeedbacks, err := m.DataClient.CountFeedback(ctx) if err != nil { return nil, nil, errors.Trace(err) } dataSet = dataset.NewDataset(time.Now(), estimatedNumUsers, estimatedNumItems) newCtx, span := monitor.Start(ctx, "LoadDataFromDatabase", estimatedNumUsers+estimatedNumItems+estimatedNumFeedbacks) defer span.End() // setup time limit var feedbackTimeLimit data.ScanOption var itemTimeLimit *time.Time if itemTTL > 0 { temp := time.Now().AddDate(0, 0, -int(itemTTL)) itemTimeLimit = &temp } if positiveFeedbackTTL > 0 { temp := time.Now().AddDate(0, 0, -int(positiveFeedbackTTL)) feedbackTimeLimit = data.WithBeginTime(temp) } // STEP 1: pull users userLabelCount := make(map[string]int) userLabelFirst := make(map[string]int32) userLabelIndex := dataset.NewMapIndex() userLabels := make([][]lo.Tuple2[int32, float32], 0, estimatedNumUsers) start := time.Now() userChan, errChan := database.GetUserStream(newCtx, batchSize) for users := range userChan { for _, user := range users { dataSet.AddUser(user) userIndex := dataSet.GetUserDict().Id(user.UserId) if len(userLabels) == int(userIndex) { userLabels = append(userLabels, nil) } features := ctr.ConvertLabels(user.Labels) userLabels[userIndex] = make([]lo.Tuple2[int32, float32], 0, len(features)) for _, feature := range features { userLabelCount[feature.Name]++ // Memorize the first occurrence. if userLabelCount[feature.Name] == 1 { userLabelFirst[feature.Name] = userIndex } // Add the label to the index in second occurrence. if userLabelCount[feature.Name] == 2 { userLabelIndex.Add(feature.Name) firstUserIndex := userLabelFirst[feature.Name] userLabels[firstUserIndex] = append(userLabels[firstUserIndex], lo.Tuple2[int32, float32]{ A: userLabelIndex.ToNumber(feature.Name), B: feature.Value, }) } // Add the label to the user. if userLabelCount[feature.Name] > 1 { userLabels[userIndex] = append(userLabels[userIndex], lo.Tuple2[int32, float32]{ A: userLabelIndex.ToNumber(feature.Name), B: feature.Value, }) } } } span.Add(len(users)) } if err = <-errChan; err != nil { return nil, nil, errors.Trace(err) } log.Logger().Debug("pulled users from database", zap.Int("n_users", dataSet.CountUsers()), zap.Int32("n_user_labels", userLabelIndex.Len()), zap.Duration("used_time", time.Since(start))) LoadDatasetStepSecondsVec.WithLabelValues("load_users").Set(time.Since(start).Seconds()) // STEP 2: pull items var items []data.Item itemLabelCount := make(map[string]int) itemLabelFirst := make(map[string]int32) itemLabelIndex := dataset.NewMapIndex() itemLabels := make([][]lo.Tuple2[int32, float32], 0, estimatedNumItems) itemEmbeddingIndexer := dataset.NewMapIndex() itemEmbeddingDimension := make([]map[int]int, 0) itemEmbeddings := make([][][]float32, 0, estimatedNumItems) start = time.Now() itemChan, errChan := database.GetItemStream(newCtx, batchSize, itemTimeLimit) for batchItems := range itemChan { items = append(items, batchItems...) for _, item := range batchItems { dataSet.AddItem(item) itemIndex := dataSet.GetItemDict().Id(item.ItemId) if len(itemLabels) == int(itemIndex) { itemLabels = append(itemLabels, nil) } if len(itemEmbeddings) == int(itemIndex) { itemEmbeddings = append(itemEmbeddings, nil) } // load labels labels := ctr.ConvertLabels(item.Labels) itemLabels[itemIndex] = make([]lo.Tuple2[int32, float32], 0, len(labels)) for _, feature := range labels { itemLabelCount[feature.Name]++ // Memorize the first occurrence. if itemLabelCount[feature.Name] == 1 { itemLabelFirst[feature.Name] = itemIndex } // Add the label to the index in second occurrence. if itemLabelCount[feature.Name] == 2 { itemLabelIndex.Add(feature.Name) firstItemIndex := itemLabelFirst[feature.Name] itemLabels[firstItemIndex] = append(itemLabels[firstItemIndex], lo.Tuple2[int32, float32]{ A: itemLabelIndex.ToNumber(feature.Name), B: feature.Value, }) } // Add the label to the item. if itemLabelCount[feature.Name] > 1 { itemLabels[itemIndex] = append(itemLabels[itemIndex], lo.Tuple2[int32, float32]{ A: itemLabelIndex.ToNumber(feature.Name), B: feature.Value, }) } } // load embeddings embeddings := ctr.ConvertEmbeddings(item.Labels) itemEmbeddings[itemIndex] = make([][]float32, 0, len(embeddings)) for _, embedding := range embeddings { itemEmbeddingIndexer.Add(embedding.Name) itemEmbeddingIndex := itemEmbeddingIndexer.ToNumber(embedding.Name) for len(itemEmbeddings[itemIndex]) <= int(itemEmbeddingIndex) { itemEmbeddings[itemIndex] = append(itemEmbeddings[itemIndex], nil) } itemEmbeddings[itemIndex][itemEmbeddingIndex] = embedding.Value for len(itemEmbeddingDimension) <= int(itemEmbeddingIndex) { itemEmbeddingDimension = append(itemEmbeddingDimension, make(map[int]int)) } itemEmbeddingDimension[itemEmbeddingIndex][len(itemEmbeddings[itemIndex][itemEmbeddingIndex])]++ } } span.Add(len(batchItems)) } if err = <-errChan; err != nil { return nil, nil, errors.Trace(err) } log.Logger().Debug("pulled items from database", zap.Int("n_items", dataSet.CountItems()), zap.Int32("n_item_labels", itemLabelIndex.Len()), zap.Duration("used_time", time.Since(start))) LoadDatasetStepSecondsVec.WithLabelValues("load_items").Set(time.Since(start).Seconds()) // create positive set positiveSet := make([]mapset.Set[int32], dataSet.CountUsers()) for i := range positiveSet { positiveSet[i] = mapset.NewSet[int32]() } // split item groups sort.Slice(items, func(i, j int) bool { return items[i].ItemId < items[j].ItemId }) itemGroups := parallel.Split(items, m.Config.Master.NumJobs) // STEP 3: pull positive feedback var mu sync.Mutex var posFeedbackCount int start = time.Now() err = parallel.Parallel(newCtx, len(itemGroups), m.Config.Master.NumJobs, func(_, i int) error { var itemFeedback []data.Feedback var itemGroupIndex int itemHasFeedback := make([]bool, len(itemGroups[i])) feedbackChan, errChan := database.GetFeedbackStream(newCtx, batchSize, data.WithBeginItemId(itemGroups[i][0].ItemId), data.WithEndItemId(itemGroups[i][len(itemGroups[i])-1].ItemId), feedbackTimeLimit, data.WithEndTime(*m.Config.Now()), data.WithFeedbackTypes(posFeedbackTypes...), data.WithOrderByItemId()) for feedback := range feedbackChan { for _, f := range feedback { // convert user and item id to index userIndex := dataSet.GetUserDict().Id(f.UserId) if userIndex == dataset.NotId { continue } itemIndex := dataSet.GetItemDict().Id(f.ItemId) if itemIndex == dataset.NotId { continue } // insert feedback to positive set positiveSet[userIndex].Add(itemIndex) mu.Lock() posFeedbackCount++ // insert feedback to evaluator evaluator.Add(f.FeedbackType, f.Value, userIndex, itemIndex, f.Timestamp) mu.Unlock() // append item feedback if len(itemFeedback) == 0 || itemFeedback[len(itemFeedback)-1].ItemId == f.ItemId { itemFeedback = append(itemFeedback, f) } else { // add item to non-personalized recommenders itemHasFeedback[itemGroupIndex] = true for _, recommender := range nonPersonalizedRecommenders { recommender.Push(itemGroups[i][itemGroupIndex], itemFeedback) } itemFeedback = itemFeedback[:0] itemFeedback = append(itemFeedback, f) } // find item group index for itemGroupIndex = 0; itemGroupIndex < len(itemGroups[i]); itemGroupIndex++ { if itemGroups[i][itemGroupIndex].ItemId == f.ItemId { break } } dataSet.AddFeedback(f.UserId, f.ItemId, f.Timestamp) } span.Add(len(feedback)) } // add item to non-personalized recommenders if len(itemFeedback) > 0 { itemHasFeedback[itemGroupIndex] = true for _, recommender := range nonPersonalizedRecommenders { recommender.Push(itemGroups[i][itemGroupIndex], itemFeedback) } } for index, hasFeedback := range itemHasFeedback { if !hasFeedback { for _, recommender := range nonPersonalizedRecommenders { recommender.Push(itemGroups[i][index], nil) } } } if err = <-errChan; err != nil { return errors.Trace(err) } return nil }) if err != nil { return nil, nil, errors.Trace(err) } log.Logger().Debug("pulled positive feedback from database", zap.Int("n_positive_feedback", posFeedbackCount), zap.Duration("used_time", time.Since(start))) LoadDatasetStepSecondsVec.WithLabelValues("load_positive_feedback").Set(time.Since(start).Seconds()) // create negative set negativeSet := make([]mapset.Set[int32], dataSet.CountUsers()) for i := range negativeSet { negativeSet[i] = mapset.NewSet[int32]() } // STEP 4: pull negative feedback start = time.Now() var negativeFeedbackCount float64 err = parallel.Parallel(newCtx, len(itemGroups), m.Config.Master.NumJobs, func(_, i int) error { feedbackChan, errChan := database.GetFeedbackStream(newCtx, batchSize, data.WithBeginItemId(itemGroups[i][0].ItemId), data.WithEndItemId(itemGroups[i][len(itemGroups[i])-1].ItemId), feedbackTimeLimit, data.WithEndTime(*m.Config.Now()), data.WithFeedbackTypes(readTypes...)) for feedback := range feedbackChan { for _, f := range feedback { userIndex := dataSet.GetUserDict().Id(f.UserId) if userIndex == dataset.NotId { continue } itemIndex := dataSet.GetItemDict().Id(f.ItemId) if itemIndex == dataset.NotId { continue } negativeSet[userIndex].Add(itemIndex) mu.Lock() negativeFeedbackCount++ evaluator.Add(f.FeedbackType, f.Value, userIndex, itemIndex, f.Timestamp) mu.Unlock() } span.Add(len(feedback)) } if err = <-errChan; err != nil { return errors.Trace(err) } return nil }) if err != nil { return nil, nil, errors.Trace(err) } log.Logger().Debug("pulled negative feedback from database", zap.Int("n_negative_feedback", int(negativeFeedbackCount)), zap.Duration("used_time", time.Since(start))) LoadDatasetStepSecondsVec.WithLabelValues("load_negative_feedback").Set(time.Since(start).Seconds()) // STEP 5: create click-through rate dataset start = time.Now() unifiedIndex := dataset.NewUnifiedMapIndexBuilder() unifiedIndex.ItemIndex = dataSet.GetItemDict().ToIndex() unifiedIndex.UserIndex = dataSet.GetUserDict().ToIndex() unifiedIndex.ItemLabelIndex = itemLabelIndex unifiedIndex.UserLabelIndex = userLabelIndex ctrDataset = &ctr.Dataset{ Index: unifiedIndex.Build(), UserLabels: userLabels, ItemLabels: itemLabels, Users: make([]int32, 0, estimatedNumFeedbacks), Items: make([]int32, 0, estimatedNumFeedbacks), Target: make([]float32, 0, estimatedNumFeedbacks), } ctrDataset.ItemEmbeddingIndex = itemEmbeddingIndexer ctrDataset.ItemEmbeddingDimension = make([]int, len(itemEmbeddingDimension)) for i, dimension := range itemEmbeddingDimension { for dim, cnt := range dimension { if cnt > itemEmbeddingDimension[i][ctrDataset.ItemEmbeddingDimension[i]] { ctrDataset.ItemEmbeddingDimension[i] = dim } } } for i, embeddings := range itemEmbeddings { for j, embedding := range embeddings { if len(embedding) != ctrDataset.ItemEmbeddingDimension[j] { itemEmbeddings[i][j] = nil } } } ctrDataset.ItemEmbeddings = itemEmbeddings for userIndex := range positiveSet { // insert positive feedback for _, itemIndex := range positiveSet[userIndex].ToSlice() { ctrDataset.Users = append(ctrDataset.Users, int32(userIndex)) ctrDataset.Items = append(ctrDataset.Items, itemIndex) ctrDataset.Target = append(ctrDataset.Target, 1) ctrDataset.PositiveCount++ } // insert negative feedback for _, itemIndex := range negativeSet[userIndex].ToSlice() { ctrDataset.Users = append(ctrDataset.Users, int32(userIndex)) ctrDataset.Items = append(ctrDataset.Items, itemIndex) ctrDataset.Target = append(ctrDataset.Target, -1) ctrDataset.NegativeCount++ } // release positive set and negative set positiveSet[userIndex] = nil negativeSet[userIndex] = nil } log.Logger().Debug("created ranking dataset", zap.Int("n_valid_positive", ctrDataset.PositiveCount), zap.Int("n_valid_negative", ctrDataset.NegativeCount), zap.Duration("used_time", time.Since(start))) LoadDatasetStepSecondsVec.WithLabelValues("create_ranking_dataset").Set(time.Since(start).Seconds()) return ctrDataset, dataSet, nil } func (m *Master) updateItemToItem(parent context.Context, dataset *dataset.Dataset) error { if len(m.Config.Recommend.ItemToItem) == 0 { return nil } ctx, span := m.tracer.Start(parent, "Generate item-to-item recommendation", len(dataset.GetItems())*(len(m.Config.Recommend.ItemToItem))*2) defer span.End() // Build item-to-item recommenders itemToItemRecommenders := make([]logics.ItemToItem, 0, len(m.Config.Recommend.ItemToItem)) for _, cfg := range m.Config.Recommend.ItemToItem { recommender, err := logics.NewItemToItem(cfg, m.Config.Recommend.CacheSize, dataset.GetTimestamp(), &logics.ItemToItemOptions{ TagsIDF: dataset.GetItemColumnValuesIDF(), UsersIDF: dataset.GetUserIDF(), OpenAIConfig: m.Config.OpenAI, }) if err != nil { return errors.Trace(err) } itemToItemRecommenders = append(itemToItemRecommenders, recommender) } // Push items to item-to-item recommenders if err := parallel.ForEach(ctx, dataset.GetItems(), m.Config.Master.NumJobs, func(i int, item data.Item) { for _, recommender := range itemToItemRecommenders { recommender.Push(&item, dataset.GetItemFeedback()[i]) span.Add(1) } }); err != nil { return errors.Trace(err) } // Save item-to-item recommendations to cache for i, recommender := range itemToItemRecommenders { if err := parallel.For(ctx, recommender.Count(), m.Config.Master.NumJobs, func(j int) { item := recommender.Get(j) itemToItemConfig := m.Config.Recommend.ItemToItem[i] if m.needUpdateItemToItem(ctx, item.ItemId, itemToItemConfig) { defer span.Add(1) score := recommender.PopAll(j) if score == nil { return } log.Logger().Debug("update item-to-item recommendation", zap.String("item_id", item.ItemId), zap.String("name", itemToItemConfig.Name), zap.Int("n_recommendations", len(score))) // Save item-to-item recommendation to cache if err := m.CacheClient.AddScores(ctx, cache.ItemToItem, cache.Key(itemToItemConfig.Name, item.ItemId), score); err != nil { log.Logger().Error("failed to save item-to-item recommendation to cache", zap.String("item_id", item.ItemId), zap.Error(err)) return } // Save item-to-item digest and last update time to cache if err := m.CacheClient.Set(ctx, cache.String(cache.Key(cache.ItemToItemDigest, itemToItemConfig.Name, item.ItemId), itemToItemConfig.Hash(&m.Config.Recommend)), cache.Time(cache.Key(cache.ItemToItemUpdateTime, itemToItemConfig.Name, item.ItemId), time.Now()), ); err != nil { log.Logger().Error("failed to save item-to-item digest to cache", zap.String("item_id", item.ItemId), zap.Error(err)) return } // Remove stale item-to-item recommendation if err := m.CacheClient.DeleteScores(ctx, []string{cache.ItemToItem}, cache.ScoreCondition{ Subset: lo.ToPtr(cache.Key(itemToItemConfig.Name, item.ItemId)), Before: lo.ToPtr(recommender.Timestamp()), }); err != nil { log.Logger().Error("failed to remove stale item-to-item recommendation", zap.String("item_id", item.ItemId), zap.Error(err)) return } } else { span.Add(1) } }); err != nil { return errors.Trace(err) } } return nil } // needUpdateItemToItem checks if item-to-item recommendation needs to be updated. func (m *Master) needUpdateItemToItem(ctx context.Context, itemId string, itemToItemConfig config.ItemToItemConfig) bool { // check cache items, err := m.CacheClient.SearchScores(ctx, cache.ItemToItem, cache.Key(itemToItemConfig.Name, itemId), nil, 0, -1) if err != nil { log.Logger().Error("failed to fetch item-to-item recommendation", zap.String("item_id", itemId), zap.Error(err)) return true } else if len(items) == 0 { return true } // check digest digest, err := m.CacheClient.Get(ctx, cache.Key(cache.ItemToItemDigest, itemToItemConfig.Name, itemId)).String() if err != nil { if !errors.Is(err, errors.NotFound) { log.Logger().Error("failed to read item-to-item digest", zap.Error(err)) } return true } if digest != itemToItemConfig.Hash(&m.Config.Recommend) { return true } // check update time updateTime, err := m.CacheClient.Get(ctx, cache.Key(cache.ItemToItemUpdateTime, itemToItemConfig.Name, itemId)).Time() if err != nil { if !errors.Is(err, errors.NotFound) { log.Logger().Error("failed to read last update item neighbors time", zap.Error(err)) } return true } return updateTime.Before(time.Now().Add(-m.Config.Recommend.CacheExpire)) } func (m *Master) updateUserToUser(parent context.Context, dataset *dataset.Dataset) error { if len(m.Config.Recommend.UserToUser) == 0 { return nil } ctx, span := m.tracer.Start(parent, "Generate user-to-user recommendation", len(dataset.GetUsers())*(len(m.Config.Recommend.UserToUser))*2) defer span.End() userToUserRecommenders := make([]logics.UserToUser, 0, len(m.Config.Recommend.UserToUser)) for _, cfg := range m.Config.Recommend.UserToUser { recommender, err := logics.NewUserToUser(cfg, m.Config.Recommend.CacheSize, dataset.GetTimestamp(), &logics.UserToUserOptions{ TagsIDF: dataset.GetUserColumnValuesIDF(), ItemsIDF: dataset.GetItemIDF(), }) if err != nil { return errors.Trace(err) } userToUserRecommenders = append(userToUserRecommenders, recommender) } // Push users to user-to-user recommender if err := parallel.ForEach(ctx, dataset.GetUsers(), m.Config.Master.NumJobs, func(i int, user data.User) { for _, recommender := range userToUserRecommenders { recommender.Push(&user, dataset.GetUserFeedback()[i]) span.Add(1) } }); err != nil { return errors.Trace(err) } // Save user-to-user recommendations to cache for i, recommender := range userToUserRecommenders { if err := parallel.ForEach(ctx, recommender.Users(), m.Config.Master.NumJobs, func(j int, user *data.User) { userToUserConfig := m.Config.Recommend.UserToUser[i] if m.needUpdateUserToUser(ctx, user.UserId, userToUserConfig) { score := recommender.PopAll(j) if score == nil { return } log.Logger().Debug("update user neighbors", zap.String("user_id", user.UserId), zap.Int("n_recommendations", len(score))) // Save user-to-user recommendations to cache if err := m.CacheClient.AddScores(ctx, cache.UserToUser, cache.Key(userToUserConfig.Name, user.UserId), score); err != nil { log.Logger().Error("failed to save user neighbors to cache", zap.String("user_id", user.UserId), zap.Error(err)) return } // Save user-to-user digest and last update time to cache if err := m.CacheClient.Set(ctx, cache.String(cache.Key(cache.UserToUserDigest, cache.Key(userToUserConfig.Name, user.UserId)), userToUserConfig.Hash(&m.Config.Recommend)), cache.Time(cache.Key(cache.UserToUserUpdateTime, cache.Key(userToUserConfig.Name, user.UserId)), time.Now()), ); err != nil { log.Logger().Error("failed to save user neighbors digest to cache", zap.String("user_id", user.UserId), zap.Error(err)) return } // Delete stale user-to-user recommendations if err := m.CacheClient.DeleteScores(ctx, []string{cache.UserToUser}, cache.ScoreCondition{ Subset: lo.ToPtr(cache.Key(userToUserConfig.Name, user.UserId)), Before: lo.ToPtr(recommender.Timestamp()), }); err != nil { log.Logger().Error("failed to remove stale user neighbors", zap.String("user_id", user.UserId), zap.Error(err)) } } span.Add(1) }); err != nil { return errors.Trace(err) } } return nil } // needUpdateUserToUser checks if user-to-user recommendation needs to be updated. func (m *Master) needUpdateUserToUser(ctx context.Context, userId string, userToUserConfig config.UserToUserConfig) bool { // check cache if items, err := m.CacheClient.SearchScores(ctx, cache.UserToUser, cache.Key(userToUserConfig.Name, userId), nil, 0, -1); err != nil { log.Logger().Error("failed to load user neighbors", zap.String("user_id", userId), zap.Error(err)) return true } else if len(items) == 0 { return true } // read digest cacheDigest, err := m.CacheClient.Get(ctx, cache.Key(cache.UserToUserDigest, cache.Key(userToUserConfig.Name, userId))).String() if err != nil { if !errors.Is(err, errors.NotFound) { log.Logger().Error("failed to read user neighbors digest", zap.Error(err)) } return true } if cacheDigest != userToUserConfig.Hash(&m.Config.Recommend) { return true } // check update time updateTime, err := m.CacheClient.Get(ctx, cache.Key(cache.UserToUserUpdateTime, cache.Key(userToUserConfig.Name, userId))).Time() if err != nil { if !errors.Is(err, errors.NotFound) { log.Logger().Error("failed to read last update user neighbors time", zap.Error(err)) } return true } return updateTime.Before(time.Now().Add(-m.Config.Recommend.CacheExpire)) } func (m *Master) trainCollaborativeFiltering(parent context.Context, trainSet, testSet dataset.CFSplit) error { ctx, span := m.tracer.Start(parent, "Train Collaborative Filtering Model", 2) defer span.End() if trainSet.CountUsers() == 0 { span.Fail(errors.New("No user found.")) return nil } else if trainSet.CountItems() == 0 { span.Fail(errors.New("No item found.")) return nil } else if trainSet.CountFeedback() == 0 { span.Fail(errors.New("No feedback found.")) return nil } else if trainSet.CountFeedback() == m.collaborativeFilteringTrainSetSize { log.Logger().Info("collaborative filtering dataset not changed") return nil } m.collaborativeFilteringModelMutex.Lock() collaborativeFilteringType := m.collaborativeFilteringMeta.Type collaborativeFilteringParams := m.collaborativeFilteringMeta.Params if m.collaborativeFilteringTarget.Score.NDCG > 0 && (!m.collaborativeFilteringTarget.Equal(m.collaborativeFilteringMeta)) && (m.collaborativeFilteringTarget.Score.NDCG > m.collaborativeFilteringMeta.Score.NDCG) { // 1. best ranking model must have been found. // 2. best ranking model must be different from current model // 3. best ranking model must perform better than current model collaborativeFilteringType = m.collaborativeFilteringTarget.Type collaborativeFilteringParams = m.collaborativeFilteringTarget.Params log.Logger().Info("find better collaborative filtering model", zap.Any("score", m.collaborativeFilteringTarget.Score), zap.String("type", collaborativeFilteringType), zap.Any("params", collaborativeFilteringParams)) } m.collaborativeFilteringModelMutex.Unlock() startFitTime := time.Now() fitCtx, fitSpan := monitor.Start(ctx, "Fit", 1) collaborativeFilteringModel := m.newCollaborativeFilteringModel(collaborativeFilteringType, collaborativeFilteringParams) score := collaborativeFilteringModel.Fit(fitCtx, trainSet, testSet, cf.NewFitConfig(). SetJobs(m.Config.Master.NumJobs). SetPatience(m.Config.Recommend.Collaborative.EarlyStopping.Patience)) CollaborativeFilteringFitSeconds.Set(time.Since(startFitTime).Seconds()) span.Add(1) fitSpan.End() _, indexSpan := monitor.Start(ctx, "Index", trainSet.CountItems()) matrixFactorizationItems := logics.NewMatrixFactorizationItems(time.Now()) if err := parallel.For(ctx, trainSet.CountItems(), m.Config.Master.NumJobs, func(i int) { defer indexSpan.Add(1) if itemId, ok := trainSet.GetItemDict().String(int32(i)); ok && collaborativeFilteringModel.IsItemPredictable(int32(i)) { matrixFactorizationItems.Add(itemId, collaborativeFilteringModel.GetItemFactor(int32(i))) } }); err != nil { return errors.Trace(err) } span.Add(1) indexSpan.End() matrixFactorizationUsers := logics.NewMatrixFactorizationUsers() for i := 0; i < trainSet.CountUsers(); i++ { if userId, ok := trainSet.GetUserDict().String(int32(i)); ok && collaborativeFilteringModel.IsUserPredictable(int32(i)) { matrixFactorizationUsers.Add(userId, collaborativeFilteringModel.GetUserFactor(int32(i))) } } // update ranking model m.collaborativeFilteringModelMutex.Lock() m.collaborativeFilteringTrainSetSize = trainSet.CountFeedback() m.collaborativeFilteringModelMutex.Unlock() collaborativeFilteringModelId := time.Now().UnixMilli() log.Logger().Info("fit collaborative filtering model completed", zap.Int64("id", collaborativeFilteringModelId)) CollaborativeFilteringNDCG10.Set(float64(score.NDCG)) CollaborativeFilteringRecall10.Set(float64(score.Recall)) CollaborativeFilteringPrecision10.Set(float64(score.Precision)) if err := m.CacheClient.Set(ctx, cache.Time(cache.Key(cache.GlobalMeta, cache.LastFitMatchingModelTime), time.Now())); err != nil { log.Logger().Error("failed to write meta", zap.Error(err)) } // upload model w, done, err := m.blobStore.Create(strconv.FormatInt(collaborativeFilteringModelId, 10)) if err != nil { log.Logger().Error("failed to create blob for collaborative filtering model", zap.Int64("id", collaborativeFilteringModelId), zap.Error(err)) return err } if err = matrixFactorizationItems.Marshal(w); err != nil { log.Logger().Error("failed to matrix factorization items", zap.Int64("id", collaborativeFilteringModelId), zap.Error(err)) return err } if err = matrixFactorizationUsers.Marshal(w); err != nil { log.Logger().Error("failed to matrix factorization users", zap.Int64("id", collaborativeFilteringModelId), zap.Error(err)) return err } if err = w.Close(); err != nil { log.Logger().Error("failed to close blob for collaborative filtering model", zap.Int64("id", collaborativeFilteringModelId), zap.Error(err)) return err } <-done // update meta m.collaborativeFilteringModelMutex.Lock() m.collaborativeFilteringMeta.ID = collaborativeFilteringModelId m.collaborativeFilteringMeta.Type = collaborativeFilteringType m.collaborativeFilteringMeta.Params = collaborativeFilteringParams m.collaborativeFilteringMeta.Score = score m.collaborativeFilteringModelMutex.Unlock() if err = m.metaStore.Put(meta.COLLABORATIVE_FILTERING_MODEL, m.collaborativeFilteringMeta.ToJSON()); err != nil { log.Logger().Error("failed to write collaborative filtering model meta", zap.Error(err)) return err } else { log.Logger().Info("write collaborative filtering model meta", zap.Int64("id", collaborativeFilteringModelId), zap.Float32("ndcg", score.NDCG), zap.Float32("recall", score.Recall), zap.Float32("precision", score.Precision)) } // update statistics if err = m.CacheClient.AddTimeSeriesPoints(ctx, []cache.TimeSeriesPoint{ {Name: cache.CFNDCG, Value: float64(score.NDCG), Timestamp: time.Now()}, {Name: cache.CFPrecision, Value: float64(score.Precision), Timestamp: time.Now()}, {Name: cache.CFRecall, Value: float64(score.Recall), Timestamp: time.Now()}, }); err != nil { log.Logger().Error("failed to write time series points", zap.Error(err)) return nil } m.removeOutOfDateModels() return nil } func (m *Master) newCollaborativeFilteringModel(modelType string, params model.Params) cf.MatrixFactorization { switch modelType { case "BPR": return cf.NewBPR(params) case "ALS": return cf.NewALS(params) default: return cf.NewBPR(params) } } func (m *Master) trainClickThroughRatePrediction(parent context.Context, trainSet, testSet *ctr.Dataset) error { ctx, span := m.tracer.Start(parent, "Train Click-Through Rate Prediction Model", 1) defer span.End() if trainSet.CountUsers() == 0 { span.Fail(errors.New("No user found.")) return nil } else if trainSet.CountItems() == 0 { span.Fail(errors.New("No item found.")) return nil } else if trainSet.Count() == 0 { span.Fail(errors.New("No feedback found.")) return nil } else if trainSet.Count() == m.clickThroughRateTrainSetSize { log.Logger().Info("click dataset not changed") return nil } m.clickThroughRateModelMutex.Lock() clickThroughRateType := m.clickThroughRateMeta.Type clickThroughRateParams := m.clickThroughRateMeta.Params if m.clickThroughRateTarget.Score.AUC > 0 && (!m.clickThroughRateTarget.Equal(m.clickThroughRateMeta)) && (m.clickThroughRateTarget.Score.AUC > m.clickThroughRateMeta.Score.AUC) { // 1. best click model must have been found. // 2. best click model must be different from current model // 3. best click model must perform better than current model clickThroughRateType = m.clickThroughRateTarget.Type clickThroughRateParams = m.clickThroughRateTarget.Params log.Logger().Info("find better click model", zap.Float32("Precision", m.clickThroughRateTarget.Score.Precision), zap.Float32("Recall", m.clickThroughRateTarget.Score.Recall), zap.Any("params", clickThroughRateParams)) } clickModel := ctr.NewAFM(clickThroughRateParams) m.clickThroughRateModelMutex.Unlock() startFitTime := time.Now() score := clickModel.Fit(ctx, trainSet, testSet, ctr.NewFitConfig(). SetJobs(m.Config.Master.NumJobs). SetPatience(m.Config.Recommend.Ranker.EarlyStopping.Patience)) RankingFitSeconds.Set(time.Since(startFitTime).Seconds()) // update match model m.clickThroughRateModelMutex.Lock() m.clickThroughRateTrainSetSize = trainSet.Count() clickThroughRateModelId := time.Now().UnixMilli() m.clickThroughRateModelMutex.Unlock() log.Logger().Info("fit click model complete", zap.Int64("id", clickThroughRateModelId)) RankingPrecision.Set(float64(score.Precision)) RankingRecall.Set(float64(score.Recall)) RankingAUC.Set(float64(score.AUC)) if err := m.CacheClient.Set(ctx, cache.Time(cache.Key(cache.GlobalMeta, cache.LastFitRankingModelTime), time.Now())); err != nil { log.Logger().Error("failed to write meta", zap.Error(err)) } // upload model w, done, err := m.blobStore.Create(strconv.FormatInt(clickThroughRateModelId, 10)) if err != nil { log.Logger().Error("failed to create blob for click-through rate model", zap.Int64("id", clickThroughRateModelId), zap.Error(err)) return err } if err = ctr.MarshalModel(w, clickModel); err != nil { log.Logger().Error("failed to marshal click-through rate model", zap.Int64("id", clickThroughRateModelId), zap.Error(err)) return err } if err = w.Close(); err != nil { log.Logger().Error("failed to close blob for click-through rate model", zap.Int64("id", clickThroughRateModelId), zap.Error(err)) return err } <-done // update meta m.clickThroughRateModelMutex.Lock() m.clickThroughRateMeta.ID = clickThroughRateModelId m.clickThroughRateMeta.Type = clickThroughRateType m.clickThroughRateMeta.Params = clickThroughRateParams m.clickThroughRateMeta.Score = score m.clickThroughRateModelMutex.Unlock() if err = m.metaStore.Put(meta.CLICK_THROUGH_RATE_MODEL, m.clickThroughRateMeta.ToJSON()); err != nil { log.Logger().Error("failed to write click-through rate model meta", zap.Error(err)) return err } else { log.Logger().Info("write click-through rate model meta", zap.Int64("id", clickThroughRateModelId), zap.Float32("precision", score.Precision), zap.Float32("recall", score.Recall), zap.Float32("auc", score.AUC), zap.Any("params", clickThroughRateParams)) } // update statistics if err = m.CacheClient.AddTimeSeriesPoints(ctx, []cache.TimeSeriesPoint{ {Name: cache.CTRPrecision, Value: float64(score.Precision), Timestamp: time.Now()}, {Name: cache.CTRRecall, Value: float64(score.Recall), Timestamp: time.Now()}, {Name: cache.CTRAUC, Value: float64(score.AUC), Timestamp: time.Now()}, }); err != nil { log.Logger().Error("failed to write time series points", zap.Error(err)) return err } m.removeOutOfDateModels() return nil } func (m *Master) removeOutOfDateModels() { m.collaborativeFilteringModelMutex.RLock() m.clickThroughRateModelMutex.RLock() timestamp := min(m.collaborativeFilteringMeta.ID, m.clickThroughRateMeta.ID) m.clickThroughRateModelMutex.RUnlock() m.collaborativeFilteringModelMutex.RUnlock() files, err := m.blobStore.List() if err != nil { log.Logger().Error("failed to list models in blob store", zap.Error(err)) return } for _, file := range files { id, err := strconv.ParseInt(file, 10, 64) if err != nil { log.Logger().Info("failed to parse model id", zap.String("file", file), zap.Error(err)) continue } if id < timestamp { if err = m.blobStore.Remove(file); err != nil { log.Logger().Error("failed to delete model from blob store", zap.Int64("id", id), zap.Error(err)) } else { log.Logger().Info("deleted out-of-date model from blob store", zap.Int64("id", id)) } } } } func (m *Master) collectGarbage(parent context.Context, dataSet *dataset.Dataset) error { ctx, span := m.tracer.Start(parent, "Collect Garbage in Cache", 1) defer span.End() err := m.CacheClient.ScanScores(ctx, func(collection, id, subset string, timestamp time.Time) error { switch collection { case cache.NonPersonalized: if !lo.ContainsBy(m.Config.Recommend.NonPersonalized, func(cfg config.NonPersonalizedConfig) bool { return cfg.Name == subset }) { return m.CacheClient.DeleteScores(ctx, []string{cache.NonPersonalized}, cache.ScoreCondition{ Subset: lo.ToPtr(subset), }) } case cache.UserToUser: splits := strings.Split(subset, "/") if len(splits) != 2 { log.Logger().Error("invalid subset", zap.String("subset", subset)) return nil } if dataSet.GetUserDict().Id(splits[1]) == dataset.NotId || !lo.ContainsBy(m.Config.Recommend.UserToUser, func(cfg config.UserToUserConfig) bool { return cfg.Name == splits[0] }) { return m.CacheClient.DeleteScores(ctx, []string{cache.UserToUser}, cache.ScoreCondition{ Subset: lo.ToPtr(subset), Before: lo.ToPtr(dataSet.GetTimestamp()), }) } case cache.ItemToItem: splits := strings.Split(subset, "/") if len(splits) != 2 { log.Logger().Error("invalid subset", zap.String("subset", subset)) return nil } if dataSet.GetItemDict().Id(splits[1]) == dataset.NotId || !lo.ContainsBy(m.Config.Recommend.ItemToItem, func(cfg config.ItemToItemConfig) bool { return cfg.Name == splits[0] }) { return m.CacheClient.DeleteScores(ctx, []string{cache.ItemToItem}, cache.ScoreCondition{ Subset: lo.ToPtr(subset), Before: lo.ToPtr(dataSet.GetTimestamp()), }) } case cache.CollaborativeFiltering: if dataSet.GetUserDict().Id(subset) == dataset.NotId { return m.CacheClient.DeleteScores(ctx, []string{cache.CollaborativeFiltering}, cache.ScoreCondition{ Subset: lo.ToPtr(subset), Before: lo.ToPtr(dataSet.GetTimestamp()), }) } } return nil }) return errors.Trace(err) } func (m *Master) optimizeCollaborativeFiltering(parent context.Context, trainSet, testSet dataset.CFSplit) error { ctx, span := m.tracer.Start(parent, "Optimize Collaborative Filtering Model", m.Config.Recommend.Collaborative.OptimizeTrials) defer span.End() if trainSet.CountUsers() == 0 { span.Fail(errors.New("No user found.")) return nil } else if trainSet.CountItems() == 0 { span.Fail(errors.New("No item found.")) return nil } else if trainSet.CountFeedback() == 0 { span.Fail(errors.New("No feedback found.")) return nil } search := cf.NewModelSearch(map[string]cf.ModelCreator{ "BPR": func() cf.MatrixFactorization { return cf.NewBPR(nil) }, "ALS": func() cf.MatrixFactorization { return cf.NewALS(nil) }, }, trainSet, testSet, cf.NewFitConfig(). SetJobs(m.Config.Master.NumJobs). SetPatience(m.Config.Recommend.Collaborative.EarlyStopping.Patience)). WithContext(ctx). WithSpan(span) study, err := goptuna.CreateStudy("optimizeCollaborativeFiltering", goptuna.StudyOptionDirection(goptuna.StudyDirectionMaximize), goptuna.StudyOptionSampler(tpe.NewSampler()), goptuna.StudyOptionLogger(log.NewOptunaLogger(log.Logger()))) if err != nil { return errors.Trace(err) } study.WithContext(ctx) if err = study.Optimize(search.Objective, m.Config.Recommend.Collaborative.OptimizeTrials); err != nil { return errors.Trace(err) } m.collaborativeFilteringModelMutex.Lock() m.collaborativeFilteringTarget = search.Result() m.collaborativeFilteringModelMutex.Unlock() log.Logger().Info("optimize collaborative filtering model completed", zap.Any("score", m.collaborativeFilteringTarget.Score), zap.String("type", m.collaborativeFilteringTarget.Type), zap.Any("params", m.collaborativeFilteringTarget.Params)) return nil } func (m *Master) optimizeClickThroughRatePrediction(parent context.Context, trainSet, testSet *ctr.Dataset) error { ctx, span := m.tracer.Start(parent, "Optimize Click-Through Rate Prediction Model", m.Config.Recommend.Ranker.OptimizeTrials) defer span.End() if trainSet.CountUsers() == 0 { span.Fail(errors.New("No user found.")) return nil } else if trainSet.CountItems() == 0 { span.Fail(errors.New("No item found.")) return nil } else if trainSet.Count() == 0 { span.Fail(errors.New("No feedback found.")) return nil } search := ctr.NewModelSearch(map[string]ctr.ModelCreator{ "FM": func() ctr.FactorizationMachines { return ctr.NewAFM(nil) }, }, trainSet, testSet, ctr.NewFitConfig(). SetJobs(m.Config.Master.NumJobs). SetPatience(m.Config.Recommend.Ranker.EarlyStopping.Patience)). WithContext(ctx). WithSpan(span) study, err := goptuna.CreateStudy("optimizeClickThroughRatePrediction", goptuna.StudyOptionDirection(goptuna.StudyDirectionMaximize), goptuna.StudyOptionSampler(tpe.NewSampler()), goptuna.StudyOptionLogger(log.NewOptunaLogger(log.Logger()))) if err != nil { return errors.Trace(err) } study.WithContext(ctx) if err = study.Optimize(search.Objective, m.Config.Recommend.Ranker.OptimizeTrials); err != nil { return errors.Trace(err) } m.clickThroughRateModelMutex.Lock() m.clickThroughRateTarget = search.Result() m.clickThroughRateModelMutex.Unlock() log.Logger().Info("optimize click-through rate model completed", zap.Any("score", m.clickThroughRateTarget.Score), zap.String("type", m.clickThroughRateTarget.Type), zap.Any("params", m.clickThroughRateTarget.Params)) return nil } // updateRecommend updates recommendations for all user in standalone mode. func (m *Master) updateRecommend(ctx context.Context) error { pipeline := &worker.Pipeline{ Config: m.Config, DataClient: m.DataClient, CacheClient: m.CacheClient, Tracer: m.tracer, Jobs: m.Config.Master.NumJobs, MatrixFactorizationItems: logics.NewMatrixFactorizationItems(time.Time{}), MatrixFactorizationUsers: logics.NewMatrixFactorizationUsers(), } // load matrix factorization model if m.collaborativeFilteringMeta.ID > 0 { r, err := m.blobStore.Open(strconv.FormatInt(m.collaborativeFilteringMeta.ID, 10)) if err != nil { log.Logger().Error("failed to load collaborative filtering model from blob store", zap.Int64("id", m.collaborativeFilteringMeta.ID), zap.Error(err)) return errors.Trace(err) } if err = pipeline.MatrixFactorizationItems.Unmarshal(r); err != nil { log.Logger().Error("failed to unmarshal matrix factorization items", zap.Error(err)) } else if err = pipeline.MatrixFactorizationUsers.Unmarshal(r); err != nil { log.Logger().Error("failed to unmarshal matrix factorization users", zap.Error(err)) } } // load click-through rate model when FM ranker is enabled if strings.EqualFold(m.Config.Recommend.Ranker.Type, "fm") && m.clickThroughRateMeta.ID > 0 { r, err := m.blobStore.Open(strconv.FormatInt(m.clickThroughRateMeta.ID, 10)) if err != nil { log.Logger().Error("failed to open click-through rate model", zap.Error(err)) return errors.Trace(err) } pipeline.ClickThroughRateModel, err = ctr.UnmarshalModel(r) if err != nil { log.Logger().Error("failed to unmarshal click-through rate model", zap.Error(err)) return errors.Trace(err) } } // Pull all users from database users, err := m.pullAllUsers(ctx) if err != nil { log.Logger().Error("failed to pull users", zap.Error(err)) return errors.Trace(err) } pipeline.Recommend(ctx, users, func(completed, throughput int) { log.Logger().Info("ranking recommendation", zap.Int("n_complete_users", completed), zap.Int("throughput", throughput)) }) return nil } // pullAllUsers pulls all users from the data store. func (m *Master) pullAllUsers(ctx context.Context) ([]data.User, error) { var users []data.User userChan, errChan := m.DataClient.GetUserStream(ctx, batchSize) for batchUsers := range userChan { users = append(users, batchUsers...) } if err := <-errChan; err != nil { return nil, errors.Trace(err) } return users, nil } ================================================ FILE: master/tasks_test.go ================================================ // Copyright 2020 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package master import ( "runtime" "strconv" "time" "github.com/gorse-io/gorse/common/expression" "github.com/gorse-io/gorse/config" "github.com/gorse-io/gorse/storage/cache" "github.com/gorse-io/gorse/storage/data" "github.com/samber/lo" ) func (s *MasterTestSuite) TestFindItemToItem() { ctx := s.T().Context() // create config s.Config = &config.Config{} s.Config.Recommend.CacheSize = 3 s.Config.Master.NumJobs = 4 // collect similar items := []data.Item{ {ItemId: "0", IsHidden: false, Categories: []string{"*"}, Timestamp: time.Now(), Labels: []string{"a", "b", "c", "d"}, Comment: ""}, {ItemId: "1", IsHidden: false, Categories: []string{"*"}, Timestamp: time.Now(), Labels: []string{}, Comment: ""}, {ItemId: "2", IsHidden: false, Categories: []string{"*"}, Timestamp: time.Now(), Labels: []string{"b", "c", "d"}, Comment: ""}, {ItemId: "3", IsHidden: false, Categories: nil, Timestamp: time.Now(), Labels: []string{}, Comment: ""}, {ItemId: "4", IsHidden: false, Categories: nil, Timestamp: time.Now(), Labels: []string{"b", "c"}, Comment: ""}, {ItemId: "5", IsHidden: false, Categories: []string{"*"}, Timestamp: time.Now(), Labels: []string{}, Comment: ""}, {ItemId: "6", IsHidden: false, Categories: []string{"*"}, Timestamp: time.Now(), Labels: []string{"c"}, Comment: ""}, {ItemId: "7", IsHidden: false, Categories: []string{"*"}, Timestamp: time.Now(), Labels: []string{}, Comment: ""}, {ItemId: "8", IsHidden: false, Categories: []string{"*"}, Timestamp: time.Now(), Labels: []string{"a", "b", "c", "d", "e"}, Comment: ""}, {ItemId: "9", IsHidden: false, Categories: nil, Timestamp: time.Now(), Labels: []string{}, Comment: ""}, } feedbacks := make([]data.Feedback, 0) for i := 0; i < 10; i++ { for j := 0; j <= i; j++ { if i%2 == 1 { feedbacks = append(feedbacks, data.Feedback{ FeedbackKey: data.FeedbackKey{ ItemId: strconv.Itoa(i), UserId: strconv.Itoa(j), FeedbackType: "FeedbackType", }, Timestamp: time.Now(), }) } } } var err error err = s.DataClient.BatchInsertItems(ctx, items) s.NoError(err) err = s.DataClient.BatchInsertFeedback(ctx, feedbacks, true, true, true) s.NoError(err) // insert hidden item err = s.DataClient.BatchInsertItems(ctx, []data.Item{{ ItemId: "10", Labels: []string{"a", "b", "c", "d", "e"}, IsHidden: true, }}) s.NoError(err) for i := 0; i <= 10; i++ { err = s.DataClient.BatchInsertFeedback(ctx, []data.Feedback{{ FeedbackKey: data.FeedbackKey{UserId: strconv.Itoa(i), ItemId: "10", FeedbackType: "FeedbackType"}, }}, true, true, true) s.NoError(err) } // load mock dataset _, dataSet, err := s.LoadDataFromDatabase(s.T().Context(), s.DataClient, []expression.FeedbackTypeExpression{expression.MustParseFeedbackTypeExpression("FeedbackType")}, nil, 0, 0, NewOnlineEvaluator(nil, nil), nil) s.NoError(err) // similar items (common users) s.Config.Recommend.ItemToItem = []config.ItemToItemConfig{{Name: "default", Type: "users"}} s.NoError(s.updateItemToItem(s.T().Context(), dataSet)) similar, err := s.CacheClient.SearchScores(ctx, cache.ItemToItem, cache.Key("default", "9"), nil, 0, 100) s.NoError(err) s.Equal([]string{"7", "5", "3"}, cache.ConvertDocumentsToValues(similar)) // similar items in category (common users) similar, err = s.CacheClient.SearchScores(ctx, cache.ItemToItem, cache.Key("default", "9"), []string{"*"}, 0, 100) s.NoError(err) s.Equal([]string{"7", "5"}, cache.ConvertDocumentsToValues(similar)) // digest digest, err := s.CacheClient.Get(ctx, cache.Key(cache.ItemToItemDigest, "default", "9")).String() s.NoError(err) s.Equal(s.Config.Recommend.ItemToItem[0].Hash(&s.Config.Recommend), digest) // similar items (common labels) err = s.CacheClient.Set(ctx, cache.Time(cache.Key(cache.LastModifyItemTime, "8"), time.Now())) s.NoError(err) s.Config.Recommend.ItemToItem = []config.ItemToItemConfig{{Name: "default", Type: "tags", Column: "item.Labels"}} s.NoError(s.updateItemToItem(s.T().Context(), dataSet)) similar, err = s.CacheClient.SearchScores(ctx, cache.ItemToItem, cache.Key("default", "8"), nil, 0, 100) s.NoError(err) s.Equal([]string{"0", "2", "4"}, cache.ConvertDocumentsToValues(similar)) // similar items in category (common labels) similar, err = s.CacheClient.SearchScores(ctx, cache.ItemToItem, cache.Key("default", "8"), []string{"*"}, 0, 100) s.NoError(err) s.Equal([]string{"0", "2"}, cache.ConvertDocumentsToValues(similar)) // similar items (auto) err = s.CacheClient.Set(ctx, cache.Time(cache.Key(cache.LastModifyItemTime, "8"), time.Now())) s.NoError(err) err = s.CacheClient.Set(ctx, cache.Time(cache.Key(cache.LastModifyItemTime, "9"), time.Now())) s.NoError(err) s.Config.Recommend.ItemToItem = []config.ItemToItemConfig{{Name: "default", Type: "auto"}} s.NoError(s.updateItemToItem(s.T().Context(), dataSet)) similar, err = s.CacheClient.SearchScores(ctx, cache.ItemToItem, cache.Key("default", "8"), nil, 0, 100) s.NoError(err) s.Equal([]string{"0", "2", "4"}, cache.ConvertDocumentsToValues(similar)) similar, err = s.CacheClient.SearchScores(ctx, cache.ItemToItem, cache.Key("default", "9"), nil, 0, 100) s.NoError(err) s.Equal([]string{"7", "5", "3"}, cache.ConvertDocumentsToValues(similar)) } func (s *MasterTestSuite) TestUserToUser() { ctx := s.T().Context() // create config s.Config = &config.Config{} s.Config.Recommend.CacheSize = 3 s.Config.Master.NumJobs = 4 // collect similar users := []data.User{ {UserId: "0", Labels: []string{"a", "b", "c", "d"}, Comment: ""}, {UserId: "1", Labels: []string{}, Comment: ""}, {UserId: "2", Labels: []string{"b", "c", "d"}, Comment: ""}, {UserId: "3", Labels: []string{}, Comment: ""}, {UserId: "4", Labels: []string{"b", "c"}, Comment: ""}, {UserId: "5", Labels: []string{}, Comment: ""}, {UserId: "6", Labels: []string{"c"}, Comment: ""}, {UserId: "7", Labels: []string{}, Comment: ""}, {UserId: "8", Labels: []string{"a", "b", "c", "d", "e"}, Comment: ""}, {UserId: "9", Labels: []string{}, Comment: ""}, } feedbacks := make([]data.Feedback, 0) for i := 0; i < 10; i++ { for j := 0; j <= i; j++ { if i%2 == 1 { feedbacks = append(feedbacks, data.Feedback{ FeedbackKey: data.FeedbackKey{ ItemId: strconv.Itoa(j), UserId: strconv.Itoa(i), FeedbackType: "FeedbackType", }, Timestamp: time.Now(), }) } } } var err error err = s.DataClient.BatchInsertUsers(ctx, users) s.NoError(err) err = s.DataClient.BatchInsertFeedback(ctx, feedbacks, true, true, true) s.NoError(err) _, dataSet, err := s.LoadDataFromDatabase(s.T().Context(), s.DataClient, []expression.FeedbackTypeExpression{expression.MustParseFeedbackTypeExpression("FeedbackType")}, nil, 0, 0, NewOnlineEvaluator(nil, nil), nil) s.NoError(err) // similar items (common users) s.Config.Recommend.UserToUser = []config.UserToUserConfig{{Name: "default", Type: "items"}} s.NoError(s.updateUserToUser(s.T().Context(), dataSet)) similar, err := s.CacheClient.SearchScores(ctx, cache.UserToUser, cache.Key("default", "9"), nil, 0, 100) s.NoError(err) s.Equal([]string{"7", "5", "3"}, cache.ConvertDocumentsToValues(similar)) digest, err := s.CacheClient.Get(ctx, cache.Key(cache.UserToUserDigest, "default", "9")).String() s.NoError(err) s.Equal(s.Config.Recommend.UserToUser[0].Hash(&s.Config.Recommend), digest) // similar items (common labels) err = s.CacheClient.Set(ctx, cache.Time(cache.Key(cache.LastModifyUserTime, "8"), time.Now())) s.NoError(err) s.Config.Recommend.UserToUser = []config.UserToUserConfig{{Name: "default", Type: "tags", Column: "user.Labels"}} s.NoError(s.updateUserToUser(s.T().Context(), dataSet)) similar, err = s.CacheClient.SearchScores(ctx, cache.UserToUser, cache.Key("default", "8"), nil, 0, 100) s.NoError(err) s.Equal([]string{"0", "2", "4"}, cache.ConvertDocumentsToValues(similar)) // similar items (auto) err = s.CacheClient.Set(ctx, cache.Time(cache.Key(cache.LastModifyUserTime, "8"), time.Now())) s.NoError(err) err = s.CacheClient.Set(ctx, cache.Time(cache.Key(cache.LastModifyUserTime, "9"), time.Now())) s.NoError(err) s.Config.Recommend.UserToUser = []config.UserToUserConfig{{Name: "default", Type: "auto"}} s.NoError(s.updateUserToUser(s.T().Context(), dataSet)) similar, err = s.CacheClient.SearchScores(ctx, cache.UserToUser, cache.Key("default", "8"), nil, 0, 100) s.NoError(err) s.Equal([]string{"0", "2", "4"}, cache.ConvertDocumentsToValues(similar)) similar, err = s.CacheClient.SearchScores(ctx, cache.UserToUser, cache.Key("default", "9"), nil, 0, 100) s.NoError(err) s.Equal([]string{"7", "5", "3"}, cache.ConvertDocumentsToValues(similar)) } func (s *MasterTestSuite) TestLoadDataFromDatabase() { ctx := s.T().Context() // create config s.Config = &config.Config{} s.Config.Recommend.CacheSize = 3 s.Config.Recommend.DataSource.PositiveFeedbackTypes = []expression.FeedbackTypeExpression{ expression.MustParseFeedbackTypeExpression("positive")} s.Config.Recommend.DataSource.ReadFeedbackTypes = []expression.FeedbackTypeExpression{ expression.MustParseFeedbackTypeExpression("negative")} s.Config.Master.NumJobs = runtime.NumCPU() // insert items var items []data.Item for i := 0; i < 9; i++ { items = append(items, data.Item{ ItemId: strconv.Itoa(i), Timestamp: time.Date(2000+i, 1, 1, 1, 1, 0, 0, time.UTC), Labels: []any{strconv.Itoa(i % 3), strconv.Itoa(i*10 + 10)}, Categories: []string{strconv.Itoa(i % 3)}, }) } err := s.DataClient.BatchInsertItems(ctx, items) s.NoError(err) err = s.DataClient.BatchInsertItems(ctx, []data.Item{{ ItemId: "9", Timestamp: time.Date(2020, 1, 1, 1, 1, 0, 0, time.UTC), IsHidden: true, }}) s.NoError(err) // insert users var users []data.User for i := 0; i <= 10; i++ { users = append(users, data.User{ UserId: strconv.Itoa(i), Labels: []string{strconv.Itoa(i % 5), strconv.Itoa(i*10 + 10)}, }) } err = s.DataClient.BatchInsertUsers(ctx, users) s.NoError(err) // insert feedback feedbacks := make([]data.Feedback, 0) for i := 0; i < 10; i++ { // positive feedback // item 0: user 0 // ... // item 9: user 0 ... user 9 for j := 0; j <= i; j++ { feedbacks = append(feedbacks, data.Feedback{ FeedbackKey: data.FeedbackKey{ ItemId: strconv.Itoa(i), UserId: strconv.Itoa(j), FeedbackType: "positive", }, Timestamp: time.Now(), }) } // negative feedback // item 0: user 1 .. user 10 // ... // item 9: user 10 for j := i + 1; j < 11; j++ { feedbacks = append(feedbacks, data.Feedback{ FeedbackKey: data.FeedbackKey{ ItemId: strconv.Itoa(i), UserId: strconv.Itoa(j), FeedbackType: "negative", }, Timestamp: time.Now(), }) } } err = s.DataClient.BatchInsertFeedback(ctx, feedbacks, false, false, true) s.NoError(err) // load dataset datasets, err := s.loadDataset(ctx) s.NoError(err) s.Equal(11, datasets.rankingTrainSet.CountUsers()) s.Equal(10, datasets.rankingTrainSet.CountItems()) s.Equal(11, datasets.rankingTestSet.CountUsers()) s.Equal(10, datasets.rankingTestSet.CountItems()) s.Equal(55, datasets.rankingTrainSet.CountFeedback()+datasets.rankingTestSet.CountFeedback()) s.Equal(11, datasets.clickTrainSet.CountUsers()) s.Equal(10, datasets.clickTrainSet.CountItems()) s.Equal(11, datasets.clickTestSet.CountUsers()) s.Equal(10, datasets.clickTestSet.CountItems()) s.Equal(int32(3), datasets.clickTrainSet.Index.CountItemLabels()) s.Equal(int32(5), datasets.clickTrainSet.Index.CountUserLabels()) s.Equal(int32(3), datasets.clickTestSet.Index.CountItemLabels()) s.Equal(int32(5), datasets.clickTestSet.Index.CountUserLabels()) s.Equal(110, datasets.clickTrainSet.Count()+datasets.clickTestSet.Count()) s.Equal(55, datasets.clickTrainSet.PositiveCount+datasets.clickTestSet.PositiveCount) s.Equal(55, datasets.clickTrainSet.NegativeCount+datasets.clickTestSet.NegativeCount) // check latest items latest, err := s.DataClient.GetLatestItems(ctx, 3, nil) s.NoError(err) s.Equal([]data.Item{ items[8], items[7], items[6], }, latest) latest, err = s.DataClient.GetLatestItems(ctx, 3, []string{"2"}) s.NoError(err) s.Equal([]data.Item{ items[8], items[5], items[2], }, latest) // check categories categoryScores, err := s.CacheClient.SearchScores(ctx, cache.ItemCategories, "", nil, 0, -1) s.NoError(err) categories := make([]string, len(categoryScores)) for i, score := range categoryScores { categories[i] = score.Id } s.Equal([]string{"0", "1", "2"}, categories) } func (s *MasterTestSuite) TestNonPersonalizedRecommend() { ctx := s.T().Context() // create config s.Config = &config.Config{} s.Config.Recommend.CacheSize = 3 s.Config.Recommend.DataSource.PositiveFeedbackTypes = []expression.FeedbackTypeExpression{ expression.MustParseFeedbackTypeExpression("positive")} s.Config.Recommend.DataSource.ReadFeedbackTypes = []expression.FeedbackTypeExpression{ expression.MustParseFeedbackTypeExpression("negative")} s.Config.Recommend.NonPersonalized = []config.NonPersonalizedConfig{{Name: "latest", Score: "item.Timestamp.Unix()"}} s.Config.Master.NumJobs = runtime.NumCPU() // insert items var items []data.Item for i := 0; i < 10; i++ { items = append(items, data.Item{ ItemId: strconv.Itoa(i), Timestamp: time.Date(2000+i%2, 1, 1, i, 1, 0, 0, time.UTC), }) } err := s.DataClient.BatchInsertItems(ctx, items) s.NoError(err) // insert users var users []data.User for i := 0; i < 10; i++ { users = append(users, data.User{ UserId: strconv.Itoa(i), }) } err = s.DataClient.BatchInsertUsers(ctx, users) s.NoError(err) // insert feedback feedbacks := make([]data.Feedback, 0) for i := 0; i < 10; i++ { // positive feedback // item 0: user 0 // ... // item 8: user 0 ... user 8 if i%2 == 0 { for j := 0; j <= i; j++ { feedbacks = append(feedbacks, data.Feedback{ FeedbackKey: data.FeedbackKey{ ItemId: strconv.Itoa(i), UserId: strconv.Itoa(j), FeedbackType: "positive", }, Timestamp: time.Now(), }) } } } err = s.DataClient.BatchInsertFeedback(ctx, feedbacks, false, false, true) s.NoError(err) // load dataset _, err = s.loadDataset(ctx) s.NoError(err) // check latest items latest, err := s.CacheClient.SearchScores(ctx, cache.NonPersonalized, "latest", []string{""}, 0, 3) s.NoError(err) s.Equal([]cache.Score{ {Id: items[9].ItemId, Score: float64(items[9].Timestamp.Unix())}, {Id: items[7].ItemId, Score: float64(items[7].Timestamp.Unix())}, {Id: items[5].ItemId, Score: float64(items[5].Timestamp.Unix())}, }, lo.Map(latest, func(document cache.Score, _ int) cache.Score { return cache.Score{Id: document.Id, Score: document.Score} })) // check digest digest, err := s.CacheClient.Get(ctx, cache.Key(cache.NonPersonalizedDigest, "latest")).String() s.NoError(err) s.Equal(s.Config.Recommend.NonPersonalized[0].Hash(), digest) } func (s *MasterTestSuite) TestNeedUpdateItemToItem() { s.Config = config.GetDefaultConfig() recommendConfig := config.ItemToItemConfig{Name: "default"} ctx := s.T().Context() // empty cache s.True(s.needUpdateItemToItem(ctx, "1", recommendConfig)) err := s.CacheClient.AddScores(ctx, cache.ItemToItem, cache.Key("default", "1"), []cache.Score{ {Id: "2", Score: 1, Categories: []string{""}}, {Id: "3", Score: 2, Categories: []string{""}}, {Id: "4", Score: 3, Categories: []string{""}}, }) s.NoError(err) // digest mismatch err = s.CacheClient.Set(ctx, cache.String(cache.Key(cache.ItemToItemDigest, "default", "1"), "digest")) s.NoError(err) s.True(s.needUpdateItemToItem(ctx, "1", recommendConfig)) // staled cache err = s.CacheClient.Set(ctx, cache.String(cache.Key(cache.ItemToItemDigest, "default", "1"), recommendConfig.Hash(&s.Config.Recommend))) s.NoError(err) s.True(s.needUpdateItemToItem(ctx, "1", recommendConfig)) err = s.CacheClient.Set(ctx, cache.Time(cache.Key(cache.ItemToItemUpdateTime, "default", "1"), time.Now().Add(-s.Config.Recommend.CacheExpire))) s.NoError(err) s.True(s.needUpdateItemToItem(ctx, "1", recommendConfig)) // not staled cache err = s.CacheClient.Set(ctx, cache.Time(cache.Key(cache.ItemToItemUpdateTime, "default", "1"), time.Now())) s.NoError(err) s.False(s.needUpdateItemToItem(ctx, "1", recommendConfig)) } func (s *MasterTestSuite) TestNeedUpdateUserToUser() { ctx := s.T().Context() s.Config = config.GetDefaultConfig() recommendConfig := config.UserToUserConfig{Name: "default"} // empty cache s.True(s.needUpdateUserToUser(ctx, "1", recommendConfig)) err := s.CacheClient.AddScores(ctx, cache.UserToUser, cache.Key("default", "1"), []cache.Score{ {Id: "1", Score: 1, Categories: []string{""}}, {Id: "2", Score: 2, Categories: []string{""}}, {Id: "3", Score: 3, Categories: []string{""}}, }) s.NoError(err) // digest mismatch err = s.CacheClient.Set(ctx, cache.String(cache.Key(cache.UserToUserDigest, "default", "1"), "digest")) s.NoError(err) s.True(s.needUpdateUserToUser(ctx, "1", recommendConfig)) // staled cache err = s.CacheClient.Set(ctx, cache.String(cache.Key(cache.UserToUserDigest, "default", "1"), recommendConfig.Hash(&s.Config.Recommend))) s.NoError(err) s.True(s.needUpdateUserToUser(ctx, "1", recommendConfig)) err = s.CacheClient.Set(ctx, cache.Time(cache.Key(cache.UserToUserUpdateTime, "default", "1"), time.Now().Add(-s.Config.Recommend.CacheExpire))) s.NoError(err) s.True(s.needUpdateUserToUser(ctx, "1", recommendConfig)) // not staled cache err = s.CacheClient.Set(ctx, cache.Time(cache.Key(cache.UserToUserUpdateTime, "default", "1"), time.Now())) s.NoError(err) s.False(s.needUpdateUserToUser(ctx, "1", recommendConfig)) } func (s *MasterTestSuite) TestGarbageCollection() { // create config s.Config = &config.Config{} s.Config.Master.NumJobs = 1 s.Config.Recommend.NonPersonalized = []config.NonPersonalizedConfig{{Name: "custom", Score: "1"}} s.Config.Recommend.ItemToItem = []config.ItemToItemConfig{{Name: "default", Type: "users"}} s.Config.Recommend.UserToUser = []config.UserToUserConfig{{Name: "default", Type: "items"}} // insert items ctx := s.T().Context() err := s.DataClient.BatchInsertItems(ctx, []data.Item{ {ItemId: "1", Timestamp: time.Now(), Categories: []string{"*"}, Labels: []string{"a", "b", "c", "d"}, Comment: ""}, {ItemId: "2", Timestamp: time.Now(), Categories: []string{"*"}, Labels: []string{}, Comment: ""}, }) s.NoError(err) // insert users err = s.DataClient.BatchInsertUsers(ctx, []data.User{ {UserId: "1", Labels: []string{"a", "b", "c", "d"}, Comment: ""}, {UserId: "2", Labels: []string{}, Comment: ""}, }) s.NoError(err) // insert non-personalized cache timestamp := time.Now().Add(time.Hour) err = s.CacheClient.AddScores(ctx, cache.NonPersonalized, "custom", []cache.Score{ {Id: "1", Score: 1, Categories: []string{""}, Timestamp: timestamp}, {Id: "2", Score: 2, Categories: []string{""}, Timestamp: timestamp}, }) s.NoError(err) err = s.CacheClient.AddScores(ctx, cache.NonPersonalized, "unknown", []cache.Score{ {Id: "1", Score: 1, Categories: []string{""}, Timestamp: timestamp}, {Id: "2", Score: 2, Categories: []string{""}, Timestamp: timestamp}, }) s.NoError(err) // insert item-to-item cache err = s.CacheClient.AddScores(ctx, cache.ItemToItem, cache.Key("default", "1"), []cache.Score{ {Id: "1", Score: 1, Categories: []string{""}}, {Id: "2", Score: 2, Categories: []string{""}}, }) s.NoError(err) err = s.CacheClient.AddScores(ctx, cache.ItemToItem, cache.Key("default", "3"), []cache.Score{ {Id: "1", Score: 1, Categories: []string{""}}, {Id: "2", Score: 2, Categories: []string{""}}, }) s.NoError(err) err = s.CacheClient.AddScores(ctx, cache.ItemToItem, cache.Key("unknown", "1"), []cache.Score{ {Id: "1", Score: 1, Categories: []string{""}}, {Id: "2", Score: 2, Categories: []string{""}}, }) s.NoError(err) // insert user-to-user cache err = s.CacheClient.AddScores(ctx, cache.UserToUser, cache.Key("default", "1"), []cache.Score{ {Id: "1", Score: 1, Categories: []string{""}}, {Id: "2", Score: 2, Categories: []string{""}}, }) s.NoError(err) err = s.CacheClient.AddScores(ctx, cache.UserToUser, cache.Key("default", "3"), []cache.Score{ {Id: "1", Score: 1, Categories: []string{""}}, {Id: "2", Score: 2, Categories: []string{""}}, }) s.NoError(err) err = s.CacheClient.AddScores(ctx, cache.UserToUser, cache.Key("unknown", "1"), []cache.Score{ {Id: "1", Score: 1, Categories: []string{""}}, {Id: "2", Score: 2, Categories: []string{""}}, }) s.NoError(err) // insert collaborative filtering cache err = s.CacheClient.AddScores(ctx, cache.CollaborativeFiltering, "1", []cache.Score{ {Id: "1", Score: 1, Categories: []string{""}}, {Id: "2", Score: 2, Categories: []string{""}}, }) s.NoError(err) err = s.CacheClient.AddScores(ctx, cache.CollaborativeFiltering, "3", []cache.Score{ {Id: "1", Score: 1, Categories: []string{""}}, {Id: "2", Score: 2, Categories: []string{""}}, }) s.NoError(err) // load dataset and run garbage collection datasets, err := s.loadDataset(ctx) s.NoError(err) err = s.collectGarbage(ctx, datasets.rankingDataset) s.NoError(err) // check non-personalized cache np, err := s.CacheClient.SearchScores(ctx, cache.NonPersonalized, "custom", nil, 0, 100) s.NoError(err) s.Equal([]string{"2", "1"}, cache.ConvertDocumentsToValues(np)) np, err = s.CacheClient.SearchScores(ctx, cache.NonPersonalized, "unknown", nil, 0, 100) s.NoError(err) s.Empty(np) // check item-to-item cache similar, err := s.CacheClient.SearchScores(ctx, cache.ItemToItem, cache.Key("default", "1"), nil, 0, 100) s.NoError(err) s.Equal([]string{"2", "1"}, cache.ConvertDocumentsToValues(similar)) similar, err = s.CacheClient.SearchScores(ctx, cache.ItemToItem, cache.Key("default", "3"), nil, 0, 100) s.NoError(err) s.Empty(similar) similar, err = s.CacheClient.SearchScores(ctx, cache.ItemToItem, cache.Key("unknown", "1"), nil, 0, 100) s.NoError(err) s.Empty(similar) // check user-to-user cache similar, err = s.CacheClient.SearchScores(ctx, cache.UserToUser, cache.Key("default", "1"), nil, 0, 100) s.NoError(err) s.Equal([]string{"2", "1"}, cache.ConvertDocumentsToValues(similar)) similar, err = s.CacheClient.SearchScores(ctx, cache.UserToUser, cache.Key("default", "3"), nil, 0, 100) s.NoError(err) s.Empty(similar) similar, err = s.CacheClient.SearchScores(ctx, cache.UserToUser, cache.Key("unknown", "1"), nil, 0, 100) s.NoError(err) s.Empty(similar) // check collaborative filtering cache cf, err := s.CacheClient.SearchScores(ctx, cache.CollaborativeFiltering, "1", nil, 0, 100) s.NoError(err) s.Equal([]string{"2", "1"}, cache.ConvertDocumentsToValues(cf)) cf, err = s.CacheClient.SearchScores(ctx, cache.CollaborativeFiltering, "3", nil, 0, 100) s.NoError(err) s.Empty(cf) } ================================================ FILE: model/built_in.go ================================================ // Copyright 2020 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package model import ( "archive/zip" "fmt" "io" "net/http" "os" "os/user" "path/filepath" "strings" "github.com/gorse-io/gorse/common/log" "go.uber.org/zap" ) type DatasetFormat int const ( FormatNCF DatasetFormat = iota FormatLibFM ) // Built-in Data set type _BuiltInDataSet struct { downloadURL string trainFile string testFile string format DatasetFormat } var builtInDataSets = map[string]_BuiltInDataSet{ "pinterest-20": { downloadURL: "https://cdn.gorse.io/datasets/pinterest-20.zip", trainFile: "pinterest-20/train.txt", testFile: "pinterest-20/test.txt", format: FormatNCF, }, "ml-100k": { downloadURL: "https://cdn.gorse.io/datasets/ml-100k.zip", trainFile: "ml-100k/train.txt", testFile: "ml-100k/test.txt", format: FormatNCF, }, "ml-1m": { downloadURL: "https://cdn.gorse.io/datasets/ml-1m.zip", trainFile: "ml-1m/train.txt", testFile: "ml-1m/test.txt", format: FormatNCF, }, "ml-tag": { downloadURL: "https://cdn.gorse.io/datasets/ml-tag.zip", trainFile: "ml-tag/train.libfm", testFile: "ml-tag/test.libfm", format: FormatLibFM, }, "frappe": { downloadURL: "https://cdn.gorse.io/datasets/frappe.zip", trainFile: "frappe/train.libfm", testFile: "frappe/test.libfm", format: FormatLibFM, }, "criteo": { downloadURL: "https://cdn.gorse.io/datasets/criteo.zip", trainFile: "criteo/train.libfm", testFile: "criteo/test.libfm", format: FormatLibFM, }, } // The Data directories var ( GorseDir string DataSetDir string TempDir string ) func init() { usr, err := user.Current() if err != nil { log.Logger().Fatal("failed to get user directory", zap.Error(err)) } GorseDir = usr.HomeDir + "/.gorse" DataSetDir = GorseDir + "/dataset" TempDir = GorseDir + "/temp" // create all folders if err = os.MkdirAll(DataSetDir, os.ModePerm); err != nil { log.Logger().Fatal("failed to create directory", zap.Error(err), zap.String("path", DataSetDir)) } if err = os.MkdirAll(TempDir, os.ModePerm); err != nil { log.Logger().Fatal("failed to create directory", zap.Error(err), zap.String("path", TempDir)) } } func LocateBuiltInDataset(name string, format DatasetFormat) (string, string, error) { // Extract Data set information dataSet, exist := builtInDataSets[name] if !exist { return "", "", fmt.Errorf("no such dataset %v", name) } if dataSet.format != format { return "", "", fmt.Errorf("format not matchs %v != %v", format, dataSet.format) } // Download if not exists trainFilePah := filepath.Join(DataSetDir, dataSet.trainFile) testFilePath := filepath.Join(DataSetDir, dataSet.testFile) if _, err := os.Stat(trainFilePah); os.IsNotExist(err) { zipFileName, _ := downloadFromUrl(dataSet.downloadURL, TempDir) if _, err := unzip(zipFileName, DataSetDir); err != nil { return "", "", err } } return trainFilePah, testFilePath, nil } // downloadFromUrl downloads file from URL. func downloadFromUrl(src, dst string) (string, error) { log.Logger().Info("Download dataset", zap.String("source", src)) // Extract file name tokens := strings.Split(src, "/") fileName := filepath.Join(dst, tokens[len(tokens)-1]) // Create file if err := os.MkdirAll(filepath.Dir(fileName), os.ModePerm); err != nil { return fileName, err } output, err := os.Create(fileName) if err != nil { log.Logger().Error("failed to create file", zap.Error(err), zap.String("filename", fileName)) return fileName, err } defer output.Close() // Download file response, err := http.Get(src) if err != nil { log.Logger().Error("failed to download", zap.Error(err), zap.String("source", src)) return fileName, err } defer response.Body.Close() // Save file _, err = io.Copy(output, response.Body) if err != nil { log.Logger().Error("failed to download", zap.Error(err), zap.String("source", src)) return fileName, err } return fileName, nil } // unzip zip file. func unzip(src, dst string) ([]string, error) { var fileNames []string // Open zip file r, err := zip.OpenReader(src) if err != nil { return fileNames, err } defer r.Close() // Extract files for _, f := range r.File { // Open file rc, err := f.Open() if err != nil { return fileNames, err } // Store filename/path for returning and using later on filePath := filepath.Join(dst, f.Name) // Check for ZipSlip. More Info: http://bit.ly/2MsjAWE if !strings.HasPrefix(filePath, filepath.Clean(dst)+string(os.PathSeparator)) { return fileNames, fmt.Errorf("%s: illegal file path", filePath) } // Add filename fileNames = append(fileNames, filePath) if f.FileInfo().IsDir() { // Create folder if err = os.MkdirAll(filePath, os.ModePerm); err != nil { return fileNames, err } } else { // Create all folders if err = os.MkdirAll(filepath.Dir(filePath), os.ModePerm); err != nil { return fileNames, err } // Create file outFile, err := os.OpenFile(filePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode()) if err != nil { return fileNames, err } // Save file _, err = io.Copy(outFile, rc) if err != nil { return nil, err } // Close the file without defer to close before next iteration of loop err = outFile.Close() if err != nil { return nil, err } } // Close file err = rc.Close() if err != nil { return nil, err } } return fileNames, nil } ================================================ FILE: model/built_in_test.go ================================================ // Copyright 2020 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package model import ( "github.com/stretchr/testify/assert" "os" "path/filepath" "testing" ) func TestUnzip(t *testing.T) { // Download zipName, err := downloadFromUrl("https://cdn.gorse.io/datasets/yelp.zip", os.TempDir()) assert.Nil(t, err, "download file failed ") // Extract files fileNames, err := unzip(zipName, DataSetDir) // Check assert.Nil(t, err, "unzip file failed ") assert.Equal(t, 2, len(fileNames), "Number of file doesn't match") } func TestLocateBuiltInDataset(t *testing.T) { trainFilePath, testFilePath, err := LocateBuiltInDataset("ml-1m", FormatNCF) assert.NoError(t, err) assert.Equal(t, filepath.Join(DataSetDir, "ml-1m", "train.txt"), trainFilePath) assert.Equal(t, filepath.Join(DataSetDir, "ml-1m", "test.txt"), testFilePath) } ================================================ FILE: model/cf/evaluator.go ================================================ // Copyright 2020 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package cf import ( "context" "github.com/chewxy/math32" mapset "github.com/deckarep/golang-set/v2" "github.com/gorse-io/gorse/common/floats" "github.com/gorse-io/gorse/common/heap" "github.com/gorse-io/gorse/common/parallel" "github.com/gorse-io/gorse/dataset" "github.com/samber/lo" ) /* Evaluate Item Ranking */ // Metric is used by evaluators in personalized ranking tasks. type Metric func(targetSet mapset.Set[int32], rankList []int32) float32 // Evaluate evaluates a model in top-n tasks. func Evaluate(estimator MatrixFactorization, testSet, trainSet dataset.CFSplit, topK, numCandidates, nJobs int, scorers ...Metric) []float32 { partSum := make([][]float32, nJobs) partCount := make([]float32, nJobs) for i := 0; i < nJobs; i++ { partSum[i] = make([]float32, len(scorers)) } //rng := NewRandomGenerator(0) // For all UserFeedback negatives := testSet.SampleUserNegatives(trainSet, numCandidates) _ = parallel.Parallel(context.Background(), testSet.CountUsers(), nJobs, func(workerId, userIndex int) error { // Find top-n ItemFeedback in test set targetSet := mapset.NewSet(testSet.GetUserFeedback()[userIndex]...) if targetSet.Cardinality() > 0 { // Sample negative samples //userTrainSet := NewSet(trainSet.UserFeedback[userIndex]) negativeSample := negatives[userIndex] candidates := make([]int32, 0, targetSet.Cardinality()+len(negativeSample)) candidates = append(candidates, testSet.GetUserFeedback()[userIndex]...) candidates = append(candidates, negativeSample...) // Find top-n ItemFeedback in predictions rankList := Rank(estimator, int32(userIndex), candidates, topK) partCount[workerId]++ for i, metric := range scorers { partSum[workerId][i] += metric(targetSet, rankList) } } return nil }) sum := make([]float32, len(scorers)) for i := 0; i < nJobs; i++ { for j := range partSum[i] { sum[j] += partSum[i][j] } } count := lo.Sum(partCount) floats.MulConst(sum, 1/count) return sum } // NDCG means Normalized Discounted Cumulative Gain. func NDCG(targetSet mapset.Set[int32], rankList []int32) float32 { // IDCG = \sum^{|REL|}_{i=1} \frac {1} {\log_2(i+1)} idcg := float32(0) for i := 0; i < targetSet.Cardinality() && i < len(rankList); i++ { idcg += 1.0 / math32.Log2(float32(i)+2.0) } // DCG = \sum^{N}_{i=1} \frac {2^{rel_i}-1} {\log_2(i+1)} dcg := float32(0) for i, itemId := range rankList { if targetSet.Contains(itemId) { dcg += 1.0 / math32.Log2(float32(i)+2.0) } } return dcg / idcg } // Precision is the fraction of relevant ItemFeedback among the recommended ItemFeedback. // // \frac{|relevant documents| \cap |retrieved documents|} {|{retrieved documents}|} func Precision(targetSet mapset.Set[int32], rankList []int32) float32 { hit := float32(0) for _, itemId := range rankList { if targetSet.Contains(itemId) { hit++ } } return hit / float32(len(rankList)) } // Recall is the fraction of relevant ItemFeedback that have been recommended over the total // amount of relevant ItemFeedback. // // \frac{|relevant documents| \cap |retrieved documents|} {|{relevant documents}|} func Recall(targetSet mapset.Set[int32], rankList []int32) float32 { hit := 0 for _, itemId := range rankList { if targetSet.Contains(itemId) { hit++ } } return float32(hit) / float32(targetSet.Cardinality()) } // HR means Hit Ratio. func HR(targetSet mapset.Set[int32], rankList []int32) float32 { for _, itemId := range rankList { if targetSet.Contains(itemId) { return 1 } } return 0 } // MAP means Mean Average Precision. // mAP: http://sdsawtelle.github.io/blog/output/mean-average-precision-MAP-for-recommender-systems.html func MAP(targetSet mapset.Set[int32], rankList []int32) float32 { sumPrecision := float32(0) hit := 0 for i, itemId := range rankList { if targetSet.Contains(itemId) { hit++ sumPrecision += float32(hit) / float32(i+1) } } return sumPrecision / float32(targetSet.Cardinality()) } // MRR means Mean Reciprocal Rank. // // The mean reciprocal rank is a statistic measure for evaluating any process // that produces a list of possible responses to a sample of queries, ordered // by probability of correctness. The reciprocal rank of a query response is // the multiplicative inverse of the rank of the first correct answer: 1 for // first place, ​1⁄2 for second place, ​1⁄3 for third place and so on. The // mean reciprocal rank is the average of the reciprocal ranks of results for // a sample of queries Q: // // MRR = \frac{1}{Q} \sum^{|Q|}_{i=1} \frac{1}{rank_i} func MRR(targetSet mapset.Set[int32], rankList []int32) float32 { for i, itemId := range rankList { if targetSet.Contains(itemId) { return 1 / float32(i+1) } } return 0 } func Rank(model MatrixFactorization, userId int32, candidates []int32, topN int) []int32 { // Get top-n list itemsHeap := heap.NewTopKFilter[int32, float32](topN) for _, itemId := range candidates { itemsHeap.Push(itemId, model.internalPredict(userId, itemId)) } return itemsHeap.PopAllValues() } ================================================ FILE: model/cf/evaluator_test.go ================================================ // Copyright 2020 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package cf import ( "context" "io" "strconv" "testing" "time" "github.com/c-bata/goptuna" mapset "github.com/deckarep/golang-set/v2" "github.com/gorse-io/gorse/dataset" "github.com/gorse-io/gorse/model" "github.com/gorse-io/gorse/storage/data" "github.com/stretchr/testify/assert" ) const evalEpsilon = 0.00001 func TestNDCG(t *testing.T) { targetSet := mapset.NewSet[int32](1, 3, 5, 7) rankList := []int32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} assert.InDelta(t, 0.6766372989, NDCG(targetSet, rankList), evalEpsilon) } func TestPrecision(t *testing.T) { targetSet := mapset.NewSet[int32](1, 3, 5, 7) rankList := []int32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} assert.InDelta(t, 0.4, Precision(targetSet, rankList), evalEpsilon) } func TestRecall(t *testing.T) { targetSet := mapset.NewSet[int32](1, 3, 15, 17, 19) rankList := []int32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} assert.InDelta(t, 0.4, Recall(targetSet, rankList), evalEpsilon) } func TestAP(t *testing.T) { targetSet := mapset.NewSet[int32](1, 3, 7, 9) rankList := []int32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} assert.InDelta(t, 0.44375, MAP(targetSet, rankList), evalEpsilon) } func TestRR(t *testing.T) { targetSet := mapset.NewSet[int32](3) rankList := []int32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} assert.InDelta(t, 0.25, MRR(targetSet, rankList), evalEpsilon) } func TestHR(t *testing.T) { rankList := []int32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} assert.InDelta(t, 1, HR(mapset.NewSet[int32](3), rankList), evalEpsilon) assert.InDelta(t, 0, HR(mapset.NewSet[int32](30), rankList), evalEpsilon) } type mockMatrixFactorizationForEval struct { model.BaseModel positive []mapset.Set[int32] negative []mapset.Set[int32] } func (m *mockMatrixFactorizationForEval) GetUserFactor(_ int32) []float32 { panic("implement me") } func (m *mockMatrixFactorizationForEval) GetItemFactor(_ int32) []float32 { panic("implement me") } func (m *mockMatrixFactorizationForEval) IsUserPredictable(_ int32) bool { panic("implement me") } func (m *mockMatrixFactorizationForEval) IsItemPredictable(_ int32) bool { panic("implement me") } func (m *mockMatrixFactorizationForEval) Marshal(_ io.Writer) error { panic("implement me") } func (m *mockMatrixFactorizationForEval) Unmarshal(_ io.Reader) error { panic("implement me") } func (m *mockMatrixFactorizationForEval) Invalid() bool { panic("implement me") } func (m *mockMatrixFactorizationForEval) GetUserIndex() *dataset.FreqDict { panic("don't call me") } func (m *mockMatrixFactorizationForEval) GetItemIndex() *dataset.FreqDict { panic("don't call me") } func (m *mockMatrixFactorizationForEval) Fit(_ context.Context, _, _ dataset.CFSplit, _ *FitConfig) Score { panic("don't call me") } func (m *mockMatrixFactorizationForEval) Predict(_, _ string) float32 { panic("don't call me") } func (m *mockMatrixFactorizationForEval) internalPredict(userId, itemId int32) float32 { if m.positive[userId].Contains(itemId) { return 1 } if m.negative[userId].Contains(itemId) { return -1 } return 0 } func (m *mockMatrixFactorizationForEval) Clear() { // do nothing } func (m *mockMatrixFactorizationForEval) SuggestParams(trial goptuna.Trial) model.Params { panic("not implemented") } func TestEvaluate(t *testing.T) { // create dataset train, test := dataset.NewDataset(time.Now(), 0, 0), dataset.NewDataset(time.Now(), 0, 0) //train.UserFeedback = make([][]int32, 4) for i := 0; i < 4; i++ { train.AddUser(data.User{UserId: strconv.Itoa(i)}) test.AddUser(data.User{UserId: strconv.Itoa(i / 4)}) } for i := 0; i < 16; i++ { test.AddItem(data.Item{ItemId: strconv.Itoa(i)}) test.AddFeedback(strconv.Itoa(i/4), strconv.Itoa(i), time.Time{}) } assert.Equal(t, 16, test.CountFeedback()) assert.Equal(t, 4, test.CountUsers()) assert.Equal(t, 16, test.CountItems()) // create model m := &mockMatrixFactorizationForEval{ positive: []mapset.Set[int32]{ mapset.NewSet[int32](0, 1, 2, 3), mapset.NewSet[int32](4, 5, 6), mapset.NewSet[int32](8, 9), mapset.NewSet[int32](12), }, negative: []mapset.Set[int32]{ mapset.NewSet[int32](), mapset.NewSet[int32](7), mapset.NewSet[int32](10, 11), mapset.NewSet[int32](13, 14, 15), }, } // evaluate model s := Evaluate(m, test, train, 4, test.CountItems(), 4, Precision) assert.Equal(t, 1, len(s)) assert.Equal(t, float32(0.625), s[0]) } ================================================ FILE: model/cf/model.go ================================================ // Copyright 2021 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package cf import ( "context" "encoding/binary" "fmt" "io" "reflect" "time" "github.com/bits-and-blooms/bitset" "github.com/c-bata/goptuna" "github.com/chewxy/math32" mapset "github.com/deckarep/golang-set/v2" "github.com/gorse-io/gorse/common/encoding" "github.com/gorse-io/gorse/common/floats" "github.com/gorse-io/gorse/common/log" "github.com/gorse-io/gorse/common/monitor" "github.com/gorse-io/gorse/common/parallel" "github.com/gorse-io/gorse/common/util" "github.com/gorse-io/gorse/dataset" "github.com/gorse-io/gorse/model" "github.com/gorse-io/gorse/protocol" "github.com/juju/errors" "github.com/matttproud/golang_protobuf_extensions/pbutil" "github.com/samber/lo" "go.uber.org/zap" ) type Score struct { NDCG float32 Precision float32 Recall float32 } type FitConfig struct { Jobs int Verbose int Candidates int TopK int Patience int } func NewFitConfig() *FitConfig { return &FitConfig{ Jobs: 1, Verbose: 10, Candidates: 100, TopK: 10, } } func (config *FitConfig) SetVerbose(verbose int) *FitConfig { config.Verbose = verbose return config } func (config *FitConfig) SetJobs(jobs int) *FitConfig { config.Jobs = jobs return config } func (config *FitConfig) SetPatience(patience int) *FitConfig { config.Patience = patience return config } type Model interface { model.Model // Fit a model with a train set and parameters. Fit(ctx context.Context, trainSet, validateSet dataset.CFSplit, config *FitConfig) Score // GetItemIndex returns item index. GetItemIndex() *dataset.FreqDict // Marshal model into byte stream. Marshal(w io.Writer) error // Unmarshal model from byte stream. Unmarshal(r io.Reader) error // GetUserFactor returns latent factor of a user. GetUserFactor(userIndex int32) []float32 // GetItemFactor returns latent factor of an item. GetItemFactor(itemIndex int32) []float32 } type MatrixFactorization interface { Model // Predict the rating given by a user (userId) to a item (itemId). Predict(userId, itemId string) float32 // InternalPredict predicts rating given by a user index and a item index internalPredict(userIndex, itemIndex int32) float32 // GetUserIndex returns user index. GetUserIndex() *dataset.FreqDict // GetItemIndex returns item index. GetItemIndex() *dataset.FreqDict // IsUserPredictable returns false if user has no feedback and its embedding vector never be trained. IsUserPredictable(userIndex int32) bool // IsItemPredictable returns false if item has no feedback and its embedding vector never be trained. IsItemPredictable(itemIndex int32) bool // Marshal model into byte stream. Marshal(w io.Writer) error // Unmarshal model from byte stream. Unmarshal(r io.Reader) error } type BaseMatrixFactorization struct { model.BaseModel UserIndex *dataset.FreqDict ItemIndex *dataset.FreqDict UserPredictable *bitset.BitSet ItemPredictable *bitset.BitSet // Model parameters UserFactor [][]float32 // p_u ItemFactor [][]float32 // q_i } func (baseModel *BaseMatrixFactorization) Init(trainSet dataset.CFSplit) { baseModel.UserIndex = trainSet.GetUserDict() baseModel.ItemIndex = trainSet.GetItemDict() // set user trained flags baseModel.UserPredictable = bitset.New(uint(baseModel.UserIndex.Count())) for userIndex := int32(0); userIndex < baseModel.UserIndex.Count(); userIndex++ { if len(trainSet.GetUserFeedback()[userIndex]) > 0 { baseModel.UserPredictable.Set(uint(userIndex)) } } // set item trained flags baseModel.ItemPredictable = bitset.New(uint(baseModel.ItemIndex.Count())) for itemIndex := int32(0); itemIndex < baseModel.ItemIndex.Count(); itemIndex++ { if len(trainSet.GetItemFeedback()[itemIndex]) > 0 { baseModel.ItemPredictable.Set(uint(itemIndex)) } } } func (baseModel *BaseMatrixFactorization) GetUserIndex() *dataset.FreqDict { return baseModel.UserIndex } func (baseModel *BaseMatrixFactorization) GetItemIndex() *dataset.FreqDict { return baseModel.ItemIndex } // IsUserPredictable returns false if user has no feedback and its embedding vector never be trained. func (baseModel *BaseMatrixFactorization) IsUserPredictable(userIndex int32) bool { if userIndex >= baseModel.UserIndex.Count() || userIndex < 0 { return false } return baseModel.UserPredictable.Test(uint(userIndex)) } // IsItemPredictable returns false if item has no feedback and its embedding vector never be trained. func (baseModel *BaseMatrixFactorization) IsItemPredictable(itemIndex int32) bool { if itemIndex >= baseModel.ItemIndex.Count() || itemIndex < 0 { return false } return baseModel.ItemPredictable.Test(uint(itemIndex)) } // GetUserFactor returns the latent factor of a user. func (baseModel *BaseMatrixFactorization) GetUserFactor(userIndex int32) []float32 { return baseModel.UserFactor[userIndex] } // GetItemFactor returns the latent factor of an item. func (baseModel *BaseMatrixFactorization) GetItemFactor(itemIndex int32) []float32 { return baseModel.ItemFactor[itemIndex] } func (baseModel *BaseMatrixFactorization) Predict(userId, itemId string) float32 { // Convert sparse Names to dense Names userIndex := baseModel.UserIndex.Id(userId) itemIndex := baseModel.ItemIndex.Id(itemId) if userIndex < 0 { log.Logger().Warn("unknown user", zap.String("user_id", userId)) } if itemIndex < 0 { log.Logger().Warn("unknown item", zap.String("item_id", itemId)) } return baseModel.internalPredict(userIndex, itemIndex) } func (baseModel *BaseMatrixFactorization) internalPredict(userIndex, itemIndex int32) float32 { ret := float32(0.0) if itemIndex >= 0 && userIndex >= 0 { ret = floats.Dot(baseModel.UserFactor[userIndex], baseModel.ItemFactor[itemIndex]) } else { log.Logger().Warn("unknown user or item") } return ret } // Marshal model into byte stream. func (baseModel *BaseMatrixFactorization) Marshal(w io.Writer) error { // write params err := encoding.WriteGob(w, baseModel.Params) if err != nil { return errors.Trace(err) } // write predictable user count if err := binary.Write(w, binary.LittleEndian, int64(baseModel.UserPredictable.Count())); err != nil { return errors.Trace(err) } // write user latent factors for userIndex := int32(0); userIndex < baseModel.UserIndex.Count(); userIndex++ { if baseModel.UserPredictable.Test(uint(userIndex)) { userId, _ := baseModel.UserIndex.String(userIndex) latentFactor := &protocol.LatentFactor{ Id: userId, Data: baseModel.UserFactor[userIndex], } if _, err := pbutil.WriteDelimited(w, latentFactor); err != nil { return errors.Trace(err) } } } // write predictable item count if err := binary.Write(w, binary.LittleEndian, int64(baseModel.ItemPredictable.Count())); err != nil { return errors.Trace(err) } // write item latent factors for itemIndex := int32(0); itemIndex < baseModel.ItemIndex.Count(); itemIndex++ { if baseModel.ItemPredictable.Test(uint(itemIndex)) { itemId, _ := baseModel.ItemIndex.String(itemIndex) latentFactor := &protocol.LatentFactor{ Id: itemId, Data: baseModel.ItemFactor[itemIndex], } if _, err := pbutil.WriteDelimited(w, latentFactor); err != nil { return errors.Trace(err) } } } return nil } // Unmarshal model from byte stream. func (baseModel *BaseMatrixFactorization) Unmarshal(r io.Reader) error { // read params if err := encoding.ReadGob(r, &baseModel.Params); err != nil { return errors.Trace(err) } // read predictable user count var userPredictableCount int64 if err := binary.Read(r, binary.LittleEndian, &userPredictableCount); err != nil { return errors.Trace(err) } // read user latent factors baseModel.UserIndex = dataset.NewFreqDict() baseModel.UserPredictable = bitset.New(uint(userPredictableCount)) baseModel.UserFactor = make([][]float32, userPredictableCount) for i := 0; i < int(userPredictableCount); i++ { latentFactor := new(protocol.LatentFactor) if _, err := pbutil.ReadDelimited(r, latentFactor); err != nil { return errors.Trace(err) } userIndex := baseModel.UserIndex.Add(latentFactor.Id) baseModel.UserPredictable.Set(uint(userIndex)) baseModel.UserFactor[userIndex] = latentFactor.Data } // read predictable item count var itemPredictableCount int64 if err := binary.Read(r, binary.LittleEndian, &itemPredictableCount); err != nil { return errors.Trace(err) } // read item latent factors baseModel.ItemIndex = dataset.NewFreqDict() baseModel.ItemPredictable = bitset.New(uint(itemPredictableCount)) baseModel.ItemFactor = make([][]float32, itemPredictableCount) for i := 0; i < int(itemPredictableCount); i++ { latentFactor := new(protocol.LatentFactor) if _, err := pbutil.ReadDelimited(r, latentFactor); err != nil { return errors.Trace(err) } itemIndex := baseModel.ItemIndex.Add(latentFactor.Id) baseModel.ItemPredictable.Set(uint(itemIndex)) baseModel.ItemFactor[itemIndex] = latentFactor.Data } return nil } func (baseModel *BaseMatrixFactorization) Clear() { baseModel.UserIndex = nil baseModel.ItemIndex = nil baseModel.ItemFactor = nil baseModel.UserFactor = nil } func (baseModel *BaseMatrixFactorization) Invalid() bool { return baseModel == nil || baseModel.UserIndex == nil || baseModel.ItemIndex == nil || baseModel.ItemFactor == nil || baseModel.UserFactor == nil } func GetModelName(m Model) string { switch m.(type) { case *BPR: return "bpr" case *ALS: return "als" default: return reflect.TypeOf(m).String() } } func MarshalModel(w io.Writer, m Model) error { if err := encoding.WriteString(w, GetModelName(m)); err != nil { return errors.Trace(err) } if err := m.Marshal(w); err != nil { return errors.Trace(err) } return nil } func UnmarshalModel(r io.Reader) (MatrixFactorization, error) { name, err := encoding.ReadString(r) if err != nil { return nil, errors.Trace(err) } switch name { case "bpr": var bpr BPR if err := bpr.Unmarshal(r); err != nil { return nil, errors.Trace(err) } return &bpr, nil case "als": var als ALS if err := als.Unmarshal(r); err != nil { return nil, errors.Trace(err) } return &als, nil } return nil, fmt.Errorf("unknown model %v", name) } // BPR means Bayesian Personal Ranking, is a pairwise learning algorithm for matrix factorization // model with implicit feedback. The pairwise ranking between item i and j for user u is estimated // by: // // p(i >_u j) = \sigma( p_u^T (q_i - q_j) ) // // Hyper-parameters: // // Reg - The regularization parameter of the cost function that is // optimized. Default is 0.01. // Lr - The learning rate of SGD. Default is 0.05. // nFactors - The number of latent factors. Default is 10. // NEpochs - The number of iteration of the SGD procedure. Default is 100. // InitMean - The mean of initial random latent factors. Default is 0. // InitStdDev - The standard deviation of initial random latent factors. Default is 0.001. type BPR struct { BaseMatrixFactorization // Hyper parameters nFactors int nEpochs int lr float32 reg float32 initMean float32 initStdDev float32 } // NewBPR creates a BPR model. func NewBPR(params model.Params) *BPR { bpr := new(BPR) bpr.SetParams(params) return bpr } // SetParams sets hyper-parameters of the BPR model. func (bpr *BPR) SetParams(params model.Params) { bpr.BaseMatrixFactorization.SetParams(params) // Setup hyper-parameters bpr.nFactors = bpr.Params.GetInt(model.NFactors, 16) bpr.nEpochs = bpr.Params.GetInt(model.NEpochs, 100) bpr.lr = bpr.Params.GetFloat32(model.Lr, 0.05) bpr.reg = bpr.Params.GetFloat32(model.Reg, 0.01) bpr.initMean = bpr.Params.GetFloat32(model.InitMean, 0) bpr.initStdDev = bpr.Params.GetFloat32(model.InitStdDev, 0.001) } func (bpr *BPR) SuggestParams(trial goptuna.Trial) model.Params { return model.Params{ model.NFactors: 16, model.Lr: lo.Must(trial.SuggestLogFloat(string(model.Lr), 0.001, 0.1)), model.Reg: lo.Must(trial.SuggestLogFloat(string(model.Reg), 0.001, 0.1)), model.InitMean: 0, model.InitStdDev: lo.Must(trial.SuggestLogFloat(string(model.InitStdDev), 0.001, 0.1)), } } // Fit the BPR model. Its task complexity is O(bpr.nEpochs). func (bpr *BPR) Fit(ctx context.Context, trainSet, valSet dataset.CFSplit, config *FitConfig) Score { log.Logger().Info("fit bpr", zap.Int("train_set_size", trainSet.CountFeedback()), zap.Int("test_set_size", valSet.CountFeedback()), zap.Any("params", bpr.GetParams()), zap.Any("config", config)) bpr.Init(trainSet) // Create buffers temp := util.NewMatrix32(config.Jobs, bpr.nFactors) userFactor := util.NewMatrix32(config.Jobs, bpr.nFactors) positiveItemFactor := util.NewMatrix32(config.Jobs, bpr.nFactors) negativeItemFactor := util.NewMatrix32(config.Jobs, bpr.nFactors) rng := make([]util.RandomGenerator, config.Jobs) for i := 0; i < config.Jobs; i++ { rng[i] = util.NewRandomGenerator(bpr.GetRandomGenerator().Int63()) } // Convert array to hashmap userFeedback := make([]mapset.Set[int32], trainSet.CountUsers()) for u := range userFeedback { userFeedback[u] = mapset.NewSet[int32]() for _, i := range trainSet.GetUserFeedback()[u] { userFeedback[u].Add(i) } } evalStart := time.Now() score := Evaluate(bpr, valSet, trainSet, config.TopK, config.Candidates, config.Jobs, NDCG, Precision, Recall) scores := []lo.Tuple2[int, float32]{{A: 0, B: score[0]}} evalTime := time.Since(evalStart) log.Logger().Debug(fmt.Sprintf("fit bpr %v/%v", 0, bpr.nEpochs), zap.String("eval_time", evalTime.String()), zap.Float32(fmt.Sprintf("NDCG@%v", config.TopK), score[0]), zap.Float32(fmt.Sprintf("Precision@%v", config.TopK), score[1]), zap.Float32(fmt.Sprintf("Recall@%v", config.TopK), score[2])) // Training _, span := monitor.Start(ctx, "BPR.Fit", bpr.nEpochs) defer span.End() for epoch := 1; epoch <= bpr.nEpochs; epoch++ { fitStart := time.Now() // Training epoch cost := make([]float32, config.Jobs) if err := parallel.Parallel(ctx, trainSet.CountFeedback(), config.Jobs, func(workerId, _ int) error { // Select a user var userIndex int32 var ratingCount int for { userIndex = rng[workerId].Int31n(int32(trainSet.CountUsers())) ratingCount = len(trainSet.GetUserFeedback()[userIndex]) if ratingCount > 0 { break } } posIndex := trainSet.GetUserFeedback()[userIndex][rng[workerId].Intn(ratingCount)] // Select a negative sample negIndex := int32(-1) for { temp := rng[workerId].Int31n(int32(trainSet.CountItems())) if !userFeedback[userIndex].Contains(temp) { negIndex = temp break } } diff := bpr.internalPredict(userIndex, posIndex) - bpr.internalPredict(userIndex, negIndex) cost[workerId] += math32.Log1p(math32.Exp(-diff)) grad := math32.Exp(-diff) / (1.0 + math32.Exp(-diff)) // Pairwise update copy(userFactor[workerId], bpr.UserFactor[userIndex]) copy(positiveItemFactor[workerId], bpr.ItemFactor[posIndex]) copy(negativeItemFactor[workerId], bpr.ItemFactor[negIndex]) // Update positive item latent factor: +w_u floats.MulConstTo(userFactor[workerId], grad, temp[workerId]) floats.MulConstAdd(positiveItemFactor[workerId], -bpr.reg, temp[workerId]) floats.MulConstAdd(temp[workerId], bpr.lr, bpr.ItemFactor[posIndex]) // Update negative item latent factor: -w_u floats.MulConstTo(userFactor[workerId], -grad, temp[workerId]) floats.MulConstAdd(negativeItemFactor[workerId], -bpr.reg, temp[workerId]) floats.MulConstAdd(temp[workerId], bpr.lr, bpr.ItemFactor[negIndex]) // Update user latent factor: h_i-h_j floats.SubTo(positiveItemFactor[workerId], negativeItemFactor[workerId], temp[workerId]) floats.MulConst(temp[workerId], grad) floats.MulConstAdd(userFactor[workerId], -bpr.reg, temp[workerId]) floats.MulConstAdd(temp[workerId], bpr.lr, bpr.UserFactor[userIndex]) return nil }); err != nil { log.Logger().Info("fit bpr canceled", zap.Int("epoch", epoch), zap.Error(err)) return Score{} } fitTime := time.Since(fitStart) // Cross validation if epoch%config.Verbose == 0 || epoch == bpr.nEpochs { evalStart = time.Now() score = Evaluate(bpr, valSet, trainSet, config.TopK, config.Candidates, config.Jobs, NDCG, Precision, Recall) scores = append(scores, lo.Tuple2[int, float32]{A: epoch, B: score[0]}) evalTime = time.Since(evalStart) log.Logger().Info(fmt.Sprintf("fit bpr %v/%v", epoch, bpr.nEpochs), zap.String("fit_time", fitTime.String()), zap.String("eval_time", evalTime.String()), zap.Float32(fmt.Sprintf("NDCG@%v", config.TopK), score[0]), zap.Float32(fmt.Sprintf("Precision@%v", config.TopK), score[1]), zap.Float32(fmt.Sprintf("Recall@%v", config.TopK), score[2])) // early stopping if no improvement in last `patience` epochs if config.Patience > 0 && epoch > config.Patience { epochScore := lo.MaxBy(scores, func(a, b lo.Tuple2[int, float32]) bool { return a.B > b.B }) if epochScore.A <= epoch-config.Patience { log.Logger().Info("early stopping", zap.Int("best_epoch", epochScore.A), zap.Float32("best_NDCG", epochScore.B), zap.Int("patience", config.Patience)) break } } } span.Add(1) } log.Logger().Info("fit bpr complete", zap.Float32(fmt.Sprintf("NDCG@%v", config.TopK), score[0]), zap.Float32(fmt.Sprintf("Precision@%v", config.TopK), score[1]), zap.Float32(fmt.Sprintf("Recall@%v", config.TopK), score[2])) return Score{ NDCG: score[0], Precision: score[1], Recall: score[2], } } func (bpr *BPR) Init(trainSet dataset.CFSplit) { // Initialize parameters newUserFactor := bpr.GetRandomGenerator().NormalMatrix(trainSet.CountUsers(), bpr.nFactors, bpr.initMean, bpr.initStdDev) newItemFactor := bpr.GetRandomGenerator().NormalMatrix(trainSet.CountItems(), bpr.nFactors, bpr.initMean, bpr.initStdDev) // Initialize base bpr.UserFactor = newUserFactor bpr.ItemFactor = newItemFactor bpr.BaseMatrixFactorization.Init(trainSet) } // Marshal model into byte stream. func (bpr *BPR) Marshal(w io.Writer) error { if err := bpr.BaseMatrixFactorization.Marshal(w); err != nil { return errors.Trace(err) } return nil } // Unmarshal model from byte stream. func (bpr *BPR) Unmarshal(r io.Reader) error { if err := bpr.BaseMatrixFactorization.Unmarshal(r); err != nil { return errors.Trace(err) } bpr.SetParams(bpr.Params) return nil } type ALS struct { BaseMatrixFactorization // Hyper parameters nFactors int nEpochs int reg float32 initMean float32 initStdDev float32 weight float32 } // NewALS creates a eALS model. func NewALS(params model.Params) *ALS { fast := new(ALS) fast.SetParams(params) return fast } // SetParams sets hyper-parameters for the ALS model. func (als *ALS) SetParams(params model.Params) { als.BaseMatrixFactorization.SetParams(params) als.nFactors = als.Params.GetInt(model.NFactors, 16) als.nEpochs = als.Params.GetInt(model.NEpochs, 50) als.initMean = als.Params.GetFloat32(model.InitMean, 0) als.initStdDev = als.Params.GetFloat32(model.InitStdDev, 0.1) als.reg = als.Params.GetFloat32(model.Reg, 0.06) als.weight = als.Params.GetFloat32(model.Alpha, 0.001) } func (als *ALS) SuggestParams(trial goptuna.Trial) model.Params { return model.Params{ model.NFactors: 16, model.InitMean: 0, model.InitStdDev: lo.Must(trial.SuggestLogFloat(string(model.InitStdDev), 0.001, 0.1)), model.Reg: lo.Must(trial.SuggestLogFloat(string(model.Reg), 0.001, 0.1)), model.Alpha: lo.Must(trial.SuggestLogFloat(string(model.Alpha), 0.001, 0.1)), } } func (als *ALS) Init(trainSet dataset.CFSplit) { // Initialize newUserFactor := als.GetRandomGenerator().NormalMatrix(trainSet.CountUsers(), als.nFactors, als.initMean, als.initStdDev) newItemFactor := als.GetRandomGenerator().NormalMatrix(trainSet.CountItems(), als.nFactors, als.initMean, als.initStdDev) // Initialize base als.UserFactor = newUserFactor als.ItemFactor = newItemFactor als.BaseMatrixFactorization.Init(trainSet) } // Fit the ALS model. Its task complexity is O(ccd.nEpochs). func (als *ALS) Fit(ctx context.Context, trainSet, valSet dataset.CFSplit, config *FitConfig) Score { log.Logger().Info("fit als", zap.Int("train_set_size", trainSet.CountFeedback()), zap.Int("test_set_size", valSet.CountFeedback()), zap.Any("params", als.GetParams()), zap.Any("config", config)) als.Init(trainSet) // Create temporary matrix s := util.NewMatrix32(als.nFactors, als.nFactors) userPredictions := make([][]float32, config.Jobs) itemPredictions := make([][]float32, config.Jobs) userRes := make([][]float32, config.Jobs) itemRes := make([][]float32, config.Jobs) for i := 0; i < config.Jobs; i++ { userPredictions[i] = make([]float32, trainSet.CountItems()) itemPredictions[i] = make([]float32, trainSet.CountUsers()) userRes[i] = make([]float32, trainSet.CountItems()) itemRes[i] = make([]float32, trainSet.CountUsers()) } // evaluate initial model evalStart := time.Now() score := Evaluate(als, valSet, trainSet, config.TopK, config.Candidates, config.Jobs, NDCG, Precision, Recall) scores := []lo.Tuple2[int, float32]{{A: 0, B: score[0]}} evalTime := time.Since(evalStart) log.Logger().Debug(fmt.Sprintf("fit als %v/%v", 0, als.nEpochs), zap.String("eval_time", evalTime.String()), zap.Float32(fmt.Sprintf("NDCG@%v", config.TopK), score[0]), zap.Float32(fmt.Sprintf("Precision@%v", config.TopK), score[1]), zap.Float32(fmt.Sprintf("Recall@%v", config.TopK), score[2])) _, span := monitor.Start(ctx, "ALS.Fit", als.nEpochs) defer span.End() for ep := 1; ep <= als.nEpochs; ep++ { fitStart := time.Now() // Update user factors // S^q <- \sum^N_{itemIndex=1} c_i q_i q_i^T floats.MatZero(s) for itemIndex := 0; itemIndex < trainSet.CountItems(); itemIndex++ { if err := ctx.Err(); err != nil { log.Logger().Info("fit als canceled", zap.Int("epoch", ep), zap.Error(err)) return Score{} } if len(trainSet.GetItemFeedback()[itemIndex]) > 0 { for i := 0; i < als.nFactors; i++ { for j := 0; j < als.nFactors; j++ { s[i][j] += als.ItemFactor[itemIndex][i] * als.ItemFactor[itemIndex][j] } } } } if err := parallel.Parallel(ctx, trainSet.CountUsers(), config.Jobs, func(workerId, userIndex int) error { userFeedback := trainSet.GetUserFeedback()[userIndex] for _, i := range userFeedback { userPredictions[workerId][i] = als.internalPredict(int32(userIndex), i) } for f := 0; f < als.nFactors; f++ { // for itemIndex \in R_u do \hat_{r}^f_{ui} <- \hat_{r}_{ui} - p_{uf]q_{if} for _, i := range userFeedback { userRes[workerId][i] = userPredictions[workerId][i] - als.UserFactor[userIndex][f]*als.ItemFactor[i][f] } // p_{uf} <- a, b, c := float32(0), float32(0), float32(0) for _, i := range userFeedback { a += (1 - (1-als.weight)*userRes[workerId][i]) * als.ItemFactor[i][f] c += (1 - als.weight) * als.ItemFactor[i][f] * als.ItemFactor[i][f] } for k := 0; k < als.nFactors; k++ { if k != f { b += als.weight * als.UserFactor[userIndex][k] * s[k][f] } } als.UserFactor[userIndex][f] = (a - b) / (c + als.weight*s[f][f] + als.reg) // for itemIndex \in R_u do \hat_{r}_{ui} <- \hat_{r}^f_{ui} - p_{uf]q_{if} for _, i := range userFeedback { userPredictions[workerId][i] = userRes[workerId][i] + als.UserFactor[userIndex][f]*als.ItemFactor[i][f] } } return nil }); err != nil { log.Logger().Info("fit als canceled", zap.Int("epoch", ep), zap.Error(err)) return Score{} } // Update item factors // S^p <- P^T P floats.MatZero(s) for userIndex := 0; userIndex < trainSet.CountUsers(); userIndex++ { if ctx.Err() != nil { log.Logger().Info("fit als canceled", zap.Int("epoch", ep), zap.Error(ctx.Err())) return Score{} } if len(trainSet.GetUserFeedback()[userIndex]) > 0 { for i := 0; i < als.nFactors; i++ { for j := 0; j < als.nFactors; j++ { s[i][j] += als.UserFactor[userIndex][i] * als.UserFactor[userIndex][j] } } } } if err := parallel.Parallel(ctx, trainSet.CountItems(), config.Jobs, func(workerId, itemIndex int) error { itemFeedback := trainSet.GetItemFeedback()[itemIndex] for _, u := range itemFeedback { itemPredictions[workerId][u] = als.internalPredict(u, int32(itemIndex)) } for f := 0; f < als.nFactors; f++ { // for itemIndex \in R_u do \hat_{r}^f_{ui} <- \hat_{r}_{ui} - p_{uf]q_{if} for _, u := range itemFeedback { itemRes[workerId][u] = itemPredictions[workerId][u] - als.UserFactor[u][f]*als.ItemFactor[itemIndex][f] } // q_{if} <- a, b, c := float32(0), float32(0), float32(0) for _, u := range itemFeedback { a += (1 - (1-als.weight)*itemRes[workerId][u]) * als.UserFactor[u][f] c += (1 - als.weight) * als.UserFactor[u][f] * als.UserFactor[u][f] } for k := 0; k < als.nFactors; k++ { if k != f { b += als.weight * als.ItemFactor[itemIndex][k] * s[k][f] } } als.ItemFactor[itemIndex][f] = (a - b) / (c + als.weight*s[f][f] + als.reg) // for itemIndex \in R_u do \hat_{r}_{ui} <- \hat_{r}^f_{ui} - p_{uf]q_{if} for _, u := range itemFeedback { itemPredictions[workerId][u] = itemRes[workerId][u] + als.UserFactor[u][f]*als.ItemFactor[itemIndex][f] } } return nil }); err != nil { log.Logger().Info("fit als canceled", zap.Int("epoch", ep), zap.Error(ctx.Err())) return Score{} } fitTime := time.Since(fitStart) // Cross validation if ep%config.Verbose == 0 || ep == als.nEpochs { evalStart = time.Now() score = Evaluate(als, valSet, trainSet, config.TopK, config.Candidates, config.Jobs, NDCG, Precision, Recall) scores = append(scores, lo.Tuple2[int, float32]{A: ep, B: score[0]}) evalTime = time.Since(evalStart) log.Logger().Debug(fmt.Sprintf("fit als %v/%v", ep, als.nEpochs), zap.String("fit_time", fitTime.String()), zap.String("eval_time", evalTime.String()), zap.Float32(fmt.Sprintf("NDCG@%v", config.TopK), score[0]), zap.Float32(fmt.Sprintf("Precision@%v", config.TopK), score[1]), zap.Float32(fmt.Sprintf("Recall@%v", config.TopK), score[2])) // early stopping if no improvement in last `patience` epochs if config.Patience > 0 && ep > config.Patience { epochScore := lo.MaxBy(scores, func(a, b lo.Tuple2[int, float32]) bool { return a.B > b.B }) if epochScore.A <= ep-config.Patience { log.Logger().Info("early stopping", zap.Int("best_epoch", epochScore.A), zap.Float32("best_NDCG", epochScore.B), zap.Int("patience", config.Patience)) break } } } span.Add(1) } log.Logger().Info("fit als complete", zap.Float32(fmt.Sprintf("NDCG@%v", config.TopK), score[0]), zap.Float32(fmt.Sprintf("Precision@%v", config.TopK), score[1]), zap.Float32(fmt.Sprintf("Recall@%v", config.TopK), score[2])) return Score{ NDCG: score[0], Precision: score[1], Recall: score[2], } } // Marshal model into byte stream. func (als *ALS) Marshal(w io.Writer) error { if err := als.BaseMatrixFactorization.Marshal(w); err != nil { return errors.Trace(err) } return nil } // Unmarshal model from byte stream. func (als *ALS) Unmarshal(r io.Reader) error { if err := als.BaseMatrixFactorization.Unmarshal(r); err != nil { return errors.Trace(err) } als.SetParams(als.Params) return nil } ================================================ FILE: model/cf/model_test.go ================================================ // Copyright 2021 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package cf import ( "bytes" "math" "runtime" "testing" "github.com/gorse-io/gorse/common/floats" "github.com/gorse-io/gorse/dataset" "github.com/gorse-io/gorse/model" "github.com/stretchr/testify/assert" ) const benchDelta = 0.01 func newFitConfig(_ int) *FitConfig { cfg := NewFitConfig().SetVerbose(1).SetJobs(runtime.NumCPU()) return cfg } func TestBPR_MovieLens(t *testing.T) { trainSet, testSet, err := dataset.LoadDataFromBuiltIn("ml-1m") assert.NoError(t, err) m := NewBPR(model.Params{ model.NFactors: 8, model.Reg: 0.01, model.Lr: 0.05, model.NEpochs: 30, model.InitMean: 0, model.InitStdDev: 0.001, }) fitConfig := newFitConfig(30) score := m.Fit(t.Context(), trainSet, testSet, fitConfig) assert.InDelta(t, 0.36, score.NDCG, benchDelta) assert.Equal(t, trainSet.GetUserDict(), m.GetUserIndex()) assert.Equal(t, testSet.GetItemDict(), m.GetItemIndex()) // test predict assert.Equal(t, m.Predict("1", "1"), m.internalPredict(1, 1)) assert.Equal(t, m.internalPredict(1, 1), floats.Dot(m.GetUserFactor(1), m.GetItemFactor(1))) assert.True(t, m.IsUserPredictable(1)) assert.True(t, m.IsItemPredictable(1)) assert.False(t, m.IsUserPredictable(math.MaxInt32)) assert.False(t, m.IsItemPredictable(math.MaxInt32)) // test encode/decode model and increment training buf := bytes.NewBuffer(nil) err = MarshalModel(buf, m) assert.NoError(t, err) tmp, err := UnmarshalModel(buf) assert.NoError(t, err) assert.Equal(t, m.Params, tmp.GetParams()) assert.Equal(t, m.Predict("1", "1"), tmp.Predict("1", "1")) assert.True(t, m.IsUserPredictable(1)) assert.True(t, m.IsItemPredictable(1)) assert.False(t, m.IsUserPredictable(math.MaxInt32)) assert.False(t, m.IsItemPredictable(math.MaxInt32)) // test clear m.Clear() assert.True(t, m.Invalid()) } //func TestBPR_Pinterest(t *testing.T) { // trainSet, testSet, err := LoadDataFromBuiltIn("pinterest-20") // assert.NoError(t, err) // m := NewBPR(model.Params{ // model.NFactors: 8, // model.Reg: 0.005, // model.Lr: 0.05, // model.NEpochs: 50, // model.InitMean: 0, // model.InitStdDev: 0.001, // }) // score := m.Fit(trainSet, testSet, fitConfig) // assertEpsilon(t, 0.53, score.NDCG, benchDelta) //} func TestCCD_MovieLens(t *testing.T) { trainSet, testSet, err := dataset.LoadDataFromBuiltIn("ml-1m") assert.NoError(t, err) m := NewALS(model.Params{ model.NFactors: 8, model.Reg: 0.015, model.NEpochs: 30, model.Alpha: 0.05, }) fitConfig := newFitConfig(30) score := m.Fit(t.Context(), trainSet, testSet, fitConfig) assert.InDelta(t, 0.36, score.NDCG, benchDelta) // test predict assert.Equal(t, m.Predict("1", "1"), m.internalPredict(1, 1)) assert.Equal(t, m.internalPredict(1, 1), floats.Dot(m.GetUserFactor(1), m.GetItemFactor(1))) // test encode/decode model and increment training buf := bytes.NewBuffer(nil) err = MarshalModel(buf, m) assert.NoError(t, err) tmp, err := UnmarshalModel(buf) assert.NoError(t, err) assert.Equal(t, m.Params, tmp.GetParams()) assert.Equal(t, m.Predict("1", "1"), tmp.Predict("1", "1")) // test clear m.Clear() assert.True(t, m.Invalid()) } //func TestCCD_Pinterest(t *testing.T) { // trainSet, testSet, err := LoadDataFromBuiltIn("pinterest-20") // assert.NoError(t, err) // m := NewALS(model.Params{ // model.NFactors: 8, // model.Reg: 0.01, // model.NEpochs: 20, // model.InitStdDev: 0.01, // model.Alpha: 0.001, // }) // score := m.Fit(trainSet, testSet, fitConfig) // assertEpsilon(t, 0.52, score.NDCG, benchDelta) //} ================================================ FILE: model/cf/optimize.go ================================================ // Copyright 2020 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package cf import ( "context" "errors" "github.com/c-bata/goptuna" "github.com/gorse-io/gorse/common/monitor" "github.com/gorse-io/gorse/dataset" "github.com/gorse-io/gorse/storage/meta" "golang.org/x/exp/maps" ) type ModelCreator func() MatrixFactorization type ModelSearch struct { modelCreators map[string]ModelCreator modelTypes []string trainSet dataset.CFSplit valSet dataset.CFSplit config *FitConfig ctx context.Context span *monitor.Span result meta.Model[Score] } func NewModelSearch(models map[string]ModelCreator, trainSet, valSet dataset.CFSplit, config *FitConfig) *ModelSearch { return &ModelSearch{ modelCreators: models, modelTypes: maps.Keys(models), trainSet: trainSet, valSet: valSet, config: config, } } func (ms *ModelSearch) WithContext(ctx context.Context) *ModelSearch { ms.ctx = ctx return ms } func (ms *ModelSearch) WithSpan(span *monitor.Span) *ModelSearch { ms.span = span return ms } func (ms *ModelSearch) Objective(trial goptuna.Trial) (float64, error) { if len(ms.modelCreators) == 0 { return 0, errors.New("no model to search") } modelType, err := trial.SuggestCategorical("Model", ms.modelTypes) if err != nil { return 0, err } m := ms.modelCreators[modelType]() m.SetParams(m.SuggestParams(trial)) score := m.Fit(ms.ctx, ms.trainSet, ms.valSet, ms.config) if score.NDCG > ms.result.Score.NDCG { ms.result.Type = modelType ms.result.Params = m.GetParams() ms.result.Score = score } if ms.span != nil { ms.span.Add(1) } return float64(score.NDCG), nil } func (ms *ModelSearch) Result() meta.Model[Score] { return ms.result } ================================================ FILE: model/cf/optimize_test.go ================================================ // Copyright 2020 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package cf import ( "context" "io" "testing" "github.com/c-bata/goptuna" "github.com/c-bata/goptuna/tpe" "github.com/gorse-io/gorse/dataset" "github.com/gorse-io/gorse/model" "github.com/samber/lo" "github.com/stretchr/testify/assert" ) type mockMatrixFactorizationForSearch struct { model.BaseModel } func newMockMatrixFactorizationForSearch(numEpoch int) *mockMatrixFactorizationForSearch { return &mockMatrixFactorizationForSearch{model.BaseModel{Params: model.Params{model.NEpochs: numEpoch}}} } func (m *mockMatrixFactorizationForSearch) GetUserFactor(_ int32) []float32 { panic("implement me") } func (m *mockMatrixFactorizationForSearch) GetItemFactor(_ int32) []float32 { panic("implement me") } func (m *mockMatrixFactorizationForSearch) IsUserPredictable(_ int32) bool { panic("implement me") } func (m *mockMatrixFactorizationForSearch) IsItemPredictable(_ int32) bool { panic("implement me") } func (m *mockMatrixFactorizationForSearch) Marshal(_ io.Writer) error { panic("implement me") } func (m *mockMatrixFactorizationForSearch) Unmarshal(_ io.Reader) error { panic("implement me") } func (m *mockMatrixFactorizationForSearch) Invalid() bool { panic("implement me") } func (m *mockMatrixFactorizationForSearch) GetUserIndex() *dataset.FreqDict { panic("don't call me") } func (m *mockMatrixFactorizationForSearch) GetItemIndex() *dataset.FreqDict { panic("don't call me") } func (m *mockMatrixFactorizationForSearch) Fit(_ context.Context, _, _ dataset.CFSplit, _ *FitConfig) Score { score := float32(0) score += m.Params.GetFloat32(model.NFactors, 0.0) score += m.Params.GetFloat32(model.InitMean, 0.0) score += m.Params.GetFloat32(model.InitStdDev, 0.0) return Score{NDCG: score} } func (m *mockMatrixFactorizationForSearch) Predict(_, _ string) float32 { panic("don't call me") } func (m *mockMatrixFactorizationForSearch) internalPredict(_, _ int32) float32 { panic("don't call me") } func (m *mockMatrixFactorizationForSearch) Clear() { // do nothing } func (m *mockMatrixFactorizationForSearch) SuggestParams(trial goptuna.Trial) model.Params { return model.Params{ model.NFactors: lo.Must(trial.SuggestDiscreteFloat(string(model.NFactors), 1, 4, 1)), model.InitMean: lo.Must(trial.SuggestDiscreteFloat(string(model.InitMean), 1, 4, 1)), model.InitStdDev: lo.Must(trial.SuggestDiscreteFloat(string(model.InitStdDev), 4, 4, 1)), } } func TestTPE(t *testing.T) { search := NewModelSearch(map[string]ModelCreator{ "mock": func() MatrixFactorization { return newMockMatrixFactorizationForSearch(10) }, }, nil, nil, nil) study, err := goptuna.CreateStudy("TestTPE", goptuna.StudyOptionDirection(goptuna.StudyDirectionMaximize), goptuna.StudyOptionSampler(tpe.NewSampler())) assert.NoError(t, err) err = study.Optimize(search.Objective, 10) assert.NoError(t, err) v, _ := study.GetBestValue() assert.Equal(t, float64(12), v) result := search.Result() assert.Equal(t, "mock", result.Type) assert.Equal(t, model.Params{ model.NFactors: float64(4), model.InitMean: float64(4), model.InitStdDev: float64(4), }, result.Params) assert.Equal(t, Score{NDCG: 12}, result.Score) } ================================================ FILE: model/ctr/data.go ================================================ // Copyright 2020 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package ctr import ( "bufio" "encoding/json" "os" "strconv" "strings" mapset "github.com/deckarep/golang-set/v2" "github.com/gorse-io/gorse/common/jsonutil" "github.com/gorse-io/gorse/common/util" "github.com/gorse-io/gorse/dataset" "github.com/gorse-io/gorse/model" "github.com/juju/errors" "github.com/samber/lo" "modernc.org/mathutil" ) type Label struct { Name string Value float32 } func ConvertLabels(o any) []Label { features := make([]Label, 0) return convertLabels(features, "", o) } func convertLabels(result []Label, prefix string, o any) []Label { if o == nil { return nil } switch labels := o.(type) { case []any: if len(labels) == 0 { return nil } switch labels[0].(type) { case string: for _, val := range labels { if s, ok := val.(string); ok { result = append(result, Label{ Name: prefix + s, Value: 1, }) } else { panic("unsupported labels: " + jsonutil.MustMarshal(labels)) } } } case map[string]any: for key, val := range labels { result = convertLabels(result, prefix+key+".", val) } case string: result = append(result, Label{ Name: prefix + labels, Value: 1, }) case json.Number: value, _ := labels.Float64() result = append(result, Label{ Name: prefix, Value: float32(value), }) } return result } type Embedding struct { Name string Value []float32 } func ConvertEmbeddings(o any) []Embedding { embeddings := make([]Embedding, 0) return convertEmbeddings(embeddings, "", o) } func convertEmbeddings(result []Embedding, prefix string, o any) []Embedding { if o == nil { return nil } switch embeddings := o.(type) { case []any: if len(embeddings) == 0 { return nil } var value []float32 for _, val := range embeddings { switch v := val.(type) { case float64: value = append(value, float32(v)) case float32: value = append(value, v) default: return result } } result = append(result, Embedding{ Name: prefix, Value: value, }) case []float64: result = append(result, Embedding{ Name: prefix, Value: lo.Map(embeddings, func(f float64, _ int) float32 { return float32(f) }), }) case []float32: result = append(result, Embedding{ Name: prefix, Value: embeddings, }) case map[string]any: for key, val := range embeddings { if prefix == "" { result = convertEmbeddings(result, key, val) } else { result = convertEmbeddings(result, prefix+"."+key, val) } } } return result } // Dataset for click-through-rate models. type Dataset struct { Index dataset.UnifiedIndex UserLabels [][]lo.Tuple2[int32, float32] ItemLabels [][]lo.Tuple2[int32, float32] ContextLabels [][]lo.Tuple2[int32, float32] Users []int32 Items []int32 Target []float32 ItemEmbeddings [][][]float32 // Index by row id, embedding id, embedding dimension ItemEmbeddingDimension []int ItemEmbeddingIndex *dataset.Index PositiveCount int NegativeCount int } // CountUsers returns the number of users. func (dataset *Dataset) CountUsers() int { return int(dataset.Index.CountUsers()) } // CountItems returns the number of items. func (dataset *Dataset) CountItems() int { return int(dataset.Index.CountItems()) } func (dataset *Dataset) CountUserLabels() int { return int(dataset.Index.CountUserLabels()) } func (dataset *Dataset) CountItemLabels() int { return int(dataset.Index.CountItemLabels()) } func (dataset *Dataset) CountContextLabels() int { return int(dataset.Index.CountContextLabels()) } func (dataset *Dataset) CountPositive() int { return dataset.PositiveCount } func (dataset *Dataset) CountNegative() int { return dataset.NegativeCount } func (dataset *Dataset) GetIndex() dataset.UnifiedIndex { return dataset.Index } // Count returns the number of samples. func (dataset *Dataset) Count() int { if len(dataset.Users) != len(dataset.Items) { panic("len(dataset.Users) != len(dataset.Items)") } if len(dataset.Users) > 0 && len(dataset.Users) != len(dataset.Target) { panic("len(dataset.Users) != len(dataset.Target)") } if dataset.ContextLabels != nil && len(dataset.ContextLabels) != len(dataset.Target) { panic("len(dataset.ContextLabels) != len(dataset.Target)") } return len(dataset.Target) } func (dataset *Dataset) GetTarget(i int) float32 { return dataset.Target[i] } // Get returns the i-th sample. func (dataset *Dataset) Get(i int) ([]int32, []float32, [][]float32, float32) { var ( indices []int32 values []float32 embedding [][]float32 position int32 ) // append user id if len(dataset.Users) > 0 { indices = append(indices, dataset.Users[i]) values = append(values, 1) position += int32(dataset.CountUsers()) } // append item id if len(dataset.Items) > 0 { indices = append(indices, position+dataset.Items[i]) values = append(values, 1) position += int32(dataset.CountItems()) if len(dataset.ItemEmbeddings) > 0 { embedding = dataset.ItemEmbeddings[dataset.Items[i]] } } // append user indices if len(dataset.Users) > 0 { userFeatures := dataset.UserLabels[dataset.Users[i]] for _, feature := range userFeatures { indices = append(indices, position+feature.A) values = append(values, feature.B) } position += dataset.Index.CountUserLabels() } // append item indices if len(dataset.Items) > 0 { itemFeatures := dataset.ItemLabels[dataset.Items[i]] for _, feature := range itemFeatures { indices = append(indices, position+feature.A) values = append(values, feature.B) } } // append context indices if dataset.ContextLabels != nil { contextIndices, contextValues := lo.Unzip2(dataset.ContextLabels[i]) indices = append(indices, contextIndices...) values = append(values, contextValues...) } return indices, values, embedding, dataset.Target[i] } // LoadLibFMFile loads libFM format file. func LoadLibFMFile(path string) (features [][]lo.Tuple2[int32, float32], targets []float32, maxLabel int32, err error) { // open file file, err := os.Open(path) if err != nil { return nil, []float32{}, 0, errors.Trace(err) } defer file.Close() // read lines scanner := bufio.NewScanner(file) for scanner.Scan() { line := scanner.Text() fields := strings.Split(line, " ") // fetch target target, err := strconv.ParseFloat(fields[0], 32) if err != nil { return nil, []float32{}, 0, errors.Trace(err) } targets = append(targets, float32(target)) // fetch features lineFeatures := make([]lo.Tuple2[int32, float32], 0, len(fields[1:])) for _, field := range fields[1:] { if len(strings.TrimSpace(field)) > 0 { kv := strings.Split(field, ":") k, v := kv[0], kv[1] // append feature feature, err := strconv.Atoi(k) if err != nil { return nil, []float32{}, 0, errors.Trace(err) } value, err := strconv.ParseFloat(v, 32) if err != nil { return nil, []float32{}, 0, errors.Trace(err) } lineFeatures = append(lineFeatures, lo.Tuple2[int32, float32]{ A: int32(feature), B: float32(value), }) maxLabel = mathutil.MaxInt32Val(maxLabel, int32(feature)) } } features = append(features, lineFeatures) } // check error if err = scanner.Err(); err != nil { return nil, []float32{}, 0, errors.Trace(err) } return } // LoadDataFromBuiltIn loads built-in dataset. func LoadDataFromBuiltIn(name string) (train, test *Dataset, err error) { trainFilePath, testFilePath, err := model.LocateBuiltInDataset(name, model.FormatLibFM) if err != nil { return nil, nil, err } train, test = &Dataset{}, &Dataset{} var trainMaxLabel, testMaxLabel int32 if train.ContextLabels, train.Target, trainMaxLabel, err = LoadLibFMFile(trainFilePath); err != nil { return nil, nil, err } if test.ContextLabels, test.Target, testMaxLabel, err = LoadLibFMFile(testFilePath); err != nil { return nil, nil, err } unifiedIndex := dataset.NewUnifiedDirectIndex(mathutil.MaxInt32(trainMaxLabel, testMaxLabel) + 1) train.Index = unifiedIndex test.Index = unifiedIndex return } // Split a dataset to training set and test set. func (dataset *Dataset) Split(ratio float32, seed int64) (*Dataset, *Dataset) { // create train/test dataset trainSet := &Dataset{ Index: dataset.Index, UserLabels: dataset.UserLabels, ItemLabels: dataset.ItemLabels, ItemEmbeddings: dataset.ItemEmbeddings, ItemEmbeddingIndex: dataset.ItemEmbeddingIndex, ItemEmbeddingDimension: dataset.ItemEmbeddingDimension, } testSet := &Dataset{ Index: dataset.Index, UserLabels: dataset.UserLabels, ItemLabels: dataset.ItemLabels, ItemEmbeddings: dataset.ItemEmbeddings, ItemEmbeddingIndex: dataset.ItemEmbeddingIndex, ItemEmbeddingDimension: dataset.ItemEmbeddingDimension, } // split by random numTestSize := int(float32(dataset.Count()) * ratio) rng := util.NewRandomGenerator(seed) sampledIndex := mapset.NewSet(rng.Sample(0, dataset.Count(), numTestSize)...) for i := 0; i < len(dataset.Target); i++ { if sampledIndex.Contains(i) { // add samples into test set testSet.Users = append(testSet.Users, dataset.Users[i]) testSet.Items = append(testSet.Items, dataset.Items[i]) if dataset.ContextLabels != nil { testSet.ContextLabels = append(testSet.ContextLabels, dataset.ContextLabels[i]) } testSet.Target = append(testSet.Target, dataset.Target[i]) if dataset.Target[i] > 0 { testSet.PositiveCount++ } else { testSet.NegativeCount++ } } else { // add samples into train set trainSet.Users = append(trainSet.Users, dataset.Users[i]) trainSet.Items = append(trainSet.Items, dataset.Items[i]) if dataset.ContextLabels != nil { trainSet.ContextLabels = append(trainSet.ContextLabels, dataset.ContextLabels[i]) } trainSet.Target = append(trainSet.Target, dataset.Target[i]) if dataset.Target[i] > 0 { trainSet.PositiveCount++ } else { trainSet.NegativeCount++ } } } return trainSet, testSet } func (dataset *Dataset) GetItemEmbeddingDim() []int { return dataset.ItemEmbeddingDimension } func (dataset *Dataset) GetItemEmbeddingIndex() *dataset.Index { return dataset.ItemEmbeddingIndex } ================================================ FILE: model/ctr/data_test.go ================================================ // Copyright 2021 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package ctr import ( "encoding/json" "fmt" "testing" "github.com/gorse-io/gorse/dataset" "github.com/samber/lo" "github.com/stretchr/testify/assert" ) func TestConvertLabels(t *testing.T) { features := ConvertLabels(nil) assert.Nil(t, features) // categorical features features = ConvertLabels("label") assert.ElementsMatch(t, []Label{{Name: "label", Value: 1}}, features) features = ConvertLabels([]any{"1", "2", "3"}) assert.ElementsMatch(t, []Label{ {Name: "1", Value: 1}, {Name: "2", Value: 1}, {Name: "3", Value: 1}, }, features) features = ConvertLabels(map[string]any{"city": "wenzhou", "tags": []any{"1", "2", "3"}}) assert.ElementsMatch(t, []Label{ {Name: "city.wenzhou", Value: 1}, {Name: "tags.1", Value: 1}, {Name: "tags.2", Value: 1}, {Name: "tags.3", Value: 1}, }, features) features = ConvertLabels(map[string]any{"address": map[string]any{"province": "zhejiang", "city": "wenzhou"}}) assert.ElementsMatch(t, []Label{ {Name: "address.province.zhejiang", Value: 1}, {Name: "address.city.wenzhou", Value: 1}, }, features) // numerical features features = ConvertLabels(json.Number("1")) assert.Equal(t, []Label{{Name: "", Value: 1}}, features) features = ConvertLabels(map[string]any{"city": "wenzhou", "tags": json.Number("0.5")}) assert.ElementsMatch(t, []Label{ {Name: "city.wenzhou", Value: 1}, {Name: "tags.", Value: 0.5}, }, features) // not supported features = ConvertLabels([]any{float64(1), float64(2), float64(3)}) assert.Empty(t, features) features = ConvertLabels(map[string]any{"city": "wenzhou", "tags": []any{float64(1), float64(2), float64(3)}}) assert.ElementsMatch(t, []Label{{Name: "city.wenzhou", Value: 1}}, features) } func TestConvertEmbeddings(t *testing.T) { embeddings := ConvertEmbeddings(nil) assert.Nil(t, embeddings) embeddings = ConvertEmbeddings([]float32{1, 2, 3}) if assert.Len(t, embeddings, 1) { assert.Equal(t, "", embeddings[0].Name) assert.Equal(t, []float32{1, 2, 3}, embeddings[0].Value) } embeddings = ConvertEmbeddings([]float64{1, 2, 3}) if assert.Len(t, embeddings, 1) { assert.Equal(t, "", embeddings[0].Name) assert.Equal(t, []float32{1, 2, 3}, embeddings[0].Value) } embeddings = ConvertEmbeddings([]any{float64(1), float32(2), float64(3)}) if assert.Len(t, embeddings, 1) { assert.Equal(t, "", embeddings[0].Name) assert.Equal(t, []float32{1, 2, 3}, embeddings[0].Value) } embeddings = ConvertEmbeddings(map[string]any{ "embedding1": []float32{1, 2, 3}, "a": map[string]any{ "embedding2": []float64{4, 5, 6}, }, "no_embedding": "test", }) if assert.Len(t, embeddings, 2) { assert.ElementsMatch(t, []Embedding{ {Name: "embedding1", Value: []float32{1, 2, 3}}, {Name: "a.embedding2", Value: []float32{4, 5, 6}}, }, embeddings) } } func TestLoadDataFromBuiltIn(t *testing.T) { train, test, err := LoadDataFromBuiltIn("frappe") assert.NoError(t, err) assert.Equal(t, 202027, train.Count()) assert.Equal(t, 28860, test.Count()) } func TestDataset_Split(t *testing.T) { // create dataset unifiedIndex := dataset.NewUnifiedMapIndexBuilder() dataSet := NewMapIndexDataset() numUsers, numItems := 5, 6 for i := 0; i < numUsers; i++ { unifiedIndex.AddUser(fmt.Sprintf("user%v", i)) unifiedIndex.AddUserLabel(fmt.Sprintf("user_label%v", 2*i)) unifiedIndex.AddUserLabel(fmt.Sprintf("user_label%v", 2*i+1)) dataSet.UserLabels = append(dataSet.UserLabels, []lo.Tuple2[int32, float32]{ {A: int32(2 * i), B: 1}, {A: int32(2*i + 1), B: 1}, }) } for i := 0; i < numItems; i++ { unifiedIndex.AddItem(fmt.Sprintf("item%v", i)) unifiedIndex.AddItemLabel(fmt.Sprintf("item_label%v", 3*i)) unifiedIndex.AddItemLabel(fmt.Sprintf("item_label%v", 3*i+1)) unifiedIndex.AddItemLabel(fmt.Sprintf("item_label%v", 3*i+2)) dataSet.ItemLabels = append(dataSet.ItemLabels, []lo.Tuple2[int32, float32]{ {A: int32(3 * i), B: 1}, {A: int32(3*i + 1), B: 1}, {A: int32(3*i + 2), B: 1}, }) dataSet.ItemEmbeddings = append(dataSet.ItemEmbeddings, [][]float32{ {float32(i), float32(i) + 0.1, float32(i) + 0.2}, }) } for i := 0; i < numUsers; i++ { for j := 0; j < numItems; j++ { if i+j > 4 { dataSet.Users = append(dataSet.Users, int32(i)) dataSet.Items = append(dataSet.Items, int32(j)) dataSet.ContextLabels = append(dataSet.ContextLabels, []lo.Tuple2[int32, float32]{{A: int32(i * j), B: 0.5}}) dataSet.Target = append(dataSet.Target, 1) dataSet.PositiveCount++ } else { dataSet.Users = append(dataSet.Users, int32(i)) dataSet.Items = append(dataSet.Items, int32(j)) dataSet.ContextLabels = append(dataSet.ContextLabels, []lo.Tuple2[int32, float32]{{A: int32(i * j), B: 0.5}}) dataSet.Target = append(dataSet.Target, -1) dataSet.NegativeCount++ } } } dataSet.Index = unifiedIndex.Build() assert.Equal(t, numUsers*numItems, dataSet.Count()) assert.Equal(t, numUsers, dataSet.CountUsers()) assert.Equal(t, numItems, dataSet.CountItems()) assert.Equal(t, numUsers*numItems/2, dataSet.PositiveCount) assert.Equal(t, numUsers*numItems/2, dataSet.NegativeCount) features, values, embeddings, target := dataSet.Get(2) assert.Equal(t, []int32{ 0, dataSet.Index.CountUsers() + 2, dataSet.Index.CountUsers() + dataSet.Index.CountItems() + 0, dataSet.Index.CountUsers() + dataSet.Index.CountItems() + 1, dataSet.Index.CountUsers() + dataSet.Index.CountItems() + dataSet.Index.CountUserLabels() + 6, dataSet.Index.CountUsers() + dataSet.Index.CountItems() + dataSet.Index.CountUserLabels() + 7, dataSet.Index.CountUsers() + dataSet.Index.CountItems() + dataSet.Index.CountUserLabels() + 8, 0, }, features) assert.Equal(t, [][]float32{{2, 2.1, 2.2}}, embeddings) assert.Equal(t, []float32{1, 1, 1, 1, 1, 1, 1, 0.5}, values) assert.Equal(t, float32(-1), target) // split train, test := dataSet.Split(0.2, 0) assert.Equal(t, numUsers, train.CountUsers()) assert.Equal(t, numItems, train.CountItems()) assert.Equal(t, 24, train.Count()) assert.Equal(t, 12, train.PositiveCount) assert.Equal(t, 12, train.NegativeCount) assert.Equal(t, numUsers, test.CountUsers()) assert.Equal(t, numItems, test.CountItems()) assert.Equal(t, 6, test.Count()) assert.Equal(t, 3, test.PositiveCount) assert.Equal(t, 3, test.NegativeCount) } ================================================ FILE: model/ctr/evaluator.go ================================================ // Copyright 2020 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package ctr import ( "sort" "github.com/chewxy/math32" "github.com/gorse-io/gorse/dataset" "github.com/samber/lo" "modernc.org/sortutil" ) // EvaluateRegression evaluates factorization machines in regression task. func EvaluateRegression(estimator FactorizationMachines, testSet *Dataset) Score { sum := float32(0) // For all UserFeedback for i := 0; i < testSet.Count(); i++ { features, values, _, target := testSet.Get(i) prediction := estimator.InternalPredict(features, values) sum += (target - prediction) * (target - prediction) } if testSet.Count() == 0 { return Score{ RMSE: 0, } } return Score{ RMSE: math32.Sqrt(sum / float32(testSet.Count())), } } // EvaluateClassification evaluates factorization machines in classification task. func EvaluateClassification(estimator FactorizationMachines, testSet dataset.CTRSplit, jobs int) Score { // For all UserFeedback var posFeatures, negFeatures []lo.Tuple2[[]int32, []float32] var posEmbeddings, negEmbeddings [][][]float32 for i := 0; i < testSet.Count(); i++ { indices, values, embeddings, target := testSet.Get(i) if target > 0 { posFeatures = append(posFeatures, lo.Tuple2[[]int32, []float32]{A: indices, B: values}) posEmbeddings = append(posEmbeddings, embeddings) } else { negFeatures = append(negFeatures, lo.Tuple2[[]int32, []float32]{A: indices, B: values}) negEmbeddings = append(negEmbeddings, embeddings) } } var posPrediction, negPrediction []float32 if batchInference, ok := estimator.(BatchInference); ok { posPrediction = batchInference.BatchInternalPredict(posFeatures, posEmbeddings, jobs) negPrediction = batchInference.BatchInternalPredict(negFeatures, negEmbeddings, jobs) } else { for _, features := range posFeatures { posPrediction = append(posPrediction, estimator.InternalPredict(features.A, features.B)) } for _, features := range negFeatures { negPrediction = append(negPrediction, estimator.InternalPredict(features.A, features.B)) } } if testSet.Count() == 0 { return Score{ Precision: 0, } } return Score{ Precision: Precision(posPrediction, negPrediction), Recall: Recall(posPrediction, negPrediction), Accuracy: Accuracy(posPrediction, negPrediction), AUC: AUC(posPrediction, negPrediction), } } func Precision(posPrediction, negPrediction []float32) float32 { var tp, fp float32 for _, p := range posPrediction { if p > 0 { // true positive tp++ } } for _, p := range negPrediction { if p > 0 { // false positive fp++ } } if tp+fp == 0 { return 0 } return tp / (tp + fp) } func Recall(posPrediction, _ []float32) float32 { var tp, fn float32 for _, p := range posPrediction { if p > 0 { // true positive tp++ } else { // false negative fn++ } } if tp+fn == 0 { return 0 } return tp / (tp + fn) } func Accuracy(posPrediction, negPrediction []float32) float32 { var correct float32 for _, p := range posPrediction { if p > 0 { correct++ } } for _, p := range negPrediction { if p < 0 { correct++ } } if len(posPrediction)+len(negPrediction) == 0 { return 0 } return correct / float32(len(posPrediction)+len(negPrediction)) } func AUC(posPrediction, negPrediction []float32) float32 { sort.Sort(sortutil.Float32Slice(posPrediction)) sort.Sort(sortutil.Float32Slice(negPrediction)) var sum float32 var nPos int for pPos := range posPrediction { // find the negative sample with the greatest prediction less than current positive sample for nPos < len(negPrediction) && negPrediction[nPos] < posPrediction[pPos] { nPos++ } // add the number of negative samples have less prediction than current positive sample sum += float32(nPos) } if len(posPrediction)*len(negPrediction) == 0 { return 0 } return sum / float32(len(posPrediction)*len(negPrediction)) } ================================================ FILE: model/ctr/evaluator_test.go ================================================ // Copyright 2021 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package ctr import ( "github.com/stretchr/testify/assert" "testing" ) func TestPrecision(t *testing.T) { posPrediction := []float32{1, 1, 1} negPrediction := []float32{1} precision := Precision(posPrediction, negPrediction) assert.Equal(t, float32(0.75), precision) precision = Precision(nil, nil) assert.Zero(t, precision) } func TestRecall(t *testing.T) { posPrediction := []float32{1, -1, -1, -1} recall := Recall(posPrediction, nil) assert.Equal(t, float32(0.25), recall) recall = Recall(nil, nil) assert.Zero(t, recall) } func TestAccuracy(t *testing.T) { posPrediction := []float32{1, 1, -1, -1} negPrediction := []float32{1, 1, -1, -1} accuracy := Accuracy(posPrediction, negPrediction) assert.Equal(t, float32(0.5), accuracy) accuracy = Accuracy(nil, nil) assert.Zero(t, accuracy) } ================================================ FILE: model/ctr/fm.go ================================================ //go:build !cgo || !xla // Copyright 2024 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package ctr import ( "context" "fmt" "io" "sync" "time" "github.com/c-bata/goptuna" "github.com/chewxy/math32" "github.com/gorse-io/gorse/common/encoding" "github.com/gorse-io/gorse/common/log" "github.com/gorse-io/gorse/common/monitor" "github.com/gorse-io/gorse/common/nn" "github.com/gorse-io/gorse/dataset" "github.com/gorse-io/gorse/model" "github.com/juju/errors" "github.com/samber/lo" "go.uber.org/zap" "modernc.org/mathutil" ) const headerAFM = "AFM" type AFM struct { BaseFactorizationMachines mu sync.RWMutex // parameters B *nn.Tensor W nn.Layer V nn.Layer A []nn.Layer E []nn.Layer // hyper parameters batchSize int nFactors int nEpochs int lr float32 reg float32 initMean float32 initStdDev float32 optimizer string // dataset stats numFeatures int numDimension int embeddingDim []int embeddingIndex *dataset.Index } func NewAFM(params model.Params) *AFM { fm := new(AFM) fm.SetParams(params) return fm } func (fm *AFM) SuggestParams(trial goptuna.Trial) model.Params { return model.Params{ model.NFactors: 16, model.Lr: lo.Must(trial.SuggestLogFloat(string(model.Lr), 0.001, 0.1)), model.Reg: lo.Must(trial.SuggestLogFloat(string(model.Reg), 0.001, 0.1)), model.InitMean: 0, model.InitStdDev: lo.Must(trial.SuggestLogFloat(string(model.InitStdDev), 0.001, 0.1)), } } func (fm *AFM) SetParams(params model.Params) { fm.BaseFactorizationMachines.SetParams(params) fm.batchSize = fm.Params.GetInt(model.BatchSize, 1024) fm.nFactors = fm.Params.GetInt(model.NFactors, 16) fm.nEpochs = fm.Params.GetInt(model.NEpochs, 50) fm.lr = fm.Params.GetFloat32(model.Lr, 0.001) fm.reg = fm.Params.GetFloat32(model.Reg, 0.0002) fm.initMean = fm.Params.GetFloat32(model.InitMean, 0) fm.initStdDev = fm.Params.GetFloat32(model.InitStdDev, 0.01) fm.optimizer = fm.Params.GetString(model.Optimizer, model.Adam) } func (fm *AFM) Clear() { fm.Index = nil } func (fm *AFM) Invalid() bool { return fm == nil || fm.Index == nil } func (fm *AFM) Forward(indices, values *nn.Tensor, embeddings []*nn.Tensor, jobs int) *nn.Tensor { batchSize := indices.Shape()[0] v := fm.V.Forward(indices) x := nn.Reshape(values, batchSize, fm.numDimension, 1) vx := nn.BMM(v, x, true, false, jobs) sumSquare := nn.Square(vx) e2 := nn.Square(v) x2 := nn.Square(x) squareSum := nn.BMM(e2, x2, true, false, jobs) sum := nn.Sub(sumSquare, squareSum) sum = nn.Sum(sum, 1) sum = nn.Mul(sum, nn.NewScalar(0.5)) w := fm.W.Forward(indices) linear := nn.BMM(w, x, true, false, jobs) fmOutput := nn.Add(nn.Reshape(linear, batchSize), nn.Reshape(sum, batchSize), fm.B) // Encode the embedding for i, embedding := range embeddings { encodedNorm := fm.E[i].Forward(fm.A[i].Forward(embedding)) encodedNorm = nn.Reshape(encodedNorm, batchSize, fm.nFactors, 1) fmOutput = nn.Add(fmOutput, nn.Reshape(nn.BMM(vx, encodedNorm, true, false, jobs), batchSize)) } return nn.Flatten(fmOutput) } func (fm *AFM) Parameters() []*nn.Tensor { var params []*nn.Tensor params = append(params, fm.B) params = append(params, fm.V.Parameters()...) params = append(params, fm.W.Parameters()...) for i := range fm.embeddingDim { params = append(params, fm.A[i].Parameters()...) params = append(params, fm.E[i].Parameters()...) } return params } func (fm *AFM) Predict(_, _ string, _, _ []Label) float32 { panic("Predict is unsupported for deep learning models") } func (fm *AFM) InternalPredict(_ []int32, _ []float32) float32 { panic("InternalPredict is unsupported for deep learning models") } func (fm *AFM) BatchInternalPredict(x []lo.Tuple2[[]int32, []float32], e [][][]float32, jobs int) []float32 { fm.mu.RLock() defer fm.mu.RUnlock() indicesTensor, valuesTensor, embeddingTensor, _ := fm.convertToTensors(x, e, nil) predictions := make([]float32, 0, len(x)) for i := 0; i < len(x); i += fm.batchSize { j := mathutil.Min(i+fm.batchSize, len(x)) embeddingTensorSlice := make([]*nn.Tensor, len(fm.embeddingDim)) for k := range fm.embeddingDim { embeddingTensorSlice[k] = embeddingTensor[k].Slice(i, j) } output := fm.Forward(indicesTensor.Slice(i, j), valuesTensor.Slice(i, j), embeddingTensorSlice, jobs) predictions = append(predictions, output.Data()...) } return predictions[:len(x)] } func (fm *AFM) BatchPredict(inputs []lo.Tuple4[string, string, []Label, []Label], embeddings [][]Embedding, jobs int) []float32 { x := make([]lo.Tuple2[[]int32, []float32], len(inputs)) for i, input := range inputs { // encode user if userIndex := fm.Index.EncodeUser(input.A); userIndex != dataset.NotId { x[i].A = append(x[i].A, userIndex) x[i].B = append(x[i].B, 1) } // encode item if itemIndex := fm.Index.EncodeItem(input.B); itemIndex != dataset.NotId { x[i].A = append(x[i].A, itemIndex) x[i].B = append(x[i].B, 1) } // encode user labels for _, userFeature := range input.C { if userFeatureIndex := fm.Index.EncodeUserLabel(userFeature.Name); userFeatureIndex != dataset.NotId { x[i].A = append(x[i].A, userFeatureIndex) x[i].B = append(x[i].B, userFeature.Value) } } // encode item labels for _, itemFeature := range input.D { if itemFeatureIndex := fm.Index.EncodeItemLabel(itemFeature.Name); itemFeatureIndex != dataset.NotId { x[i].A = append(x[i].A, itemFeatureIndex) x[i].B = append(x[i].B, itemFeature.Value) } } } e := make([][][]float32, len(inputs)) for i := range inputs { e[i] = make([][]float32, len(fm.embeddingDim)) for _, embedding := range embeddings[i] { itemIndex := fm.embeddingIndex.ToNumber(embedding.Name) if itemIndex == dataset.NotId { // unknown embedding continue } index := int(itemIndex) if len(embedding.Value) != fm.embeddingDim[index] { // dimension mismatch continue } e[i][index] = embedding.Value } } return fm.BatchInternalPredict(x, e, jobs) } func (fm *AFM) Init(trainSet dataset.CTRSplit) { fm.numFeatures = int(trainSet.GetIndex().Len()) fm.numDimension = 0 for i := 0; i < trainSet.Count(); i++ { _, x, _, _ := trainSet.Get(i) fm.numDimension = mathutil.MaxVal(fm.numDimension, len(x)) } fm.B = nn.Zeros() fm.W = nn.NewEmbedding(int(trainSet.GetIndex().Len()), 1) fm.V = nn.NewEmbedding(int(trainSet.GetIndex().Len()), fm.nFactors) fm.embeddingDim = trainSet.GetItemEmbeddingDim() fm.embeddingIndex = trainSet.GetItemEmbeddingIndex() fm.A = make([]nn.Layer, len(fm.embeddingDim)) fm.E = make([]nn.Layer, len(fm.embeddingDim)) for i, dim := range fm.embeddingDim { fm.A[i] = nn.NewAttention(dim, fm.nFactors) fm.E[i] = nn.NewLinear(dim, fm.nFactors) } fm.BaseFactorizationMachines.Init(trainSet) } func (fm *AFM) Fit(ctx context.Context, trainSet, testSet dataset.CTRSplit, config *FitConfig) Score { log.Logger().Info("fit AFM", zap.Int("train_set_size", trainSet.Count()), zap.Int("test_set_size", testSet.Count()), zap.Any("params", fm.GetParams()), zap.Any("config", config)) fm.Init(trainSet) fm.W.SetJobs(config.Jobs) fm.V.SetJobs(config.Jobs) for i := range fm.embeddingDim { fm.A[i].SetJobs(config.Jobs) fm.E[i].SetJobs(config.Jobs) } evalStart := time.Now() score := EvaluateClassification(fm, testSet, config.Jobs) scores := []lo.Tuple2[int, float32]{{A: 0, B: score.AUC}} evalTime := time.Since(evalStart) fields := append([]zap.Field{zap.String("eval_time", evalTime.String())}, score.ZapFields()...) log.Logger().Info(fmt.Sprintf("fit AFM %v/%v", 0, fm.nEpochs), fields...) var x []lo.Tuple2[[]int32, []float32] var e [][][]float32 var y []float32 for i := 0; i < trainSet.Count(); i++ { indices, values, embeddings, target := trainSet.Get(i) x = append(x, lo.Tuple2[[]int32, []float32]{A: indices, B: values}) e = append(e, embeddings) y = append(y, target) } indices, values, embeddings, target := fm.convertToTensors(x, e, y) var optimizer nn.Optimizer switch fm.optimizer { case model.SGD: optimizer = nn.NewSGD(fm.Parameters(), fm.lr) case model.Adam: optimizer = nn.NewAdam(fm.Parameters(), fm.lr) default: panic("unknown optimizer") } optimizer.SetWeightDecay(fm.reg) optimizer.SetJobs(config.Jobs) _, span := monitor.Start(ctx, "FM.Fit", fm.nEpochs) defer span.End() for epoch := 1; epoch <= fm.nEpochs; epoch++ { fitStart := time.Now() cost := float32(0) for i := 0; i < trainSet.Count(); i += fm.batchSize { if ctx.Err() != nil { log.Logger().Info("fit AFM canceled", zap.Error(ctx.Err())) return Score{} } j := mathutil.Min(i+fm.batchSize, trainSet.Count()) batchIndices := indices.Slice(i, j) batchValues := values.Slice(i, j) batchEmbedding := make([]*nn.Tensor, len(fm.embeddingDim)) for k := range fm.embeddingDim { batchEmbedding[k] = embeddings[k].Slice(i, j) } batchTarget := target.Slice(i, j) batchOutput := fm.Forward(batchIndices, batchValues, batchEmbedding, config.Jobs) batchLoss := nn.BCEWithLogits(batchTarget, batchOutput, nil) cost += batchLoss.Data()[0] optimizer.ZeroGrad() batchLoss.Backward() optimizer.Step() } fitTime := time.Since(fitStart) // Cross validation if epoch%config.Verbose == 0 || epoch == fm.nEpochs { evalStart = time.Now() score = EvaluateClassification(fm, testSet, config.Jobs) scores = append(scores, lo.Tuple2[int, float32]{A: epoch, B: score.AUC}) evalTime = time.Since(evalStart) fields = append([]zap.Field{ zap.String("fit_time", fitTime.String()), zap.String("eval_time", evalTime.String()), zap.Float32("loss", cost), }, score.ZapFields()...) log.Logger().Info(fmt.Sprintf("fit AFM %v/%v", epoch, fm.nEpochs), fields...) // check NaN if math32.IsNaN(cost) || math32.IsNaN(score.GetValue()) { log.Logger().Warn("model diverged", zap.Float32("lr", fm.lr)) break } // early stopping if no improvement in last `patience` epochs if config.Patience > 0 && epoch > config.Patience { epochScore := lo.MaxBy(scores, func(a, b lo.Tuple2[int, float32]) bool { return a.B > b.B }) if epochScore.A <= epoch-config.Patience { log.Logger().Info("early stopping", zap.Int("best_epoch", epochScore.A), zap.Float32("best_auc", epochScore.B), zap.Int("patience", config.Patience)) break } } } span.Add(1) } return score } func (fm *AFM) Marshal(w io.Writer) error { // write params if err := encoding.WriteGob(w, fm.Params); err != nil { return errors.Trace(err) } // write index if err := dataset.MarshalUnifiedIndex(w, fm.Index); err != nil { return errors.Trace(err) } // write dataset stats if err := encoding.WriteGob(w, fm.numFeatures); err != nil { return errors.Trace(err) } if err := encoding.WriteGob(w, fm.numDimension); err != nil { return errors.Trace(err) } if err := encoding.WriteGob(w, fm.embeddingDim); err != nil { return errors.Trace(err) } if len(fm.embeddingDim) > 0 { if err := dataset.MarshalIndex(w, fm.embeddingIndex); err != nil { return errors.Trace(err) } } // write parameters if err := nn.Save(fm.Parameters(), w); err != nil { return errors.Trace(err) } return nil } func (fm *AFM) Unmarshal(r io.Reader) error { // read params err := encoding.ReadGob(r, &fm.Params) if err != nil { return errors.Trace(err) } fm.SetParams(fm.Params) // read index fm.Index, err = dataset.UnmarshalUnifiedIndex(r) if err != nil { return errors.Trace(err) } // read dataset stats if err = encoding.ReadGob(r, &fm.numFeatures); err != nil { return errors.Trace(err) } if err = encoding.ReadGob(r, &fm.numDimension); err != nil { return errors.Trace(err) } if err = encoding.ReadGob(r, &fm.embeddingDim); err != nil { return errors.Trace(err) } if len(fm.embeddingDim) > 0 { fm.embeddingIndex, err = dataset.UnmarshalIndex(r) if err != nil { return errors.Trace(err) } } // read parameters fm.B = nn.Zeros() fm.W = nn.NewEmbedding(fm.numFeatures, 1) fm.V = nn.NewEmbedding(fm.numFeatures, fm.nFactors) fm.A = make([]nn.Layer, len(fm.embeddingDim)) fm.E = make([]nn.Layer, len(fm.embeddingDim)) for i, dim := range fm.embeddingDim { fm.A[i] = nn.NewAttention(dim, fm.nFactors) fm.E[i] = nn.NewLinear(dim, fm.nFactors) } if err = nn.Load(fm.Parameters(), r); err != nil { return errors.Trace(err) } return nil } func (fm *AFM) convertToTensors(x []lo.Tuple2[[]int32, []float32], e [][][]float32, y []float32) ( indicesTensor, valuesTensor *nn.Tensor, embeddingTensor []*nn.Tensor, targetTensor *nn.Tensor, ) { if y != nil && len(x) != len(y) { panic("length of x and y must be equal") } alignedIndices := make([]float32, len(x)*fm.numDimension) alignedValues := make([]float32, len(x)*fm.numDimension) alignedEmbeddings := make([][]float32, len(fm.embeddingDim)) for i := range fm.embeddingDim { alignedEmbeddings[i] = make([]float32, 0, len(x)*fm.embeddingDim[i]) } alignedTarget := make([]float32, len(x)) for i := range x { if len(x[i].A) != len(x[i].B) { panic("length of indices and values must be equal") } for j := range x[i].A { alignedIndices[i*fm.numDimension+j] = float32(x[i].A[j]) alignedValues[i*fm.numDimension+j] = x[i].B[j] } for j := range fm.embeddingDim { if len(e[i]) > j && len(e[i][j]) == fm.embeddingDim[j] { alignedEmbeddings[j] = append(alignedEmbeddings[j], e[i][j]...) } else { alignedEmbeddings[j] = append(alignedEmbeddings[j], make([]float32, fm.embeddingDim[j])...) } } if y != nil { alignedTarget[i] = y[i] } } indicesTensor = nn.NewTensor(alignedIndices, len(x), fm.numDimension) valuesTensor = nn.NewTensor(alignedValues, len(x), fm.numDimension) embeddingTensor = make([]*nn.Tensor, len(fm.embeddingDim)) for i := range fm.embeddingDim { embeddingTensor[i] = nn.NewTensor(alignedEmbeddings[i], len(x), fm.embeddingDim[i]) } if y != nil { targetTensor = nn.NewTensor(alignedTarget, len(x)) } return } ================================================ FILE: model/ctr/fm_xla.go ================================================ //go:build cgo && xla // Copyright 2025 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package ctr import ( std_context "context" "fmt" "io" "sync" "time" "github.com/c-bata/goptuna" "github.com/gomlx/gomlx/backends" _ "github.com/gomlx/gomlx/backends/xla" "github.com/gomlx/gomlx/pkg/core/dtypes" "github.com/gomlx/gomlx/pkg/core/graph" "github.com/gomlx/gomlx/pkg/core/shapes" "github.com/gomlx/gomlx/pkg/core/tensors" mlx_context "github.com/gomlx/gomlx/pkg/ml/context" "github.com/gomlx/gomlx/pkg/ml/context/initializers" "github.com/gomlx/gomlx/pkg/ml/layers" "github.com/gomlx/gomlx/pkg/ml/layers/activations" "github.com/gomlx/gomlx/pkg/ml/train" "github.com/gomlx/gomlx/pkg/ml/train/losses" "github.com/gomlx/gomlx/pkg/ml/train/optimizers" "github.com/gorse-io/gorse/common/encoding" "github.com/gorse-io/gorse/common/log" "github.com/gorse-io/gorse/common/monitor" "github.com/gorse-io/gorse/dataset" "github.com/gorse-io/gorse/model" "github.com/juju/errors" "github.com/samber/lo" "go.uber.org/zap" "modernc.org/mathutil" ) const headerAFM = "AFM" type AFM struct { BaseFactorizationMachines mu sync.RWMutex ctx *mlx_context.Context backend backends.Backend // hyper parameters batchSize int nFactors int nEpochs int lr float32 reg float32 initMean float32 initStdDev float32 optimizer string // dataset stats numFeatures int numDimension int embeddingDim []int embeddingIndex *dataset.Index // compiled executors predictExecutor *mlx_context.Exec } func NewAFM(params model.Params) *AFM { fm := new(AFM) fm.SetParams(params) return fm } func (fm *AFM) SuggestParams(trial goptuna.Trial) model.Params { return model.Params{ model.NFactors: 16, model.Lr: lo.Must(trial.SuggestLogFloat(string(model.Lr), 0.001, 0.1)), model.Reg: lo.Must(trial.SuggestLogFloat(string(model.Reg), 0.001, 0.1)), model.InitMean: 0, model.InitStdDev: lo.Must(trial.SuggestLogFloat(string(model.InitStdDev), 0.001, 0.1)), } } func (fm *AFM) SetParams(params model.Params) { fm.BaseFactorizationMachines.SetParams(params) fm.batchSize = fm.Params.GetInt(model.BatchSize, 1024) fm.nFactors = fm.Params.GetInt(model.NFactors, 16) fm.nEpochs = fm.Params.GetInt(model.NEpochs, 50) fm.lr = fm.Params.GetFloat32(model.Lr, 0.001) fm.reg = fm.Params.GetFloat32(model.Reg, 0.0002) fm.initMean = fm.Params.GetFloat32(model.InitMean, 0) fm.initStdDev = fm.Params.GetFloat32(model.InitStdDev, 0.01) fm.optimizer = fm.Params.GetString(model.Optimizer, model.Adam) } func (fm *AFM) Clear() { fm.Index = nil } func (fm *AFM) Invalid() bool { return fm == nil || fm.Index == nil } func (fm *AFM) Predict(_, _ string, _, _ []Label) float32 { panic("Predict is unsupported for deep learning models") } func (fm *AFM) InternalPredict(_ []int32, _ []float32) float32 { panic("InternalPredict is unsupported for deep learning models") } func (fm *AFM) attentionForward(ctx *mlx_context.Context, x *graph.Node, dimensions, k int) *graph.Node { g := x.Graph() // W: Linear(dimensions -> k) wCtx := ctx.In("attention_w") w := layers.Dense(wCtx, x, true, k) w = activations.Relu(w) // H: [k, dimensions] hCtx := ctx.In("attention_h") hVar := hCtx.VariableWithShape("H", shapes.Make(dtypes.F32, k, dimensions)) h := hVar.ValueGraph(g) // Softmax(W * H, 1) // w: [batchSize, k] // h: [k, dimensions] // score: [batchSize, dimensions] score := graph.Dot(w, h) score = graph.Softmax(score, 1) // score * x return graph.Mul(score, x) } func (fm *AFM) forwardGraph(ctx *mlx_context.Context, indices, values *graph.Node, additionalEmbeddings []*graph.Node) *graph.Node { // Disable variable checks to allow reuse ctx = ctx.Checked(false) g := indices.Graph() batchSize := indices.Shape().Dimensions[0] // V: Embedding(numFeatures, nFactors) vCtx := ctx.In("V") v := layers.Embedding(vCtx, indices, dtypes.F32, fm.numFeatures, fm.nFactors) // [batchSize, numDimension, nFactors] // x: values [batchSize, numDimension, 1] x := graph.Reshape(values, batchSize, fm.numDimension, 1) // vx: BMM(v, x, true, false) -> [batchSize, nFactors, 1] // contracting axes: [1] (numDimension), batch axes: [0] vx := graph.DotGeneral(v, []int{1}, []int{0}, x, []int{1}, []int{0}) // Interaction part: 0.5 * sum(vx^2 - sum(v^2 * x^2)) sumSquare := graph.Square(vx) e2 := graph.Square(v) x2 := graph.Square(x) squareSum := graph.DotGeneral(e2, []int{1}, []int{0}, x2, []int{1}, []int{0}) interaction := graph.Sub(sumSquare, squareSum) interaction = graph.ReduceSum(interaction, 1) // [batchSize, 1] interaction = graph.Mul(interaction, graph.Scalar(g, dtypes.F32, 0.5)) // Linear part: sum(W[indices] * values) wCtx := ctx.In("W") w := layers.Embedding(wCtx, indices, dtypes.F32, fm.numFeatures, 1) // [batchSize, numDimension, 1] linear := graph.DotGeneral(w, []int{1}, []int{0}, x, []int{1}, []int{0}) linear = graph.Reshape(linear, batchSize, 1) // Bias bCtx := ctx.In("B") bVar := bCtx.VariableWithShape("bias", shapes.Make(dtypes.F32, 1)) bias := bVar.ValueGraph(g) bias = graph.Reshape(bias, 1, 1) // Reshape to [1, 1] for broadcasting with [batchSize, 1] fmOutput := graph.Add(graph.Add(linear, interaction), bias) // [batchSize, 1] // Additional embeddings with attention for i, embedding := range additionalEmbeddings { // A: Attention aCtx := ctx.In(fmt.Sprintf("A_%d", i)) attended := fm.attentionForward(aCtx, embedding, fm.embeddingDim[i], fm.nFactors) // E: Linear(dim -> nFactors) eCtx := ctx.In(fmt.Sprintf("E_%d", i)) encoded := layers.Dense(eCtx, attended, true, fm.nFactors) encoded = graph.Reshape(encoded, batchSize, fm.nFactors, 1) // Output: vx^T * encoded -> [batchSize, 1, 1] // vx: [batch, nFactors, 1], encoded: [batch, nFactors, 1] // contracting axes: [1] (nFactors), batch axes: [0] term := graph.DotGeneral(vx, []int{1}, []int{0}, encoded, []int{1}, []int{0}) fmOutput = graph.Add(fmOutput, graph.Reshape(term, batchSize, 1)) } return graph.Reshape(fmOutput, batchSize) } func (fm *AFM) BatchInternalPredict(x []lo.Tuple2[[]int32, []float32], e [][][]float32, jobs int) []float32 { fm.mu.RLock() defer fm.mu.RUnlock() if fm.predictExecutor == nil { var err error fm.predictExecutor, err = mlx_context.NewExec(fm.backend, fm.ctx, func(ctx *mlx_context.Context, nodes []*graph.Node) *graph.Node { res := fm.forwardGraph(ctx, nodes[0], nodes[1], nodes[2:]) return res }) if err != nil { panic(err) } } // Prepare data numBatches := (len(x) + fm.batchSize - 1) / fm.batchSize predictions := make([]float32, 0, len(x)) for b := 0; b < numBatches; b++ { start := b * fm.batchSize end := mathutil.Min(start+fm.batchSize, len(x)) batchSize := end - start indicesData := make([]int32, batchSize*fm.numDimension) valuesData := make([]float32, batchSize*fm.numDimension) additionalData := make([][]float32, len(fm.embeddingDim)) for i := range additionalData { additionalData[i] = make([]float32, batchSize*fm.embeddingDim[i]) } for i := 0; i < batchSize; i++ { row := x[start+i] for j := 0; j < len(row.A); j++ { indicesData[i*fm.numDimension+j] = row.A[j] valuesData[i*fm.numDimension+j] = row.B[j] } for j := range fm.embeddingDim { if len(e[start+i]) > j && len(e[start+i][j]) == fm.embeddingDim[j] { copy(additionalData[j][i*fm.embeddingDim[j]:], e[start+i][j]) } } } inputs := []any{ tensors.FromFlatDataAndDimensions(indicesData, batchSize, fm.numDimension), tensors.FromFlatDataAndDimensions(valuesData, batchSize, fm.numDimension), } for i := range additionalData { inputs = append(inputs, tensors.FromFlatDataAndDimensions(additionalData[i], batchSize, fm.embeddingDim[i])) } outputs := fm.predictExecutor.MustExec(inputs...) batchPreds := outputs[0].Value().([]float32) predictions = append(predictions, batchPreds...) } return predictions[:len(x)] } func (fm *AFM) BatchPredict(inputs []lo.Tuple4[string, string, []Label, []Label], embeddings [][]Embedding, jobs int) []float32 { x := make([]lo.Tuple2[[]int32, []float32], len(inputs)) for i, input := range inputs { // encode user if userIndex := fm.Index.EncodeUser(input.A); userIndex != dataset.NotId { x[i].A = append(x[i].A, userIndex) x[i].B = append(x[i].B, 1) } // encode item if itemIndex := fm.Index.EncodeItem(input.B); itemIndex != dataset.NotId { x[i].A = append(x[i].A, itemIndex) x[i].B = append(x[i].B, 1) } // encode user labels for _, userFeature := range input.C { if userFeatureIndex := fm.Index.EncodeUserLabel(userFeature.Name); userFeatureIndex != dataset.NotId { x[i].A = append(x[i].A, userFeatureIndex) x[i].B = append(x[i].B, userFeature.Value) } } // encode item labels for _, itemFeature := range input.D { if itemFeatureIndex := fm.Index.EncodeItemLabel(itemFeature.Name); itemFeatureIndex != dataset.NotId { x[i].A = append(x[i].A, itemFeatureIndex) x[i].B = append(x[i].B, itemFeature.Value) } } } e := make([][][]float32, len(inputs)) for i := range inputs { e[i] = make([][]float32, len(fm.embeddingDim)) for _, embedding := range embeddings[i] { itemIndex := fm.embeddingIndex.ToNumber(embedding.Name) if itemIndex == dataset.NotId { // unknown embedding continue } index := int(itemIndex) if len(embedding.Value) != fm.embeddingDim[index] { // dimension mismatch continue } e[i][index] = embedding.Value } } return fm.BatchInternalPredict(x, e, jobs) } func (fm *AFM) Init(trainSet dataset.CTRSplit) { fm.numFeatures = int(trainSet.GetIndex().Len()) fm.numDimension = 0 for i := 0; i < trainSet.Count(); i++ { _, x, _, _ := trainSet.Get(i) fm.numDimension = mathutil.MaxVal(fm.numDimension, len(x)) } fm.embeddingDim = trainSet.GetItemEmbeddingDim() fm.embeddingIndex = trainSet.GetItemEmbeddingIndex() if fm.ctx == nil { fm.ctx = mlx_context.New() fm.ctx.SetParam("initializers_seed", int64(42)) // Set default initializer to Normal(0, 0.01) to match nn package fm.ctx = fm.ctx.WithInitializer(initializers.RandomNormalFn(fm.ctx, float64(fm.initStdDev))) } if fm.backend == nil { var err error fm.backend, err = backends.New() if err != nil { panic(err) } } fm.BaseFactorizationMachines.Init(trainSet) } type ctrDataset struct { trainSet dataset.CTRSplit numFeatures int numDimension int embeddingDim []int batchSize int currentOffset int } func (d *ctrDataset) Name() string { return "CTRDataset" } func (d *ctrDataset) Reset() { d.currentOffset = 0 } func (d *ctrDataset) Yield() (spec any, inputs []*tensors.Tensor, labels []*tensors.Tensor, err error) { if d.currentOffset >= d.trainSet.Count() { return nil, nil, nil, io.EOF } batchSize := mathutil.Min(d.batchSize, d.trainSet.Count()-d.currentOffset) indicesData := make([]int32, batchSize*d.numDimension) valuesData := make([]float32, batchSize*d.numDimension) additionalData := make([][]float32, len(d.embeddingDim)) for i := range additionalData { additionalData[i] = make([]float32, batchSize*d.embeddingDim[i]) } labelsData := make([]float32, batchSize) for i := 0; i < batchSize; i++ { indices, values, embeddings, target := d.trainSet.Get(d.currentOffset + i) for j := 0; j < len(indices); j++ { indicesData[i*d.numDimension+j] = indices[j] valuesData[i*d.numDimension+j] = values[j] } for j := range d.embeddingDim { if len(embeddings) > j && len(embeddings[j]) == d.embeddingDim[j] { copy(additionalData[j][i*d.embeddingDim[j]:], embeddings[j]) } } // Convert target from {-1, 1} to {0, 1} for GoMLX BinaryCrossentropy labelsData[i] = (target + 1) / 2 } d.currentOffset += batchSize inputs = []*tensors.Tensor{ tensors.FromFlatDataAndDimensions(indicesData, batchSize, d.numDimension), tensors.FromFlatDataAndDimensions(valuesData, batchSize, d.numDimension), } for i := range additionalData { inputs = append(inputs, tensors.FromFlatDataAndDimensions(additionalData[i], batchSize, d.embeddingDim[i])) } labels = []*tensors.Tensor{tensors.FromFlatDataAndDimensions(labelsData, batchSize)} return nil, inputs, labels, nil } func (fm *AFM) Fit(ctx std_context.Context, trainSet, testSet dataset.CTRSplit, config *FitConfig) Score { log.Logger().Info("fit AFM (mlx)", zap.Int("train_set_size", trainSet.Count()), zap.Int("test_set_size", testSet.Count()), zap.Any("params", fm.GetParams()), zap.Any("config", config)) fm.Init(trainSet) evalStart := time.Now() score := EvaluateClassification(fm, testSet, config.Jobs) scores := []lo.Tuple2[int, float32]{{A: 0, B: score.AUC}} evalTime := time.Since(evalStart) fields := append([]zap.Field{zap.String("eval_time", evalTime.String())}, score.ZapFields()...) log.Logger().Info(fmt.Sprintf("fit AFM %v/%v", 0, fm.nEpochs), fields...) modelFn := func(ctx *mlx_context.Context, spec any, inputs []*graph.Node) []*graph.Node { return []*graph.Node{fm.forwardGraph(ctx, inputs[0], inputs[1], inputs[2:])} } lossFn := func(labels, predictions []*graph.Node) *graph.Node { return graph.ReduceAllMean(losses.BinaryCrossentropyLogits(labels, predictions)) } optimizer := optimizers.Adam().LearningRate(float64(fm.lr)).Done() trainer := train.NewTrainer(fm.backend, fm.ctx, modelFn, lossFn, optimizer, nil, nil) loop := train.NewLoop(trainer) ds := &ctrDataset{ trainSet: trainSet, numFeatures: fm.numFeatures, numDimension: fm.numDimension, embeddingDim: fm.embeddingDim, batchSize: fm.batchSize, } _, span := monitor.Start(ctx, "FM.Fit", fm.nEpochs) defer span.End() for epoch := 1; epoch <= fm.nEpochs; epoch++ { fitStart := time.Now() ds.Reset() _, err := loop.RunSteps(ds, (trainSet.Count()+fm.batchSize-1)/fm.batchSize) if err != nil { panic(err) } fitTime := time.Since(fitStart) if epoch%config.Verbose == 0 || epoch == fm.nEpochs { evalStart = time.Now() score = EvaluateClassification(fm, testSet, config.Jobs) scores = append(scores, lo.Tuple2[int, float32]{A: epoch, B: score.AUC}) evalTime = time.Since(evalStart) fields := append([]zap.Field{ zap.String("fit_time", fitTime.String()), zap.String("eval_time", evalTime.String()), }, score.ZapFields()...) log.Logger().Info(fmt.Sprintf("fit AFM %v/%v", epoch, fm.nEpochs), fields...) if config.Patience > 0 && epoch > config.Patience { epochScore := lo.MaxBy(scores, func(a, b lo.Tuple2[int, float32]) bool { return a.B > b.B }) if epochScore.A <= epoch-config.Patience { log.Logger().Info("early stopping", zap.Int("best_epoch", epochScore.A), zap.Float32("best_auc", epochScore.B), zap.Int("patience", config.Patience)) break } } } span.Add(1) } return score } type savedVariable struct { Dimensions []int Data any Scope string Name string } func (fm *AFM) Marshal(w io.Writer) error { // write params if err := encoding.WriteGob(w, fm.Params); err != nil { return errors.Trace(err) } // write index if err := dataset.MarshalUnifiedIndex(w, fm.Index); err != nil { return errors.Trace(err) } // write dataset stats if err := encoding.WriteGob(w, fm.numFeatures); err != nil { return errors.Trace(err) } if err := encoding.WriteGob(w, fm.numDimension); err != nil { return errors.Trace(err) } if err := encoding.WriteGob(w, fm.embeddingDim); err != nil { return errors.Trace(err) } if len(fm.embeddingDim) > 0 { if err := dataset.MarshalIndex(w, fm.embeddingIndex); err != nil { return errors.Trace(err) } } // write parameters (GoMLX variables) variables := make(map[string]savedVariable) fm.ctx.EnumerateVariables(func(v *mlx_context.Variable) { val, err := v.Value() if err != nil { panic(err) } var flatData any val.MustConstFlatData(func(flat any) { flatData = flat }) variables[v.ScopeAndName()] = savedVariable{ Dimensions: val.Shape().Dimensions, Data: flatData, Scope: v.Scope(), Name: v.Name(), } }) if err := encoding.WriteGob(w, variables); err != nil { return errors.Trace(err) } return nil } func (fm *AFM) Unmarshal(r io.Reader) error { // read params err := encoding.ReadGob(r, &fm.Params) if err != nil { return errors.Trace(err) } fm.SetParams(fm.Params) // read index fm.Index, err = dataset.UnmarshalUnifiedIndex(r) if err != nil { return errors.Trace(err) } // read dataset stats if err = encoding.ReadGob(r, &fm.numFeatures); err != nil { return errors.Trace(err) } if err = encoding.ReadGob(r, &fm.numDimension); err != nil { return errors.Trace(err) } if err = encoding.ReadGob(r, &fm.embeddingDim); err != nil { return errors.Trace(err) } if len(fm.embeddingDim) > 0 { fm.embeddingIndex, err = dataset.UnmarshalIndex(r) if err != nil { return errors.Trace(err) } } // read parameters var variables map[string]savedVariable if err = encoding.ReadGob(r, &variables); err != nil { return errors.Trace(err) } if fm.ctx == nil { fm.ctx = mlx_context.New() } if fm.backend == nil { var err error fm.backend, err = backends.New() if err != nil { return errors.Trace(err) } } for _, data := range variables { // Use a type switch to handle different data types in tensors var t *tensors.Tensor switch d := data.Data.(type) { case []float32: t = tensors.FromFlatDataAndDimensions(d, data.Dimensions...) case []float64: t = tensors.FromFlatDataAndDimensions(d, data.Dimensions...) case []int32: t = tensors.FromFlatDataAndDimensions(d, data.Dimensions...) case []int64: t = tensors.FromFlatDataAndDimensions(d, data.Dimensions...) case []uint64: t = tensors.FromFlatDataAndDimensions(d, data.Dimensions...) default: log.Logger().Warn("unknown variable type", zap.String("scope", data.Scope), zap.String("name", data.Name), zap.Any("type", fmt.Sprintf("%T", d))) continue } fm.ctx.InAbsPath(data.Scope).VariableWithValue(data.Name, t) } return nil } ================================================ FILE: model/ctr/model.go ================================================ // Copyright 2020 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package ctr import ( "context" "fmt" "io" "reflect" "github.com/gorse-io/gorse/common/encoding" "github.com/gorse-io/gorse/dataset" "github.com/gorse-io/gorse/model" "github.com/juju/errors" "github.com/samber/lo" "go.uber.org/zap" ) type Score struct { RMSE float32 Precision float32 Recall float32 Accuracy float32 AUC float32 } func (score Score) ZapFields() []zap.Field { return []zap.Field{ zap.Float32("Accuracy", score.Accuracy), zap.Float32("Precision", score.Precision), zap.Float32("Recall", score.Recall), zap.Float32("AUC", score.AUC), } } func (score Score) GetValue() float32 { return score.Precision } func (score Score) BetterThan(s Score) bool { return score.AUC > s.AUC } type FitConfig struct { Jobs int Verbose int Patience int } func NewFitConfig() *FitConfig { return &FitConfig{ Jobs: 1, Verbose: 10, Patience: 10, } } func (config *FitConfig) SetVerbose(verbose int) *FitConfig { config.Verbose = verbose return config } func (config *FitConfig) SetJobs(jobs int) *FitConfig { config.Jobs = jobs return config } func (config *FitConfig) SetPatience(patience int) *FitConfig { config.Patience = patience return config } func (config *FitConfig) LoadDefaultIfNil() *FitConfig { if config == nil { return NewFitConfig() } return config } type FactorizationMachines interface { model.Model Predict(userId, itemId string, userFeatures, itemFeatures []Label) float32 InternalPredict(x []int32, values []float32) float32 Fit(ctx context.Context, trainSet, testSet dataset.CTRSplit, config *FitConfig) Score Marshal(w io.Writer) error } type BatchInference interface { BatchPredict(inputs []lo.Tuple4[string, string, []Label, []Label], e [][]Embedding, jobs int) []float32 BatchInternalPredict(x []lo.Tuple2[[]int32, []float32], e [][][]float32, jobs int) []float32 } type BaseFactorizationMachines struct { model.BaseModel Index dataset.UnifiedIndex } func (b *BaseFactorizationMachines) Init(trainSet dataset.CTRSplit) { b.Index = trainSet.GetIndex() } func MarshalModel(w io.Writer, m FactorizationMachines) error { // write header var err error switch m.(type) { case *AFM: err = encoding.WriteString(w, headerAFM) default: return fmt.Errorf("unknown model: %v", reflect.TypeOf(m)) } if err != nil { return err } return m.Marshal(w) } func UnmarshalModel(r io.Reader) (FactorizationMachines, error) { // read header header, err := encoding.ReadString(r) if err != nil { return nil, err } switch header { case headerAFM: var fm AFM if err := fm.Unmarshal(r); err != nil { return nil, errors.Trace(err) } return &fm, nil } return nil, fmt.Errorf("unknown model: %v", header) } ================================================ FILE: model/ctr/model.py ================================================ import os from pathlib import Path from typing import Dict, List, Tuple import click import torch from tqdm import tqdm class Dataset: def __init__(self): self.indices = [] self.values = [] self.targets = [] self.num_fields = 0 self.num_features = 0 def __len__(self): assert len(self.indices) == len(self.targets) return len(self.indices) def aligned(self) -> Tuple[List[List[int]], List[List[float]]]: aligned_indices = [] aligned_values = [] for i in range(len(self)): aligned_indices_row = self.indices[i] aligned_values_row = self.values[i] if len(aligned_indices_row) < self.num_fields: aligned_indices_row += [0] * ( self.num_fields - len(aligned_indices_row) ) aligned_values_row += [0] * (self.num_fields - len(aligned_values_row)) aligned_indices.append(aligned_indices_row) aligned_values.append(aligned_values_row) return aligned_indices, aligned_values def load_libfm(path: str) -> Dataset: dataset = Dataset() with open(path, 'r') as f: for line in f.readlines(): splits = line.strip().split(' ') indices = [int(v.split(':')[0]) for v in splits[1:]] values = [float(v.split(':')[1]) for v in splits[1:]] target = 1 if float(splits[0]) == 1 else 0 dataset.indices.append(indices) dataset.values.append(values) dataset.targets.append(target) dataset.num_fields = max(dataset.num_fields, len(indices)) dataset.num_features = max(dataset.num_features, max(indices) + 1) return dataset def load_dataset(name: str) -> Tuple[Dataset, Dataset]: dataset_dir = os.path.join(Path.home(), '.gorse', 'dataset', name) return load_libfm(os.path.join(dataset_dir, "train.libfm")), load_libfm(os.path.join(dataset_dir, "test.libfm")) def accuracy(positive_predictions: List[float], negative_predictions: List[float]) -> float: num_pos = len(positive_predictions) num_neg = len(negative_predictions) num_correct = 0 for pos in positive_predictions: if pos > 0: num_correct += 1 for neg in negative_predictions: if neg <= 0: num_correct += 1 return num_correct / (num_pos + num_neg) def auc(positive_predictions: List[float], negative_predictions: List[float]) -> float: sorted_positive_predictions = sorted(positive_predictions) sorted_negative_predictions = sorted(negative_predictions) sum = 0.0 num_pos = 0 for pos in sorted_positive_predictions: while ( num_pos < len(sorted_negative_predictions) and sorted_negative_predictions[num_pos] < pos ): num_pos += 1 sum += num_pos return sum / (len(positive_predictions) * len(negative_predictions)) class Evaluator: def __init__(self, test: Dataset, device: str = "cpu") -> None: align_indices, align_values = test.aligned() self.indices = torch.tensor(align_indices, dtype=torch.long).to(device) self.values = torch.tensor(align_values, dtype=torch.float).to(device) self.positive_samples = [] self.negative_samples = [] for i in range(len(test)): if test.targets[i] > 0: self.positive_samples.append(i) else: self.negative_samples.append(i) def evaluate(self, model) -> Dict[str, float]: positive_predictions = ( model.predict( self.indices[self.positive_samples], self.values[self.positive_samples], ) .cpu() .tolist() ) negative_predictions = ( model.predict( self.indices[self.negative_samples], self.values[self.negative_samples], ) .cpu() .tolist() ) return { "Accuracy": accuracy(positive_predictions, negative_predictions), "AUC": auc(positive_predictions, negative_predictions), } class FactorizationMachine(torch.nn.Module): def __init__( self, num_fields: int, num_features: int, k: int = 8, init_stddev=0.01 ) -> None: super().__init__() self.b = torch.nn.Parameter(torch.zeros(1)) self.w = torch.nn.Embedding(num_features, 1) self.v = torch.nn.Embedding(num_features, k) torch.nn.init.normal_(self.w.weight, mean=0.0, std=init_stddev) torch.nn.init.normal_(self.v.weight, mean=0.0, std=init_stddev) def forward(self, indices, values): linear = self.b + torch.sum(self.w(indices) * values.unsqueeze(-1), dim=1) x = self.v(indices) * values.unsqueeze(-1) square_of_sum = torch.sum(x, dim=1) ** 2 sum_of_square = torch.sum(x ** 2, dim=1) interaction = 0.5 * torch.sum( square_of_sum - sum_of_square, dim=1, keepdim=True ) fm = linear + interaction return fm.squeeze() def predict(self, indices, values): return self.forward(indices, values) def fit( self, train: Dataset, test: Dataset, epochs: int = 20, lr: float = 0.01, reg: float = 0.0001, batch_size: int = 1024, device: str = "cpu", verbose: bool = True, ) -> dict[str, float]: self.to(device) aligned_indices, aligned_values = train.aligned() indices = torch.tensor(aligned_indices, dtype=torch.long).to(device) values = torch.tensor(aligned_values, dtype=torch.float).to(device) targets = torch.tensor(train.targets, dtype=torch.float).to(device) optimizer = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=reg) criterion = torch.nn.BCEWithLogitsLoss() evaluator = Evaluator(test, device) score = evaluator.evaluate(self) for epoch in range(epochs): for i in tqdm( range(0, len(indices), batch_size), disable=not verbose, desc=f"Epoch {epoch + 1}/{epochs}", postfix=score, ): batch_indices = indices[i: i + batch_size] batch_values = values[i: i + batch_size] batch_targets = targets[i: i + batch_size] optimizer.zero_grad() output = self.forward(batch_indices, batch_values) loss = criterion(output, batch_targets) loss.backward() optimizer.step() score = evaluator.evaluate(self) return score @click.command() @click.argument("dataset") @click.option('-dim', default=8, help="Number of factors.") @click.option('-iter', default=20, help="Number of epochs.") @click.option('-learn_rate', default=0.01, help="Learning rate.") @click.option('-regular', default=0.0001, help="Regularization.") @click.option('-device', default="cpu", help="Device to use.") def main(dataset: str, dim: int, iter: int, learn_rate: float, regular: float, device: str): train, test = load_dataset(dataset) model = FactorizationMachine(train.num_fields, train.num_features, dim) model.fit(train, test, iter, learn_rate, regular, device=device) if __name__ == '__main__': main() ================================================ FILE: model/ctr/model_test.go ================================================ // Copyright 2020 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package ctr import ( "bytes" "runtime" "testing" "github.com/gorse-io/gorse/dataset" "github.com/gorse-io/gorse/model" "github.com/samber/lo" "github.com/stretchr/testify/assert" ) const classificationDelta = 0.01 func newFitConfigWithTestTracker() *FitConfig { cfg := NewFitConfig().SetVerbose(1).SetJobs(runtime.NumCPU()) return cfg } func TestFactorizationMachines_Classification_Frappe(t *testing.T) { // python .\model.py frappe -dim 8 -iter 10 -learn_rate 0.01 -regular 0.0001 train, test, err := LoadDataFromBuiltIn("frappe") assert.NoError(t, err) m := NewAFM(model.Params{ model.NFactors: 8, model.NEpochs: 10, model.Lr: 0.01, model.Reg: 0.0001, model.BatchSize: 1024, }) fitConfig := newFitConfigWithTestTracker() score := m.Fit(t.Context(), train, test, fitConfig) assert.InDelta(t, 0.919, score.Accuracy, classificationDelta) } func TestFactorizationMachines_Classification_MovieLens(t *testing.T) { t.Skip("Skip time-consuming test") // python .\model.py ml-tag -dim 8 -iter 10 -learn_rate 0.01 -regular 0.0001 train, test, err := LoadDataFromBuiltIn("ml-tag") assert.NoError(t, err) m := NewAFM(model.Params{ model.InitStdDev: 0.01, model.NFactors: 8, model.NEpochs: 10, model.Lr: 0.001, model.Reg: 0.0001, model.BatchSize: 1024, }) fitConfig := newFitConfigWithTestTracker() score := m.Fit(t.Context(), train, test, fitConfig) assert.InDelta(t, 0.815, score.Accuracy, classificationDelta) } func TestFactorizationMachines_Classification_Criteo(t *testing.T) { // python .\model.py criteo -dim 8 -iter 10 -learn_rate 0.01 -regular 0.0001 train, test, err := LoadDataFromBuiltIn("criteo") assert.NoError(t, err) m := NewAFM(model.Params{ model.NFactors: 8, model.NEpochs: 10, model.Lr: 0.01, model.Reg: 0.0001, model.BatchSize: 1024, }) fitConfig := newFitConfigWithTestTracker() score := m.Fit(t.Context(), train, test, fitConfig) assert.InDelta(t, 0.77, score.Accuracy, 0.025) // test prediction assert.Equal(t, m.BatchInternalPredict( []lo.Tuple2[[]int32, []float32]{{A: []int32{1, 2, 3, 4, 5, 6}, B: []float32{1, 1, 0.3, 0.4, 0.5, 0.6}}}, make([][][]float32, 2), fitConfig.Jobs), m.BatchPredict([]lo.Tuple4[string, string, []Label, []Label]{{ A: "1", B: "2", C: []Label{ {Name: "3", Value: 0.3}, {Name: "4", Value: 0.4}, }, D: []Label{ {Name: "5", Value: 0.5}, {Name: "6", Value: 0.6}, }}}, make([][]Embedding, 2), fitConfig.Jobs)) // test marshal and unmarshal buf := bytes.NewBuffer(nil) err = MarshalModel(buf, m) assert.NoError(t, err) tmp, err := UnmarshalModel(buf) assert.NoError(t, err) scoreClone := EvaluateClassification(tmp, test, fitConfig.Jobs) assert.InDelta(t, 0.77, scoreClone.Accuracy, 0.02) // test clear assert.False(t, m.Invalid()) m.Clear() assert.True(t, m.Invalid()) } func newSynthesisDataset() *Dataset { builder := dataset.NewUnifiedMapIndexBuilder() builder.AddUser("u0") builder.AddUser("u1") builder.AddUserLabel("ul0") builder.AddUserLabel("ul1") builder.AddUserLabel("ul2") builder.AddItem("i0") builder.AddItem("i1") builder.AddItemLabel("il0") builder.AddItemLabel("il1") builder.AddItemLabel("il2") dataSet := NewMapIndexDataset() dataSet.Index = builder.Build() dataSet.UserLabels = [][]lo.Tuple2[int32, float32]{ {{A: 0, B: 1.0}, {A: 1, B: 0.5}, {A: 2, B: -1.0}}, {{A: 0, B: -1.0}, {A: 1, B: -0.5}, {A: 2, B: 1.0}}, } dataSet.ItemLabels = [][]lo.Tuple2[int32, float32]{ {{A: 0, B: 1.0}, {A: 1, B: 0.5}, {A: 2, B: -1.0}}, {{A: 0, B: -1.0}, {A: 1, B: -0.5}, {A: 2, B: 1.0}}, } dataSet.ItemEmbeddingIndex = dataset.NewMapIndex() dataSet.ItemEmbeddingIndex.Add("e1") dataSet.ItemEmbeddingIndex.Add("e2") dataSet.ItemEmbeddingDimension = []int{3, 4} dataSet.ItemEmbeddings = [][][]float32{ {{0.8, 0.8, 0.8}, {0.1, 0.1, 0.1, 0.1}}, {{-0.8, -0.8, -0.8}, {-0.1, -0.1, -0.1, -0.1}}, } dataSet.Users = []int32{0, 0, 1, 1} dataSet.Items = []int32{0, 1, 0, 1} dataSet.Target = []float32{1, -1, -1, 1} dataSet.PositiveCount = 2 dataSet.NegativeCount = 2 return dataSet } func TestFactorizationMachines_Classification_Synthesis(t *testing.T) { dataSet := newSynthesisDataset() fitConfig := newFitConfigWithTestTracker() m := NewAFM(nil) score := m.Fit(t.Context(), dataSet, dataSet, fitConfig) assert.GreaterOrEqual(t, score.Accuracy, float32(0.5)) buf := bytes.NewBuffer(nil) err := MarshalModel(buf, m) assert.NoError(t, err) clone, err := UnmarshalModel(buf) assert.NoError(t, err) cloneScore := EvaluateClassification(clone, dataSet, fitConfig.Jobs) assert.InDelta(t, score.Accuracy, cloneScore.Accuracy, 0.05) indicesPos, valuesPos, embeddingsPos, _ := dataSet.Get(0) indicesNeg, valuesNeg, embeddingsNeg, _ := dataSet.Get(1) assert.Equal(t, m.BatchInternalPredict( []lo.Tuple2[[]int32, []float32]{ {A: indicesPos, B: valuesPos}, {A: indicesNeg, B: valuesNeg}, }, [][][]float32{embeddingsPos, embeddingsNeg}, fitConfig.Jobs, ), m.BatchPredict( []lo.Tuple4[string, string, []Label, []Label]{ { A: "u0", B: "i0", C: []Label{{Name: "ul0", Value: 1.0}, {Name: "ul1", Value: 0.5}, {Name: "ul2", Value: -1.0}}, D: []Label{{Name: "il0", Value: 1.0}, {Name: "il1", Value: 0.5}, {Name: "il2", Value: -1.0}}, }, { A: "u0", B: "i1", C: []Label{{Name: "ul0", Value: 1.0}, {Name: "ul1", Value: 0.5}, {Name: "ul2", Value: -1.0}}, D: []Label{{Name: "il0", Value: -1.0}, {Name: "il1", Value: -0.5}, {Name: "il2", Value: 1.0}}, }, }, [][]Embedding{ {{Name: "e1", Value: embeddingsPos[0]}, {Name: "e2", Value: embeddingsPos[1]}}, {{Name: "e1", Value: embeddingsNeg[0]}, {Name: "e2", Value: embeddingsNeg[1]}}, }, fitConfig.Jobs, )) assert.Len(t, m.BatchPredict( []lo.Tuple4[string, string, []Label, []Label]{ {A: "u0", B: "i0"}, {A: "u0", B: "i1"}, { A: "u0", B: "i0", C: []Label{{Name: "ul_unknown", Value: 1}}, D: []Label{{Name: "il_unknown", Value: 1}}, }, { A: "u0", B: "i1", C: []Label{{Name: "ul_unknown", Value: 1}}, D: []Label{{Name: "il_unknown", Value: 1}}, }, }, [][]Embedding{ {}, {}, {{Name: "unknown_embedding", Value: make([]float32, 3)}}, {{Name: "unknown_embedding", Value: make([]float32, 3)}}, }, fitConfig.Jobs, ), 4) } ================================================ FILE: model/ctr/optimize.go ================================================ // Copyright 2020 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package ctr import ( "context" "github.com/c-bata/goptuna" "github.com/gorse-io/gorse/common/monitor" "github.com/gorse-io/gorse/dataset" "github.com/gorse-io/gorse/storage/meta" "github.com/juju/errors" "golang.org/x/exp/maps" ) type ModelCreator func() FactorizationMachines type ModelSearch struct { modelCreators map[string]ModelCreator modelTypes []string trainSet dataset.CTRSplit testSet dataset.CTRSplit config *FitConfig ctx context.Context span *monitor.Span result meta.Model[Score] } func NewModelSearch(models map[string]ModelCreator, trainSet, testSet dataset.CTRSplit, config *FitConfig) *ModelSearch { return &ModelSearch{ modelCreators: models, modelTypes: maps.Keys(models), trainSet: trainSet, testSet: testSet, config: config, } } func (ms *ModelSearch) WithContext(ctx context.Context) *ModelSearch { ms.ctx = ctx return ms } func (ms *ModelSearch) WithSpan(span *monitor.Span) *ModelSearch { ms.span = span return ms } func (ms *ModelSearch) Objective(trial goptuna.Trial) (float64, error) { if len(ms.modelCreators) == 0 { return 0, errors.New("no model to search") } modelType, err := trial.SuggestCategorical("Model", ms.modelTypes) if err != nil { return 0, errors.Trace(err) } m := ms.modelCreators[modelType]() m.SetParams(m.SuggestParams(trial)) score := m.Fit(ms.ctx, ms.trainSet, ms.testSet, ms.config) if score.AUC > ms.result.Score.AUC { ms.result = meta.Model[Score]{ Type: modelType, Params: m.GetParams(), Score: score, } } if ms.span != nil { ms.span.Add(1) } return float64(score.AUC), nil } func (ms *ModelSearch) Result() meta.Model[Score] { return ms.result } ================================================ FILE: model/ctr/optimize_test.go ================================================ // Copyright 2021 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package ctr import ( "context" "io" "testing" "github.com/c-bata/goptuna" "github.com/c-bata/goptuna/tpe" "github.com/gorse-io/gorse/dataset" "github.com/gorse-io/gorse/model" "github.com/samber/lo" "github.com/stretchr/testify/assert" ) // NewMapIndexDataset creates a data set. func NewMapIndexDataset() *Dataset { s := new(Dataset) s.Index = dataset.NewUnifiedDirectIndex(0) return s } type mockFactorizationMachineForSearch struct { model.BaseModel } func (m *mockFactorizationMachineForSearch) Marshal(_ io.Writer) error { panic("implement me") } func (m *mockFactorizationMachineForSearch) Invalid() bool { panic("implement me") } func (m *mockFactorizationMachineForSearch) GetUserIndex() dataset.Index { panic("don't call me") } func (m *mockFactorizationMachineForSearch) GetItemIndex() dataset.Index { panic("don't call me") } func (m *mockFactorizationMachineForSearch) Fit(_ context.Context, _, _ dataset.CTRSplit, cfg *FitConfig) Score { score := float32(0) score += m.Params.GetFloat32(model.NFactors, 0.0) score += m.Params.GetFloat32(model.InitMean, 0.0) score += m.Params.GetFloat32(model.InitStdDev, 0.0) return Score{AUC: score} } func (m *mockFactorizationMachineForSearch) Predict(_, _ string, _, _ []Label) float32 { panic("don't call me") } func (m *mockFactorizationMachineForSearch) InternalPredict(_ []int32, _ []float32) float32 { panic("don't call me") } func (m *mockFactorizationMachineForSearch) Clear() { // do nothing } func (m *mockFactorizationMachineForSearch) GetParamsGrid(_ bool) model.ParamsGrid { return model.ParamsGrid{ model.NFactors: []interface{}{1, 2, 3, 4}, model.InitMean: []interface{}{4, 3, 2, 1}, model.InitStdDev: []interface{}{4, 4, 4, 4}, } } func (m *mockFactorizationMachineForSearch) SuggestParams(trial goptuna.Trial) model.Params { return model.Params{ model.NFactors: lo.Must(trial.SuggestDiscreteFloat(string(model.NFactors), 1, 4, 1)), model.InitMean: lo.Must(trial.SuggestDiscreteFloat(string(model.InitMean), 1, 4, 1)), model.InitStdDev: lo.Must(trial.SuggestDiscreteFloat(string(model.InitStdDev), 4, 4, 1)), } } func TestTPE(t *testing.T) { search := NewModelSearch(map[string]ModelCreator{ "mock": func() FactorizationMachines { return &mockFactorizationMachineForSearch{} }, }, nil, nil, nil) study, err := goptuna.CreateStudy("TestTPE", goptuna.StudyOptionDirection(goptuna.StudyDirectionMaximize), goptuna.StudyOptionSampler(tpe.NewSampler())) assert.NoError(t, err) err = study.Optimize(search.Objective, 10) assert.NoError(t, err) v, _ := study.GetBestValue() assert.Equal(t, float64(12), v) result := search.Result() assert.Equal(t, "mock", result.Type) assert.Equal(t, model.Params{ model.NFactors: float64(4), model.InitMean: float64(4), model.InitStdDev: float64(4), }, result.Params) assert.Equal(t, Score{AUC: 12}, result.Score) } ================================================ FILE: model/model.go ================================================ // Copyright 2020 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package model import ( "github.com/c-bata/goptuna" "github.com/gorse-io/gorse/common/util" ) // Model is the interface for all models. Any model in this // package should implement it. type Model interface { SetParams(params Params) GetParams() Params SuggestParams(trial goptuna.Trial) Params Clear() Invalid() bool } // BaseModel model must be included by every recommendation model. Hyper-parameters, // ID sets, random generator and fitting options are managed the BaseModel model. type BaseModel struct { Params Params // Hyper-parameters rng util.RandomGenerator // Random generator randState int64 // Random seed } // SetParams sets hyper-parameters for the BaseModel model. func (model *BaseModel) SetParams(params Params) { model.Params = params model.randState = model.Params.GetInt64(RandomState, 0) model.rng = util.NewRandomGenerator(model.randState) } // GetParams returns all hyper-parameters. func (model *BaseModel) GetParams() Params { return model.Params } func (model *BaseModel) GetRandomGenerator() util.RandomGenerator { return model.rng } ================================================ FILE: model/params.go ================================================ // Copyright 2020 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package model import ( "reflect" "github.com/gorse-io/gorse/common/log" "go.uber.org/zap" ) /* ParamName */ // ParamName is the type of hyper-parameter names. type ParamName string // Predefined hyper-parameter names const ( Lr ParamName = "Lr" // learning rate Reg ParamName = "Reg" // regularization strength NEpochs ParamName = "NEpochs" // number of epochs NFactors ParamName = "NFactors" // number of factors RandomState ParamName = "RandomState" // random state (seed) InitMean ParamName = "InitMean" // mean of gaussian initial parameter InitStdDev ParamName = "InitStdDev" // standard deviation of gaussian initial parameter Alpha ParamName = "Alpha" // weight for negative samples in ALS Similarity ParamName = "Similarity" UseFeature ParamName = "UseFeature" BatchSize ParamName = "BatchSize" HiddenLayers ParamName = "HiddenLayers" Optimizer ParamName = "Optimizer" SGD = "sgd" Adam = "adam" ) // Params stores hyper-parameters for an model. It is a map between strings // (names) and interface{}s (values). For example, hyper-parameters for SVD // is given by: // // base.Params{ // base.Lr: 0.007, // base.NEpochs: 100, // base.NFactors: 80, // base.Reg: 0.1, // } type Params map[ParamName]interface{} // Copy hyper-parameters. func (parameters Params) Copy() Params { newParams := make(Params) for k, v := range parameters { newParams[k] = v } return newParams } // GetBool gets a boolean parameter by name. Returns _default if not exists or type doesn't match. func (parameters Params) GetBool(name ParamName, _default bool) bool { if val, exist := parameters[name]; exist { switch val := val.(type) { case bool: return val default: log.Logger().Error("type mismatch", zap.String("param_name", string(name)), zap.String("actual_type", reflect.TypeOf(name).Name())) } } return _default } // GetInt gets a integer parameter by name. Returns _default if not exists or type doesn't match. func (parameters Params) GetInt(name ParamName, _default int) int { if val, exist := parameters[name]; exist { switch val := val.(type) { case int: return val default: log.Logger().Error("type mismatch", zap.String("param_name", string(name)), zap.String("actual_type", reflect.TypeOf(name).Name())) } } return _default } // GetInt64 gets a int64 parameter by name. Returns _default if not exists or type doesn't match. The // type will be converted if given int. func (parameters Params) GetInt64(name ParamName, _default int64) int64 { if val, exist := parameters[name]; exist { switch val := val.(type) { case int64: return val case int: return int64(val) default: log.Logger().Error("type mismatch", zap.String("param_name", string(name)), zap.String("actual_type", reflect.TypeOf(name).Name())) } } return _default } func (parameters Params) GetFloat32(name ParamName, _default float32) float32 { if val, exist := parameters[name]; exist { switch val := val.(type) { case float32: return val case float64: return float32(val) case int: return float32(val) default: log.Logger().Error("type mismatch", zap.String("param_name", string(name)), zap.String("actual_type", reflect.TypeOf(name).Name())) } } return _default } // GetString gets a string parameter func (parameters Params) GetString(name ParamName, _default string) string { if val, exist := parameters[name]; exist { return val.(string) } return _default } func (parameters Params) GetIntSlice(name ParamName, _default []int) []int { if val, exist := parameters[name]; exist { switch val := val.(type) { case []int: return val default: log.Logger().Error("type mismatch", zap.String("param_name", string(name)), zap.String("actual_type", reflect.TypeOf(name).Name())) } } return _default } func (parameters Params) Overwrite(params Params) Params { merged := make(Params) for k, v := range parameters { merged[k] = v } for k, v := range params { merged[k] = v } return merged } // ParamsGrid contains candidate for grid search. type ParamsGrid map[ParamName][]interface{} func (grid ParamsGrid) Len() int { return len(grid) } func (grid ParamsGrid) NumCombinations() int { count := 1 for _, values := range grid { count *= len(values) } return count } func (grid ParamsGrid) Fill(_default ParamsGrid) { for param, values := range _default { if _, exist := grid[param]; !exist { grid[param] = values } } } ================================================ FILE: model/params_test.go ================================================ // Copyright 2020 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package model import ( "testing" "github.com/stretchr/testify/assert" ) func TestParams_Copy(t *testing.T) { // Create parameters a := Params{ NFactors: 1, Lr: 0.1, RandomState: 0, } // Create copy b := a.Copy() b[NFactors] = 2 b[Lr] = 0.2 b[RandomState] = 1 // Check original parameters assert.Equal(t, 1, a.GetInt(NFactors, -1)) assert.Equal(t, float32(0.1), a.GetFloat32(Lr, -0.1)) assert.Equal(t, int64(0), a.GetInt64(RandomState, -1)) // Check copy parameters assert.Equal(t, 2, b.GetInt(NFactors, -1)) assert.Equal(t, float32(0.2), b.GetFloat32(Lr, -0.1)) assert.Equal(t, int64(1), b.GetInt64(RandomState, -1)) } func TestParams_GetFloat32(t *testing.T) { p := Params{} // Empty case assert.Equal(t, float32(0.1), p.GetFloat32(Lr, 0.1)) // Normal case p[Lr] = float32(1.0) assert.Equal(t, float32(1.0), p.GetFloat32(Lr, 0.1)) // Convertible case p[Lr] = 2.0 assert.Equal(t, float32(2.0), p.GetFloat32(Lr, 0.1)) p[Lr] = int(3) assert.Equal(t, float32(3.0), p.GetFloat32(Lr, 0.1)) // Wrong type case p[Lr] = 1 assert.Equal(t, float32(1.0), p.GetFloat32(Lr, 0.1)) p[Lr] = "hello" assert.Equal(t, float32(0.1), p.GetFloat32(Lr, 0.1)) } func TestParams_GetBool(t *testing.T) { p := Params{} // Empty case assert.True(t, p.GetBool(UseFeature, true)) // Normal case p[UseFeature] = false assert.False(t, p.GetBool(UseFeature, true)) // Wrong type case p[UseFeature] = 1 assert.True(t, p.GetBool(UseFeature, true)) } func TestParams_GetInt(t *testing.T) { p := Params{} // Empty case assert.Equal(t, -1, p.GetInt(NFactors, -1)) // Normal case p[NFactors] = 0 assert.Equal(t, 0, p.GetInt(NFactors, -1)) // Wrong type case p[NFactors] = "hello" assert.Equal(t, -1, p.GetInt(NFactors, -1)) } func TestParams_GetInt64(t *testing.T) { p := Params{} // Empty case assert.Equal(t, int64(-1), p.GetInt64(RandomState, -1)) // Normal case p[RandomState] = int64(0) assert.Equal(t, int64(0), p.GetInt64(RandomState, -1)) // Wrong type case p[RandomState] = 0 assert.Equal(t, int64(0), p.GetInt64(RandomState, -1)) p[RandomState] = "hello" assert.Equal(t, int64(-1), p.GetInt64(RandomState, -1)) } func TestParams_GetString(t *testing.T) { p := Params{} // Empty case assert.Equal(t, "xyz", p.GetString(Similarity, "xyz")) // Normal case p[Similarity] = "abc" assert.Equal(t, "abc", p.GetString(Similarity, "abc")) } func TestParams_GetIntSlice(t *testing.T) { p := Params{} // Empty case assert.Equal(t, []int{1, 2, 3}, p.GetIntSlice(HiddenLayers, []int{1, 2, 3})) // Normal case p[HiddenLayers] = []int{4, 5, 6} assert.Equal(t, []int{4, 5, 6}, p.GetIntSlice(HiddenLayers, []int{1, 2, 3})) // Wrong type case p[HiddenLayers] = []string{"hello"} assert.Equal(t, []int{1, 2, 3}, p.GetIntSlice(HiddenLayers, []int{1, 2, 3})) } func TestParams_Overwrite(t *testing.T) { a := Params{ NFactors: 10, Lr: 0.5, } b := Params{ NEpochs: 100, NFactors: 20, } c := a.Overwrite(b) assert.Equal(t, 20, c[NFactors]) assert.Equal(t, 0.5, c[Lr]) assert.Equal(t, 100, c[NEpochs]) } func TestParamsGrid(t *testing.T) { grid := ParamsGrid{} grid["a"] = []interface{}{0, 1} defaultGrid := ParamsGrid{} defaultGrid["a"] = []interface{}{2, 3} defaultGrid["b"] = []interface{}{4, 5} assert.Equal(t, 1, grid.Len()) grid.Fill(defaultGrid) assert.Equal(t, []interface{}{0, 1}, grid["a"]) assert.Equal(t, []interface{}{4, 5}, grid["b"]) assert.Equal(t, 4, grid.NumCombinations()) } ================================================ FILE: protocol/cache_store.pb.go ================================================ // Copyright 2024 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.10 // protoc v6.33.1 // source: cache_store.proto package protocol import ( protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" timestamppb "google.golang.org/protobuf/types/known/timestamppb" reflect "reflect" sync "sync" unsafe "unsafe" ) const ( // Verify that this generated code is sufficiently up-to-date. _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) // Verify that runtime/protoimpl is sufficiently up-to-date. _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) type Value struct { state protoimpl.MessageState `protogen:"open.v1"` Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` Value string `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *Value) Reset() { *x = Value{} mi := &file_cache_store_proto_msgTypes[0] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *Value) String() string { return protoimpl.X.MessageStringOf(x) } func (*Value) ProtoMessage() {} func (x *Value) ProtoReflect() protoreflect.Message { mi := &file_cache_store_proto_msgTypes[0] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use Value.ProtoReflect.Descriptor instead. func (*Value) Descriptor() ([]byte, []int) { return file_cache_store_proto_rawDescGZIP(), []int{0} } func (x *Value) GetName() string { if x != nil { return x.Name } return "" } func (x *Value) GetValue() string { if x != nil { return x.Value } return "" } type Score struct { state protoimpl.MessageState `protogen:"open.v1"` Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` Score float64 `protobuf:"fixed64,2,opt,name=score,proto3" json:"score,omitempty"` IsHidden bool `protobuf:"varint,3,opt,name=is_hidden,json=isHidden,proto3" json:"is_hidden,omitempty"` Categories []string `protobuf:"bytes,4,rep,name=categories,proto3" json:"categories,omitempty"` Timestamp *timestamppb.Timestamp `protobuf:"bytes,5,opt,name=timestamp,proto3" json:"timestamp,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *Score) Reset() { *x = Score{} mi := &file_cache_store_proto_msgTypes[1] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *Score) String() string { return protoimpl.X.MessageStringOf(x) } func (*Score) ProtoMessage() {} func (x *Score) ProtoReflect() protoreflect.Message { mi := &file_cache_store_proto_msgTypes[1] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use Score.ProtoReflect.Descriptor instead. func (*Score) Descriptor() ([]byte, []int) { return file_cache_store_proto_rawDescGZIP(), []int{1} } func (x *Score) GetId() string { if x != nil { return x.Id } return "" } func (x *Score) GetScore() float64 { if x != nil { return x.Score } return 0 } func (x *Score) GetIsHidden() bool { if x != nil { return x.IsHidden } return false } func (x *Score) GetCategories() []string { if x != nil { return x.Categories } return nil } func (x *Score) GetTimestamp() *timestamppb.Timestamp { if x != nil { return x.Timestamp } return nil } type ScoreCondition struct { state protoimpl.MessageState `protogen:"open.v1"` Subset *string `protobuf:"bytes,1,opt,name=subset,proto3,oneof" json:"subset,omitempty"` Id *string `protobuf:"bytes,2,opt,name=id,proto3,oneof" json:"id,omitempty"` Before *timestamppb.Timestamp `protobuf:"bytes,3,opt,name=before,proto3,oneof" json:"before,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *ScoreCondition) Reset() { *x = ScoreCondition{} mi := &file_cache_store_proto_msgTypes[2] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *ScoreCondition) String() string { return protoimpl.X.MessageStringOf(x) } func (*ScoreCondition) ProtoMessage() {} func (x *ScoreCondition) ProtoReflect() protoreflect.Message { mi := &file_cache_store_proto_msgTypes[2] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use ScoreCondition.ProtoReflect.Descriptor instead. func (*ScoreCondition) Descriptor() ([]byte, []int) { return file_cache_store_proto_rawDescGZIP(), []int{2} } func (x *ScoreCondition) GetSubset() string { if x != nil && x.Subset != nil { return *x.Subset } return "" } func (x *ScoreCondition) GetId() string { if x != nil && x.Id != nil { return *x.Id } return "" } func (x *ScoreCondition) GetBefore() *timestamppb.Timestamp { if x != nil { return x.Before } return nil } type ScorePatch struct { state protoimpl.MessageState `protogen:"open.v1"` IsHidden *bool `protobuf:"varint,1,opt,name=is_hidden,json=isHidden,proto3,oneof" json:"is_hidden,omitempty"` Categories []string `protobuf:"bytes,2,rep,name=categories,proto3" json:"categories,omitempty"` Score *float64 `protobuf:"fixed64,3,opt,name=score,proto3,oneof" json:"score,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *ScorePatch) Reset() { *x = ScorePatch{} mi := &file_cache_store_proto_msgTypes[3] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *ScorePatch) String() string { return protoimpl.X.MessageStringOf(x) } func (*ScorePatch) ProtoMessage() {} func (x *ScorePatch) ProtoReflect() protoreflect.Message { mi := &file_cache_store_proto_msgTypes[3] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use ScorePatch.ProtoReflect.Descriptor instead. func (*ScorePatch) Descriptor() ([]byte, []int) { return file_cache_store_proto_rawDescGZIP(), []int{3} } func (x *ScorePatch) GetIsHidden() bool { if x != nil && x.IsHidden != nil { return *x.IsHidden } return false } func (x *ScorePatch) GetCategories() []string { if x != nil { return x.Categories } return nil } func (x *ScorePatch) GetScore() float64 { if x != nil && x.Score != nil { return *x.Score } return 0 } type TimeSeriesPoint struct { state protoimpl.MessageState `protogen:"open.v1"` Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` Timestamp *timestamppb.Timestamp `protobuf:"bytes,2,opt,name=timestamp,proto3" json:"timestamp,omitempty"` Value float64 `protobuf:"fixed64,3,opt,name=value,proto3" json:"value,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *TimeSeriesPoint) Reset() { *x = TimeSeriesPoint{} mi := &file_cache_store_proto_msgTypes[4] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *TimeSeriesPoint) String() string { return protoimpl.X.MessageStringOf(x) } func (*TimeSeriesPoint) ProtoMessage() {} func (x *TimeSeriesPoint) ProtoReflect() protoreflect.Message { mi := &file_cache_store_proto_msgTypes[4] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use TimeSeriesPoint.ProtoReflect.Descriptor instead. func (*TimeSeriesPoint) Descriptor() ([]byte, []int) { return file_cache_store_proto_rawDescGZIP(), []int{4} } func (x *TimeSeriesPoint) GetName() string { if x != nil { return x.Name } return "" } func (x *TimeSeriesPoint) GetTimestamp() *timestamppb.Timestamp { if x != nil { return x.Timestamp } return nil } func (x *TimeSeriesPoint) GetValue() float64 { if x != nil { return x.Value } return 0 } type GetRequest struct { state protoimpl.MessageState `protogen:"open.v1"` Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *GetRequest) Reset() { *x = GetRequest{} mi := &file_cache_store_proto_msgTypes[5] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *GetRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*GetRequest) ProtoMessage() {} func (x *GetRequest) ProtoReflect() protoreflect.Message { mi := &file_cache_store_proto_msgTypes[5] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use GetRequest.ProtoReflect.Descriptor instead. func (*GetRequest) Descriptor() ([]byte, []int) { return file_cache_store_proto_rawDescGZIP(), []int{5} } func (x *GetRequest) GetName() string { if x != nil { return x.Name } return "" } type GetResponse struct { state protoimpl.MessageState `protogen:"open.v1"` Value *string `protobuf:"bytes,1,opt,name=value,proto3,oneof" json:"value,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *GetResponse) Reset() { *x = GetResponse{} mi := &file_cache_store_proto_msgTypes[6] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *GetResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*GetResponse) ProtoMessage() {} func (x *GetResponse) ProtoReflect() protoreflect.Message { mi := &file_cache_store_proto_msgTypes[6] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use GetResponse.ProtoReflect.Descriptor instead. func (*GetResponse) Descriptor() ([]byte, []int) { return file_cache_store_proto_rawDescGZIP(), []int{6} } func (x *GetResponse) GetValue() string { if x != nil && x.Value != nil { return *x.Value } return "" } type SetRequest struct { state protoimpl.MessageState `protogen:"open.v1"` Values []*Value `protobuf:"bytes,1,rep,name=values,proto3" json:"values,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *SetRequest) Reset() { *x = SetRequest{} mi := &file_cache_store_proto_msgTypes[7] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *SetRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*SetRequest) ProtoMessage() {} func (x *SetRequest) ProtoReflect() protoreflect.Message { mi := &file_cache_store_proto_msgTypes[7] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use SetRequest.ProtoReflect.Descriptor instead. func (*SetRequest) Descriptor() ([]byte, []int) { return file_cache_store_proto_rawDescGZIP(), []int{7} } func (x *SetRequest) GetValues() []*Value { if x != nil { return x.Values } return nil } type SetResponse struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *SetResponse) Reset() { *x = SetResponse{} mi := &file_cache_store_proto_msgTypes[8] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *SetResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*SetResponse) ProtoMessage() {} func (x *SetResponse) ProtoReflect() protoreflect.Message { mi := &file_cache_store_proto_msgTypes[8] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use SetResponse.ProtoReflect.Descriptor instead. func (*SetResponse) Descriptor() ([]byte, []int) { return file_cache_store_proto_rawDescGZIP(), []int{8} } type DeleteRequest struct { state protoimpl.MessageState `protogen:"open.v1"` Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *DeleteRequest) Reset() { *x = DeleteRequest{} mi := &file_cache_store_proto_msgTypes[9] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *DeleteRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*DeleteRequest) ProtoMessage() {} func (x *DeleteRequest) ProtoReflect() protoreflect.Message { mi := &file_cache_store_proto_msgTypes[9] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use DeleteRequest.ProtoReflect.Descriptor instead. func (*DeleteRequest) Descriptor() ([]byte, []int) { return file_cache_store_proto_rawDescGZIP(), []int{9} } func (x *DeleteRequest) GetName() string { if x != nil { return x.Name } return "" } type DeleteResponse struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *DeleteResponse) Reset() { *x = DeleteResponse{} mi := &file_cache_store_proto_msgTypes[10] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *DeleteResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*DeleteResponse) ProtoMessage() {} func (x *DeleteResponse) ProtoReflect() protoreflect.Message { mi := &file_cache_store_proto_msgTypes[10] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use DeleteResponse.ProtoReflect.Descriptor instead. func (*DeleteResponse) Descriptor() ([]byte, []int) { return file_cache_store_proto_rawDescGZIP(), []int{10} } type PushRequest struct { state protoimpl.MessageState `protogen:"open.v1"` Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` Value string `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *PushRequest) Reset() { *x = PushRequest{} mi := &file_cache_store_proto_msgTypes[11] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *PushRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*PushRequest) ProtoMessage() {} func (x *PushRequest) ProtoReflect() protoreflect.Message { mi := &file_cache_store_proto_msgTypes[11] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use PushRequest.ProtoReflect.Descriptor instead. func (*PushRequest) Descriptor() ([]byte, []int) { return file_cache_store_proto_rawDescGZIP(), []int{11} } func (x *PushRequest) GetName() string { if x != nil { return x.Name } return "" } func (x *PushRequest) GetValue() string { if x != nil { return x.Value } return "" } type PushResponse struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *PushResponse) Reset() { *x = PushResponse{} mi := &file_cache_store_proto_msgTypes[12] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *PushResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*PushResponse) ProtoMessage() {} func (x *PushResponse) ProtoReflect() protoreflect.Message { mi := &file_cache_store_proto_msgTypes[12] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use PushResponse.ProtoReflect.Descriptor instead. func (*PushResponse) Descriptor() ([]byte, []int) { return file_cache_store_proto_rawDescGZIP(), []int{12} } type PopRequest struct { state protoimpl.MessageState `protogen:"open.v1"` Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *PopRequest) Reset() { *x = PopRequest{} mi := &file_cache_store_proto_msgTypes[13] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *PopRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*PopRequest) ProtoMessage() {} func (x *PopRequest) ProtoReflect() protoreflect.Message { mi := &file_cache_store_proto_msgTypes[13] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use PopRequest.ProtoReflect.Descriptor instead. func (*PopRequest) Descriptor() ([]byte, []int) { return file_cache_store_proto_rawDescGZIP(), []int{13} } func (x *PopRequest) GetName() string { if x != nil { return x.Name } return "" } type PopResponse struct { state protoimpl.MessageState `protogen:"open.v1"` Value *string `protobuf:"bytes,1,opt,name=value,proto3,oneof" json:"value,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *PopResponse) Reset() { *x = PopResponse{} mi := &file_cache_store_proto_msgTypes[14] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *PopResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*PopResponse) ProtoMessage() {} func (x *PopResponse) ProtoReflect() protoreflect.Message { mi := &file_cache_store_proto_msgTypes[14] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use PopResponse.ProtoReflect.Descriptor instead. func (*PopResponse) Descriptor() ([]byte, []int) { return file_cache_store_proto_rawDescGZIP(), []int{14} } func (x *PopResponse) GetValue() string { if x != nil && x.Value != nil { return *x.Value } return "" } type RemainRequest struct { state protoimpl.MessageState `protogen:"open.v1"` Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *RemainRequest) Reset() { *x = RemainRequest{} mi := &file_cache_store_proto_msgTypes[15] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *RemainRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*RemainRequest) ProtoMessage() {} func (x *RemainRequest) ProtoReflect() protoreflect.Message { mi := &file_cache_store_proto_msgTypes[15] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use RemainRequest.ProtoReflect.Descriptor instead. func (*RemainRequest) Descriptor() ([]byte, []int) { return file_cache_store_proto_rawDescGZIP(), []int{15} } func (x *RemainRequest) GetName() string { if x != nil { return x.Name } return "" } type RemainResponse struct { state protoimpl.MessageState `protogen:"open.v1"` Count int64 `protobuf:"varint,1,opt,name=count,proto3" json:"count,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *RemainResponse) Reset() { *x = RemainResponse{} mi := &file_cache_store_proto_msgTypes[16] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *RemainResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*RemainResponse) ProtoMessage() {} func (x *RemainResponse) ProtoReflect() protoreflect.Message { mi := &file_cache_store_proto_msgTypes[16] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use RemainResponse.ProtoReflect.Descriptor instead. func (*RemainResponse) Descriptor() ([]byte, []int) { return file_cache_store_proto_rawDescGZIP(), []int{16} } func (x *RemainResponse) GetCount() int64 { if x != nil { return x.Count } return 0 } type AddScoresRequest struct { state protoimpl.MessageState `protogen:"open.v1"` Collection string `protobuf:"bytes,1,opt,name=collection,proto3" json:"collection,omitempty"` Subset string `protobuf:"bytes,2,opt,name=subset,proto3" json:"subset,omitempty"` Documents []*Score `protobuf:"bytes,3,rep,name=documents,proto3" json:"documents,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *AddScoresRequest) Reset() { *x = AddScoresRequest{} mi := &file_cache_store_proto_msgTypes[17] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *AddScoresRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*AddScoresRequest) ProtoMessage() {} func (x *AddScoresRequest) ProtoReflect() protoreflect.Message { mi := &file_cache_store_proto_msgTypes[17] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use AddScoresRequest.ProtoReflect.Descriptor instead. func (*AddScoresRequest) Descriptor() ([]byte, []int) { return file_cache_store_proto_rawDescGZIP(), []int{17} } func (x *AddScoresRequest) GetCollection() string { if x != nil { return x.Collection } return "" } func (x *AddScoresRequest) GetSubset() string { if x != nil { return x.Subset } return "" } func (x *AddScoresRequest) GetDocuments() []*Score { if x != nil { return x.Documents } return nil } type AddScoresResponse struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *AddScoresResponse) Reset() { *x = AddScoresResponse{} mi := &file_cache_store_proto_msgTypes[18] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *AddScoresResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*AddScoresResponse) ProtoMessage() {} func (x *AddScoresResponse) ProtoReflect() protoreflect.Message { mi := &file_cache_store_proto_msgTypes[18] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use AddScoresResponse.ProtoReflect.Descriptor instead. func (*AddScoresResponse) Descriptor() ([]byte, []int) { return file_cache_store_proto_rawDescGZIP(), []int{18} } type SearchScoresRequest struct { state protoimpl.MessageState `protogen:"open.v1"` Collection string `protobuf:"bytes,1,opt,name=collection,proto3" json:"collection,omitempty"` Subset string `protobuf:"bytes,2,opt,name=subset,proto3" json:"subset,omitempty"` Query []string `protobuf:"bytes,3,rep,name=query,proto3" json:"query,omitempty"` Begin int32 `protobuf:"varint,4,opt,name=begin,proto3" json:"begin,omitempty"` End int32 `protobuf:"varint,5,opt,name=end,proto3" json:"end,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *SearchScoresRequest) Reset() { *x = SearchScoresRequest{} mi := &file_cache_store_proto_msgTypes[19] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *SearchScoresRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*SearchScoresRequest) ProtoMessage() {} func (x *SearchScoresRequest) ProtoReflect() protoreflect.Message { mi := &file_cache_store_proto_msgTypes[19] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use SearchScoresRequest.ProtoReflect.Descriptor instead. func (*SearchScoresRequest) Descriptor() ([]byte, []int) { return file_cache_store_proto_rawDescGZIP(), []int{19} } func (x *SearchScoresRequest) GetCollection() string { if x != nil { return x.Collection } return "" } func (x *SearchScoresRequest) GetSubset() string { if x != nil { return x.Subset } return "" } func (x *SearchScoresRequest) GetQuery() []string { if x != nil { return x.Query } return nil } func (x *SearchScoresRequest) GetBegin() int32 { if x != nil { return x.Begin } return 0 } func (x *SearchScoresRequest) GetEnd() int32 { if x != nil { return x.End } return 0 } type SearchScoresResponse struct { state protoimpl.MessageState `protogen:"open.v1"` Documents []*Score `protobuf:"bytes,1,rep,name=documents,proto3" json:"documents,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *SearchScoresResponse) Reset() { *x = SearchScoresResponse{} mi := &file_cache_store_proto_msgTypes[20] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *SearchScoresResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*SearchScoresResponse) ProtoMessage() {} func (x *SearchScoresResponse) ProtoReflect() protoreflect.Message { mi := &file_cache_store_proto_msgTypes[20] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use SearchScoresResponse.ProtoReflect.Descriptor instead. func (*SearchScoresResponse) Descriptor() ([]byte, []int) { return file_cache_store_proto_rawDescGZIP(), []int{20} } func (x *SearchScoresResponse) GetDocuments() []*Score { if x != nil { return x.Documents } return nil } type DeleteScoresRequest struct { state protoimpl.MessageState `protogen:"open.v1"` Collection []string `protobuf:"bytes,1,rep,name=collection,proto3" json:"collection,omitempty"` Condition *ScoreCondition `protobuf:"bytes,2,opt,name=condition,proto3" json:"condition,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *DeleteScoresRequest) Reset() { *x = DeleteScoresRequest{} mi := &file_cache_store_proto_msgTypes[21] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *DeleteScoresRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*DeleteScoresRequest) ProtoMessage() {} func (x *DeleteScoresRequest) ProtoReflect() protoreflect.Message { mi := &file_cache_store_proto_msgTypes[21] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use DeleteScoresRequest.ProtoReflect.Descriptor instead. func (*DeleteScoresRequest) Descriptor() ([]byte, []int) { return file_cache_store_proto_rawDescGZIP(), []int{21} } func (x *DeleteScoresRequest) GetCollection() []string { if x != nil { return x.Collection } return nil } func (x *DeleteScoresRequest) GetCondition() *ScoreCondition { if x != nil { return x.Condition } return nil } type DeleteScoresResponse struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *DeleteScoresResponse) Reset() { *x = DeleteScoresResponse{} mi := &file_cache_store_proto_msgTypes[22] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *DeleteScoresResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*DeleteScoresResponse) ProtoMessage() {} func (x *DeleteScoresResponse) ProtoReflect() protoreflect.Message { mi := &file_cache_store_proto_msgTypes[22] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use DeleteScoresResponse.ProtoReflect.Descriptor instead. func (*DeleteScoresResponse) Descriptor() ([]byte, []int) { return file_cache_store_proto_rawDescGZIP(), []int{22} } type UpdateScoresRequest struct { state protoimpl.MessageState `protogen:"open.v1"` Collection []string `protobuf:"bytes,1,rep,name=collection,proto3" json:"collection,omitempty"` Subset *string `protobuf:"bytes,2,opt,name=subset,proto3,oneof" json:"subset,omitempty"` Id string `protobuf:"bytes,3,opt,name=id,proto3" json:"id,omitempty"` Patch *ScorePatch `protobuf:"bytes,4,opt,name=patch,proto3" json:"patch,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *UpdateScoresRequest) Reset() { *x = UpdateScoresRequest{} mi := &file_cache_store_proto_msgTypes[23] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *UpdateScoresRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*UpdateScoresRequest) ProtoMessage() {} func (x *UpdateScoresRequest) ProtoReflect() protoreflect.Message { mi := &file_cache_store_proto_msgTypes[23] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use UpdateScoresRequest.ProtoReflect.Descriptor instead. func (*UpdateScoresRequest) Descriptor() ([]byte, []int) { return file_cache_store_proto_rawDescGZIP(), []int{23} } func (x *UpdateScoresRequest) GetCollection() []string { if x != nil { return x.Collection } return nil } func (x *UpdateScoresRequest) GetSubset() string { if x != nil && x.Subset != nil { return *x.Subset } return "" } func (x *UpdateScoresRequest) GetId() string { if x != nil { return x.Id } return "" } func (x *UpdateScoresRequest) GetPatch() *ScorePatch { if x != nil { return x.Patch } return nil } type UpdateScoresResponse struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *UpdateScoresResponse) Reset() { *x = UpdateScoresResponse{} mi := &file_cache_store_proto_msgTypes[24] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *UpdateScoresResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*UpdateScoresResponse) ProtoMessage() {} func (x *UpdateScoresResponse) ProtoReflect() protoreflect.Message { mi := &file_cache_store_proto_msgTypes[24] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use UpdateScoresResponse.ProtoReflect.Descriptor instead. func (*UpdateScoresResponse) Descriptor() ([]byte, []int) { return file_cache_store_proto_rawDescGZIP(), []int{24} } type ScanScoresRequest struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *ScanScoresRequest) Reset() { *x = ScanScoresRequest{} mi := &file_cache_store_proto_msgTypes[25] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *ScanScoresRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*ScanScoresRequest) ProtoMessage() {} func (x *ScanScoresRequest) ProtoReflect() protoreflect.Message { mi := &file_cache_store_proto_msgTypes[25] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use ScanScoresRequest.ProtoReflect.Descriptor instead. func (*ScanScoresRequest) Descriptor() ([]byte, []int) { return file_cache_store_proto_rawDescGZIP(), []int{25} } type ScanScoresResponse struct { state protoimpl.MessageState `protogen:"open.v1"` Collection string `protobuf:"bytes,1,opt,name=collection,proto3" json:"collection,omitempty"` Id string `protobuf:"bytes,2,opt,name=id,proto3" json:"id,omitempty"` Subset string `protobuf:"bytes,3,opt,name=subset,proto3" json:"subset,omitempty"` Timestamp *timestamppb.Timestamp `protobuf:"bytes,4,opt,name=timestamp,proto3" json:"timestamp,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *ScanScoresResponse) Reset() { *x = ScanScoresResponse{} mi := &file_cache_store_proto_msgTypes[26] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *ScanScoresResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*ScanScoresResponse) ProtoMessage() {} func (x *ScanScoresResponse) ProtoReflect() protoreflect.Message { mi := &file_cache_store_proto_msgTypes[26] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use ScanScoresResponse.ProtoReflect.Descriptor instead. func (*ScanScoresResponse) Descriptor() ([]byte, []int) { return file_cache_store_proto_rawDescGZIP(), []int{26} } func (x *ScanScoresResponse) GetCollection() string { if x != nil { return x.Collection } return "" } func (x *ScanScoresResponse) GetId() string { if x != nil { return x.Id } return "" } func (x *ScanScoresResponse) GetSubset() string { if x != nil { return x.Subset } return "" } func (x *ScanScoresResponse) GetTimestamp() *timestamppb.Timestamp { if x != nil { return x.Timestamp } return nil } type AddTimeSeriesPointsRequest struct { state protoimpl.MessageState `protogen:"open.v1"` Points []*TimeSeriesPoint `protobuf:"bytes,1,rep,name=points,proto3" json:"points,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *AddTimeSeriesPointsRequest) Reset() { *x = AddTimeSeriesPointsRequest{} mi := &file_cache_store_proto_msgTypes[27] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *AddTimeSeriesPointsRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*AddTimeSeriesPointsRequest) ProtoMessage() {} func (x *AddTimeSeriesPointsRequest) ProtoReflect() protoreflect.Message { mi := &file_cache_store_proto_msgTypes[27] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use AddTimeSeriesPointsRequest.ProtoReflect.Descriptor instead. func (*AddTimeSeriesPointsRequest) Descriptor() ([]byte, []int) { return file_cache_store_proto_rawDescGZIP(), []int{27} } func (x *AddTimeSeriesPointsRequest) GetPoints() []*TimeSeriesPoint { if x != nil { return x.Points } return nil } type AddTimeSeriesPointsResponse struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *AddTimeSeriesPointsResponse) Reset() { *x = AddTimeSeriesPointsResponse{} mi := &file_cache_store_proto_msgTypes[28] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *AddTimeSeriesPointsResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*AddTimeSeriesPointsResponse) ProtoMessage() {} func (x *AddTimeSeriesPointsResponse) ProtoReflect() protoreflect.Message { mi := &file_cache_store_proto_msgTypes[28] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use AddTimeSeriesPointsResponse.ProtoReflect.Descriptor instead. func (*AddTimeSeriesPointsResponse) Descriptor() ([]byte, []int) { return file_cache_store_proto_rawDescGZIP(), []int{28} } type GetTimeSeriesPointsRequest struct { state protoimpl.MessageState `protogen:"open.v1"` Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` Begin *timestamppb.Timestamp `protobuf:"bytes,2,opt,name=begin,proto3" json:"begin,omitempty"` End *timestamppb.Timestamp `protobuf:"bytes,3,opt,name=end,proto3" json:"end,omitempty"` Duration int64 `protobuf:"varint,4,opt,name=duration,proto3" json:"duration,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *GetTimeSeriesPointsRequest) Reset() { *x = GetTimeSeriesPointsRequest{} mi := &file_cache_store_proto_msgTypes[29] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *GetTimeSeriesPointsRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*GetTimeSeriesPointsRequest) ProtoMessage() {} func (x *GetTimeSeriesPointsRequest) ProtoReflect() protoreflect.Message { mi := &file_cache_store_proto_msgTypes[29] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use GetTimeSeriesPointsRequest.ProtoReflect.Descriptor instead. func (*GetTimeSeriesPointsRequest) Descriptor() ([]byte, []int) { return file_cache_store_proto_rawDescGZIP(), []int{29} } func (x *GetTimeSeriesPointsRequest) GetName() string { if x != nil { return x.Name } return "" } func (x *GetTimeSeriesPointsRequest) GetBegin() *timestamppb.Timestamp { if x != nil { return x.Begin } return nil } func (x *GetTimeSeriesPointsRequest) GetEnd() *timestamppb.Timestamp { if x != nil { return x.End } return nil } func (x *GetTimeSeriesPointsRequest) GetDuration() int64 { if x != nil { return x.Duration } return 0 } type GetTimeSeriesPointsResponse struct { state protoimpl.MessageState `protogen:"open.v1"` Points []*TimeSeriesPoint `protobuf:"bytes,1,rep,name=points,proto3" json:"points,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *GetTimeSeriesPointsResponse) Reset() { *x = GetTimeSeriesPointsResponse{} mi := &file_cache_store_proto_msgTypes[30] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *GetTimeSeriesPointsResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*GetTimeSeriesPointsResponse) ProtoMessage() {} func (x *GetTimeSeriesPointsResponse) ProtoReflect() protoreflect.Message { mi := &file_cache_store_proto_msgTypes[30] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use GetTimeSeriesPointsResponse.ProtoReflect.Descriptor instead. func (*GetTimeSeriesPointsResponse) Descriptor() ([]byte, []int) { return file_cache_store_proto_rawDescGZIP(), []int{30} } func (x *GetTimeSeriesPointsResponse) GetPoints() []*TimeSeriesPoint { if x != nil { return x.Points } return nil } var File_cache_store_proto protoreflect.FileDescriptor const file_cache_store_proto_rawDesc = "" + "\n" + "\x11cache_store.proto\x12\bprotocol\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x0eprotocol.proto\"1\n" + "\x05Value\x12\x12\n" + "\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n" + "\x05value\x18\x02 \x01(\tR\x05value\"\xa4\x01\n" + "\x05Score\x12\x0e\n" + "\x02id\x18\x01 \x01(\tR\x02id\x12\x14\n" + "\x05score\x18\x02 \x01(\x01R\x05score\x12\x1b\n" + "\tis_hidden\x18\x03 \x01(\bR\bisHidden\x12\x1e\n" + "\n" + "categories\x18\x04 \x03(\tR\n" + "categories\x128\n" + "\ttimestamp\x18\x05 \x01(\v2\x1a.google.protobuf.TimestampR\ttimestamp\"\x98\x01\n" + "\x0eScoreCondition\x12\x1b\n" + "\x06subset\x18\x01 \x01(\tH\x00R\x06subset\x88\x01\x01\x12\x13\n" + "\x02id\x18\x02 \x01(\tH\x01R\x02id\x88\x01\x01\x127\n" + "\x06before\x18\x03 \x01(\v2\x1a.google.protobuf.TimestampH\x02R\x06before\x88\x01\x01B\t\n" + "\a_subsetB\x05\n" + "\x03_idB\t\n" + "\a_before\"\x81\x01\n" + "\n" + "ScorePatch\x12 \n" + "\tis_hidden\x18\x01 \x01(\bH\x00R\bisHidden\x88\x01\x01\x12\x1e\n" + "\n" + "categories\x18\x02 \x03(\tR\n" + "categories\x12\x19\n" + "\x05score\x18\x03 \x01(\x01H\x01R\x05score\x88\x01\x01B\f\n" + "\n" + "_is_hiddenB\b\n" + "\x06_score\"u\n" + "\x0fTimeSeriesPoint\x12\x12\n" + "\x04name\x18\x01 \x01(\tR\x04name\x128\n" + "\ttimestamp\x18\x02 \x01(\v2\x1a.google.protobuf.TimestampR\ttimestamp\x12\x14\n" + "\x05value\x18\x03 \x01(\x01R\x05value\" \n" + "\n" + "GetRequest\x12\x12\n" + "\x04name\x18\x01 \x01(\tR\x04name\"2\n" + "\vGetResponse\x12\x19\n" + "\x05value\x18\x01 \x01(\tH\x00R\x05value\x88\x01\x01B\b\n" + "\x06_value\"5\n" + "\n" + "SetRequest\x12'\n" + "\x06values\x18\x01 \x03(\v2\x0f.protocol.ValueR\x06values\"\r\n" + "\vSetResponse\"#\n" + "\rDeleteRequest\x12\x12\n" + "\x04name\x18\x01 \x01(\tR\x04name\"\x10\n" + "\x0eDeleteResponse\"7\n" + "\vPushRequest\x12\x12\n" + "\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n" + "\x05value\x18\x02 \x01(\tR\x05value\"\x0e\n" + "\fPushResponse\" \n" + "\n" + "PopRequest\x12\x12\n" + "\x04name\x18\x01 \x01(\tR\x04name\"2\n" + "\vPopResponse\x12\x19\n" + "\x05value\x18\x01 \x01(\tH\x00R\x05value\x88\x01\x01B\b\n" + "\x06_value\"#\n" + "\rRemainRequest\x12\x12\n" + "\x04name\x18\x01 \x01(\tR\x04name\"&\n" + "\x0eRemainResponse\x12\x14\n" + "\x05count\x18\x01 \x01(\x03R\x05count\"y\n" + "\x10AddScoresRequest\x12\x1e\n" + "\n" + "collection\x18\x01 \x01(\tR\n" + "collection\x12\x16\n" + "\x06subset\x18\x02 \x01(\tR\x06subset\x12-\n" + "\tdocuments\x18\x03 \x03(\v2\x0f.protocol.ScoreR\tdocuments\"\x13\n" + "\x11AddScoresResponse\"\x8b\x01\n" + "\x13SearchScoresRequest\x12\x1e\n" + "\n" + "collection\x18\x01 \x01(\tR\n" + "collection\x12\x16\n" + "\x06subset\x18\x02 \x01(\tR\x06subset\x12\x14\n" + "\x05query\x18\x03 \x03(\tR\x05query\x12\x14\n" + "\x05begin\x18\x04 \x01(\x05R\x05begin\x12\x10\n" + "\x03end\x18\x05 \x01(\x05R\x03end\"E\n" + "\x14SearchScoresResponse\x12-\n" + "\tdocuments\x18\x01 \x03(\v2\x0f.protocol.ScoreR\tdocuments\"m\n" + "\x13DeleteScoresRequest\x12\x1e\n" + "\n" + "collection\x18\x01 \x03(\tR\n" + "collection\x126\n" + "\tcondition\x18\x02 \x01(\v2\x18.protocol.ScoreConditionR\tcondition\"\x16\n" + "\x14DeleteScoresResponse\"\x99\x01\n" + "\x13UpdateScoresRequest\x12\x1e\n" + "\n" + "collection\x18\x01 \x03(\tR\n" + "collection\x12\x1b\n" + "\x06subset\x18\x02 \x01(\tH\x00R\x06subset\x88\x01\x01\x12\x0e\n" + "\x02id\x18\x03 \x01(\tR\x02id\x12*\n" + "\x05patch\x18\x04 \x01(\v2\x14.protocol.ScorePatchR\x05patchB\t\n" + "\a_subset\"\x16\n" + "\x14UpdateScoresResponse\"\x13\n" + "\x11ScanScoresRequest\"\x96\x01\n" + "\x12ScanScoresResponse\x12\x1e\n" + "\n" + "collection\x18\x01 \x01(\tR\n" + "collection\x12\x0e\n" + "\x02id\x18\x02 \x01(\tR\x02id\x12\x16\n" + "\x06subset\x18\x03 \x01(\tR\x06subset\x128\n" + "\ttimestamp\x18\x04 \x01(\v2\x1a.google.protobuf.TimestampR\ttimestamp\"O\n" + "\x1aAddTimeSeriesPointsRequest\x121\n" + "\x06points\x18\x01 \x03(\v2\x19.protocol.TimeSeriesPointR\x06points\"\x1d\n" + "\x1bAddTimeSeriesPointsResponse\"\xac\x01\n" + "\x1aGetTimeSeriesPointsRequest\x12\x12\n" + "\x04name\x18\x01 \x01(\tR\x04name\x120\n" + "\x05begin\x18\x02 \x01(\v2\x1a.google.protobuf.TimestampR\x05begin\x12,\n" + "\x03end\x18\x03 \x01(\v2\x1a.google.protobuf.TimestampR\x03end\x12\x1a\n" + "\bduration\x18\x04 \x01(\x03R\bduration\"P\n" + "\x1bGetTimeSeriesPointsResponse\x121\n" + "\x06points\x18\x01 \x03(\v2\x19.protocol.TimeSeriesPointR\x06points2\xf2\a\n" + "\n" + "CacheStore\x127\n" + "\x04Ping\x12\x15.protocol.PingRequest\x1a\x16.protocol.PingResponse\"\x00\x124\n" + "\x03Get\x12\x14.protocol.GetRequest\x1a\x15.protocol.GetResponse\"\x00\x124\n" + "\x03Set\x12\x14.protocol.SetRequest\x1a\x15.protocol.SetResponse\"\x00\x12=\n" + "\x06Delete\x12\x17.protocol.DeleteRequest\x1a\x18.protocol.DeleteResponse\"\x00\x127\n" + "\x04Push\x12\x15.protocol.PushRequest\x1a\x16.protocol.PushResponse\"\x00\x124\n" + "\x03Pop\x12\x14.protocol.PopRequest\x1a\x15.protocol.PopResponse\"\x00\x12=\n" + "\x06Remain\x12\x17.protocol.RemainRequest\x1a\x18.protocol.RemainResponse\"\x00\x12F\n" + "\tAddScores\x12\x1a.protocol.AddScoresRequest\x1a\x1b.protocol.AddScoresResponse\"\x00\x12O\n" + "\fSearchScores\x12\x1d.protocol.SearchScoresRequest\x1a\x1e.protocol.SearchScoresResponse\"\x00\x12O\n" + "\fDeleteScores\x12\x1d.protocol.DeleteScoresRequest\x1a\x1e.protocol.DeleteScoresResponse\"\x00\x12O\n" + "\fUpdateScores\x12\x1d.protocol.UpdateScoresRequest\x1a\x1e.protocol.UpdateScoresResponse\"\x00\x12K\n" + "\n" + "ScanScores\x12\x1b.protocol.ScanScoresRequest\x1a\x1c.protocol.ScanScoresResponse\"\x000\x01\x12d\n" + "\x13AddTimeSeriesPoints\x12$.protocol.AddTimeSeriesPointsRequest\x1a%.protocol.AddTimeSeriesPointsResponse\"\x00\x12d\n" + "\x13GetTimeSeriesPoints\x12$.protocol.GetTimeSeriesPointsRequest\x1a%.protocol.GetTimeSeriesPointsResponse\"\x00B$Z\"github.com/gorse-io/gorse/protocolb\x06proto3" var ( file_cache_store_proto_rawDescOnce sync.Once file_cache_store_proto_rawDescData []byte ) func file_cache_store_proto_rawDescGZIP() []byte { file_cache_store_proto_rawDescOnce.Do(func() { file_cache_store_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_cache_store_proto_rawDesc), len(file_cache_store_proto_rawDesc))) }) return file_cache_store_proto_rawDescData } var file_cache_store_proto_msgTypes = make([]protoimpl.MessageInfo, 31) var file_cache_store_proto_goTypes = []any{ (*Value)(nil), // 0: protocol.Value (*Score)(nil), // 1: protocol.Score (*ScoreCondition)(nil), // 2: protocol.ScoreCondition (*ScorePatch)(nil), // 3: protocol.ScorePatch (*TimeSeriesPoint)(nil), // 4: protocol.TimeSeriesPoint (*GetRequest)(nil), // 5: protocol.GetRequest (*GetResponse)(nil), // 6: protocol.GetResponse (*SetRequest)(nil), // 7: protocol.SetRequest (*SetResponse)(nil), // 8: protocol.SetResponse (*DeleteRequest)(nil), // 9: protocol.DeleteRequest (*DeleteResponse)(nil), // 10: protocol.DeleteResponse (*PushRequest)(nil), // 11: protocol.PushRequest (*PushResponse)(nil), // 12: protocol.PushResponse (*PopRequest)(nil), // 13: protocol.PopRequest (*PopResponse)(nil), // 14: protocol.PopResponse (*RemainRequest)(nil), // 15: protocol.RemainRequest (*RemainResponse)(nil), // 16: protocol.RemainResponse (*AddScoresRequest)(nil), // 17: protocol.AddScoresRequest (*AddScoresResponse)(nil), // 18: protocol.AddScoresResponse (*SearchScoresRequest)(nil), // 19: protocol.SearchScoresRequest (*SearchScoresResponse)(nil), // 20: protocol.SearchScoresResponse (*DeleteScoresRequest)(nil), // 21: protocol.DeleteScoresRequest (*DeleteScoresResponse)(nil), // 22: protocol.DeleteScoresResponse (*UpdateScoresRequest)(nil), // 23: protocol.UpdateScoresRequest (*UpdateScoresResponse)(nil), // 24: protocol.UpdateScoresResponse (*ScanScoresRequest)(nil), // 25: protocol.ScanScoresRequest (*ScanScoresResponse)(nil), // 26: protocol.ScanScoresResponse (*AddTimeSeriesPointsRequest)(nil), // 27: protocol.AddTimeSeriesPointsRequest (*AddTimeSeriesPointsResponse)(nil), // 28: protocol.AddTimeSeriesPointsResponse (*GetTimeSeriesPointsRequest)(nil), // 29: protocol.GetTimeSeriesPointsRequest (*GetTimeSeriesPointsResponse)(nil), // 30: protocol.GetTimeSeriesPointsResponse (*timestamppb.Timestamp)(nil), // 31: google.protobuf.Timestamp (*PingRequest)(nil), // 32: protocol.PingRequest (*PingResponse)(nil), // 33: protocol.PingResponse } var file_cache_store_proto_depIdxs = []int32{ 31, // 0: protocol.Score.timestamp:type_name -> google.protobuf.Timestamp 31, // 1: protocol.ScoreCondition.before:type_name -> google.protobuf.Timestamp 31, // 2: protocol.TimeSeriesPoint.timestamp:type_name -> google.protobuf.Timestamp 0, // 3: protocol.SetRequest.values:type_name -> protocol.Value 1, // 4: protocol.AddScoresRequest.documents:type_name -> protocol.Score 1, // 5: protocol.SearchScoresResponse.documents:type_name -> protocol.Score 2, // 6: protocol.DeleteScoresRequest.condition:type_name -> protocol.ScoreCondition 3, // 7: protocol.UpdateScoresRequest.patch:type_name -> protocol.ScorePatch 31, // 8: protocol.ScanScoresResponse.timestamp:type_name -> google.protobuf.Timestamp 4, // 9: protocol.AddTimeSeriesPointsRequest.points:type_name -> protocol.TimeSeriesPoint 31, // 10: protocol.GetTimeSeriesPointsRequest.begin:type_name -> google.protobuf.Timestamp 31, // 11: protocol.GetTimeSeriesPointsRequest.end:type_name -> google.protobuf.Timestamp 4, // 12: protocol.GetTimeSeriesPointsResponse.points:type_name -> protocol.TimeSeriesPoint 32, // 13: protocol.CacheStore.Ping:input_type -> protocol.PingRequest 5, // 14: protocol.CacheStore.Get:input_type -> protocol.GetRequest 7, // 15: protocol.CacheStore.Set:input_type -> protocol.SetRequest 9, // 16: protocol.CacheStore.Delete:input_type -> protocol.DeleteRequest 11, // 17: protocol.CacheStore.Push:input_type -> protocol.PushRequest 13, // 18: protocol.CacheStore.Pop:input_type -> protocol.PopRequest 15, // 19: protocol.CacheStore.Remain:input_type -> protocol.RemainRequest 17, // 20: protocol.CacheStore.AddScores:input_type -> protocol.AddScoresRequest 19, // 21: protocol.CacheStore.SearchScores:input_type -> protocol.SearchScoresRequest 21, // 22: protocol.CacheStore.DeleteScores:input_type -> protocol.DeleteScoresRequest 23, // 23: protocol.CacheStore.UpdateScores:input_type -> protocol.UpdateScoresRequest 25, // 24: protocol.CacheStore.ScanScores:input_type -> protocol.ScanScoresRequest 27, // 25: protocol.CacheStore.AddTimeSeriesPoints:input_type -> protocol.AddTimeSeriesPointsRequest 29, // 26: protocol.CacheStore.GetTimeSeriesPoints:input_type -> protocol.GetTimeSeriesPointsRequest 33, // 27: protocol.CacheStore.Ping:output_type -> protocol.PingResponse 6, // 28: protocol.CacheStore.Get:output_type -> protocol.GetResponse 8, // 29: protocol.CacheStore.Set:output_type -> protocol.SetResponse 10, // 30: protocol.CacheStore.Delete:output_type -> protocol.DeleteResponse 12, // 31: protocol.CacheStore.Push:output_type -> protocol.PushResponse 14, // 32: protocol.CacheStore.Pop:output_type -> protocol.PopResponse 16, // 33: protocol.CacheStore.Remain:output_type -> protocol.RemainResponse 18, // 34: protocol.CacheStore.AddScores:output_type -> protocol.AddScoresResponse 20, // 35: protocol.CacheStore.SearchScores:output_type -> protocol.SearchScoresResponse 22, // 36: protocol.CacheStore.DeleteScores:output_type -> protocol.DeleteScoresResponse 24, // 37: protocol.CacheStore.UpdateScores:output_type -> protocol.UpdateScoresResponse 26, // 38: protocol.CacheStore.ScanScores:output_type -> protocol.ScanScoresResponse 28, // 39: protocol.CacheStore.AddTimeSeriesPoints:output_type -> protocol.AddTimeSeriesPointsResponse 30, // 40: protocol.CacheStore.GetTimeSeriesPoints:output_type -> protocol.GetTimeSeriesPointsResponse 27, // [27:41] is the sub-list for method output_type 13, // [13:27] is the sub-list for method input_type 13, // [13:13] is the sub-list for extension type_name 13, // [13:13] is the sub-list for extension extendee 0, // [0:13] is the sub-list for field type_name } func init() { file_cache_store_proto_init() } func file_cache_store_proto_init() { if File_cache_store_proto != nil { return } file_protocol_proto_init() file_cache_store_proto_msgTypes[2].OneofWrappers = []any{} file_cache_store_proto_msgTypes[3].OneofWrappers = []any{} file_cache_store_proto_msgTypes[6].OneofWrappers = []any{} file_cache_store_proto_msgTypes[14].OneofWrappers = []any{} file_cache_store_proto_msgTypes[23].OneofWrappers = []any{} type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_cache_store_proto_rawDesc), len(file_cache_store_proto_rawDesc)), NumEnums: 0, NumMessages: 31, NumExtensions: 0, NumServices: 1, }, GoTypes: file_cache_store_proto_goTypes, DependencyIndexes: file_cache_store_proto_depIdxs, MessageInfos: file_cache_store_proto_msgTypes, }.Build() File_cache_store_proto = out.File file_cache_store_proto_goTypes = nil file_cache_store_proto_depIdxs = nil } ================================================ FILE: protocol/cache_store.proto ================================================ // Copyright 2024 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. syntax = "proto3"; option go_package = "github.com/gorse-io/gorse/protocol"; package protocol; import "google/protobuf/timestamp.proto"; import "protocol.proto"; message Value { string name = 1; string value = 2; } message Score { string id = 1; double score = 2; bool is_hidden = 3; repeated string categories = 4; google.protobuf.Timestamp timestamp = 5; } message ScoreCondition { optional string subset = 1; optional string id = 2; optional google.protobuf.Timestamp before = 3; } message ScorePatch { optional bool is_hidden = 1; repeated string categories = 2; optional double score = 3; } message TimeSeriesPoint { string name = 1; google.protobuf.Timestamp timestamp = 2; double value = 3; } message GetRequest { string name = 1; } message GetResponse { optional string value = 1; } message SetRequest { repeated Value values = 1; } message SetResponse {} message DeleteRequest { string name = 1; } message DeleteResponse {} message PushRequest { string name = 1; string value = 2; } message PushResponse {} message PopRequest { string name = 1; } message PopResponse { optional string value = 1; } message RemainRequest { string name = 1; } message RemainResponse { int64 count = 1; } message AddScoresRequest { string collection = 1; string subset = 2; repeated Score documents = 3; } message AddScoresResponse {} message SearchScoresRequest { string collection = 1; string subset = 2; repeated string query = 3; int32 begin = 4; int32 end = 5; } message SearchScoresResponse { repeated Score documents = 1; } message DeleteScoresRequest { repeated string collection = 1; ScoreCondition condition = 2; } message DeleteScoresResponse {} message UpdateScoresRequest { repeated string collection = 1; optional string subset = 2; string id = 3; ScorePatch patch = 4; } message UpdateScoresResponse {} message ScanScoresRequest {} message ScanScoresResponse { string collection = 1; string id = 2; string subset = 3; google.protobuf.Timestamp timestamp = 4; } message AddTimeSeriesPointsRequest { repeated TimeSeriesPoint points = 1; } message AddTimeSeriesPointsResponse {} message GetTimeSeriesPointsRequest { string name = 1; google.protobuf.Timestamp begin = 2; google.protobuf.Timestamp end = 3; int64 duration = 4; } message GetTimeSeriesPointsResponse { repeated TimeSeriesPoint points = 1; } service CacheStore { rpc Ping(PingRequest) returns (PingResponse) {} rpc Get(GetRequest) returns (GetResponse) {} rpc Set(SetRequest) returns (SetResponse) {} rpc Delete(DeleteRequest) returns (DeleteResponse) {} rpc Push(PushRequest) returns (PushResponse) {} rpc Pop(PopRequest) returns (PopResponse) {} rpc Remain(RemainRequest) returns (RemainResponse) {} rpc AddScores(AddScoresRequest) returns (AddScoresResponse) {} rpc SearchScores(SearchScoresRequest) returns (SearchScoresResponse) {} rpc DeleteScores(DeleteScoresRequest) returns (DeleteScoresResponse) {} rpc UpdateScores(UpdateScoresRequest) returns (UpdateScoresResponse) {} rpc ScanScores(ScanScoresRequest) returns (stream ScanScoresResponse) {} rpc AddTimeSeriesPoints(AddTimeSeriesPointsRequest) returns (AddTimeSeriesPointsResponse) {} rpc GetTimeSeriesPoints(GetTimeSeriesPointsRequest) returns (GetTimeSeriesPointsResponse) {} } ================================================ FILE: protocol/cache_store_grpc.pb.go ================================================ // Copyright 2024 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: // - protoc-gen-go-grpc v1.6.0 // - protoc v6.33.1 // source: cache_store.proto package protocol import ( context "context" grpc "google.golang.org/grpc" codes "google.golang.org/grpc/codes" status "google.golang.org/grpc/status" ) // This is a compile-time assertion to ensure that this generated file // is compatible with the grpc package it is being compiled against. // Requires gRPC-Go v1.64.0 or later. const _ = grpc.SupportPackageIsVersion9 const ( CacheStore_Ping_FullMethodName = "/protocol.CacheStore/Ping" CacheStore_Get_FullMethodName = "/protocol.CacheStore/Get" CacheStore_Set_FullMethodName = "/protocol.CacheStore/Set" CacheStore_Delete_FullMethodName = "/protocol.CacheStore/Delete" CacheStore_Push_FullMethodName = "/protocol.CacheStore/Push" CacheStore_Pop_FullMethodName = "/protocol.CacheStore/Pop" CacheStore_Remain_FullMethodName = "/protocol.CacheStore/Remain" CacheStore_AddScores_FullMethodName = "/protocol.CacheStore/AddScores" CacheStore_SearchScores_FullMethodName = "/protocol.CacheStore/SearchScores" CacheStore_DeleteScores_FullMethodName = "/protocol.CacheStore/DeleteScores" CacheStore_UpdateScores_FullMethodName = "/protocol.CacheStore/UpdateScores" CacheStore_ScanScores_FullMethodName = "/protocol.CacheStore/ScanScores" CacheStore_AddTimeSeriesPoints_FullMethodName = "/protocol.CacheStore/AddTimeSeriesPoints" CacheStore_GetTimeSeriesPoints_FullMethodName = "/protocol.CacheStore/GetTimeSeriesPoints" ) // CacheStoreClient is the client API for CacheStore service. // // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. type CacheStoreClient interface { Ping(ctx context.Context, in *PingRequest, opts ...grpc.CallOption) (*PingResponse, error) Get(ctx context.Context, in *GetRequest, opts ...grpc.CallOption) (*GetResponse, error) Set(ctx context.Context, in *SetRequest, opts ...grpc.CallOption) (*SetResponse, error) Delete(ctx context.Context, in *DeleteRequest, opts ...grpc.CallOption) (*DeleteResponse, error) Push(ctx context.Context, in *PushRequest, opts ...grpc.CallOption) (*PushResponse, error) Pop(ctx context.Context, in *PopRequest, opts ...grpc.CallOption) (*PopResponse, error) Remain(ctx context.Context, in *RemainRequest, opts ...grpc.CallOption) (*RemainResponse, error) AddScores(ctx context.Context, in *AddScoresRequest, opts ...grpc.CallOption) (*AddScoresResponse, error) SearchScores(ctx context.Context, in *SearchScoresRequest, opts ...grpc.CallOption) (*SearchScoresResponse, error) DeleteScores(ctx context.Context, in *DeleteScoresRequest, opts ...grpc.CallOption) (*DeleteScoresResponse, error) UpdateScores(ctx context.Context, in *UpdateScoresRequest, opts ...grpc.CallOption) (*UpdateScoresResponse, error) ScanScores(ctx context.Context, in *ScanScoresRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[ScanScoresResponse], error) AddTimeSeriesPoints(ctx context.Context, in *AddTimeSeriesPointsRequest, opts ...grpc.CallOption) (*AddTimeSeriesPointsResponse, error) GetTimeSeriesPoints(ctx context.Context, in *GetTimeSeriesPointsRequest, opts ...grpc.CallOption) (*GetTimeSeriesPointsResponse, error) } type cacheStoreClient struct { cc grpc.ClientConnInterface } func NewCacheStoreClient(cc grpc.ClientConnInterface) CacheStoreClient { return &cacheStoreClient{cc} } func (c *cacheStoreClient) Ping(ctx context.Context, in *PingRequest, opts ...grpc.CallOption) (*PingResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(PingResponse) err := c.cc.Invoke(ctx, CacheStore_Ping_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } func (c *cacheStoreClient) Get(ctx context.Context, in *GetRequest, opts ...grpc.CallOption) (*GetResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(GetResponse) err := c.cc.Invoke(ctx, CacheStore_Get_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } func (c *cacheStoreClient) Set(ctx context.Context, in *SetRequest, opts ...grpc.CallOption) (*SetResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(SetResponse) err := c.cc.Invoke(ctx, CacheStore_Set_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } func (c *cacheStoreClient) Delete(ctx context.Context, in *DeleteRequest, opts ...grpc.CallOption) (*DeleteResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(DeleteResponse) err := c.cc.Invoke(ctx, CacheStore_Delete_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } func (c *cacheStoreClient) Push(ctx context.Context, in *PushRequest, opts ...grpc.CallOption) (*PushResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(PushResponse) err := c.cc.Invoke(ctx, CacheStore_Push_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } func (c *cacheStoreClient) Pop(ctx context.Context, in *PopRequest, opts ...grpc.CallOption) (*PopResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(PopResponse) err := c.cc.Invoke(ctx, CacheStore_Pop_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } func (c *cacheStoreClient) Remain(ctx context.Context, in *RemainRequest, opts ...grpc.CallOption) (*RemainResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(RemainResponse) err := c.cc.Invoke(ctx, CacheStore_Remain_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } func (c *cacheStoreClient) AddScores(ctx context.Context, in *AddScoresRequest, opts ...grpc.CallOption) (*AddScoresResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(AddScoresResponse) err := c.cc.Invoke(ctx, CacheStore_AddScores_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } func (c *cacheStoreClient) SearchScores(ctx context.Context, in *SearchScoresRequest, opts ...grpc.CallOption) (*SearchScoresResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(SearchScoresResponse) err := c.cc.Invoke(ctx, CacheStore_SearchScores_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } func (c *cacheStoreClient) DeleteScores(ctx context.Context, in *DeleteScoresRequest, opts ...grpc.CallOption) (*DeleteScoresResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(DeleteScoresResponse) err := c.cc.Invoke(ctx, CacheStore_DeleteScores_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } func (c *cacheStoreClient) UpdateScores(ctx context.Context, in *UpdateScoresRequest, opts ...grpc.CallOption) (*UpdateScoresResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(UpdateScoresResponse) err := c.cc.Invoke(ctx, CacheStore_UpdateScores_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } func (c *cacheStoreClient) ScanScores(ctx context.Context, in *ScanScoresRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[ScanScoresResponse], error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) stream, err := c.cc.NewStream(ctx, &CacheStore_ServiceDesc.Streams[0], CacheStore_ScanScores_FullMethodName, cOpts...) if err != nil { return nil, err } x := &grpc.GenericClientStream[ScanScoresRequest, ScanScoresResponse]{ClientStream: stream} if err := x.ClientStream.SendMsg(in); err != nil { return nil, err } if err := x.ClientStream.CloseSend(); err != nil { return nil, err } return x, nil } // This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. type CacheStore_ScanScoresClient = grpc.ServerStreamingClient[ScanScoresResponse] func (c *cacheStoreClient) AddTimeSeriesPoints(ctx context.Context, in *AddTimeSeriesPointsRequest, opts ...grpc.CallOption) (*AddTimeSeriesPointsResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(AddTimeSeriesPointsResponse) err := c.cc.Invoke(ctx, CacheStore_AddTimeSeriesPoints_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } func (c *cacheStoreClient) GetTimeSeriesPoints(ctx context.Context, in *GetTimeSeriesPointsRequest, opts ...grpc.CallOption) (*GetTimeSeriesPointsResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(GetTimeSeriesPointsResponse) err := c.cc.Invoke(ctx, CacheStore_GetTimeSeriesPoints_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } // CacheStoreServer is the server API for CacheStore service. // All implementations must embed UnimplementedCacheStoreServer // for forward compatibility. type CacheStoreServer interface { Ping(context.Context, *PingRequest) (*PingResponse, error) Get(context.Context, *GetRequest) (*GetResponse, error) Set(context.Context, *SetRequest) (*SetResponse, error) Delete(context.Context, *DeleteRequest) (*DeleteResponse, error) Push(context.Context, *PushRequest) (*PushResponse, error) Pop(context.Context, *PopRequest) (*PopResponse, error) Remain(context.Context, *RemainRequest) (*RemainResponse, error) AddScores(context.Context, *AddScoresRequest) (*AddScoresResponse, error) SearchScores(context.Context, *SearchScoresRequest) (*SearchScoresResponse, error) DeleteScores(context.Context, *DeleteScoresRequest) (*DeleteScoresResponse, error) UpdateScores(context.Context, *UpdateScoresRequest) (*UpdateScoresResponse, error) ScanScores(*ScanScoresRequest, grpc.ServerStreamingServer[ScanScoresResponse]) error AddTimeSeriesPoints(context.Context, *AddTimeSeriesPointsRequest) (*AddTimeSeriesPointsResponse, error) GetTimeSeriesPoints(context.Context, *GetTimeSeriesPointsRequest) (*GetTimeSeriesPointsResponse, error) mustEmbedUnimplementedCacheStoreServer() } // UnimplementedCacheStoreServer must be embedded to have // forward compatible implementations. // // NOTE: this should be embedded by value instead of pointer to avoid a nil // pointer dereference when methods are called. type UnimplementedCacheStoreServer struct{} func (UnimplementedCacheStoreServer) Ping(context.Context, *PingRequest) (*PingResponse, error) { return nil, status.Error(codes.Unimplemented, "method Ping not implemented") } func (UnimplementedCacheStoreServer) Get(context.Context, *GetRequest) (*GetResponse, error) { return nil, status.Error(codes.Unimplemented, "method Get not implemented") } func (UnimplementedCacheStoreServer) Set(context.Context, *SetRequest) (*SetResponse, error) { return nil, status.Error(codes.Unimplemented, "method Set not implemented") } func (UnimplementedCacheStoreServer) Delete(context.Context, *DeleteRequest) (*DeleteResponse, error) { return nil, status.Error(codes.Unimplemented, "method Delete not implemented") } func (UnimplementedCacheStoreServer) Push(context.Context, *PushRequest) (*PushResponse, error) { return nil, status.Error(codes.Unimplemented, "method Push not implemented") } func (UnimplementedCacheStoreServer) Pop(context.Context, *PopRequest) (*PopResponse, error) { return nil, status.Error(codes.Unimplemented, "method Pop not implemented") } func (UnimplementedCacheStoreServer) Remain(context.Context, *RemainRequest) (*RemainResponse, error) { return nil, status.Error(codes.Unimplemented, "method Remain not implemented") } func (UnimplementedCacheStoreServer) AddScores(context.Context, *AddScoresRequest) (*AddScoresResponse, error) { return nil, status.Error(codes.Unimplemented, "method AddScores not implemented") } func (UnimplementedCacheStoreServer) SearchScores(context.Context, *SearchScoresRequest) (*SearchScoresResponse, error) { return nil, status.Error(codes.Unimplemented, "method SearchScores not implemented") } func (UnimplementedCacheStoreServer) DeleteScores(context.Context, *DeleteScoresRequest) (*DeleteScoresResponse, error) { return nil, status.Error(codes.Unimplemented, "method DeleteScores not implemented") } func (UnimplementedCacheStoreServer) UpdateScores(context.Context, *UpdateScoresRequest) (*UpdateScoresResponse, error) { return nil, status.Error(codes.Unimplemented, "method UpdateScores not implemented") } func (UnimplementedCacheStoreServer) ScanScores(*ScanScoresRequest, grpc.ServerStreamingServer[ScanScoresResponse]) error { return status.Error(codes.Unimplemented, "method ScanScores not implemented") } func (UnimplementedCacheStoreServer) AddTimeSeriesPoints(context.Context, *AddTimeSeriesPointsRequest) (*AddTimeSeriesPointsResponse, error) { return nil, status.Error(codes.Unimplemented, "method AddTimeSeriesPoints not implemented") } func (UnimplementedCacheStoreServer) GetTimeSeriesPoints(context.Context, *GetTimeSeriesPointsRequest) (*GetTimeSeriesPointsResponse, error) { return nil, status.Error(codes.Unimplemented, "method GetTimeSeriesPoints not implemented") } func (UnimplementedCacheStoreServer) mustEmbedUnimplementedCacheStoreServer() {} func (UnimplementedCacheStoreServer) testEmbeddedByValue() {} // UnsafeCacheStoreServer may be embedded to opt out of forward compatibility for this service. // Use of this interface is not recommended, as added methods to CacheStoreServer will // result in compilation errors. type UnsafeCacheStoreServer interface { mustEmbedUnimplementedCacheStoreServer() } func RegisterCacheStoreServer(s grpc.ServiceRegistrar, srv CacheStoreServer) { // If the following call panics, it indicates UnimplementedCacheStoreServer was // embedded by pointer and is nil. This will cause panics if an // unimplemented method is ever invoked, so we test this at initialization // time to prevent it from happening at runtime later due to I/O. if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { t.testEmbeddedByValue() } s.RegisterService(&CacheStore_ServiceDesc, srv) } func _CacheStore_Ping_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(PingRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(CacheStoreServer).Ping(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: CacheStore_Ping_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(CacheStoreServer).Ping(ctx, req.(*PingRequest)) } return interceptor(ctx, in, info, handler) } func _CacheStore_Get_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(GetRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(CacheStoreServer).Get(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: CacheStore_Get_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(CacheStoreServer).Get(ctx, req.(*GetRequest)) } return interceptor(ctx, in, info, handler) } func _CacheStore_Set_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(SetRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(CacheStoreServer).Set(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: CacheStore_Set_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(CacheStoreServer).Set(ctx, req.(*SetRequest)) } return interceptor(ctx, in, info, handler) } func _CacheStore_Delete_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(DeleteRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(CacheStoreServer).Delete(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: CacheStore_Delete_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(CacheStoreServer).Delete(ctx, req.(*DeleteRequest)) } return interceptor(ctx, in, info, handler) } func _CacheStore_Push_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(PushRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(CacheStoreServer).Push(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: CacheStore_Push_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(CacheStoreServer).Push(ctx, req.(*PushRequest)) } return interceptor(ctx, in, info, handler) } func _CacheStore_Pop_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(PopRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(CacheStoreServer).Pop(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: CacheStore_Pop_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(CacheStoreServer).Pop(ctx, req.(*PopRequest)) } return interceptor(ctx, in, info, handler) } func _CacheStore_Remain_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(RemainRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(CacheStoreServer).Remain(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: CacheStore_Remain_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(CacheStoreServer).Remain(ctx, req.(*RemainRequest)) } return interceptor(ctx, in, info, handler) } func _CacheStore_AddScores_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(AddScoresRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(CacheStoreServer).AddScores(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: CacheStore_AddScores_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(CacheStoreServer).AddScores(ctx, req.(*AddScoresRequest)) } return interceptor(ctx, in, info, handler) } func _CacheStore_SearchScores_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(SearchScoresRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(CacheStoreServer).SearchScores(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: CacheStore_SearchScores_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(CacheStoreServer).SearchScores(ctx, req.(*SearchScoresRequest)) } return interceptor(ctx, in, info, handler) } func _CacheStore_DeleteScores_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(DeleteScoresRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(CacheStoreServer).DeleteScores(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: CacheStore_DeleteScores_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(CacheStoreServer).DeleteScores(ctx, req.(*DeleteScoresRequest)) } return interceptor(ctx, in, info, handler) } func _CacheStore_UpdateScores_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(UpdateScoresRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(CacheStoreServer).UpdateScores(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: CacheStore_UpdateScores_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(CacheStoreServer).UpdateScores(ctx, req.(*UpdateScoresRequest)) } return interceptor(ctx, in, info, handler) } func _CacheStore_ScanScores_Handler(srv interface{}, stream grpc.ServerStream) error { m := new(ScanScoresRequest) if err := stream.RecvMsg(m); err != nil { return err } return srv.(CacheStoreServer).ScanScores(m, &grpc.GenericServerStream[ScanScoresRequest, ScanScoresResponse]{ServerStream: stream}) } // This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. type CacheStore_ScanScoresServer = grpc.ServerStreamingServer[ScanScoresResponse] func _CacheStore_AddTimeSeriesPoints_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(AddTimeSeriesPointsRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(CacheStoreServer).AddTimeSeriesPoints(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: CacheStore_AddTimeSeriesPoints_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(CacheStoreServer).AddTimeSeriesPoints(ctx, req.(*AddTimeSeriesPointsRequest)) } return interceptor(ctx, in, info, handler) } func _CacheStore_GetTimeSeriesPoints_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(GetTimeSeriesPointsRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(CacheStoreServer).GetTimeSeriesPoints(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: CacheStore_GetTimeSeriesPoints_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(CacheStoreServer).GetTimeSeriesPoints(ctx, req.(*GetTimeSeriesPointsRequest)) } return interceptor(ctx, in, info, handler) } // CacheStore_ServiceDesc is the grpc.ServiceDesc for CacheStore service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) var CacheStore_ServiceDesc = grpc.ServiceDesc{ ServiceName: "protocol.CacheStore", HandlerType: (*CacheStoreServer)(nil), Methods: []grpc.MethodDesc{ { MethodName: "Ping", Handler: _CacheStore_Ping_Handler, }, { MethodName: "Get", Handler: _CacheStore_Get_Handler, }, { MethodName: "Set", Handler: _CacheStore_Set_Handler, }, { MethodName: "Delete", Handler: _CacheStore_Delete_Handler, }, { MethodName: "Push", Handler: _CacheStore_Push_Handler, }, { MethodName: "Pop", Handler: _CacheStore_Pop_Handler, }, { MethodName: "Remain", Handler: _CacheStore_Remain_Handler, }, { MethodName: "AddScores", Handler: _CacheStore_AddScores_Handler, }, { MethodName: "SearchScores", Handler: _CacheStore_SearchScores_Handler, }, { MethodName: "DeleteScores", Handler: _CacheStore_DeleteScores_Handler, }, { MethodName: "UpdateScores", Handler: _CacheStore_UpdateScores_Handler, }, { MethodName: "AddTimeSeriesPoints", Handler: _CacheStore_AddTimeSeriesPoints_Handler, }, { MethodName: "GetTimeSeriesPoints", Handler: _CacheStore_GetTimeSeriesPoints_Handler, }, }, Streams: []grpc.StreamDesc{ { StreamName: "ScanScores", Handler: _CacheStore_ScanScores_Handler, ServerStreams: true, }, }, Metadata: "cache_store.proto", } ================================================ FILE: protocol/data_store.pb.go ================================================ // Copyright 2024 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.10 // protoc v6.33.1 // source: data_store.proto package protocol import ( protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" timestamppb "google.golang.org/protobuf/types/known/timestamppb" reflect "reflect" sync "sync" unsafe "unsafe" ) const ( // Verify that this generated code is sufficiently up-to-date. _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) // Verify that runtime/protoimpl is sufficiently up-to-date. _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) type ExpressionType int32 const ( ExpressionType_None ExpressionType = 0 ExpressionType_Less ExpressionType = 1 ExpressionType_LessOrEqual ExpressionType = 2 ExpressionType_Greater ExpressionType = 3 ExpressionType_GreaterOrEqual ExpressionType = 4 ) // Enum value maps for ExpressionType. var ( ExpressionType_name = map[int32]string{ 0: "None", 1: "Less", 2: "LessOrEqual", 3: "Greater", 4: "GreaterOrEqual", } ExpressionType_value = map[string]int32{ "None": 0, "Less": 1, "LessOrEqual": 2, "Greater": 3, "GreaterOrEqual": 4, } ) func (x ExpressionType) Enum() *ExpressionType { p := new(ExpressionType) *p = x return p } func (x ExpressionType) String() string { return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) } func (ExpressionType) Descriptor() protoreflect.EnumDescriptor { return file_data_store_proto_enumTypes[0].Descriptor() } func (ExpressionType) Type() protoreflect.EnumType { return &file_data_store_proto_enumTypes[0] } func (x ExpressionType) Number() protoreflect.EnumNumber { return protoreflect.EnumNumber(x) } // Deprecated: Use ExpressionType.Descriptor instead. func (ExpressionType) EnumDescriptor() ([]byte, []int) { return file_data_store_proto_rawDescGZIP(), []int{0} } type FeedbackTypeExpression struct { state protoimpl.MessageState `protogen:"open.v1"` FeedbackType string `protobuf:"bytes,1,opt,name=feedback_type,json=feedbackType,proto3" json:"feedback_type,omitempty"` ExpressionType ExpressionType `protobuf:"varint,2,opt,name=expression_type,json=expressionType,proto3,enum=protocol.ExpressionType" json:"expression_type,omitempty"` Value float64 `protobuf:"fixed64,3,opt,name=value,proto3" json:"value,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *FeedbackTypeExpression) Reset() { *x = FeedbackTypeExpression{} mi := &file_data_store_proto_msgTypes[0] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *FeedbackTypeExpression) String() string { return protoimpl.X.MessageStringOf(x) } func (*FeedbackTypeExpression) ProtoMessage() {} func (x *FeedbackTypeExpression) ProtoReflect() protoreflect.Message { mi := &file_data_store_proto_msgTypes[0] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use FeedbackTypeExpression.ProtoReflect.Descriptor instead. func (*FeedbackTypeExpression) Descriptor() ([]byte, []int) { return file_data_store_proto_rawDescGZIP(), []int{0} } func (x *FeedbackTypeExpression) GetFeedbackType() string { if x != nil { return x.FeedbackType } return "" } func (x *FeedbackTypeExpression) GetExpressionType() ExpressionType { if x != nil { return x.ExpressionType } return ExpressionType_None } func (x *FeedbackTypeExpression) GetValue() float64 { if x != nil { return x.Value } return 0 } type UserPatch struct { state protoimpl.MessageState `protogen:"open.v1"` Labels []byte `protobuf:"bytes,1,opt,name=labels,proto3" json:"labels,omitempty"` Comment *string `protobuf:"bytes,2,opt,name=comment,proto3,oneof" json:"comment,omitempty"` Subscribe []string `protobuf:"bytes,3,rep,name=subscribe,proto3" json:"subscribe,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *UserPatch) Reset() { *x = UserPatch{} mi := &file_data_store_proto_msgTypes[1] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *UserPatch) String() string { return protoimpl.X.MessageStringOf(x) } func (*UserPatch) ProtoMessage() {} func (x *UserPatch) ProtoReflect() protoreflect.Message { mi := &file_data_store_proto_msgTypes[1] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use UserPatch.ProtoReflect.Descriptor instead. func (*UserPatch) Descriptor() ([]byte, []int) { return file_data_store_proto_rawDescGZIP(), []int{1} } func (x *UserPatch) GetLabels() []byte { if x != nil { return x.Labels } return nil } func (x *UserPatch) GetComment() string { if x != nil && x.Comment != nil { return *x.Comment } return "" } func (x *UserPatch) GetSubscribe() []string { if x != nil { return x.Subscribe } return nil } type ItemPatch struct { state protoimpl.MessageState `protogen:"open.v1"` IsHidden *bool `protobuf:"varint,1,opt,name=is_hidden,json=isHidden,proto3,oneof" json:"is_hidden,omitempty"` Categories []string `protobuf:"bytes,2,rep,name=categories,proto3" json:"categories,omitempty"` Timestamp *timestamppb.Timestamp `protobuf:"bytes,3,opt,name=timestamp,proto3,oneof" json:"timestamp,omitempty"` Labels []byte `protobuf:"bytes,4,opt,name=labels,proto3" json:"labels,omitempty"` Comment *string `protobuf:"bytes,5,opt,name=comment,proto3,oneof" json:"comment,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *ItemPatch) Reset() { *x = ItemPatch{} mi := &file_data_store_proto_msgTypes[2] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *ItemPatch) String() string { return protoimpl.X.MessageStringOf(x) } func (*ItemPatch) ProtoMessage() {} func (x *ItemPatch) ProtoReflect() protoreflect.Message { mi := &file_data_store_proto_msgTypes[2] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use ItemPatch.ProtoReflect.Descriptor instead. func (*ItemPatch) Descriptor() ([]byte, []int) { return file_data_store_proto_rawDescGZIP(), []int{2} } func (x *ItemPatch) GetIsHidden() bool { if x != nil && x.IsHidden != nil { return *x.IsHidden } return false } func (x *ItemPatch) GetCategories() []string { if x != nil { return x.Categories } return nil } func (x *ItemPatch) GetTimestamp() *timestamppb.Timestamp { if x != nil { return x.Timestamp } return nil } func (x *ItemPatch) GetLabels() []byte { if x != nil { return x.Labels } return nil } func (x *ItemPatch) GetComment() string { if x != nil && x.Comment != nil { return *x.Comment } return "" } type ScanOptions struct { state protoimpl.MessageState `protogen:"open.v1"` BeginUserId *string `protobuf:"bytes,1,opt,name=begin_user_id,json=beginUserId,proto3,oneof" json:"begin_user_id,omitempty"` EndUserId *string `protobuf:"bytes,2,opt,name=end_user_id,json=endUserId,proto3,oneof" json:"end_user_id,omitempty"` BeginItemId *string `protobuf:"bytes,3,opt,name=begin_item_id,json=beginItemId,proto3,oneof" json:"begin_item_id,omitempty"` EndItemId *string `protobuf:"bytes,4,opt,name=end_item_id,json=endItemId,proto3,oneof" json:"end_item_id,omitempty"` BeginTime *timestamppb.Timestamp `protobuf:"bytes,5,opt,name=begin_time,json=beginTime,proto3,oneof" json:"begin_time,omitempty"` EndTime *timestamppb.Timestamp `protobuf:"bytes,6,opt,name=end_time,json=endTime,proto3,oneof" json:"end_time,omitempty"` FeedbackTypes []*FeedbackTypeExpression `protobuf:"bytes,7,rep,name=feedback_types,json=feedbackTypes,proto3" json:"feedback_types,omitempty"` OrderByItemId bool `protobuf:"varint,8,opt,name=order_by_item_id,json=orderByItemId,proto3" json:"order_by_item_id,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *ScanOptions) Reset() { *x = ScanOptions{} mi := &file_data_store_proto_msgTypes[3] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *ScanOptions) String() string { return protoimpl.X.MessageStringOf(x) } func (*ScanOptions) ProtoMessage() {} func (x *ScanOptions) ProtoReflect() protoreflect.Message { mi := &file_data_store_proto_msgTypes[3] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use ScanOptions.ProtoReflect.Descriptor instead. func (*ScanOptions) Descriptor() ([]byte, []int) { return file_data_store_proto_rawDescGZIP(), []int{3} } func (x *ScanOptions) GetBeginUserId() string { if x != nil && x.BeginUserId != nil { return *x.BeginUserId } return "" } func (x *ScanOptions) GetEndUserId() string { if x != nil && x.EndUserId != nil { return *x.EndUserId } return "" } func (x *ScanOptions) GetBeginItemId() string { if x != nil && x.BeginItemId != nil { return *x.BeginItemId } return "" } func (x *ScanOptions) GetEndItemId() string { if x != nil && x.EndItemId != nil { return *x.EndItemId } return "" } func (x *ScanOptions) GetBeginTime() *timestamppb.Timestamp { if x != nil { return x.BeginTime } return nil } func (x *ScanOptions) GetEndTime() *timestamppb.Timestamp { if x != nil { return x.EndTime } return nil } func (x *ScanOptions) GetFeedbackTypes() []*FeedbackTypeExpression { if x != nil { return x.FeedbackTypes } return nil } func (x *ScanOptions) GetOrderByItemId() bool { if x != nil { return x.OrderByItemId } return false } type BatchInsertItemsRequest struct { state protoimpl.MessageState `protogen:"open.v1"` Items []*Item `protobuf:"bytes,1,rep,name=items,proto3" json:"items,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *BatchInsertItemsRequest) Reset() { *x = BatchInsertItemsRequest{} mi := &file_data_store_proto_msgTypes[4] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *BatchInsertItemsRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*BatchInsertItemsRequest) ProtoMessage() {} func (x *BatchInsertItemsRequest) ProtoReflect() protoreflect.Message { mi := &file_data_store_proto_msgTypes[4] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use BatchInsertItemsRequest.ProtoReflect.Descriptor instead. func (*BatchInsertItemsRequest) Descriptor() ([]byte, []int) { return file_data_store_proto_rawDescGZIP(), []int{4} } func (x *BatchInsertItemsRequest) GetItems() []*Item { if x != nil { return x.Items } return nil } type BatchInsertItemsResponse struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *BatchInsertItemsResponse) Reset() { *x = BatchInsertItemsResponse{} mi := &file_data_store_proto_msgTypes[5] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *BatchInsertItemsResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*BatchInsertItemsResponse) ProtoMessage() {} func (x *BatchInsertItemsResponse) ProtoReflect() protoreflect.Message { mi := &file_data_store_proto_msgTypes[5] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use BatchInsertItemsResponse.ProtoReflect.Descriptor instead. func (*BatchInsertItemsResponse) Descriptor() ([]byte, []int) { return file_data_store_proto_rawDescGZIP(), []int{5} } type BatchGetItemsRequest struct { state protoimpl.MessageState `protogen:"open.v1"` ItemIds []string `protobuf:"bytes,1,rep,name=item_ids,json=itemIds,proto3" json:"item_ids,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *BatchGetItemsRequest) Reset() { *x = BatchGetItemsRequest{} mi := &file_data_store_proto_msgTypes[6] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *BatchGetItemsRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*BatchGetItemsRequest) ProtoMessage() {} func (x *BatchGetItemsRequest) ProtoReflect() protoreflect.Message { mi := &file_data_store_proto_msgTypes[6] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use BatchGetItemsRequest.ProtoReflect.Descriptor instead. func (*BatchGetItemsRequest) Descriptor() ([]byte, []int) { return file_data_store_proto_rawDescGZIP(), []int{6} } func (x *BatchGetItemsRequest) GetItemIds() []string { if x != nil { return x.ItemIds } return nil } type BatchGetItemsResponse struct { state protoimpl.MessageState `protogen:"open.v1"` Items []*Item `protobuf:"bytes,1,rep,name=items,proto3" json:"items,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *BatchGetItemsResponse) Reset() { *x = BatchGetItemsResponse{} mi := &file_data_store_proto_msgTypes[7] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *BatchGetItemsResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*BatchGetItemsResponse) ProtoMessage() {} func (x *BatchGetItemsResponse) ProtoReflect() protoreflect.Message { mi := &file_data_store_proto_msgTypes[7] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use BatchGetItemsResponse.ProtoReflect.Descriptor instead. func (*BatchGetItemsResponse) Descriptor() ([]byte, []int) { return file_data_store_proto_rawDescGZIP(), []int{7} } func (x *BatchGetItemsResponse) GetItems() []*Item { if x != nil { return x.Items } return nil } type DeleteItemRequest struct { state protoimpl.MessageState `protogen:"open.v1"` ItemId string `protobuf:"bytes,1,opt,name=item_id,json=itemId,proto3" json:"item_id,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *DeleteItemRequest) Reset() { *x = DeleteItemRequest{} mi := &file_data_store_proto_msgTypes[8] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *DeleteItemRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*DeleteItemRequest) ProtoMessage() {} func (x *DeleteItemRequest) ProtoReflect() protoreflect.Message { mi := &file_data_store_proto_msgTypes[8] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use DeleteItemRequest.ProtoReflect.Descriptor instead. func (*DeleteItemRequest) Descriptor() ([]byte, []int) { return file_data_store_proto_rawDescGZIP(), []int{8} } func (x *DeleteItemRequest) GetItemId() string { if x != nil { return x.ItemId } return "" } type DeleteItemResponse struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *DeleteItemResponse) Reset() { *x = DeleteItemResponse{} mi := &file_data_store_proto_msgTypes[9] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *DeleteItemResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*DeleteItemResponse) ProtoMessage() {} func (x *DeleteItemResponse) ProtoReflect() protoreflect.Message { mi := &file_data_store_proto_msgTypes[9] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use DeleteItemResponse.ProtoReflect.Descriptor instead. func (*DeleteItemResponse) Descriptor() ([]byte, []int) { return file_data_store_proto_rawDescGZIP(), []int{9} } type GetItemRequest struct { state protoimpl.MessageState `protogen:"open.v1"` ItemId string `protobuf:"bytes,1,opt,name=item_id,json=itemId,proto3" json:"item_id,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *GetItemRequest) Reset() { *x = GetItemRequest{} mi := &file_data_store_proto_msgTypes[10] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *GetItemRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*GetItemRequest) ProtoMessage() {} func (x *GetItemRequest) ProtoReflect() protoreflect.Message { mi := &file_data_store_proto_msgTypes[10] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use GetItemRequest.ProtoReflect.Descriptor instead. func (*GetItemRequest) Descriptor() ([]byte, []int) { return file_data_store_proto_rawDescGZIP(), []int{10} } func (x *GetItemRequest) GetItemId() string { if x != nil { return x.ItemId } return "" } type GetItemResponse struct { state protoimpl.MessageState `protogen:"open.v1"` Item *Item `protobuf:"bytes,1,opt,name=item,proto3,oneof" json:"item,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *GetItemResponse) Reset() { *x = GetItemResponse{} mi := &file_data_store_proto_msgTypes[11] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *GetItemResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*GetItemResponse) ProtoMessage() {} func (x *GetItemResponse) ProtoReflect() protoreflect.Message { mi := &file_data_store_proto_msgTypes[11] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use GetItemResponse.ProtoReflect.Descriptor instead. func (*GetItemResponse) Descriptor() ([]byte, []int) { return file_data_store_proto_rawDescGZIP(), []int{11} } func (x *GetItemResponse) GetItem() *Item { if x != nil { return x.Item } return nil } type ModifyItemRequest struct { state protoimpl.MessageState `protogen:"open.v1"` ItemId string `protobuf:"bytes,1,opt,name=item_id,json=itemId,proto3" json:"item_id,omitempty"` Patch *ItemPatch `protobuf:"bytes,2,opt,name=patch,proto3" json:"patch,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *ModifyItemRequest) Reset() { *x = ModifyItemRequest{} mi := &file_data_store_proto_msgTypes[12] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *ModifyItemRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*ModifyItemRequest) ProtoMessage() {} func (x *ModifyItemRequest) ProtoReflect() protoreflect.Message { mi := &file_data_store_proto_msgTypes[12] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use ModifyItemRequest.ProtoReflect.Descriptor instead. func (*ModifyItemRequest) Descriptor() ([]byte, []int) { return file_data_store_proto_rawDescGZIP(), []int{12} } func (x *ModifyItemRequest) GetItemId() string { if x != nil { return x.ItemId } return "" } func (x *ModifyItemRequest) GetPatch() *ItemPatch { if x != nil { return x.Patch } return nil } type ModifyItemResponse struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *ModifyItemResponse) Reset() { *x = ModifyItemResponse{} mi := &file_data_store_proto_msgTypes[13] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *ModifyItemResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*ModifyItemResponse) ProtoMessage() {} func (x *ModifyItemResponse) ProtoReflect() protoreflect.Message { mi := &file_data_store_proto_msgTypes[13] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use ModifyItemResponse.ProtoReflect.Descriptor instead. func (*ModifyItemResponse) Descriptor() ([]byte, []int) { return file_data_store_proto_rawDescGZIP(), []int{13} } type GetItemsRequest struct { state protoimpl.MessageState `protogen:"open.v1"` Cursor string `protobuf:"bytes,1,opt,name=cursor,proto3" json:"cursor,omitempty"` N int32 `protobuf:"varint,2,opt,name=n,proto3" json:"n,omitempty"` BeginTime *timestamppb.Timestamp `protobuf:"bytes,3,opt,name=begin_time,json=beginTime,proto3" json:"begin_time,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *GetItemsRequest) Reset() { *x = GetItemsRequest{} mi := &file_data_store_proto_msgTypes[14] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *GetItemsRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*GetItemsRequest) ProtoMessage() {} func (x *GetItemsRequest) ProtoReflect() protoreflect.Message { mi := &file_data_store_proto_msgTypes[14] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use GetItemsRequest.ProtoReflect.Descriptor instead. func (*GetItemsRequest) Descriptor() ([]byte, []int) { return file_data_store_proto_rawDescGZIP(), []int{14} } func (x *GetItemsRequest) GetCursor() string { if x != nil { return x.Cursor } return "" } func (x *GetItemsRequest) GetN() int32 { if x != nil { return x.N } return 0 } func (x *GetItemsRequest) GetBeginTime() *timestamppb.Timestamp { if x != nil { return x.BeginTime } return nil } type GetItemsResponse struct { state protoimpl.MessageState `protogen:"open.v1"` Cursor string `protobuf:"bytes,1,opt,name=cursor,proto3" json:"cursor,omitempty"` Items []*Item `protobuf:"bytes,2,rep,name=items,proto3" json:"items,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *GetItemsResponse) Reset() { *x = GetItemsResponse{} mi := &file_data_store_proto_msgTypes[15] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *GetItemsResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*GetItemsResponse) ProtoMessage() {} func (x *GetItemsResponse) ProtoReflect() protoreflect.Message { mi := &file_data_store_proto_msgTypes[15] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use GetItemsResponse.ProtoReflect.Descriptor instead. func (*GetItemsResponse) Descriptor() ([]byte, []int) { return file_data_store_proto_rawDescGZIP(), []int{15} } func (x *GetItemsResponse) GetCursor() string { if x != nil { return x.Cursor } return "" } func (x *GetItemsResponse) GetItems() []*Item { if x != nil { return x.Items } return nil } type GetItemFeedbackRequest struct { state protoimpl.MessageState `protogen:"open.v1"` ItemId string `protobuf:"bytes,1,opt,name=item_id,json=itemId,proto3" json:"item_id,omitempty"` FeedbackTypes []*FeedbackTypeExpression `protobuf:"bytes,2,rep,name=feedback_types,json=feedbackTypes,proto3" json:"feedback_types,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *GetItemFeedbackRequest) Reset() { *x = GetItemFeedbackRequest{} mi := &file_data_store_proto_msgTypes[16] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *GetItemFeedbackRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*GetItemFeedbackRequest) ProtoMessage() {} func (x *GetItemFeedbackRequest) ProtoReflect() protoreflect.Message { mi := &file_data_store_proto_msgTypes[16] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use GetItemFeedbackRequest.ProtoReflect.Descriptor instead. func (*GetItemFeedbackRequest) Descriptor() ([]byte, []int) { return file_data_store_proto_rawDescGZIP(), []int{16} } func (x *GetItemFeedbackRequest) GetItemId() string { if x != nil { return x.ItemId } return "" } func (x *GetItemFeedbackRequest) GetFeedbackTypes() []*FeedbackTypeExpression { if x != nil { return x.FeedbackTypes } return nil } type BatchInsertUsersRequest struct { state protoimpl.MessageState `protogen:"open.v1"` Users []*User `protobuf:"bytes,1,rep,name=users,proto3" json:"users,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *BatchInsertUsersRequest) Reset() { *x = BatchInsertUsersRequest{} mi := &file_data_store_proto_msgTypes[17] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *BatchInsertUsersRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*BatchInsertUsersRequest) ProtoMessage() {} func (x *BatchInsertUsersRequest) ProtoReflect() protoreflect.Message { mi := &file_data_store_proto_msgTypes[17] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use BatchInsertUsersRequest.ProtoReflect.Descriptor instead. func (*BatchInsertUsersRequest) Descriptor() ([]byte, []int) { return file_data_store_proto_rawDescGZIP(), []int{17} } func (x *BatchInsertUsersRequest) GetUsers() []*User { if x != nil { return x.Users } return nil } type BatchInsertUsersResponse struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *BatchInsertUsersResponse) Reset() { *x = BatchInsertUsersResponse{} mi := &file_data_store_proto_msgTypes[18] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *BatchInsertUsersResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*BatchInsertUsersResponse) ProtoMessage() {} func (x *BatchInsertUsersResponse) ProtoReflect() protoreflect.Message { mi := &file_data_store_proto_msgTypes[18] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use BatchInsertUsersResponse.ProtoReflect.Descriptor instead. func (*BatchInsertUsersResponse) Descriptor() ([]byte, []int) { return file_data_store_proto_rawDescGZIP(), []int{18} } type DeleteUserRequest struct { state protoimpl.MessageState `protogen:"open.v1"` UserId string `protobuf:"bytes,1,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *DeleteUserRequest) Reset() { *x = DeleteUserRequest{} mi := &file_data_store_proto_msgTypes[19] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *DeleteUserRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*DeleteUserRequest) ProtoMessage() {} func (x *DeleteUserRequest) ProtoReflect() protoreflect.Message { mi := &file_data_store_proto_msgTypes[19] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use DeleteUserRequest.ProtoReflect.Descriptor instead. func (*DeleteUserRequest) Descriptor() ([]byte, []int) { return file_data_store_proto_rawDescGZIP(), []int{19} } func (x *DeleteUserRequest) GetUserId() string { if x != nil { return x.UserId } return "" } type DeleteUserResponse struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *DeleteUserResponse) Reset() { *x = DeleteUserResponse{} mi := &file_data_store_proto_msgTypes[20] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *DeleteUserResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*DeleteUserResponse) ProtoMessage() {} func (x *DeleteUserResponse) ProtoReflect() protoreflect.Message { mi := &file_data_store_proto_msgTypes[20] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use DeleteUserResponse.ProtoReflect.Descriptor instead. func (*DeleteUserResponse) Descriptor() ([]byte, []int) { return file_data_store_proto_rawDescGZIP(), []int{20} } type GetUserRequest struct { state protoimpl.MessageState `protogen:"open.v1"` UserId string `protobuf:"bytes,1,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *GetUserRequest) Reset() { *x = GetUserRequest{} mi := &file_data_store_proto_msgTypes[21] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *GetUserRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*GetUserRequest) ProtoMessage() {} func (x *GetUserRequest) ProtoReflect() protoreflect.Message { mi := &file_data_store_proto_msgTypes[21] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use GetUserRequest.ProtoReflect.Descriptor instead. func (*GetUserRequest) Descriptor() ([]byte, []int) { return file_data_store_proto_rawDescGZIP(), []int{21} } func (x *GetUserRequest) GetUserId() string { if x != nil { return x.UserId } return "" } type GetUserResponse struct { state protoimpl.MessageState `protogen:"open.v1"` User *User `protobuf:"bytes,1,opt,name=user,proto3,oneof" json:"user,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *GetUserResponse) Reset() { *x = GetUserResponse{} mi := &file_data_store_proto_msgTypes[22] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *GetUserResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*GetUserResponse) ProtoMessage() {} func (x *GetUserResponse) ProtoReflect() protoreflect.Message { mi := &file_data_store_proto_msgTypes[22] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use GetUserResponse.ProtoReflect.Descriptor instead. func (*GetUserResponse) Descriptor() ([]byte, []int) { return file_data_store_proto_rawDescGZIP(), []int{22} } func (x *GetUserResponse) GetUser() *User { if x != nil { return x.User } return nil } type ModifyUserRequest struct { state protoimpl.MessageState `protogen:"open.v1"` UserId string `protobuf:"bytes,1,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"` Patch *UserPatch `protobuf:"bytes,2,opt,name=patch,proto3" json:"patch,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *ModifyUserRequest) Reset() { *x = ModifyUserRequest{} mi := &file_data_store_proto_msgTypes[23] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *ModifyUserRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*ModifyUserRequest) ProtoMessage() {} func (x *ModifyUserRequest) ProtoReflect() protoreflect.Message { mi := &file_data_store_proto_msgTypes[23] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use ModifyUserRequest.ProtoReflect.Descriptor instead. func (*ModifyUserRequest) Descriptor() ([]byte, []int) { return file_data_store_proto_rawDescGZIP(), []int{23} } func (x *ModifyUserRequest) GetUserId() string { if x != nil { return x.UserId } return "" } func (x *ModifyUserRequest) GetPatch() *UserPatch { if x != nil { return x.Patch } return nil } type ModifyUserResponse struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *ModifyUserResponse) Reset() { *x = ModifyUserResponse{} mi := &file_data_store_proto_msgTypes[24] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *ModifyUserResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*ModifyUserResponse) ProtoMessage() {} func (x *ModifyUserResponse) ProtoReflect() protoreflect.Message { mi := &file_data_store_proto_msgTypes[24] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use ModifyUserResponse.ProtoReflect.Descriptor instead. func (*ModifyUserResponse) Descriptor() ([]byte, []int) { return file_data_store_proto_rawDescGZIP(), []int{24} } type GetUsersRequest struct { state protoimpl.MessageState `protogen:"open.v1"` Cursor string `protobuf:"bytes,1,opt,name=cursor,proto3" json:"cursor,omitempty"` N int32 `protobuf:"varint,2,opt,name=n,proto3" json:"n,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *GetUsersRequest) Reset() { *x = GetUsersRequest{} mi := &file_data_store_proto_msgTypes[25] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *GetUsersRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*GetUsersRequest) ProtoMessage() {} func (x *GetUsersRequest) ProtoReflect() protoreflect.Message { mi := &file_data_store_proto_msgTypes[25] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use GetUsersRequest.ProtoReflect.Descriptor instead. func (*GetUsersRequest) Descriptor() ([]byte, []int) { return file_data_store_proto_rawDescGZIP(), []int{25} } func (x *GetUsersRequest) GetCursor() string { if x != nil { return x.Cursor } return "" } func (x *GetUsersRequest) GetN() int32 { if x != nil { return x.N } return 0 } type GetUsersResponse struct { state protoimpl.MessageState `protogen:"open.v1"` Cursor string `protobuf:"bytes,1,opt,name=cursor,proto3" json:"cursor,omitempty"` Users []*User `protobuf:"bytes,2,rep,name=users,proto3" json:"users,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *GetUsersResponse) Reset() { *x = GetUsersResponse{} mi := &file_data_store_proto_msgTypes[26] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *GetUsersResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*GetUsersResponse) ProtoMessage() {} func (x *GetUsersResponse) ProtoReflect() protoreflect.Message { mi := &file_data_store_proto_msgTypes[26] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use GetUsersResponse.ProtoReflect.Descriptor instead. func (*GetUsersResponse) Descriptor() ([]byte, []int) { return file_data_store_proto_rawDescGZIP(), []int{26} } func (x *GetUsersResponse) GetCursor() string { if x != nil { return x.Cursor } return "" } func (x *GetUsersResponse) GetUsers() []*User { if x != nil { return x.Users } return nil } type GetUserFeedbackRequest struct { state protoimpl.MessageState `protogen:"open.v1"` UserId string `protobuf:"bytes,1,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"` EndTime *timestamppb.Timestamp `protobuf:"bytes,2,opt,name=end_time,json=endTime,proto3" json:"end_time,omitempty"` FeedbackTypes []*FeedbackTypeExpression `protobuf:"bytes,3,rep,name=feedback_types,json=feedbackTypes,proto3" json:"feedback_types,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *GetUserFeedbackRequest) Reset() { *x = GetUserFeedbackRequest{} mi := &file_data_store_proto_msgTypes[27] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *GetUserFeedbackRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*GetUserFeedbackRequest) ProtoMessage() {} func (x *GetUserFeedbackRequest) ProtoReflect() protoreflect.Message { mi := &file_data_store_proto_msgTypes[27] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use GetUserFeedbackRequest.ProtoReflect.Descriptor instead. func (*GetUserFeedbackRequest) Descriptor() ([]byte, []int) { return file_data_store_proto_rawDescGZIP(), []int{27} } func (x *GetUserFeedbackRequest) GetUserId() string { if x != nil { return x.UserId } return "" } func (x *GetUserFeedbackRequest) GetEndTime() *timestamppb.Timestamp { if x != nil { return x.EndTime } return nil } func (x *GetUserFeedbackRequest) GetFeedbackTypes() []*FeedbackTypeExpression { if x != nil { return x.FeedbackTypes } return nil } type GetUserItemFeedbackRequest struct { state protoimpl.MessageState `protogen:"open.v1"` UserId string `protobuf:"bytes,1,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"` ItemId string `protobuf:"bytes,2,opt,name=item_id,json=itemId,proto3" json:"item_id,omitempty"` FeedbackTypes []*FeedbackTypeExpression `protobuf:"bytes,3,rep,name=feedback_types,json=feedbackTypes,proto3" json:"feedback_types,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *GetUserItemFeedbackRequest) Reset() { *x = GetUserItemFeedbackRequest{} mi := &file_data_store_proto_msgTypes[28] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *GetUserItemFeedbackRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*GetUserItemFeedbackRequest) ProtoMessage() {} func (x *GetUserItemFeedbackRequest) ProtoReflect() protoreflect.Message { mi := &file_data_store_proto_msgTypes[28] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use GetUserItemFeedbackRequest.ProtoReflect.Descriptor instead. func (*GetUserItemFeedbackRequest) Descriptor() ([]byte, []int) { return file_data_store_proto_rawDescGZIP(), []int{28} } func (x *GetUserItemFeedbackRequest) GetUserId() string { if x != nil { return x.UserId } return "" } func (x *GetUserItemFeedbackRequest) GetItemId() string { if x != nil { return x.ItemId } return "" } func (x *GetUserItemFeedbackRequest) GetFeedbackTypes() []*FeedbackTypeExpression { if x != nil { return x.FeedbackTypes } return nil } type DeleteUserItemFeedbackRequest struct { state protoimpl.MessageState `protogen:"open.v1"` UserId string `protobuf:"bytes,1,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"` ItemId string `protobuf:"bytes,2,opt,name=item_id,json=itemId,proto3" json:"item_id,omitempty"` FeedbackTypes []string `protobuf:"bytes,3,rep,name=feedback_types,json=feedbackTypes,proto3" json:"feedback_types,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *DeleteUserItemFeedbackRequest) Reset() { *x = DeleteUserItemFeedbackRequest{} mi := &file_data_store_proto_msgTypes[29] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *DeleteUserItemFeedbackRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*DeleteUserItemFeedbackRequest) ProtoMessage() {} func (x *DeleteUserItemFeedbackRequest) ProtoReflect() protoreflect.Message { mi := &file_data_store_proto_msgTypes[29] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use DeleteUserItemFeedbackRequest.ProtoReflect.Descriptor instead. func (*DeleteUserItemFeedbackRequest) Descriptor() ([]byte, []int) { return file_data_store_proto_rawDescGZIP(), []int{29} } func (x *DeleteUserItemFeedbackRequest) GetUserId() string { if x != nil { return x.UserId } return "" } func (x *DeleteUserItemFeedbackRequest) GetItemId() string { if x != nil { return x.ItemId } return "" } func (x *DeleteUserItemFeedbackRequest) GetFeedbackTypes() []string { if x != nil { return x.FeedbackTypes } return nil } type DeleteUserItemFeedbackResponse struct { state protoimpl.MessageState `protogen:"open.v1"` Count int32 `protobuf:"varint,1,opt,name=count,proto3" json:"count,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *DeleteUserItemFeedbackResponse) Reset() { *x = DeleteUserItemFeedbackResponse{} mi := &file_data_store_proto_msgTypes[30] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *DeleteUserItemFeedbackResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*DeleteUserItemFeedbackResponse) ProtoMessage() {} func (x *DeleteUserItemFeedbackResponse) ProtoReflect() protoreflect.Message { mi := &file_data_store_proto_msgTypes[30] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use DeleteUserItemFeedbackResponse.ProtoReflect.Descriptor instead. func (*DeleteUserItemFeedbackResponse) Descriptor() ([]byte, []int) { return file_data_store_proto_rawDescGZIP(), []int{30} } func (x *DeleteUserItemFeedbackResponse) GetCount() int32 { if x != nil { return x.Count } return 0 } type BatchInsertFeedbackRequest struct { state protoimpl.MessageState `protogen:"open.v1"` Feedback []*Feedback `protobuf:"bytes,1,rep,name=feedback,proto3" json:"feedback,omitempty"` InsertUser bool `protobuf:"varint,2,opt,name=insert_user,json=insertUser,proto3" json:"insert_user,omitempty"` InsertItem bool `protobuf:"varint,3,opt,name=insert_item,json=insertItem,proto3" json:"insert_item,omitempty"` Overwrite bool `protobuf:"varint,4,opt,name=overwrite,proto3" json:"overwrite,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *BatchInsertFeedbackRequest) Reset() { *x = BatchInsertFeedbackRequest{} mi := &file_data_store_proto_msgTypes[31] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *BatchInsertFeedbackRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*BatchInsertFeedbackRequest) ProtoMessage() {} func (x *BatchInsertFeedbackRequest) ProtoReflect() protoreflect.Message { mi := &file_data_store_proto_msgTypes[31] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use BatchInsertFeedbackRequest.ProtoReflect.Descriptor instead. func (*BatchInsertFeedbackRequest) Descriptor() ([]byte, []int) { return file_data_store_proto_rawDescGZIP(), []int{31} } func (x *BatchInsertFeedbackRequest) GetFeedback() []*Feedback { if x != nil { return x.Feedback } return nil } func (x *BatchInsertFeedbackRequest) GetInsertUser() bool { if x != nil { return x.InsertUser } return false } func (x *BatchInsertFeedbackRequest) GetInsertItem() bool { if x != nil { return x.InsertItem } return false } func (x *BatchInsertFeedbackRequest) GetOverwrite() bool { if x != nil { return x.Overwrite } return false } type BatchInsertFeedbackResponse struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *BatchInsertFeedbackResponse) Reset() { *x = BatchInsertFeedbackResponse{} mi := &file_data_store_proto_msgTypes[32] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *BatchInsertFeedbackResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*BatchInsertFeedbackResponse) ProtoMessage() {} func (x *BatchInsertFeedbackResponse) ProtoReflect() protoreflect.Message { mi := &file_data_store_proto_msgTypes[32] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use BatchInsertFeedbackResponse.ProtoReflect.Descriptor instead. func (*BatchInsertFeedbackResponse) Descriptor() ([]byte, []int) { return file_data_store_proto_rawDescGZIP(), []int{32} } type GetFeedbackRequest struct { state protoimpl.MessageState `protogen:"open.v1"` Cursor string `protobuf:"bytes,1,opt,name=cursor,proto3" json:"cursor,omitempty"` N int32 `protobuf:"varint,2,opt,name=n,proto3" json:"n,omitempty"` BeginTime *timestamppb.Timestamp `protobuf:"bytes,3,opt,name=begin_time,json=beginTime,proto3" json:"begin_time,omitempty"` EndTime *timestamppb.Timestamp `protobuf:"bytes,4,opt,name=end_time,json=endTime,proto3" json:"end_time,omitempty"` FeedbackTypes []*FeedbackTypeExpression `protobuf:"bytes,5,rep,name=feedback_types,json=feedbackTypes,proto3" json:"feedback_types,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *GetFeedbackRequest) Reset() { *x = GetFeedbackRequest{} mi := &file_data_store_proto_msgTypes[33] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *GetFeedbackRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*GetFeedbackRequest) ProtoMessage() {} func (x *GetFeedbackRequest) ProtoReflect() protoreflect.Message { mi := &file_data_store_proto_msgTypes[33] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use GetFeedbackRequest.ProtoReflect.Descriptor instead. func (*GetFeedbackRequest) Descriptor() ([]byte, []int) { return file_data_store_proto_rawDescGZIP(), []int{33} } func (x *GetFeedbackRequest) GetCursor() string { if x != nil { return x.Cursor } return "" } func (x *GetFeedbackRequest) GetN() int32 { if x != nil { return x.N } return 0 } func (x *GetFeedbackRequest) GetBeginTime() *timestamppb.Timestamp { if x != nil { return x.BeginTime } return nil } func (x *GetFeedbackRequest) GetEndTime() *timestamppb.Timestamp { if x != nil { return x.EndTime } return nil } func (x *GetFeedbackRequest) GetFeedbackTypes() []*FeedbackTypeExpression { if x != nil { return x.FeedbackTypes } return nil } type GetFeedbackResponse struct { state protoimpl.MessageState `protogen:"open.v1"` Cursor string `protobuf:"bytes,1,opt,name=cursor,proto3" json:"cursor,omitempty"` Feedback []*Feedback `protobuf:"bytes,2,rep,name=feedback,proto3" json:"feedback,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *GetFeedbackResponse) Reset() { *x = GetFeedbackResponse{} mi := &file_data_store_proto_msgTypes[34] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *GetFeedbackResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*GetFeedbackResponse) ProtoMessage() {} func (x *GetFeedbackResponse) ProtoReflect() protoreflect.Message { mi := &file_data_store_proto_msgTypes[34] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use GetFeedbackResponse.ProtoReflect.Descriptor instead. func (*GetFeedbackResponse) Descriptor() ([]byte, []int) { return file_data_store_proto_rawDescGZIP(), []int{34} } func (x *GetFeedbackResponse) GetCursor() string { if x != nil { return x.Cursor } return "" } func (x *GetFeedbackResponse) GetFeedback() []*Feedback { if x != nil { return x.Feedback } return nil } type GetUserStreamRequest struct { state protoimpl.MessageState `protogen:"open.v1"` BatchSize int32 `protobuf:"varint,1,opt,name=batch_size,json=batchSize,proto3" json:"batch_size,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *GetUserStreamRequest) Reset() { *x = GetUserStreamRequest{} mi := &file_data_store_proto_msgTypes[35] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *GetUserStreamRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*GetUserStreamRequest) ProtoMessage() {} func (x *GetUserStreamRequest) ProtoReflect() protoreflect.Message { mi := &file_data_store_proto_msgTypes[35] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use GetUserStreamRequest.ProtoReflect.Descriptor instead. func (*GetUserStreamRequest) Descriptor() ([]byte, []int) { return file_data_store_proto_rawDescGZIP(), []int{35} } func (x *GetUserStreamRequest) GetBatchSize() int32 { if x != nil { return x.BatchSize } return 0 } type GetUserStreamResponse struct { state protoimpl.MessageState `protogen:"open.v1"` Users []*User `protobuf:"bytes,1,rep,name=users,proto3" json:"users,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *GetUserStreamResponse) Reset() { *x = GetUserStreamResponse{} mi := &file_data_store_proto_msgTypes[36] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *GetUserStreamResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*GetUserStreamResponse) ProtoMessage() {} func (x *GetUserStreamResponse) ProtoReflect() protoreflect.Message { mi := &file_data_store_proto_msgTypes[36] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use GetUserStreamResponse.ProtoReflect.Descriptor instead. func (*GetUserStreamResponse) Descriptor() ([]byte, []int) { return file_data_store_proto_rawDescGZIP(), []int{36} } func (x *GetUserStreamResponse) GetUsers() []*User { if x != nil { return x.Users } return nil } type GetItemStreamRequest struct { state protoimpl.MessageState `protogen:"open.v1"` BatchSize int32 `protobuf:"varint,1,opt,name=batch_size,json=batchSize,proto3" json:"batch_size,omitempty"` TimeLimit *timestamppb.Timestamp `protobuf:"bytes,2,opt,name=time_limit,json=timeLimit,proto3" json:"time_limit,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *GetItemStreamRequest) Reset() { *x = GetItemStreamRequest{} mi := &file_data_store_proto_msgTypes[37] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *GetItemStreamRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*GetItemStreamRequest) ProtoMessage() {} func (x *GetItemStreamRequest) ProtoReflect() protoreflect.Message { mi := &file_data_store_proto_msgTypes[37] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use GetItemStreamRequest.ProtoReflect.Descriptor instead. func (*GetItemStreamRequest) Descriptor() ([]byte, []int) { return file_data_store_proto_rawDescGZIP(), []int{37} } func (x *GetItemStreamRequest) GetBatchSize() int32 { if x != nil { return x.BatchSize } return 0 } func (x *GetItemStreamRequest) GetTimeLimit() *timestamppb.Timestamp { if x != nil { return x.TimeLimit } return nil } type GetItemStreamResponse struct { state protoimpl.MessageState `protogen:"open.v1"` Items []*Item `protobuf:"bytes,1,rep,name=items,proto3" json:"items,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *GetItemStreamResponse) Reset() { *x = GetItemStreamResponse{} mi := &file_data_store_proto_msgTypes[38] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *GetItemStreamResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*GetItemStreamResponse) ProtoMessage() {} func (x *GetItemStreamResponse) ProtoReflect() protoreflect.Message { mi := &file_data_store_proto_msgTypes[38] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use GetItemStreamResponse.ProtoReflect.Descriptor instead. func (*GetItemStreamResponse) Descriptor() ([]byte, []int) { return file_data_store_proto_rawDescGZIP(), []int{38} } func (x *GetItemStreamResponse) GetItems() []*Item { if x != nil { return x.Items } return nil } type GetFeedbackStreamRequest struct { state protoimpl.MessageState `protogen:"open.v1"` BatchSize int32 `protobuf:"varint,1,opt,name=batch_size,json=batchSize,proto3" json:"batch_size,omitempty"` ScanOptions *ScanOptions `protobuf:"bytes,2,opt,name=scan_options,json=scanOptions,proto3" json:"scan_options,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *GetFeedbackStreamRequest) Reset() { *x = GetFeedbackStreamRequest{} mi := &file_data_store_proto_msgTypes[39] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *GetFeedbackStreamRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*GetFeedbackStreamRequest) ProtoMessage() {} func (x *GetFeedbackStreamRequest) ProtoReflect() protoreflect.Message { mi := &file_data_store_proto_msgTypes[39] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use GetFeedbackStreamRequest.ProtoReflect.Descriptor instead. func (*GetFeedbackStreamRequest) Descriptor() ([]byte, []int) { return file_data_store_proto_rawDescGZIP(), []int{39} } func (x *GetFeedbackStreamRequest) GetBatchSize() int32 { if x != nil { return x.BatchSize } return 0 } func (x *GetFeedbackStreamRequest) GetScanOptions() *ScanOptions { if x != nil { return x.ScanOptions } return nil } type GetFeedbackStreamResponse struct { state protoimpl.MessageState `protogen:"open.v1"` Feedback []*Feedback `protobuf:"bytes,1,rep,name=feedback,proto3" json:"feedback,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *GetFeedbackStreamResponse) Reset() { *x = GetFeedbackStreamResponse{} mi := &file_data_store_proto_msgTypes[40] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *GetFeedbackStreamResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*GetFeedbackStreamResponse) ProtoMessage() {} func (x *GetFeedbackStreamResponse) ProtoReflect() protoreflect.Message { mi := &file_data_store_proto_msgTypes[40] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use GetFeedbackStreamResponse.ProtoReflect.Descriptor instead. func (*GetFeedbackStreamResponse) Descriptor() ([]byte, []int) { return file_data_store_proto_rawDescGZIP(), []int{40} } func (x *GetFeedbackStreamResponse) GetFeedback() []*Feedback { if x != nil { return x.Feedback } return nil } type CountUsersRequest struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *CountUsersRequest) Reset() { *x = CountUsersRequest{} mi := &file_data_store_proto_msgTypes[41] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *CountUsersRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*CountUsersRequest) ProtoMessage() {} func (x *CountUsersRequest) ProtoReflect() protoreflect.Message { mi := &file_data_store_proto_msgTypes[41] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use CountUsersRequest.ProtoReflect.Descriptor instead. func (*CountUsersRequest) Descriptor() ([]byte, []int) { return file_data_store_proto_rawDescGZIP(), []int{41} } type CountUsersResponse struct { state protoimpl.MessageState `protogen:"open.v1"` Count int32 `protobuf:"varint,1,opt,name=count,proto3" json:"count,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *CountUsersResponse) Reset() { *x = CountUsersResponse{} mi := &file_data_store_proto_msgTypes[42] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *CountUsersResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*CountUsersResponse) ProtoMessage() {} func (x *CountUsersResponse) ProtoReflect() protoreflect.Message { mi := &file_data_store_proto_msgTypes[42] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use CountUsersResponse.ProtoReflect.Descriptor instead. func (*CountUsersResponse) Descriptor() ([]byte, []int) { return file_data_store_proto_rawDescGZIP(), []int{42} } func (x *CountUsersResponse) GetCount() int32 { if x != nil { return x.Count } return 0 } type CountItemsRequest struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *CountItemsRequest) Reset() { *x = CountItemsRequest{} mi := &file_data_store_proto_msgTypes[43] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *CountItemsRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*CountItemsRequest) ProtoMessage() {} func (x *CountItemsRequest) ProtoReflect() protoreflect.Message { mi := &file_data_store_proto_msgTypes[43] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use CountItemsRequest.ProtoReflect.Descriptor instead. func (*CountItemsRequest) Descriptor() ([]byte, []int) { return file_data_store_proto_rawDescGZIP(), []int{43} } type CountItemsResponse struct { state protoimpl.MessageState `protogen:"open.v1"` Count int32 `protobuf:"varint,1,opt,name=count,proto3" json:"count,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *CountItemsResponse) Reset() { *x = CountItemsResponse{} mi := &file_data_store_proto_msgTypes[44] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *CountItemsResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*CountItemsResponse) ProtoMessage() {} func (x *CountItemsResponse) ProtoReflect() protoreflect.Message { mi := &file_data_store_proto_msgTypes[44] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use CountItemsResponse.ProtoReflect.Descriptor instead. func (*CountItemsResponse) Descriptor() ([]byte, []int) { return file_data_store_proto_rawDescGZIP(), []int{44} } func (x *CountItemsResponse) GetCount() int32 { if x != nil { return x.Count } return 0 } type CountFeedbackRequest struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *CountFeedbackRequest) Reset() { *x = CountFeedbackRequest{} mi := &file_data_store_proto_msgTypes[45] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *CountFeedbackRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*CountFeedbackRequest) ProtoMessage() {} func (x *CountFeedbackRequest) ProtoReflect() protoreflect.Message { mi := &file_data_store_proto_msgTypes[45] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use CountFeedbackRequest.ProtoReflect.Descriptor instead. func (*CountFeedbackRequest) Descriptor() ([]byte, []int) { return file_data_store_proto_rawDescGZIP(), []int{45} } type CountFeedbackResponse struct { state protoimpl.MessageState `protogen:"open.v1"` Count int32 `protobuf:"varint,1,opt,name=count,proto3" json:"count,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *CountFeedbackResponse) Reset() { *x = CountFeedbackResponse{} mi := &file_data_store_proto_msgTypes[46] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *CountFeedbackResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*CountFeedbackResponse) ProtoMessage() {} func (x *CountFeedbackResponse) ProtoReflect() protoreflect.Message { mi := &file_data_store_proto_msgTypes[46] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use CountFeedbackResponse.ProtoReflect.Descriptor instead. func (*CountFeedbackResponse) Descriptor() ([]byte, []int) { return file_data_store_proto_rawDescGZIP(), []int{46} } func (x *CountFeedbackResponse) GetCount() int32 { if x != nil { return x.Count } return 0 } type GetLatestItemsRequest struct { state protoimpl.MessageState `protogen:"open.v1"` N int32 `protobuf:"varint,1,opt,name=n,proto3" json:"n,omitempty"` Categories []string `protobuf:"bytes,2,rep,name=categories,proto3" json:"categories,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *GetLatestItemsRequest) Reset() { *x = GetLatestItemsRequest{} mi := &file_data_store_proto_msgTypes[47] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *GetLatestItemsRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*GetLatestItemsRequest) ProtoMessage() {} func (x *GetLatestItemsRequest) ProtoReflect() protoreflect.Message { mi := &file_data_store_proto_msgTypes[47] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use GetLatestItemsRequest.ProtoReflect.Descriptor instead. func (*GetLatestItemsRequest) Descriptor() ([]byte, []int) { return file_data_store_proto_rawDescGZIP(), []int{47} } func (x *GetLatestItemsRequest) GetN() int32 { if x != nil { return x.N } return 0 } func (x *GetLatestItemsRequest) GetCategories() []string { if x != nil { return x.Categories } return nil } type GetLatestItemsResponse struct { state protoimpl.MessageState `protogen:"open.v1"` Items []*Item `protobuf:"bytes,1,rep,name=items,proto3" json:"items,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *GetLatestItemsResponse) Reset() { *x = GetLatestItemsResponse{} mi := &file_data_store_proto_msgTypes[48] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *GetLatestItemsResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*GetLatestItemsResponse) ProtoMessage() {} func (x *GetLatestItemsResponse) ProtoReflect() protoreflect.Message { mi := &file_data_store_proto_msgTypes[48] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use GetLatestItemsResponse.ProtoReflect.Descriptor instead. func (*GetLatestItemsResponse) Descriptor() ([]byte, []int) { return file_data_store_proto_rawDescGZIP(), []int{48} } func (x *GetLatestItemsResponse) GetItems() []*Item { if x != nil { return x.Items } return nil } var File_data_store_proto protoreflect.FileDescriptor const file_data_store_proto_rawDesc = "" + "\n" + "\x10data_store.proto\x12\bprotocol\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x0eprotocol.proto\"\x96\x01\n" + "\x16FeedbackTypeExpression\x12#\n" + "\rfeedback_type\x18\x01 \x01(\tR\ffeedbackType\x12A\n" + "\x0fexpression_type\x18\x02 \x01(\x0e2\x18.protocol.ExpressionTypeR\x0eexpressionType\x12\x14\n" + "\x05value\x18\x03 \x01(\x01R\x05value\"l\n" + "\tUserPatch\x12\x16\n" + "\x06labels\x18\x01 \x01(\fR\x06labels\x12\x1d\n" + "\acomment\x18\x02 \x01(\tH\x00R\acomment\x88\x01\x01\x12\x1c\n" + "\tsubscribe\x18\x03 \x03(\tR\tsubscribeB\n" + "\n" + "\b_comment\"\xeb\x01\n" + "\tItemPatch\x12 \n" + "\tis_hidden\x18\x01 \x01(\bH\x00R\bisHidden\x88\x01\x01\x12\x1e\n" + "\n" + "categories\x18\x02 \x03(\tR\n" + "categories\x12=\n" + "\ttimestamp\x18\x03 \x01(\v2\x1a.google.protobuf.TimestampH\x01R\ttimestamp\x88\x01\x01\x12\x16\n" + "\x06labels\x18\x04 \x01(\fR\x06labels\x12\x1d\n" + "\acomment\x18\x05 \x01(\tH\x02R\acomment\x88\x01\x01B\f\n" + "\n" + "_is_hiddenB\f\n" + "\n" + "_timestampB\n" + "\n" + "\b_comment\"\xf7\x03\n" + "\vScanOptions\x12'\n" + "\rbegin_user_id\x18\x01 \x01(\tH\x00R\vbeginUserId\x88\x01\x01\x12#\n" + "\vend_user_id\x18\x02 \x01(\tH\x01R\tendUserId\x88\x01\x01\x12'\n" + "\rbegin_item_id\x18\x03 \x01(\tH\x02R\vbeginItemId\x88\x01\x01\x12#\n" + "\vend_item_id\x18\x04 \x01(\tH\x03R\tendItemId\x88\x01\x01\x12>\n" + "\n" + "begin_time\x18\x05 \x01(\v2\x1a.google.protobuf.TimestampH\x04R\tbeginTime\x88\x01\x01\x12:\n" + "\bend_time\x18\x06 \x01(\v2\x1a.google.protobuf.TimestampH\x05R\aendTime\x88\x01\x01\x12G\n" + "\x0efeedback_types\x18\a \x03(\v2 .protocol.FeedbackTypeExpressionR\rfeedbackTypes\x12'\n" + "\x10order_by_item_id\x18\b \x01(\bR\rorderByItemIdB\x10\n" + "\x0e_begin_user_idB\x0e\n" + "\f_end_user_idB\x10\n" + "\x0e_begin_item_idB\x0e\n" + "\f_end_item_idB\r\n" + "\v_begin_timeB\v\n" + "\t_end_time\"?\n" + "\x17BatchInsertItemsRequest\x12$\n" + "\x05items\x18\x01 \x03(\v2\x0e.protocol.ItemR\x05items\"\x1a\n" + "\x18BatchInsertItemsResponse\"1\n" + "\x14BatchGetItemsRequest\x12\x19\n" + "\bitem_ids\x18\x01 \x03(\tR\aitemIds\"=\n" + "\x15BatchGetItemsResponse\x12$\n" + "\x05items\x18\x01 \x03(\v2\x0e.protocol.ItemR\x05items\",\n" + "\x11DeleteItemRequest\x12\x17\n" + "\aitem_id\x18\x01 \x01(\tR\x06itemId\"\x14\n" + "\x12DeleteItemResponse\")\n" + "\x0eGetItemRequest\x12\x17\n" + "\aitem_id\x18\x01 \x01(\tR\x06itemId\"C\n" + "\x0fGetItemResponse\x12'\n" + "\x04item\x18\x01 \x01(\v2\x0e.protocol.ItemH\x00R\x04item\x88\x01\x01B\a\n" + "\x05_item\"W\n" + "\x11ModifyItemRequest\x12\x17\n" + "\aitem_id\x18\x01 \x01(\tR\x06itemId\x12)\n" + "\x05patch\x18\x02 \x01(\v2\x13.protocol.ItemPatchR\x05patch\"\x14\n" + "\x12ModifyItemResponse\"r\n" + "\x0fGetItemsRequest\x12\x16\n" + "\x06cursor\x18\x01 \x01(\tR\x06cursor\x12\f\n" + "\x01n\x18\x02 \x01(\x05R\x01n\x129\n" + "\n" + "begin_time\x18\x03 \x01(\v2\x1a.google.protobuf.TimestampR\tbeginTime\"P\n" + "\x10GetItemsResponse\x12\x16\n" + "\x06cursor\x18\x01 \x01(\tR\x06cursor\x12$\n" + "\x05items\x18\x02 \x03(\v2\x0e.protocol.ItemR\x05items\"z\n" + "\x16GetItemFeedbackRequest\x12\x17\n" + "\aitem_id\x18\x01 \x01(\tR\x06itemId\x12G\n" + "\x0efeedback_types\x18\x02 \x03(\v2 .protocol.FeedbackTypeExpressionR\rfeedbackTypes\"?\n" + "\x17BatchInsertUsersRequest\x12$\n" + "\x05users\x18\x01 \x03(\v2\x0e.protocol.UserR\x05users\"\x1a\n" + "\x18BatchInsertUsersResponse\",\n" + "\x11DeleteUserRequest\x12\x17\n" + "\auser_id\x18\x01 \x01(\tR\x06userId\"\x14\n" + "\x12DeleteUserResponse\")\n" + "\x0eGetUserRequest\x12\x17\n" + "\auser_id\x18\x01 \x01(\tR\x06userId\"C\n" + "\x0fGetUserResponse\x12'\n" + "\x04user\x18\x01 \x01(\v2\x0e.protocol.UserH\x00R\x04user\x88\x01\x01B\a\n" + "\x05_user\"W\n" + "\x11ModifyUserRequest\x12\x17\n" + "\auser_id\x18\x01 \x01(\tR\x06userId\x12)\n" + "\x05patch\x18\x02 \x01(\v2\x13.protocol.UserPatchR\x05patch\"\x14\n" + "\x12ModifyUserResponse\"7\n" + "\x0fGetUsersRequest\x12\x16\n" + "\x06cursor\x18\x01 \x01(\tR\x06cursor\x12\f\n" + "\x01n\x18\x02 \x01(\x05R\x01n\"P\n" + "\x10GetUsersResponse\x12\x16\n" + "\x06cursor\x18\x01 \x01(\tR\x06cursor\x12$\n" + "\x05users\x18\x02 \x03(\v2\x0e.protocol.UserR\x05users\"\xb1\x01\n" + "\x16GetUserFeedbackRequest\x12\x17\n" + "\auser_id\x18\x01 \x01(\tR\x06userId\x125\n" + "\bend_time\x18\x02 \x01(\v2\x1a.google.protobuf.TimestampR\aendTime\x12G\n" + "\x0efeedback_types\x18\x03 \x03(\v2 .protocol.FeedbackTypeExpressionR\rfeedbackTypes\"\x97\x01\n" + "\x1aGetUserItemFeedbackRequest\x12\x17\n" + "\auser_id\x18\x01 \x01(\tR\x06userId\x12\x17\n" + "\aitem_id\x18\x02 \x01(\tR\x06itemId\x12G\n" + "\x0efeedback_types\x18\x03 \x03(\v2 .protocol.FeedbackTypeExpressionR\rfeedbackTypes\"x\n" + "\x1dDeleteUserItemFeedbackRequest\x12\x17\n" + "\auser_id\x18\x01 \x01(\tR\x06userId\x12\x17\n" + "\aitem_id\x18\x02 \x01(\tR\x06itemId\x12%\n" + "\x0efeedback_types\x18\x03 \x03(\tR\rfeedbackTypes\"6\n" + "\x1eDeleteUserItemFeedbackResponse\x12\x14\n" + "\x05count\x18\x01 \x01(\x05R\x05count\"\xac\x01\n" + "\x1aBatchInsertFeedbackRequest\x12.\n" + "\bfeedback\x18\x01 \x03(\v2\x12.protocol.FeedbackR\bfeedback\x12\x1f\n" + "\vinsert_user\x18\x02 \x01(\bR\n" + "insertUser\x12\x1f\n" + "\vinsert_item\x18\x03 \x01(\bR\n" + "insertItem\x12\x1c\n" + "\toverwrite\x18\x04 \x01(\bR\toverwrite\"\x1d\n" + "\x1bBatchInsertFeedbackResponse\"\xf5\x01\n" + "\x12GetFeedbackRequest\x12\x16\n" + "\x06cursor\x18\x01 \x01(\tR\x06cursor\x12\f\n" + "\x01n\x18\x02 \x01(\x05R\x01n\x129\n" + "\n" + "begin_time\x18\x03 \x01(\v2\x1a.google.protobuf.TimestampR\tbeginTime\x125\n" + "\bend_time\x18\x04 \x01(\v2\x1a.google.protobuf.TimestampR\aendTime\x12G\n" + "\x0efeedback_types\x18\x05 \x03(\v2 .protocol.FeedbackTypeExpressionR\rfeedbackTypes\"]\n" + "\x13GetFeedbackResponse\x12\x16\n" + "\x06cursor\x18\x01 \x01(\tR\x06cursor\x12.\n" + "\bfeedback\x18\x02 \x03(\v2\x12.protocol.FeedbackR\bfeedback\"5\n" + "\x14GetUserStreamRequest\x12\x1d\n" + "\n" + "batch_size\x18\x01 \x01(\x05R\tbatchSize\"=\n" + "\x15GetUserStreamResponse\x12$\n" + "\x05users\x18\x01 \x03(\v2\x0e.protocol.UserR\x05users\"p\n" + "\x14GetItemStreamRequest\x12\x1d\n" + "\n" + "batch_size\x18\x01 \x01(\x05R\tbatchSize\x129\n" + "\n" + "time_limit\x18\x02 \x01(\v2\x1a.google.protobuf.TimestampR\ttimeLimit\"=\n" + "\x15GetItemStreamResponse\x12$\n" + "\x05items\x18\x01 \x03(\v2\x0e.protocol.ItemR\x05items\"s\n" + "\x18GetFeedbackStreamRequest\x12\x1d\n" + "\n" + "batch_size\x18\x01 \x01(\x05R\tbatchSize\x128\n" + "\fscan_options\x18\x02 \x01(\v2\x15.protocol.ScanOptionsR\vscanOptions\"K\n" + "\x19GetFeedbackStreamResponse\x12.\n" + "\bfeedback\x18\x01 \x03(\v2\x12.protocol.FeedbackR\bfeedback\"\x13\n" + "\x11CountUsersRequest\"*\n" + "\x12CountUsersResponse\x12\x14\n" + "\x05count\x18\x01 \x01(\x05R\x05count\"\x13\n" + "\x11CountItemsRequest\"*\n" + "\x12CountItemsResponse\x12\x14\n" + "\x05count\x18\x01 \x01(\x05R\x05count\"\x16\n" + "\x14CountFeedbackRequest\"-\n" + "\x15CountFeedbackResponse\x12\x14\n" + "\x05count\x18\x01 \x01(\x05R\x05count\"E\n" + "\x15GetLatestItemsRequest\x12\f\n" + "\x01n\x18\x01 \x01(\x05R\x01n\x12\x1e\n" + "\n" + "categories\x18\x02 \x03(\tR\n" + "categories\">\n" + "\x16GetLatestItemsResponse\x12$\n" + "\x05items\x18\x01 \x03(\v2\x0e.protocol.ItemR\x05items*V\n" + "\x0eExpressionType\x12\b\n" + "\x04None\x10\x00\x12\b\n" + "\x04Less\x10\x01\x12\x0f\n" + "\vLessOrEqual\x10\x02\x12\v\n" + "\aGreater\x10\x03\x12\x12\n" + "\x0eGreaterOrEqual\x10\x042\x88\x10\n" + "\tDataStore\x127\n" + "\x04Ping\x12\x15.protocol.PingRequest\x1a\x16.protocol.PingResponse\"\x00\x12[\n" + "\x10BatchInsertItems\x12!.protocol.BatchInsertItemsRequest\x1a\".protocol.BatchInsertItemsResponse\"\x00\x12R\n" + "\rBatchGetItems\x12\x1e.protocol.BatchGetItemsRequest\x1a\x1f.protocol.BatchGetItemsResponse\"\x00\x12I\n" + "\n" + "DeleteItem\x12\x1b.protocol.DeleteItemRequest\x1a\x1c.protocol.DeleteItemResponse\"\x00\x12@\n" + "\aGetItem\x12\x18.protocol.GetItemRequest\x1a\x19.protocol.GetItemResponse\"\x00\x12I\n" + "\n" + "ModifyItem\x12\x1b.protocol.ModifyItemRequest\x1a\x1c.protocol.ModifyItemResponse\"\x00\x12C\n" + "\bGetItems\x12\x19.protocol.GetItemsRequest\x1a\x1a.protocol.GetItemsResponse\"\x00\x12T\n" + "\x0fGetItemFeedback\x12 .protocol.GetItemFeedbackRequest\x1a\x1d.protocol.GetFeedbackResponse\"\x00\x12[\n" + "\x10BatchInsertUsers\x12!.protocol.BatchInsertUsersRequest\x1a\".protocol.BatchInsertUsersResponse\"\x00\x12I\n" + "\n" + "DeleteUser\x12\x1b.protocol.DeleteUserRequest\x1a\x1c.protocol.DeleteUserResponse\"\x00\x12@\n" + "\aGetUser\x12\x18.protocol.GetUserRequest\x1a\x19.protocol.GetUserResponse\"\x00\x12I\n" + "\n" + "ModifyUser\x12\x1b.protocol.ModifyUserRequest\x1a\x1c.protocol.ModifyUserResponse\"\x00\x12C\n" + "\bGetUsers\x12\x19.protocol.GetUsersRequest\x1a\x1a.protocol.GetUsersResponse\"\x00\x12T\n" + "\x0fGetUserFeedback\x12 .protocol.GetUserFeedbackRequest\x1a\x1d.protocol.GetFeedbackResponse\"\x00\x12\\\n" + "\x13GetUserItemFeedback\x12$.protocol.GetUserItemFeedbackRequest\x1a\x1d.protocol.GetFeedbackResponse\"\x00\x12m\n" + "\x16DeleteUserItemFeedback\x12'.protocol.DeleteUserItemFeedbackRequest\x1a(.protocol.DeleteUserItemFeedbackResponse\"\x00\x12d\n" + "\x13BatchInsertFeedback\x12$.protocol.BatchInsertFeedbackRequest\x1a%.protocol.BatchInsertFeedbackResponse\"\x00\x12L\n" + "\vGetFeedback\x12\x1c.protocol.GetFeedbackRequest\x1a\x1d.protocol.GetFeedbackResponse\"\x00\x12T\n" + "\rGetUserStream\x12\x1e.protocol.GetUserStreamRequest\x1a\x1f.protocol.GetUserStreamResponse\"\x000\x01\x12T\n" + "\rGetItemStream\x12\x1e.protocol.GetItemStreamRequest\x1a\x1f.protocol.GetItemStreamResponse\"\x000\x01\x12`\n" + "\x11GetFeedbackStream\x12\".protocol.GetFeedbackStreamRequest\x1a#.protocol.GetFeedbackStreamResponse\"\x000\x01\x12I\n" + "\n" + "CountUsers\x12\x1b.protocol.CountUsersRequest\x1a\x1c.protocol.CountUsersResponse\"\x00\x12I\n" + "\n" + "CountItems\x12\x1b.protocol.CountItemsRequest\x1a\x1c.protocol.CountItemsResponse\"\x00\x12R\n" + "\rCountFeedback\x12\x1e.protocol.CountFeedbackRequest\x1a\x1f.protocol.CountFeedbackResponse\"\x00\x12U\n" + "\x0eGetLatestItems\x12\x1f.protocol.GetLatestItemsRequest\x1a .protocol.GetLatestItemsResponse\"\x00B$Z\"github.com/gorse-io/gorse/protocolb\x06proto3" var ( file_data_store_proto_rawDescOnce sync.Once file_data_store_proto_rawDescData []byte ) func file_data_store_proto_rawDescGZIP() []byte { file_data_store_proto_rawDescOnce.Do(func() { file_data_store_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_data_store_proto_rawDesc), len(file_data_store_proto_rawDesc))) }) return file_data_store_proto_rawDescData } var file_data_store_proto_enumTypes = make([]protoimpl.EnumInfo, 1) var file_data_store_proto_msgTypes = make([]protoimpl.MessageInfo, 49) var file_data_store_proto_goTypes = []any{ (ExpressionType)(0), // 0: protocol.ExpressionType (*FeedbackTypeExpression)(nil), // 1: protocol.FeedbackTypeExpression (*UserPatch)(nil), // 2: protocol.UserPatch (*ItemPatch)(nil), // 3: protocol.ItemPatch (*ScanOptions)(nil), // 4: protocol.ScanOptions (*BatchInsertItemsRequest)(nil), // 5: protocol.BatchInsertItemsRequest (*BatchInsertItemsResponse)(nil), // 6: protocol.BatchInsertItemsResponse (*BatchGetItemsRequest)(nil), // 7: protocol.BatchGetItemsRequest (*BatchGetItemsResponse)(nil), // 8: protocol.BatchGetItemsResponse (*DeleteItemRequest)(nil), // 9: protocol.DeleteItemRequest (*DeleteItemResponse)(nil), // 10: protocol.DeleteItemResponse (*GetItemRequest)(nil), // 11: protocol.GetItemRequest (*GetItemResponse)(nil), // 12: protocol.GetItemResponse (*ModifyItemRequest)(nil), // 13: protocol.ModifyItemRequest (*ModifyItemResponse)(nil), // 14: protocol.ModifyItemResponse (*GetItemsRequest)(nil), // 15: protocol.GetItemsRequest (*GetItemsResponse)(nil), // 16: protocol.GetItemsResponse (*GetItemFeedbackRequest)(nil), // 17: protocol.GetItemFeedbackRequest (*BatchInsertUsersRequest)(nil), // 18: protocol.BatchInsertUsersRequest (*BatchInsertUsersResponse)(nil), // 19: protocol.BatchInsertUsersResponse (*DeleteUserRequest)(nil), // 20: protocol.DeleteUserRequest (*DeleteUserResponse)(nil), // 21: protocol.DeleteUserResponse (*GetUserRequest)(nil), // 22: protocol.GetUserRequest (*GetUserResponse)(nil), // 23: protocol.GetUserResponse (*ModifyUserRequest)(nil), // 24: protocol.ModifyUserRequest (*ModifyUserResponse)(nil), // 25: protocol.ModifyUserResponse (*GetUsersRequest)(nil), // 26: protocol.GetUsersRequest (*GetUsersResponse)(nil), // 27: protocol.GetUsersResponse (*GetUserFeedbackRequest)(nil), // 28: protocol.GetUserFeedbackRequest (*GetUserItemFeedbackRequest)(nil), // 29: protocol.GetUserItemFeedbackRequest (*DeleteUserItemFeedbackRequest)(nil), // 30: protocol.DeleteUserItemFeedbackRequest (*DeleteUserItemFeedbackResponse)(nil), // 31: protocol.DeleteUserItemFeedbackResponse (*BatchInsertFeedbackRequest)(nil), // 32: protocol.BatchInsertFeedbackRequest (*BatchInsertFeedbackResponse)(nil), // 33: protocol.BatchInsertFeedbackResponse (*GetFeedbackRequest)(nil), // 34: protocol.GetFeedbackRequest (*GetFeedbackResponse)(nil), // 35: protocol.GetFeedbackResponse (*GetUserStreamRequest)(nil), // 36: protocol.GetUserStreamRequest (*GetUserStreamResponse)(nil), // 37: protocol.GetUserStreamResponse (*GetItemStreamRequest)(nil), // 38: protocol.GetItemStreamRequest (*GetItemStreamResponse)(nil), // 39: protocol.GetItemStreamResponse (*GetFeedbackStreamRequest)(nil), // 40: protocol.GetFeedbackStreamRequest (*GetFeedbackStreamResponse)(nil), // 41: protocol.GetFeedbackStreamResponse (*CountUsersRequest)(nil), // 42: protocol.CountUsersRequest (*CountUsersResponse)(nil), // 43: protocol.CountUsersResponse (*CountItemsRequest)(nil), // 44: protocol.CountItemsRequest (*CountItemsResponse)(nil), // 45: protocol.CountItemsResponse (*CountFeedbackRequest)(nil), // 46: protocol.CountFeedbackRequest (*CountFeedbackResponse)(nil), // 47: protocol.CountFeedbackResponse (*GetLatestItemsRequest)(nil), // 48: protocol.GetLatestItemsRequest (*GetLatestItemsResponse)(nil), // 49: protocol.GetLatestItemsResponse (*timestamppb.Timestamp)(nil), // 50: google.protobuf.Timestamp (*Item)(nil), // 51: protocol.Item (*User)(nil), // 52: protocol.User (*Feedback)(nil), // 53: protocol.Feedback (*PingRequest)(nil), // 54: protocol.PingRequest (*PingResponse)(nil), // 55: protocol.PingResponse } var file_data_store_proto_depIdxs = []int32{ 0, // 0: protocol.FeedbackTypeExpression.expression_type:type_name -> protocol.ExpressionType 50, // 1: protocol.ItemPatch.timestamp:type_name -> google.protobuf.Timestamp 50, // 2: protocol.ScanOptions.begin_time:type_name -> google.protobuf.Timestamp 50, // 3: protocol.ScanOptions.end_time:type_name -> google.protobuf.Timestamp 1, // 4: protocol.ScanOptions.feedback_types:type_name -> protocol.FeedbackTypeExpression 51, // 5: protocol.BatchInsertItemsRequest.items:type_name -> protocol.Item 51, // 6: protocol.BatchGetItemsResponse.items:type_name -> protocol.Item 51, // 7: protocol.GetItemResponse.item:type_name -> protocol.Item 3, // 8: protocol.ModifyItemRequest.patch:type_name -> protocol.ItemPatch 50, // 9: protocol.GetItemsRequest.begin_time:type_name -> google.protobuf.Timestamp 51, // 10: protocol.GetItemsResponse.items:type_name -> protocol.Item 1, // 11: protocol.GetItemFeedbackRequest.feedback_types:type_name -> protocol.FeedbackTypeExpression 52, // 12: protocol.BatchInsertUsersRequest.users:type_name -> protocol.User 52, // 13: protocol.GetUserResponse.user:type_name -> protocol.User 2, // 14: protocol.ModifyUserRequest.patch:type_name -> protocol.UserPatch 52, // 15: protocol.GetUsersResponse.users:type_name -> protocol.User 50, // 16: protocol.GetUserFeedbackRequest.end_time:type_name -> google.protobuf.Timestamp 1, // 17: protocol.GetUserFeedbackRequest.feedback_types:type_name -> protocol.FeedbackTypeExpression 1, // 18: protocol.GetUserItemFeedbackRequest.feedback_types:type_name -> protocol.FeedbackTypeExpression 53, // 19: protocol.BatchInsertFeedbackRequest.feedback:type_name -> protocol.Feedback 50, // 20: protocol.GetFeedbackRequest.begin_time:type_name -> google.protobuf.Timestamp 50, // 21: protocol.GetFeedbackRequest.end_time:type_name -> google.protobuf.Timestamp 1, // 22: protocol.GetFeedbackRequest.feedback_types:type_name -> protocol.FeedbackTypeExpression 53, // 23: protocol.GetFeedbackResponse.feedback:type_name -> protocol.Feedback 52, // 24: protocol.GetUserStreamResponse.users:type_name -> protocol.User 50, // 25: protocol.GetItemStreamRequest.time_limit:type_name -> google.protobuf.Timestamp 51, // 26: protocol.GetItemStreamResponse.items:type_name -> protocol.Item 4, // 27: protocol.GetFeedbackStreamRequest.scan_options:type_name -> protocol.ScanOptions 53, // 28: protocol.GetFeedbackStreamResponse.feedback:type_name -> protocol.Feedback 51, // 29: protocol.GetLatestItemsResponse.items:type_name -> protocol.Item 54, // 30: protocol.DataStore.Ping:input_type -> protocol.PingRequest 5, // 31: protocol.DataStore.BatchInsertItems:input_type -> protocol.BatchInsertItemsRequest 7, // 32: protocol.DataStore.BatchGetItems:input_type -> protocol.BatchGetItemsRequest 9, // 33: protocol.DataStore.DeleteItem:input_type -> protocol.DeleteItemRequest 11, // 34: protocol.DataStore.GetItem:input_type -> protocol.GetItemRequest 13, // 35: protocol.DataStore.ModifyItem:input_type -> protocol.ModifyItemRequest 15, // 36: protocol.DataStore.GetItems:input_type -> protocol.GetItemsRequest 17, // 37: protocol.DataStore.GetItemFeedback:input_type -> protocol.GetItemFeedbackRequest 18, // 38: protocol.DataStore.BatchInsertUsers:input_type -> protocol.BatchInsertUsersRequest 20, // 39: protocol.DataStore.DeleteUser:input_type -> protocol.DeleteUserRequest 22, // 40: protocol.DataStore.GetUser:input_type -> protocol.GetUserRequest 24, // 41: protocol.DataStore.ModifyUser:input_type -> protocol.ModifyUserRequest 26, // 42: protocol.DataStore.GetUsers:input_type -> protocol.GetUsersRequest 28, // 43: protocol.DataStore.GetUserFeedback:input_type -> protocol.GetUserFeedbackRequest 29, // 44: protocol.DataStore.GetUserItemFeedback:input_type -> protocol.GetUserItemFeedbackRequest 30, // 45: protocol.DataStore.DeleteUserItemFeedback:input_type -> protocol.DeleteUserItemFeedbackRequest 32, // 46: protocol.DataStore.BatchInsertFeedback:input_type -> protocol.BatchInsertFeedbackRequest 34, // 47: protocol.DataStore.GetFeedback:input_type -> protocol.GetFeedbackRequest 36, // 48: protocol.DataStore.GetUserStream:input_type -> protocol.GetUserStreamRequest 38, // 49: protocol.DataStore.GetItemStream:input_type -> protocol.GetItemStreamRequest 40, // 50: protocol.DataStore.GetFeedbackStream:input_type -> protocol.GetFeedbackStreamRequest 42, // 51: protocol.DataStore.CountUsers:input_type -> protocol.CountUsersRequest 44, // 52: protocol.DataStore.CountItems:input_type -> protocol.CountItemsRequest 46, // 53: protocol.DataStore.CountFeedback:input_type -> protocol.CountFeedbackRequest 48, // 54: protocol.DataStore.GetLatestItems:input_type -> protocol.GetLatestItemsRequest 55, // 55: protocol.DataStore.Ping:output_type -> protocol.PingResponse 6, // 56: protocol.DataStore.BatchInsertItems:output_type -> protocol.BatchInsertItemsResponse 8, // 57: protocol.DataStore.BatchGetItems:output_type -> protocol.BatchGetItemsResponse 10, // 58: protocol.DataStore.DeleteItem:output_type -> protocol.DeleteItemResponse 12, // 59: protocol.DataStore.GetItem:output_type -> protocol.GetItemResponse 14, // 60: protocol.DataStore.ModifyItem:output_type -> protocol.ModifyItemResponse 16, // 61: protocol.DataStore.GetItems:output_type -> protocol.GetItemsResponse 35, // 62: protocol.DataStore.GetItemFeedback:output_type -> protocol.GetFeedbackResponse 19, // 63: protocol.DataStore.BatchInsertUsers:output_type -> protocol.BatchInsertUsersResponse 21, // 64: protocol.DataStore.DeleteUser:output_type -> protocol.DeleteUserResponse 23, // 65: protocol.DataStore.GetUser:output_type -> protocol.GetUserResponse 25, // 66: protocol.DataStore.ModifyUser:output_type -> protocol.ModifyUserResponse 27, // 67: protocol.DataStore.GetUsers:output_type -> protocol.GetUsersResponse 35, // 68: protocol.DataStore.GetUserFeedback:output_type -> protocol.GetFeedbackResponse 35, // 69: protocol.DataStore.GetUserItemFeedback:output_type -> protocol.GetFeedbackResponse 31, // 70: protocol.DataStore.DeleteUserItemFeedback:output_type -> protocol.DeleteUserItemFeedbackResponse 33, // 71: protocol.DataStore.BatchInsertFeedback:output_type -> protocol.BatchInsertFeedbackResponse 35, // 72: protocol.DataStore.GetFeedback:output_type -> protocol.GetFeedbackResponse 37, // 73: protocol.DataStore.GetUserStream:output_type -> protocol.GetUserStreamResponse 39, // 74: protocol.DataStore.GetItemStream:output_type -> protocol.GetItemStreamResponse 41, // 75: protocol.DataStore.GetFeedbackStream:output_type -> protocol.GetFeedbackStreamResponse 43, // 76: protocol.DataStore.CountUsers:output_type -> protocol.CountUsersResponse 45, // 77: protocol.DataStore.CountItems:output_type -> protocol.CountItemsResponse 47, // 78: protocol.DataStore.CountFeedback:output_type -> protocol.CountFeedbackResponse 49, // 79: protocol.DataStore.GetLatestItems:output_type -> protocol.GetLatestItemsResponse 55, // [55:80] is the sub-list for method output_type 30, // [30:55] is the sub-list for method input_type 30, // [30:30] is the sub-list for extension type_name 30, // [30:30] is the sub-list for extension extendee 0, // [0:30] is the sub-list for field type_name } func init() { file_data_store_proto_init() } func file_data_store_proto_init() { if File_data_store_proto != nil { return } file_protocol_proto_init() file_data_store_proto_msgTypes[1].OneofWrappers = []any{} file_data_store_proto_msgTypes[2].OneofWrappers = []any{} file_data_store_proto_msgTypes[3].OneofWrappers = []any{} file_data_store_proto_msgTypes[11].OneofWrappers = []any{} file_data_store_proto_msgTypes[22].OneofWrappers = []any{} type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_data_store_proto_rawDesc), len(file_data_store_proto_rawDesc)), NumEnums: 1, NumMessages: 49, NumExtensions: 0, NumServices: 1, }, GoTypes: file_data_store_proto_goTypes, DependencyIndexes: file_data_store_proto_depIdxs, EnumInfos: file_data_store_proto_enumTypes, MessageInfos: file_data_store_proto_msgTypes, }.Build() File_data_store_proto = out.File file_data_store_proto_goTypes = nil file_data_store_proto_depIdxs = nil } ================================================ FILE: protocol/data_store.proto ================================================ // Copyright 2024 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. syntax = "proto3"; option go_package = "github.com/gorse-io/gorse/protocol"; package protocol; import "google/protobuf/timestamp.proto"; import "protocol.proto"; enum ExpressionType { None = 0; Less = 1; LessOrEqual = 2; Greater = 3; GreaterOrEqual = 4; } message FeedbackTypeExpression { string feedback_type = 1; ExpressionType expression_type = 2; double value = 3; } message UserPatch { bytes labels = 1; optional string comment = 2; repeated string subscribe = 3; } message ItemPatch { optional bool is_hidden = 1; repeated string categories = 2; optional google.protobuf.Timestamp timestamp = 3; bytes labels = 4; optional string comment = 5; } message ScanOptions { optional string begin_user_id = 1; optional string end_user_id = 2; optional string begin_item_id = 3; optional string end_item_id = 4; optional google.protobuf.Timestamp begin_time = 5; optional google.protobuf.Timestamp end_time = 6; repeated FeedbackTypeExpression feedback_types = 7; bool order_by_item_id = 8; } message BatchInsertItemsRequest { repeated Item items = 1; } message BatchInsertItemsResponse {} message BatchGetItemsRequest { repeated string item_ids = 1; } message BatchGetItemsResponse { repeated Item items = 1; } message DeleteItemRequest { string item_id = 1; } message DeleteItemResponse {} message GetItemRequest { string item_id = 1; } message GetItemResponse { optional Item item = 1; } message ModifyItemRequest { string item_id = 1; ItemPatch patch = 2; } message ModifyItemResponse {} message GetItemsRequest { string cursor = 1; int32 n = 2; google.protobuf.Timestamp begin_time = 3; } message GetItemsResponse { string cursor = 1; repeated Item items = 2; } message GetItemFeedbackRequest { string item_id = 1; repeated FeedbackTypeExpression feedback_types = 2; } message BatchInsertUsersRequest { repeated User users = 1; } message BatchInsertUsersResponse {} message DeleteUserRequest { string user_id = 1; } message DeleteUserResponse {} message GetUserRequest { string user_id = 1; } message GetUserResponse { optional User user = 1; } message ModifyUserRequest { string user_id = 1; UserPatch patch = 2; } message ModifyUserResponse {} message GetUsersRequest { string cursor = 1; int32 n = 2; } message GetUsersResponse { string cursor = 1; repeated User users = 2; } message GetUserFeedbackRequest { string user_id = 1; google.protobuf.Timestamp end_time = 2; repeated FeedbackTypeExpression feedback_types = 3; } message GetUserItemFeedbackRequest { string user_id = 1; string item_id = 2; repeated FeedbackTypeExpression feedback_types = 3; } message DeleteUserItemFeedbackRequest { string user_id = 1; string item_id = 2; repeated string feedback_types = 3; } message DeleteUserItemFeedbackResponse { int32 count = 1; } message BatchInsertFeedbackRequest { repeated Feedback feedback = 1; bool insert_user = 2; bool insert_item = 3; bool overwrite = 4; } message BatchInsertFeedbackResponse {} message GetFeedbackRequest { string cursor = 1; int32 n = 2; google.protobuf.Timestamp begin_time = 3; google.protobuf.Timestamp end_time = 4; repeated FeedbackTypeExpression feedback_types = 5; } message GetFeedbackResponse { string cursor = 1; repeated Feedback feedback = 2; } message GetUserStreamRequest { int32 batch_size = 1; } message GetUserStreamResponse { repeated User users = 1; } message GetItemStreamRequest { int32 batch_size = 1; google.protobuf.Timestamp time_limit = 2; } message GetItemStreamResponse { repeated Item items = 1; } message GetFeedbackStreamRequest { int32 batch_size = 1; ScanOptions scan_options = 2; } message GetFeedbackStreamResponse { repeated Feedback feedback = 1; } message CountUsersRequest {} message CountUsersResponse { int32 count = 1; } message CountItemsRequest {} message CountItemsResponse { int32 count = 1; } message CountFeedbackRequest {} message CountFeedbackResponse { int32 count = 1; } message GetLatestItemsRequest { int32 n = 1; repeated string categories = 2; } message GetLatestItemsResponse { repeated Item items = 1; } service DataStore { rpc Ping(PingRequest) returns (PingResponse) {} rpc BatchInsertItems(BatchInsertItemsRequest) returns (BatchInsertItemsResponse) {} rpc BatchGetItems(BatchGetItemsRequest) returns (BatchGetItemsResponse) {} rpc DeleteItem(DeleteItemRequest) returns (DeleteItemResponse) {} rpc GetItem(GetItemRequest) returns (GetItemResponse) {} rpc ModifyItem(ModifyItemRequest) returns (ModifyItemResponse) {} rpc GetItems(GetItemsRequest) returns (GetItemsResponse) {} rpc GetItemFeedback(GetItemFeedbackRequest) returns (GetFeedbackResponse) {} rpc BatchInsertUsers(BatchInsertUsersRequest) returns (BatchInsertUsersResponse) {} rpc DeleteUser(DeleteUserRequest) returns (DeleteUserResponse) {} rpc GetUser(GetUserRequest) returns (GetUserResponse) {} rpc ModifyUser(ModifyUserRequest) returns (ModifyUserResponse) {} rpc GetUsers(GetUsersRequest) returns (GetUsersResponse) {} rpc GetUserFeedback(GetUserFeedbackRequest) returns (GetFeedbackResponse) {} rpc GetUserItemFeedback(GetUserItemFeedbackRequest) returns (GetFeedbackResponse) {} rpc DeleteUserItemFeedback(DeleteUserItemFeedbackRequest) returns (DeleteUserItemFeedbackResponse) {} rpc BatchInsertFeedback(BatchInsertFeedbackRequest) returns (BatchInsertFeedbackResponse) {} rpc GetFeedback(GetFeedbackRequest) returns (GetFeedbackResponse) {} rpc GetUserStream(GetUserStreamRequest) returns (stream GetUserStreamResponse) {} rpc GetItemStream(GetItemStreamRequest) returns (stream GetItemStreamResponse) {} rpc GetFeedbackStream(GetFeedbackStreamRequest) returns (stream GetFeedbackStreamResponse) {} rpc CountUsers(CountUsersRequest) returns (CountUsersResponse) {} rpc CountItems(CountItemsRequest) returns (CountItemsResponse) {} rpc CountFeedback(CountFeedbackRequest) returns (CountFeedbackResponse) {} rpc GetLatestItems(GetLatestItemsRequest) returns (GetLatestItemsResponse) {} } ================================================ FILE: protocol/data_store_grpc.pb.go ================================================ // Copyright 2024 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: // - protoc-gen-go-grpc v1.6.0 // - protoc v6.33.1 // source: data_store.proto package protocol import ( context "context" grpc "google.golang.org/grpc" codes "google.golang.org/grpc/codes" status "google.golang.org/grpc/status" ) // This is a compile-time assertion to ensure that this generated file // is compatible with the grpc package it is being compiled against. // Requires gRPC-Go v1.64.0 or later. const _ = grpc.SupportPackageIsVersion9 const ( DataStore_Ping_FullMethodName = "/protocol.DataStore/Ping" DataStore_BatchInsertItems_FullMethodName = "/protocol.DataStore/BatchInsertItems" DataStore_BatchGetItems_FullMethodName = "/protocol.DataStore/BatchGetItems" DataStore_DeleteItem_FullMethodName = "/protocol.DataStore/DeleteItem" DataStore_GetItem_FullMethodName = "/protocol.DataStore/GetItem" DataStore_ModifyItem_FullMethodName = "/protocol.DataStore/ModifyItem" DataStore_GetItems_FullMethodName = "/protocol.DataStore/GetItems" DataStore_GetItemFeedback_FullMethodName = "/protocol.DataStore/GetItemFeedback" DataStore_BatchInsertUsers_FullMethodName = "/protocol.DataStore/BatchInsertUsers" DataStore_DeleteUser_FullMethodName = "/protocol.DataStore/DeleteUser" DataStore_GetUser_FullMethodName = "/protocol.DataStore/GetUser" DataStore_ModifyUser_FullMethodName = "/protocol.DataStore/ModifyUser" DataStore_GetUsers_FullMethodName = "/protocol.DataStore/GetUsers" DataStore_GetUserFeedback_FullMethodName = "/protocol.DataStore/GetUserFeedback" DataStore_GetUserItemFeedback_FullMethodName = "/protocol.DataStore/GetUserItemFeedback" DataStore_DeleteUserItemFeedback_FullMethodName = "/protocol.DataStore/DeleteUserItemFeedback" DataStore_BatchInsertFeedback_FullMethodName = "/protocol.DataStore/BatchInsertFeedback" DataStore_GetFeedback_FullMethodName = "/protocol.DataStore/GetFeedback" DataStore_GetUserStream_FullMethodName = "/protocol.DataStore/GetUserStream" DataStore_GetItemStream_FullMethodName = "/protocol.DataStore/GetItemStream" DataStore_GetFeedbackStream_FullMethodName = "/protocol.DataStore/GetFeedbackStream" DataStore_CountUsers_FullMethodName = "/protocol.DataStore/CountUsers" DataStore_CountItems_FullMethodName = "/protocol.DataStore/CountItems" DataStore_CountFeedback_FullMethodName = "/protocol.DataStore/CountFeedback" DataStore_GetLatestItems_FullMethodName = "/protocol.DataStore/GetLatestItems" ) // DataStoreClient is the client API for DataStore service. // // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. type DataStoreClient interface { Ping(ctx context.Context, in *PingRequest, opts ...grpc.CallOption) (*PingResponse, error) BatchInsertItems(ctx context.Context, in *BatchInsertItemsRequest, opts ...grpc.CallOption) (*BatchInsertItemsResponse, error) BatchGetItems(ctx context.Context, in *BatchGetItemsRequest, opts ...grpc.CallOption) (*BatchGetItemsResponse, error) DeleteItem(ctx context.Context, in *DeleteItemRequest, opts ...grpc.CallOption) (*DeleteItemResponse, error) GetItem(ctx context.Context, in *GetItemRequest, opts ...grpc.CallOption) (*GetItemResponse, error) ModifyItem(ctx context.Context, in *ModifyItemRequest, opts ...grpc.CallOption) (*ModifyItemResponse, error) GetItems(ctx context.Context, in *GetItemsRequest, opts ...grpc.CallOption) (*GetItemsResponse, error) GetItemFeedback(ctx context.Context, in *GetItemFeedbackRequest, opts ...grpc.CallOption) (*GetFeedbackResponse, error) BatchInsertUsers(ctx context.Context, in *BatchInsertUsersRequest, opts ...grpc.CallOption) (*BatchInsertUsersResponse, error) DeleteUser(ctx context.Context, in *DeleteUserRequest, opts ...grpc.CallOption) (*DeleteUserResponse, error) GetUser(ctx context.Context, in *GetUserRequest, opts ...grpc.CallOption) (*GetUserResponse, error) ModifyUser(ctx context.Context, in *ModifyUserRequest, opts ...grpc.CallOption) (*ModifyUserResponse, error) GetUsers(ctx context.Context, in *GetUsersRequest, opts ...grpc.CallOption) (*GetUsersResponse, error) GetUserFeedback(ctx context.Context, in *GetUserFeedbackRequest, opts ...grpc.CallOption) (*GetFeedbackResponse, error) GetUserItemFeedback(ctx context.Context, in *GetUserItemFeedbackRequest, opts ...grpc.CallOption) (*GetFeedbackResponse, error) DeleteUserItemFeedback(ctx context.Context, in *DeleteUserItemFeedbackRequest, opts ...grpc.CallOption) (*DeleteUserItemFeedbackResponse, error) BatchInsertFeedback(ctx context.Context, in *BatchInsertFeedbackRequest, opts ...grpc.CallOption) (*BatchInsertFeedbackResponse, error) GetFeedback(ctx context.Context, in *GetFeedbackRequest, opts ...grpc.CallOption) (*GetFeedbackResponse, error) GetUserStream(ctx context.Context, in *GetUserStreamRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[GetUserStreamResponse], error) GetItemStream(ctx context.Context, in *GetItemStreamRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[GetItemStreamResponse], error) GetFeedbackStream(ctx context.Context, in *GetFeedbackStreamRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[GetFeedbackStreamResponse], error) CountUsers(ctx context.Context, in *CountUsersRequest, opts ...grpc.CallOption) (*CountUsersResponse, error) CountItems(ctx context.Context, in *CountItemsRequest, opts ...grpc.CallOption) (*CountItemsResponse, error) CountFeedback(ctx context.Context, in *CountFeedbackRequest, opts ...grpc.CallOption) (*CountFeedbackResponse, error) GetLatestItems(ctx context.Context, in *GetLatestItemsRequest, opts ...grpc.CallOption) (*GetLatestItemsResponse, error) } type dataStoreClient struct { cc grpc.ClientConnInterface } func NewDataStoreClient(cc grpc.ClientConnInterface) DataStoreClient { return &dataStoreClient{cc} } func (c *dataStoreClient) Ping(ctx context.Context, in *PingRequest, opts ...grpc.CallOption) (*PingResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(PingResponse) err := c.cc.Invoke(ctx, DataStore_Ping_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } func (c *dataStoreClient) BatchInsertItems(ctx context.Context, in *BatchInsertItemsRequest, opts ...grpc.CallOption) (*BatchInsertItemsResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(BatchInsertItemsResponse) err := c.cc.Invoke(ctx, DataStore_BatchInsertItems_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } func (c *dataStoreClient) BatchGetItems(ctx context.Context, in *BatchGetItemsRequest, opts ...grpc.CallOption) (*BatchGetItemsResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(BatchGetItemsResponse) err := c.cc.Invoke(ctx, DataStore_BatchGetItems_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } func (c *dataStoreClient) DeleteItem(ctx context.Context, in *DeleteItemRequest, opts ...grpc.CallOption) (*DeleteItemResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(DeleteItemResponse) err := c.cc.Invoke(ctx, DataStore_DeleteItem_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } func (c *dataStoreClient) GetItem(ctx context.Context, in *GetItemRequest, opts ...grpc.CallOption) (*GetItemResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(GetItemResponse) err := c.cc.Invoke(ctx, DataStore_GetItem_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } func (c *dataStoreClient) ModifyItem(ctx context.Context, in *ModifyItemRequest, opts ...grpc.CallOption) (*ModifyItemResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(ModifyItemResponse) err := c.cc.Invoke(ctx, DataStore_ModifyItem_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } func (c *dataStoreClient) GetItems(ctx context.Context, in *GetItemsRequest, opts ...grpc.CallOption) (*GetItemsResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(GetItemsResponse) err := c.cc.Invoke(ctx, DataStore_GetItems_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } func (c *dataStoreClient) GetItemFeedback(ctx context.Context, in *GetItemFeedbackRequest, opts ...grpc.CallOption) (*GetFeedbackResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(GetFeedbackResponse) err := c.cc.Invoke(ctx, DataStore_GetItemFeedback_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } func (c *dataStoreClient) BatchInsertUsers(ctx context.Context, in *BatchInsertUsersRequest, opts ...grpc.CallOption) (*BatchInsertUsersResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(BatchInsertUsersResponse) err := c.cc.Invoke(ctx, DataStore_BatchInsertUsers_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } func (c *dataStoreClient) DeleteUser(ctx context.Context, in *DeleteUserRequest, opts ...grpc.CallOption) (*DeleteUserResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(DeleteUserResponse) err := c.cc.Invoke(ctx, DataStore_DeleteUser_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } func (c *dataStoreClient) GetUser(ctx context.Context, in *GetUserRequest, opts ...grpc.CallOption) (*GetUserResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(GetUserResponse) err := c.cc.Invoke(ctx, DataStore_GetUser_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } func (c *dataStoreClient) ModifyUser(ctx context.Context, in *ModifyUserRequest, opts ...grpc.CallOption) (*ModifyUserResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(ModifyUserResponse) err := c.cc.Invoke(ctx, DataStore_ModifyUser_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } func (c *dataStoreClient) GetUsers(ctx context.Context, in *GetUsersRequest, opts ...grpc.CallOption) (*GetUsersResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(GetUsersResponse) err := c.cc.Invoke(ctx, DataStore_GetUsers_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } func (c *dataStoreClient) GetUserFeedback(ctx context.Context, in *GetUserFeedbackRequest, opts ...grpc.CallOption) (*GetFeedbackResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(GetFeedbackResponse) err := c.cc.Invoke(ctx, DataStore_GetUserFeedback_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } func (c *dataStoreClient) GetUserItemFeedback(ctx context.Context, in *GetUserItemFeedbackRequest, opts ...grpc.CallOption) (*GetFeedbackResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(GetFeedbackResponse) err := c.cc.Invoke(ctx, DataStore_GetUserItemFeedback_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } func (c *dataStoreClient) DeleteUserItemFeedback(ctx context.Context, in *DeleteUserItemFeedbackRequest, opts ...grpc.CallOption) (*DeleteUserItemFeedbackResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(DeleteUserItemFeedbackResponse) err := c.cc.Invoke(ctx, DataStore_DeleteUserItemFeedback_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } func (c *dataStoreClient) BatchInsertFeedback(ctx context.Context, in *BatchInsertFeedbackRequest, opts ...grpc.CallOption) (*BatchInsertFeedbackResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(BatchInsertFeedbackResponse) err := c.cc.Invoke(ctx, DataStore_BatchInsertFeedback_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } func (c *dataStoreClient) GetFeedback(ctx context.Context, in *GetFeedbackRequest, opts ...grpc.CallOption) (*GetFeedbackResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(GetFeedbackResponse) err := c.cc.Invoke(ctx, DataStore_GetFeedback_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } func (c *dataStoreClient) GetUserStream(ctx context.Context, in *GetUserStreamRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[GetUserStreamResponse], error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) stream, err := c.cc.NewStream(ctx, &DataStore_ServiceDesc.Streams[0], DataStore_GetUserStream_FullMethodName, cOpts...) if err != nil { return nil, err } x := &grpc.GenericClientStream[GetUserStreamRequest, GetUserStreamResponse]{ClientStream: stream} if err := x.ClientStream.SendMsg(in); err != nil { return nil, err } if err := x.ClientStream.CloseSend(); err != nil { return nil, err } return x, nil } // This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. type DataStore_GetUserStreamClient = grpc.ServerStreamingClient[GetUserStreamResponse] func (c *dataStoreClient) GetItemStream(ctx context.Context, in *GetItemStreamRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[GetItemStreamResponse], error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) stream, err := c.cc.NewStream(ctx, &DataStore_ServiceDesc.Streams[1], DataStore_GetItemStream_FullMethodName, cOpts...) if err != nil { return nil, err } x := &grpc.GenericClientStream[GetItemStreamRequest, GetItemStreamResponse]{ClientStream: stream} if err := x.ClientStream.SendMsg(in); err != nil { return nil, err } if err := x.ClientStream.CloseSend(); err != nil { return nil, err } return x, nil } // This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. type DataStore_GetItemStreamClient = grpc.ServerStreamingClient[GetItemStreamResponse] func (c *dataStoreClient) GetFeedbackStream(ctx context.Context, in *GetFeedbackStreamRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[GetFeedbackStreamResponse], error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) stream, err := c.cc.NewStream(ctx, &DataStore_ServiceDesc.Streams[2], DataStore_GetFeedbackStream_FullMethodName, cOpts...) if err != nil { return nil, err } x := &grpc.GenericClientStream[GetFeedbackStreamRequest, GetFeedbackStreamResponse]{ClientStream: stream} if err := x.ClientStream.SendMsg(in); err != nil { return nil, err } if err := x.ClientStream.CloseSend(); err != nil { return nil, err } return x, nil } // This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. type DataStore_GetFeedbackStreamClient = grpc.ServerStreamingClient[GetFeedbackStreamResponse] func (c *dataStoreClient) CountUsers(ctx context.Context, in *CountUsersRequest, opts ...grpc.CallOption) (*CountUsersResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(CountUsersResponse) err := c.cc.Invoke(ctx, DataStore_CountUsers_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } func (c *dataStoreClient) CountItems(ctx context.Context, in *CountItemsRequest, opts ...grpc.CallOption) (*CountItemsResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(CountItemsResponse) err := c.cc.Invoke(ctx, DataStore_CountItems_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } func (c *dataStoreClient) CountFeedback(ctx context.Context, in *CountFeedbackRequest, opts ...grpc.CallOption) (*CountFeedbackResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(CountFeedbackResponse) err := c.cc.Invoke(ctx, DataStore_CountFeedback_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } func (c *dataStoreClient) GetLatestItems(ctx context.Context, in *GetLatestItemsRequest, opts ...grpc.CallOption) (*GetLatestItemsResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(GetLatestItemsResponse) err := c.cc.Invoke(ctx, DataStore_GetLatestItems_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } // DataStoreServer is the server API for DataStore service. // All implementations must embed UnimplementedDataStoreServer // for forward compatibility. type DataStoreServer interface { Ping(context.Context, *PingRequest) (*PingResponse, error) BatchInsertItems(context.Context, *BatchInsertItemsRequest) (*BatchInsertItemsResponse, error) BatchGetItems(context.Context, *BatchGetItemsRequest) (*BatchGetItemsResponse, error) DeleteItem(context.Context, *DeleteItemRequest) (*DeleteItemResponse, error) GetItem(context.Context, *GetItemRequest) (*GetItemResponse, error) ModifyItem(context.Context, *ModifyItemRequest) (*ModifyItemResponse, error) GetItems(context.Context, *GetItemsRequest) (*GetItemsResponse, error) GetItemFeedback(context.Context, *GetItemFeedbackRequest) (*GetFeedbackResponse, error) BatchInsertUsers(context.Context, *BatchInsertUsersRequest) (*BatchInsertUsersResponse, error) DeleteUser(context.Context, *DeleteUserRequest) (*DeleteUserResponse, error) GetUser(context.Context, *GetUserRequest) (*GetUserResponse, error) ModifyUser(context.Context, *ModifyUserRequest) (*ModifyUserResponse, error) GetUsers(context.Context, *GetUsersRequest) (*GetUsersResponse, error) GetUserFeedback(context.Context, *GetUserFeedbackRequest) (*GetFeedbackResponse, error) GetUserItemFeedback(context.Context, *GetUserItemFeedbackRequest) (*GetFeedbackResponse, error) DeleteUserItemFeedback(context.Context, *DeleteUserItemFeedbackRequest) (*DeleteUserItemFeedbackResponse, error) BatchInsertFeedback(context.Context, *BatchInsertFeedbackRequest) (*BatchInsertFeedbackResponse, error) GetFeedback(context.Context, *GetFeedbackRequest) (*GetFeedbackResponse, error) GetUserStream(*GetUserStreamRequest, grpc.ServerStreamingServer[GetUserStreamResponse]) error GetItemStream(*GetItemStreamRequest, grpc.ServerStreamingServer[GetItemStreamResponse]) error GetFeedbackStream(*GetFeedbackStreamRequest, grpc.ServerStreamingServer[GetFeedbackStreamResponse]) error CountUsers(context.Context, *CountUsersRequest) (*CountUsersResponse, error) CountItems(context.Context, *CountItemsRequest) (*CountItemsResponse, error) CountFeedback(context.Context, *CountFeedbackRequest) (*CountFeedbackResponse, error) GetLatestItems(context.Context, *GetLatestItemsRequest) (*GetLatestItemsResponse, error) mustEmbedUnimplementedDataStoreServer() } // UnimplementedDataStoreServer must be embedded to have // forward compatible implementations. // // NOTE: this should be embedded by value instead of pointer to avoid a nil // pointer dereference when methods are called. type UnimplementedDataStoreServer struct{} func (UnimplementedDataStoreServer) Ping(context.Context, *PingRequest) (*PingResponse, error) { return nil, status.Error(codes.Unimplemented, "method Ping not implemented") } func (UnimplementedDataStoreServer) BatchInsertItems(context.Context, *BatchInsertItemsRequest) (*BatchInsertItemsResponse, error) { return nil, status.Error(codes.Unimplemented, "method BatchInsertItems not implemented") } func (UnimplementedDataStoreServer) BatchGetItems(context.Context, *BatchGetItemsRequest) (*BatchGetItemsResponse, error) { return nil, status.Error(codes.Unimplemented, "method BatchGetItems not implemented") } func (UnimplementedDataStoreServer) DeleteItem(context.Context, *DeleteItemRequest) (*DeleteItemResponse, error) { return nil, status.Error(codes.Unimplemented, "method DeleteItem not implemented") } func (UnimplementedDataStoreServer) GetItem(context.Context, *GetItemRequest) (*GetItemResponse, error) { return nil, status.Error(codes.Unimplemented, "method GetItem not implemented") } func (UnimplementedDataStoreServer) ModifyItem(context.Context, *ModifyItemRequest) (*ModifyItemResponse, error) { return nil, status.Error(codes.Unimplemented, "method ModifyItem not implemented") } func (UnimplementedDataStoreServer) GetItems(context.Context, *GetItemsRequest) (*GetItemsResponse, error) { return nil, status.Error(codes.Unimplemented, "method GetItems not implemented") } func (UnimplementedDataStoreServer) GetItemFeedback(context.Context, *GetItemFeedbackRequest) (*GetFeedbackResponse, error) { return nil, status.Error(codes.Unimplemented, "method GetItemFeedback not implemented") } func (UnimplementedDataStoreServer) BatchInsertUsers(context.Context, *BatchInsertUsersRequest) (*BatchInsertUsersResponse, error) { return nil, status.Error(codes.Unimplemented, "method BatchInsertUsers not implemented") } func (UnimplementedDataStoreServer) DeleteUser(context.Context, *DeleteUserRequest) (*DeleteUserResponse, error) { return nil, status.Error(codes.Unimplemented, "method DeleteUser not implemented") } func (UnimplementedDataStoreServer) GetUser(context.Context, *GetUserRequest) (*GetUserResponse, error) { return nil, status.Error(codes.Unimplemented, "method GetUser not implemented") } func (UnimplementedDataStoreServer) ModifyUser(context.Context, *ModifyUserRequest) (*ModifyUserResponse, error) { return nil, status.Error(codes.Unimplemented, "method ModifyUser not implemented") } func (UnimplementedDataStoreServer) GetUsers(context.Context, *GetUsersRequest) (*GetUsersResponse, error) { return nil, status.Error(codes.Unimplemented, "method GetUsers not implemented") } func (UnimplementedDataStoreServer) GetUserFeedback(context.Context, *GetUserFeedbackRequest) (*GetFeedbackResponse, error) { return nil, status.Error(codes.Unimplemented, "method GetUserFeedback not implemented") } func (UnimplementedDataStoreServer) GetUserItemFeedback(context.Context, *GetUserItemFeedbackRequest) (*GetFeedbackResponse, error) { return nil, status.Error(codes.Unimplemented, "method GetUserItemFeedback not implemented") } func (UnimplementedDataStoreServer) DeleteUserItemFeedback(context.Context, *DeleteUserItemFeedbackRequest) (*DeleteUserItemFeedbackResponse, error) { return nil, status.Error(codes.Unimplemented, "method DeleteUserItemFeedback not implemented") } func (UnimplementedDataStoreServer) BatchInsertFeedback(context.Context, *BatchInsertFeedbackRequest) (*BatchInsertFeedbackResponse, error) { return nil, status.Error(codes.Unimplemented, "method BatchInsertFeedback not implemented") } func (UnimplementedDataStoreServer) GetFeedback(context.Context, *GetFeedbackRequest) (*GetFeedbackResponse, error) { return nil, status.Error(codes.Unimplemented, "method GetFeedback not implemented") } func (UnimplementedDataStoreServer) GetUserStream(*GetUserStreamRequest, grpc.ServerStreamingServer[GetUserStreamResponse]) error { return status.Error(codes.Unimplemented, "method GetUserStream not implemented") } func (UnimplementedDataStoreServer) GetItemStream(*GetItemStreamRequest, grpc.ServerStreamingServer[GetItemStreamResponse]) error { return status.Error(codes.Unimplemented, "method GetItemStream not implemented") } func (UnimplementedDataStoreServer) GetFeedbackStream(*GetFeedbackStreamRequest, grpc.ServerStreamingServer[GetFeedbackStreamResponse]) error { return status.Error(codes.Unimplemented, "method GetFeedbackStream not implemented") } func (UnimplementedDataStoreServer) CountUsers(context.Context, *CountUsersRequest) (*CountUsersResponse, error) { return nil, status.Error(codes.Unimplemented, "method CountUsers not implemented") } func (UnimplementedDataStoreServer) CountItems(context.Context, *CountItemsRequest) (*CountItemsResponse, error) { return nil, status.Error(codes.Unimplemented, "method CountItems not implemented") } func (UnimplementedDataStoreServer) CountFeedback(context.Context, *CountFeedbackRequest) (*CountFeedbackResponse, error) { return nil, status.Error(codes.Unimplemented, "method CountFeedback not implemented") } func (UnimplementedDataStoreServer) GetLatestItems(context.Context, *GetLatestItemsRequest) (*GetLatestItemsResponse, error) { return nil, status.Error(codes.Unimplemented, "method GetLatestItems not implemented") } func (UnimplementedDataStoreServer) mustEmbedUnimplementedDataStoreServer() {} func (UnimplementedDataStoreServer) testEmbeddedByValue() {} // UnsafeDataStoreServer may be embedded to opt out of forward compatibility for this service. // Use of this interface is not recommended, as added methods to DataStoreServer will // result in compilation errors. type UnsafeDataStoreServer interface { mustEmbedUnimplementedDataStoreServer() } func RegisterDataStoreServer(s grpc.ServiceRegistrar, srv DataStoreServer) { // If the following call panics, it indicates UnimplementedDataStoreServer was // embedded by pointer and is nil. This will cause panics if an // unimplemented method is ever invoked, so we test this at initialization // time to prevent it from happening at runtime later due to I/O. if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { t.testEmbeddedByValue() } s.RegisterService(&DataStore_ServiceDesc, srv) } func _DataStore_Ping_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(PingRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(DataStoreServer).Ping(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: DataStore_Ping_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DataStoreServer).Ping(ctx, req.(*PingRequest)) } return interceptor(ctx, in, info, handler) } func _DataStore_BatchInsertItems_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(BatchInsertItemsRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(DataStoreServer).BatchInsertItems(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: DataStore_BatchInsertItems_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DataStoreServer).BatchInsertItems(ctx, req.(*BatchInsertItemsRequest)) } return interceptor(ctx, in, info, handler) } func _DataStore_BatchGetItems_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(BatchGetItemsRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(DataStoreServer).BatchGetItems(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: DataStore_BatchGetItems_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DataStoreServer).BatchGetItems(ctx, req.(*BatchGetItemsRequest)) } return interceptor(ctx, in, info, handler) } func _DataStore_DeleteItem_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(DeleteItemRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(DataStoreServer).DeleteItem(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: DataStore_DeleteItem_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DataStoreServer).DeleteItem(ctx, req.(*DeleteItemRequest)) } return interceptor(ctx, in, info, handler) } func _DataStore_GetItem_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(GetItemRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(DataStoreServer).GetItem(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: DataStore_GetItem_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DataStoreServer).GetItem(ctx, req.(*GetItemRequest)) } return interceptor(ctx, in, info, handler) } func _DataStore_ModifyItem_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(ModifyItemRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(DataStoreServer).ModifyItem(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: DataStore_ModifyItem_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DataStoreServer).ModifyItem(ctx, req.(*ModifyItemRequest)) } return interceptor(ctx, in, info, handler) } func _DataStore_GetItems_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(GetItemsRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(DataStoreServer).GetItems(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: DataStore_GetItems_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DataStoreServer).GetItems(ctx, req.(*GetItemsRequest)) } return interceptor(ctx, in, info, handler) } func _DataStore_GetItemFeedback_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(GetItemFeedbackRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(DataStoreServer).GetItemFeedback(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: DataStore_GetItemFeedback_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DataStoreServer).GetItemFeedback(ctx, req.(*GetItemFeedbackRequest)) } return interceptor(ctx, in, info, handler) } func _DataStore_BatchInsertUsers_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(BatchInsertUsersRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(DataStoreServer).BatchInsertUsers(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: DataStore_BatchInsertUsers_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DataStoreServer).BatchInsertUsers(ctx, req.(*BatchInsertUsersRequest)) } return interceptor(ctx, in, info, handler) } func _DataStore_DeleteUser_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(DeleteUserRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(DataStoreServer).DeleteUser(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: DataStore_DeleteUser_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DataStoreServer).DeleteUser(ctx, req.(*DeleteUserRequest)) } return interceptor(ctx, in, info, handler) } func _DataStore_GetUser_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(GetUserRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(DataStoreServer).GetUser(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: DataStore_GetUser_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DataStoreServer).GetUser(ctx, req.(*GetUserRequest)) } return interceptor(ctx, in, info, handler) } func _DataStore_ModifyUser_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(ModifyUserRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(DataStoreServer).ModifyUser(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: DataStore_ModifyUser_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DataStoreServer).ModifyUser(ctx, req.(*ModifyUserRequest)) } return interceptor(ctx, in, info, handler) } func _DataStore_GetUsers_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(GetUsersRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(DataStoreServer).GetUsers(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: DataStore_GetUsers_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DataStoreServer).GetUsers(ctx, req.(*GetUsersRequest)) } return interceptor(ctx, in, info, handler) } func _DataStore_GetUserFeedback_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(GetUserFeedbackRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(DataStoreServer).GetUserFeedback(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: DataStore_GetUserFeedback_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DataStoreServer).GetUserFeedback(ctx, req.(*GetUserFeedbackRequest)) } return interceptor(ctx, in, info, handler) } func _DataStore_GetUserItemFeedback_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(GetUserItemFeedbackRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(DataStoreServer).GetUserItemFeedback(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: DataStore_GetUserItemFeedback_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DataStoreServer).GetUserItemFeedback(ctx, req.(*GetUserItemFeedbackRequest)) } return interceptor(ctx, in, info, handler) } func _DataStore_DeleteUserItemFeedback_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(DeleteUserItemFeedbackRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(DataStoreServer).DeleteUserItemFeedback(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: DataStore_DeleteUserItemFeedback_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DataStoreServer).DeleteUserItemFeedback(ctx, req.(*DeleteUserItemFeedbackRequest)) } return interceptor(ctx, in, info, handler) } func _DataStore_BatchInsertFeedback_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(BatchInsertFeedbackRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(DataStoreServer).BatchInsertFeedback(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: DataStore_BatchInsertFeedback_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DataStoreServer).BatchInsertFeedback(ctx, req.(*BatchInsertFeedbackRequest)) } return interceptor(ctx, in, info, handler) } func _DataStore_GetFeedback_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(GetFeedbackRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(DataStoreServer).GetFeedback(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: DataStore_GetFeedback_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DataStoreServer).GetFeedback(ctx, req.(*GetFeedbackRequest)) } return interceptor(ctx, in, info, handler) } func _DataStore_GetUserStream_Handler(srv interface{}, stream grpc.ServerStream) error { m := new(GetUserStreamRequest) if err := stream.RecvMsg(m); err != nil { return err } return srv.(DataStoreServer).GetUserStream(m, &grpc.GenericServerStream[GetUserStreamRequest, GetUserStreamResponse]{ServerStream: stream}) } // This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. type DataStore_GetUserStreamServer = grpc.ServerStreamingServer[GetUserStreamResponse] func _DataStore_GetItemStream_Handler(srv interface{}, stream grpc.ServerStream) error { m := new(GetItemStreamRequest) if err := stream.RecvMsg(m); err != nil { return err } return srv.(DataStoreServer).GetItemStream(m, &grpc.GenericServerStream[GetItemStreamRequest, GetItemStreamResponse]{ServerStream: stream}) } // This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. type DataStore_GetItemStreamServer = grpc.ServerStreamingServer[GetItemStreamResponse] func _DataStore_GetFeedbackStream_Handler(srv interface{}, stream grpc.ServerStream) error { m := new(GetFeedbackStreamRequest) if err := stream.RecvMsg(m); err != nil { return err } return srv.(DataStoreServer).GetFeedbackStream(m, &grpc.GenericServerStream[GetFeedbackStreamRequest, GetFeedbackStreamResponse]{ServerStream: stream}) } // This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. type DataStore_GetFeedbackStreamServer = grpc.ServerStreamingServer[GetFeedbackStreamResponse] func _DataStore_CountUsers_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(CountUsersRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(DataStoreServer).CountUsers(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: DataStore_CountUsers_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DataStoreServer).CountUsers(ctx, req.(*CountUsersRequest)) } return interceptor(ctx, in, info, handler) } func _DataStore_CountItems_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(CountItemsRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(DataStoreServer).CountItems(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: DataStore_CountItems_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DataStoreServer).CountItems(ctx, req.(*CountItemsRequest)) } return interceptor(ctx, in, info, handler) } func _DataStore_CountFeedback_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(CountFeedbackRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(DataStoreServer).CountFeedback(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: DataStore_CountFeedback_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DataStoreServer).CountFeedback(ctx, req.(*CountFeedbackRequest)) } return interceptor(ctx, in, info, handler) } func _DataStore_GetLatestItems_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(GetLatestItemsRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(DataStoreServer).GetLatestItems(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: DataStore_GetLatestItems_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DataStoreServer).GetLatestItems(ctx, req.(*GetLatestItemsRequest)) } return interceptor(ctx, in, info, handler) } // DataStore_ServiceDesc is the grpc.ServiceDesc for DataStore service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) var DataStore_ServiceDesc = grpc.ServiceDesc{ ServiceName: "protocol.DataStore", HandlerType: (*DataStoreServer)(nil), Methods: []grpc.MethodDesc{ { MethodName: "Ping", Handler: _DataStore_Ping_Handler, }, { MethodName: "BatchInsertItems", Handler: _DataStore_BatchInsertItems_Handler, }, { MethodName: "BatchGetItems", Handler: _DataStore_BatchGetItems_Handler, }, { MethodName: "DeleteItem", Handler: _DataStore_DeleteItem_Handler, }, { MethodName: "GetItem", Handler: _DataStore_GetItem_Handler, }, { MethodName: "ModifyItem", Handler: _DataStore_ModifyItem_Handler, }, { MethodName: "GetItems", Handler: _DataStore_GetItems_Handler, }, { MethodName: "GetItemFeedback", Handler: _DataStore_GetItemFeedback_Handler, }, { MethodName: "BatchInsertUsers", Handler: _DataStore_BatchInsertUsers_Handler, }, { MethodName: "DeleteUser", Handler: _DataStore_DeleteUser_Handler, }, { MethodName: "GetUser", Handler: _DataStore_GetUser_Handler, }, { MethodName: "ModifyUser", Handler: _DataStore_ModifyUser_Handler, }, { MethodName: "GetUsers", Handler: _DataStore_GetUsers_Handler, }, { MethodName: "GetUserFeedback", Handler: _DataStore_GetUserFeedback_Handler, }, { MethodName: "GetUserItemFeedback", Handler: _DataStore_GetUserItemFeedback_Handler, }, { MethodName: "DeleteUserItemFeedback", Handler: _DataStore_DeleteUserItemFeedback_Handler, }, { MethodName: "BatchInsertFeedback", Handler: _DataStore_BatchInsertFeedback_Handler, }, { MethodName: "GetFeedback", Handler: _DataStore_GetFeedback_Handler, }, { MethodName: "CountUsers", Handler: _DataStore_CountUsers_Handler, }, { MethodName: "CountItems", Handler: _DataStore_CountItems_Handler, }, { MethodName: "CountFeedback", Handler: _DataStore_CountFeedback_Handler, }, { MethodName: "GetLatestItems", Handler: _DataStore_GetLatestItems_Handler, }, }, Streams: []grpc.StreamDesc{ { StreamName: "GetUserStream", Handler: _DataStore_GetUserStream_Handler, ServerStreams: true, }, { StreamName: "GetItemStream", Handler: _DataStore_GetItemStream_Handler, ServerStreams: true, }, { StreamName: "GetFeedbackStream", Handler: _DataStore_GetFeedbackStream_Handler, ServerStreams: true, }, }, Metadata: "data_store.proto", } ================================================ FILE: protocol/encoding.pb.go ================================================ // Copyright 2025 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.10 // protoc v6.33.1 // source: encoding.proto package protocol import ( protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" reflect "reflect" sync "sync" unsafe "unsafe" ) const ( // Verify that this generated code is sufficiently up-to-date. _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) // Verify that runtime/protoimpl is sufficiently up-to-date. _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) type Tensor struct { state protoimpl.MessageState `protogen:"open.v1"` Key []string `protobuf:"bytes,1,rep,name=key,proto3" json:"key,omitempty"` Shape []int32 `protobuf:"varint,2,rep,packed,name=shape,proto3" json:"shape,omitempty"` Data []float32 `protobuf:"fixed32,3,rep,packed,name=data,proto3" json:"data,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *Tensor) Reset() { *x = Tensor{} mi := &file_encoding_proto_msgTypes[0] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *Tensor) String() string { return protoimpl.X.MessageStringOf(x) } func (*Tensor) ProtoMessage() {} func (x *Tensor) ProtoReflect() protoreflect.Message { mi := &file_encoding_proto_msgTypes[0] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use Tensor.ProtoReflect.Descriptor instead. func (*Tensor) Descriptor() ([]byte, []int) { return file_encoding_proto_rawDescGZIP(), []int{0} } func (x *Tensor) GetKey() []string { if x != nil { return x.Key } return nil } func (x *Tensor) GetShape() []int32 { if x != nil { return x.Shape } return nil } func (x *Tensor) GetData() []float32 { if x != nil { return x.Data } return nil } type LatentFactor struct { state protoimpl.MessageState `protogen:"open.v1"` Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` Data []float32 `protobuf:"fixed32,2,rep,packed,name=data,proto3" json:"data,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *LatentFactor) Reset() { *x = LatentFactor{} mi := &file_encoding_proto_msgTypes[1] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *LatentFactor) String() string { return protoimpl.X.MessageStringOf(x) } func (*LatentFactor) ProtoMessage() {} func (x *LatentFactor) ProtoReflect() protoreflect.Message { mi := &file_encoding_proto_msgTypes[1] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use LatentFactor.ProtoReflect.Descriptor instead. func (*LatentFactor) Descriptor() ([]byte, []int) { return file_encoding_proto_rawDescGZIP(), []int{1} } func (x *LatentFactor) GetId() string { if x != nil { return x.Id } return "" } func (x *LatentFactor) GetData() []float32 { if x != nil { return x.Data } return nil } var File_encoding_proto protoreflect.FileDescriptor const file_encoding_proto_rawDesc = "" + "\n" + "\x0eencoding.proto\x12\bprotocol\"D\n" + "\x06Tensor\x12\x10\n" + "\x03key\x18\x01 \x03(\tR\x03key\x12\x14\n" + "\x05shape\x18\x02 \x03(\x05R\x05shape\x12\x12\n" + "\x04data\x18\x03 \x03(\x02R\x04data\"2\n" + "\fLatentFactor\x12\x0e\n" + "\x02id\x18\x01 \x01(\tR\x02id\x12\x12\n" + "\x04data\x18\x02 \x03(\x02R\x04dataB$Z\"github.com/gorse-io/gorse/protocolb\x06proto3" var ( file_encoding_proto_rawDescOnce sync.Once file_encoding_proto_rawDescData []byte ) func file_encoding_proto_rawDescGZIP() []byte { file_encoding_proto_rawDescOnce.Do(func() { file_encoding_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_encoding_proto_rawDesc), len(file_encoding_proto_rawDesc))) }) return file_encoding_proto_rawDescData } var file_encoding_proto_msgTypes = make([]protoimpl.MessageInfo, 2) var file_encoding_proto_goTypes = []any{ (*Tensor)(nil), // 0: protocol.Tensor (*LatentFactor)(nil), // 1: protocol.LatentFactor } var file_encoding_proto_depIdxs = []int32{ 0, // [0:0] is the sub-list for method output_type 0, // [0:0] is the sub-list for method input_type 0, // [0:0] is the sub-list for extension type_name 0, // [0:0] is the sub-list for extension extendee 0, // [0:0] is the sub-list for field type_name } func init() { file_encoding_proto_init() } func file_encoding_proto_init() { if File_encoding_proto != nil { return } type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_encoding_proto_rawDesc), len(file_encoding_proto_rawDesc)), NumEnums: 0, NumMessages: 2, NumExtensions: 0, NumServices: 0, }, GoTypes: file_encoding_proto_goTypes, DependencyIndexes: file_encoding_proto_depIdxs, MessageInfos: file_encoding_proto_msgTypes, }.Build() File_encoding_proto = out.File file_encoding_proto_goTypes = nil file_encoding_proto_depIdxs = nil } ================================================ FILE: protocol/encoding.proto ================================================ // Copyright 2025 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. syntax = "proto3"; option go_package = "github.com/gorse-io/gorse/protocol"; package protocol; message Tensor { repeated string key = 1; repeated int32 shape = 2; repeated float data = 3; } message LatentFactor { string id = 1; repeated float data = 2; } ================================================ FILE: protocol/generate.go ================================================ // Copyright 2022 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package protocol //go:generate protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative cache_store.proto //go:generate protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative data_store.proto //go:generate protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative vector_store.proto //go:generate protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative encoding.proto //go:generate protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative protocol.proto ================================================ FILE: protocol/protocol.pb.go ================================================ // Copyright 2020 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.10 // protoc v6.33.1 // source: protocol.proto package protocol import ( protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" timestamppb "google.golang.org/protobuf/types/known/timestamppb" reflect "reflect" sync "sync" unsafe "unsafe" ) const ( // Verify that this generated code is sufficiently up-to-date. _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) // Verify that runtime/protoimpl is sufficiently up-to-date. _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) type NodeType int32 const ( NodeType_Server NodeType = 0 NodeType_Worker NodeType = 1 NodeType_Client NodeType = 2 ) // Enum value maps for NodeType. var ( NodeType_name = map[int32]string{ 0: "Server", 1: "Worker", 2: "Client", } NodeType_value = map[string]int32{ "Server": 0, "Worker": 1, "Client": 2, } ) func (x NodeType) Enum() *NodeType { p := new(NodeType) *p = x return p } func (x NodeType) String() string { return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) } func (NodeType) Descriptor() protoreflect.EnumDescriptor { return file_protocol_proto_enumTypes[0].Descriptor() } func (NodeType) Type() protoreflect.EnumType { return &file_protocol_proto_enumTypes[0] } func (x NodeType) Number() protoreflect.EnumNumber { return protoreflect.EnumNumber(x) } // Deprecated: Use NodeType.Descriptor instead. func (NodeType) EnumDescriptor() ([]byte, []int) { return file_protocol_proto_rawDescGZIP(), []int{0} } type User struct { state protoimpl.MessageState `protogen:"open.v1"` UserId string `protobuf:"bytes,1,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"` Labels []byte `protobuf:"bytes,2,opt,name=labels,proto3" json:"labels,omitempty"` Comment string `protobuf:"bytes,3,opt,name=comment,proto3" json:"comment,omitempty"` Subscribe []string `protobuf:"bytes,4,rep,name=subscribe,proto3" json:"subscribe,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *User) Reset() { *x = User{} mi := &file_protocol_proto_msgTypes[0] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *User) String() string { return protoimpl.X.MessageStringOf(x) } func (*User) ProtoMessage() {} func (x *User) ProtoReflect() protoreflect.Message { mi := &file_protocol_proto_msgTypes[0] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use User.ProtoReflect.Descriptor instead. func (*User) Descriptor() ([]byte, []int) { return file_protocol_proto_rawDescGZIP(), []int{0} } func (x *User) GetUserId() string { if x != nil { return x.UserId } return "" } func (x *User) GetLabels() []byte { if x != nil { return x.Labels } return nil } func (x *User) GetComment() string { if x != nil { return x.Comment } return "" } func (x *User) GetSubscribe() []string { if x != nil { return x.Subscribe } return nil } type Item struct { state protoimpl.MessageState `protogen:"open.v1"` Namespace string `protobuf:"bytes,1,opt,name=namespace,proto3" json:"namespace,omitempty"` ItemId string `protobuf:"bytes,2,opt,name=item_id,json=itemId,proto3" json:"item_id,omitempty"` IsHidden bool `protobuf:"varint,3,opt,name=is_hidden,json=isHidden,proto3" json:"is_hidden,omitempty"` Categories []string `protobuf:"bytes,4,rep,name=categories,proto3" json:"categories,omitempty"` Timestamp *timestamppb.Timestamp `protobuf:"bytes,5,opt,name=timestamp,proto3" json:"timestamp,omitempty"` Labels []byte `protobuf:"bytes,6,opt,name=labels,proto3" json:"labels,omitempty"` Comment string `protobuf:"bytes,7,opt,name=comment,proto3" json:"comment,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *Item) Reset() { *x = Item{} mi := &file_protocol_proto_msgTypes[1] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *Item) String() string { return protoimpl.X.MessageStringOf(x) } func (*Item) ProtoMessage() {} func (x *Item) ProtoReflect() protoreflect.Message { mi := &file_protocol_proto_msgTypes[1] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use Item.ProtoReflect.Descriptor instead. func (*Item) Descriptor() ([]byte, []int) { return file_protocol_proto_rawDescGZIP(), []int{1} } func (x *Item) GetNamespace() string { if x != nil { return x.Namespace } return "" } func (x *Item) GetItemId() string { if x != nil { return x.ItemId } return "" } func (x *Item) GetIsHidden() bool { if x != nil { return x.IsHidden } return false } func (x *Item) GetCategories() []string { if x != nil { return x.Categories } return nil } func (x *Item) GetTimestamp() *timestamppb.Timestamp { if x != nil { return x.Timestamp } return nil } func (x *Item) GetLabels() []byte { if x != nil { return x.Labels } return nil } func (x *Item) GetComment() string { if x != nil { return x.Comment } return "" } type Feedback struct { state protoimpl.MessageState `protogen:"open.v1"` Namespace string `protobuf:"bytes,1,opt,name=namespace,proto3" json:"namespace,omitempty"` FeedbackType string `protobuf:"bytes,2,opt,name=feedback_type,json=feedbackType,proto3" json:"feedback_type,omitempty"` UserId string `protobuf:"bytes,3,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"` ItemId string `protobuf:"bytes,4,opt,name=item_id,json=itemId,proto3" json:"item_id,omitempty"` Value float64 `protobuf:"fixed64,5,opt,name=value,proto3" json:"value,omitempty"` Timestamp *timestamppb.Timestamp `protobuf:"bytes,6,opt,name=timestamp,proto3" json:"timestamp,omitempty"` Comment string `protobuf:"bytes,7,opt,name=comment,proto3" json:"comment,omitempty"` Updated *timestamppb.Timestamp `protobuf:"bytes,8,opt,name=updated,proto3" json:"updated,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *Feedback) Reset() { *x = Feedback{} mi := &file_protocol_proto_msgTypes[2] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *Feedback) String() string { return protoimpl.X.MessageStringOf(x) } func (*Feedback) ProtoMessage() {} func (x *Feedback) ProtoReflect() protoreflect.Message { mi := &file_protocol_proto_msgTypes[2] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use Feedback.ProtoReflect.Descriptor instead. func (*Feedback) Descriptor() ([]byte, []int) { return file_protocol_proto_rawDescGZIP(), []int{2} } func (x *Feedback) GetNamespace() string { if x != nil { return x.Namespace } return "" } func (x *Feedback) GetFeedbackType() string { if x != nil { return x.FeedbackType } return "" } func (x *Feedback) GetUserId() string { if x != nil { return x.UserId } return "" } func (x *Feedback) GetItemId() string { if x != nil { return x.ItemId } return "" } func (x *Feedback) GetValue() float64 { if x != nil { return x.Value } return 0 } func (x *Feedback) GetTimestamp() *timestamppb.Timestamp { if x != nil { return x.Timestamp } return nil } func (x *Feedback) GetComment() string { if x != nil { return x.Comment } return "" } func (x *Feedback) GetUpdated() *timestamppb.Timestamp { if x != nil { return x.Updated } return nil } type Meta struct { state protoimpl.MessageState `protogen:"open.v1"` Config string `protobuf:"bytes,1,opt,name=config,proto3" json:"config,omitempty"` CollaborativeFilteringModelId int64 `protobuf:"varint,3,opt,name=collaborative_filtering_model_id,json=collaborativeFilteringModelId,proto3" json:"collaborative_filtering_model_id,omitempty"` ClickThroughRateModelId int64 `protobuf:"varint,4,opt,name=click_through_rate_model_id,json=clickThroughRateModelId,proto3" json:"click_through_rate_model_id,omitempty"` Me string `protobuf:"bytes,5,opt,name=me,proto3" json:"me,omitempty"` Servers []string `protobuf:"bytes,6,rep,name=servers,proto3" json:"servers,omitempty"` Workers []string `protobuf:"bytes,7,rep,name=workers,proto3" json:"workers,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *Meta) Reset() { *x = Meta{} mi := &file_protocol_proto_msgTypes[3] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *Meta) String() string { return protoimpl.X.MessageStringOf(x) } func (*Meta) ProtoMessage() {} func (x *Meta) ProtoReflect() protoreflect.Message { mi := &file_protocol_proto_msgTypes[3] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use Meta.ProtoReflect.Descriptor instead. func (*Meta) Descriptor() ([]byte, []int) { return file_protocol_proto_rawDescGZIP(), []int{3} } func (x *Meta) GetConfig() string { if x != nil { return x.Config } return "" } func (x *Meta) GetCollaborativeFilteringModelId() int64 { if x != nil { return x.CollaborativeFilteringModelId } return 0 } func (x *Meta) GetClickThroughRateModelId() int64 { if x != nil { return x.ClickThroughRateModelId } return 0 } func (x *Meta) GetMe() string { if x != nil { return x.Me } return "" } func (x *Meta) GetServers() []string { if x != nil { return x.Servers } return nil } func (x *Meta) GetWorkers() []string { if x != nil { return x.Workers } return nil } type Fragment struct { state protoimpl.MessageState `protogen:"open.v1"` Data []byte `protobuf:"bytes,1,opt,name=data,proto3" json:"data,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *Fragment) Reset() { *x = Fragment{} mi := &file_protocol_proto_msgTypes[4] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *Fragment) String() string { return protoimpl.X.MessageStringOf(x) } func (*Fragment) ProtoMessage() {} func (x *Fragment) ProtoReflect() protoreflect.Message { mi := &file_protocol_proto_msgTypes[4] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use Fragment.ProtoReflect.Descriptor instead. func (*Fragment) Descriptor() ([]byte, []int) { return file_protocol_proto_rawDescGZIP(), []int{4} } func (x *Fragment) GetData() []byte { if x != nil { return x.Data } return nil } type NodeInfo struct { state protoimpl.MessageState `protogen:"open.v1"` NodeType NodeType `protobuf:"varint,1,opt,name=node_type,json=nodeType,proto3,enum=protocol.NodeType" json:"node_type,omitempty"` Uuid string `protobuf:"bytes,2,opt,name=uuid,proto3" json:"uuid,omitempty"` BinaryVersion string `protobuf:"bytes,4,opt,name=binary_version,json=binaryVersion,proto3" json:"binary_version,omitempty"` Hostname string `protobuf:"bytes,5,opt,name=hostname,proto3" json:"hostname,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *NodeInfo) Reset() { *x = NodeInfo{} mi := &file_protocol_proto_msgTypes[5] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *NodeInfo) String() string { return protoimpl.X.MessageStringOf(x) } func (*NodeInfo) ProtoMessage() {} func (x *NodeInfo) ProtoReflect() protoreflect.Message { mi := &file_protocol_proto_msgTypes[5] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use NodeInfo.ProtoReflect.Descriptor instead. func (*NodeInfo) Descriptor() ([]byte, []int) { return file_protocol_proto_rawDescGZIP(), []int{5} } func (x *NodeInfo) GetNodeType() NodeType { if x != nil { return x.NodeType } return NodeType_Server } func (x *NodeInfo) GetUuid() string { if x != nil { return x.Uuid } return "" } func (x *NodeInfo) GetBinaryVersion() string { if x != nil { return x.BinaryVersion } return "" } func (x *NodeInfo) GetHostname() string { if x != nil { return x.Hostname } return "" } type Progress struct { state protoimpl.MessageState `protogen:"open.v1"` Tracer string `protobuf:"bytes,1,opt,name=tracer,proto3" json:"tracer,omitempty"` Name string `protobuf:"bytes,2,opt,name=name,proto3" json:"name,omitempty"` Status string `protobuf:"bytes,3,opt,name=status,proto3" json:"status,omitempty"` Error string `protobuf:"bytes,4,opt,name=error,proto3" json:"error,omitempty"` Count int64 `protobuf:"varint,5,opt,name=count,proto3" json:"count,omitempty"` Total int64 `protobuf:"varint,6,opt,name=total,proto3" json:"total,omitempty"` StartTime int64 `protobuf:"varint,7,opt,name=start_time,json=startTime,proto3" json:"start_time,omitempty"` FinishTime int64 `protobuf:"varint,8,opt,name=finish_time,json=finishTime,proto3" json:"finish_time,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *Progress) Reset() { *x = Progress{} mi := &file_protocol_proto_msgTypes[6] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *Progress) String() string { return protoimpl.X.MessageStringOf(x) } func (*Progress) ProtoMessage() {} func (x *Progress) ProtoReflect() protoreflect.Message { mi := &file_protocol_proto_msgTypes[6] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use Progress.ProtoReflect.Descriptor instead. func (*Progress) Descriptor() ([]byte, []int) { return file_protocol_proto_rawDescGZIP(), []int{6} } func (x *Progress) GetTracer() string { if x != nil { return x.Tracer } return "" } func (x *Progress) GetName() string { if x != nil { return x.Name } return "" } func (x *Progress) GetStatus() string { if x != nil { return x.Status } return "" } func (x *Progress) GetError() string { if x != nil { return x.Error } return "" } func (x *Progress) GetCount() int64 { if x != nil { return x.Count } return 0 } func (x *Progress) GetTotal() int64 { if x != nil { return x.Total } return 0 } func (x *Progress) GetStartTime() int64 { if x != nil { return x.StartTime } return 0 } func (x *Progress) GetFinishTime() int64 { if x != nil { return x.FinishTime } return 0 } type PushProgressRequest struct { state protoimpl.MessageState `protogen:"open.v1"` Progress []*Progress `protobuf:"bytes,1,rep,name=progress,proto3" json:"progress,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *PushProgressRequest) Reset() { *x = PushProgressRequest{} mi := &file_protocol_proto_msgTypes[7] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *PushProgressRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*PushProgressRequest) ProtoMessage() {} func (x *PushProgressRequest) ProtoReflect() protoreflect.Message { mi := &file_protocol_proto_msgTypes[7] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use PushProgressRequest.ProtoReflect.Descriptor instead. func (*PushProgressRequest) Descriptor() ([]byte, []int) { return file_protocol_proto_rawDescGZIP(), []int{7} } func (x *PushProgressRequest) GetProgress() []*Progress { if x != nil { return x.Progress } return nil } type PushProgressResponse struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *PushProgressResponse) Reset() { *x = PushProgressResponse{} mi := &file_protocol_proto_msgTypes[8] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *PushProgressResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*PushProgressResponse) ProtoMessage() {} func (x *PushProgressResponse) ProtoReflect() protoreflect.Message { mi := &file_protocol_proto_msgTypes[8] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use PushProgressResponse.ProtoReflect.Descriptor instead. func (*PushProgressResponse) Descriptor() ([]byte, []int) { return file_protocol_proto_rawDescGZIP(), []int{8} } type PingRequest struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *PingRequest) Reset() { *x = PingRequest{} mi := &file_protocol_proto_msgTypes[9] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *PingRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*PingRequest) ProtoMessage() {} func (x *PingRequest) ProtoReflect() protoreflect.Message { mi := &file_protocol_proto_msgTypes[9] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use PingRequest.ProtoReflect.Descriptor instead. func (*PingRequest) Descriptor() ([]byte, []int) { return file_protocol_proto_rawDescGZIP(), []int{9} } type PingResponse struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *PingResponse) Reset() { *x = PingResponse{} mi := &file_protocol_proto_msgTypes[10] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *PingResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*PingResponse) ProtoMessage() {} func (x *PingResponse) ProtoReflect() protoreflect.Message { mi := &file_protocol_proto_msgTypes[10] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use PingResponse.ProtoReflect.Descriptor instead. func (*PingResponse) Descriptor() ([]byte, []int) { return file_protocol_proto_rawDescGZIP(), []int{10} } type UploadBlobRequest struct { state protoimpl.MessageState `protogen:"open.v1"` Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` Timestamp *timestamppb.Timestamp `protobuf:"bytes,2,opt,name=timestamp,proto3" json:"timestamp,omitempty"` Data []byte `protobuf:"bytes,3,opt,name=data,proto3" json:"data,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *UploadBlobRequest) Reset() { *x = UploadBlobRequest{} mi := &file_protocol_proto_msgTypes[11] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *UploadBlobRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*UploadBlobRequest) ProtoMessage() {} func (x *UploadBlobRequest) ProtoReflect() protoreflect.Message { mi := &file_protocol_proto_msgTypes[11] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use UploadBlobRequest.ProtoReflect.Descriptor instead. func (*UploadBlobRequest) Descriptor() ([]byte, []int) { return file_protocol_proto_rawDescGZIP(), []int{11} } func (x *UploadBlobRequest) GetName() string { if x != nil { return x.Name } return "" } func (x *UploadBlobRequest) GetTimestamp() *timestamppb.Timestamp { if x != nil { return x.Timestamp } return nil } func (x *UploadBlobRequest) GetData() []byte { if x != nil { return x.Data } return nil } type UploadBlobResponse struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *UploadBlobResponse) Reset() { *x = UploadBlobResponse{} mi := &file_protocol_proto_msgTypes[12] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *UploadBlobResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*UploadBlobResponse) ProtoMessage() {} func (x *UploadBlobResponse) ProtoReflect() protoreflect.Message { mi := &file_protocol_proto_msgTypes[12] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use UploadBlobResponse.ProtoReflect.Descriptor instead. func (*UploadBlobResponse) Descriptor() ([]byte, []int) { return file_protocol_proto_rawDescGZIP(), []int{12} } type DownloadBlobRequest struct { state protoimpl.MessageState `protogen:"open.v1"` Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *DownloadBlobRequest) Reset() { *x = DownloadBlobRequest{} mi := &file_protocol_proto_msgTypes[13] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *DownloadBlobRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*DownloadBlobRequest) ProtoMessage() {} func (x *DownloadBlobRequest) ProtoReflect() protoreflect.Message { mi := &file_protocol_proto_msgTypes[13] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use DownloadBlobRequest.ProtoReflect.Descriptor instead. func (*DownloadBlobRequest) Descriptor() ([]byte, []int) { return file_protocol_proto_rawDescGZIP(), []int{13} } func (x *DownloadBlobRequest) GetName() string { if x != nil { return x.Name } return "" } type DownloadBlobResponse struct { state protoimpl.MessageState `protogen:"open.v1"` Data []byte `protobuf:"bytes,1,opt,name=data,proto3" json:"data,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *DownloadBlobResponse) Reset() { *x = DownloadBlobResponse{} mi := &file_protocol_proto_msgTypes[14] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *DownloadBlobResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*DownloadBlobResponse) ProtoMessage() {} func (x *DownloadBlobResponse) ProtoReflect() protoreflect.Message { mi := &file_protocol_proto_msgTypes[14] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use DownloadBlobResponse.ProtoReflect.Descriptor instead. func (*DownloadBlobResponse) Descriptor() ([]byte, []int) { return file_protocol_proto_rawDescGZIP(), []int{14} } func (x *DownloadBlobResponse) GetData() []byte { if x != nil { return x.Data } return nil } type ListBlobsRequest struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *ListBlobsRequest) Reset() { *x = ListBlobsRequest{} mi := &file_protocol_proto_msgTypes[15] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *ListBlobsRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*ListBlobsRequest) ProtoMessage() {} func (x *ListBlobsRequest) ProtoReflect() protoreflect.Message { mi := &file_protocol_proto_msgTypes[15] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use ListBlobsRequest.ProtoReflect.Descriptor instead. func (*ListBlobsRequest) Descriptor() ([]byte, []int) { return file_protocol_proto_rawDescGZIP(), []int{15} } type ListBlobsResponse struct { state protoimpl.MessageState `protogen:"open.v1"` Names []string `protobuf:"bytes,1,rep,name=names,proto3" json:"names,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *ListBlobsResponse) Reset() { *x = ListBlobsResponse{} mi := &file_protocol_proto_msgTypes[16] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *ListBlobsResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*ListBlobsResponse) ProtoMessage() {} func (x *ListBlobsResponse) ProtoReflect() protoreflect.Message { mi := &file_protocol_proto_msgTypes[16] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use ListBlobsResponse.ProtoReflect.Descriptor instead. func (*ListBlobsResponse) Descriptor() ([]byte, []int) { return file_protocol_proto_rawDescGZIP(), []int{16} } func (x *ListBlobsResponse) GetNames() []string { if x != nil { return x.Names } return nil } type RemoveBlobRequest struct { state protoimpl.MessageState `protogen:"open.v1"` Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *RemoveBlobRequest) Reset() { *x = RemoveBlobRequest{} mi := &file_protocol_proto_msgTypes[17] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *RemoveBlobRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*RemoveBlobRequest) ProtoMessage() {} func (x *RemoveBlobRequest) ProtoReflect() protoreflect.Message { mi := &file_protocol_proto_msgTypes[17] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use RemoveBlobRequest.ProtoReflect.Descriptor instead. func (*RemoveBlobRequest) Descriptor() ([]byte, []int) { return file_protocol_proto_rawDescGZIP(), []int{17} } func (x *RemoveBlobRequest) GetName() string { if x != nil { return x.Name } return "" } type RemoveBlobResponse struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *RemoveBlobResponse) Reset() { *x = RemoveBlobResponse{} mi := &file_protocol_proto_msgTypes[18] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *RemoveBlobResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*RemoveBlobResponse) ProtoMessage() {} func (x *RemoveBlobResponse) ProtoReflect() protoreflect.Message { mi := &file_protocol_proto_msgTypes[18] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use RemoveBlobResponse.ProtoReflect.Descriptor instead. func (*RemoveBlobResponse) Descriptor() ([]byte, []int) { return file_protocol_proto_rawDescGZIP(), []int{18} } var File_protocol_proto protoreflect.FileDescriptor const file_protocol_proto_rawDesc = "" + "\n" + "\x0eprotocol.proto\x12\bprotocol\x1a\x1fgoogle/protobuf/timestamp.proto\"o\n" + "\x04User\x12\x17\n" + "\auser_id\x18\x01 \x01(\tR\x06userId\x12\x16\n" + "\x06labels\x18\x02 \x01(\fR\x06labels\x12\x18\n" + "\acomment\x18\x03 \x01(\tR\acomment\x12\x1c\n" + "\tsubscribe\x18\x04 \x03(\tR\tsubscribe\"\xe6\x01\n" + "\x04Item\x12\x1c\n" + "\tnamespace\x18\x01 \x01(\tR\tnamespace\x12\x17\n" + "\aitem_id\x18\x02 \x01(\tR\x06itemId\x12\x1b\n" + "\tis_hidden\x18\x03 \x01(\bR\bisHidden\x12\x1e\n" + "\n" + "categories\x18\x04 \x03(\tR\n" + "categories\x128\n" + "\ttimestamp\x18\x05 \x01(\v2\x1a.google.protobuf.TimestampR\ttimestamp\x12\x16\n" + "\x06labels\x18\x06 \x01(\fR\x06labels\x12\x18\n" + "\acomment\x18\a \x01(\tR\acomment\"\x9f\x02\n" + "\bFeedback\x12\x1c\n" + "\tnamespace\x18\x01 \x01(\tR\tnamespace\x12#\n" + "\rfeedback_type\x18\x02 \x01(\tR\ffeedbackType\x12\x17\n" + "\auser_id\x18\x03 \x01(\tR\x06userId\x12\x17\n" + "\aitem_id\x18\x04 \x01(\tR\x06itemId\x12\x14\n" + "\x05value\x18\x05 \x01(\x01R\x05value\x128\n" + "\ttimestamp\x18\x06 \x01(\v2\x1a.google.protobuf.TimestampR\ttimestamp\x12\x18\n" + "\acomment\x18\a \x01(\tR\acomment\x124\n" + "\aupdated\x18\b \x01(\v2\x1a.google.protobuf.TimestampR\aupdated\"\xe9\x01\n" + "\x04Meta\x12\x16\n" + "\x06config\x18\x01 \x01(\tR\x06config\x12G\n" + " collaborative_filtering_model_id\x18\x03 \x01(\x03R\x1dcollaborativeFilteringModelId\x12<\n" + "\x1bclick_through_rate_model_id\x18\x04 \x01(\x03R\x17clickThroughRateModelId\x12\x0e\n" + "\x02me\x18\x05 \x01(\tR\x02me\x12\x18\n" + "\aservers\x18\x06 \x03(\tR\aservers\x12\x18\n" + "\aworkers\x18\a \x03(\tR\aworkers\"\x1e\n" + "\bFragment\x12\x12\n" + "\x04data\x18\x01 \x01(\fR\x04data\"\x92\x01\n" + "\bNodeInfo\x12/\n" + "\tnode_type\x18\x01 \x01(\x0e2\x12.protocol.NodeTypeR\bnodeType\x12\x12\n" + "\x04uuid\x18\x02 \x01(\tR\x04uuid\x12%\n" + "\x0ebinary_version\x18\x04 \x01(\tR\rbinaryVersion\x12\x1a\n" + "\bhostname\x18\x05 \x01(\tR\bhostname\"\xd0\x01\n" + "\bProgress\x12\x16\n" + "\x06tracer\x18\x01 \x01(\tR\x06tracer\x12\x12\n" + "\x04name\x18\x02 \x01(\tR\x04name\x12\x16\n" + "\x06status\x18\x03 \x01(\tR\x06status\x12\x14\n" + "\x05error\x18\x04 \x01(\tR\x05error\x12\x14\n" + "\x05count\x18\x05 \x01(\x03R\x05count\x12\x14\n" + "\x05total\x18\x06 \x01(\x03R\x05total\x12\x1d\n" + "\n" + "start_time\x18\a \x01(\x03R\tstartTime\x12\x1f\n" + "\vfinish_time\x18\b \x01(\x03R\n" + "finishTime\"E\n" + "\x13PushProgressRequest\x12.\n" + "\bprogress\x18\x01 \x03(\v2\x12.protocol.ProgressR\bprogress\"\x16\n" + "\x14PushProgressResponse\"\r\n" + "\vPingRequest\"\x0e\n" + "\fPingResponse\"u\n" + "\x11UploadBlobRequest\x12\x12\n" + "\x04name\x18\x01 \x01(\tR\x04name\x128\n" + "\ttimestamp\x18\x02 \x01(\v2\x1a.google.protobuf.TimestampR\ttimestamp\x12\x12\n" + "\x04data\x18\x03 \x01(\fR\x04data\"\x14\n" + "\x12UploadBlobResponse\")\n" + "\x13DownloadBlobRequest\x12\x12\n" + "\x04name\x18\x01 \x01(\tR\x04name\"*\n" + "\x14DownloadBlobResponse\x12\x12\n" + "\x04data\x18\x01 \x01(\fR\x04data\"\x12\n" + "\x10ListBlobsRequest\")\n" + "\x11ListBlobsResponse\x12\x14\n" + "\x05names\x18\x01 \x03(\tR\x05names\"'\n" + "\x11RemoveBlobRequest\x12\x12\n" + "\x04name\x18\x01 \x01(\tR\x04name\"\x14\n" + "\x12RemoveBlobResponse*.\n" + "\bNodeType\x12\n" + "\n" + "\x06Server\x10\x00\x12\n" + "\n" + "\x06Worker\x10\x01\x12\n" + "\n" + "\x06Client\x10\x022\x8a\x01\n" + "\x06Master\x12/\n" + "\aGetMeta\x12\x12.protocol.NodeInfo\x1a\x0e.protocol.Meta\"\x00\x12O\n" + "\fPushProgress\x12\x1d.protocol.PushProgressRequest\x1a\x1e.protocol.PushProgressResponse\"\x002\xbe\x02\n" + "\tBlobStore\x12K\n" + "\n" + "UploadBlob\x12\x1b.protocol.UploadBlobRequest\x1a\x1c.protocol.UploadBlobResponse\"\x00(\x01\x12Q\n" + "\fDownloadBlob\x12\x1d.protocol.DownloadBlobRequest\x1a\x1e.protocol.DownloadBlobResponse\"\x000\x01\x12F\n" + "\tListBlobs\x12\x1a.protocol.ListBlobsRequest\x1a\x1b.protocol.ListBlobsResponse\"\x00\x12I\n" + "\n" + "RemoveBlob\x12\x1b.protocol.RemoveBlobRequest\x1a\x1c.protocol.RemoveBlobResponse\"\x00B$Z\"github.com/gorse-io/gorse/protocolb\x06proto3" var ( file_protocol_proto_rawDescOnce sync.Once file_protocol_proto_rawDescData []byte ) func file_protocol_proto_rawDescGZIP() []byte { file_protocol_proto_rawDescOnce.Do(func() { file_protocol_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_protocol_proto_rawDesc), len(file_protocol_proto_rawDesc))) }) return file_protocol_proto_rawDescData } var file_protocol_proto_enumTypes = make([]protoimpl.EnumInfo, 1) var file_protocol_proto_msgTypes = make([]protoimpl.MessageInfo, 19) var file_protocol_proto_goTypes = []any{ (NodeType)(0), // 0: protocol.NodeType (*User)(nil), // 1: protocol.User (*Item)(nil), // 2: protocol.Item (*Feedback)(nil), // 3: protocol.Feedback (*Meta)(nil), // 4: protocol.Meta (*Fragment)(nil), // 5: protocol.Fragment (*NodeInfo)(nil), // 6: protocol.NodeInfo (*Progress)(nil), // 7: protocol.Progress (*PushProgressRequest)(nil), // 8: protocol.PushProgressRequest (*PushProgressResponse)(nil), // 9: protocol.PushProgressResponse (*PingRequest)(nil), // 10: protocol.PingRequest (*PingResponse)(nil), // 11: protocol.PingResponse (*UploadBlobRequest)(nil), // 12: protocol.UploadBlobRequest (*UploadBlobResponse)(nil), // 13: protocol.UploadBlobResponse (*DownloadBlobRequest)(nil), // 14: protocol.DownloadBlobRequest (*DownloadBlobResponse)(nil), // 15: protocol.DownloadBlobResponse (*ListBlobsRequest)(nil), // 16: protocol.ListBlobsRequest (*ListBlobsResponse)(nil), // 17: protocol.ListBlobsResponse (*RemoveBlobRequest)(nil), // 18: protocol.RemoveBlobRequest (*RemoveBlobResponse)(nil), // 19: protocol.RemoveBlobResponse (*timestamppb.Timestamp)(nil), // 20: google.protobuf.Timestamp } var file_protocol_proto_depIdxs = []int32{ 20, // 0: protocol.Item.timestamp:type_name -> google.protobuf.Timestamp 20, // 1: protocol.Feedback.timestamp:type_name -> google.protobuf.Timestamp 20, // 2: protocol.Feedback.updated:type_name -> google.protobuf.Timestamp 0, // 3: protocol.NodeInfo.node_type:type_name -> protocol.NodeType 7, // 4: protocol.PushProgressRequest.progress:type_name -> protocol.Progress 20, // 5: protocol.UploadBlobRequest.timestamp:type_name -> google.protobuf.Timestamp 6, // 6: protocol.Master.GetMeta:input_type -> protocol.NodeInfo 8, // 7: protocol.Master.PushProgress:input_type -> protocol.PushProgressRequest 12, // 8: protocol.BlobStore.UploadBlob:input_type -> protocol.UploadBlobRequest 14, // 9: protocol.BlobStore.DownloadBlob:input_type -> protocol.DownloadBlobRequest 16, // 10: protocol.BlobStore.ListBlobs:input_type -> protocol.ListBlobsRequest 18, // 11: protocol.BlobStore.RemoveBlob:input_type -> protocol.RemoveBlobRequest 4, // 12: protocol.Master.GetMeta:output_type -> protocol.Meta 9, // 13: protocol.Master.PushProgress:output_type -> protocol.PushProgressResponse 13, // 14: protocol.BlobStore.UploadBlob:output_type -> protocol.UploadBlobResponse 15, // 15: protocol.BlobStore.DownloadBlob:output_type -> protocol.DownloadBlobResponse 17, // 16: protocol.BlobStore.ListBlobs:output_type -> protocol.ListBlobsResponse 19, // 17: protocol.BlobStore.RemoveBlob:output_type -> protocol.RemoveBlobResponse 12, // [12:18] is the sub-list for method output_type 6, // [6:12] is the sub-list for method input_type 6, // [6:6] is the sub-list for extension type_name 6, // [6:6] is the sub-list for extension extendee 0, // [0:6] is the sub-list for field type_name } func init() { file_protocol_proto_init() } func file_protocol_proto_init() { if File_protocol_proto != nil { return } type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_protocol_proto_rawDesc), len(file_protocol_proto_rawDesc)), NumEnums: 1, NumMessages: 19, NumExtensions: 0, NumServices: 2, }, GoTypes: file_protocol_proto_goTypes, DependencyIndexes: file_protocol_proto_depIdxs, EnumInfos: file_protocol_proto_enumTypes, MessageInfos: file_protocol_proto_msgTypes, }.Build() File_protocol_proto = out.File file_protocol_proto_goTypes = nil file_protocol_proto_depIdxs = nil } ================================================ FILE: protocol/protocol.proto ================================================ // Copyright 2020 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. syntax = "proto3"; option go_package = "github.com/gorse-io/gorse/protocol"; package protocol; import "google/protobuf/timestamp.proto"; message User { string user_id = 1; bytes labels = 2; string comment = 3; repeated string subscribe = 4; } message Item { string namespace = 1; string item_id = 2; bool is_hidden = 3; repeated string categories = 4; google.protobuf.Timestamp timestamp = 5; bytes labels = 6; string comment = 7; } message Feedback { string namespace = 1; string feedback_type = 2; string user_id = 3; string item_id = 4; double value = 5; google.protobuf.Timestamp timestamp = 6; string comment = 7; google.protobuf.Timestamp updated = 8; } enum NodeType { Server = 0; Worker = 1; Client = 2; } service Master { /* meta distribute */ rpc GetMeta(NodeInfo) returns (Meta) {} rpc PushProgress(PushProgressRequest) returns (PushProgressResponse) {} } message Meta { string config = 1; int64 collaborative_filtering_model_id = 3; int64 click_through_rate_model_id = 4; string me = 5; repeated string servers = 6; repeated string workers = 7; } message Fragment { bytes data = 1; } message NodeInfo { NodeType node_type = 1; string uuid = 2; string binary_version = 4; string hostname = 5; } message Progress { string tracer = 1; string name = 2; string status = 3; string error = 4; int64 count = 5; int64 total = 6; int64 start_time = 7; int64 finish_time = 8; } message PushProgressRequest { repeated Progress progress = 1; } message PushProgressResponse {} message PingRequest {} message PingResponse {} message UploadBlobRequest { string name = 1; google.protobuf.Timestamp timestamp = 2; bytes data = 3; } message UploadBlobResponse {} message DownloadBlobRequest { string name = 1; } message DownloadBlobResponse { bytes data = 1; } message ListBlobsRequest {} message ListBlobsResponse { repeated string names = 1; } message RemoveBlobRequest { string name = 1; } message RemoveBlobResponse {} service BlobStore { rpc UploadBlob(stream UploadBlobRequest) returns (UploadBlobResponse) {} rpc DownloadBlob(DownloadBlobRequest) returns (stream DownloadBlobResponse) {} rpc ListBlobs(ListBlobsRequest) returns (ListBlobsResponse) {} rpc RemoveBlob(RemoveBlobRequest) returns (RemoveBlobResponse) {} } ================================================ FILE: protocol/protocol_grpc.pb.go ================================================ // Copyright 2020 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: // - protoc-gen-go-grpc v1.6.0 // - protoc v6.33.1 // source: protocol.proto package protocol import ( context "context" grpc "google.golang.org/grpc" codes "google.golang.org/grpc/codes" status "google.golang.org/grpc/status" ) // This is a compile-time assertion to ensure that this generated file // is compatible with the grpc package it is being compiled against. // Requires gRPC-Go v1.64.0 or later. const _ = grpc.SupportPackageIsVersion9 const ( Master_GetMeta_FullMethodName = "/protocol.Master/GetMeta" Master_PushProgress_FullMethodName = "/protocol.Master/PushProgress" ) // MasterClient is the client API for Master service. // // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. type MasterClient interface { // meta distribute GetMeta(ctx context.Context, in *NodeInfo, opts ...grpc.CallOption) (*Meta, error) PushProgress(ctx context.Context, in *PushProgressRequest, opts ...grpc.CallOption) (*PushProgressResponse, error) } type masterClient struct { cc grpc.ClientConnInterface } func NewMasterClient(cc grpc.ClientConnInterface) MasterClient { return &masterClient{cc} } func (c *masterClient) GetMeta(ctx context.Context, in *NodeInfo, opts ...grpc.CallOption) (*Meta, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(Meta) err := c.cc.Invoke(ctx, Master_GetMeta_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } func (c *masterClient) PushProgress(ctx context.Context, in *PushProgressRequest, opts ...grpc.CallOption) (*PushProgressResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(PushProgressResponse) err := c.cc.Invoke(ctx, Master_PushProgress_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } // MasterServer is the server API for Master service. // All implementations must embed UnimplementedMasterServer // for forward compatibility. type MasterServer interface { // meta distribute GetMeta(context.Context, *NodeInfo) (*Meta, error) PushProgress(context.Context, *PushProgressRequest) (*PushProgressResponse, error) mustEmbedUnimplementedMasterServer() } // UnimplementedMasterServer must be embedded to have // forward compatible implementations. // // NOTE: this should be embedded by value instead of pointer to avoid a nil // pointer dereference when methods are called. type UnimplementedMasterServer struct{} func (UnimplementedMasterServer) GetMeta(context.Context, *NodeInfo) (*Meta, error) { return nil, status.Error(codes.Unimplemented, "method GetMeta not implemented") } func (UnimplementedMasterServer) PushProgress(context.Context, *PushProgressRequest) (*PushProgressResponse, error) { return nil, status.Error(codes.Unimplemented, "method PushProgress not implemented") } func (UnimplementedMasterServer) mustEmbedUnimplementedMasterServer() {} func (UnimplementedMasterServer) testEmbeddedByValue() {} // UnsafeMasterServer may be embedded to opt out of forward compatibility for this service. // Use of this interface is not recommended, as added methods to MasterServer will // result in compilation errors. type UnsafeMasterServer interface { mustEmbedUnimplementedMasterServer() } func RegisterMasterServer(s grpc.ServiceRegistrar, srv MasterServer) { // If the following call panics, it indicates UnimplementedMasterServer was // embedded by pointer and is nil. This will cause panics if an // unimplemented method is ever invoked, so we test this at initialization // time to prevent it from happening at runtime later due to I/O. if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { t.testEmbeddedByValue() } s.RegisterService(&Master_ServiceDesc, srv) } func _Master_GetMeta_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(NodeInfo) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(MasterServer).GetMeta(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: Master_GetMeta_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(MasterServer).GetMeta(ctx, req.(*NodeInfo)) } return interceptor(ctx, in, info, handler) } func _Master_PushProgress_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(PushProgressRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(MasterServer).PushProgress(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: Master_PushProgress_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(MasterServer).PushProgress(ctx, req.(*PushProgressRequest)) } return interceptor(ctx, in, info, handler) } // Master_ServiceDesc is the grpc.ServiceDesc for Master service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) var Master_ServiceDesc = grpc.ServiceDesc{ ServiceName: "protocol.Master", HandlerType: (*MasterServer)(nil), Methods: []grpc.MethodDesc{ { MethodName: "GetMeta", Handler: _Master_GetMeta_Handler, }, { MethodName: "PushProgress", Handler: _Master_PushProgress_Handler, }, }, Streams: []grpc.StreamDesc{}, Metadata: "protocol.proto", } const ( BlobStore_UploadBlob_FullMethodName = "/protocol.BlobStore/UploadBlob" BlobStore_DownloadBlob_FullMethodName = "/protocol.BlobStore/DownloadBlob" BlobStore_ListBlobs_FullMethodName = "/protocol.BlobStore/ListBlobs" BlobStore_RemoveBlob_FullMethodName = "/protocol.BlobStore/RemoveBlob" ) // BlobStoreClient is the client API for BlobStore service. // // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. type BlobStoreClient interface { UploadBlob(ctx context.Context, opts ...grpc.CallOption) (grpc.ClientStreamingClient[UploadBlobRequest, UploadBlobResponse], error) DownloadBlob(ctx context.Context, in *DownloadBlobRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[DownloadBlobResponse], error) ListBlobs(ctx context.Context, in *ListBlobsRequest, opts ...grpc.CallOption) (*ListBlobsResponse, error) RemoveBlob(ctx context.Context, in *RemoveBlobRequest, opts ...grpc.CallOption) (*RemoveBlobResponse, error) } type blobStoreClient struct { cc grpc.ClientConnInterface } func NewBlobStoreClient(cc grpc.ClientConnInterface) BlobStoreClient { return &blobStoreClient{cc} } func (c *blobStoreClient) UploadBlob(ctx context.Context, opts ...grpc.CallOption) (grpc.ClientStreamingClient[UploadBlobRequest, UploadBlobResponse], error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) stream, err := c.cc.NewStream(ctx, &BlobStore_ServiceDesc.Streams[0], BlobStore_UploadBlob_FullMethodName, cOpts...) if err != nil { return nil, err } x := &grpc.GenericClientStream[UploadBlobRequest, UploadBlobResponse]{ClientStream: stream} return x, nil } // This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. type BlobStore_UploadBlobClient = grpc.ClientStreamingClient[UploadBlobRequest, UploadBlobResponse] func (c *blobStoreClient) DownloadBlob(ctx context.Context, in *DownloadBlobRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[DownloadBlobResponse], error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) stream, err := c.cc.NewStream(ctx, &BlobStore_ServiceDesc.Streams[1], BlobStore_DownloadBlob_FullMethodName, cOpts...) if err != nil { return nil, err } x := &grpc.GenericClientStream[DownloadBlobRequest, DownloadBlobResponse]{ClientStream: stream} if err := x.ClientStream.SendMsg(in); err != nil { return nil, err } if err := x.ClientStream.CloseSend(); err != nil { return nil, err } return x, nil } // This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. type BlobStore_DownloadBlobClient = grpc.ServerStreamingClient[DownloadBlobResponse] func (c *blobStoreClient) ListBlobs(ctx context.Context, in *ListBlobsRequest, opts ...grpc.CallOption) (*ListBlobsResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(ListBlobsResponse) err := c.cc.Invoke(ctx, BlobStore_ListBlobs_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } func (c *blobStoreClient) RemoveBlob(ctx context.Context, in *RemoveBlobRequest, opts ...grpc.CallOption) (*RemoveBlobResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(RemoveBlobResponse) err := c.cc.Invoke(ctx, BlobStore_RemoveBlob_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } // BlobStoreServer is the server API for BlobStore service. // All implementations must embed UnimplementedBlobStoreServer // for forward compatibility. type BlobStoreServer interface { UploadBlob(grpc.ClientStreamingServer[UploadBlobRequest, UploadBlobResponse]) error DownloadBlob(*DownloadBlobRequest, grpc.ServerStreamingServer[DownloadBlobResponse]) error ListBlobs(context.Context, *ListBlobsRequest) (*ListBlobsResponse, error) RemoveBlob(context.Context, *RemoveBlobRequest) (*RemoveBlobResponse, error) mustEmbedUnimplementedBlobStoreServer() } // UnimplementedBlobStoreServer must be embedded to have // forward compatible implementations. // // NOTE: this should be embedded by value instead of pointer to avoid a nil // pointer dereference when methods are called. type UnimplementedBlobStoreServer struct{} func (UnimplementedBlobStoreServer) UploadBlob(grpc.ClientStreamingServer[UploadBlobRequest, UploadBlobResponse]) error { return status.Error(codes.Unimplemented, "method UploadBlob not implemented") } func (UnimplementedBlobStoreServer) DownloadBlob(*DownloadBlobRequest, grpc.ServerStreamingServer[DownloadBlobResponse]) error { return status.Error(codes.Unimplemented, "method DownloadBlob not implemented") } func (UnimplementedBlobStoreServer) ListBlobs(context.Context, *ListBlobsRequest) (*ListBlobsResponse, error) { return nil, status.Error(codes.Unimplemented, "method ListBlobs not implemented") } func (UnimplementedBlobStoreServer) RemoveBlob(context.Context, *RemoveBlobRequest) (*RemoveBlobResponse, error) { return nil, status.Error(codes.Unimplemented, "method RemoveBlob not implemented") } func (UnimplementedBlobStoreServer) mustEmbedUnimplementedBlobStoreServer() {} func (UnimplementedBlobStoreServer) testEmbeddedByValue() {} // UnsafeBlobStoreServer may be embedded to opt out of forward compatibility for this service. // Use of this interface is not recommended, as added methods to BlobStoreServer will // result in compilation errors. type UnsafeBlobStoreServer interface { mustEmbedUnimplementedBlobStoreServer() } func RegisterBlobStoreServer(s grpc.ServiceRegistrar, srv BlobStoreServer) { // If the following call panics, it indicates UnimplementedBlobStoreServer was // embedded by pointer and is nil. This will cause panics if an // unimplemented method is ever invoked, so we test this at initialization // time to prevent it from happening at runtime later due to I/O. if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { t.testEmbeddedByValue() } s.RegisterService(&BlobStore_ServiceDesc, srv) } func _BlobStore_UploadBlob_Handler(srv interface{}, stream grpc.ServerStream) error { return srv.(BlobStoreServer).UploadBlob(&grpc.GenericServerStream[UploadBlobRequest, UploadBlobResponse]{ServerStream: stream}) } // This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. type BlobStore_UploadBlobServer = grpc.ClientStreamingServer[UploadBlobRequest, UploadBlobResponse] func _BlobStore_DownloadBlob_Handler(srv interface{}, stream grpc.ServerStream) error { m := new(DownloadBlobRequest) if err := stream.RecvMsg(m); err != nil { return err } return srv.(BlobStoreServer).DownloadBlob(m, &grpc.GenericServerStream[DownloadBlobRequest, DownloadBlobResponse]{ServerStream: stream}) } // This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. type BlobStore_DownloadBlobServer = grpc.ServerStreamingServer[DownloadBlobResponse] func _BlobStore_ListBlobs_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(ListBlobsRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(BlobStoreServer).ListBlobs(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: BlobStore_ListBlobs_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(BlobStoreServer).ListBlobs(ctx, req.(*ListBlobsRequest)) } return interceptor(ctx, in, info, handler) } func _BlobStore_RemoveBlob_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(RemoveBlobRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(BlobStoreServer).RemoveBlob(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: BlobStore_RemoveBlob_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(BlobStoreServer).RemoveBlob(ctx, req.(*RemoveBlobRequest)) } return interceptor(ctx, in, info, handler) } // BlobStore_ServiceDesc is the grpc.ServiceDesc for BlobStore service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) var BlobStore_ServiceDesc = grpc.ServiceDesc{ ServiceName: "protocol.BlobStore", HandlerType: (*BlobStoreServer)(nil), Methods: []grpc.MethodDesc{ { MethodName: "ListBlobs", Handler: _BlobStore_ListBlobs_Handler, }, { MethodName: "RemoveBlob", Handler: _BlobStore_RemoveBlob_Handler, }, }, Streams: []grpc.StreamDesc{ { StreamName: "UploadBlob", Handler: _BlobStore_UploadBlob_Handler, ClientStreams: true, }, { StreamName: "DownloadBlob", Handler: _BlobStore_DownloadBlob_Handler, ServerStreams: true, }, }, Metadata: "protocol.proto", } ================================================ FILE: protocol/vector_store.pb.go ================================================ // Copyright 2026 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 // protoc v7.34.0--rc2 // source: vector_store.proto package protocol import ( protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" timestamppb "google.golang.org/protobuf/types/known/timestamppb" reflect "reflect" sync "sync" unsafe "unsafe" ) const ( // Verify that this generated code is sufficiently up-to-date. _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) // Verify that runtime/protoimpl is sufficiently up-to-date. _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) type Distance int32 const ( Distance_Unknown Distance = 0 Distance_Cosine Distance = 1 Distance_Euclidean Distance = 2 Distance_Dot Distance = 3 ) // Enum value maps for Distance. var ( Distance_name = map[int32]string{ 0: "Unknown", 1: "Cosine", 2: "Euclidean", 3: "Dot", } Distance_value = map[string]int32{ "Unknown": 0, "Cosine": 1, "Euclidean": 2, "Dot": 3, } ) func (x Distance) Enum() *Distance { p := new(Distance) *p = x return p } func (x Distance) String() string { return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) } func (Distance) Descriptor() protoreflect.EnumDescriptor { return file_vector_store_proto_enumTypes[0].Descriptor() } func (Distance) Type() protoreflect.EnumType { return &file_vector_store_proto_enumTypes[0] } func (x Distance) Number() protoreflect.EnumNumber { return protoreflect.EnumNumber(x) } // Deprecated: Use Distance.Descriptor instead. func (Distance) EnumDescriptor() ([]byte, []int) { return file_vector_store_proto_rawDescGZIP(), []int{0} } type Vector struct { state protoimpl.MessageState `protogen:"open.v1"` Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` Values []float32 `protobuf:"fixed32,2,rep,packed,name=values,proto3" json:"values,omitempty"` Categories []string `protobuf:"bytes,3,rep,name=categories,proto3" json:"categories,omitempty"` Timestamp *timestamppb.Timestamp `protobuf:"bytes,4,opt,name=timestamp,proto3" json:"timestamp,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *Vector) Reset() { *x = Vector{} mi := &file_vector_store_proto_msgTypes[0] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *Vector) String() string { return protoimpl.X.MessageStringOf(x) } func (*Vector) ProtoMessage() {} func (x *Vector) ProtoReflect() protoreflect.Message { mi := &file_vector_store_proto_msgTypes[0] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use Vector.ProtoReflect.Descriptor instead. func (*Vector) Descriptor() ([]byte, []int) { return file_vector_store_proto_rawDescGZIP(), []int{0} } func (x *Vector) GetId() string { if x != nil { return x.Id } return "" } func (x *Vector) GetValues() []float32 { if x != nil { return x.Values } return nil } func (x *Vector) GetCategories() []string { if x != nil { return x.Categories } return nil } func (x *Vector) GetTimestamp() *timestamppb.Timestamp { if x != nil { return x.Timestamp } return nil } type ListCollectionsRequest struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *ListCollectionsRequest) Reset() { *x = ListCollectionsRequest{} mi := &file_vector_store_proto_msgTypes[1] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *ListCollectionsRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*ListCollectionsRequest) ProtoMessage() {} func (x *ListCollectionsRequest) ProtoReflect() protoreflect.Message { mi := &file_vector_store_proto_msgTypes[1] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use ListCollectionsRequest.ProtoReflect.Descriptor instead. func (*ListCollectionsRequest) Descriptor() ([]byte, []int) { return file_vector_store_proto_rawDescGZIP(), []int{1} } type ListCollectionsResponse struct { state protoimpl.MessageState `protogen:"open.v1"` Collections []string `protobuf:"bytes,1,rep,name=collections,proto3" json:"collections,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *ListCollectionsResponse) Reset() { *x = ListCollectionsResponse{} mi := &file_vector_store_proto_msgTypes[2] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *ListCollectionsResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*ListCollectionsResponse) ProtoMessage() {} func (x *ListCollectionsResponse) ProtoReflect() protoreflect.Message { mi := &file_vector_store_proto_msgTypes[2] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use ListCollectionsResponse.ProtoReflect.Descriptor instead. func (*ListCollectionsResponse) Descriptor() ([]byte, []int) { return file_vector_store_proto_rawDescGZIP(), []int{2} } func (x *ListCollectionsResponse) GetCollections() []string { if x != nil { return x.Collections } return nil } type AddCollectionRequest struct { state protoimpl.MessageState `protogen:"open.v1"` Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` Dimensions int32 `protobuf:"varint,2,opt,name=dimensions,proto3" json:"dimensions,omitempty"` Distance Distance `protobuf:"varint,3,opt,name=distance,proto3,enum=protocol.Distance" json:"distance,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *AddCollectionRequest) Reset() { *x = AddCollectionRequest{} mi := &file_vector_store_proto_msgTypes[3] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *AddCollectionRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*AddCollectionRequest) ProtoMessage() {} func (x *AddCollectionRequest) ProtoReflect() protoreflect.Message { mi := &file_vector_store_proto_msgTypes[3] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use AddCollectionRequest.ProtoReflect.Descriptor instead. func (*AddCollectionRequest) Descriptor() ([]byte, []int) { return file_vector_store_proto_rawDescGZIP(), []int{3} } func (x *AddCollectionRequest) GetName() string { if x != nil { return x.Name } return "" } func (x *AddCollectionRequest) GetDimensions() int32 { if x != nil { return x.Dimensions } return 0 } func (x *AddCollectionRequest) GetDistance() Distance { if x != nil { return x.Distance } return Distance_Unknown } type AddCollectionResponse struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *AddCollectionResponse) Reset() { *x = AddCollectionResponse{} mi := &file_vector_store_proto_msgTypes[4] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *AddCollectionResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*AddCollectionResponse) ProtoMessage() {} func (x *AddCollectionResponse) ProtoReflect() protoreflect.Message { mi := &file_vector_store_proto_msgTypes[4] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use AddCollectionResponse.ProtoReflect.Descriptor instead. func (*AddCollectionResponse) Descriptor() ([]byte, []int) { return file_vector_store_proto_rawDescGZIP(), []int{4} } type DeleteCollectionRequest struct { state protoimpl.MessageState `protogen:"open.v1"` Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *DeleteCollectionRequest) Reset() { *x = DeleteCollectionRequest{} mi := &file_vector_store_proto_msgTypes[5] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *DeleteCollectionRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*DeleteCollectionRequest) ProtoMessage() {} func (x *DeleteCollectionRequest) ProtoReflect() protoreflect.Message { mi := &file_vector_store_proto_msgTypes[5] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use DeleteCollectionRequest.ProtoReflect.Descriptor instead. func (*DeleteCollectionRequest) Descriptor() ([]byte, []int) { return file_vector_store_proto_rawDescGZIP(), []int{5} } func (x *DeleteCollectionRequest) GetName() string { if x != nil { return x.Name } return "" } type DeleteCollectionResponse struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *DeleteCollectionResponse) Reset() { *x = DeleteCollectionResponse{} mi := &file_vector_store_proto_msgTypes[6] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *DeleteCollectionResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*DeleteCollectionResponse) ProtoMessage() {} func (x *DeleteCollectionResponse) ProtoReflect() protoreflect.Message { mi := &file_vector_store_proto_msgTypes[6] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use DeleteCollectionResponse.ProtoReflect.Descriptor instead. func (*DeleteCollectionResponse) Descriptor() ([]byte, []int) { return file_vector_store_proto_rawDescGZIP(), []int{6} } type AddVectorsRequest struct { state protoimpl.MessageState `protogen:"open.v1"` Collection string `protobuf:"bytes,1,opt,name=collection,proto3" json:"collection,omitempty"` Vectors []*Vector `protobuf:"bytes,2,rep,name=vectors,proto3" json:"vectors,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *AddVectorsRequest) Reset() { *x = AddVectorsRequest{} mi := &file_vector_store_proto_msgTypes[7] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *AddVectorsRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*AddVectorsRequest) ProtoMessage() {} func (x *AddVectorsRequest) ProtoReflect() protoreflect.Message { mi := &file_vector_store_proto_msgTypes[7] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use AddVectorsRequest.ProtoReflect.Descriptor instead. func (*AddVectorsRequest) Descriptor() ([]byte, []int) { return file_vector_store_proto_rawDescGZIP(), []int{7} } func (x *AddVectorsRequest) GetCollection() string { if x != nil { return x.Collection } return "" } func (x *AddVectorsRequest) GetVectors() []*Vector { if x != nil { return x.Vectors } return nil } type AddVectorsResponse struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *AddVectorsResponse) Reset() { *x = AddVectorsResponse{} mi := &file_vector_store_proto_msgTypes[8] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *AddVectorsResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*AddVectorsResponse) ProtoMessage() {} func (x *AddVectorsResponse) ProtoReflect() protoreflect.Message { mi := &file_vector_store_proto_msgTypes[8] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use AddVectorsResponse.ProtoReflect.Descriptor instead. func (*AddVectorsResponse) Descriptor() ([]byte, []int) { return file_vector_store_proto_rawDescGZIP(), []int{8} } type DeleteVectorsRequest struct { state protoimpl.MessageState `protogen:"open.v1"` Collection string `protobuf:"bytes,1,opt,name=collection,proto3" json:"collection,omitempty"` Timestamp *timestamppb.Timestamp `protobuf:"bytes,2,opt,name=timestamp,proto3" json:"timestamp,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *DeleteVectorsRequest) Reset() { *x = DeleteVectorsRequest{} mi := &file_vector_store_proto_msgTypes[9] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *DeleteVectorsRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*DeleteVectorsRequest) ProtoMessage() {} func (x *DeleteVectorsRequest) ProtoReflect() protoreflect.Message { mi := &file_vector_store_proto_msgTypes[9] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use DeleteVectorsRequest.ProtoReflect.Descriptor instead. func (*DeleteVectorsRequest) Descriptor() ([]byte, []int) { return file_vector_store_proto_rawDescGZIP(), []int{9} } func (x *DeleteVectorsRequest) GetCollection() string { if x != nil { return x.Collection } return "" } func (x *DeleteVectorsRequest) GetTimestamp() *timestamppb.Timestamp { if x != nil { return x.Timestamp } return nil } type DeleteVectorsResponse struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *DeleteVectorsResponse) Reset() { *x = DeleteVectorsResponse{} mi := &file_vector_store_proto_msgTypes[10] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *DeleteVectorsResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*DeleteVectorsResponse) ProtoMessage() {} func (x *DeleteVectorsResponse) ProtoReflect() protoreflect.Message { mi := &file_vector_store_proto_msgTypes[10] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use DeleteVectorsResponse.ProtoReflect.Descriptor instead. func (*DeleteVectorsResponse) Descriptor() ([]byte, []int) { return file_vector_store_proto_rawDescGZIP(), []int{10} } type QueryVectorsRequest struct { state protoimpl.MessageState `protogen:"open.v1"` Collection string `protobuf:"bytes,1,opt,name=collection,proto3" json:"collection,omitempty"` Query []float32 `protobuf:"fixed32,2,rep,packed,name=query,proto3" json:"query,omitempty"` Categories []string `protobuf:"bytes,3,rep,name=categories,proto3" json:"categories,omitempty"` TopK int32 `protobuf:"varint,4,opt,name=top_k,json=topK,proto3" json:"top_k,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *QueryVectorsRequest) Reset() { *x = QueryVectorsRequest{} mi := &file_vector_store_proto_msgTypes[11] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *QueryVectorsRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*QueryVectorsRequest) ProtoMessage() {} func (x *QueryVectorsRequest) ProtoReflect() protoreflect.Message { mi := &file_vector_store_proto_msgTypes[11] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use QueryVectorsRequest.ProtoReflect.Descriptor instead. func (*QueryVectorsRequest) Descriptor() ([]byte, []int) { return file_vector_store_proto_rawDescGZIP(), []int{11} } func (x *QueryVectorsRequest) GetCollection() string { if x != nil { return x.Collection } return "" } func (x *QueryVectorsRequest) GetQuery() []float32 { if x != nil { return x.Query } return nil } func (x *QueryVectorsRequest) GetCategories() []string { if x != nil { return x.Categories } return nil } func (x *QueryVectorsRequest) GetTopK() int32 { if x != nil { return x.TopK } return 0 } type QueryVectorsResponse struct { state protoimpl.MessageState `protogen:"open.v1"` Vectors []*Vector `protobuf:"bytes,1,rep,name=vectors,proto3" json:"vectors,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *QueryVectorsResponse) Reset() { *x = QueryVectorsResponse{} mi := &file_vector_store_proto_msgTypes[12] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } func (x *QueryVectorsResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*QueryVectorsResponse) ProtoMessage() {} func (x *QueryVectorsResponse) ProtoReflect() protoreflect.Message { mi := &file_vector_store_proto_msgTypes[12] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use QueryVectorsResponse.ProtoReflect.Descriptor instead. func (*QueryVectorsResponse) Descriptor() ([]byte, []int) { return file_vector_store_proto_rawDescGZIP(), []int{12} } func (x *QueryVectorsResponse) GetVectors() []*Vector { if x != nil { return x.Vectors } return nil } var File_vector_store_proto protoreflect.FileDescriptor const file_vector_store_proto_rawDesc = "" + "\n" + "\x12vector_store.proto\x12\bprotocol\x1a\x1fgoogle/protobuf/timestamp.proto\"\x8a\x01\n" + "\x06Vector\x12\x0e\n" + "\x02id\x18\x01 \x01(\tR\x02id\x12\x16\n" + "\x06values\x18\x02 \x03(\x02R\x06values\x12\x1e\n" + "\n" + "categories\x18\x03 \x03(\tR\n" + "categories\x128\n" + "\ttimestamp\x18\x04 \x01(\v2\x1a.google.protobuf.TimestampR\ttimestamp\"\x18\n" + "\x16ListCollectionsRequest\";\n" + "\x17ListCollectionsResponse\x12 \n" + "\vcollections\x18\x01 \x03(\tR\vcollections\"z\n" + "\x14AddCollectionRequest\x12\x12\n" + "\x04name\x18\x01 \x01(\tR\x04name\x12\x1e\n" + "\n" + "dimensions\x18\x02 \x01(\x05R\n" + "dimensions\x12.\n" + "\bdistance\x18\x03 \x01(\x0e2\x12.protocol.DistanceR\bdistance\"\x17\n" + "\x15AddCollectionResponse\"-\n" + "\x17DeleteCollectionRequest\x12\x12\n" + "\x04name\x18\x01 \x01(\tR\x04name\"\x1a\n" + "\x18DeleteCollectionResponse\"_\n" + "\x11AddVectorsRequest\x12\x1e\n" + "\n" + "collection\x18\x01 \x01(\tR\n" + "collection\x12*\n" + "\avectors\x18\x02 \x03(\v2\x10.protocol.VectorR\avectors\"\x14\n" + "\x12AddVectorsResponse\"p\n" + "\x14DeleteVectorsRequest\x12\x1e\n" + "\n" + "collection\x18\x01 \x01(\tR\n" + "collection\x128\n" + "\ttimestamp\x18\x02 \x01(\v2\x1a.google.protobuf.TimestampR\ttimestamp\"\x17\n" + "\x15DeleteVectorsResponse\"\x80\x01\n" + "\x13QueryVectorsRequest\x12\x1e\n" + "\n" + "collection\x18\x01 \x01(\tR\n" + "collection\x12\x14\n" + "\x05query\x18\x02 \x03(\x02R\x05query\x12\x1e\n" + "\n" + "categories\x18\x03 \x03(\tR\n" + "categories\x12\x13\n" + "\x05top_k\x18\x04 \x01(\x05R\x04topK\"B\n" + "\x14QueryVectorsResponse\x12*\n" + "\avectors\x18\x01 \x03(\v2\x10.protocol.VectorR\avectors*;\n" + "\bDistance\x12\v\n" + "\aUnknown\x10\x00\x12\n" + "\n" + "\x06Cosine\x10\x01\x12\r\n" + "\tEuclidean\x10\x02\x12\a\n" + "\x03Dot\x10\x032\x88\x04\n" + "\vVectorStore\x12X\n" + "\x0fListCollections\x12 .protocol.ListCollectionsRequest\x1a!.protocol.ListCollectionsResponse\"\x00\x12R\n" + "\rAddCollection\x12\x1e.protocol.AddCollectionRequest\x1a\x1f.protocol.AddCollectionResponse\"\x00\x12[\n" + "\x10DeleteCollection\x12!.protocol.DeleteCollectionRequest\x1a\".protocol.DeleteCollectionResponse\"\x00\x12I\n" + "\n" + "AddVectors\x12\x1b.protocol.AddVectorsRequest\x1a\x1c.protocol.AddVectorsResponse\"\x00\x12R\n" + "\rDeleteVectors\x12\x1e.protocol.DeleteVectorsRequest\x1a\x1f.protocol.DeleteVectorsResponse\"\x00\x12O\n" + "\fQueryVectors\x12\x1d.protocol.QueryVectorsRequest\x1a\x1e.protocol.QueryVectorsResponse\"\x00B$Z\"github.com/gorse-io/gorse/protocolb\x06proto3" var ( file_vector_store_proto_rawDescOnce sync.Once file_vector_store_proto_rawDescData []byte ) func file_vector_store_proto_rawDescGZIP() []byte { file_vector_store_proto_rawDescOnce.Do(func() { file_vector_store_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_vector_store_proto_rawDesc), len(file_vector_store_proto_rawDesc))) }) return file_vector_store_proto_rawDescData } var file_vector_store_proto_enumTypes = make([]protoimpl.EnumInfo, 1) var file_vector_store_proto_msgTypes = make([]protoimpl.MessageInfo, 13) var file_vector_store_proto_goTypes = []any{ (Distance)(0), // 0: protocol.Distance (*Vector)(nil), // 1: protocol.Vector (*ListCollectionsRequest)(nil), // 2: protocol.ListCollectionsRequest (*ListCollectionsResponse)(nil), // 3: protocol.ListCollectionsResponse (*AddCollectionRequest)(nil), // 4: protocol.AddCollectionRequest (*AddCollectionResponse)(nil), // 5: protocol.AddCollectionResponse (*DeleteCollectionRequest)(nil), // 6: protocol.DeleteCollectionRequest (*DeleteCollectionResponse)(nil), // 7: protocol.DeleteCollectionResponse (*AddVectorsRequest)(nil), // 8: protocol.AddVectorsRequest (*AddVectorsResponse)(nil), // 9: protocol.AddVectorsResponse (*DeleteVectorsRequest)(nil), // 10: protocol.DeleteVectorsRequest (*DeleteVectorsResponse)(nil), // 11: protocol.DeleteVectorsResponse (*QueryVectorsRequest)(nil), // 12: protocol.QueryVectorsRequest (*QueryVectorsResponse)(nil), // 13: protocol.QueryVectorsResponse (*timestamppb.Timestamp)(nil), // 14: google.protobuf.Timestamp } var file_vector_store_proto_depIdxs = []int32{ 14, // 0: protocol.Vector.timestamp:type_name -> google.protobuf.Timestamp 0, // 1: protocol.AddCollectionRequest.distance:type_name -> protocol.Distance 1, // 2: protocol.AddVectorsRequest.vectors:type_name -> protocol.Vector 14, // 3: protocol.DeleteVectorsRequest.timestamp:type_name -> google.protobuf.Timestamp 1, // 4: protocol.QueryVectorsResponse.vectors:type_name -> protocol.Vector 2, // 5: protocol.VectorStore.ListCollections:input_type -> protocol.ListCollectionsRequest 4, // 6: protocol.VectorStore.AddCollection:input_type -> protocol.AddCollectionRequest 6, // 7: protocol.VectorStore.DeleteCollection:input_type -> protocol.DeleteCollectionRequest 8, // 8: protocol.VectorStore.AddVectors:input_type -> protocol.AddVectorsRequest 10, // 9: protocol.VectorStore.DeleteVectors:input_type -> protocol.DeleteVectorsRequest 12, // 10: protocol.VectorStore.QueryVectors:input_type -> protocol.QueryVectorsRequest 3, // 11: protocol.VectorStore.ListCollections:output_type -> protocol.ListCollectionsResponse 5, // 12: protocol.VectorStore.AddCollection:output_type -> protocol.AddCollectionResponse 7, // 13: protocol.VectorStore.DeleteCollection:output_type -> protocol.DeleteCollectionResponse 9, // 14: protocol.VectorStore.AddVectors:output_type -> protocol.AddVectorsResponse 11, // 15: protocol.VectorStore.DeleteVectors:output_type -> protocol.DeleteVectorsResponse 13, // 16: protocol.VectorStore.QueryVectors:output_type -> protocol.QueryVectorsResponse 11, // [11:17] is the sub-list for method output_type 5, // [5:11] is the sub-list for method input_type 5, // [5:5] is the sub-list for extension type_name 5, // [5:5] is the sub-list for extension extendee 0, // [0:5] is the sub-list for field type_name } func init() { file_vector_store_proto_init() } func file_vector_store_proto_init() { if File_vector_store_proto != nil { return } type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_vector_store_proto_rawDesc), len(file_vector_store_proto_rawDesc)), NumEnums: 1, NumMessages: 13, NumExtensions: 0, NumServices: 1, }, GoTypes: file_vector_store_proto_goTypes, DependencyIndexes: file_vector_store_proto_depIdxs, EnumInfos: file_vector_store_proto_enumTypes, MessageInfos: file_vector_store_proto_msgTypes, }.Build() File_vector_store_proto = out.File file_vector_store_proto_goTypes = nil file_vector_store_proto_depIdxs = nil } ================================================ FILE: protocol/vector_store.proto ================================================ // Copyright 2026 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. syntax = "proto3"; option go_package = "github.com/gorse-io/gorse/protocol"; package protocol; import "google/protobuf/timestamp.proto"; enum Distance { Unknown = 0; Cosine = 1; Euclidean = 2; Dot = 3; } message Vector { string id = 1; repeated float values = 2; repeated string categories = 3; google.protobuf.Timestamp timestamp = 4; } message ListCollectionsRequest {} message ListCollectionsResponse { repeated string collections = 1; } message AddCollectionRequest { string name = 1; int32 dimensions = 2; Distance distance = 3; } message AddCollectionResponse {} message DeleteCollectionRequest { string name = 1; } message DeleteCollectionResponse {} message AddVectorsRequest { string collection = 1; repeated Vector vectors = 2; } message AddVectorsResponse {} message DeleteVectorsRequest { string collection = 1; google.protobuf.Timestamp timestamp = 2; } message DeleteVectorsResponse {} message QueryVectorsRequest { string collection = 1; repeated float query = 2; repeated string categories = 3; int32 top_k = 4; } message QueryVectorsResponse { repeated Vector vectors = 1; } service VectorStore { rpc ListCollections(ListCollectionsRequest) returns (ListCollectionsResponse) {} rpc AddCollection(AddCollectionRequest) returns (AddCollectionResponse) {} rpc DeleteCollection(DeleteCollectionRequest) returns (DeleteCollectionResponse) {} rpc AddVectors(AddVectorsRequest) returns (AddVectorsResponse) {} rpc DeleteVectors(DeleteVectorsRequest) returns (DeleteVectorsResponse) {} rpc QueryVectors(QueryVectorsRequest) returns (QueryVectorsResponse) {} } ================================================ FILE: protocol/vector_store_grpc.pb.go ================================================ // Copyright 2026 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: // - protoc-gen-go-grpc v1.6.1 // - protoc v7.34.0--rc2 // source: vector_store.proto package protocol import ( context "context" grpc "google.golang.org/grpc" codes "google.golang.org/grpc/codes" status "google.golang.org/grpc/status" ) // This is a compile-time assertion to ensure that this generated file // is compatible with the grpc package it is being compiled against. // Requires gRPC-Go v1.64.0 or later. const _ = grpc.SupportPackageIsVersion9 const ( VectorStore_ListCollections_FullMethodName = "/protocol.VectorStore/ListCollections" VectorStore_AddCollection_FullMethodName = "/protocol.VectorStore/AddCollection" VectorStore_DeleteCollection_FullMethodName = "/protocol.VectorStore/DeleteCollection" VectorStore_AddVectors_FullMethodName = "/protocol.VectorStore/AddVectors" VectorStore_DeleteVectors_FullMethodName = "/protocol.VectorStore/DeleteVectors" VectorStore_QueryVectors_FullMethodName = "/protocol.VectorStore/QueryVectors" ) // VectorStoreClient is the client API for VectorStore service. // // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. type VectorStoreClient interface { ListCollections(ctx context.Context, in *ListCollectionsRequest, opts ...grpc.CallOption) (*ListCollectionsResponse, error) AddCollection(ctx context.Context, in *AddCollectionRequest, opts ...grpc.CallOption) (*AddCollectionResponse, error) DeleteCollection(ctx context.Context, in *DeleteCollectionRequest, opts ...grpc.CallOption) (*DeleteCollectionResponse, error) AddVectors(ctx context.Context, in *AddVectorsRequest, opts ...grpc.CallOption) (*AddVectorsResponse, error) DeleteVectors(ctx context.Context, in *DeleteVectorsRequest, opts ...grpc.CallOption) (*DeleteVectorsResponse, error) QueryVectors(ctx context.Context, in *QueryVectorsRequest, opts ...grpc.CallOption) (*QueryVectorsResponse, error) } type vectorStoreClient struct { cc grpc.ClientConnInterface } func NewVectorStoreClient(cc grpc.ClientConnInterface) VectorStoreClient { return &vectorStoreClient{cc} } func (c *vectorStoreClient) ListCollections(ctx context.Context, in *ListCollectionsRequest, opts ...grpc.CallOption) (*ListCollectionsResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(ListCollectionsResponse) err := c.cc.Invoke(ctx, VectorStore_ListCollections_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } func (c *vectorStoreClient) AddCollection(ctx context.Context, in *AddCollectionRequest, opts ...grpc.CallOption) (*AddCollectionResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(AddCollectionResponse) err := c.cc.Invoke(ctx, VectorStore_AddCollection_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } func (c *vectorStoreClient) DeleteCollection(ctx context.Context, in *DeleteCollectionRequest, opts ...grpc.CallOption) (*DeleteCollectionResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(DeleteCollectionResponse) err := c.cc.Invoke(ctx, VectorStore_DeleteCollection_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } func (c *vectorStoreClient) AddVectors(ctx context.Context, in *AddVectorsRequest, opts ...grpc.CallOption) (*AddVectorsResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(AddVectorsResponse) err := c.cc.Invoke(ctx, VectorStore_AddVectors_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } func (c *vectorStoreClient) DeleteVectors(ctx context.Context, in *DeleteVectorsRequest, opts ...grpc.CallOption) (*DeleteVectorsResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(DeleteVectorsResponse) err := c.cc.Invoke(ctx, VectorStore_DeleteVectors_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } func (c *vectorStoreClient) QueryVectors(ctx context.Context, in *QueryVectorsRequest, opts ...grpc.CallOption) (*QueryVectorsResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(QueryVectorsResponse) err := c.cc.Invoke(ctx, VectorStore_QueryVectors_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } // VectorStoreServer is the server API for VectorStore service. // All implementations must embed UnimplementedVectorStoreServer // for forward compatibility. type VectorStoreServer interface { ListCollections(context.Context, *ListCollectionsRequest) (*ListCollectionsResponse, error) AddCollection(context.Context, *AddCollectionRequest) (*AddCollectionResponse, error) DeleteCollection(context.Context, *DeleteCollectionRequest) (*DeleteCollectionResponse, error) AddVectors(context.Context, *AddVectorsRequest) (*AddVectorsResponse, error) DeleteVectors(context.Context, *DeleteVectorsRequest) (*DeleteVectorsResponse, error) QueryVectors(context.Context, *QueryVectorsRequest) (*QueryVectorsResponse, error) mustEmbedUnimplementedVectorStoreServer() } // UnimplementedVectorStoreServer must be embedded to have // forward compatible implementations. // // NOTE: this should be embedded by value instead of pointer to avoid a nil // pointer dereference when methods are called. type UnimplementedVectorStoreServer struct{} func (UnimplementedVectorStoreServer) ListCollections(context.Context, *ListCollectionsRequest) (*ListCollectionsResponse, error) { return nil, status.Error(codes.Unimplemented, "method ListCollections not implemented") } func (UnimplementedVectorStoreServer) AddCollection(context.Context, *AddCollectionRequest) (*AddCollectionResponse, error) { return nil, status.Error(codes.Unimplemented, "method AddCollection not implemented") } func (UnimplementedVectorStoreServer) DeleteCollection(context.Context, *DeleteCollectionRequest) (*DeleteCollectionResponse, error) { return nil, status.Error(codes.Unimplemented, "method DeleteCollection not implemented") } func (UnimplementedVectorStoreServer) AddVectors(context.Context, *AddVectorsRequest) (*AddVectorsResponse, error) { return nil, status.Error(codes.Unimplemented, "method AddVectors not implemented") } func (UnimplementedVectorStoreServer) DeleteVectors(context.Context, *DeleteVectorsRequest) (*DeleteVectorsResponse, error) { return nil, status.Error(codes.Unimplemented, "method DeleteVectors not implemented") } func (UnimplementedVectorStoreServer) QueryVectors(context.Context, *QueryVectorsRequest) (*QueryVectorsResponse, error) { return nil, status.Error(codes.Unimplemented, "method QueryVectors not implemented") } func (UnimplementedVectorStoreServer) mustEmbedUnimplementedVectorStoreServer() {} func (UnimplementedVectorStoreServer) testEmbeddedByValue() {} // UnsafeVectorStoreServer may be embedded to opt out of forward compatibility for this service. // Use of this interface is not recommended, as added methods to VectorStoreServer will // result in compilation errors. type UnsafeVectorStoreServer interface { mustEmbedUnimplementedVectorStoreServer() } func RegisterVectorStoreServer(s grpc.ServiceRegistrar, srv VectorStoreServer) { // If the following call panics, it indicates UnimplementedVectorStoreServer was // embedded by pointer and is nil. This will cause panics if an // unimplemented method is ever invoked, so we test this at initialization // time to prevent it from happening at runtime later due to I/O. if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { t.testEmbeddedByValue() } s.RegisterService(&VectorStore_ServiceDesc, srv) } func _VectorStore_ListCollections_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(ListCollectionsRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(VectorStoreServer).ListCollections(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: VectorStore_ListCollections_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(VectorStoreServer).ListCollections(ctx, req.(*ListCollectionsRequest)) } return interceptor(ctx, in, info, handler) } func _VectorStore_AddCollection_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(AddCollectionRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(VectorStoreServer).AddCollection(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: VectorStore_AddCollection_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(VectorStoreServer).AddCollection(ctx, req.(*AddCollectionRequest)) } return interceptor(ctx, in, info, handler) } func _VectorStore_DeleteCollection_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(DeleteCollectionRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(VectorStoreServer).DeleteCollection(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: VectorStore_DeleteCollection_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(VectorStoreServer).DeleteCollection(ctx, req.(*DeleteCollectionRequest)) } return interceptor(ctx, in, info, handler) } func _VectorStore_AddVectors_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(AddVectorsRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(VectorStoreServer).AddVectors(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: VectorStore_AddVectors_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(VectorStoreServer).AddVectors(ctx, req.(*AddVectorsRequest)) } return interceptor(ctx, in, info, handler) } func _VectorStore_DeleteVectors_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(DeleteVectorsRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(VectorStoreServer).DeleteVectors(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: VectorStore_DeleteVectors_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(VectorStoreServer).DeleteVectors(ctx, req.(*DeleteVectorsRequest)) } return interceptor(ctx, in, info, handler) } func _VectorStore_QueryVectors_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(QueryVectorsRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(VectorStoreServer).QueryVectors(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: VectorStore_QueryVectors_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(VectorStoreServer).QueryVectors(ctx, req.(*QueryVectorsRequest)) } return interceptor(ctx, in, info, handler) } // VectorStore_ServiceDesc is the grpc.ServiceDesc for VectorStore service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) var VectorStore_ServiceDesc = grpc.ServiceDesc{ ServiceName: "protocol.VectorStore", HandlerType: (*VectorStoreServer)(nil), Methods: []grpc.MethodDesc{ { MethodName: "ListCollections", Handler: _VectorStore_ListCollections_Handler, }, { MethodName: "AddCollection", Handler: _VectorStore_AddCollection_Handler, }, { MethodName: "DeleteCollection", Handler: _VectorStore_DeleteCollection_Handler, }, { MethodName: "AddVectors", Handler: _VectorStore_AddVectors_Handler, }, { MethodName: "DeleteVectors", Handler: _VectorStore_DeleteVectors_Handler, }, { MethodName: "QueryVectors", Handler: _VectorStore_QueryVectors_Handler, }, }, Streams: []grpc.StreamDesc{}, Metadata: "vector_store.proto", } ================================================ FILE: server/metrics.go ================================================ // Copyright 2021 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package server import ( "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" ) var ( RestAPIRequestSecondsVec = promauto.NewHistogramVec(prometheus.HistogramOpts{ Namespace: "gorse", Subsystem: "server", Name: "rest_api_request_seconds", }, []string{"api"}) ) ================================================ FILE: server/rest.go ================================================ // Copyright 2020 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package server import ( "context" "encoding/json" "fmt" "net/http" "net/http/pprof" "strconv" "strings" "time" "github.com/araddon/dateparse" mapset "github.com/deckarep/golang-set/v2" restfulspec "github.com/emicklei/go-restful-openapi/v2" "github.com/emicklei/go-restful/v3" "github.com/google/uuid" "github.com/gorse-io/gorse/common/expression" "github.com/gorse-io/gorse/common/heap" "github.com/gorse-io/gorse/common/log" "github.com/gorse-io/gorse/config" "github.com/gorse-io/gorse/logics" "github.com/gorse-io/gorse/storage/cache" "github.com/gorse-io/gorse/storage/data" "github.com/juju/errors" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/samber/lo" "github.com/swaggest/swgui/v5emb" "go.opentelemetry.io/contrib/instrumentation/github.com/emicklei/go-restful/otelrestful" "go.uber.org/zap" ) const ( HealthAPITag = "health" UsersAPITag = "users" ItemsAPITag = "items" FeedbackAPITag = "feedback" RecommendationAPITag = "recommendation" MeasurementsAPITag = "measurements" DetractedAPITag = "deprecated" apiDocsPath = "/apidocs/" ) // RestServer implements a REST-ful API server. type RestServer struct { Config *config.Config CacheClient cache.Database DataClient data.Database HttpHost string HttpPort int DisableLog bool WebService *restful.WebService HttpServer *http.Server } // StartHttpServer starts the REST-ful API server. func (s *RestServer) StartHttpServer(container *restful.Container) { // register restful APIs s.CreateWebService() container.Add(s.WebService) // register swagger UI specConfig := restfulspec.Config{ WebServices: []*restful.WebService{s.WebService}, APIPath: "/apidocs.json", } container.Add(restfulspec.NewOpenAPIService(specConfig)) container.Handle(apiDocsPath, v5emb.New( "Gorse REST API", "/apidocs.json", apiDocsPath, )) // register prometheus container.Handle("/metrics", promhttp.Handler()) // register pprof container.Handle("/debug/pprof/", http.HandlerFunc(pprof.Index)) container.Handle("/debug/pprof/cmdline", http.HandlerFunc(pprof.Cmdline)) container.Handle("/debug/pprof/profile", http.HandlerFunc(pprof.Profile)) container.Handle("/debug/pprof/symbol", http.HandlerFunc(pprof.Symbol)) container.Handle("/debug/pprof/trace", http.HandlerFunc(pprof.Trace)) container.Handle("/debug/pprof/allocs", pprof.Handler("allocs")) container.Handle("/debug/pprof/block", pprof.Handler("block")) container.Handle("/debug/pprof/goroutine", pprof.Handler("goroutine")) container.Handle("/debug/pprof/heap", pprof.Handler("heap")) container.Handle("/debug/pprof/mutex", pprof.Handler("mutex")) container.Handle("/debug/pprof/threadcreate", pprof.Handler("threadcreate")) // Add container filter to enable CORS cors := restful.CrossOriginResourceSharing{ AllowedHeaders: []string{"Content-Type", "Accept", "X-API-Key"}, AllowedDomains: s.Config.Master.HttpCorsDomains, AllowedMethods: s.Config.Master.HttpCorsMethods, CookiesAllowed: false, Container: container} container.Filter(cors.Filter) log.Logger().Info("start http server", zap.String("url", fmt.Sprintf("http://%s:%d", s.HttpHost, s.HttpPort)), zap.Strings("cors_methods", s.Config.Master.HttpCorsMethods), zap.Strings("cors_domains", s.Config.Master.HttpCorsDomains), ) s.HttpServer = &http.Server{ Addr: fmt.Sprintf("%s:%d", s.HttpHost, s.HttpPort), Handler: container, } if err := s.HttpServer.ListenAndServe(); err != http.ErrServerClosed { log.Logger().Fatal("failed to start http server", zap.Error(err)) } } func (s *RestServer) LogFilter(req *restful.Request, resp *restful.Response, chain *restful.FilterChain) { // generate request id requestId := uuid.New().String() resp.AddHeader("X-Request-ID", requestId) start := time.Now() chain.ProcessFilter(req, resp) responseTime := time.Since(start) if !s.DisableLog && req.Request.URL.Path != "/api/dashboard/cluster" && req.Request.URL.Path != "/api/dashboard/tasks" { log.ResponseLogger(resp).Info(fmt.Sprintf("%s %s", req.Request.Method, req.Request.URL), zap.Int("status_code", resp.StatusCode()), zap.Duration("response_time", responseTime)) } } func (s *RestServer) AuthFilter(req *restful.Request, resp *restful.Response, chain *restful.FilterChain) { if strings.HasPrefix(req.SelectedRoute().Path(), "/api/health/") { // Health check APIs don't need API key, chain.ProcessFilter(req, resp) return } if s.Config.Server.APIKey == "" { chain.ProcessFilter(req, resp) return } apikey := req.HeaderParameter("X-API-Key") if apikey == s.Config.Server.APIKey { chain.ProcessFilter(req, resp) return } log.ResponseLogger(resp).Error("unauthorized", zap.String("api_key", s.Config.Server.APIKey), zap.String("X-API-Key", apikey)) if err := resp.WriteError(http.StatusUnauthorized, fmt.Errorf("unauthorized")); err != nil { log.ResponseLogger(resp).Error("failed to write error", zap.Error(err)) } } func (s *RestServer) MetricsFilter(req *restful.Request, resp *restful.Response, chain *restful.FilterChain) { startTime := time.Now() chain.ProcessFilter(req, resp) if req.SelectedRoute() != nil && resp.StatusCode() == http.StatusOK { routePath := req.SelectedRoutePath() if !strings.HasPrefix(routePath, "/api/dashboard") { RestAPIRequestSecondsVec.WithLabelValues(fmt.Sprintf("%s %s", req.Request.Method, routePath)). Observe(time.Since(startTime).Seconds()) } } } // CreateWebService creates web service. func (s *RestServer) CreateWebService() { // Create a server ws := s.WebService ws.Path("/api/"). Produces(restful.MIME_JSON). Filter(s.LogFilter). Filter(s.AuthFilter). Filter(s.MetricsFilter). Filter(otelrestful.OTelFilter("gorse")) /* Health check */ ws.Route(ws.GET("/health/live").To(s.checkLive). Doc("Probe the liveness of this node. Return OK once the server starts."). Metadata(restfulspec.KeyOpenAPITags, []string{HealthAPITag}). Returns(http.StatusOK, "OK", HealthStatus{}). Writes(HealthStatus{})) ws.Route(ws.GET("/health/ready").To(s.checkReady). Doc("Probe the readiness of this node. Return OK if the server is able to handle requests."). Metadata(restfulspec.KeyOpenAPITags, []string{HealthAPITag}). Returns(http.StatusOK, "OK", HealthStatus{}). Writes(HealthStatus{})) // Insert a user ws.Route(ws.POST("/user").To(s.insertUser). Doc("Insert a user."). Metadata(restfulspec.KeyOpenAPITags, []string{UsersAPITag}). Param(ws.HeaderParameter("X-API-Key", "API key").DataType("string")). Reads(data.User{}). Returns(http.StatusOK, "OK", Success{}). Writes(Success{})) // Modify a user ws.Route(ws.PATCH("/user/{user-id}").To(s.modifyUser). Doc("Modify a user."). Metadata(restfulspec.KeyOpenAPITags, []string{UsersAPITag}). Param(ws.HeaderParameter("X-API-Key", "API key").DataType("string")). Param(ws.PathParameter("user-id", "User ID").DataType("string")). Reads(data.UserPatch{}). Returns(http.StatusOK, "OK", Success{}). Writes(Success{})) // Get a user ws.Route(ws.GET("/user/{user-id}").To(s.getUser). Doc("Get a user."). Metadata(restfulspec.KeyOpenAPITags, []string{UsersAPITag}). Param(ws.HeaderParameter("X-API-Key", "API key").DataType("string")). Param(ws.PathParameter("user-id", "User ID").DataType("string")). Returns(http.StatusOK, "OK", data.User{}). Writes(data.User{})) // Insert users ws.Route(ws.POST("/users").To(s.insertUsers). Doc("Insert users."). Metadata(restfulspec.KeyOpenAPITags, []string{UsersAPITag}). Param(ws.HeaderParameter("X-API-Key", "API key").DataType("string")). Reads([]data.User{}). Returns(http.StatusOK, "OK", Success{}). Writes(Success{})) // Get users ws.Route(ws.GET("/users").To(s.getUsers). Doc("List users."). Metadata(restfulspec.KeyOpenAPITags, []string{UsersAPITag}). Param(ws.HeaderParameter("X-API-Key", "API key").DataType("string")). Param(ws.QueryParameter("n", "Number of returned users").DataType("integer")). Param(ws.QueryParameter("cursor", "Cursor for the next page").DataType("string")). Returns(http.StatusOK, "OK", UserIterator{}). Writes(UserIterator{})) // Delete a user ws.Route(ws.DELETE("/user/{user-id}").To(s.deleteUser). Doc("Delete a user. His or her feedback will also be deleted."). Metadata(restfulspec.KeyOpenAPITags, []string{UsersAPITag}). Param(ws.HeaderParameter("X-API-Key", "API key").DataType("string")). Param(ws.PathParameter("user-id", "User ID").DataType("string")). Returns(http.StatusOK, "OK", Success{}). Writes(Success{})) // Insert an item ws.Route(ws.POST("/item").To(s.insertItem). Doc("Insert an item."). Metadata(restfulspec.KeyOpenAPITags, []string{ItemsAPITag}). Param(ws.HeaderParameter("X-API-Key", "API key").DataType("string")). Reads(data.Item{}). Returns(http.StatusOK, "OK", Success{}). Writes(Success{})) // Modify an item ws.Route(ws.PATCH("/item/{item-id}").To(s.modifyItem). Doc("Modify an item."). Metadata(restfulspec.KeyOpenAPITags, []string{ItemsAPITag}). Param(ws.HeaderParameter("X-API-Key", "API key").DataType("string")). Param(ws.PathParameter("item-id", "Item ID").DataType("string")). Reads(data.ItemPatch{}). Returns(http.StatusOK, "OK", Success{}). Writes(Success{})) // Get items ws.Route(ws.GET("/items").To(s.getItems). Doc("List items."). Metadata(restfulspec.KeyOpenAPITags, []string{ItemsAPITag}). Param(ws.HeaderParameter("X-API-Key", "API key").DataType("string")). Param(ws.QueryParameter("n", "Number of returned items").DataType("integer")). Param(ws.QueryParameter("cursor", "Cursor for the next page").DataType("string")). Returns(http.StatusOK, "OK", ItemIterator{}). Writes(ItemIterator{})) // Get item ws.Route(ws.GET("/item/{item-id}").To(s.getItem). Doc("Get an item."). Metadata(restfulspec.KeyOpenAPITags, []string{ItemsAPITag}). Param(ws.HeaderParameter("X-API-Key", "API key").DataType("string")). Param(ws.PathParameter("item-id", "Item ID.").DataType("string")). Returns(http.StatusOK, "OK", data.Item{}). Writes(data.Item{})) // Insert items ws.Route(ws.POST("/items").To(s.insertItems). Doc("Insert items."). Metadata(restfulspec.KeyOpenAPITags, []string{ItemsAPITag}). Param(ws.HeaderParameter("X-API-Key", "API key").DataType("string")). Reads([]data.Item{}). Returns(http.StatusOK, "OK", Success{}). Writes(Success{})) // Delete item ws.Route(ws.DELETE("/item/{item-id}").To(s.deleteItem). Doc("Delete an item and its feedback."). Metadata(restfulspec.KeyOpenAPITags, []string{ItemsAPITag}). Param(ws.HeaderParameter("X-API-Key", "API key").DataType("string")). Param(ws.PathParameter("item-id", "Item ID").DataType("string")). Returns(http.StatusOK, "OK", Success{}). Writes(Success{})) // Insert category ws.Route(ws.PUT("/item/{item-id}/category/{category}").To(s.insertItemCategory). Doc("Insert a category for a item."). Metadata(restfulspec.KeyOpenAPITags, []string{ItemsAPITag}). Param(ws.HeaderParameter("X-API-Key", "API key").DataType("string")). Param(ws.PathParameter("item-id", "Item ID").DataType("string")). Param(ws.PathParameter("category", "Category to insert").DataType("string")). Returns(http.StatusOK, "OK", Success{}). Writes(Success{})) // Delete category ws.Route(ws.DELETE("/item/{item-id}/category/{category}").To(s.deleteItemCategory). Doc("Delete a category from a item."). Metadata(restfulspec.KeyOpenAPITags, []string{ItemsAPITag}). Param(ws.HeaderParameter("X-API-Key", "API key").DataType("string")). Param(ws.PathParameter("item-id", "Item ID").DataType("string")). Param(ws.PathParameter("category", "Category to delete").DataType("string")). Returns(http.StatusOK, "OK", Success{}). Writes(Success{})) // Insert feedback ws.Route(ws.POST("/feedback").To(s.insertFeedback(false)). Doc("Insert feedbacks. Accumulate value if feedback already exists."). Metadata(restfulspec.KeyOpenAPITags, []string{FeedbackAPITag}). Param(ws.HeaderParameter("X-API-Key", "API key").DataType("string")). Reads([]data.Feedback{}). Returns(http.StatusOK, "OK", Success{}). Writes(Success{})) ws.Route(ws.PUT("/feedback").To(s.insertFeedback(true)). Doc("Insert feedbacks. Existed feedback will be overwritten."). Metadata(restfulspec.KeyOpenAPITags, []string{FeedbackAPITag}). Param(ws.HeaderParameter("X-API-Key", "API key").DataType("string")). Reads([]data.Feedback{}). Returns(http.StatusOK, "OK", Success{}). Writes(Success{})) // Get feedback ws.Route(ws.GET("/feedback").To(s.getFeedback). Doc("List feedbacks."). Metadata(restfulspec.KeyOpenAPITags, []string{FeedbackAPITag}). Param(ws.HeaderParameter("X-API-Key", "API key").DataType("string")). Param(ws.QueryParameter("cursor", "Cursor for the next page").DataType("string")). Param(ws.QueryParameter("n", "Number of returned feedback").DataType("integer")). Returns(http.StatusOK, "OK", FeedbackIterator{}). Writes(FeedbackIterator{})) ws.Route(ws.GET("/feedback/{user-id}/{item-id}").To(s.getUserItemFeedback). Doc("List feedbacks between a user and a item."). Metadata(restfulspec.KeyOpenAPITags, []string{FeedbackAPITag}). Param(ws.HeaderParameter("X-API-Key", "API key").DataType("string")). Param(ws.PathParameter("user-id", "User ID").DataType("string")). Param(ws.PathParameter("item-id", "Item ID").DataType("string")). Returns(http.StatusOK, "OK", []data.Feedback{}). Writes([]data.Feedback{})) ws.Route(ws.DELETE("/feedback/{user-id}/{item-id}").To(s.deleteUserItemFeedback). Doc("Delete feedbacks between a user and a item."). Metadata(restfulspec.KeyOpenAPITags, []string{FeedbackAPITag}). Param(ws.HeaderParameter("X-API-Key", "API key").DataType("string")). Param(ws.PathParameter("user-id", "User ID").DataType("string")). Param(ws.PathParameter("item-id", "Item ID").DataType("string")). Returns(http.StatusOK, "OK", Success{}). Writes(Success{})) ws.Route(ws.GET("/feedback/{feedback-type}").To(s.getTypedFeedback). Doc("List feedbacks with feedback type."). Metadata(restfulspec.KeyOpenAPITags, []string{FeedbackAPITag}). Param(ws.HeaderParameter("X-API-Key", "API key").DataType("string")). Param(ws.PathParameter("feedback-type", "Feedback type").DataType("string")). Param(ws.QueryParameter("cursor", "Cursor for the next page").DataType("string")). Param(ws.QueryParameter("n", "Number of returned feedbacks").DataType("integer")). Returns(http.StatusOK, "OK", FeedbackIterator{}). Writes(FeedbackIterator{})) ws.Route(ws.GET("/feedback/{feedback-type}/{user-id}/{item-id}").To(s.getTypedUserItemFeedback). Doc("Get feedbacks between a user and a item with feedback type."). Metadata(restfulspec.KeyOpenAPITags, []string{FeedbackAPITag}). Param(ws.HeaderParameter("X-API-Key", "API key").DataType("string")). Param(ws.PathParameter("feedback-type", "Feedback type").DataType("string")). Param(ws.PathParameter("user-id", "User ID").DataType("string")). Param(ws.PathParameter("item-id", "Item ID").DataType("string")). Returns(http.StatusOK, "OK", data.Feedback{}). Writes(data.Feedback{})) ws.Route(ws.DELETE("/feedback/{feedback-type}/{user-id}/{item-id}").To(s.deleteTypedUserItemFeedback). Doc("Delete feedbacks between a user and a item with feedback type."). Metadata(restfulspec.KeyOpenAPITags, []string{FeedbackAPITag}). Param(ws.HeaderParameter("X-API-Key", "API key").DataType("string")). Param(ws.PathParameter("feedback-type", "Feedback type").DataType("string")). Param(ws.PathParameter("user-id", "User ID").DataType("string")). Param(ws.PathParameter("item-id", "Item ID").DataType("string")). Returns(http.StatusOK, "OK", Success{}). Writes(Success{})) // Get feedback by user id ws.Route(ws.GET("/user/{user-id}/feedback/{feedback-type}").To(s.getTypedFeedbackByUser). Doc("Get feedbacks by user id with feedback type."). Metadata(restfulspec.KeyOpenAPITags, []string{FeedbackAPITag}). Param(ws.HeaderParameter("X-API-Key", "API key").DataType("string")). Param(ws.PathParameter("user-id", "User ID").DataType("string")). Param(ws.PathParameter("feedback-type", "Feedback type").DataType("string")). Returns(http.StatusOK, "OK", []data.Feedback{}). Writes([]data.Feedback{})) ws.Route(ws.GET("/user/{user-id}/feedback").To(s.getFeedbackByUser). Doc("Get feedbacks by user id."). Metadata(restfulspec.KeyOpenAPITags, []string{FeedbackAPITag}). Param(ws.HeaderParameter("X-API-Key", "API key").DataType("string")). Param(ws.PathParameter("user-id", "User ID").DataType("string")). Returns(http.StatusOK, "OK", []data.Feedback{}). Writes([]data.Feedback{})) // Get feedback by item-id ws.Route(ws.GET("/item/{item-id}/feedback/{feedback-type}").To(s.getTypedFeedbackByItem). Doc("Get feedbacks by item id with feedback type."). Metadata(restfulspec.KeyOpenAPITags, []string{FeedbackAPITag}). Param(ws.HeaderParameter("X-API-Key", "API key").DataType("string")). Param(ws.PathParameter("item-id", "Item ID").DataType("string")). Param(ws.PathParameter("feedback-type", "Feedback type").DataType("string")). Returns(http.StatusOK, "OK", []data.Feedback{}). Writes([]data.Feedback{})) ws.Route(ws.GET("/item/{item-id}/feedback/").To(s.getFeedbackByItem). Doc("Get feedbacks by item id."). Metadata(restfulspec.KeyOpenAPITags, []string{FeedbackAPITag}). Param(ws.HeaderParameter("X-API-Key", "API key").DataType("string")). Param(ws.PathParameter("item-id", "Item ID").DataType("string")). Returns(http.StatusOK, "OK", []data.Feedback{}). Writes([]data.Feedback{})) // Get collaborative filtering recommendation by user id ws.Route(ws.GET("/collaborative-filtering/{user-id}").To(s.getCollaborativeFiltering). Doc("Get collaborative filtering recommendation for a user."). Metadata(restfulspec.KeyOpenAPITags, []string{RecommendationAPITag}). Param(ws.HeaderParameter("X-API-Key", "API key").DataType("string")). Param(ws.PathParameter("user-id", "User ID").DataType("string")). Param(ws.QueryParameter("n", "Number of returned items").DataType("integer")). Param(ws.QueryParameter("offset", "Offset of returned items").DataType("integer")). Param(ws.QueryParameter("category", "Category of returned items.").DataType("string")). Param(ws.QueryParameter("user-id", "Remove read items of a user").DataType("string")). Returns(http.StatusOK, "OK", []cache.Score{}). Writes([]cache.Score{})) ws.Route(ws.GET("/collaborative-filtering/{user-id}/{category}").To(s.getCollaborativeFiltering). Deprecate().Doc("Get the collaborative filtering recommendation for a user."). Metadata(restfulspec.KeyOpenAPITags, []string{RecommendationAPITag}). Param(ws.HeaderParameter("X-API-Key", "API key").DataType("string")). Param(ws.PathParameter("user-id", "User ID").DataType("string")). Param(ws.PathParameter("category", "Category of returned items.").DataType("string")). Param(ws.QueryParameter("n", "Number of returned items").DataType("integer")). Param(ws.QueryParameter("offset", "Offset of returned items").DataType("integer")). Param(ws.QueryParameter("user-id", "Remove read items of a user").DataType("string")). Returns(http.StatusOK, "OK", []cache.Score{}). Writes([]cache.Score{})) // Get latest items ws.Route(ws.GET("/latest").To(s.getLatest). Doc("Get latest items."). Metadata(restfulspec.KeyOpenAPITags, []string{RecommendationAPITag}). Param(ws.HeaderParameter("X-API-Key", "API key").DataType("string")). Param(ws.QueryParameter("category", "Category of returned items").DataType("string")). Param(ws.QueryParameter("n", "Number of returned items").DataType("integer")). Param(ws.QueryParameter("offset", "Offset of returned items").DataType("integer")). Param(ws.QueryParameter("user-id", "Remove read items of a user").DataType("string")). Returns(http.StatusOK, "OK", []cache.Score{}). Writes([]cache.Score{})) ws.Route(ws.GET("/latest/{category}").To(s.getLatest). Deprecate().Doc("Get the latest items in category."). Metadata(restfulspec.KeyOpenAPITags, []string{RecommendationAPITag}). Param(ws.HeaderParameter("X-API-Key", "API key").DataType("string")). Param(ws.PathParameter("category", "Category of returned items.").DataType("string")). Param(ws.QueryParameter("n", "Number of returned items").DataType("integer")). Param(ws.QueryParameter("offset", "Offset of returned items").DataType("integer")). Param(ws.QueryParameter("user-id", "Remove read items of a user").DataType("string")). Returns(http.StatusOK, "OK", []cache.Score{}). Writes([]cache.Score{})) // Get non-personalized ws.Route(ws.GET("/non-personalized/{name}").To(s.getNonPersonalized). Doc("Get non-personalized recommendations."). Metadata(restfulspec.KeyOpenAPITags, []string{RecommendationAPITag}). Param(ws.HeaderParameter("X-API-Key", "API key").DataType("string")). Param(ws.QueryParameter("category", "Category of returned items.").DataType("string")). Param(ws.QueryParameter("n", "Number of returned users").DataType("integer")). Param(ws.QueryParameter("offset", "Offset of returned users").DataType("integer")). Param(ws.QueryParameter("user-id", "Remove read items of a user").DataType("string")). Returns(http.StatusOK, "OK", []cache.Score{}). Writes([]cache.Score{})) // Get item-to-item recommendation ws.Route(ws.GET("/item-to-item/{name}/{item-id}").To(s.getItemToItem). Doc("Get item-to-item recommendation."). Metadata(restfulspec.KeyOpenAPITags, []string{RecommendationAPITag}). Param(ws.HeaderParameter("X-API-Key", "API key").DataType("string")). Param(ws.PathParameter("name", "Name of the item-to-item recommendation").DataType("string")). Param(ws.PathParameter("item-id", "Item ID").DataType("string")). Param(ws.QueryParameter("n", "Number of returned items").DataType("integer")). Param(ws.QueryParameter("offset", "Offset of returned items").DataType("integer")). Param(ws.QueryParameter("category", "Category of returned items").DataType("string")). Param(ws.QueryParameter("user-id", "Remove read items of a user").DataType("string")). Returns(http.StatusOK, "OK", []cache.Score{}). Writes([]cache.Score{})) // Get user-to-user recommendation ws.Route(ws.GET("/user-to-user/{name}/{user-id}").To(s.getUserToUser). Doc("Get user-to-user recommendation."). Metadata(restfulspec.KeyOpenAPITags, []string{RecommendationAPITag}). Param(ws.HeaderParameter("X-API-Key", "API key").DataType("string")). Param(ws.PathParameter("name", "Name of the user-to-user recommendation").DataType("string")). Param(ws.PathParameter("user-id", "User ID").DataType("string")). Param(ws.QueryParameter("n", "Number of returned users").DataType("integer")). Param(ws.QueryParameter("offset", "Offset of returned users").DataType("integer")). Returns(http.StatusOK, "OK", []cache.Score{}). Writes([]cache.Score{})) // Get neighbors ws.Route(ws.GET("/item/{item-id}/neighbors/").To(s.getItemNeighbors). Doc("Get neighbors of a item"). Metadata(restfulspec.KeyOpenAPITags, []string{RecommendationAPITag}). Param(ws.HeaderParameter("X-API-Key", "API key").DataType("string")). Param(ws.PathParameter("item-id", "Item ID").DataType("string")). Param(ws.QueryParameter("n", "Number of returned items").DataType("integer")). Param(ws.QueryParameter("offset", "Offset of returned items").DataType("integer")). Param(ws.QueryParameter("category", "Category of returned items").DataType("string")). Param(ws.QueryParameter("user-id", "Remove read items of a user").DataType("string")). Returns(http.StatusOK, "OK", []cache.Score{}). Writes([]cache.Score{})) ws.Route(ws.GET("/item/{item-id}/neighbors/{category}").To(s.getItemNeighbors). Deprecate().Doc("Get neighbors of a item in category."). Metadata(restfulspec.KeyOpenAPITags, []string{RecommendationAPITag}). Param(ws.HeaderParameter("X-API-Key", "API key").DataType("string")). Param(ws.PathParameter("item-id", "Item ID").DataType("string")). Param(ws.PathParameter("category", "Category of returned items").DataType("string")). Param(ws.QueryParameter("n", "Number of returned items").DataType("integer")). Param(ws.QueryParameter("offset", "Offset of returned items").DataType("integer")). Param(ws.QueryParameter("user-id", "Remove read items of a user").DataType("string")). Returns(http.StatusOK, "OK", []cache.Score{}). Writes([]cache.Score{})) ws.Route(ws.GET("/user/{user-id}/neighbors/").To(s.getUserNeighbors). Doc("Get neighbors of a user."). Metadata(restfulspec.KeyOpenAPITags, []string{RecommendationAPITag}). Param(ws.HeaderParameter("X-API-Key", "API key").DataType("string")). Param(ws.PathParameter("user-id", "User ID").DataType("string")). Param(ws.QueryParameter("n", "Number of returned users").DataType("integer")). Param(ws.QueryParameter("offset", "Offset of returned users").DataType("integer")). Returns(http.StatusOK, "OK", []cache.Score{}). Writes([]cache.Score{})) ws.Route(ws.GET("/recommend/{user-id}").To(s.getRecommend). Doc("Get recommendation for user. Set X-API-Version: 2 to return scores."). Metadata(restfulspec.KeyOpenAPITags, []string{RecommendationAPITag}). Param(ws.HeaderParameter("X-API-Key", "API key").DataType("string")). Param(ws.HeaderParameter("X-API-Version", "API version (set to 2 to return scores)").DataType("string")). Param(ws.PathParameter("user-id", "User ID").DataType("string")). Param(ws.QueryParameter("category", "Category of the returned items (support multi-categories filtering)").DataType("string")). Param(ws.QueryParameter("write-back-type", "Type of write back feedback").DataType("string")). Param(ws.QueryParameter("write-back-delay", "Timestamp delay of write back feedback (format 0h0m0s)").DataType("string")). Param(ws.QueryParameter("n", "Number of returned items").DataType("integer")). Param(ws.QueryParameter("offset", "Offset of returned items").DataType("integer")). Returns(http.StatusOK, "OK", []string{}). Writes([]string{})) ws.Route(ws.GET("/recommend/{user-id}/{category}").To(s.getRecommend). Deprecate().Doc("Get recommendation for user. Set X-API-Version: 2 to return scores."). Metadata(restfulspec.KeyOpenAPITags, []string{RecommendationAPITag}). Param(ws.HeaderParameter("X-API-Key", "API key").DataType("string")). Param(ws.HeaderParameter("X-API-Version", "API version (set to 2 to return scores)").DataType("string")). Param(ws.PathParameter("user-id", "User ID").DataType("string")). Param(ws.PathParameter("category", "Category of the returned items").DataType("string")). Param(ws.QueryParameter("write-back-type", "Type of write back feedback").DataType("string")). Param(ws.QueryParameter("write-back-delay", "Timestamp delay of write back feedback (format 0h0m0s)").DataType("string")). Param(ws.QueryParameter("n", "Number of returned items").DataType("integer")). Param(ws.QueryParameter("offset", "Offset of returned items").DataType("integer")). Returns(http.StatusOK, "OK", []string{}). Writes([]string{})) ws.Route(ws.POST("/session/recommend").To(s.sessionRecommend). Doc("Get recommendation for session."). Metadata(restfulspec.KeyOpenAPITags, []string{RecommendationAPITag}). Param(ws.HeaderParameter("X-API-Key", "API key").DataType("string")). Param(ws.QueryParameter("n", "Number of returned items").DataType("integer")). Param(ws.QueryParameter("offset", "Offset of returned items").DataType("integer")). Reads([]Feedback{}). Returns(http.StatusOK, "OK", []cache.Score{}). Writes([]cache.Score{})) ws.Route(ws.POST("/session/recommend/{category}").To(s.sessionRecommend). Deprecate().Doc("Get recommendation for session."). Metadata(restfulspec.KeyOpenAPITags, []string{RecommendationAPITag}). Param(ws.HeaderParameter("X-API-Key", "API key").DataType("string")). Param(ws.PathParameter("category", "Category of the returned items").DataType("string")). Param(ws.QueryParameter("n", "Number of returned items").DataType("integer")). Param(ws.QueryParameter("offset", "Offset of returned items").DataType("integer")). Reads([]Feedback{}). Returns(http.StatusOK, "OK", []cache.Score{}). Writes([]cache.Score{})) } // ParseInt parses integers from the query parameter. func ParseInt(request *restful.Request, name string, fallback int) (value int, err error) { valueString := request.QueryParameter(name) value, err = strconv.Atoi(valueString) if err != nil && valueString == "" { value = fallback err = nil } return } // ParseDuration parses duration from the query parameter. func ParseDuration(request *restful.Request, name string) (time.Duration, error) { valueString := request.QueryParameter(name) if valueString == "" { return 0, nil } return time.ParseDuration(valueString) } func (s *RestServer) SearchDocuments(collection, subset string, categories []string, iteratee func(item cache.Score) (any, error), request *restful.Request, response *restful.Response, ) { var ( ctx = request.Request.Context() n int offset int userId string err error ) // parse arguments if offset, err = ParseInt(request, "offset", 0); err != nil { BadRequest(response, err) return } if n, err = ParseInt(request, "n", s.Config.Server.DefaultN); err != nil { BadRequest(response, err) return } userId = request.QueryParameter("user-id") readItems := mapset.NewSet[string]() if userId != "" { feedback, err := s.DataClient.GetUserFeedback(ctx, userId, s.Config.Now()) if err != nil { InternalServerError(response, err) return } for _, f := range feedback { readItems.Add(f.ItemId) } } end := offset + n if end > 0 && readItems.Cardinality() > 0 { end += readItems.Cardinality() } // Get the sorted list items, err := s.CacheClient.SearchScores(ctx, collection, subset, categories, offset, end) if err != nil { InternalServerError(response, err) return } // Remove read items if userId != "" { prunedItems := make([]cache.Score, 0, len(items)) for _, item := range items { if !readItems.Contains(item.Id) { prunedItems = append(prunedItems, item) } } items = prunedItems } // Send result if n > 0 && len(items) > n { items = items[:n] } if iteratee != nil { var results []any for _, item := range items { result, err := iteratee(item) if err != nil { InternalServerError(response, err) return } results = append(results, result) } Ok(response, results) } else { Ok(response, items) } } func (s *RestServer) getLatest(request *restful.Request, response *restful.Response) { var ( offset int n int err error ) ctx := request.Request.Context() if offset, err = ParseInt(request, "offset", 0); err != nil { BadRequest(response, errors.Errorf("invalid offset parameter: %v", err)) return } if n, err = ParseInt(request, "n", s.Config.Server.DefaultN); err != nil { BadRequest(response, errors.Errorf("invalid n parameter: %v", err)) return } categories := ReadCategories(request, nil) userId := request.QueryParameter("user-id") readItems := mapset.NewSet[string]() if userId != "" { feedback, err := s.DataClient.GetUserFeedback(ctx, userId, s.Config.Now()) if err != nil { InternalServerError(response, err) return } for _, f := range feedback { readItems.Add(f.ItemId) } } limit := offset + n if readItems.Cardinality() > 0 { limit += readItems.Cardinality() } items, err := s.DataClient.GetLatestItems(ctx, limit, categories) if err != nil { InternalServerError(response, err) return } if readItems.Cardinality() > 0 { filtered := make([]data.Item, 0, len(items)) for _, item := range items { if !readItems.Contains(item.ItemId) { filtered = append(filtered, item) } } items = filtered } items = items[min(offset, len(items)):] if n > 0 && len(items) > n { items = items[:n] } Ok(response, lo.Map(items, func(item data.Item, _ int) cache.Score { return cache.Score{ Id: item.ItemId, Score: float64(item.Timestamp.Unix()), } })) } func (s *RestServer) getNonPersonalized(request *restful.Request, response *restful.Response) { name := request.PathParameter("name") categories := ReadCategories(request, []string{""}) log.ResponseLogger(response).Debug("get leaderboard", zap.String("name", name)) s.SetLastModified(request, response, cache.Key(cache.NonPersonalizedUpdateTime, name)) s.SearchDocuments(cache.NonPersonalized, name, categories, nil, request, response) } func (s *RestServer) getItemToItem(request *restful.Request, response *restful.Response) { name := request.PathParameter("name") itemId := request.PathParameter("item-id") categories := request.QueryParameters("category") s.SetLastModified(request, response, cache.Key(cache.ItemToItemUpdateTime, name, itemId)) s.SearchDocuments(cache.ItemToItem, cache.Key(name, itemId), categories, nil, request, response) } func (s *RestServer) getUserToUser(request *restful.Request, response *restful.Response) { name := request.PathParameter("name") userId := request.PathParameter("user-id") s.SetLastModified(request, response, cache.Key(cache.UserToUserUpdateTime, name, userId)) s.SearchDocuments(cache.UserToUser, cache.Key(name, userId), nil, nil, request, response) } func (s *RestServer) SetLastModified(request *restful.Request, response *restful.Response, key string) { lastModified, err := s.CacheClient.Get(request.Request.Context(), key).Time() if err != nil { log.ResponseLogger(response).Error("failed to get last modified time", zap.Error(err)) return } response.AddHeader("Last-Modified", lastModified.Format(time.RFC1123)) } // get feedback by item-id with feedback type func (s *RestServer) getTypedFeedbackByItem(request *restful.Request, response *restful.Response) { ctx := context.Background() if request != nil && request.Request != nil { ctx = request.Request.Context() } feedbackType := request.PathParameter("feedback-type") itemId := request.PathParameter("item-id") feedback, err := s.DataClient.GetItemFeedback(ctx, itemId, feedbackType) if err != nil { InternalServerError(response, err) return } Ok(response, feedback) } // get feedback by item-id func (s *RestServer) getFeedbackByItem(request *restful.Request, response *restful.Response) { ctx := context.Background() if request != nil && request.Request != nil { ctx = request.Request.Context() } itemId := request.PathParameter("item-id") feedback, err := s.DataClient.GetItemFeedback(ctx, itemId) if err != nil { InternalServerError(response, err) return } Ok(response, feedback) } // getItemNeighbors gets neighbors of a item from database. func (s *RestServer) getItemNeighbors(request *restful.Request, response *restful.Response) { // Get item id itemId := request.PathParameter("item-id") categories := ReadCategories(request, nil) if len(s.Config.Recommend.ItemToItem) == 0 { PageNotFound(response, errors.New("item-to-item recommendation is not enabled")) return } else { name := s.Config.Recommend.ItemToItem[0].Name s.SetLastModified(request, response, cache.Key(cache.ItemToItemUpdateTime, name, itemId)) s.SearchDocuments(cache.ItemToItem, cache.Key(name, itemId), categories, nil, request, response) } } // getUserNeighbors gets neighbors of a user from database. func (s *RestServer) getUserNeighbors(request *restful.Request, response *restful.Response) { // Get item id userId := request.PathParameter("user-id") if len(s.Config.Recommend.UserToUser) == 0 { PageNotFound(response, errors.New("user-to-user recommendation is not enabled")) return } else { name := s.Config.Recommend.UserToUser[0].Name s.SetLastModified(request, response, cache.Key(cache.UserToUserUpdateTime, name, userId)) s.SearchDocuments(cache.UserToUser, cache.Key(name, userId), nil, nil, request, response) } } // getCollaborativeFiltering gets cached recommended items from database. func (s *RestServer) getCollaborativeFiltering(request *restful.Request, response *restful.Response) { if strings.EqualFold(s.Config.Recommend.Collaborative.Type, "none") { PageNotFound(response, errors.New("collaborative filtering recommendation is disabled")) return } // Get user id userId := request.PathParameter("user-id") categories := ReadCategories(request, nil) s.SetLastModified(request, response, cache.Key(cache.RecommendUpdateTime, userId)) s.SearchDocuments(cache.Recommend, userId, categories, nil, request, response) } func (s *RestServer) getRecommend(request *restful.Request, response *restful.Response) { ctx := context.Background() if request != nil && request.Request != nil { ctx = request.Request.Context() } // parse arguments userId := request.PathParameter("user-id") apiVersion := strings.TrimSpace(request.HeaderParameter("X-API-Version")) n, err := ParseInt(request, "n", s.Config.Server.DefaultN) if err != nil { BadRequest(response, err) return } categories := ReadCategories(request, nil) offset, err := ParseInt(request, "offset", 0) if err != nil { BadRequest(response, err) return } writeBackFeedback := request.QueryParameter("write-back-type") writeBackDelay, err := ParseDuration(request, "write-back-delay") if err != nil { BadRequest(response, err) return } // online recommendation recommender, err := logics.NewRecommender(s.Config.Recommend, s.CacheClient, s.DataClient, true, userId, categories) if err != nil { InternalServerError(response, err) return } scores, err := recommender.Recommend(ctx, n+offset) if err != nil { InternalServerError(response, err) return } if len(scores) > offset { scores = scores[offset:] } else { scores = []cache.Score{} } results := lo.Map(scores, func(item cache.Score, index int) string { return item.Id }) // write back if writeBackFeedback != "" { startTime := time.Now() for _, itemId := range results { // insert to data store feedback := data.Feedback{ FeedbackKey: data.FeedbackKey{ UserId: userId, ItemId: itemId, FeedbackType: writeBackFeedback, }, Timestamp: startTime.Add(writeBackDelay), } err = s.DataClient.BatchInsertFeedback(ctx, []data.Feedback{feedback}, false, false, false) if err != nil { InternalServerError(response, err) return } } } // Send result if apiVersion == "2" { Ok(response, scores) return } Ok(response, results) } func (s *RestServer) sessionRecommend(request *restful.Request, response *restful.Response) { ctx := context.Background() if len(s.Config.Recommend.ItemToItem) == 0 { PageNotFound(response, errors.New("item-to-item recommendation is not enabled")) return } name := s.Config.Recommend.ItemToItem[0].Name if request != nil && request.Request != nil { ctx = request.Request.Context() } // parse arguments var feedbacks []Feedback if err := request.ReadEntity(&feedbacks); err != nil { BadRequest(response, err) return } n, err := ParseInt(request, "n", s.Config.Server.DefaultN) if err != nil { BadRequest(response, err) return } category := request.PathParameter("category") offset, err := ParseInt(request, "offset", 0) if err != nil { BadRequest(response, err) return } // pre-process feedback dataFeedback := make([]data.Feedback, len(feedbacks)) for i := range dataFeedback { var err error dataFeedback[i], err = feedbacks[i].ToDataFeedback() if err != nil { BadRequest(response, err) return } } data.SortFeedbacks(dataFeedback) // item-based recommendation var excludeSet = mapset.NewSet[string]() var userFeedback []data.Feedback for _, feedback := range dataFeedback { excludeSet.Add(feedback.ItemId) if expression.MatchFeedbackTypeExpressions(s.Config.Recommend.DataSource.PositiveFeedbackTypes, feedback.FeedbackType, feedback.Value) { userFeedback = append(userFeedback, feedback) } } // collect candidates candidates := make(map[string]float64) usedFeedbackCount := 0 for _, feedback := range userFeedback { // load similar items similarItems, err := s.CacheClient.SearchScores(ctx, cache.ItemToItem, cache.Key(name, feedback.ItemId), []string{category}, 0, s.Config.Recommend.CacheSize) if err != nil { BadRequest(response, err) return } // add unseen items // similarItems = s.FilterOutHiddenScores(response, similarItems, "") for _, item := range similarItems { if !excludeSet.Contains(item.Id) { candidates[item.Id] += item.Score } } // finish recommendation if the number of used feedbacks is enough if len(similarItems) > 0 { usedFeedbackCount++ if usedFeedbackCount >= s.Config.Recommend.ContextSize { break } } } // collect top k filter := heap.NewTopKFilter[string, float64](n + offset) for id, score := range candidates { filter.Push(id, score) } scores := filter.PopAll() result := lo.Map(scores, func(score heap.Elem[string, float64], _ int) cache.Score { return cache.Score{ Id: score.Value, Score: score.Weight, } }) if len(result) > offset { result = result[offset:] } else { result = nil } result = result[:lo.Min([]int{len(result), n})] // Send result Ok(response, result) } // Success is the returned data structure for data insert operations. type Success struct { RowAffected int } func (s *RestServer) insertUser(request *restful.Request, response *restful.Response) { ctx := context.Background() if request != nil && request.Request != nil { ctx = request.Request.Context() } temp := data.User{} // get userInfo from request and put into temp if err := request.ReadEntity(&temp); err != nil { BadRequest(response, err) return } // validate labels if err := data.ValidateLabels(temp.Labels); err != nil { BadRequest(response, err) return } if err := s.DataClient.BatchInsertUsers(ctx, []data.User{temp}); err != nil { InternalServerError(response, err) return } // insert modify timestamp if err := s.CacheClient.Set(ctx, cache.Time(cache.Key(cache.LastModifyUserTime, temp.UserId), time.Now())); err != nil { InternalServerError(response, err) return } Ok(response, Success{RowAffected: 1}) } func (s *RestServer) modifyUser(request *restful.Request, response *restful.Response) { ctx := context.Background() if request != nil && request.Request != nil { ctx = request.Request.Context() } // get user id userId := request.PathParameter("user-id") // modify user var patch data.UserPatch if err := request.ReadEntity(&patch); err != nil { BadRequest(response, err) return } // validate labels if err := data.ValidateLabels(patch.Labels); err != nil { BadRequest(response, err) return } if err := s.DataClient.ModifyUser(ctx, userId, patch); err != nil { InternalServerError(response, err) return } // insert modify timestamp if err := s.CacheClient.Set(ctx, cache.Time(cache.Key(cache.LastModifyUserTime, userId), time.Now())); err != nil { return } Ok(response, Success{RowAffected: 1}) } func (s *RestServer) getUser(request *restful.Request, response *restful.Response) { ctx := context.Background() if request != nil && request.Request != nil { ctx = request.Request.Context() } // get user id userId := request.PathParameter("user-id") // get user user, err := s.DataClient.GetUser(ctx, userId) if err != nil { if errors.Is(err, errors.NotFound) { PageNotFound(response, err) } else { InternalServerError(response, err) } return } Ok(response, user) } func (s *RestServer) insertUsers(request *restful.Request, response *restful.Response) { ctx := context.Background() if request != nil && request.Request != nil { ctx = request.Request.Context() } var temp []data.User // get param from request and put into temp if err := request.ReadEntity(&temp); err != nil { BadRequest(response, err) return } // validate labels for _, user := range temp { if err := data.ValidateLabels(user.Labels); err != nil { BadRequest(response, err) return } } // range temp and achieve user if err := s.DataClient.BatchInsertUsers(ctx, temp); err != nil { InternalServerError(response, err) return } // insert modify timestamp values := make([]cache.Value, len(temp)) for i, user := range temp { values[i] = cache.Time(cache.Key(cache.LastModifyUserTime, user.UserId), time.Now()) } if err := s.CacheClient.Set(ctx, values...); err != nil { InternalServerError(response, err) return } Ok(response, Success{RowAffected: len(temp)}) } type UserIterator struct { Cursor string Users []data.User } func (s *RestServer) getUsers(request *restful.Request, response *restful.Response) { ctx := context.Background() if request != nil && request.Request != nil { ctx = request.Request.Context() } cursor := request.QueryParameter("cursor") n, err := ParseInt(request, "n", s.Config.Server.DefaultN) if err != nil { BadRequest(response, err) return } // get all users cursor, users, err := s.DataClient.GetUsers(ctx, cursor, n) if err != nil { InternalServerError(response, err) return } Ok(response, UserIterator{Cursor: cursor, Users: users}) } // delete a user by user-id func (s *RestServer) deleteUser(request *restful.Request, response *restful.Response) { ctx := context.Background() if request != nil && request.Request != nil { ctx = request.Request.Context() } // get user-id and put into temp userId := request.PathParameter("user-id") if err := s.DataClient.DeleteUser(ctx, userId); err != nil { InternalServerError(response, err) return } Ok(response, Success{RowAffected: 1}) } // get feedback by user-id with feedback type func (s *RestServer) getTypedFeedbackByUser(request *restful.Request, response *restful.Response) { ctx := context.Background() if request != nil && request.Request != nil { ctx = request.Request.Context() } feedbackType := request.PathParameter("feedback-type") var feednackTypeExpr expression.FeedbackTypeExpression if err := feednackTypeExpr.FromString(feedbackType); err != nil { BadRequest(response, err) return } userId := request.PathParameter("user-id") feedback, err := s.DataClient.GetUserFeedback(ctx, userId, s.Config.Now(), feednackTypeExpr) if err != nil { InternalServerError(response, err) return } Ok(response, feedback) } // get feedback by user-id func (s *RestServer) getFeedbackByUser(request *restful.Request, response *restful.Response) { ctx := context.Background() if request != nil && request.Request != nil { ctx = request.Request.Context() } userId := request.PathParameter("user-id") feedback, err := s.DataClient.GetUserFeedback(ctx, userId, s.Config.Now()) if err != nil { InternalServerError(response, err) return } Ok(response, feedback) } // Item is the data structure for the item but stores the timestamp using string. type Item struct { ItemId string IsHidden bool Categories []string Timestamp string Labels any Comment string } func (s *RestServer) batchInsertItems(ctx context.Context, response *restful.Response, temp []Item) { var ( count int items = make([]data.Item, 0, len(temp)) loadExistedItemsTime time.Duration parseTimesatmpTime time.Duration insertItemsTime time.Duration insertCacheTime time.Duration ) // load existed items start := time.Now() existedItems, err := s.DataClient.BatchGetItems(ctx, lo.Map(temp, func(t Item, i int) string { return t.ItemId })) if err != nil { InternalServerError(response, err) return } existedItemsSet := make(map[string]data.Item) for _, item := range existedItems { existedItemsSet[item.ItemId] = item } loadExistedItemsTime = time.Since(start) start = time.Now() for _, item := range temp { // parse datetime var timestamp time.Time var err error if item.Timestamp != "" { if timestamp, err = dateparse.ParseAny(item.Timestamp); err != nil { BadRequest(response, err) return } } items = append(items, data.Item{ ItemId: item.ItemId, IsHidden: item.IsHidden, Categories: item.Categories, Timestamp: timestamp, Labels: item.Labels, Comment: item.Comment, }) // update items cache if err = s.CacheClient.UpdateScores(ctx, cache.ItemCache, nil, item.ItemId, cache.ScorePatch{ Categories: withWildCard(item.Categories), IsHidden: &item.IsHidden, }); err != nil { InternalServerError(response, err) return } count++ } parseTimesatmpTime = time.Since(start) // insert items start = time.Now() if err = s.DataClient.BatchInsertItems(ctx, items); err != nil { InternalServerError(response, err) return } insertItemsTime = time.Since(start) // insert modify timestamp start = time.Now() values := make([]cache.Value, len(items)) for i, item := range items { values[i] = cache.Time(cache.Key(cache.LastModifyItemTime, item.ItemId), time.Now()) } if err = s.CacheClient.Set(ctx, values...); err != nil { InternalServerError(response, err) return } insertCacheTime = time.Since(start) log.ResponseLogger(response).Info("batch insert items", zap.Duration("load_existed_items_time", loadExistedItemsTime), zap.Duration("parse_timestamp_time", parseTimesatmpTime), zap.Duration("insert_items_time", insertItemsTime), zap.Duration("insert_cache_time", insertCacheTime)) Ok(response, Success{RowAffected: count}) } func (s *RestServer) insertItems(request *restful.Request, response *restful.Response) { ctx := context.Background() if request != nil && request.Request != nil { ctx = request.Request.Context() } var items []Item if err := request.ReadEntity(&items); err != nil { BadRequest(response, err) return } // validate labels for _, user := range items { if err := data.ValidateLabels(user.Labels); err != nil { BadRequest(response, err) return } } // Insert items s.batchInsertItems(ctx, response, items) } func (s *RestServer) insertItem(request *restful.Request, response *restful.Response) { ctx := context.Background() if request != nil && request.Request != nil { ctx = request.Request.Context() } var item Item var err error if err = request.ReadEntity(&item); err != nil { BadRequest(response, err) return } // validate labels if err := data.ValidateLabels(item.Labels); err != nil { BadRequest(response, err) return } s.batchInsertItems(ctx, response, []Item{item}) } func (s *RestServer) modifyItem(request *restful.Request, response *restful.Response) { ctx := context.Background() if request != nil && request.Request != nil { ctx = request.Request.Context() } itemId := request.PathParameter("item-id") var patch data.ItemPatch if err := request.ReadEntity(&patch); err != nil { BadRequest(response, err) return } // validate labels if err := data.ValidateLabels(patch.Labels); err != nil { BadRequest(response, err) return } // remove hidden item from cache if patch.IsHidden != nil { if err := s.CacheClient.UpdateScores(ctx, cache.ItemCache, nil, itemId, cache.ScorePatch{IsHidden: patch.IsHidden}); err != nil { InternalServerError(response, err) return } } // update categories in cache if patch.Categories != nil { if err := s.CacheClient.UpdateScores(ctx, cache.ItemCache, nil, itemId, cache.ScorePatch{Categories: withWildCard(patch.Categories)}); err != nil { InternalServerError(response, err) return } } // modify item if err := s.DataClient.ModifyItem(ctx, itemId, patch); err != nil { InternalServerError(response, err) return } // insert modify timestamp if err := s.CacheClient.Set(ctx, cache.Time(cache.Key(cache.LastModifyItemTime, itemId), time.Now())); err != nil { return } Ok(response, Success{RowAffected: 1}) } // ItemIterator is the iterator for items. type ItemIterator struct { Cursor string Items []data.Item } func (s *RestServer) getItems(request *restful.Request, response *restful.Response) { ctx := context.Background() if request != nil && request.Request != nil { ctx = request.Request.Context() } cursor := request.QueryParameter("cursor") n, err := ParseInt(request, "n", s.Config.Server.DefaultN) if err != nil { BadRequest(response, err) return } cursor, items, err := s.DataClient.GetItems(ctx, cursor, n, nil) if err != nil { InternalServerError(response, err) return } Ok(response, ItemIterator{Cursor: cursor, Items: items}) } func (s *RestServer) getItem(request *restful.Request, response *restful.Response) { ctx := context.Background() if request != nil && request.Request != nil { ctx = request.Request.Context() } // Get item id itemId := request.PathParameter("item-id") // Get item item, err := s.DataClient.GetItem(ctx, itemId) if err != nil { if errors.Is(err, errors.NotFound) { PageNotFound(response, err) } else { InternalServerError(response, err) } return } Ok(response, item) } func (s *RestServer) deleteItem(request *restful.Request, response *restful.Response) { ctx := context.Background() if request != nil && request.Request != nil { ctx = request.Request.Context() } itemId := request.PathParameter("item-id") // delete item from database if err := s.DataClient.DeleteItem(ctx, itemId); err != nil { InternalServerError(response, err) return } // delete item from cache if err := s.CacheClient.DeleteScores(ctx, cache.ItemCache, cache.ScoreCondition{Id: &itemId}); err != nil { InternalServerError(response, err) return } Ok(response, Success{RowAffected: 1}) } func (s *RestServer) insertItemCategory(request *restful.Request, response *restful.Response) { ctx := context.Background() if request != nil && request.Request != nil { ctx = request.Request.Context() } // fetch item id and category itemId := request.PathParameter("item-id") category := request.PathParameter("category") // fetch item item, err := s.DataClient.GetItem(ctx, itemId) if err != nil { InternalServerError(response, err) return } if !lo.Contains(item.Categories, category) { item.Categories = append(item.Categories, category) } // insert category to database if err = s.DataClient.BatchInsertItems(ctx, []data.Item{item}); err != nil { InternalServerError(response, err) return } // insert category to cache if err = s.CacheClient.UpdateScores(ctx, cache.ItemCache, nil, itemId, cache.ScorePatch{Categories: withWildCard(item.Categories)}); err != nil { InternalServerError(response, err) return } Ok(response, Success{RowAffected: 1}) } func (s *RestServer) deleteItemCategory(request *restful.Request, response *restful.Response) { ctx := context.Background() if request != nil && request.Request != nil { ctx = request.Request.Context() } // fetch item id and category itemId := request.PathParameter("item-id") category := request.PathParameter("category") // fetch item item, err := s.DataClient.GetItem(ctx, itemId) if err != nil { InternalServerError(response, err) return } categories := make([]string, 0, len(item.Categories)) for _, cat := range item.Categories { if cat != category { categories = append(categories, cat) } } item.Categories = categories // delete category from cache if err = s.CacheClient.UpdateScores(ctx, cache.ItemCache, nil, itemId, cache.ScorePatch{Categories: withWildCard(categories)}); err != nil { InternalServerError(response, err) return } // delete category from database if err = s.DataClient.BatchInsertItems(ctx, []data.Item{item}); err != nil { InternalServerError(response, err) return } Ok(response, Success{RowAffected: 1}) } // Feedback is the data structure for the feedback but stores the timestamp using string. type Feedback struct { data.FeedbackKey Value float64 Timestamp string Comment string } func (f Feedback) ToDataFeedback() (data.Feedback, error) { var feedback data.Feedback feedback.FeedbackKey = f.FeedbackKey feedback.Value = f.Value feedback.Comment = f.Comment if f.Timestamp != "" { var err error feedback.Timestamp, err = dateparse.ParseAny(f.Timestamp) if err != nil { return data.Feedback{}, err } } return feedback, nil } func (s *RestServer) insertFeedback(overwrite bool) func(request *restful.Request, response *restful.Response) { return func(request *restful.Request, response *restful.Response) { ctx := context.Background() if request != nil && request.Request != nil { ctx = request.Request.Context() } // add ratings var feedbackLiterTime []Feedback if err := request.ReadEntity(&feedbackLiterTime); err != nil { BadRequest(response, err) return } // parse datetime var err error feedback := make([]data.Feedback, len(feedbackLiterTime)) users := mapset.NewSet[string]() items := mapset.NewSet[string]() for i := range feedback { users.Add(feedbackLiterTime[i].UserId) items.Add(feedbackLiterTime[i].ItemId) feedback[i], err = feedbackLiterTime[i].ToDataFeedback() if err != nil { BadRequest(response, err) return } } // insert feedback to data store err = s.DataClient.BatchInsertFeedback(ctx, feedback, s.Config.Server.AutoInsertUser, s.Config.Server.AutoInsertItem, overwrite) if err != nil { InternalServerError(response, err) return } values := make([]cache.Value, 0, users.Cardinality()+items.Cardinality()) for _, userId := range users.ToSlice() { values = append(values, cache.Time(cache.Key(cache.LastModifyUserTime, userId), time.Now())) } for _, itemId := range items.ToSlice() { values = append(values, cache.Time(cache.Key(cache.LastModifyItemTime, itemId), time.Now())) } if err = s.CacheClient.Set(ctx, values...); err != nil { InternalServerError(response, err) return } log.ResponseLogger(response).Info("Insert feedback successfully", zap.Int("num_feedback", len(feedback))) Ok(response, Success{RowAffected: len(feedback)}) } } // FeedbackIterator is the iterator for feedback. type FeedbackIterator struct { Cursor string Feedback []data.Feedback } func (s *RestServer) getFeedback(request *restful.Request, response *restful.Response) { ctx := context.Background() if request != nil && request.Request != nil { ctx = request.Request.Context() } // Parse parameters cursor := request.QueryParameter("cursor") n, err := ParseInt(request, "n", s.Config.Server.DefaultN) if err != nil { BadRequest(response, err) return } cursor, feedback, err := s.DataClient.GetFeedback(ctx, cursor, n, nil, s.Config.Now()) if err != nil { InternalServerError(response, err) return } Ok(response, FeedbackIterator{Cursor: cursor, Feedback: feedback}) } func (s *RestServer) getTypedFeedback(request *restful.Request, response *restful.Response) { ctx := context.Background() if request != nil && request.Request != nil { ctx = request.Request.Context() } // Parse parameters feedbackType := request.PathParameter("feedback-type") cursor := request.QueryParameter("cursor") n, err := ParseInt(request, "n", s.Config.Server.DefaultN) if err != nil { BadRequest(response, err) return } cursor, feedback, err := s.DataClient.GetFeedback(ctx, cursor, n, nil, s.Config.Now(), feedbackType) if err != nil { InternalServerError(response, err) return } Ok(response, FeedbackIterator{Cursor: cursor, Feedback: feedback}) } func (s *RestServer) getUserItemFeedback(request *restful.Request, response *restful.Response) { ctx := context.Background() if request != nil && request.Request != nil { ctx = request.Request.Context() } // Parse parameters userId := request.PathParameter("user-id") itemId := request.PathParameter("item-id") if feedback, err := s.DataClient.GetUserItemFeedback(ctx, userId, itemId); err != nil { InternalServerError(response, err) } else { Ok(response, feedback) } } func (s *RestServer) deleteUserItemFeedback(request *restful.Request, response *restful.Response) { ctx := context.Background() if request != nil && request.Request != nil { ctx = request.Request.Context() } // Parse parameters userId := request.PathParameter("user-id") itemId := request.PathParameter("item-id") if deleteCount, err := s.DataClient.DeleteUserItemFeedback(ctx, userId, itemId); err != nil { InternalServerError(response, err) } else { Ok(response, Success{RowAffected: deleteCount}) } } func (s *RestServer) getTypedUserItemFeedback(request *restful.Request, response *restful.Response) { ctx := context.Background() if request != nil && request.Request != nil { ctx = request.Request.Context() } // Parse parameters feedbackType := request.PathParameter("feedback-type") userId := request.PathParameter("user-id") itemId := request.PathParameter("item-id") if feedback, err := s.DataClient.GetUserItemFeedback(ctx, userId, itemId, feedbackType); err != nil { InternalServerError(response, err) } else if feedbackType == "" { Text(response, "{}") } else { Ok(response, feedback[0]) } } func (s *RestServer) deleteTypedUserItemFeedback(request *restful.Request, response *restful.Response) { ctx := context.Background() if request != nil && request.Request != nil { ctx = request.Request.Context() } // Parse parameters feedbackType := request.PathParameter("feedback-type") userId := request.PathParameter("user-id") itemId := request.PathParameter("item-id") if deleteCount, err := s.DataClient.DeleteUserItemFeedback(ctx, userId, itemId, feedbackType); err != nil { InternalServerError(response, err) } else { Ok(response, Success{deleteCount}) } } type HealthStatus struct { Ready bool DataStoreError error CacheStoreError error DataStoreConnected bool CacheStoreConnected bool } func (s *RestServer) checkHealth() HealthStatus { healthStatus := HealthStatus{} healthStatus.DataStoreError = s.DataClient.Ping() healthStatus.CacheStoreError = s.CacheClient.Ping() healthStatus.DataStoreConnected = healthStatus.DataStoreError == nil healthStatus.CacheStoreConnected = healthStatus.CacheStoreError == nil healthStatus.Ready = healthStatus.DataStoreConnected && healthStatus.CacheStoreConnected return healthStatus } func (s *RestServer) checkReady(_ *restful.Request, response *restful.Response) { healthStatus := s.checkHealth() if healthStatus.Ready { Ok(response, healthStatus) } else { errReason, err := json.Marshal(healthStatus) if err != nil { Error(response, http.StatusInternalServerError, err) } else { Error(response, http.StatusServiceUnavailable, errors.New(string(errReason))) } } } func (s *RestServer) checkLive(_ *restful.Request, response *restful.Response) { healthStatus := s.checkHealth() Ok(response, healthStatus) } // BadRequest returns a bad request error. func BadRequest(response *restful.Response, err error) { response.Header().Set("Access-Control-Allow-Origin", "*") log.ResponseLogger(response).Error("bad request", zap.Error(err)) if err = response.WriteError(http.StatusBadRequest, err); err != nil { log.ResponseLogger(response).Error("failed to write error", zap.Error(err)) } } // InternalServerError returns a internal server error. func InternalServerError(response *restful.Response, err error) { response.Header().Set("Access-Control-Allow-Origin", "*") log.ResponseLogger(response).Error("internal server error", zap.Error(err)) if err = response.WriteError(http.StatusInternalServerError, err); err != nil { log.ResponseLogger(response).Error("failed to write error", zap.Error(err)) } } // PageNotFound returns a not found error. func PageNotFound(response *restful.Response, err error) { response.Header().Set("Access-Control-Allow-Origin", "*") if err := response.WriteError(http.StatusNotFound, err); err != nil { log.ResponseLogger(response).Error("failed to write error", zap.Error(err)) } } // Ok sends the content as JSON to the client. func Ok(response *restful.Response, content interface{}) { response.AddHeader("Access-Control-Allow-Origin", "*") if err := response.WriteAsJson(content); err != nil { log.ResponseLogger(response).Error("failed to write json", zap.Error(err)) } } func Error(response *restful.Response, httpStatus int, responseError error) { response.Header().Set("Access-Control-Allow-Origin", "*") if err := response.WriteError(httpStatus, responseError); err != nil { log.ResponseLogger(response).Error("failed to write error", zap.Error(err)) } } // Text returns a plain text. func Text(response *restful.Response, content string) { response.Header().Set("Access-Control-Allow-Origin", "*") if _, err := response.Write([]byte(content)); err != nil { log.ResponseLogger(response).Error("failed to write text", zap.Error(err)) } } func withWildCard(categories []string) []string { result := make([]string, len(categories), len(categories)+1) copy(result, categories) result = append(result, "") return result } // ReadCategories tries to read categories from the request. If the category is not found, it returns an empty string. func ReadCategories(request *restful.Request, defaultCategories []string) []string { if pathValue := request.PathParameter("category"); pathValue != "" { return []string{pathValue} } else if queryValues := request.QueryParameters("category"); len(queryValues) > 0 { return lo.Filter(queryValues, func(cat string, _ int) bool { return len(cat) > 0 }) } else { return defaultCategories } } ================================================ FILE: server/rest_test.go ================================================ // Copyright 2020 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package server import ( "encoding/json" "fmt" "net/http" "strconv" "testing" "time" "github.com/emicklei/go-restful/v3" "github.com/gorse-io/gorse/common/expression" "github.com/gorse-io/gorse/config" "github.com/gorse-io/gorse/storage/cache" "github.com/gorse-io/gorse/storage/data" "github.com/samber/lo/mutable" "github.com/steinfletcher/apitest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" ) const apiKey = "test_api_key" type ServerTestSuite struct { suite.Suite RestServer handler *restful.Container } func (suite *ServerTestSuite) SetupSuite() { // create mock redis server var err error // open database suite.Config = config.GetDefaultConfig() suite.DataClient, err = data.Open(fmt.Sprintf("sqlite://%s/data.db", suite.T().TempDir()), "") suite.NoError(err) suite.CacheClient, err = cache.Open(fmt.Sprintf("sqlite://%s/cache.db", suite.T().TempDir()), "") suite.NoError(err) // init database err = suite.DataClient.Init() suite.NoError(err) err = suite.CacheClient.Init() suite.NoError(err) suite.WebService = new(restful.WebService) suite.CreateWebService() // create handler suite.handler = restful.NewContainer() suite.handler.Add(suite.WebService) } func (suite *ServerTestSuite) TearDownSuite() { err := suite.DataClient.Close() suite.NoError(err) err = suite.CacheClient.Close() suite.NoError(err) } func (suite *ServerTestSuite) SetupTest() { err := suite.DataClient.Purge() suite.NoError(err) err = suite.CacheClient.Purge() suite.NoError(err) // configuration suite.Config = config.GetDefaultConfig() suite.Config.Server.APIKey = apiKey suite.Config.Recommend.Collaborative.Type = "mf" suite.Config.Recommend.Ranker.Type = "fm" suite.Config.Recommend.Fallback.Recommenders = []string{"latest"} } func (suite *ServerTestSuite) marshal(v interface{}) string { s, err := json.Marshal(v) suite.NoError(err) return string(s) } func (suite *ServerTestSuite) TestUsers() { t := suite.T() users := []data.User{ {UserId: "0"}, {UserId: "1"}, {UserId: "2"}, {UserId: "3"}, {UserId: "4"}, } apitest.New(). Handler(suite.handler). Post("/api/user"). Header("X-API-Key", apiKey). JSON(users[0]). Expect(t). Status(http.StatusOK). Body(`{"RowAffected":1}`). End() apitest.New(). Handler(suite.handler). Get("/api/user/0"). Header("X-API-Key", apiKey). Expect(t). Status(http.StatusOK). Body(suite.marshal(users[0])). End() apitest.New(). Handler(suite.handler). Post("/api/users"). Header("X-API-Key", apiKey). JSON(users[1:]). Expect(t). Status(http.StatusOK). Body(`{"RowAffected":4}`). End() apitest.New(). Handler(suite.handler). Get("/api/users"). Header("X-API-Key", apiKey). QueryParams(map[string]string{ "cursor": "", "n": "100", }). Expect(t). Status(http.StatusOK). Body(suite.marshal(UserIterator{ Cursor: "", Users: users, })). End() apitest.New(). Handler(suite.handler). Delete("/api/user/0"). Header("X-API-Key", apiKey). Expect(t). Status(http.StatusOK). Body(`{"RowAffected": 1}`). End() apitest.New(). Handler(suite.handler). Get("/api/user/0"). Header("X-API-Key", apiKey). Expect(t). Status(http.StatusNotFound). End() // test modify apitest.New(). Handler(suite.handler). Patch("/api/user/1"). Header("X-API-Key", apiKey). JSON(data.UserPatch{Labels: []string{"a", "b", "c"}, Comment: new("modified")}). Expect(t). Status(http.StatusOK). Body(`{"RowAffected": 1}`). End() apitest.New(). Handler(suite.handler). Get("/api/user/1"). Header("X-API-Key", apiKey). Expect(t). Status(http.StatusOK). Body(suite.marshal(data.User{ UserId: "1", Comment: "modified", Labels: []string{"a", "b", "c"}, })). End() // malicious labels apitest.New(). Handler(suite.handler). Post("/api/user"). Header("X-API-Key", apiKey). JSON(data.User{UserId: "malicious", Labels: []any{"price", 100}}). Expect(t). Status(http.StatusBadRequest). End() apitest.New(). Handler(suite.handler). Post("/api/users"). Header("X-API-Key", apiKey). JSON([]data.User{{UserId: "malicious", Labels: []any{"price", 100}}}). Expect(t). Status(http.StatusBadRequest). End() apitest.New(). Handler(suite.handler). Patch("/api/user/malicious"). Header("X-API-Key", apiKey). JSON(data.UserPatch{Labels: []any{"price", 100}}). Expect(t). Status(http.StatusBadRequest). End() } func (suite *ServerTestSuite) TestItems() { t := suite.T() // Items items := []data.Item{ { ItemId: "0", IsHidden: true, Timestamp: time.Date(1996, 3, 15, 0, 0, 0, 0, time.UTC), Labels: []string{"a"}, Comment: "comment_0", }, { ItemId: "2", Categories: []string{"*"}, Timestamp: time.Date(1997, 3, 15, 0, 0, 0, 0, time.UTC), Labels: []string{"a"}, Comment: "comment_2", }, { ItemId: "4", IsHidden: true, Timestamp: time.Date(1998, 3, 15, 0, 0, 0, 0, time.UTC), Labels: []string{"a", "b"}, Comment: "comment_4", }, { ItemId: "6", Categories: []string{"*"}, Timestamp: time.Date(1999, 3, 15, 0, 0, 0, 0, time.UTC), Labels: []string{"b"}, Comment: "comment_6", }, { ItemId: "8", IsHidden: true, Timestamp: time.Date(2000, 3, 15, 0, 0, 0, 0, time.UTC), Labels: []string{"b"}, Comment: "comment_8", }, } // insert items apitest.New(). Handler(suite.handler). Post("/api/item"). Header("X-API-Key", apiKey). JSON(items[0]). Expect(t). Status(http.StatusOK). Body(`{"RowAffected": 1}`). End() // batch insert items apitest.New(). Handler(suite.handler). Post("/api/items"). Header("X-API-Key", apiKey). JSON(items[1:]). Expect(t). Status(http.StatusOK). Body(`{"RowAffected": 4}`). End() // get items apitest.New(). Handler(suite.handler). Get("/api/items"). Header("X-API-Key", apiKey). QueryParams(map[string]string{ "cursor": "", "n": "100", }). Expect(t). Status(http.StatusOK). Body(suite.marshal(ItemIterator{ Cursor: "", Items: items, })). End() // get latest items apitest.New(). Handler(suite.handler). Get("/api/latest"). Header("X-API-Key", apiKey). QueryParams(map[string]string{ "n": "3", }). Expect(t). Status(http.StatusOK). Body(suite.marshal([]cache.Score{ {Id: items[3].ItemId, Score: float64(items[3].Timestamp.Unix())}, {Id: items[1].ItemId, Score: float64(items[1].Timestamp.Unix())}, })). End() apitest.New(). Handler(suite.handler). Get("/api/latest/"). Header("X-API-Key", apiKey). QueryParams(map[string]string{ "n": "3", "offset": "1", }). Expect(t). Status(http.StatusOK). Body(suite.marshal([]cache.Score{ {Id: items[1].ItemId, Score: float64(items[1].Timestamp.Unix())}, })). End() apitest.New(). Handler(suite.handler). Get("/api/latest/*"). Header("X-API-Key", apiKey). QueryParams(map[string]string{ "n": "3", }). Expect(t). Status(http.StatusOK). Body(suite.marshal([]cache.Score{ {Id: items[3].ItemId, Score: float64(items[3].Timestamp.Unix())}, {Id: items[1].ItemId, Score: float64(items[1].Timestamp.Unix())}, })). End() err := suite.DataClient.BatchInsertFeedback(suite.T().Context(), []data.Feedback{{ FeedbackKey: data.FeedbackKey{FeedbackType: "read", UserId: "0", ItemId: "6"}, Timestamp: time.Now().Truncate(time.Hour), }}, true, true, true) suite.NoError(err) apitest.New(). Handler(suite.handler). Get("/api/latest"). Header("X-API-Key", apiKey). QueryParams(map[string]string{ "n": "3", "user-id": "0", }). Expect(t). Status(http.StatusOK). Body(suite.marshal([]cache.Score{ {Id: items[1].ItemId, Score: float64(items[1].Timestamp.Unix())}, })). End() // delete item apitest.New(). Handler(suite.handler). Delete("/api/item/6"). Header("X-API-Key", apiKey). Expect(t). Status(http.StatusOK). Body(`{"RowAffected": 1}`). End() // get item apitest.New(). Handler(suite.handler). Get("/api/item/6"). Header("X-API-Key", apiKey). Expect(t). Status(http.StatusNotFound). End() // get latest items apitest.New(). Handler(suite.handler). Get("/api/latest"). Header("X-API-Key", apiKey). QueryParams(map[string]string{ "n": "3", }). Expect(t). Status(http.StatusOK). Body(suite.marshal([]cache.Score{ {Id: items[1].ItemId, Score: float64(items[1].Timestamp.Unix())}, })). End() apitest.New(). Handler(suite.handler). Get("/api/latest/*"). Header("X-API-Key", apiKey). QueryParams(map[string]string{ "n": "3", }). Expect(t). Status(http.StatusOK). Body(suite.marshal([]cache.Score{ {Id: items[1].ItemId, Score: float64(items[1].Timestamp.Unix())}, })). End() // test modify timestamp := time.Date(2010, 1, 1, 1, 1, 1, 0, time.UTC) apitest.New(). Handler(suite.handler). Patch("/api/item/2"). Header("X-API-Key", apiKey). JSON(data.ItemPatch{ IsHidden: new(true), Categories: []string{"-"}, Labels: []string{"a", "b", "c"}, Comment: new("modified"), Timestamp: ×tamp, }). Expect(t). Status(http.StatusOK). Body(`{"RowAffected": 1}`). End() apitest.New(). Handler(suite.handler). Get("/api/item/2"). Header("X-API-Key", apiKey). Expect(t). Status(http.StatusOK). Body(suite.marshal(data.Item{ ItemId: "2", IsHidden: true, Categories: []string{"-"}, Comment: "modified", Labels: []string{"a", "b", "c"}, Timestamp: timestamp, })). End() apitest.New(). Handler(suite.handler). Patch("/api/item/2"). Header("X-API-Key", apiKey). JSON(data.ItemPatch{ IsHidden: new(false), }). Expect(t). Status(http.StatusOK). Body(`{"RowAffected": 1}`). End() // get latest items apitest.New(). Handler(suite.handler). Get("/api/latest/-"). Header("X-API-Key", apiKey). QueryParams(map[string]string{ "n": "1", }). Expect(t). Status(http.StatusOK). Body(suite.marshal([]cache.Score{ {Id: "2", Score: float64(timestamp.Unix())}, })). End() apitest.New(). Handler(suite.handler). Get("/api/latest/*"). Header("X-API-Key", apiKey). QueryParams(map[string]string{ "n": "3", }). Expect(t). Status(http.StatusOK). Body(suite.marshal([]cache.Score{})). End() // insert category apitest.New(). Handler(suite.handler). Put("/api/item/2/category/@"). Header("X-API-Key", apiKey). Expect(t). Status(http.StatusOK). Body(suite.marshal(Success{RowAffected: 1})). End() apitest.New(). Handler(suite.handler). Get("/api/item/2"). Header("X-API-Key", apiKey). Expect(t). Status(http.StatusOK). Body(suite.marshal(data.Item{ ItemId: "2", IsHidden: false, Categories: []string{"-", "@"}, Comment: "modified", Labels: []string{"a", "b", "c"}, Timestamp: timestamp, })). End() // get latest items apitest.New(). Handler(suite.handler). Get("/api/latest/@"). Header("X-API-Key", apiKey). QueryParams(map[string]string{ "n": "1", }). Expect(t). Status(http.StatusOK). Body(suite.marshal([]cache.Score{ {Id: "2", Score: float64(timestamp.Unix())}, })). End() // delete category apitest.New(). Handler(suite.handler). Delete("/api/item/2/category/@"). Header("X-API-Key", apiKey). Expect(t). Status(http.StatusOK). Body(suite.marshal(Success{RowAffected: 1})). End() apitest.New(). Handler(suite.handler). Get("/api/item/2"). Header("X-API-Key", apiKey). Expect(t). Status(http.StatusOK). Body(suite.marshal(data.Item{ ItemId: "2", IsHidden: false, Categories: []string{"-"}, Comment: "modified", Labels: []string{"a", "b", "c"}, Timestamp: timestamp, })). End() // get latest items apitest.New(). Handler(suite.handler). Get("/api/latest/@"). Header("X-API-Key", apiKey). QueryParams(map[string]string{ "n": "1", }). Expect(t). Status(http.StatusOK). Body(suite.marshal([]cache.Score{})). End() // insert items without timestamp apitest.New(). Handler(suite.handler). Post("/api/item"). Header("X-API-Key", apiKey). JSON(Item{ItemId: "256"}). Expect(t). Status(http.StatusOK). Body(`{"RowAffected": 1}`). End() // malicious labels apitest.New(). Handler(suite.handler). Post("/api/item"). Header("X-API-Key", apiKey). JSON(Item{ItemId: "malicious", Labels: []any{"price", 1}}). Expect(t). Status(http.StatusBadRequest). End() apitest.New(). Handler(suite.handler). Post("/api/items"). Header("X-API-Key", apiKey). JSON([]Item{{ItemId: "malicious", Labels: []any{"price", 1}}}). Expect(t). Status(http.StatusBadRequest). End() apitest.New(). Handler(suite.handler). Patch("/api/item/malicious"). Header("X-API-Key", apiKey). JSON(data.ItemPatch{Labels: []any{"price", 1}}). Expect(t). Status(http.StatusBadRequest). End() } func (suite *ServerTestSuite) TestFeedback() { ctx := suite.T().Context() t := suite.T() // Insert ret feedback := []data.Feedback{ {FeedbackKey: data.FeedbackKey{FeedbackType: "click", UserId: "0", ItemId: "0"}, Value: 1.0}, {FeedbackKey: data.FeedbackKey{FeedbackType: "click", UserId: "1", ItemId: "2"}, Value: 1.0}, {FeedbackKey: data.FeedbackKey{FeedbackType: "click", UserId: "2", ItemId: "4"}, Value: 1.0}, {FeedbackKey: data.FeedbackKey{FeedbackType: "click", UserId: "3", ItemId: "6"}, Value: 1.0}, {FeedbackKey: data.FeedbackKey{FeedbackType: "click", UserId: "4", ItemId: "8"}, Value: 1.0}, } //BatchInsertFeedback apitest.New(). Handler(suite.handler). Post("/api/feedback"). Header("X-API-Key", apiKey). JSON(feedback). Expect(t). Status(http.StatusOK). Body(`{"RowAffected": 5}`). End() //Get Feedback apitest.New(). Handler(suite.handler). Get("/api/feedback"). Header("X-API-Key", apiKey). QueryParams(map[string]string{ "cursor": "", "n": "100", }). Expect(t). Body(suite.marshal(FeedbackIterator{ Cursor: "", Feedback: feedback, })). Status(http.StatusOK). End() // get feedback by user apitest.New(). Handler(suite.handler). Get("/api/user/1/feedback"). Header("X-API-Key", apiKey). Expect(t). Body(suite.marshal([]data.Feedback{feedback[1]})). Status(http.StatusOK). End() // get feedback by item apitest.New(). Handler(suite.handler). Get("/api/item/2/feedback"). Header("X-API-Key", apiKey). Expect(t). Body(suite.marshal([]data.Feedback{feedback[1]})). Status(http.StatusOK). End() //Get Items apitest.New(). Handler(suite.handler). Get("/api/items"). Header("X-API-Key", apiKey). Expect(t). Status(http.StatusOK). Body(suite.marshal(ItemIterator{ Cursor: "", Items: []data.Item{ {ItemId: "0"}, {ItemId: "2"}, {ItemId: "4"}, {ItemId: "6"}, {ItemId: "8"}, }, })). End() apitest.New(). Handler(suite.handler). Get("/api/users"). Header("X-API-Key", apiKey). Expect(t). Status(http.StatusOK). Body(suite.marshal(UserIterator{ Cursor: "", Users: []data.User{ {UserId: "0"}, {UserId: "1"}, {UserId: "2"}, {UserId: "3"}, {UserId: "4"}}, })). End() apitest.New(). Handler(suite.handler). Get("/api/user/2/feedback/click"). Header("X-API-Key", apiKey). Expect(t). Status(http.StatusOK). Body(`[{"FeedbackType":"click", "UserId": "2", "ItemId": "4", "Timestamp":"0001-01-01T00:00:00Z", "Updated":"0001-01-01T00:00:00Z", "Comment":"", "Value":1}]`). End() apitest.New(). Handler(suite.handler). Get("/api/item/4/feedback/click"). Header("X-API-Key", apiKey). Expect(t). Status(http.StatusOK). Body(`[{"FeedbackType":"click", "UserId": "2", "ItemId": "4", "Timestamp":"0001-01-01T00:00:00Z", "Updated":"0001-01-01T00:00:00Z", "Comment":"", "Value":1}]`). End() // test overwrite apitest.New(). Handler(suite.handler). Put("/api/feedback"). Header("X-API-Key", apiKey). JSON([]data.Feedback{{ FeedbackKey: data.FeedbackKey{FeedbackType: "click", UserId: "0", ItemId: "0"}, Comment: "override", }}). Expect(t). Status(http.StatusOK). Body(`{"RowAffected": 1}`). End() ret, err := suite.DataClient.GetUserFeedback(ctx, "0", suite.Config.Now(), expression.MustParseFeedbackTypeExpression("click")) assert.NoError(t, err) assert.Equal(t, 1, len(ret)) assert.Equal(t, "override", ret[0].Comment) // test not overwrite apitest.New(). Handler(suite.handler). Post("/api/feedback"). Header("X-API-Key", apiKey). JSON([]data.Feedback{{ FeedbackKey: data.FeedbackKey{FeedbackType: "click", UserId: "0", ItemId: "0"}, Comment: "not_override", }}). Expect(t). Status(http.StatusOK). Body(`{"RowAffected": 1}`). End() ret, err = suite.DataClient.GetUserFeedback(ctx, "0", suite.Config.Now(), expression.MustParseFeedbackTypeExpression("click")) assert.NoError(t, err) assert.Equal(t, 1, len(ret)) assert.Equal(t, "not_override", ret[0].Comment) // insert feedback without timestamp apitest.New(). Handler(suite.handler). Post("/api/feedback"). Header("X-API-Key", apiKey). JSON([]Feedback{{FeedbackKey: data.FeedbackKey{UserId: "100", ItemId: "100", FeedbackType: "Type"}}}). Expect(t). Status(http.StatusOK). Body(`{"RowAffected": 1}`). End() } func (suite *ServerTestSuite) TestNonPersonalizedRecommend() { ctx := suite.T().Context() suite.Config.Recommend.ItemToItem = []config.ItemToItemConfig{{Name: "default"}} type ListOperator struct { Name string Collection string Subset string Category string URL string } operators := []ListOperator{ // TODO: Support hide users in the future. //{"User Neighbors", cache.Collection(cache.UserNeighbors, "0"), "/api/user/0/neighbors"}, {"Item Neighbors", cache.ItemToItem, cache.Key("default", "0"), "", "/api/item/0/neighbors"}, {"Item Neighbors in Category", cache.ItemToItem, cache.Key("default", "0"), "0", "/api/item/0/neighbors/0"}, {"NonPersonalized", cache.NonPersonalized, "trending", "", "/api/non-personalized/trending"}, {"NonPersonalizedCategory", cache.NonPersonalized, "trending", "0", "/api/non-personalized/trending"}, {"ItemToItem", cache.ItemToItem, cache.Key("lookalike", "0"), "", "/api/item-to-item/lookalike/0"}, {"ItemToItemCategory", cache.ItemToItem, cache.Key("lookalike", "0"), "0", "/api/item-to-item/lookalike/0"}, {"CollaborativeFiltering", cache.Recommend, "0", "", "/api/collaborative-filtering/0"}, {"CollaborativeFilteringCategory", cache.Recommend, "0", "0", "/api/collaborative-filtering/0/0"}, } lastModified := time.Now() for i, operator := range operators { suite.T().Run(operator.Name, func(t *testing.T) { // insert documents documents := []cache.Score{ {Id: strconv.Itoa(i) + "0", Score: 100, Categories: []string{operator.Category}}, {Id: strconv.Itoa(i) + "1", Score: 99, Categories: []string{operator.Category}}, {Id: strconv.Itoa(i) + "2", Score: 98, Categories: []string{operator.Category}}, {Id: strconv.Itoa(i) + "3", Score: 97, Categories: []string{operator.Category}}, {Id: strconv.Itoa(i) + "4", Score: 96, Categories: []string{operator.Category}}, } err := suite.CacheClient.AddScores(ctx, operator.Collection, operator.Subset, documents) assert.NoError(t, err) err = suite.CacheClient.Set(ctx, cache.Time(cache.Key(operator.Collection+"_update_time", operator.Subset), lastModified)) assert.NoError(t, err) // hidden item apitest.New(). Handler(suite.handler). Patch("/api/item/"+strconv.Itoa(i)+"3"). Header("X-API-Key", apiKey). JSON(data.ItemPatch{IsHidden: new(true)}). Expect(t). Status(http.StatusOK). End() // insert read feedback err = suite.DataClient.BatchInsertFeedback(ctx, []data.Feedback{{ FeedbackKey: data.FeedbackKey{ FeedbackType: "read", UserId: "0", ItemId: strconv.Itoa(i) + "1", }, Timestamp: time.Now().Add(-time.Hour), }}, true, true, true) assert.NoError(t, err) apitest.New(). Handler(suite.handler). Get(operator.URL). Query("category", operator.Category). Header("X-API-Key", apiKey). Expect(t). Status(http.StatusOK). HeaderPresent("Last-Modified"). Body(suite.marshal([]cache.Score{documents[0], documents[1], documents[2], documents[4]})). End() apitest.New(). Handler(suite.handler). Get(operator.URL). Query("category", operator.Category). Header("X-API-Key", apiKey). QueryParams(map[string]string{ "offset": "0", "n": "3"}). Expect(t). Status(http.StatusOK). HeaderPresent("Last-Modified"). Body(suite.marshal([]cache.Score{documents[0], documents[1], documents[2]})). End() apitest.New(). Handler(suite.handler). Get(operator.URL). Query("category", operator.Category). Header("X-API-Key", apiKey). QueryParams(map[string]string{ "offset": "1", "n": "3"}). Expect(t). Status(http.StatusOK). HeaderPresent("Last-Modified"). Body(suite.marshal([]cache.Score{documents[1], documents[2], documents[4]})). End() apitest.New(). Handler(suite.handler). Get(operator.URL). Query("category", operator.Category). Header("X-API-Key", apiKey). QueryParams(map[string]string{ "offset": "0", }). Expect(t). Status(http.StatusOK). HeaderPresent("Last-Modified"). Body(suite.marshal([]cache.Score{documents[0], documents[1], documents[2], documents[4]})). End() apitest.New(). Handler(suite.handler). Get(operator.URL). Query("category", operator.Category). Header("X-API-Key", apiKey). QueryParams(map[string]string{ "user-id": "0", "offset": "0", }). Expect(t). Status(http.StatusOK). HeaderPresent("Last-Modified"). Body(suite.marshal([]cache.Score{documents[0], documents[2], documents[4]})). End() }) } } func (suite *ServerTestSuite) TestUserToUser() { ctx := suite.T().Context() suite.Config.Recommend.UserToUser = []config.UserToUserConfig{{Name: "default"}} err := suite.CacheClient.AddScores(ctx, cache.UserToUser, cache.Key("default", "0"), []cache.Score{ {Id: "1", Score: 100}, {Id: "2", Score: 99}, {Id: "3", Score: 98}, {Id: "4", Score: 97}, {Id: "5", Score: 96}, }) suite.NoError(err) apitest.New(). Handler(suite.handler). Get("/api/user-to-user/default/0"). Header("X-API-Key", apiKey). Expect(suite.T()). Status(http.StatusOK). Body(suite.marshal([]cache.Score{ {Id: "1", Score: 100}, {Id: "2", Score: 99}, {Id: "3", Score: 98}, {Id: "4", Score: 97}, {Id: "5", Score: 96}, })). End() apitest.New(). Handler(suite.handler). Get("/api/user/0/neighbors"). Header("X-API-Key", apiKey). Expect(suite.T()). Status(http.StatusOK). Body(suite.marshal([]cache.Score{ {Id: "1", Score: 100}, {Id: "2", Score: 99}, {Id: "3", Score: 98}, {Id: "4", Score: 97}, {Id: "5", Score: 96}, })). End() } func (suite *ServerTestSuite) TestDeleteFeedback() { t := suite.T() // Insert feedback feedback := []data.Feedback{ {FeedbackKey: data.FeedbackKey{FeedbackType: "type1", UserId: "2", ItemId: "3"}}, {FeedbackKey: data.FeedbackKey{FeedbackType: "type2", UserId: "2", ItemId: "3"}}, {FeedbackKey: data.FeedbackKey{FeedbackType: "type3", UserId: "2", ItemId: "3"}}, {FeedbackKey: data.FeedbackKey{FeedbackType: "type1", UserId: "1", ItemId: "6"}}, {FeedbackKey: data.FeedbackKey{FeedbackType: "type1", UserId: "4", ItemId: "8"}}, } apitest.New(). Handler(suite.handler). Post("/api/feedback"). Header("X-API-Key", apiKey). JSON(feedback). Expect(t). Status(http.StatusOK). Body(`{"RowAffected": 5}`). End() // Get Feedback apitest.New(). Handler(suite.handler). Get("/api/feedback/2/3"). Header("X-API-Key", apiKey). Expect(t). Body(suite.marshal([]data.Feedback{feedback[0], feedback[1], feedback[2]})). Status(http.StatusOK). End() // Get typed feedback apitest.New(). Handler(suite.handler). Get("/api/feedback/type2/2/3"). Header("X-API-Key", apiKey). Expect(t). Status(http.StatusOK). Body(suite.marshal(feedback[1])). End() // delete feedback apitest.New(). Handler(suite.handler). Delete("/api/feedback/2/3"). Header("X-API-Key", apiKey). Expect(t). Body(`{"RowAffected": 3}`). Status(http.StatusOK). End() // delete typed feedback apitest.New(). Handler(suite.handler). Delete("/api/feedback/type1/4/8"). Header("X-API-Key", apiKey). Expect(t). Status(http.StatusOK). Body(`{"RowAffected": 1}`). End() } func (suite *ServerTestSuite) TestGetRecommends() { ctx := suite.T().Context() // insert hidden items err := suite.CacheClient.AddScores(ctx, cache.Recommend, "0", []cache.Score{{Id: "0", Score: 100, Categories: []string{""}}}) suite.NoError(err) // hide item apitest.New(). Handler(suite.handler). Patch("/api/item/0"). Header("X-API-Key", apiKey). JSON(data.ItemPatch{IsHidden: new(true)}). Expect(suite.T()). Status(http.StatusOK). End() // insert items err = suite.DataClient.BatchInsertItems(ctx, []data.Item{ {ItemId: "1"}, {ItemId: "2"}, {ItemId: "3"}, {ItemId: "4"}, {ItemId: "5"}, {ItemId: "6"}, {ItemId: "7"}, {ItemId: "8"}, }) suite.NoError(err) // insert recommendation err = suite.CacheClient.AddScores(ctx, cache.Recommend, "0", []cache.Score{ {Id: "1", Score: 99, Categories: []string{""}}, {Id: "2", Score: 98, Categories: []string{""}}, {Id: "3", Score: 97, Categories: []string{""}}, {Id: "4", Score: 96, Categories: []string{""}}, {Id: "5", Score: 95, Categories: []string{""}}, {Id: "6", Score: 94, Categories: []string{""}}, {Id: "7", Score: 93, Categories: []string{""}}, {Id: "8", Score: 92, Categories: []string{""}}, }) suite.NoError(err) // insert feedback feedback := []data.Feedback{ {FeedbackKey: data.FeedbackKey{FeedbackType: "a", UserId: "0", ItemId: "2"}}, {FeedbackKey: data.FeedbackKey{FeedbackType: "a", UserId: "0", ItemId: "4"}}, {FeedbackKey: data.FeedbackKey{FeedbackType: "a", UserId: "0", ItemId: "1"}, Timestamp: time.Now().Add(time.Hour)}, } apitest.New(). Handler(suite.handler). Post("/api/feedback"). Header("X-API-Key", apiKey). JSON(feedback). Expect(suite.T()). Status(http.StatusOK). Body(`{"RowAffected": 3}`). End() apitest.New(). Handler(suite.handler). Get("/api/recommend/0"). Header("X-API-Key", apiKey). QueryParams(map[string]string{ "n": "3", }). Expect(suite.T()). Status(http.StatusOK). Body(suite.marshal([]string{"1", "3", "5"})). End() apitest.New(). Handler(suite.handler). Get("/api/recommend/0"). Header("X-API-Key", apiKey). Header("X-API-Version", "2"). QueryParams(map[string]string{ "n": "3", }). Expect(suite.T()). Status(http.StatusOK). Body(suite.marshal([]cache.Score{{Id: "1", Score: 99}, {Id: "3", Score: 97}, {Id: "5", Score: 95}})). End() apitest.New(). Handler(suite.handler). Get("/api/recommend/0"). Header("X-API-Key", apiKey). QueryParams(map[string]string{ "n": "3", "offset": "3", }). Expect(suite.T()). Status(http.StatusOK). Body(suite.marshal([]string{"6", "7", "8"})). End() apitest.New(). Handler(suite.handler). Get("/api/recommend/0"). Header("X-API-Key", apiKey). QueryParams(map[string]string{ "n": "3", "offset": "10000", }). Expect(suite.T()). Status(http.StatusOK). Body(suite.marshal([]string{})). End() apitest.New(). Handler(suite.handler). Get("/api/recommend/0"). Header("X-API-Key", apiKey). QueryParams(map[string]string{ "n": "3", "write-back-type": "read", }). Expect(suite.T()). Status(http.StatusOK). Body(suite.marshal([]string{"1", "3", "5"})). End() apitest.New(). Handler(suite.handler). Get("/api/recommend/0"). Header("X-API-Key", apiKey). QueryParams(map[string]string{ "n": "3", "write-back-type": "read", "write-back-delay": "10m", }). Expect(suite.T()). Status(http.StatusOK). Body(suite.marshal([]string{"6", "7", "8"})). End() apitest.New(). Handler(suite.handler). Get("/api/recommend/0"). Header("X-API-Key", apiKey). QueryParams(map[string]string{ "n": "3", }). Expect(suite.T()). Status(http.StatusOK). Body(suite.marshal([]string{"6", "7", "8"})). End() } func (suite *ServerTestSuite) TestGetRecommendsMultiCategories() { ctx := suite.T().Context() // insert recommendation err := suite.CacheClient.AddScores(ctx, cache.Recommend, "0", []cache.Score{ {Id: "1", Score: 1, Categories: []string{""}}, {Id: "2", Score: 2, Categories: []string{"", "2"}}, {Id: "3", Score: 3, Categories: []string{"", "3"}}, {Id: "4", Score: 4, Categories: []string{"", "2"}}, {Id: "5", Score: 5, Categories: []string{"", "5"}}, {Id: "6", Score: 6, Categories: []string{"", "2", "3"}}, {Id: "7", Score: 7, Categories: []string{"", "7"}}, {Id: "8", Score: 8, Categories: []string{"", "2"}}, {Id: "9", Score: 9, Categories: []string{"", "3"}}, }) suite.NoError(err) apitest.New(). Handler(suite.handler). Get("/api/recommend/0"). Header("X-API-Key", apiKey). QueryCollection(map[string][]string{ "n": {"3"}, "category": {"2", "3"}, }). Expect(suite.T()). Status(http.StatusOK). Body(suite.marshal([]string{"6"})). End() } func (suite *ServerTestSuite) TestGetRecommendsReplacement() { ctx := suite.T().Context() suite.Config.Recommend.Replacement.EnableReplacement = true // insert recommendation err := suite.CacheClient.AddScores(ctx, cache.Recommend, "0", []cache.Score{ {Id: "0", Score: 100, Categories: []string{""}}, {Id: "1", Score: 99, Categories: []string{""}}, {Id: "2", Score: 98, Categories: []string{""}}, {Id: "3", Score: 97, Categories: []string{""}}, {Id: "4", Score: 96, Categories: []string{""}}, {Id: "5", Score: 95, Categories: []string{""}}, {Id: "6", Score: 94, Categories: []string{""}}, {Id: "7", Score: 93, Categories: []string{""}}, {Id: "8", Score: 92, Categories: []string{""}}, }) suite.NoError(err) // hide item apitest.New(). Handler(suite.handler). Patch("/api/item/0"). Header("X-API-Key", apiKey). JSON(data.ItemPatch{IsHidden: new(true)}). Expect(suite.T()). Status(http.StatusOK). End() // insert feedback feedback := []data.Feedback{ {FeedbackKey: data.FeedbackKey{FeedbackType: "a", UserId: "0", ItemId: "2"}}, {FeedbackKey: data.FeedbackKey{FeedbackType: "a", UserId: "0", ItemId: "4"}}, {FeedbackKey: data.FeedbackKey{FeedbackType: "a", UserId: "0", ItemId: "1"}, Timestamp: time.Now().Add(time.Hour)}, } apitest.New(). Handler(suite.handler). Post("/api/feedback"). Header("X-API-Key", apiKey). JSON(feedback). Expect(suite.T()). Status(http.StatusOK). Body(`{"RowAffected": 3}`). End() apitest.New(). Handler(suite.handler). Get("/api/recommend/0"). Header("X-API-Key", apiKey). QueryParams(map[string]string{ "n": "3", }). Expect(suite.T()). Status(http.StatusOK). Body(suite.marshal([]string{"1", "2", "3"})). End() } func (suite *ServerTestSuite) TestGetRecommendsFallbackItemToItem() { ctx := suite.T().Context() suite.Config.Recommend.ContextSize = 4 suite.Config.Recommend.DataSource.PositiveFeedbackTypes = []expression.FeedbackTypeExpression{ expression.MustParseFeedbackTypeExpression("a")} suite.Config.Recommend.ItemToItem = []config.ItemToItemConfig{{Name: "default"}} // insert recommendation err := suite.CacheClient.AddScores(ctx, cache.Recommend, "0", []cache.Score{ {Id: "1", Score: 99}, {Id: "2", Score: 98}, {Id: "3", Score: 97}, {Id: "4", Score: 96}}) suite.NoError(err) // insert feedback feedback := []data.Feedback{ {FeedbackKey: data.FeedbackKey{FeedbackType: "a", UserId: "0", ItemId: "1"}, Timestamp: time.Date(2010, 1, 1, 1, 1, 1, 1, time.UTC)}, {FeedbackKey: data.FeedbackKey{FeedbackType: "a", UserId: "0", ItemId: "2"}, Timestamp: time.Date(2009, 1, 1, 1, 1, 1, 1, time.UTC)}, {FeedbackKey: data.FeedbackKey{FeedbackType: "a", UserId: "0", ItemId: "3"}, Timestamp: time.Date(2008, 1, 1, 1, 1, 1, 1, time.UTC)}, {FeedbackKey: data.FeedbackKey{FeedbackType: "a", UserId: "0", ItemId: "4"}, Timestamp: time.Date(2007, 1, 1, 1, 1, 1, 1, time.UTC)}, {FeedbackKey: data.FeedbackKey{FeedbackType: "a", UserId: "0", ItemId: "5"}, Timestamp: time.Date(2006, 1, 1, 1, 1, 1, 1, time.UTC)}, } apitest.New(). Handler(suite.handler). Post("/api/feedback"). Header("X-API-Key", apiKey). JSON(feedback). Expect(suite.T()). Status(http.StatusOK). Body(`{"RowAffected": 5}`). End() // insert similar items err = suite.CacheClient.AddScores(ctx, cache.ItemToItem, cache.Key("default", "1"), []cache.Score{ {Id: "2", Score: 100000, Categories: []string{""}}, {Id: "9", Score: 1, Categories: []string{"", "*"}}, }) suite.NoError(err) err = suite.CacheClient.AddScores(ctx, cache.ItemToItem, cache.Key("default", "2"), []cache.Score{ {Id: "3", Score: 100000, Categories: []string{"", "*"}}, {Id: "8", Score: 1, Categories: []string{""}}, {Id: "9", Score: 1, Categories: []string{"", "*"}}, }) suite.NoError(err) err = suite.CacheClient.AddScores(ctx, cache.ItemToItem, cache.Key("default", "3"), []cache.Score{ {Id: "4", Score: 100000, Categories: []string{""}}, {Id: "7", Score: 1, Categories: []string{"", "*"}}, {Id: "8", Score: 1, Categories: []string{""}}, {Id: "9", Score: 1, Categories: []string{"", "*"}}, }) suite.NoError(err) err = suite.CacheClient.AddScores(ctx, cache.ItemToItem, cache.Key("default", "4"), []cache.Score{ {Id: "1", Score: 100000, Categories: []string{"", "*"}}, {Id: "6", Score: 1, Categories: []string{""}}, {Id: "7", Score: 1, Categories: []string{"", "*"}}, {Id: "8", Score: 1, Categories: []string{""}}, {Id: "9", Score: 1, Categories: []string{"", "*"}}, }) suite.NoError(err) err = suite.CacheClient.AddScores(ctx, cache.ItemToItem, cache.Key("default", "5"), []cache.Score{ {Id: "1", Score: 1, Categories: []string{""}}, {Id: "6", Score: 1, Categories: []string{""}}, {Id: "7", Score: 100000, Categories: []string{""}}, {Id: "8", Score: 100, Categories: []string{""}}, {Id: "9", Score: 1, Categories: []string{""}}, }) suite.NoError(err) // test fallback suite.Config.Recommend.Fallback.Recommenders = []string{"item-to-item/default"} apitest.New(). Handler(suite.handler). Get("/api/recommend/0"). Header("X-API-Key", apiKey). QueryParams(map[string]string{ "n": "3", }). Expect(suite.T()). Status(http.StatusOK). Body(suite.marshal([]string{"9", "8", "7"})). End() suite.Config.Recommend.Fallback.Recommenders = []string{"item-to-item/default"} apitest.New(). Handler(suite.handler). Get("/api/recommend/0/*"). Header("X-API-Key", apiKey). QueryParams(map[string]string{ "n": "3", }). Expect(suite.T()). Status(http.StatusOK). Body(suite.marshal([]string{"9", "7"})). End() } func (suite *ServerTestSuite) TestGetRecommendsFallbackUserToUser() { ctx := suite.T().Context() suite.Config.Recommend.UserToUser = []config.UserToUserConfig{{Name: "default"}} // insert recommendation err := suite.CacheClient.AddScores(ctx, cache.Recommend, "0", []cache.Score{{Id: "1", Score: 99}, {Id: "2", Score: 98}, {Id: "3", Score: 97}, {Id: "4", Score: 96}}) suite.NoError(err) // insert feedback feedback := []data.Feedback{ {FeedbackKey: data.FeedbackKey{FeedbackType: "a", UserId: "0", ItemId: "1"}}, {FeedbackKey: data.FeedbackKey{FeedbackType: "a", UserId: "0", ItemId: "2"}}, {FeedbackKey: data.FeedbackKey{FeedbackType: "a", UserId: "0", ItemId: "3"}}, {FeedbackKey: data.FeedbackKey{FeedbackType: "a", UserId: "0", ItemId: "4"}}, } apitest.New(). Handler(suite.handler). Post("/api/feedback"). Header("X-API-Key", apiKey). JSON(feedback). Expect(suite.T()). Status(http.StatusOK). Body(`{"RowAffected": 4}`). End() // insert similar users err = suite.CacheClient.AddScores(ctx, cache.UserToUser, cache.Key("default", "0"), []cache.Score{ {Id: "1", Score: 2, Categories: []string{""}}, {Id: "2", Score: 1.5, Categories: []string{""}}, {Id: "3", Score: 1, Categories: []string{""}}, }) suite.NoError(err) err = suite.DataClient.BatchInsertFeedback(ctx, []data.Feedback{ {FeedbackKey: data.FeedbackKey{FeedbackType: "a", UserId: "1", ItemId: "11"}}, }, true, true, true) suite.NoError(err) err = suite.DataClient.BatchInsertFeedback(ctx, []data.Feedback{ {FeedbackKey: data.FeedbackKey{FeedbackType: "a", UserId: "2", ItemId: "12"}}, {FeedbackKey: data.FeedbackKey{FeedbackType: "a", UserId: "2", ItemId: "48"}}, }, true, true, true) suite.NoError(err) err = suite.DataClient.BatchInsertFeedback(ctx, []data.Feedback{ {FeedbackKey: data.FeedbackKey{FeedbackType: "a", UserId: "3", ItemId: "13"}}, {FeedbackKey: data.FeedbackKey{FeedbackType: "a", UserId: "3", ItemId: "48"}}, }, true, true, true) suite.NoError(err) // insert categorized items err = suite.DataClient.BatchInsertItems(ctx, []data.Item{ {ItemId: "12", Categories: []string{"*"}}, {ItemId: "48", Categories: []string{"*"}}, }) suite.NoError(err) // test fallback suite.Config.Recommend.Fallback.Recommenders = []string{"user-to-user/default"} apitest.New(). Handler(suite.handler). Get("/api/recommend/0"). Header("X-API-Key", apiKey). QueryParams(map[string]string{ "n": "3", }). Expect(suite.T()). Status(http.StatusOK). Body(suite.marshal([]string{"48", "11", "12"})). End() apitest.New(). Handler(suite.handler). Get("/api/recommend/0/*"). Header("X-API-Key", apiKey). QueryParams(map[string]string{ "n": "3", }). Expect(suite.T()). Status(http.StatusOK). Body(suite.marshal([]string{"48", "12"})). End() } func (suite *ServerTestSuite) TestRecommendFallbackLatest() { ctx := suite.T().Context() // insert offline recommendation err := suite.CacheClient.AddScores(ctx, cache.Recommend, "0", []cache.Score{ {Id: "1", Score: 99, Categories: []string{"*"}}, {Id: "2", Score: 98, Categories: []string{"*"}}, {Id: "3", Score: 97, Categories: []string{"*"}}, {Id: "4", Score: 96, Categories: []string{"*"}}}) suite.NoError(err) // insert latest items err = suite.DataClient.BatchInsertItems(ctx, []data.Item{ {ItemId: "5", Timestamp: time.Unix(95, 0)}, {ItemId: "6", Timestamp: time.Unix(94, 0)}, {ItemId: "7", Timestamp: time.Unix(93, 0)}, {ItemId: "8", Timestamp: time.Unix(92, 0)}, {ItemId: "105", Categories: []string{"*"}, Timestamp: time.Unix(85, 0)}, {ItemId: "106", Categories: []string{"*"}, Timestamp: time.Unix(84, 0)}, {ItemId: "107", Categories: []string{"*"}, Timestamp: time.Unix(83, 0)}, {ItemId: "108", Categories: []string{"*"}, Timestamp: time.Unix(82, 0)}, }) suite.NoError(err) // test latest fallback suite.Config.Recommend.Fallback.Recommenders = []string{"latest"} apitest.New(). Handler(suite.handler). Get("/api/recommend/0"). Header("X-API-Key", apiKey). QueryParams(map[string]string{ "n": "8", }). Expect(suite.T()). Status(http.StatusOK). Body(suite.marshal([]string{"1", "2", "3", "4", "5", "6", "7", "8"})). End() apitest.New(). Handler(suite.handler). Get("/api/recommend/0/*"). Header("X-API-Key", apiKey). QueryParams(map[string]string{ "n": "8", }). Expect(suite.T()). Status(http.StatusOK). Body(suite.marshal([]string{"1", "2", "3", "4", "105", "106", "107", "108"})). End() } func (suite *ServerTestSuite) TestGetRecommendsFallbackCollaborativeFiltering() { ctx := suite.T().Context() // insert offline recommendation err := suite.CacheClient.AddScores(ctx, cache.Recommend, "0", []cache.Score{ {Id: "1", Score: 99, Categories: []string{"*"}}, {Id: "2", Score: 98, Categories: []string{"*"}}, {Id: "3", Score: 97, Categories: []string{"*"}}, {Id: "4", Score: 96, Categories: []string{"*"}}}) suite.NoError(err) // insert collaborative filtering recommendation err = suite.CacheClient.AddScores(ctx, cache.CollaborativeFiltering, "0", []cache.Score{ {Id: "13", Score: 79, Categories: []string{"*"}}, {Id: "14", Score: 78, Categories: []string{"*"}}, {Id: "15", Score: 77, Categories: []string{"*"}}, {Id: "16", Score: 76, Categories: []string{"*"}}}) suite.NoError(err) // test collaborative filtering suite.Config.Recommend.Fallback.Recommenders = []string{"collaborative"} apitest.New(). Handler(suite.handler). Get("/api/recommend/0"). Header("X-API-Key", apiKey). QueryParams(map[string]string{ "n": "8", }). Expect(suite.T()). Status(http.StatusOK). Body(suite.marshal([]string{"1", "2", "3", "4", "13", "14", "15", "16"})). End() apitest.New(). Handler(suite.handler). Get("/api/recommend/0/*"). Header("X-API-Key", apiKey). QueryParams(map[string]string{ "n": "8", }). Expect(suite.T()). Status(http.StatusOK). Body(suite.marshal([]string{"1", "2", "3", "4", "13", "14", "15", "16"})). End() } func (suite *ServerTestSuite) TestGetRecommendsFallbackNonPersonalized() { ctx := suite.T().Context() // insert offline recommendation err := suite.CacheClient.AddScores(ctx, cache.Recommend, "0", []cache.Score{ {Id: "1", Score: 99, Categories: []string{"*"}}, {Id: "2", Score: 98, Categories: []string{"*"}}, {Id: "3", Score: 97, Categories: []string{"*"}}, {Id: "4", Score: 96, Categories: []string{"*"}}}) suite.NoError(err) // insert non-personalized recommendation err = suite.CacheClient.AddScores(ctx, cache.NonPersonalized, "popular", []cache.Score{ {Id: "5", Score: 95, Categories: []string{""}}, {Id: "6", Score: 94, Categories: []string{""}}, {Id: "7", Score: 93, Categories: []string{""}}, {Id: "8", Score: 92, Categories: []string{""}}, {Id: "105", Score: 91, Categories: []string{"*"}}, {Id: "106", Score: 90, Categories: []string{"*"}}, {Id: "107", Score: 89, Categories: []string{"*"}}, {Id: "108", Score: 88, Categories: []string{"*"}}, }) suite.NoError(err) // test non-personalized fallback suite.Config.Recommend.Fallback.Recommenders = []string{"non-personalized/popular"} apitest.New(). Handler(suite.handler). Get("/api/recommend/0"). Header("X-API-Key", apiKey). QueryParams(map[string]string{ "n": "8", }). Expect(suite.T()). Status(http.StatusOK). Body(suite.marshal([]string{"1", "2", "3", "4", "5", "6", "7", "8"})). End() apitest.New(). Handler(suite.handler). Get("/api/recommend/0/*"). Header("X-API-Key", apiKey). QueryParams(map[string]string{ "n": "8", }). Expect(suite.T()). Status(http.StatusOK). Body(suite.marshal([]string{"1", "2", "3", "4", "105", "106", "107", "108"})). End() } func (suite *ServerTestSuite) TestGetRecommendsLatest() { ctx := suite.T().Context() err := suite.DataClient.BatchInsertItems(ctx, []data.Item{ {ItemId: "5", Timestamp: time.Unix(95, 0)}, {ItemId: "6", Timestamp: time.Unix(94, 0)}, {ItemId: "7", Timestamp: time.Unix(93, 0)}, {ItemId: "8", Timestamp: time.Unix(92, 0)}, }) suite.NoError(err) suite.Config.Recommend.Ranker.Type = "none" suite.Config.Recommend.Ranker.Recommenders = []string{"latest"} suite.Config.Recommend.Fallback.Recommenders = []string{} apitest.New(). Handler(suite.handler). Get("/api/recommend/0"). Header("X-API-Key", apiKey). QueryParams(map[string]string{ "n": "4", }). Expect(suite.T()). Status(http.StatusOK). Body(suite.marshal([]string{"5", "6", "7", "8"})). End() } func (suite *ServerTestSuite) TestSessionRecommend() { ctx := suite.T().Context() suite.Config.Recommend.ContextSize = 4 suite.Config.Recommend.DataSource.PositiveFeedbackTypes = []expression.FeedbackTypeExpression{ expression.MustParseFeedbackTypeExpression("a")} suite.Config.Recommend.ItemToItem = []config.ItemToItemConfig{{Name: "default"}} // insert similar items err := suite.CacheClient.AddScores(ctx, cache.ItemToItem, cache.Key("default", "1"), []cache.Score{ {Id: "2", Score: 100000, Categories: []string{""}}, {Id: "9", Score: 1, Categories: []string{"", "*"}}, {Id: "100", Score: 100000, Categories: []string{""}}, }) suite.NoError(err) err = suite.CacheClient.AddScores(ctx, cache.ItemToItem, cache.Key("default", "2"), []cache.Score{ {Id: "3", Score: 100000, Categories: []string{"", "*"}}, {Id: "8", Score: 1, Categories: []string{""}}, {Id: "9", Score: 1, Categories: []string{"", "*"}}, }) suite.NoError(err) err = suite.CacheClient.AddScores(ctx, cache.ItemToItem, cache.Key("default", "3"), []cache.Score{ {Id: "4", Score: 100000, Categories: []string{""}}, {Id: "7", Score: 1, Categories: []string{"", "*"}}, {Id: "8", Score: 1, Categories: []string{""}}, {Id: "9", Score: 1, Categories: []string{"", "*"}}, }) suite.NoError(err) err = suite.CacheClient.AddScores(ctx, cache.ItemToItem, cache.Key("default", "4"), []cache.Score{ {Id: "1", Score: 100000, Categories: []string{"", "*"}}, {Id: "6", Score: 1, Categories: []string{""}}, {Id: "7", Score: 1, Categories: []string{"", "*"}}, {Id: "8", Score: 1, Categories: []string{""}}, {Id: "9", Score: 1, Categories: []string{"", "*"}}, }) suite.NoError(err) err = suite.CacheClient.AddScores(ctx, cache.ItemToItem, cache.Key("default", "5"), []cache.Score{ {Id: "1", Score: 1, Categories: []string{""}}, {Id: "6", Score: 1, Categories: []string{""}}, {Id: "7", Score: 100000, Categories: []string{""}}, {Id: "8", Score: 100, Categories: []string{""}}, {Id: "9", Score: 1, Categories: []string{""}}, }) suite.NoError(err) // hide items apitest.New(). Handler(suite.handler). Post("/api/item"). Header("X-API-Key", apiKey). JSON(Item{ItemId: "100", IsHidden: true}). Expect(suite.T()). Status(http.StatusOK). Body(`{"RowAffected": 1}`). End() // test fallback feedback := []data.Feedback{ {FeedbackKey: data.FeedbackKey{FeedbackType: "a", UserId: "0", ItemId: "1"}, Timestamp: time.Date(2010, 1, 1, 1, 1, 1, 1, time.UTC)}, {FeedbackKey: data.FeedbackKey{FeedbackType: "a", UserId: "0", ItemId: "2"}, Timestamp: time.Date(2009, 1, 1, 1, 1, 1, 1, time.UTC)}, {FeedbackKey: data.FeedbackKey{FeedbackType: "a", UserId: "0", ItemId: "3"}, Timestamp: time.Date(2008, 1, 1, 1, 1, 1, 1, time.UTC)}, {FeedbackKey: data.FeedbackKey{FeedbackType: "a", UserId: "0", ItemId: "4"}, Timestamp: time.Date(2007, 1, 1, 1, 1, 1, 1, time.UTC)}, {FeedbackKey: data.FeedbackKey{FeedbackType: "a", UserId: "0", ItemId: "5"}, Timestamp: time.Date(2006, 1, 1, 1, 1, 1, 1, time.UTC)}, } apitest.New(). Handler(suite.handler). Post("/api/session/recommend"). Header("X-API-Key", apiKey). QueryParams(map[string]string{ "n": "3", }). JSON(feedback). Expect(suite.T()). Status(http.StatusOK). Body(suite.marshal([]cache.Score{{Id: "9", Score: 4}, {Id: "8", Score: 3}, {Id: "7", Score: 2}})). End() apitest.New(). Handler(suite.handler). Post("/api/session/recommend"). Header("X-API-Key", apiKey). QueryParams(map[string]string{ "offset": "100", }). JSON(feedback). Expect(suite.T()). Status(http.StatusOK). Body(suite.marshal([]cache.Score(nil))). End() suite.Config.Recommend.Fallback.Recommenders = []string{"item_based"} apitest.New(). Handler(suite.handler). Post("/api/session/recommend/*"). Header("X-API-Key", apiKey). QueryParams(map[string]string{ "n": "3", }). JSON(feedback). Expect(suite.T()). Status(http.StatusOK). Body(suite.marshal([]cache.Score{{Id: "9", Score: 4}, {Id: "7", Score: 2}})). End() } func (suite *ServerTestSuite) TestVisibility() { ctx := suite.T().Context() suite.Config.Recommend.ItemToItem = []config.ItemToItemConfig{{Name: "default"}} // insert items: 0, 1, 2, 3, 4 var items []Item for i := 0; i < 5; i++ { items = append(items, Item{ ItemId: strconv.Itoa(i), Categories: []string{"a"}, Timestamp: time.Date(1989, 6, i+1, 1, 1, 1, 1, time.UTC).String(), }) } apitest.New(). Handler(suite.handler). Post("/api/items"). Header("X-API-Key", apiKey). QueryParams(map[string]string{ "n": "3", }). JSON(items). Expect(suite.T()). Status(http.StatusOK). End() // insert cache var documents []cache.Score for i := range items { documents = append(documents, cache.Score{ Id: strconv.Itoa(i), Score: float64(time.Date(1989, 6, i+1, 1, 1, 1, 1, time.UTC).Unix()), Categories: []string{"", "a"}, }) } mutable.Reverse(documents) err := suite.CacheClient.AddScores(ctx, cache.ItemToItem, cache.Key("default", "100"), documents) suite.NoError(err) err = suite.CacheClient.AddScores(ctx, cache.Recommend, "100", documents) suite.NoError(err) // delete item apitest.New(). Handler(suite.handler). Delete("/api/item/0"). Header("X-API-Key", apiKey). JSON(items). Expect(suite.T()). Status(http.StatusOK). End() // modify item apitest.New(). Handler(suite.handler). Patch("/api/item/1"). Header("X-API-Key", apiKey). JSON(data.ItemPatch{IsHidden: new(true)}). Expect(suite.T()). Status(http.StatusOK). End() // overwrite item apitest.New(). Handler(suite.handler). Post("/api/item"). Header("X-API-Key", apiKey). JSON(Item{ItemId: "2", IsHidden: true}). Expect(suite.T()). Status(http.StatusOK). End() // recommend apitest.New(). Handler(suite.handler). Get("/api/latest"). Header("X-API-Key", apiKey). JSON(items). Expect(suite.T()). Status(http.StatusOK). Body(suite.marshal(documents[:2])). End() apitest.New(). Handler(suite.handler). Get("/api/item/100/neighbors/"). Header("X-API-Key", apiKey). JSON(items). Expect(suite.T()). Status(http.StatusOK). Body(suite.marshal(documents[:2])). End() apitest.New(). Handler(suite.handler). Get("/api/recommend/100/"). Header("X-API-Key", apiKey). JSON(items). Expect(suite.T()). Status(http.StatusOK). Body(suite.marshal(cache.ConvertDocumentsToValues(documents[:2]))). End() // insert item apitest.New(). Handler(suite.handler). Post("/api/item"). Header("X-API-Key", apiKey). JSON(Item{ItemId: "0", Timestamp: time.Date(1989, 6, 1, 1, 1, 1, 1, time.UTC).String()}). Expect(suite.T()). Status(http.StatusOK). End() // modify item apitest.New(). Handler(suite.handler). Patch("/api/item/1"). Header("X-API-Key", apiKey). JSON(data.ItemPatch{IsHidden: new(false)}). Expect(suite.T()). Status(http.StatusOK). End() // overwrite item apitest.New(). Handler(suite.handler). Post("/api/item"). Header("X-API-Key", apiKey). JSON(Item{ItemId: "2", IsHidden: false, Timestamp: time.Date(1989, 6, 3, 1, 1, 1, 1, time.UTC).String()}). Expect(suite.T()). Status(http.StatusOK). End() // recommend apitest.New(). Handler(suite.handler). Get("/api/latest"). Header("X-API-Key", apiKey). JSON(items). Expect(suite.T()). Status(http.StatusOK). Body(suite.marshal(documents)). End() apitest.New(). Handler(suite.handler). Get("/api/item/100/neighbors/"). Header("X-API-Key", apiKey). JSON(items). Expect(suite.T()). Status(http.StatusOK). Body(suite.marshal(documents[:len(documents)-1])). End() apitest.New(). Handler(suite.handler). Get("/api/recommend/100/"). Header("X-API-Key", apiKey). JSON(items). Expect(suite.T()). Status(http.StatusOK). Body(suite.marshal(cache.ConvertDocumentsToValues(documents))). End() // delete category apitest.New(). Handler(suite.handler). Delete("/api/item/0/category/a"). Header("X-API-Key", apiKey). JSON(items). Expect(suite.T()). Status(http.StatusOK). End() // modify category apitest.New(). Handler(suite.handler). Patch("/api/item/1"). Header("X-API-Key", apiKey). JSON(data.ItemPatch{Categories: []string{}}). Expect(suite.T()). Status(http.StatusOK). End() // overwrite category apitest.New(). Handler(suite.handler). Post("/api/item"). Header("X-API-Key", apiKey). JSON(Item{ItemId: "2", Categories: []string{}}). Expect(suite.T()). Status(http.StatusOK). End() // recommend apitest.New(). Handler(suite.handler). Get("/api/latest/a"). Header("X-API-Key", apiKey). JSON(items). Expect(suite.T()). Status(http.StatusOK). Body(suite.marshal(documents[:2])). End() apitest.New(). Handler(suite.handler). Get("/api/item/100/neighbors/a"). Header("X-API-Key", apiKey). JSON(items). Expect(suite.T()). Status(http.StatusOK). Body(suite.marshal(documents[:2])). End() apitest.New(). Handler(suite.handler). Get("/api/recommend/100/a"). Header("X-API-Key", apiKey). JSON(items). Expect(suite.T()). Status(http.StatusOK). Body(suite.marshal(cache.ConvertDocumentsToValues(documents[:2]))). End() // delete category apitest.New(). Handler(suite.handler). Put("/api/item/0/category/a"). Header("X-API-Key", apiKey). JSON(items). Expect(suite.T()). Status(http.StatusOK). End() // modify category apitest.New(). Handler(suite.handler). Patch("/api/item/1"). Header("X-API-Key", apiKey). JSON(data.ItemPatch{Categories: []string{"a"}}). Expect(suite.T()). Status(http.StatusOK). End() // overwrite category apitest.New(). Handler(suite.handler). Post("/api/item"). Header("X-API-Key", apiKey). JSON(Item{ItemId: "2", Categories: []string{"a"}, Timestamp: time.Date(1989, 6, 3, 1, 1, 1, 1, time.UTC).String()}). Expect(suite.T()). Status(http.StatusOK). End() // recommend apitest.New(). Handler(suite.handler). Get("/api/latest/a"). Header("X-API-Key", apiKey). JSON(items). Expect(suite.T()). Status(http.StatusOK). Body(suite.marshal(documents)). End() apitest.New(). Handler(suite.handler). Get("/api/item/100/neighbors/a"). Header("X-API-Key", apiKey). JSON(items). Expect(suite.T()). Status(http.StatusOK). Body(suite.marshal(documents[:len(documents)-1])). End() apitest.New(). Handler(suite.handler). Get("/api/recommend/100/a"). Header("X-API-Key", apiKey). JSON(items). Expect(suite.T()). Status(http.StatusOK). Body(suite.marshal(cache.ConvertDocumentsToValues(documents))). End() } func (suite *ServerTestSuite) TestHealth() { t := suite.T() // ready apitest.New(). Handler(suite.handler). Get("/api/health/live"). Expect(t). Status(http.StatusOK). Body(suite.marshal(HealthStatus{ Ready: true, DataStoreError: nil, CacheStoreError: nil, DataStoreConnected: true, CacheStoreConnected: true, })). End() apitest.New(). Handler(suite.handler). Get("/api/health/ready"). Expect(t). Status(http.StatusOK). Body(suite.marshal(HealthStatus{ Ready: true, DataStoreError: nil, CacheStoreError: nil, DataStoreConnected: true, CacheStoreConnected: true, })). End() // not ready dataClient, cacheClient := suite.DataClient, suite.CacheClient suite.DataClient, suite.CacheClient = data.NoDatabase{}, cache.NoDatabase{} apitest.New(). Handler(suite.handler). Get("/api/health/live"). Expect(t). Status(http.StatusOK). Body(suite.marshal(HealthStatus{ Ready: false, DataStoreError: data.ErrNoDatabase, CacheStoreError: cache.ErrNoDatabase, DataStoreConnected: false, CacheStoreConnected: false, })). End() apitest.New(). Handler(suite.handler). Get("/api/health/ready"). Expect(t). Status(http.StatusServiceUnavailable). Body(suite.marshal(HealthStatus{ Ready: false, DataStoreError: data.ErrNoDatabase, CacheStoreError: cache.ErrNoDatabase, DataStoreConnected: false, CacheStoreConnected: false, })). End() suite.DataClient, suite.CacheClient = dataClient, cacheClient } func TestServer(t *testing.T) { suite.Run(t, new(ServerTestSuite)) } ================================================ FILE: server/server.go ================================================ // Copyright 2020 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package server import ( "context" "crypto/md5" "encoding/hex" "encoding/json" "math/rand" "net" "os" "strconv" "strings" "time" "github.com/emicklei/go-restful/v3" "github.com/gorse-io/gorse/cmd/version" "github.com/gorse-io/gorse/common/log" "github.com/gorse-io/gorse/common/util" "github.com/gorse-io/gorse/config" "github.com/gorse-io/gorse/protocol" "github.com/gorse-io/gorse/storage" "github.com/gorse-io/gorse/storage/cache" "github.com/gorse-io/gorse/storage/data" "github.com/samber/lo" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/propagation" "go.uber.org/zap" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" ) // Server manages states of a server node. type Server struct { RestServer traceConfig config.TracingConfig cachePath string cachePrefix string dataPath string dataPrefix string conn *grpc.ClientConn masterClient protocol.MasterClient serverName string masterHost string masterPort int tlsConfig *util.TLSConfig testMode bool cacheFile string } // NewServer creates a server node. func NewServer( masterHost string, masterPort int, serverHost string, serverPort int, cacheFile string, tlsConfig *util.TLSConfig, ) *Server { s := &Server{ masterHost: masterHost, masterPort: masterPort, tlsConfig: tlsConfig, cacheFile: cacheFile, RestServer: RestServer{ Config: config.GetDefaultConfig(), CacheClient: new(cache.NoDatabase), DataClient: new(data.NoDatabase), HttpHost: serverHost, HttpPort: serverPort, WebService: new(restful.WebService), }, } return s } // Serve starts a server node. func (s *Server) Serve() { rand.Seed(time.Now().UTC().UnixNano()) var err error s.serverName, err = s.ServerName() if err != nil { log.Logger().Fatal("failed to get server name", zap.Error(err)) } log.Logger().Info("start server", zap.String("server_name", s.serverName), zap.String("server_host", s.HttpHost), zap.Int("server_port", s.HttpPort), zap.String("master_host", s.masterHost), zap.Int("master_port", s.masterPort)) // connect to master var opts []grpc.DialOption if s.tlsConfig != nil { c, err := util.NewClientCreds(s.tlsConfig) if err != nil { log.Logger().Fatal("failed to create credentials", zap.Error(err)) } opts = append(opts, grpc.WithTransportCredentials(c)) } else { opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) } s.conn, err = grpc.Dial(net.JoinHostPort(s.masterHost, strconv.Itoa(s.masterPort)), opts...) if err != nil { log.Logger().Fatal("failed to connect master", zap.Error(err)) } s.masterClient = protocol.NewMasterClient(s.conn) go s.Sync() container := restful.NewContainer() s.StartHttpServer(container) } func (s *Server) ServerName() (string, error) { hostname, err := os.Hostname() if err != nil { return "", err } hash := md5.New() hash.Write([]byte(hostname)) hash.Write([]byte(s.HttpHost)) hash.Write([]byte(strconv.Itoa(s.HttpPort))) b := hash.Sum(nil) return hex.EncodeToString(b), nil } func (s *Server) Shutdown() { err := s.HttpServer.Shutdown(context.TODO()) if err != nil { log.Logger().Fatal("failed to shutdown http server", zap.Error(err)) } } // Sync this server to the master. func (s *Server) Sync() { defer util.CheckPanic() log.Logger().Info("start meta sync", zap.Duration("meta_timeout", s.Config.Master.MetaTimeout)) for { var meta *protocol.Meta var err error if meta, err = s.masterClient.GetMeta(context.Background(), &protocol.NodeInfo{ NodeType: protocol.NodeType_Server, Uuid: s.serverName, BinaryVersion: version.Version, Hostname: lo.Must(os.Hostname()), }); err != nil { log.Logger().Error("failed to get meta", zap.Error(err)) goto sleep } // load master config err = json.Unmarshal([]byte(meta.Config), &s.Config) if err != nil { log.Logger().Error("failed to parse master config", zap.Error(err)) goto sleep } // connect to data store if s.dataPath != s.Config.Database.DataStore || s.dataPrefix != s.Config.Database.DataTablePrefix { if strings.HasPrefix(s.Config.Database.DataStore, storage.SQLitePrefix) { log.Logger().Info("connect cache store via master") s.DataClient = data.NewProxyClient(s.conn) } else { log.Logger().Info("connect data store", zap.String("database", log.RedactDBURL(s.Config.Database.DataStore))) dataOpts := s.Config.Database.StorageOptions(s.Config.Database.DataStore) if s.DataClient, err = data.Open(s.Config.Database.DataStore, s.Config.Database.DataTablePrefix, dataOpts...); err != nil { log.Logger().Error("failed to connect data store", zap.Error(err)) goto sleep } } s.dataPath = s.Config.Database.DataStore s.dataPrefix = s.Config.Database.DataTablePrefix } // connect to cache store if s.cachePath != s.Config.Database.CacheStore || s.cachePrefix != s.Config.Database.CacheTablePrefix { if strings.HasPrefix(s.Config.Database.CacheStore, storage.SQLitePrefix) { log.Logger().Info("connect cache store via master") s.CacheClient = cache.NewProxyClient(s.conn) } else { log.Logger().Info("connect cache store", zap.String("database", log.RedactDBURL(s.Config.Database.CacheStore))) cacheOpts := s.Config.Database.StorageOptions(s.Config.Database.CacheStore) if s.CacheClient, err = cache.Open(s.Config.Database.CacheStore, s.Config.Database.CacheTablePrefix, cacheOpts...); err != nil { log.Logger().Error("failed to connect cache store", zap.Error(err)) goto sleep } } s.cachePath = s.Config.Database.CacheStore s.cachePrefix = s.Config.Database.CacheTablePrefix } // create trace provider if !s.traceConfig.Equal(s.Config.Tracing) { log.Logger().Info("create trace provider", zap.Any("tracing_config", s.Config.Tracing)) tp, err := s.Config.Tracing.NewTracerProvider() if err != nil { log.Logger().Fatal("failed to create trace provider", zap.Error(err)) } otel.SetTracerProvider(tp) otel.SetErrorHandler(log.GetErrorHandler()) otel.SetTextMapPropagator(propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{})) s.traceConfig = s.Config.Tracing } sleep: if s.testMode { return } time.Sleep(s.Config.Master.MetaTimeout) } } ================================================ FILE: server/server_test.go ================================================ // Copyright 2021 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package server import ( "context" "encoding/json" "fmt" "net" "testing" "github.com/gorse-io/gorse/config" "github.com/gorse-io/gorse/protocol" "github.com/gorse-io/gorse/storage/cache" "github.com/gorse-io/gorse/storage/data" "github.com/stretchr/testify/assert" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" ) type mockMaster struct { protocol.UnimplementedMasterServer addr chan string grpcServer *grpc.Server meta *protocol.Meta cacheTempFile string dataTempFile string } func newMockMaster(t *testing.T) *mockMaster { cfg := config.GetDefaultConfig() cfg.Database.DataStore = fmt.Sprintf("sqlite://%s/data.db", t.TempDir()) cfg.Database.CacheStore = fmt.Sprintf("sqlite://%s/cache.db", t.TempDir()) bytes, err := json.Marshal(cfg) assert.NoError(t, err) return &mockMaster{ addr: make(chan string), meta: &protocol.Meta{Config: string(bytes)}, dataTempFile: cfg.Database.DataStore, cacheTempFile: cfg.Database.CacheStore, } } func (m *mockMaster) GetMeta(_ context.Context, _ *protocol.NodeInfo) (*protocol.Meta, error) { return m.meta, nil } func (m *mockMaster) Start(t *testing.T) { listen, err := net.Listen("tcp", "localhost:0") assert.NoError(t, err) m.addr <- listen.Addr().String() var opts []grpc.ServerOption m.grpcServer = grpc.NewServer(opts...) protocol.RegisterMasterServer(m.grpcServer, m) err = m.grpcServer.Serve(listen) assert.NoError(t, err) } func (m *mockMaster) Stop() { m.grpcServer.Stop() } func TestServer_Sync(t *testing.T) { master := newMockMaster(t) go master.Start(t) address := <-master.addr conn, err := grpc.Dial(address, grpc.WithTransportCredentials(insecure.NewCredentials())) assert.NoError(t, err) serv := &Server{ testMode: true, masterClient: protocol.NewMasterClient(conn), RestServer: RestServer{ Config: config.GetDefaultConfig(), CacheClient: new(cache.NoDatabase), DataClient: new(data.NoDatabase), }, } serv.Sync() assert.Equal(t, master.dataTempFile, serv.dataPath) assert.Equal(t, master.cacheTempFile, serv.cachePath) assert.NoError(t, serv.DataClient.Close()) assert.NoError(t, serv.CacheClient.Close()) master.Stop() } ================================================ FILE: storage/blob/azure.go ================================================ // Copyright 2025 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package blob import ( "context" "fmt" "io" "path/filepath" "strings" "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob" "github.com/gorse-io/gorse/common/log" "github.com/gorse-io/gorse/config" "github.com/juju/errors" "go.uber.org/zap" ) type AzureBlob struct { client *azblob.Client container string prefix string } func NewAzureBlob(cfg config.AzureBlobConfig, container string, prefix string) (*AzureBlob, error) { var ( client *azblob.Client err error ) if cfg.ConnectionString != "" { client, err = azblob.NewClientFromConnectionString(cfg.ConnectionString, nil) if err != nil { return nil, errors.Trace(err) } } else { if cfg.AccountName == "" || cfg.AccountKey == "" { return nil, errors.New("azure blob requires account_name and account_key or connection_string") } endpoint := cfg.Endpoint if endpoint == "" { endpoint = fmt.Sprintf("https://%s.blob.core.windows.net/", cfg.AccountName) } cred, err := azblob.NewSharedKeyCredential(cfg.AccountName, cfg.AccountKey) if err != nil { return nil, errors.Trace(err) } client, err = azblob.NewClientWithSharedKeyCredential(endpoint, cred, nil) if err != nil { return nil, errors.Trace(err) } } return &AzureBlob{ client: client, container: container, prefix: strings.TrimPrefix(prefix, "/"), }, nil } func (a *AzureBlob) Open(name string) (io.ReadCloser, error) { fullPath := filepath.Join(a.prefix, name) resp, err := a.client.DownloadStream(context.Background(), a.container, fullPath, nil) if err != nil { return nil, err } return resp.Body, nil } func (a *AzureBlob) Create(name string) (io.WriteCloser, chan struct{}, error) { fullPath := filepath.Join(a.prefix, name) pr, pw := io.Pipe() done := make(chan struct{}) go func() { defer close(done) _, err := a.client.UploadStream(context.Background(), a.container, fullPath, pr, nil) if err != nil { log.Logger().Error("failed to upload file to Azure Blob", zap.String("file", fullPath), zap.Error(err)) } }() return pw, done, nil } func (a *AzureBlob) List() ([]string, error) { var ( prefix *string names []string ) if a.prefix != "" { prefix = &a.prefix } pager := a.client.NewListBlobsFlatPager(a.container, &azblob.ListBlobsFlatOptions{Prefix: prefix}) for pager.More() { resp, err := pager.NextPage(context.Background()) if err != nil { return nil, err } for _, item := range resp.Segment.BlobItems { name := "" if item.Name != nil { name = *item.Name } if a.prefix != "" { name = strings.TrimPrefix(name, a.prefix) if len(name) > 0 && name[0] == '/' { name = name[1:] } } if name != "" { names = append(names, name) } } } return names, nil } func (a *AzureBlob) Remove(name string) error { fullPath := filepath.Join(a.prefix, name) _, err := a.client.DeleteBlob(context.Background(), a.container, fullPath, nil) if err != nil { log.Logger().Error("failed to remove file from Azure Blob", zap.String("file", fullPath), zap.Error(err)) return err } return nil } ================================================ FILE: storage/blob/azure_test.go ================================================ // Copyright 2025 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package blob import ( "errors" "io" "os" "testing" "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/bloberror" "github.com/gorse-io/gorse/config" "github.com/stretchr/testify/assert" ) func TestAzureBlobEmulator(t *testing.T) { connectionString := os.Getenv("AZURE_STORAGE_CONNECTION_STRING") if connectionString == "" { t.Skip("AZURE_STORAGE_CONNECTION_STRING is not set, skipping Azure Blob emulator test") } client, err := NewAzureBlob(config.AzureBlobConfig{ConnectionString: connectionString}, "gorse-test", "blob") assert.NoError(t, err) ctx := t.Context() _, err = client.client.CreateContainer(ctx, client.container, nil) if err != nil { var respErr *azcore.ResponseError if !errors.As(err, &respErr) || respErr.ErrorCode != string(bloberror.ContainerAlreadyExists) { assert.NoError(t, err) } } w, done, err := client.Create("test.txt") assert.NoError(t, err) _, err = w.Write([]byte("hello")) assert.NoError(t, err) assert.NoError(t, w.Close()) <-done r, err := client.Open("test.txt") assert.NoError(t, err) data, err := io.ReadAll(r) assert.NoError(t, err) assert.Equal(t, "hello", string(data)) assert.NoError(t, r.Close()) names, err := client.List() assert.NoError(t, err) assert.Contains(t, names, "test.txt") err = client.Remove("test.txt") assert.NoError(t, err) names, err = client.List() assert.NoError(t, err) assert.NotContains(t, names, "test.txt") } ================================================ FILE: storage/blob/blob.go ================================================ // Copyright 2024 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package blob import ( "context" "io" "net/url" "os" "path" "strings" "time" "github.com/gorse-io/gorse/common/log" "github.com/gorse-io/gorse/config" "github.com/gorse-io/gorse/protocol" "github.com/juju/errors" "go.uber.org/zap" "google.golang.org/grpc" "google.golang.org/protobuf/types/known/timestamppb" ) type Store interface { Open(name string) (io.ReadCloser, error) Create(name string) (io.WriteCloser, chan struct{}, error) List() ([]string, error) Remove(name string) error } func NewStore(cfg config.BlobConfig, masterConn *grpc.ClientConn) (Store, error) { if strings.Contains(cfg.URI, "://") { parsed, err := url.Parse(cfg.URI) if err != nil { return nil, errors.Trace(err) } switch parsed.Scheme { case "s3": if parsed.Host == "" { return nil, errors.New("blob.uri must include bucket for s3://") } store, err := NewS3(cfg.S3, parsed.Host, strings.TrimPrefix(parsed.Path, "/")) if err != nil { return nil, err } return store, nil case "gs": if parsed.Host == "" { return nil, errors.New("blob.uri must include bucket for gs://") } store, err := NewGCS(cfg.GCS, parsed.Host, strings.TrimPrefix(parsed.Path, "/")) if err != nil { return nil, err } return store, nil case "az": if parsed.Host == "" { return nil, errors.New("blob.uri must include container for az://") } store, err := NewAzureBlob(cfg.Azure, parsed.Host, strings.TrimPrefix(parsed.Path, "/")) if err != nil { return nil, err } return store, nil default: return nil, errors.Errorf("unsupported blob.uri scheme: %s", parsed.Scheme) } } if masterConn != nil { return NewMasterStoreClient(masterConn), nil } return NewPOSIX(cfg.URI), nil } type MasterStoreServer struct { protocol.UnimplementedBlobStoreServer dir string } func NewMasterStoreServer(dir string) *MasterStoreServer { // Create directory if not exists err := os.MkdirAll(dir, os.ModePerm) if err != nil { log.Logger().Fatal("failed to create directory", zap.Error(err)) } return &MasterStoreServer{dir: dir} } func (s *MasterStoreServer) UploadBlob(stream grpc.ClientStreamingServer[protocol.UploadBlobRequest, protocol.UploadBlobResponse]) error { // Create temp file file, err := os.CreateTemp(s.dir, "upload-*") if err != nil { return err } defer func(file *os.File) { _ = file.Close() }(file) // Write data var ( name string timestamp time.Time ) for { req, err := stream.Recv() if err != nil { if errors.Is(err, io.EOF) { break } return err } // Assign name if name == "" { name = req.Name } else if name != req.Name { return errors.New("inconsistent name") } // Assign timestamp if timestamp.IsZero() { timestamp = req.Timestamp.AsTime() } else if !timestamp.Equal(req.Timestamp.AsTime()) { return errors.New("inconsistent timestamp") } // Write data _, err = file.Write(req.Data) if err != nil { return err } } // Close file err = file.Close() if err != nil { return err } // Rename file err = os.Rename(file.Name(), path.Join(s.dir, name)) if err != nil { return err } return stream.SendAndClose(&protocol.UploadBlobResponse{}) } func (s *MasterStoreServer) DownloadBlob(request *protocol.DownloadBlobRequest, stream grpc.ServerStreamingServer[protocol.DownloadBlobResponse]) error { // Open file file, err := os.Open(path.Join(s.dir, request.Name)) if err != nil { return err } defer func(file *os.File) { err = file.Close() if err != nil { log.Logger().Error("failed to close file", zap.Error(err)) } }(file) // Send data for { data := make([]byte, 1024*1024*4) n, err := file.Read(data) if err != nil { if errors.Is(err, io.EOF) { break } return err } err = stream.Send(&protocol.DownloadBlobResponse{Data: data[:n]}) if err != nil { return err } } return nil } func (s *MasterStoreServer) ListBlobs(ctx context.Context, request *protocol.ListBlobsRequest) (*protocol.ListBlobsResponse, error) { files, err := os.ReadDir(s.dir) if err != nil { return nil, err } var names []string for _, file := range files { if !file.IsDir() { names = append(names, file.Name()) } } return &protocol.ListBlobsResponse{Names: names}, nil } func (s *MasterStoreServer) RemoveBlob(ctx context.Context, request *protocol.RemoveBlobRequest) (*protocol.RemoveBlobResponse, error) { err := os.Remove(path.Join(s.dir, request.Name)) if err != nil { return nil, err } return &protocol.RemoveBlobResponse{}, nil } type MasterStoreClient struct { client protocol.BlobStoreClient } func NewMasterStoreClient(clientConn *grpc.ClientConn) *MasterStoreClient { return &MasterStoreClient{client: protocol.NewBlobStoreClient(clientConn)} } func (c *MasterStoreClient) Open(name string) (io.ReadCloser, error) { stream, err := c.client.DownloadBlob(context.Background(), &protocol.DownloadBlobRequest{Name: name}) if err != nil { return nil, err } pr, pw := io.Pipe() go func() { for { resp, err := stream.Recv() if err != nil { _ = pw.CloseWithError(err) return } _, err = pw.Write(resp.Data) if err != nil { _ = pw.CloseWithError(err) return } } }() return pr, nil } func (c *MasterStoreClient) Create(name string) (io.WriteCloser, chan struct{}, error) { stream, err := c.client.UploadBlob(context.Background()) if err != nil { return nil, nil, err } done := make(chan struct{}) pr, pw := io.Pipe() go func() { defer close(done) for { data := make([]byte, 1024*1024*4) n, err := pr.Read(data) if err != nil { if errors.Is(err, io.EOF) { break } log.Logger().Error("failed to read data", zap.Error(err)) return } err = stream.Send(&protocol.UploadBlobRequest{ Name: name, Timestamp: timestamppb.Now(), Data: data[:n], }) if err != nil { log.Logger().Error("failed to send data", zap.Error(err)) return } } _, err = stream.CloseAndRecv() if err != nil { log.Logger().Error("failed to close stream", zap.Error(err)) } }() return pw, done, nil } func (c *MasterStoreClient) List() ([]string, error) { resp, err := c.client.ListBlobs(context.Background(), &protocol.ListBlobsRequest{}) if err != nil { return nil, err } return resp.Names, nil } func (c *MasterStoreClient) Remove(name string) error { _, err := c.client.RemoveBlob(context.Background(), &protocol.RemoveBlobRequest{Name: name}) if err != nil { return err } return nil } ================================================ FILE: storage/blob/blob_test.go ================================================ // Copyright 2024 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package blob import ( "io" "net" "path" "testing" "github.com/gorse-io/gorse/protocol" "github.com/stretchr/testify/assert" "google.golang.org/grpc" ) func TestBlob(t *testing.T) { // start server lis, err := net.Listen("tcp", "localhost:0") assert.NoError(t, err) grpcServer := grpc.NewServer() protocol.RegisterBlobStoreServer(grpcServer, NewMasterStoreServer(path.Join(t.TempDir(), "blob"))) go func() { err = grpcServer.Serve(lis) assert.NoError(t, err) }() defer grpcServer.Stop() // create client clientConn, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure()) assert.NoError(t, err) client := NewMasterStoreClient(clientConn) // write a temp file w, done, err := client.Create("test") assert.NoError(t, err) _, err = w.Write([]byte("hello world")) assert.NoError(t, err) assert.NoError(t, w.Close()) <-done // read the file r, err := client.Open("test") assert.NoError(t, err) content := make([]byte, 11) _, err = r.Read(content) assert.NoError(t, err) assert.Equal(t, "hello world", string(content)) _, err = r.Read(content) assert.ErrorIs(t, err, io.EOF) // list files files, err := client.List() assert.NoError(t, err) assert.Contains(t, files, "test") // remove the file err = client.Remove("test") assert.NoError(t, err) files, err = client.List() assert.NoError(t, err) assert.NotContains(t, files, "test") } ================================================ FILE: storage/blob/gcs.go ================================================ // Copyright 2025 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package blob import ( "context" "io" "os" "path/filepath" "cloud.google.com/go/storage" "github.com/gorse-io/gorse/config" "google.golang.org/api/iterator" "google.golang.org/api/option" ) type GCS struct { client *storage.Client bucket string prefix string } func NewGCS(cfg config.GCSConfig, bucket string, prefix string) (*GCS, error) { var opts []option.ClientOption if os.Getenv("GCS_EMULATOR_ENDPOINT") != "" { opts = append(opts, option.WithEndpoint(os.Getenv("GCS_EMULATOR_ENDPOINT"))) opts = append(opts, option.WithoutAuthentication()) } if cfg.CredentialsFile != "" { opts = append(opts, option.WithCredentialsFile(cfg.CredentialsFile)) } client, err := storage.NewClient(context.Background(), opts...) if err != nil { return nil, err } return &GCS{ client: client, bucket: bucket, prefix: prefix, }, nil } func (g *GCS) Open(name string) (io.ReadCloser, error) { path := filepath.Join(g.prefix, name) return g.client.Bucket(g.bucket).Object(path).NewReader(context.Background()) } func (g *GCS) Create(name string) (io.WriteCloser, chan struct{}, error) { path := filepath.Join(g.prefix, name) wc := g.client.Bucket(g.bucket).Object(path).NewWriter(context.Background()) done := make(chan struct{}) return &gcsWriter{wc, done}, done, nil } type gcsWriter struct { *storage.Writer done chan struct{} } func (w *gcsWriter) Close() error { err := w.Writer.Close() close(w.done) return err } func (g *GCS) List() ([]string, error) { var names []string it := g.client.Bucket(g.bucket).Objects(context.Background(), &storage.Query{ Prefix: g.prefix, }) for { attrs, err := it.Next() if err == iterator.Done { break } if err != nil { return nil, err } name := attrs.Name[len(g.prefix):] if len(name) > 0 && name[0] == os.PathSeparator { name = name[1:] } names = append(names, name) } return names, nil } func (g *GCS) Remove(name string) error { path := filepath.Join(g.prefix, name) return g.client.Bucket(g.bucket).Object(path).Delete(context.Background()) } ================================================ FILE: storage/blob/gcs_test.go ================================================ // Copyright 2025 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package blob import ( "io" "testing" "github.com/fsouza/fake-gcs-server/fakestorage" "github.com/gorse-io/gorse/config" "github.com/stretchr/testify/assert" ) func TestGCS(t *testing.T) { server, err := fakestorage.NewServerWithOptions(fakestorage.Options{ Scheme: "http", Port: 5050, PublicHost: "localhost:5050", }) assert.NoError(t, err) defer server.Stop() t.Setenv("GCS_EMULATOR_ENDPOINT", "http://localhost:5050/storage/v1/") // create client client, err := NewGCS(config.GCSConfig{}, "gorse-test", "blob") assert.NoError(t, err) // create bucket if not exists err = client.client.Bucket("gorse-test").Create(t.Context(), "test-project", nil) if err != nil { assert.ErrorContains(t, err, "A Cloud Storage bucket named 'gorse-test' already exists.") } // create file w, done, err := client.Create("test.txt") assert.NoError(t, err) _, err = w.Write([]byte("hello")) assert.NoError(t, err) err = w.Close() assert.NoError(t, err) <-done // list files names, err := client.List() assert.NoError(t, err) assert.Equal(t, []string{"test.txt"}, names) // read file r, err := client.Open("test.txt") assert.NoError(t, err) data, err := io.ReadAll(r) assert.NoError(t, err) assert.Equal(t, "hello", string(data)) err = r.Close() assert.NoError(t, err) // remove file err = client.Remove("test.txt") assert.NoError(t, err) // list files again names, err = client.List() assert.NoError(t, err) assert.Empty(t, names) } ================================================ FILE: storage/blob/posix.go ================================================ // Copyright 2025 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package blob import ( "io" "os" "path" "github.com/gorse-io/gorse/common/log" "go.uber.org/zap" ) type POSIX struct { dir string } func NewPOSIX(dir string) *POSIX { return &POSIX{dir: dir} } // Open a file for reading. It returns an io.Reader that can be used to read the file's content. func (p *POSIX) Open(name string) (io.ReadCloser, error) { fullPath := path.Join(p.dir, name) return os.Open(fullPath) } // Create a new file for writing. It returns an io.WriteCloser that can be used to write data to the file. It also // returns a done channel that is closed when the writing is complete. func (p *POSIX) Create(name string) (io.WriteCloser, chan struct{}, error) { fullPath := path.Join(p.dir, name) if err := os.MkdirAll(path.Dir(fullPath), os.ModePerm); err != nil { return nil, nil, err } file, err := os.Create(fullPath) if err != nil { return nil, nil, err } done := make(chan struct{}) pr, pw := io.Pipe() go func() { defer func() { _ = file.Close() close(done) }() _, err := io.Copy(file, pr) if err != nil { log.Logger().Error("failed to write to file", zap.String("file", fullPath), zap.Error(err)) } }() return pw, done, err } // List files in the directory. It returns a slice of file names. func (p *POSIX) List() ([]string, error) { files, err := os.ReadDir(p.dir) if err != nil { return nil, err } var names []string for _, file := range files { if !file.IsDir() { names = append(names, file.Name()) } } return names, nil } // Remove a file by its name. It deletes the file from the filesystem. func (p *POSIX) Remove(name string) error { fullPath := path.Join(p.dir, name) if err := os.Remove(fullPath); err != nil { log.Logger().Error("failed to remove file", zap.String("file", fullPath), zap.Error(err)) return err } return nil } ================================================ FILE: storage/blob/posix_test.go ================================================ // Copyright 2025 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package blob import ( "github.com/stretchr/testify/assert" "path" "testing" ) func TestPOSIX(t *testing.T) { // create client client := NewPOSIX(path.Join(t.TempDir(), "blob")) // write a temp file w, done, err := client.Create("test") assert.NoError(t, err) _, err = w.Write([]byte("hello world")) assert.NoError(t, err) assert.NoError(t, w.Close()) <-done // read the file r, err := client.Open("test") assert.NoError(t, err) content := make([]byte, 11) _, err = r.Read(content) assert.NoError(t, err) assert.Equal(t, "hello world", string(content)) assert.NoError(t, r.Close()) // list files files, err := client.List() assert.NoError(t, err) assert.Contains(t, files, "test") // remove the file err = client.Remove("test") assert.NoError(t, err) files, err = client.List() assert.NoError(t, err) assert.NotContains(t, files, "test") } ================================================ FILE: storage/blob/s3.go ================================================ // Copyright 2025 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package blob import ( "context" "io" "path/filepath" "github.com/gorse-io/gorse/common/log" "github.com/gorse-io/gorse/config" "github.com/minio/minio-go/v7" "github.com/minio/minio-go/v7/pkg/credentials" "go.uber.org/zap" ) type S3 struct { *minio.Client bucket string prefix string } func NewS3(cfg config.S3Config, bucket string, prefix string) (*S3, error) { minioClient, err := minio.New(cfg.Endpoint, &minio.Options{ Creds: credentials.NewStaticV4(cfg.AccessKeyID, cfg.SecretAccessKey, ""), }) if err != nil { return nil, err } return &S3{ Client: minioClient, bucket: bucket, prefix: prefix, }, nil } // Open a file in S3 for reading. This function returns an io.Reader that can be used to read the file's content. func (s *S3) Open(name string) (io.ReadCloser, error) { object, err := s.Client.GetObject(context.Background(), s.bucket, filepath.Join(s.prefix, name), minio.GetObjectOptions{}) if err != nil { return nil, err } return object, nil } // Create a new file in S3 for writing. This function returns an io.WriteCloser that can be used to write data to the // file. It also returns a done channel that is closed when the writing is complete. func (s *S3) Create(name string) (io.WriteCloser, chan struct{}, error) { fullPath := filepath.Join(s.prefix, name) pr, pw := io.Pipe() done := make(chan struct{}) go func() { defer close(done) _, err := s.Client.PutObject(context.Background(), s.bucket, fullPath, pr, -1, minio.PutObjectOptions{}) if err != nil { log.Logger().Error("failed to upload file to S3", zap.String("file", fullPath), zap.Error(err)) } }() return pw, done, nil } // List files in the S3 bucket with the specified prefix. This function returns a slice of file names. func (s *S3) List() ([]string, error) { var names []string for object := range s.Client.ListObjects(context.Background(), s.bucket, minio.ListObjectsOptions{ Prefix: s.prefix, Recursive: true, }) { if object.Err != nil { return nil, object.Err } name := object.Key[len(s.prefix):] if len(name) > 0 && name[0] == '/' { name = name[1:] } names = append(names, name) } return names, nil } // Remove a file from the S3 bucket by its name. This function deletes the file from S3. func (s *S3) Remove(name string) error { fullPath := filepath.Join(s.prefix, name) err := s.Client.RemoveObject(context.Background(), s.bucket, fullPath, minio.RemoveObjectOptions{}) if err != nil { log.Logger().Error("failed to remove file from S3", zap.String("file", fullPath), zap.Error(err)) return err } return nil } ================================================ FILE: storage/blob/s3_test.go ================================================ // Copyright 2025 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package blob import ( "io" "os" "testing" "github.com/gorse-io/gorse/config" "github.com/minio/minio-go/v7" "github.com/stretchr/testify/assert" ) var ( endpoint = os.Getenv("S3_ENDPOINT") accessKeyID = os.Getenv("S3_ACCESS_KEY_ID") secretAccessKey = os.Getenv("S3_SECRET_ACCESS_KEY") ) func TestS3(t *testing.T) { if endpoint == "" || accessKeyID == "" || secretAccessKey == "" { t.Skip("S3 environment variables are not set, skipping S3 tests") } // create client client, err := NewS3(config.S3Config{ Endpoint: endpoint, AccessKeyID: accessKeyID, SecretAccessKey: secretAccessKey, }, "gorse-test", "blob") assert.NoError(t, err) // create bucket if not exists err = client.Client.MakeBucket(t.Context(), client.bucket, minio.MakeBucketOptions{}) assert.NoError(t, err) // write a temp file w, done, err := client.Create("test") assert.NoError(t, err) _, err = w.Write([]byte("hello world")) assert.NoError(t, err) assert.NoError(t, w.Close()) <-done // read the file r, err := client.Open("test") assert.NoError(t, err) content := make([]byte, 11) _, err = r.Read(content) assert.ErrorIs(t, err, io.EOF) assert.Equal(t, "hello world", string(content)) // list files files, err := client.List() assert.NoError(t, err) assert.Contains(t, files, "test") // remove the file err = client.Remove("test") assert.NoError(t, err) files, err = client.List() assert.NoError(t, err) assert.NotContains(t, files, "test") } ================================================ FILE: storage/cache/database.go ================================================ // Copyright 2020 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package cache import ( "context" "sort" "strconv" "strings" "time" "github.com/araddon/dateparse" "github.com/gorse-io/gorse/storage" "github.com/juju/errors" ) const ( NonPersonalized = "non-personalized" NonPersonalizedDigest = "non-personalized_digest" NonPersonalizedUpdateTime = "non-personalized_update_time" ItemToItem = "item-to-item" ItemToItemDigest = "item-to-item_digest" ItemToItemUpdateTime = "item-to-item_update_time" UserToUser = "user-to-user" UserToUserDigest = "user-to-user_digest" UserToUserUpdateTime = "user-to-user_update_time" CollaborativeFiltering = "collaborative-filtering" CollaborativeFilteringDigest = "collaborative-filtering_digest" CollaborativeFilteringUpdateTime = "collaborative-filtering_update_time" Recommend = "recommend" RecommendDigest = "recommend_digest" RecommendUpdateTime = "recommend_update_time" // ItemCategories is the set of item categories. The format of key: // Global item categories - item_categories ItemCategories = "item_categories" LastModifyItemTime = "last_modify_item_time" // the latest timestamp that a user related data was modified LastModifyUserTime = "last_modify_user_time" // the latest timestamp that an item related data was modified // GlobalMeta is global meta information GlobalMeta = "global_meta" NumUsers = "num_users" NumItems = "num_items" NumFeedback = "num_feedback" NumPosFeedbacks = "num_pos_feedbacks" NumNegFeedbacks = "num_neg_feedbacks" NumUserLabels = "num_user_labels" NumItemLabels = "num_item_labels" NumTotalPosFeedbacks = "num_total_pos_feedbacks" NumValidPosFeedbacks = "num_valid_pos_feedbacks" NumValidNegFeedbacks = "num_valid_neg_feedbacks" LastFitMatchingModelTime = "last_fit_matching_model_time" LastFitRankingModelTime = "last_fit_ranking_model_time" LastUpdateLatestItemsTime = "last_update_latest_items_time" // the latest timestamp that latest items were updated LastUpdatePopularItemsTime = "last_update_popular_items_time" // the latest timestamp that popular items were updated CFNDCG = "cf_ndcg" CFPrecision = "cf_precision" CFRecall = "cf_recall" CTRPrecision = "ctr_precision" CTRRecall = "ctr_recall" CTRAUC = "ctr_auc" PositiveFeedbackRatio = "positive_feedback_ratio" ) var ItemCache = []string{ NonPersonalized, ItemToItem, Recommend, } var ( ErrObjectNotExist = errors.NotFoundf("object") ErrNoDatabase = errors.NotAssignedf("database") ) // Key creates key for cache. Empty field will be ignored. func Key(keys ...string) string { if len(keys) == 0 { return "" } var builder strings.Builder builder.WriteString(keys[0]) for _, key := range keys[1:] { if key != "" { builder.WriteRune('/') builder.WriteString(key) } } return builder.String() } type Value struct { name string value string } func String(name, value string) Value { return Value{name: name, value: value} } func Integer(name string, value int) Value { return Value{name: name, value: strconv.Itoa(value)} } func Time(name string, value time.Time) Value { return Value{name: name, value: value.String()} } type ReturnValue struct { value string err error exists bool } func (r *ReturnValue) String() (string, error) { if r.err != nil { return "", r.err } if !r.exists { return "", nil } return r.value, nil } func (r *ReturnValue) Integer() (int, error) { if r.err != nil { return 0, r.err } if !r.exists { return 0, nil } return strconv.Atoi(r.value) } func (r *ReturnValue) Time() (time.Time, error) { if r.err != nil { return time.Time{}, r.err } if !r.exists { return time.Time{}, nil } t, err := dateparse.ParseAny(r.value) if err != nil { return time.Time{}, errors.Trace(err) } return t.In(time.UTC), nil } func (r *ReturnValue) Exists() bool { return r.exists } type Score struct { Id string Score float64 IsHidden bool `json:"-"` Categories []string `json:"-" gorm:"type:text;serializer:json"` Timestamp time.Time `json:"-"` } func SortDocuments(documents []Score) { sort.Slice(documents, func(i, j int) bool { return documents[i].Score > documents[j].Score }) } func ConvertDocumentsToValues(documents []Score) []string { values := make([]string, len(documents)) for i := range values { values[i] = documents[i].Id } return values } type ScoreCondition struct { Subset *string Id *string Before *time.Time } func (condition *ScoreCondition) Check() error { if condition.Id == nil && condition.Before == nil && condition.Subset == nil { return errors.NotValidf("document condition") } return nil } type ScorePatch struct { IsHidden *bool Categories []string Score *float64 } type TimeSeriesPoint struct { Name string `gorm:"primaryKey"` Timestamp time.Time `gorm:"primaryKey"` Value float64 } // Database is the common interface for cache store. type Database interface { Close() error Ping() error Init() error Scan(work func(string) error) error Purge() error Set(ctx context.Context, values ...Value) error Get(ctx context.Context, name string) *ReturnValue Delete(ctx context.Context, name string) error Push(ctx context.Context, name, value string) error Pop(ctx context.Context, name string) (string, error) Remain(ctx context.Context, name string) (int64, error) AddScores(ctx context.Context, collection, subset string, documents []Score) error SearchScores(ctx context.Context, collection, subset string, query []string, begin, end int) ([]Score, error) DeleteScores(ctx context.Context, collection []string, condition ScoreCondition) error UpdateScores(ctx context.Context, collections []string, subset *string, id string, patch ScorePatch) error ScanScores(ctx context.Context, callback func(collection, id, subset string, timestamp time.Time) error) error AddTimeSeriesPoints(ctx context.Context, points []TimeSeriesPoint) error GetTimeSeriesPoints(ctx context.Context, name string, begin, end time.Time, duration time.Duration) ([]TimeSeriesPoint, error) } // Creator creates a database instance. type Creator func(path, tablePrefix string, opts ...storage.Option) (Database, error) var creators = make(map[string]Creator) // Register a database creator. func Register(prefixes []string, creator Creator) { for _, p := range prefixes { creators[p] = creator } } // Open a connection to a database. func Open(path, tablePrefix string, opts ...storage.Option) (Database, error) { for prefix, creator := range creators { if strings.HasPrefix(path, prefix) { return creator(path, tablePrefix, opts...) } } return nil, errors.Errorf("Unknown database: %s", path) } ================================================ FILE: storage/cache/database_test.go ================================================ // Copyright 2020 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package cache import ( "context" "io" "math" "math/rand" "os" "strconv" "testing" "time" "github.com/fxtlabs/primes" "github.com/juju/errors" "github.com/samber/lo" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) type baseTestSuite struct { suite.Suite Database } func (suite *baseTestSuite) TearDownSuite() { err := suite.Database.Close() suite.NoError(err) } func (suite *baseTestSuite) SetupTest() { err := suite.Database.Ping() suite.NoError(err) err = suite.Database.Purge() suite.NoError(err) } func (suite *baseTestSuite) TearDownTest() { err := suite.Database.Purge() suite.NoError(err) } func (suite *baseTestSuite) TestInit() { err := suite.Database.Init() suite.NoError(err) } func (suite *baseTestSuite) TestMeta() { ctx := suite.T().Context() // Set meta string err := suite.Database.Set(ctx, String(Key("meta", "1"), "2"), String(Key("meta", "1000"), "10")) suite.NoError(err) // Get meta string value, err := suite.Database.Get(ctx, Key("meta", "1")).String() suite.NoError(err) suite.Equal("2", value) value, err = suite.Database.Get(ctx, Key("meta", "1000")).String() suite.NoError(err) suite.Equal("10", value) // Delete string err = suite.Database.Delete(ctx, Key("meta", "1")) suite.NoError(err) // Get meta not existed ret := suite.Database.Get(ctx, Key("meta", "1")) suite.False(ret.Exists()) value, err = ret.String() suite.NoError(err) suite.Equal("", value) // Set meta int err = suite.Database.Set(ctx, Integer(Key("meta", "1"), 2)) suite.NoError(err) // Get meta int valInt, err := suite.Database.Get(ctx, Key("meta", "1")).Integer() suite.NoError(err) suite.Equal(2, valInt) // set meta time err = suite.Database.Set(ctx, Time(Key("meta", "1"), time.Date(1996, 4, 8, 0, 0, 0, 0, time.UTC))) suite.NoError(err) // get meta time valTime, err := suite.Database.Get(ctx, Key("meta", "1")).Time() suite.NoError(err) suite.Equal(1996, valTime.Year()) suite.Equal(time.Month(4), valTime.Month()) suite.Equal(8, valTime.Day()) // test set empty err = suite.Database.Set(ctx) suite.NoError(err) // test set duplicate err = suite.Database.Set(ctx, String("100", "1"), String("100", "2")) suite.NoError(err) } func (suite *baseTestSuite) TestExists() { ctx := suite.T().Context() // Test non-existent key ret := suite.Database.Get(ctx, Key("test", "nonexistent")) suite.False(ret.Exists()) value, err := ret.String() suite.NoError(err) suite.Equal("", value) // Set a value err = suite.Database.Set(ctx, String(Key("test", "exists"), "somevalue")) suite.NoError(err) // Test existing key ret = suite.Database.Get(ctx, Key("test", "exists")) suite.True(ret.Exists()) value, err = ret.String() suite.NoError(err) suite.Equal("somevalue", value) // Delete the key err = suite.Database.Delete(ctx, Key("test", "exists")) suite.NoError(err) // Test deleted key no longer exists ret = suite.Database.Get(ctx, Key("test", "exists")) suite.False(ret.Exists()) // Test with integer values err = suite.Database.Set(ctx, Integer(Key("test", "int"), 42)) suite.NoError(err) ret = suite.Database.Get(ctx, Key("test", "int")) suite.True(ret.Exists()) intVal, err := ret.Integer() suite.NoError(err) suite.Equal(42, intVal) // Test non-existent integer - should return 0 with no error ret = suite.Database.Get(ctx, Key("test", "noint")) suite.False(ret.Exists()) intVal, err = ret.Integer() suite.NoError(err) suite.Equal(0, intVal) } func (suite *baseTestSuite) TestScan() { ctx := suite.T().Context() err := suite.Database.Set(ctx, String("1", "1")) suite.NoError(err) var keys []string err = suite.Database.Scan(func(s string) error { keys = append(keys, s) return nil }) suite.NoError(err) suite.ElementsMatch([]string{"1"}, keys) } func (suite *baseTestSuite) TestPurge() { ctx := suite.T().Context() // insert data err := suite.Database.Set(ctx, String("key", "value")) suite.NoError(err) ret := suite.Database.Get(ctx, "key") suite.NoError(ret.err) suite.Equal("value", ret.value) suite.True(ret.Exists()) // purge data err = suite.Database.Purge() suite.NoError(err) ret = suite.Database.Get(ctx, "key") suite.False(ret.Exists()) // purge empty dataset err = suite.Database.Purge() suite.NoError(err) } func (suite *baseTestSuite) TestPushPop() { ctx := suite.T().Context() err := suite.Push(ctx, "a", "1") suite.NoError(err) err = suite.Push(ctx, "a", "2") suite.NoError(err) count, err := suite.Remain(ctx, "a") suite.NoError(err) suite.Equal(int64(2), count) err = suite.Push(ctx, "b", "1") suite.NoError(err) err = suite.Push(ctx, "b", "2") suite.NoError(err) err = suite.Push(ctx, "b", "1") suite.NoError(err) count, err = suite.Remain(ctx, "b") suite.NoError(err) suite.Equal(int64(2), count) value, err := suite.Pop(ctx, "a") suite.NoError(err) suite.Equal("1", value) value, err = suite.Pop(ctx, "a") suite.NoError(err) suite.Equal("2", value) _, err = suite.Pop(ctx, "a") suite.ErrorIs(err, io.EOF) value, err = suite.Pop(ctx, "b") suite.NoError(err) suite.Equal("2", value) value, err = suite.Pop(ctx, "b") suite.NoError(err) suite.Equal("1", value) _, err = suite.Pop(ctx, "b") suite.ErrorIs(err, io.EOF) } func (suite *baseTestSuite) TestDocument() { ts := time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC) ctx := suite.T().Context() err := suite.AddScores(ctx, "a", "", []Score{{ Id: "0", Score: math.MaxFloat64, IsHidden: true, Categories: []string{"a", "b"}, Timestamp: ts, }}) suite.NoError(err) err = suite.AddScores(ctx, "a", "", []Score{{ Id: "1", Score: 100, Categories: []string{"a", "b"}, Timestamp: ts, }}) suite.NoError(err) err = suite.AddScores(ctx, "a", "", []Score{ { Id: "1", Score: 1, Categories: []string{"a", "b"}, Timestamp: ts, }, { Id: "2", Score: 2, Categories: []string{"b", "c"}, Timestamp: ts, }, { Id: "3", Score: 3, Categories: []string{"b"}, Timestamp: time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), }, { Id: "4", Score: 4, Categories: []string{""}, Timestamp: ts, }, { Id: "5", Score: 5, Categories: []string{"b"}, Timestamp: ts, }, }) suite.NoError(err) err = suite.AddScores(ctx, "b", "", []Score{{ Id: "6", Score: 6, Categories: []string{"b"}, Timestamp: ts, }}) suite.NoError(err) // search documents documents, err := suite.SearchScores(ctx, "a", "", []string{"b"}, 1, 3) suite.NoError(err) suite.Equal([]Score{ {Id: "3", Score: 3, Categories: []string{"b"}, Timestamp: time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)}, {Id: "2", Score: 2, Categories: []string{"b", "c"}, Timestamp: ts}, }, documents) documents, err = suite.SearchScores(ctx, "a", "", []string{"b"}, 0, -1) suite.NoError(err) suite.Equal([]Score{ {Id: "5", Score: 5, Categories: []string{"b"}, Timestamp: ts}, {Id: "3", Score: 3, Categories: []string{"b"}, Timestamp: time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)}, {Id: "2", Score: 2, Categories: []string{"b", "c"}, Timestamp: ts}, {Id: "1", Score: 1, Categories: []string{"a", "b"}, Timestamp: ts}, }, documents) // search documents with nil category documents, err = suite.SearchScores(ctx, "a", "", nil, 0, -1) suite.NoError(err) suite.Equal([]Score{ {Id: "5", Score: 5, Categories: []string{"b"}, Timestamp: ts}, {Id: "4", Score: 4, Categories: []string{""}, Timestamp: ts}, {Id: "3", Score: 3, Categories: []string{"b"}, Timestamp: time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)}, {Id: "2", Score: 2, Categories: []string{"b", "c"}, Timestamp: ts}, {Id: "1", Score: 1, Categories: []string{"a", "b"}, Timestamp: ts}, }, documents) // search documents with empty category documents, err = suite.SearchScores(ctx, "a", "", []string{""}, 0, -1) suite.NoError(err) suite.Equal([]Score{{Id: "4", Score: 4, Categories: []string{""}, Timestamp: ts}}, documents) // delete nothing err = suite.DeleteScores(ctx, []string{"a"}, ScoreCondition{}) suite.ErrorIs(err, errors.NotValid) // delete by value err = suite.DeleteScores(ctx, []string{"a"}, ScoreCondition{Id: new("5")}) suite.NoError(err) documents, err = suite.SearchScores(ctx, "a", "", []string{"b"}, 0, 1) suite.NoError(err) suite.Len(documents, 1) suite.Equal("3", documents[0].Id) // delete by timestamp err = suite.DeleteScores(ctx, []string{"a"}, ScoreCondition{Before: lo.ToPtr(time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC))}) suite.NoError(err) documents, err = suite.SearchScores(ctx, "a", "", []string{"b"}, 0, 1) suite.NoError(err) suite.Len(documents, 1) suite.Equal("2", documents[0].Id) // update categories err = suite.UpdateScores(ctx, []string{"a"}, nil, "2", ScorePatch{Categories: []string{"c", "s"}}) suite.NoError(err) documents, err = suite.SearchScores(ctx, "a", "", []string{"s"}, 0, 1) suite.NoError(err) suite.Len(documents, 1) suite.Equal("2", documents[0].Id) err = suite.UpdateScores(ctx, []string{"a"}, nil, "2", ScorePatch{Categories: []string{"c"}}) suite.NoError(err) documents, err = suite.SearchScores(ctx, "a", "", []string{"s"}, 0, 1) suite.NoError(err) suite.Empty(documents) // update is hidden err = suite.UpdateScores(ctx, []string{"a"}, nil, "0", ScorePatch{IsHidden: new(false)}) suite.NoError(err) documents, err = suite.SearchScores(ctx, "a", "", []string{"b"}, 0, 1) suite.NoError(err) suite.Len(documents, 1) suite.Equal("0", documents[0].Id) } func (suite *baseTestSuite) TestSubsetDocument() { ts := time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC) ctx := suite.T().Context() err := suite.AddScores(ctx, "a", "a", []Score{ { Id: "1", Score: 1, Categories: []string{"a", "b"}, Timestamp: ts, }, { Id: "2", Score: 2, Categories: []string{"b", "c"}, Timestamp: ts, }, { Id: "3", Score: 3, Categories: []string{"b"}, Timestamp: ts, }, }) suite.NoError(err) err = suite.AddScores(ctx, "b", "", []Score{ { Id: "4", Score: 4, Categories: []string{}, Timestamp: ts, }, { Id: "3", Score: 3, Categories: []string{"b"}, Timestamp: ts, }, { Id: "2", Score: 2, Categories: []string{"b"}, Timestamp: ts, }, }) suite.NoError(err) // search documents documents, err := suite.SearchScores(ctx, "a", "a", []string{"b"}, 0, -1) suite.NoError(err) suite.Equal([]Score{ {Id: "3", Score: 3, Categories: []string{"b"}, Timestamp: ts}, {Id: "2", Score: 2, Categories: []string{"b", "c"}, Timestamp: ts}, {Id: "1", Score: 1, Categories: []string{"a", "b"}, Timestamp: ts}, }, documents) // update categories err = suite.UpdateScores(ctx, []string{"a", "b"}, nil, "2", ScorePatch{Categories: []string{"b", "s"}}) suite.NoError(err) documents, err = suite.SearchScores(ctx, "a", "a", []string{"s"}, 0, 1) suite.NoError(err) suite.Len(documents, 1) suite.Equal("2", documents[0].Id) documents, err = suite.SearchScores(ctx, "b", "", []string{"s"}, 0, 1) suite.NoError(err) suite.Len(documents, 1) suite.Equal("2", documents[0].Id) // update categories in subset err = suite.UpdateScores(ctx, []string{"a", "b"}, new("a"), "2", ScorePatch{Categories: []string{"b", "x"}}) suite.NoError(err) documents, err = suite.SearchScores(ctx, "a", "a", []string{"x"}, 0, 1) suite.NoError(err) suite.Len(documents, 1) suite.Equal("2", documents[0].Id) documents, err = suite.SearchScores(ctx, "b", "", []string{"x"}, 0, 1) suite.NoError(err) suite.Empty(documents) // delete by value err = suite.DeleteScores(ctx, []string{"a", "b"}, ScoreCondition{Id: new("3")}) suite.NoError(err) documents, err = suite.SearchScores(ctx, "a", "a", []string{"b"}, 0, 1) suite.NoError(err) suite.Len(documents, 1) suite.Equal("2", documents[0].Id) documents, err = suite.SearchScores(ctx, "b", "", []string{"b"}, 0, 1) suite.NoError(err) suite.Len(documents, 1) suite.Equal("2", documents[0].Id) // delete in subset err = suite.DeleteScores(ctx, []string{"a", "b"}, ScoreCondition{ Subset: new("a"), Id: new("2"), }) suite.NoError(err) documents, err = suite.SearchScores(ctx, "a", "a", []string{"b"}, 0, 1) suite.NoError(err) suite.Len(documents, 1) suite.Equal("1", documents[0].Id) documents, err = suite.SearchScores(ctx, "b", "", []string{"b"}, 0, 1) suite.NoError(err) suite.Len(documents, 1) suite.Equal("2", documents[0].Id) } func (suite *baseTestSuite) TestScanScores() { // add scores timestamp := time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC) scores := map[lo.Tuple2[string, string]][]Score{ {"a", "b"}: { {Id: "1", Score: 1, Categories: []string{"a", "b"}, Timestamp: timestamp}, {Id: "2", Score: 2, Categories: []string{"b", "c"}, Timestamp: timestamp}, {Id: "3", Score: 3, Categories: []string{"b"}, Timestamp: timestamp}, }, {"a", "c"}: { {Id: "4", Score: 4, Categories: []string{"a", "b"}, Timestamp: timestamp}, {Id: "5", Score: 5, Categories: []string{"b", "c"}, Timestamp: timestamp}, {Id: "6", Score: 6, Categories: []string{"b"}, Timestamp: timestamp}, }, {"b", "c"}: { {Id: "7", Score: 7, Categories: []string{"a", "b"}, Timestamp: timestamp}, {Id: "8", Score: 8, Categories: []string{"b", "c"}, Timestamp: timestamp}, {Id: "9", Score: 9, Categories: []string{"b"}, Timestamp: timestamp}, }, } for k, v := range scores { err := suite.AddScores(suite.T().Context(), k.A, k.B, v) suite.NoError(err) } // scan scores totalScores := 0 ctx := suite.T().Context() err := suite.ScanScores(ctx, func(collection, id, subset string, t time.Time) error { totalScores++ suite.Equal(timestamp, t.UTC()) return nil }) suite.NoError(err) suite.Equal(9, totalScores) // scan scores with timeout scanScores := 0 ctx, cancel := context.WithTimeout(suite.T().Context(), time.Millisecond) defer cancel() err = suite.ScanScores(ctx, func(collection, id, subset string, timestamp time.Time) error { time.Sleep(time.Millisecond) scanScores++ return nil }) if err != nil && status.Code(err) != codes.DeadlineExceeded { suite.ErrorIs(err, context.DeadlineExceeded) } suite.Less(scanScores, 9) } func (suite *baseTestSuite) TestTimeSeries() { ts := time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC) ctx := suite.T().Context() err := suite.AddTimeSeriesPoints(ctx, []TimeSeriesPoint{ {Name: "a", Value: 1, Timestamp: ts.Add(1 * time.Second)}, {Name: "a", Value: 2, Timestamp: ts.Add(2 * time.Second)}, {Name: "a", Value: 3, Timestamp: ts.Add(3 * time.Second)}, {Name: "a", Value: 4, Timestamp: ts.Add(4 * time.Second)}, {Name: "a", Value: 5, Timestamp: ts.Add(5 * time.Second)}, {Name: "b", Value: 3, Timestamp: ts.Add(3 * time.Second)}, }) suite.NoError(err) points, err := suite.GetTimeSeriesPoints(ctx, "a", ts.Add(2*time.Second), ts.Add(4*time.Second), time.Second) suite.NoError(err) suite.Equal([]TimeSeriesPoint{ {Name: "a", Value: 2, Timestamp: ts.Add(2 * time.Second)}, {Name: "a", Value: 3, Timestamp: ts.Add(3 * time.Second)}, {Name: "a", Value: 4, Timestamp: ts.Add(4 * time.Second)}, }, points) points, err = suite.GetTimeSeriesPoints(ctx, "a", ts.Add(2*time.Second), ts.Add(4*time.Second), 2*time.Second) suite.NoError(err) suite.Equal([]TimeSeriesPoint{ {Name: "a", Value: 3, Timestamp: ts.Add(2 * time.Second)}, {Name: "a", Value: 4, Timestamp: ts.Add(4 * time.Second)}, }, points) } func (suite *baseTestSuite) TestTimestampPrecision() { ctx := suite.T().Context() timestamp := time.Date(2023, 1, 1, 0, 0, 0, 500, time.UTC) // add scores err := suite.Database.AddScores(ctx, "a", "s", []Score{ {Id: "1", Score: 1, Categories: []string{""}, Timestamp: timestamp}, }) suite.NoError(err) // remove by timestamp err = suite.Database.DeleteScores(ctx, []string{"a"}, ScoreCondition{ Subset: new("s"), Before: lo.ToPtr(timestamp)}) suite.NoError(err) // search scores documents, err := suite.Database.SearchScores(ctx, "a", "s", nil, 0, -1) suite.NoError(err) suite.NotEmpty(documents) } func TestKey(t *testing.T) { assert.Empty(t, Key()) assert.Equal(t, "a", Key("a")) assert.Equal(t, "a", Key("a", "")) assert.Equal(t, "a/b", Key("a", "b")) } var ( benchmarkDataSize = 100000 primeTable []int ) func init() { benchmarkDataSizeStr := os.Getenv("BENCHMARK_DATA_SIZE") if benchmarkDataSizeStr != "" { benchmarkDataSize, _ = strconv.Atoi(benchmarkDataSizeStr) } primeTable = primes.Sieve(benchmarkDataSize) } func primeFactor(n int) []int { var factors []int for _, p := range primeTable { if n%p == 0 { factors = append(factors, p) } } return factors } func benchmark(b *testing.B, database Database) { b.Run("AddScores", func(b *testing.B) { benchmarkAddDocuments(b, database) }) b.Run("SearchScores", func(b *testing.B) { benchmarkSearchDocuments(b, database) }) b.Run("UpdateScores", func(b *testing.B) { benchmarkUpdateDocuments(b, database) }) } func benchmarkAddDocuments(b *testing.B, database Database) { ctx := b.Context() var documents []Score for i := 1; i <= b.N; i++ { documents = append(documents, Score{ Id: strconv.Itoa(i), Score: float64(-i), Categories: lo.Map(primeFactor(i), func(n, _ int) string { return strconv.Itoa(n) }), Timestamp: time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC), }) } b.ResetTimer() err := database.AddScores(ctx, "a", "", documents) assert.NoError(b, err) } func benchmarkSearchDocuments(b *testing.B, database Database) { // insert data ctx := b.Context() var documents []Score for i := 1; i <= benchmarkDataSize; i++ { documents = append(documents, Score{ Id: strconv.Itoa(i), Score: float64(-i), Categories: lo.Map(primeFactor(i), func(n, _ int) string { return strconv.Itoa(n) }), Timestamp: time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC), }) } err := database.AddScores(ctx, "a", "", documents) assert.NoError(b, err) // search data b.ResetTimer() for i := 0; i < b.N; i++ { // select a random prime p := primeTable[rand.Intn(len(primeTable))] // search documents r, err := database.SearchScores(ctx, "a", "", []string{strconv.Itoa(p)}, 0, 10) assert.NoError(b, err) assert.NotEmpty(b, r) } } func benchmarkUpdateDocuments(b *testing.B, database Database) { ctx := b.Context() b.ResetTimer() for i := 1; i <= b.N; i++ { // select a random number n := rand.Intn(benchmarkDataSize) + 1 // update documents err := database.UpdateScores(ctx, []string{"a"}, nil, strconv.Itoa(n), ScorePatch{ Score: new(float64(n)), }) assert.NoError(b, err) } } ================================================ FILE: storage/cache/mongodb.go ================================================ // Copyright 2022 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package cache import ( "context" "io" "time" "github.com/gorse-io/gorse/storage" "github.com/juju/errors" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/x/mongo/driver/connstring" "go.opentelemetry.io/contrib/instrumentation/go.mongodb.org/mongo-driver/mongo/otelmongo" ) func init() { Register([]string{storage.MongoPrefix, storage.MongoSrvPrefix}, func(path, tablePrefix string, opts ...storage.Option) (Database, error) { // connect to database database := new(MongoDB) clientOpts := options.Client() clientOpts.Monitor = otelmongo.NewMonitor() clientOpts.ApplyURI(path) var err error if database.client, err = mongo.Connect(context.Background(), clientOpts); err != nil { return nil, errors.Trace(err) } // parse DSN and extract database name if cs, err := connstring.ParseAndValidate(path); err != nil { return nil, errors.Trace(err) } else { database.dbName = cs.Database database.TablePrefix = storage.TablePrefix(tablePrefix) } return database, nil }) } type MongoDB struct { storage.TablePrefix client *mongo.Client dbName string } func (m MongoDB) Init() error { ctx := context.Background() d := m.client.Database(m.dbName) // list collections var hasValues bool collections, err := d.ListCollectionNames(ctx, bson.M{}) if err != nil { return errors.Trace(err) } for _, collectionName := range collections { switch collectionName { case m.ValuesTable(): hasValues = true } } // create collections if !hasValues { if err = d.CreateCollection(ctx, m.ValuesTable()); err != nil { return errors.Trace(err) } } _, err = d.Collection(m.MessageTable()).Indexes().CreateMany(ctx, []mongo.IndexModel{ { // update set ... where name = ? and value = ? Keys: bson.D{ {"name", 1}, {"value", 1}, }, }, { // select * from messages where name = ? order by timestamp asc limit 1 Keys: bson.D{ {"name", 1}, {"timestamp", 1}, }, }, }) if err != nil { return errors.Trace(err) } _, err = d.Collection(m.DocumentTable()).Indexes().CreateMany(ctx, []mongo.IndexModel{ { Keys: bson.D{ {"collection", 1}, {"subset", 1}, {"id", 1}, }, }, { Keys: bson.D{ {"collection", 1}, {"subset", 1}, {"categories", 1}, {"is_hidden", 1}, {"score", -1}, }, }, }) if err != nil { return errors.Trace(err) } _, err = d.Collection(m.PointsTable()).Indexes().CreateMany(ctx, []mongo.IndexModel{ { // update set ... where name = ? and timestammp = ? Keys: bson.D{ {"name", 1}, {"timestamp", 1}, }, }, }) if err != nil { return errors.Trace(err) } return nil } func (m MongoDB) Close() error { return m.client.Disconnect(context.Background()) } func (m MongoDB) Ping() error { return m.client.Ping(context.Background(), nil) } func (m MongoDB) Scan(work func(string) error) error { ctx := context.Background() // scan values valuesCollection := m.client.Database(m.dbName).Collection(m.ValuesTable()) valuesIterator, err := valuesCollection.Find(ctx, bson.M{}) if err != nil { return errors.Trace(err) } defer valuesIterator.Close(ctx) for valuesIterator.Next(ctx) { var row bson.Raw if err = valuesIterator.Decode(&row); err != nil { return errors.Trace(err) } if err = work(row.Lookup("_id").StringValue()); err != nil { return errors.Trace(err) } } return nil } func (m MongoDB) Purge() error { tables := []string{m.ValuesTable(), m.DocumentTable()} for _, tableName := range tables { c := m.client.Database(m.dbName).Collection(tableName) _, err := c.DeleteMany(context.Background(), bson.D{}) if err != nil { return errors.Trace(err) } } return nil } func (m MongoDB) Set(ctx context.Context, values ...Value) error { if len(values) == 0 { return nil } c := m.client.Database(m.dbName).Collection(m.ValuesTable()) var models []mongo.WriteModel for _, value := range values { models = append(models, mongo.NewUpdateOneModel(). SetUpsert(true). SetFilter(bson.M{"_id": value.name}). SetUpdate(bson.M{"$set": bson.M{"_id": value.name, "value": value.value}})) } _, err := c.BulkWrite(ctx, models) return errors.Trace(err) } func (m MongoDB) Get(ctx context.Context, name string) *ReturnValue { c := m.client.Database(m.dbName).Collection(m.ValuesTable()) r := c.FindOne(ctx, bson.M{"_id": bson.M{"$eq": name}}) if err := r.Err(); err == mongo.ErrNoDocuments { return &ReturnValue{value: "", exists: false} } else if err != nil { return &ReturnValue{err: errors.Trace(err), exists: false} } if raw, err := r.DecodeBytes(); err != nil { return &ReturnValue{err: errors.Trace(err), exists: false} } else { return &ReturnValue{value: raw.Lookup("value").StringValue(), exists: true} } } func (m MongoDB) Delete(ctx context.Context, name string) error { c := m.client.Database(m.dbName).Collection(m.ValuesTable()) _, err := c.DeleteOne(ctx, bson.M{"_id": bson.M{"$eq": name}}) return errors.Trace(err) } func (m MongoDB) Push(ctx context.Context, name, value string) error { _, err := m.client.Database(m.dbName).Collection(m.MessageTable()).UpdateOne(ctx, bson.M{"name": name, "value": value}, bson.M{"$set": bson.M{"name": name, "value": value, "timestamp": time.Now().UnixNano()}}, options.Update().SetUpsert(true)) return err } func (m MongoDB) Pop(ctx context.Context, name string) (string, error) { result := m.client.Database(m.dbName).Collection(m.MessageTable()).FindOneAndDelete(ctx, bson.M{"name": name}, options.FindOneAndDelete().SetSort(bson.M{"timestamp": 1})) if err := result.Err(); err == mongo.ErrNoDocuments { return "", io.EOF } else if err != nil { return "", errors.Trace(err) } var b bson.M if err := result.Decode(&b); err != nil { return "", errors.Trace(err) } return b["value"].(string), nil } func (m MongoDB) Remain(ctx context.Context, name string) (int64, error) { return m.client.Database(m.dbName).Collection(m.MessageTable()).CountDocuments(ctx, bson.M{ "name": name, }) } func (m MongoDB) AddScores(ctx context.Context, collection, subset string, documents []Score) error { if len(documents) == 0 { return nil } var models []mongo.WriteModel for _, document := range documents { models = append(models, mongo.NewUpdateOneModel(). SetUpsert(true). SetFilter(bson.M{ "collection": collection, "subset": subset, "id": document.Id, }). SetUpdate(bson.M{"$set": bson.M{ "score": document.Score, "is_hidden": document.IsHidden, "categories": document.Categories, "timestamp": document.Timestamp, }})) } _, err := m.client.Database(m.dbName).Collection(m.DocumentTable()).BulkWrite(ctx, models) return errors.Trace(err) } func (m MongoDB) SearchScores(ctx context.Context, collection, subset string, query []string, begin, end int) ([]Score, error) { opt := options.Find().SetSkip(int64(begin)).SetSort(bson.M{"score": -1}) if end != -1 { opt.SetLimit(int64(end - begin)) } filter := bson.M{ "collection": collection, "subset": subset, "is_hidden": false, } if len(query) > 0 { filter["categories"] = bson.M{"$all": query} } cur, err := m.client.Database(m.dbName).Collection(m.DocumentTable()).Find(ctx, filter, opt) if err != nil { return nil, errors.Trace(err) } documents := make([]Score, 0) for cur.Next(ctx) { var document Score if err = cur.Decode(&document); err != nil { return nil, errors.Trace(err) } documents = append(documents, document) } return documents, nil } func (m MongoDB) UpdateScores(ctx context.Context, collections []string, subset *string, id string, patch ScorePatch) error { if len(collections) == 0 { return nil } if patch.IsHidden == nil && patch.Categories == nil && patch.Score == nil { return nil } filter := bson.M{ "collection": bson.M{"$in": collections}, "id": id, } if subset != nil { filter["subset"] = *subset } update := bson.D{} if patch.IsHidden != nil { update = append(update, bson.E{Key: "$set", Value: bson.M{"is_hidden": *patch.IsHidden}}) } if patch.Categories != nil { update = append(update, bson.E{Key: "$set", Value: bson.M{"categories": patch.Categories}}) } if patch.Score != nil { update = append(update, bson.E{Key: "$set", Value: bson.M{"score": *patch.Score}}) } _, err := m.client.Database(m.dbName).Collection(m.DocumentTable()).UpdateMany(ctx, filter, update) return errors.Trace(err) } func (m MongoDB) DeleteScores(ctx context.Context, collections []string, condition ScoreCondition) error { if err := condition.Check(); err != nil { return errors.Trace(err) } filter := bson.M{"collection": bson.M{"$in": collections}} if condition.Subset != nil { filter["subset"] = condition.Subset } if condition.Id != nil { filter["id"] = condition.Id } if condition.Before != nil { filter["timestamp"] = bson.M{"$lt": condition.Before} } _, err := m.client.Database(m.dbName).Collection(m.DocumentTable()).DeleteMany(ctx, filter) return errors.Trace(err) } func (m MongoDB) ScanScores(ctx context.Context, callback func(collection string, id string, subset string, timestamp time.Time) error) error { cursor, err := m.client.Database(m.dbName).Collection(m.DocumentTable()).Find(ctx, bson.M{}) if err != nil { return errors.Trace(err) } defer cursor.Close(ctx) for cursor.Next(ctx) { // check context cancellation select { case <-ctx.Done(): return errors.Trace(ctx.Err()) default: } // decode document collection := cursor.Current.Lookup("collection").StringValue() subset := cursor.Current.Lookup("subset").StringValue() id := cursor.Current.Lookup("id").StringValue() timestamp := cursor.Current.Lookup("timestamp").Time() if err = callback(collection, id, subset, timestamp); err != nil { return errors.Trace(err) } } return nil } func (m MongoDB) AddTimeSeriesPoints(ctx context.Context, points []TimeSeriesPoint) error { var models []mongo.WriteModel for _, point := range points { models = append(models, mongo.NewUpdateOneModel(). SetUpsert(true). SetFilter(bson.M{ "name": point.Name, "timestamp": point.Timestamp, }). SetUpdate(bson.M{"$set": bson.M{ "value": point.Value, }})) } _, err := m.client.Database(m.dbName).Collection(m.PointsTable()).BulkWrite(ctx, models) return errors.Trace(err) } func (m MongoDB) GetTimeSeriesPoints(ctx context.Context, name string, begin, end time.Time, duration time.Duration) ([]TimeSeriesPoint, error) { cursor, err := m.client.Database(m.dbName).Collection(m.PointsTable()).Aggregate(ctx, []bson.M{ {"$match": bson.M{ "name": name, "timestamp": bson.M{"$gte": begin, "$lte": end}, }}, {"$sort": bson.M{"timestamp": -1}}, {"$group": bson.M{ "_id": bson.M{ "$multiply": bson.A{ bson.M{"$floor": bson.M{"$divide": bson.A{bson.M{"$toLong": "$timestamp"}, duration.Milliseconds()}}}, duration.Milliseconds(), }, }, "name": bson.M{"$first": "$name"}, "value": bson.M{"$first": "$value"}, }}, {"$project": bson.M{ "_id": 0, "timestamp": bson.M{"$toDate": "$_id"}, "name": 1, "value": 1, }}, }) if err != nil { return nil, errors.Trace(err) } defer cursor.Close(ctx) var points []TimeSeriesPoint for cursor.Next(ctx) { var point TimeSeriesPoint if err = cursor.Decode(&point); err != nil { return nil, errors.Trace(err) } points = append(points, point) } if err = cursor.Err(); err != nil { return nil, errors.Trace(err) } return points, nil } ================================================ FILE: storage/cache/mongodb_test.go ================================================ // Copyright 2022 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package cache import ( "os" "testing" "github.com/gorse-io/gorse/common/log" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" ) var ( mongoUri string ) func init() { // get environment variables env := func(key, defaultValue string) string { if value := os.Getenv(key); value != "" { return value } return defaultValue } mongoUri = env("MONGO_URI", "mongodb://root:password@127.0.0.1:27017/") } type MongoTestSuite struct { baseTestSuite } func (suite *MongoTestSuite) SetupSuite() { ctx := suite.T().Context() var err error // create database suite.Database, err = Open(mongoUri, "gorse_") suite.NoError(err) dbName := "gorse_cache_test" databaseComm := suite.getMongoDB() suite.NoError(err) err = databaseComm.client.Database(dbName).Drop(ctx) if err == nil { suite.T().Log("delete existed database:", dbName) } err = suite.Database.Close() suite.NoError(err) // create schema suite.Database, err = Open(mongoUri+dbName+"?authSource=admin&connect=direct", "gorse_") suite.NoError(err) err = suite.Database.Init() suite.NoError(err) } func (suite *MongoTestSuite) getMongoDB() *MongoDB { var mongoDatabase *MongoDB var ok bool mongoDatabase, ok = suite.Database.(*MongoDB) suite.True(ok) return mongoDatabase } func TestMongo(t *testing.T) { suite.Run(t, new(MongoTestSuite)) } func BenchmarkMongo(b *testing.B) { log.CloseLogger() ctx := b.Context() // create database database, err := Open(mongoUri, "gorse_") assert.NoError(b, err) dbName := "gorse_cache_benchmark" databaseComm := database.(*MongoDB) _ = databaseComm.client.Database(dbName).Drop(ctx) err = database.Close() assert.NoError(b, err) // create schema database, err = Open(mongoUri+dbName+"?authSource=admin&connect=direct", "gorse_") assert.NoError(b, err) err = database.Init() assert.NoError(b, err) // benchmark benchmark(b, database) } ================================================ FILE: storage/cache/no_database.go ================================================ // Copyright 2020 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package cache import ( "context" "time" ) // NoDatabase means no database used for cache. type NoDatabase struct{} // Close method of NoDatabase returns ErrNoDatabase. func (NoDatabase) Close() error { return ErrNoDatabase } func (NoDatabase) Ping() error { return ErrNoDatabase } // Init method of NoDatabase returns ErrNoDatabase. func (NoDatabase) Init() error { return ErrNoDatabase } func (NoDatabase) Scan(_ func(string) error) error { return ErrNoDatabase } func (NoDatabase) Purge() error { return ErrNoDatabase } func (NoDatabase) Set(_ context.Context, _ ...Value) error { return ErrNoDatabase } // Get method of NoDatabase returns ErrNoDatabase. func (NoDatabase) Get(_ context.Context, _ string) *ReturnValue { return &ReturnValue{err: ErrNoDatabase, exists: false} } // Delete method of NoDatabase returns ErrNoDatabase. func (NoDatabase) Delete(_ context.Context, _ string) error { return ErrNoDatabase } func (NoDatabase) Push(_ context.Context, _, _ string) error { return ErrNoDatabase } func (NoDatabase) Pop(_ context.Context, _ string) (string, error) { return "", ErrNoDatabase } func (NoDatabase) Remain(_ context.Context, _ string) (int64, error) { return 0, ErrNoDatabase } func (NoDatabase) AddScores(_ context.Context, _, _ string, _ []Score) error { return ErrNoDatabase } func (NoDatabase) SearchScores(_ context.Context, _, _ string, _ []string, _, _ int) ([]Score, error) { return nil, ErrNoDatabase } func (NoDatabase) UpdateScores(context.Context, []string, *string, string, ScorePatch) error { return ErrNoDatabase } func (NoDatabase) DeleteScores(_ context.Context, _ []string, _ ScoreCondition) error { return ErrNoDatabase } func (NoDatabase) ScanScores(context.Context, func(collection, id, subset string, timestamp time.Time) error) error { return ErrNoDatabase } func (NoDatabase) AddTimeSeriesPoints(_ context.Context, _ []TimeSeriesPoint) error { return ErrNoDatabase } func (NoDatabase) GetTimeSeriesPoints(_ context.Context, _ string, _, _ time.Time, _ time.Duration) ([]TimeSeriesPoint, error) { return nil, ErrNoDatabase } ================================================ FILE: storage/cache/no_database_test.go ================================================ // Copyright 2020 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package cache import ( "testing" "time" "github.com/stretchr/testify/assert" ) func TestNoDatabase(t *testing.T) { ctx := t.Context() var database NoDatabase err := database.Close() assert.ErrorIs(t, err, ErrNoDatabase) err = database.Ping() assert.ErrorIs(t, err, ErrNoDatabase) err = database.Init() assert.ErrorIs(t, err, ErrNoDatabase) err = database.Scan(nil) assert.ErrorIs(t, err, ErrNoDatabase) err = database.Purge() assert.ErrorIs(t, err, ErrNoDatabase) err = database.Set(ctx) assert.ErrorIs(t, err, ErrNoDatabase) _, err = database.Get(ctx, Key("", "")).String() assert.ErrorIs(t, err, ErrNoDatabase) _, err = database.Get(ctx, Key("", "")).Integer() assert.ErrorIs(t, err, ErrNoDatabase) _, err = database.Get(ctx, Key("", "")).Time() assert.ErrorIs(t, err, ErrNoDatabase) err = database.Delete(ctx, Key("", "")) assert.ErrorIs(t, err, ErrNoDatabase) err = database.Push(ctx, "", "") assert.ErrorIs(t, err, ErrNoDatabase) _, err = database.Pop(ctx, "") assert.ErrorIs(t, err, ErrNoDatabase) _, err = database.Remain(ctx, "") assert.ErrorIs(t, err, ErrNoDatabase) err = database.AddScores(ctx, "", "", nil) assert.ErrorIs(t, err, ErrNoDatabase) _, err = database.SearchScores(ctx, "", "", nil, 0, 0) assert.ErrorIs(t, err, ErrNoDatabase) err = database.UpdateScores(ctx, nil, nil, "", ScorePatch{}) assert.ErrorIs(t, err, ErrNoDatabase) err = database.DeleteScores(ctx, nil, ScoreCondition{}) assert.ErrorIs(t, err, ErrNoDatabase) err = database.ScanScores(ctx, nil) assert.ErrorIs(t, err, ErrNoDatabase) err = database.AddTimeSeriesPoints(ctx, nil) assert.ErrorIs(t, err, ErrNoDatabase) _, err = database.GetTimeSeriesPoints(ctx, "", time.Time{}, time.Time{}, 0) assert.ErrorIs(t, err, ErrNoDatabase) } ================================================ FILE: storage/cache/proxy.go ================================================ // Copyright 2024 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package cache import ( "context" "io" "net" "time" "github.com/gorse-io/gorse/protocol" "github.com/juju/errors" "github.com/samber/lo" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/timestamppb" ) type ProxyServer struct { protocol.UnimplementedCacheStoreServer database Database server *grpc.Server } func NewProxyServer(database Database) *ProxyServer { return &ProxyServer{database: database} } func (p *ProxyServer) Serve(lis net.Listener) error { p.server = grpc.NewServer() protocol.RegisterCacheStoreServer(p.server, p) return p.server.Serve(lis) } func (p *ProxyServer) Stop() { p.server.Stop() } func (p *ProxyServer) Ping(context.Context, *protocol.PingRequest) (*protocol.PingResponse, error) { return &protocol.PingResponse{}, p.database.Ping() } func (p *ProxyServer) Get(ctx context.Context, request *protocol.GetRequest) (*protocol.GetResponse, error) { value := p.database.Get(ctx, request.GetName()) if !value.Exists() { return &protocol.GetResponse{}, nil } return &protocol.GetResponse{Value: new(value.value)}, value.err } func (p *ProxyServer) Set(ctx context.Context, request *protocol.SetRequest) (*protocol.SetResponse, error) { values := make([]Value, len(request.Values)) for i, value := range request.Values { values[i] = Value{ name: value.GetName(), value: value.GetValue(), } } return &protocol.SetResponse{}, p.database.Set(ctx, values...) } func (p *ProxyServer) Delete(ctx context.Context, request *protocol.DeleteRequest) (*protocol.DeleteResponse, error) { return &protocol.DeleteResponse{}, p.database.Delete(ctx, request.GetName()) } func (p *ProxyServer) Push(ctx context.Context, request *protocol.PushRequest) (*protocol.PushResponse, error) { return &protocol.PushResponse{}, p.database.Push(ctx, request.GetName(), request.GetValue()) } func (p *ProxyServer) Pop(ctx context.Context, request *protocol.PopRequest) (*protocol.PopResponse, error) { value, err := p.database.Pop(ctx, request.GetName()) if err != nil { if errors.Is(err, io.EOF) { return &protocol.PopResponse{}, nil } return nil, err } return &protocol.PopResponse{Value: &value}, nil } func (p *ProxyServer) Remain(ctx context.Context, request *protocol.RemainRequest) (*protocol.RemainResponse, error) { count, err := p.database.Remain(ctx, request.GetName()) if err != nil { return nil, err } return &protocol.RemainResponse{Count: count}, nil } func (p *ProxyServer) AddScores(ctx context.Context, request *protocol.AddScoresRequest) (*protocol.AddScoresResponse, error) { scores := make([]Score, len(request.Documents)) for i, doc := range request.Documents { scores[i] = Score{ Id: doc.GetId(), Score: doc.GetScore(), IsHidden: doc.GetIsHidden(), Categories: doc.GetCategories(), Timestamp: doc.GetTimestamp().AsTime(), } } return &protocol.AddScoresResponse{}, p.database.AddScores(ctx, request.GetCollection(), request.GetSubset(), scores) } func (p *ProxyServer) SearchScores(ctx context.Context, request *protocol.SearchScoresRequest) (*protocol.SearchScoresResponse, error) { resp, err := p.database.SearchScores(ctx, request.GetCollection(), request.GetSubset(), request.GetQuery(), int(request.GetBegin()), int(request.GetEnd())) if err != nil { return nil, err } scores := make([]*protocol.Score, len(resp)) for i, score := range resp { scores[i] = &protocol.Score{ Id: score.Id, Score: score.Score, IsHidden: score.IsHidden, Categories: score.Categories, Timestamp: timestamppb.New(score.Timestamp), } } return &protocol.SearchScoresResponse{Documents: scores}, nil } func (p *ProxyServer) DeleteScores(ctx context.Context, request *protocol.DeleteScoresRequest) (*protocol.DeleteScoresResponse, error) { var before *time.Time if request.Condition.Before != nil { before = lo.ToPtr(request.Condition.Before.AsTime()) } return &protocol.DeleteScoresResponse{}, p.database.DeleteScores(ctx, request.GetCollection(), ScoreCondition{ Subset: request.Condition.Subset, Id: request.Condition.Id, Before: before, }) } func (p *ProxyServer) UpdateScores(ctx context.Context, request *protocol.UpdateScoresRequest) (*protocol.UpdateScoresResponse, error) { return &protocol.UpdateScoresResponse{}, p.database.UpdateScores(ctx, request.GetCollection(), request.Subset, request.GetId(), ScorePatch{ IsHidden: request.GetPatch().IsHidden, Categories: request.GetPatch().Categories, Score: request.GetPatch().Score, }) } func (p *ProxyServer) ScanScores(request *protocol.ScanScoresRequest, stream grpc.ServerStreamingServer[protocol.ScanScoresResponse]) error { err := p.database.ScanScores(stream.Context(), func(collection, id, subset string, timestamp time.Time) error { return stream.Send(&protocol.ScanScoresResponse{ Collection: collection, Id: id, Subset: subset, Timestamp: timestamppb.New(timestamp), }) }) if err != nil { return status.Errorf(codes.Internal, "failed to scan scores: %v", err) } return nil } func (p *ProxyServer) AddTimeSeriesPoints(ctx context.Context, request *protocol.AddTimeSeriesPointsRequest) (*protocol.AddTimeSeriesPointsResponse, error) { points := make([]TimeSeriesPoint, len(request.Points)) for i, point := range request.Points { points[i] = TimeSeriesPoint{ Name: point.Name, Timestamp: point.Timestamp.AsTime(), Value: point.Value, } } return &protocol.AddTimeSeriesPointsResponse{}, p.database.AddTimeSeriesPoints(ctx, points) } func (p *ProxyServer) GetTimeSeriesPoints(ctx context.Context, request *protocol.GetTimeSeriesPointsRequest) (*protocol.GetTimeSeriesPointsResponse, error) { resp, err := p.database.GetTimeSeriesPoints(ctx, request.GetName(), request.GetBegin().AsTime(), request.GetEnd().AsTime(), time.Duration(request.GetDuration())) if err != nil { return nil, err } points := make([]*protocol.TimeSeriesPoint, len(resp)) for i, point := range resp { points[i] = &protocol.TimeSeriesPoint{ Name: point.Name, Timestamp: timestamppb.New(point.Timestamp), Value: point.Value, } } return &protocol.GetTimeSeriesPointsResponse{Points: points}, nil } type ProxyClient struct { protocol.CacheStoreClient } func (p ProxyClient) Ping() error { _, err := p.CacheStoreClient.Ping(context.Background(), &protocol.PingRequest{}) return err } func (p ProxyClient) Close() error { return nil } func (p ProxyClient) Init() error { return errors.MethodNotAllowedf("init is not allowed in proxy client") } func (p ProxyClient) Scan(_ func(string) error) error { return errors.MethodNotAllowedf("scan is not allowed in proxy client") } func (p ProxyClient) Purge() error { return errors.MethodNotAllowedf("purge is not allowed in proxy client") } func (p ProxyClient) Set(ctx context.Context, values ...Value) error { pbValues := make([]*protocol.Value, len(values)) for i, value := range values { pbValues[i] = &protocol.Value{ Name: value.name, Value: value.value, } } _, err := p.CacheStoreClient.Set(ctx, &protocol.SetRequest{ Values: pbValues, }) return err } func (p ProxyClient) Get(ctx context.Context, name string) *ReturnValue { resp, err := p.CacheStoreClient.Get(ctx, &protocol.GetRequest{ Name: name, }) if err != nil { return &ReturnValue{err: err, exists: false} } if resp.Value == nil { return &ReturnValue{value: "", exists: false} } return &ReturnValue{value: resp.GetValue(), err: err, exists: true} } func (p ProxyClient) Delete(ctx context.Context, name string) error { _, err := p.CacheStoreClient.Delete(ctx, &protocol.DeleteRequest{ Name: name, }) return err } func (p ProxyClient) Push(ctx context.Context, name, value string) error { _, err := p.CacheStoreClient.Push(ctx, &protocol.PushRequest{ Name: name, Value: value, }) return err } func (p ProxyClient) Pop(ctx context.Context, name string) (string, error) { resp, err := p.CacheStoreClient.Pop(ctx, &protocol.PopRequest{ Name: name, }) if err != nil { return "", err } if resp.Value == nil { return "", io.EOF } return resp.GetValue(), nil } func (p ProxyClient) Remain(ctx context.Context, name string) (int64, error) { resp, err := p.CacheStoreClient.Remain(ctx, &protocol.RemainRequest{ Name: name, }) if err != nil { return 0, err } return resp.Count, nil } func (p ProxyClient) AddScores(ctx context.Context, collection, subset string, documents []Score) error { scores := make([]*protocol.Score, len(documents)) for i, doc := range documents { scores[i] = &protocol.Score{ Id: doc.Id, Score: doc.Score, IsHidden: doc.IsHidden, Categories: doc.Categories, Timestamp: timestamppb.New(doc.Timestamp), } } _, err := p.CacheStoreClient.AddScores(ctx, &protocol.AddScoresRequest{ Collection: collection, Subset: subset, Documents: scores, }) return err } func (p ProxyClient) SearchScores(ctx context.Context, collection, subset string, query []string, begin, end int) ([]Score, error) { resp, err := p.CacheStoreClient.SearchScores(ctx, &protocol.SearchScoresRequest{ Collection: collection, Subset: subset, Query: query, Begin: int32(begin), End: int32(end), }) if err != nil { return nil, err } scores := make([]Score, len(resp.Documents)) for i, score := range resp.Documents { scores[i] = Score{ Id: score.Id, Score: score.Score, IsHidden: score.IsHidden, Categories: score.Categories, Timestamp: score.Timestamp.AsTime(), } } return scores, nil } func (p ProxyClient) DeleteScores(ctx context.Context, collection []string, condition ScoreCondition) error { if err := condition.Check(); err != nil { return errors.Trace(err) } var before *timestamppb.Timestamp if condition.Before != nil { before = timestamppb.New(*condition.Before) } _, err := p.CacheStoreClient.DeleteScores(ctx, &protocol.DeleteScoresRequest{ Collection: collection, Condition: &protocol.ScoreCondition{ Subset: condition.Subset, Id: condition.Id, Before: before, }, }) return err } func (p ProxyClient) UpdateScores(ctx context.Context, collection []string, subset *string, id string, patch ScorePatch) error { _, err := p.CacheStoreClient.UpdateScores(ctx, &protocol.UpdateScoresRequest{ Collection: collection, Subset: subset, Id: id, Patch: &protocol.ScorePatch{ Score: patch.Score, IsHidden: patch.IsHidden, Categories: patch.Categories, }, }) return err } func (p ProxyClient) ScanScores(ctx context.Context, callback func(collection string, id string, subset string, timestamp time.Time) error) error { stream, err := p.CacheStoreClient.ScanScores(ctx, &protocol.ScanScoresRequest{}) if err != nil { return err } for { // check for context cancellation select { case <-ctx.Done(): return ctx.Err() default: } // receive the next message resp, err := stream.Recv() if err == io.EOF { return nil } if err != nil { return err } if err := callback(resp.Collection, resp.Id, resp.Subset, resp.Timestamp.AsTime()); err != nil { return err } } } func (p ProxyClient) AddTimeSeriesPoints(ctx context.Context, points []TimeSeriesPoint) error { pbPoints := make([]*protocol.TimeSeriesPoint, len(points)) for i, point := range points { pbPoints[i] = &protocol.TimeSeriesPoint{ Name: point.Name, Timestamp: timestamppb.New(point.Timestamp), Value: point.Value, } } _, err := p.CacheStoreClient.AddTimeSeriesPoints(ctx, &protocol.AddTimeSeriesPointsRequest{ Points: pbPoints, }) return err } func (p ProxyClient) GetTimeSeriesPoints(ctx context.Context, name string, begin, end time.Time, duration time.Duration) ([]TimeSeriesPoint, error) { resp, err := p.CacheStoreClient.GetTimeSeriesPoints(ctx, &protocol.GetTimeSeriesPointsRequest{ Name: name, Begin: timestamppb.New(begin), End: timestamppb.New(end), Duration: int64(duration), }) if err != nil { return nil, err } points := make([]TimeSeriesPoint, len(resp.Points)) for i, point := range resp.Points { points[i] = TimeSeriesPoint{ Name: point.Name, Timestamp: point.Timestamp.AsTime(), Value: point.Value, } } return points, nil } func NewProxyClient(conn *grpc.ClientConn) *ProxyClient { return &ProxyClient{ CacheStoreClient: protocol.NewCacheStoreClient(conn), } } ================================================ FILE: storage/cache/proxy_test.go ================================================ // Copyright 2024 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package cache import ( "fmt" "net" "testing" "github.com/stretchr/testify/suite" "google.golang.org/grpc" ) type ProxyTestSuite struct { baseTestSuite sqlite Database server *ProxyServer clientConn *grpc.ClientConn } func (suite *ProxyTestSuite) SetupSuite() { // create database var err error path := fmt.Sprintf("sqlite://%s/sqlite.db", suite.T().TempDir()) suite.sqlite, err = Open(path, "gorse_") suite.NoError(err) // create schema err = suite.sqlite.Init() suite.NoError(err) // start server lis, err := net.Listen("tcp", "localhost:0") suite.NoError(err) suite.server = NewProxyServer(suite.sqlite) go func() { err = suite.server.Serve(lis) suite.NoError(err) }() // create proxy client suite.clientConn, err = grpc.Dial(lis.Addr().String(), grpc.WithInsecure()) suite.NoError(err) suite.Database = NewProxyClient(suite.clientConn) } func (suite *ProxyTestSuite) TearDownSuite() { suite.server.Stop() suite.NoError(suite.clientConn.Close()) suite.NoError(suite.sqlite.Close()) } func (suite *ProxyTestSuite) SetupTest() { err := suite.sqlite.Ping() suite.NoError(err) err = suite.sqlite.Purge() suite.NoError(err) } func (suite *ProxyTestSuite) TearDownTest() { err := suite.sqlite.Purge() suite.NoError(err) } func (suite *ProxyTestSuite) TestInit() { suite.T().Skip() } func (suite *ProxyTestSuite) TestPurge() { suite.T().Skip() } func (suite *ProxyTestSuite) TestScan() { suite.T().Skip() } func TestProxy(t *testing.T) { suite.Run(t, new(ProxyTestSuite)) } ================================================ FILE: storage/cache/redis.go ================================================ // Copyright 2020 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package cache import ( "context" "encoding/base64" "fmt" "io" "strconv" "strings" "time" "github.com/gorse-io/gorse/common/log" "github.com/gorse-io/gorse/common/util" "github.com/gorse-io/gorse/storage" "github.com/juju/errors" "github.com/redis/go-redis/extra/redisotel/v9" "github.com/redis/go-redis/v9" "github.com/samber/lo" semconv "go.opentelemetry.io/otel/semconv/v1.8.0" "go.uber.org/zap" ) func init() { Register([]string{storage.RedisPrefix, storage.RedissPrefix}, func(path, tablePrefix string, opts ...storage.Option) (Database, error) { opt, err := redis.ParseURL(path) if err != nil { return nil, err } opt.Protocol = 2 database := new(Redis) database.client = redis.NewClient(opt) database.TablePrefix = storage.TablePrefix(tablePrefix) database.maxSearchResults = storage.NewOptions(opts...).MaxSearchResults if err = redisotel.InstrumentTracing(database.client, redisotel.WithAttributes(semconv.DBSystemRedis)); err != nil { log.Logger().Error("failed to add tracing for redis", zap.Error(err)) return nil, errors.Trace(err) } return database, nil }) Register([]string{storage.RedisClusterPrefix, storage.RedissClusterPrefix}, func(path, tablePrefix string, opts ...storage.Option) (Database, error) { var newURL string if strings.HasPrefix(path, storage.RedisClusterPrefix) { newURL = strings.Replace(path, storage.RedisClusterPrefix, storage.RedisPrefix, 1) } else if strings.HasPrefix(path, storage.RedissClusterPrefix) { newURL = strings.Replace(path, storage.RedissClusterPrefix, storage.RedissPrefix, 1) } opt, err := redis.ParseClusterURL(newURL) if err != nil { return nil, err } opt.Protocol = 2 database := new(Redis) database.client = redis.NewClusterClient(opt) database.TablePrefix = storage.TablePrefix(tablePrefix) database.maxSearchResults = storage.NewOptions(opts...).MaxSearchResults if err = redisotel.InstrumentTracing(database.client, redisotel.WithAttributes(semconv.DBSystemRedis)); err != nil { log.Logger().Error("failed to add tracing for redis", zap.Error(err)) return nil, errors.Trace(err) } return database, nil }) } // Redis cache storage. type Redis struct { storage.TablePrefix client redis.UniversalClient maxSearchResults int } // Close redis connection. func (r *Redis) Close() error { return r.client.Close() } func (r *Redis) Ping() error { return r.client.Ping(context.Background()).Err() } // Init nothing. func (r *Redis) Init() error { // list indices indices, err := r.client.FT_List(context.Background()).Result() if err != nil { return errors.Trace(err) } // create index if !lo.Contains(indices, r.DocumentTable()) { _, err = r.client.FTCreate(context.TODO(), r.DocumentTable(), &redis.FTCreateOptions{ OnHash: true, Prefix: []any{r.DocumentTable() + ":"}, }, &redis.FieldSchema{FieldName: "collection", FieldType: redis.SearchFieldTypeTag}, &redis.FieldSchema{FieldName: "subset", FieldType: redis.SearchFieldTypeTag}, &redis.FieldSchema{FieldName: "id", FieldType: redis.SearchFieldTypeTag}, &redis.FieldSchema{FieldName: "score", FieldType: redis.SearchFieldTypeNumeric}, &redis.FieldSchema{FieldName: "is_hidden", FieldType: redis.SearchFieldTypeNumeric}, &redis.FieldSchema{FieldName: "categories", FieldType: redis.SearchFieldTypeTag, Separator: ";"}, &redis.FieldSchema{FieldName: "timestamp", FieldType: redis.SearchFieldTypeNumeric}, ).Result() if err != nil { return errors.Trace(err) } } return nil } func (r *Redis) Scan(work func(string) error) error { ctx := context.Background() if clusterClient, isCluster := r.client.(*redis.ClusterClient); isCluster { return clusterClient.ForEachMaster(ctx, func(ctx context.Context, client *redis.Client) error { return r.scan(ctx, client, work) }) } else { return r.scan(ctx, r.client, work) } } func (r *Redis) scan(ctx context.Context, client redis.UniversalClient, work func(string) error) error { var ( result []string cursor uint64 err error ) for { result, cursor, err = client.Scan(ctx, cursor, string(r.TablePrefix)+"*", 0).Result() if err != nil { return errors.Trace(err) } for _, key := range result { if err = work(key[len(r.TablePrefix):]); err != nil { return errors.Trace(err) } } if cursor == 0 { return nil } } } func (r *Redis) Purge() error { ctx := context.Background() if clusterClient, isCluster := r.client.(*redis.ClusterClient); isCluster { return clusterClient.ForEachMaster(ctx, func(ctx context.Context, client *redis.Client) error { return r.purge(ctx, client, isCluster) }) } else { return r.purge(ctx, r.client, isCluster) } } func (r *Redis) purge(ctx context.Context, client redis.UniversalClient, isCluster bool) error { var ( result []string cursor uint64 err error ) for { result, cursor, err = client.Scan(ctx, cursor, string(r.TablePrefix)+"*", 0).Result() if err != nil { return errors.Trace(err) } if len(result) > 0 { if isCluster { p := client.Pipeline() for _, key := range result { if err = p.Del(ctx, key).Err(); err != nil { return errors.Trace(err) } } if _, err = p.Exec(ctx); err != nil { return errors.Trace(err) } } else { if err = client.Del(ctx, result...).Err(); err != nil { return errors.Trace(err) } } } if cursor == 0 { return nil } } } func (r *Redis) Set(ctx context.Context, values ...Value) error { p := r.client.Pipeline() for _, v := range values { if err := p.Set(ctx, r.Key(v.name), v.value, 0).Err(); err != nil { return errors.Trace(err) } } _, err := p.Exec(ctx) return errors.Trace(err) } // Get returns a value from Redis. func (r *Redis) Get(ctx context.Context, key string) *ReturnValue { val, err := r.client.Get(ctx, r.Key(key)).Result() if err != nil { if err == redis.Nil { return &ReturnValue{value: "", exists: false} } return &ReturnValue{err: err, exists: false} } return &ReturnValue{value: val, exists: true} } // Delete object from Redis. func (r *Redis) Delete(ctx context.Context, key string) error { return r.client.Del(ctx, r.Key(key)).Err() } func (r *Redis) Push(ctx context.Context, name string, message string) error { _, err := r.client.ZAdd(ctx, r.Key(name), redis.Z{Member: message, Score: float64(time.Now().UnixNano())}).Result() return err } func (r *Redis) Pop(ctx context.Context, name string) (string, error) { z, err := r.client.ZPopMin(ctx, r.Key(name), 1).Result() if err != nil { return "", errors.Trace(err) } if len(z) == 0 { return "", io.EOF } return z[0].Member.(string), nil } func (r *Redis) Remain(ctx context.Context, name string) (int64, error) { return r.client.ZCard(ctx, r.Key(name)).Result() } func (r *Redis) documentKey(collection, subset, value string) string { return r.DocumentTable() + ":" + collection + ":" + subset + ":" + value } func (r *Redis) AddScores(ctx context.Context, collection, subset string, documents []Score) error { p := r.client.Pipeline() for _, document := range documents { p.HSet(ctx, r.documentKey(collection, subset, document.Id), "collection", collection, "subset", subset, "id", document.Id, "score", document.Score, "is_hidden", document.IsHidden, "categories", encodeCategories(document.Categories), "timestamp", document.Timestamp.UnixMicro()) } _, err := p.Exec(ctx) return errors.Trace(err) } func (r *Redis) SearchScores(ctx context.Context, collection, subset string, query []string, begin, end int) ([]Score, error) { var builder strings.Builder fmt.Fprintf(&builder, "@collection:{ %s } @is_hidden:[0 0]", escape(collection)) if subset != "" { fmt.Fprintf(&builder, " @subset:{ %s }", escape(subset)) } for _, q := range query { fmt.Fprintf(&builder, " @categories:{ %s }", escape(encodeCategory(q))) } options := &redis.FTSearchOptions{ SortBy: []redis.FTSearchSortBy{{FieldName: "score", Desc: true}}, LimitOffset: begin, } if end == -1 { options.Limit = 10000 } else { options.Limit = end - begin } result, err := r.client.FTSearchWithArgs(ctx, r.DocumentTable(), builder.String(), options).Result() if err != nil { return nil, errors.Trace(err) } documents := make([]Score, 0, len(result.Docs)) for _, doc := range result.Docs { var document Score document.Id = doc.Fields["id"] score, err := strconv.ParseFloat(doc.Fields["score"], 64) if err != nil { return nil, errors.Trace(err) } document.Score = score isHidden, err := strconv.ParseInt(doc.Fields["is_hidden"], 10, 64) if err != nil { return nil, errors.Trace(err) } document.IsHidden = isHidden != 0 categories, err := decodeCategories(doc.Fields["categories"]) if err != nil { return nil, errors.Trace(err) } document.Categories = categories timestamp, err := strconv.ParseInt(doc.Fields["timestamp"], 10, 64) if err != nil { return nil, errors.Trace(err) } document.Timestamp = time.UnixMicro(timestamp).In(time.UTC) documents = append(documents, document) } return documents, nil } func (r *Redis) UpdateScores(ctx context.Context, collections []string, subset *string, id string, patch ScorePatch) error { if len(collections) == 0 { return nil } if patch.Score == nil && patch.IsHidden == nil && patch.Categories == nil { return nil } var builder strings.Builder fmt.Fprintf(&builder, "@collection:{ %s }", escape(strings.Join(collections, " | "))) fmt.Fprintf(&builder, " @id:{ %s }", escape(id)) if subset != nil { fmt.Fprintf(&builder, " @subset:{ %s }", escape(*subset)) } limit := r.maxSearchResults if limit <= 0 { limit = 10000 } // Two-phase update: // 1) collect matched document IDs with pagination, // 2) mutate documents by key. // This avoids pagination drift when patch.Score changes the sort order. keys := make([]string, 0) keySet := make(map[string]struct{}) offset := 0 for { // search documents result, err := r.client.FTSearchWithArgs(ctx, r.DocumentTable(), builder.String(), &redis.FTSearchOptions{ SortBy: []redis.FTSearchSortBy{{FieldName: "score", Desc: true}}, LimitOffset: offset, Limit: limit, }).Result() if err != nil { return errors.Trace(err) } if len(result.Docs) == 0 { break } newKeys := 0 for _, doc := range result.Docs { if _, exists := keySet[doc.ID]; !exists { keySet[doc.ID] = struct{}{} keys = append(keys, doc.ID) newKeys++ } } offset += len(result.Docs) // Stop when: // 1) the last page is shorter than the limit (common Redis behavior), or // 2) no new keys are discovered (defensive for engines with non-standard total/offset semantics). if len(result.Docs) < limit || newKeys == 0 { break } } values := make([]any, 0) if patch.Score != nil { values = append(values, "score", *patch.Score) } if patch.IsHidden != nil { values = append(values, "is_hidden", *patch.IsHidden) } if patch.Categories != nil { values = append(values, "categories", encodeCategories(patch.Categories)) } for _, key := range keys { if err := r.client.Watch(ctx, func(tx *redis.Tx) error { if exist, err := tx.Exists(ctx, key).Result(); err != nil { return err } else if exist == 0 { return nil } return tx.HSet(ctx, key, values...).Err() }, key); err != nil { return errors.Trace(err) } } return nil } func (r *Redis) DeleteScores(ctx context.Context, collections []string, condition ScoreCondition) error { if err := condition.Check(); err != nil { return errors.Trace(err) } var builder strings.Builder fmt.Fprintf(&builder, "@collection:{ %s }", escape(strings.Join(collections, " | "))) if condition.Subset != nil { fmt.Fprintf(&builder, " @subset:{ %s }", escape(*condition.Subset)) } if condition.Id != nil { fmt.Fprintf(&builder, " @id:{ %s }", escape(*condition.Id)) } if condition.Before != nil { fmt.Fprintf(&builder, " @timestamp:[-inf (%d]", condition.Before.UnixMicro()) } for { // search documents result, err := r.client.FTSearchWithArgs(ctx, r.DocumentTable(), builder.String(), &redis.FTSearchOptions{ SortBy: []redis.FTSearchSortBy{{FieldName: "score", Desc: true}}, LimitOffset: 0, Limit: 10000, }).Result() if err != nil { return errors.Trace(err) } // delete documents p := r.client.Pipeline() for _, doc := range result.Docs { p.Del(ctx, doc.ID) } _, err = p.Exec(ctx) if err != nil { return errors.Trace(err) } // break if no more documents if result.Total == len(result.Docs) { break } } return nil } func (r *Redis) ScanScores(ctx context.Context, callback func(collection string, id string, subset string, timestamp time.Time) error) error { if clusterClient, isCluster := r.client.(*redis.ClusterClient); isCluster { return clusterClient.ForEachMaster(ctx, func(ctx context.Context, client *redis.Client) error { return r.scanScores(ctx, client, callback) }) } else { return r.scanScores(ctx, r.client, callback) } } func (r *Redis) scanScores(ctx context.Context, client redis.UniversalClient, callback func(collection string, id string, subset string, timestamp time.Time) error) error { var ( result []string cursor uint64 err error ) for { result, cursor, err = client.Scan(ctx, cursor, r.DocumentTable()+"*", 0).Result() if err != nil { return errors.Trace(err) } for _, key := range result { var row map[string]string row, err = client.HGetAll(ctx, key).Result() if err != nil { return errors.Trace(err) } var usec int64 usec, err = util.ParseInt[int64](row["timestamp"]) if err != nil { return errors.Trace(err) } if err = callback(row["collection"], row["id"], row["subset"], time.UnixMicro(usec).In(time.UTC)); err != nil { return errors.Trace(err) } } if cursor == 0 { return nil } } } func (r *Redis) AddTimeSeriesPoints(ctx context.Context, points []TimeSeriesPoint) error { p := r.client.Pipeline() opt := &redis.TSOptions{DuplicatePolicy: "LAST"} for _, point := range points { if err := p.TSAddWithArgs(ctx, r.PointsTable()+":"+point.Name, point.Timestamp.UnixMilli(), point.Value, opt).Err(); err != nil { return errors.Trace(err) } } _, err := p.Exec(ctx) return errors.Trace(err) } func (r *Redis) GetTimeSeriesPoints(ctx context.Context, name string, begin, end time.Time, duration time.Duration) ([]TimeSeriesPoint, error) { result, err := r.client.TSRangeWithArgs(ctx, r.PointsTable()+":"+name, int(begin.UnixMilli()), int(end.UnixMilli()), &redis.TSRangeOptions{Aggregator: redis.Last, BucketDuration: int(duration / time.Millisecond)}).Result() if err != nil { return nil, errors.Trace(err) } points := make([]TimeSeriesPoint, 0, len(result)) for _, doc := range result { var point TimeSeriesPoint point.Name = name point.Value = doc.Value point.Timestamp = time.UnixMilli(doc.Timestamp).UTC() points = append(points, point) } return points, nil } func encodeCategory(category string) string { return base64.RawStdEncoding.EncodeToString([]byte("_" + category)) } func decodeCategory(s string) (string, error) { b, err := base64.RawStdEncoding.DecodeString(s) if err != nil { return "", errors.Trace(err) } return string(b[1:]), nil } func encodeCategories(categories []string) string { var builder strings.Builder for i, category := range categories { if i > 0 { builder.WriteByte(';') } builder.WriteString(encodeCategory(category)) } return builder.String() } func decodeCategories(s string) ([]string, error) { if s == "" { return []string{}, nil } var categories []string for _, category := range strings.Split(s, ";") { category, err := decodeCategory(category) if err != nil { return nil, errors.Trace(err) } categories = append(categories, category) } return categories, nil } // escape -:. func escape(s string) string { r := strings.NewReplacer( "-", "\\-", ":", "\\:", ".", "\\.", "/", "\\/", "+", "\\+", ) return r.Replace(s) } ================================================ FILE: storage/cache/redis_test.go ================================================ // Copyright 2020 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package cache import ( "context" "fmt" "math" "os" "testing" "time" "github.com/gorse-io/gorse/common/log" "github.com/redis/go-redis/v9" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" ) var ( redisDSN string ) func init() { // get environment variables env := func(key, defaultValue string) string { if value := os.Getenv(key); value != "" { return value } return defaultValue } redisDSN = env("REDIS_URI", "redis://127.0.0.1:6379/") } type RedisTestSuite struct { baseTestSuite } func (suite *RedisTestSuite) SetupSuite() { var err error suite.Database, err = Open(redisDSN, "gorse_") suite.NoError(err) // flush db redisClient, ok := suite.Database.(*Redis) suite.True(ok) if clusterClient, ok := redisClient.client.(*redis.ClusterClient); ok { err = clusterClient.ForEachMaster(suite.T().Context(), func(ctx context.Context, client *redis.Client) error { return client.FlushDB(ctx).Err() }) suite.NoError(err) } else { err = redisClient.client.FlushDB(suite.T().Context()).Err() suite.NoError(err) } // create schema err = suite.Database.Init() suite.NoError(err) } func (suite *RedisTestSuite) TestEscapeCharacters() { ts := time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC) ctx := suite.T().Context() for _, c := range []string{"-", ":", ".", "/"} { suite.Run(c, func() { collection := fmt.Sprintf("a%s1", c) subset := fmt.Sprintf("b%s2", c) id := fmt.Sprintf("c%s3", c) err := suite.AddScores(ctx, collection, subset, []Score{{ Id: id, Score: math.MaxFloat64, Categories: []string{"a", "b"}, Timestamp: ts, }}) suite.NoError(err) documents, err := suite.SearchScores(ctx, collection, subset, []string{"b"}, 0, -1) suite.NoError(err) suite.Equal([]Score{{Id: id, Score: math.MaxFloat64, Categories: []string{"a", "b"}, Timestamp: ts}}, documents) err = suite.UpdateScores(ctx, []string{collection}, nil, id, ScorePatch{Score: new(float64(1))}) suite.NoError(err) documents, err = suite.SearchScores(ctx, collection, subset, []string{"b"}, 0, -1) suite.NoError(err) suite.Equal([]Score{{Id: id, Score: 1, Categories: []string{"a", "b"}, Timestamp: ts}}, documents) err = suite.DeleteScores(ctx, []string{collection}, ScoreCondition{ Subset: new(subset), Id: new(id), }) suite.NoError(err) documents, err = suite.SearchScores(ctx, collection, subset, []string{"b"}, 0, -1) suite.NoError(err) suite.Empty(documents) }) } } func (suite *RedisTestSuite) TestUpdateScoresWithPagination() { ctx := suite.T().Context() db, ok := suite.Database.(*Redis) suite.True(ok) limit := db.maxSearchResults db.maxSearchResults = 2 defer func() { db.maxSearchResults = limit }() for i := 0; i < 5; i++ { subset := fmt.Sprintf("subset-%d", i) err := suite.AddScores(ctx, "collection-a", subset, []Score{{ Id: "shared-item", Score: float64(i), Categories: []string{"old"}, Timestamp: time.Now().UTC(), }}) suite.NoError(err) } err := suite.UpdateScores(ctx, []string{"collection-a"}, nil, "shared-item", ScorePatch{ Categories: []string{"new"}, }) suite.NoError(err) for i := 0; i < 5; i++ { subset := fmt.Sprintf("subset-%d", i) docs, err := suite.SearchScores(ctx, "collection-a", subset, []string{"new"}, 0, -1) suite.NoError(err) suite.Require().Len(docs, 1) suite.Equal("shared-item", docs[0].Id) } } func (suite *RedisTestSuite) TestUpdateScoresWithPaginationAndScorePatch() { ctx := suite.T().Context() db, ok := suite.Database.(*Redis) suite.True(ok) limit := db.maxSearchResults db.maxSearchResults = 1 defer func() { db.maxSearchResults = limit }() initialScores := []float64{3, 2, 1} for i, score := range initialScores { subset := fmt.Sprintf("score-subset-%d", i) err := suite.AddScores(ctx, "collection-b", subset, []Score{{ Id: "shared-item", Score: score, Categories: []string{"score-old"}, Timestamp: time.Now().UTC(), }}) suite.NoError(err) } targetScore := float64(0) err := suite.UpdateScores(ctx, []string{"collection-b"}, nil, "shared-item", ScorePatch{ Score: &targetScore, }) suite.NoError(err) for i := range initialScores { subset := fmt.Sprintf("score-subset-%d", i) docs, err := suite.SearchScores(ctx, "collection-b", subset, nil, 0, -1) suite.NoError(err) suite.Require().Len(docs, 1) suite.Equal(targetScore, docs[0].Score) } } func (suite *RedisTestSuite) TestUpdateScoresWithPaginationAndTiedScores() { ctx := suite.T().Context() db, ok := suite.Database.(*Redis) suite.True(ok) limit := db.maxSearchResults db.maxSearchResults = 2 defer func() { db.maxSearchResults = limit }() for i := 0; i < 5; i++ { subset := fmt.Sprintf("tie-subset-%d", i) err := suite.AddScores(ctx, "collection-c", subset, []Score{{ Id: "shared-item", Score: 1, Categories: []string{"tie-old"}, Timestamp: time.Now().UTC(), }}) suite.NoError(err) } err := suite.UpdateScores(ctx, []string{"collection-c"}, nil, "shared-item", ScorePatch{ Categories: []string{"tie-new"}, }) suite.NoError(err) for i := 0; i < 5; i++ { subset := fmt.Sprintf("tie-subset-%d", i) docs, err := suite.SearchScores(ctx, "collection-c", subset, []string{"tie-new"}, 0, -1) suite.NoError(err) suite.Require().Len(docs, 1) suite.Equal("shared-item", docs[0].Id) } } func TestRedis(t *testing.T) { suite.Run(t, new(RedisTestSuite)) } func TestEncodeDecodeCategories(t *testing.T) { encoded := encodeCategories([]string{"z", "h"}) decoded, err := decodeCategories(encoded) assert.NoError(t, err) assert.Equal(t, []string{"z", "h"}, decoded) encoded = encodeCategories(nil) decoded, err = decodeCategories(encoded) assert.NoError(t, err) assert.Equal(t, []string{}, decoded) } func BenchmarkRedis(b *testing.B) { log.CloseLogger() // open db database, err := Open(redisDSN, "gorse_") assert.NoError(b, err) // flush db err = database.(*Redis).client.FlushDB(b.Context()).Err() assert.NoError(b, err) // create schema err = database.Init() assert.NoError(b, err) // benchmark benchmark(b, database) } ================================================ FILE: storage/cache/sql.go ================================================ // Copyright 2022 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package cache import ( "context" "database/sql" "encoding/json" "fmt" "io" "math" "strings" "time" "github.com/XSAM/otelsql" "github.com/araddon/dateparse" mapset "github.com/deckarep/golang-set/v2" "github.com/go-sql-driver/mysql" "github.com/gorse-io/gorse/storage" "github.com/juju/errors" "github.com/lib/pq" "github.com/samber/lo" semconv "go.opentelemetry.io/otel/semconv/v1.8.0" gormmysql "gorm.io/driver/mysql" "gorm.io/driver/postgres" "gorm.io/driver/sqlite" "gorm.io/gorm" "gorm.io/gorm/clause" ) func init() { Register([]string{storage.PostgresPrefix, storage.PostgreSQLPrefix}, func(path, tablePrefix string, opts ...storage.Option) (Database, error) { database := new(SQLDatabase) database.driver = Postgres database.TablePrefix = storage.TablePrefix(tablePrefix) option := storage.NewOptions(opts...) var err error if database.client, err = otelsql.Open("postgres", path, otelsql.WithAttributes(semconv.DBSystemPostgreSQL), otelsql.WithSpanOptions(otelsql.SpanOptions{DisableErrSkip: true}), ); err != nil { return nil, errors.Trace(err) } storage.ApplySQLPool(database.client, option) database.gormDB, err = gorm.Open(postgres.New(postgres.Config{Conn: database.client}), storage.NewGORMConfig(tablePrefix)) if err != nil { return nil, errors.Trace(err) } return database, nil }) Register([]string{storage.MySQLPrefix}, func(path, tablePrefix string, opts ...storage.Option) (Database, error) { name := path[len(storage.MySQLPrefix):] option := storage.NewOptions(opts...) // probe isolation variable name isolationVarName, err := storage.ProbeMySQLIsolationVariableName(name) if err != nil { return nil, errors.Trace(err) } // append parameters if name, err = storage.AppendMySQLParams(name, map[string]string{ isolationVarName: fmt.Sprintf("'%s'", option.IsolationLevel), "parseTime": "true", }); err != nil { return nil, errors.Trace(err) } // connect to database database := new(SQLDatabase) database.driver = MySQL database.TablePrefix = storage.TablePrefix(tablePrefix) if database.client, err = otelsql.Open("mysql", name, otelsql.WithAttributes(semconv.DBSystemMySQL), otelsql.WithSpanOptions(otelsql.SpanOptions{DisableErrSkip: true}), ); err != nil { return nil, errors.Trace(err) } storage.ApplySQLPool(database.client, option) database.gormDB, err = gorm.Open(gormmysql.New(gormmysql.Config{Conn: database.client}), storage.NewGORMConfig(tablePrefix)) if err != nil { return nil, errors.Trace(err) } return database, nil }) Register([]string{storage.SQLitePrefix}, func(path, tablePrefix string, opts ...storage.Option) (Database, error) { dataSourceName := path[len(storage.SQLitePrefix):] // append parameters var err error if dataSourceName, err = storage.AppendURLParams(dataSourceName, []lo.Tuple2[string, string]{ {"_pragma", "busy_timeout(10000)"}, {"_pragma", "journal_mode(wal)"}, }); err != nil { return nil, errors.Trace(err) } // connect to database database := new(SQLDatabase) database.driver = SQLite database.TablePrefix = storage.TablePrefix(tablePrefix) if database.client, err = otelsql.Open("sqlite", dataSourceName, otelsql.WithAttributes(semconv.DBSystemSqlite), otelsql.WithSpanOptions(otelsql.SpanOptions{DisableErrSkip: true}), ); err != nil { return nil, errors.Trace(err) } database.gormDB, err = gorm.Open(sqlite.Dialector{Conn: database.client}, storage.NewGORMConfig(tablePrefix)) if err != nil { return nil, errors.Trace(err) } return database, nil }) } type SQLDriver int const ( MySQL SQLDriver = iota Postgres SQLite ) type SQLValue struct { Name string `gorm:"type:varchar(256);primaryKey"` Value string `gorm:"type:varchar(256);not null"` } type Message struct { Name string `gorm:"primaryKey;index:timestamp"` Value string `gorm:"primaryKey"` Timestamp int64 `gorm:"index:timestamp"` } type PostgresDocument struct { Collection string `gorm:"primaryKey"` Subset string `gorm:"primaryKey"` Id string `gorm:"primaryKey"` IsHidden bool Categories pq.StringArray `gorm:"type:text[]"` Score float64 Timestamp time.Time } type SQLDocument struct { Collection string `gorm:"primaryKey"` Subset string `gorm:"primaryKey"` Id string `gorm:"primaryKey"` IsHidden bool Categories []string `gorm:"type:text;serializer:json"` Score float64 Timestamp time.Time } type SQLDatabase struct { storage.TablePrefix gormDB *gorm.DB client *sql.DB driver SQLDriver } func (db *SQLDatabase) Close() error { return db.client.Close() } func (db *SQLDatabase) Ping() error { return db.client.Ping() } func (db *SQLDatabase) Init() error { err := db.gormDB.AutoMigrate(&SQLValue{}, &Message{}, &TimeSeriesPoint{}) if err != nil { return errors.Trace(err) } switch db.driver { case Postgres: err = db.gormDB.AutoMigrate(&PostgresDocument{}) if err != nil { return errors.Trace(err) } // create extension btree_gin err = db.gormDB.Exec("CREATE EXTENSION IF NOT EXISTS btree_gin").Error if err != nil { return errors.Trace(err) } // create index err = db.gormDB.Exec(fmt.Sprintf("CREATE INDEX IF NOT EXISTS idx_collection_subset_categories ON %s USING GIN (collection, subset, categories)", db.DocumentTable())).Error if err != nil { return errors.Trace(err) } err = db.gormDB.Exec(fmt.Sprintf("CREATE INDEX IF NOT EXISTS idx_collection_id ON %s (collection, id)", db.DocumentTable())).Error if err != nil { return errors.Trace(err) } case MySQL: err = db.gormDB.AutoMigrate(&SQLDocument{}) if err != nil { return errors.Trace(err) } // create index err = db.gormDB.Exec(fmt.Sprintf("ALTER TABLE %s ADD INDEX idx_collection_subset_categories (collection, subset, (CAST(categories AS CHAR(255) ARRAY)))", db.DocumentTable())).Error if err != nil { if mysqlErr, ok := err.(*mysql.MySQLError); ok && mysqlErr.Number == 1061 { // ignore duplicate index error } else { return errors.Trace(err) } } err = db.gormDB.Exec(fmt.Sprintf("ALTER TABLE %s ADD INDEX idx_collection_id (collection, id)", db.DocumentTable())).Error if err != nil { if mysqlErr, ok := err.(*mysql.MySQLError); ok && mysqlErr.Number == 1061 { // ignore duplicate index error err = nil } else { return errors.Trace(err) } } case SQLite: err = db.gormDB.AutoMigrate(&SQLDocument{}) } return errors.Trace(err) } func (db *SQLDatabase) Scan(work func(string) error) error { var ( valuerRows *sql.Rows err error ) // scan values valuerRows, err = db.gormDB.Table(db.ValuesTable()).Select("name").Rows() if err != nil { return errors.Trace(err) } defer valuerRows.Close() for valuerRows.Next() { var key string if err = valuerRows.Scan(&key); err != nil { return errors.Trace(err) } if err = work(key); err != nil { return errors.Trace(err) } } return nil } func (db *SQLDatabase) Purge() error { tables := []any{SQLValue{}, Message{}, SQLDocument{}} for _, table := range tables { err := db.gormDB.Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&table).Error if err != nil { return errors.Trace(err) } } return nil } func (db *SQLDatabase) Set(ctx context.Context, values ...Value) error { if len(values) == 0 { return nil } valueSet := mapset.NewSet[string]() rows := make([]SQLValue, 0, len(values)) for _, value := range values { if !valueSet.Contains(value.name) { rows = append(rows, SQLValue{ Name: value.name, Value: value.value, }) valueSet.Add(value.name) } } err := db.gormDB.WithContext(ctx).Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: "name"}}, DoUpdates: clause.AssignmentColumns([]string{"value"}), }).Create(rows).Error return errors.Trace(err) } func (db *SQLDatabase) Get(ctx context.Context, name string) *ReturnValue { rs, err := db.gormDB.WithContext(ctx).Table(db.ValuesTable()).Where("name = ?", name).Select("value").Rows() if err != nil { return &ReturnValue{err: errors.Trace(err), exists: false} } defer rs.Close() if rs.Next() { var value string err := rs.Scan(&value) if err != nil { return &ReturnValue{err: errors.Trace(err), exists: false} } return &ReturnValue{value: value, exists: true} } return &ReturnValue{value: "", exists: false} } func (db *SQLDatabase) Delete(ctx context.Context, name string) error { err := db.gormDB.WithContext(ctx).Delete(&SQLValue{Name: name}).Error return errors.Trace(err) } func (db *SQLDatabase) Push(ctx context.Context, name, value string) error { return db.gormDB.WithContext(ctx).Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: "name"}, {Name: "value"}}, DoUpdates: clause.AssignmentColumns([]string{"timestamp"}), }).Create(&Message{ Name: name, Value: value, Timestamp: time.Now().UnixNano(), }).Error } func (db *SQLDatabase) Pop(ctx context.Context, name string) (string, error) { var message Message err := db.gormDB.WithContext(ctx).Transaction(func(tx *gorm.DB) error { if err := db.gormDB.Order("timestamp").First(&message, "name = ?", name).Error; err != nil { return err } if err := db.gormDB.Delete(&message).Error; err != nil { return err } return nil }) if err == gorm.ErrRecordNotFound { return "", io.EOF } return message.Value, err } func (db *SQLDatabase) Remain(ctx context.Context, name string) (count int64, err error) { err = db.gormDB.WithContext(ctx).Model(&Message{}).Where("name = ?", name).Count(&count).Error return } func (db *SQLDatabase) AddScores(ctx context.Context, collection, subset string, documents []Score) error { var rows any switch db.driver { case Postgres: rows = lo.Map(documents, func(document Score, _ int) PostgresDocument { return PostgresDocument{ Collection: collection, Subset: subset, Id: document.Id, Score: document.Score, IsHidden: document.IsHidden, Categories: document.Categories, Timestamp: document.Timestamp, } }) case SQLite, MySQL: rows = lo.Map(documents, func(document Score, _ int) SQLDocument { return SQLDocument{ Collection: collection, Subset: subset, Id: document.Id, Score: document.Score, IsHidden: document.IsHidden, Categories: document.Categories, Timestamp: document.Timestamp, } }) } err := db.gormDB.WithContext(ctx).Table(db.DocumentTable()).Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: "collection"}, {Name: "subset"}, {Name: "id"}}, DoUpdates: clause.AssignmentColumns([]string{"score", "categories", "timestamp"}), }).Create(rows).Error return errors.Trace(err) } func (db *SQLDatabase) SearchScores(ctx context.Context, collection, subset string, query []string, begin, end int) ([]Score, error) { tx := db.gormDB.WithContext(ctx). Model(&PostgresDocument{}). Select("id, score, categories, timestamp"). Where("collection = ? and subset = ? and is_hidden = false", collection, subset) if len(query) > 0 { switch db.driver { case Postgres: tx.Where("categories @> ?", pq.StringArray(query)) case SQLite, MySQL: q, err := json.Marshal(query) if err != nil { return nil, errors.Trace(err) } tx.Where("JSON_CONTAINS(categories,?)", string(q)) } } tx.Order("score desc").Offset(begin) if end != -1 { tx.Limit(end - begin) } else { tx.Limit(math.MaxInt64) } rows, err := tx.Rows() if err != nil { return nil, errors.Trace(err) } documents := make([]Score, 0, 10) for rows.Next() { switch db.driver { case Postgres: var document PostgresDocument if err = rows.Scan(&document.Id, &document.Score, &document.Categories, &document.Timestamp); err != nil { return nil, errors.Trace(err) } documents = append(documents, Score{ Id: document.Id, Score: document.Score, Categories: document.Categories, Timestamp: document.Timestamp, }) case SQLite, MySQL: var document Score if err = db.gormDB.ScanRows(rows, &document); err != nil { return nil, errors.Trace(err) } document.Timestamp = document.Timestamp.In(time.UTC) documents = append(documents, document) } } return documents, nil } func (db *SQLDatabase) UpdateScores(ctx context.Context, collections []string, subset *string, id string, patch ScorePatch) error { if len(collections) == 0 { return nil } if patch.Score == nil && patch.IsHidden == nil && patch.Categories == nil { return nil } tx := db.gormDB.WithContext(ctx). Model(&PostgresDocument{}). Where("collection in (?) and id = ?", collections, id) if subset != nil { tx.Where("subset = ?", subset) } if patch.Score != nil { tx.Update("score", *patch.Score) } if patch.IsHidden != nil { tx.Update("is_hidden", *patch.IsHidden) } if patch.Categories != nil { switch db.driver { case Postgres: tx.Update("categories", pq.StringArray(patch.Categories)) case SQLite, MySQL: q, err := json.Marshal(patch.Categories) if err != nil { return errors.Trace(err) } tx.Update("categories", string(q)) } } return tx.Error } func (db *SQLDatabase) DeleteScores(ctx context.Context, collections []string, condition ScoreCondition) error { if err := condition.Check(); err != nil { return errors.Trace(err) } var builder strings.Builder builder.WriteString("collection in (?)") var args []any args = append(args, collections) if condition.Subset != nil { builder.WriteString(" and subset = ?") args = append(args, *condition.Subset) } if condition.Id != nil { builder.WriteString(" and id = ?") args = append(args, *condition.Id) } if condition.Before != nil { builder.WriteString(" and timestamp < ?") if db.driver == MySQL { // In MySQL, we need to truncate the time to milliseconds because MySQL will round the time to milliseconds. args = append(args, condition.Before.Truncate(time.Millisecond)) } else { args = append(args, *condition.Before) } } return db.gormDB.WithContext(ctx).Delete(&SQLDocument{}, append([]any{builder.String()}, args...)...).Error } func (db *SQLDatabase) ScanScores(ctx context.Context, callback func(collection, id, subset string, timestamp time.Time) error) error { rows, err := db.gormDB.WithContext(ctx).Table(db.DocumentTable()).Select("collection, id, subset, timestamp").Rows() if err != nil { return errors.Trace(err) } defer rows.Close() for rows.Next() { var collection, id, subset string var timestamp time.Time if err = rows.Scan(&collection, &id, &subset, ×tamp); err != nil { return errors.Trace(err) } if err = callback(collection, id, subset, timestamp); err != nil { return errors.Trace(err) } } return nil } func (db *SQLDatabase) AddTimeSeriesPoints(ctx context.Context, points []TimeSeriesPoint) error { if len(points) == 0 { return nil } return db.gormDB.WithContext(ctx).Table(db.PointsTable()).Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: "name"}, {Name: "timestamp"}}, DoUpdates: clause.AssignmentColumns([]string{"value"}), }).Create(points).Error } func (db *SQLDatabase) GetTimeSeriesPoints(ctx context.Context, name string, begin, end time.Time, duration time.Duration) ([]TimeSeriesPoint, error) { var points []TimeSeriesPoint switch db.driver { case Postgres: if err := db.gormDB.WithContext(ctx). Raw(fmt.Sprintf("SELECT name, bucket_timestamp AS timestamp, value FROM ("+ "SELECT *, TO_TIMESTAMP((EXTRACT(epoch FROM timestamp)::int / ?) * ?) AS bucket_timestamp,"+ "ROW_NUMBER() OVER (PARTITION BY (EXTRACT(epoch FROM timestamp)::int / ?) ORDER BY timestamp DESC) AS rn "+ "FROM %s WHERE name = ? and timestamp >= ? and timestamp <= ?) AS t WHERE rn = 1", db.PointsTable()), int(duration.Seconds()), int(duration.Seconds()), int(duration.Seconds()), name, begin, end). Scan(&points).Error; err != nil { return nil, errors.Trace(err) } case MySQL: if err := db.gormDB.WithContext(ctx). Raw(fmt.Sprintf("SELECT name, bucket_timestamp AS timestamp, value FROM("+ "SELECT *, FROM_UNIXTIME(FLOOR(UNIX_TIMESTAMP(timestamp) / ?) * ?) AS bucket_timestamp,"+ "ROW_NUMBER() OVER (PARTITION BY FLOOR(UNIX_TIMESTAMP(timestamp) / ?) ORDER BY timestamp DESC) AS rn "+ "FROM %s WHERE name = ? and timestamp >= ? and timestamp <= ?) AS t WHERE rn = 1;", db.PointsTable()), int(duration.Seconds()), int(duration.Seconds()), int(duration.Seconds()), name, begin, end). Scan(&points).Error; err != nil { return nil, errors.Trace(err) } case SQLite: rows, err := db.gormDB.WithContext(ctx). Raw(fmt.Sprintf("select name, bucket_timestamp as timestamp, value from ("+ "select *, datetime(strftime('%%s', substr(timestamp, 0, 20)) / ? * ?, 'unixepoch') as bucket_timestamp,"+ "row_number() over (partition by strftime('%%s', substr(timestamp, 0, 20)) / ? order by timestamp desc) as rn "+ "from %s where name = ? and timestamp >= ? and timestamp <= ?) where rn = 1", db.PointsTable()), int(duration.Seconds()), int(duration.Seconds()), int(duration.Seconds()), name, begin, end). Rows() if err != nil { return nil, errors.Trace(err) } defer rows.Close() for rows.Next() { var point TimeSeriesPoint var timestamp string if err := rows.Scan(&point.Name, ×tamp, &point.Value); err != nil { return nil, errors.Trace(err) } point.Timestamp, err = dateparse.ParseAny(timestamp) if err != nil { return nil, errors.Trace(err) } points = append(points, point) } } return points, nil } ================================================ FILE: storage/cache/sql_test.go ================================================ // Copyright 2022 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package cache import ( "database/sql" "fmt" "os" "strings" "testing" "github.com/gorse-io/gorse/common/log" "github.com/gorse-io/gorse/storage" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" ) var ( mySqlDSN string postgresDSN string ) func init() { // get environment variables env := func(key, defaultValue string) string { if value := os.Getenv(key); value != "" { return value } return defaultValue } mySqlDSN = env("MYSQL_URI", "mysql://root:password@tcp(127.0.0.1:3306)/") postgresDSN = env("POSTGRES_URI", "postgres://gorse:gorse_pass@127.0.0.1/") } type PostgresTestSuite struct { baseTestSuite } func (suite *PostgresTestSuite) SetupSuite() { var err error // create database databaseComm, err := sql.Open("postgres", postgresDSN+"?sslmode=disable") suite.NoError(err) const dbName = "gorse_cache_test" _, err = databaseComm.Exec("DROP DATABASE IF EXISTS " + dbName) suite.NoError(err) _, err = databaseComm.Exec("CREATE DATABASE " + dbName) suite.NoError(err) err = databaseComm.Close() suite.NoError(err) // connect database suite.Database, err = Open(postgresDSN+strings.ToLower(dbName)+"?sslmode=disable", "gorse_") suite.NoError(err) // create schema err = suite.Database.Init() suite.NoError(err) } func TestPostgres(t *testing.T) { suite.Run(t, new(PostgresTestSuite)) } type MySQLTestSuite struct { baseTestSuite } func (suite *MySQLTestSuite) SetupSuite() { // create database databaseComm, err := sql.Open("mysql", mySqlDSN[len(storage.MySQLPrefix):]) suite.NoError(err) const dbName = "gorse_cache_test" _, err = databaseComm.Exec("DROP DATABASE IF EXISTS " + dbName) suite.NoError(err) _, err = databaseComm.Exec("CREATE DATABASE " + dbName) suite.NoError(err) err = databaseComm.Close() suite.NoError(err) // connect database suite.Database, err = Open(mySqlDSN+dbName, "gorse_") suite.NoError(err) // create schema err = suite.Database.Init() suite.NoError(err) } func (suite *MySQLTestSuite) TestInit() { err := suite.Database.Init() suite.NoError(err) name, err := storage.ProbeMySQLIsolationVariableName(mySqlDSN[len(storage.MySQLPrefix):]) suite.NoError(err) connection := suite.Database.(*SQLDatabase).client assertQuery(suite.T(), connection, fmt.Sprintf("SELECT @@%s", name), "READ-UNCOMMITTED") } func TestMySQL(t *testing.T) { suite.Run(t, new(MySQLTestSuite)) } type SQLiteTestSuite struct { baseTestSuite } func (suite *SQLiteTestSuite) SetupSuite() { var err error // create database path := fmt.Sprintf("sqlite://%s/sqlite.db", suite.T().TempDir()) suite.Database, err = Open(path, "gorse_") suite.NoError(err) // create schema err = suite.Database.Init() suite.NoError(err) } func (suite *SQLiteTestSuite) TearDownSuite() { suite.NoError(suite.Database.Close()) } func TestSQLite(t *testing.T) { suite.Run(t, new(SQLiteTestSuite)) } func assertQuery(t *testing.T, connection *sql.DB, sql string, expected string) { rows, err := connection.Query(sql) assert.NoError(t, err) assert.True(t, rows.Next()) var result string err = rows.Scan(&result) assert.NoError(t, err) assert.Equal(t, expected, result) } func BenchmarkPostgres(b *testing.B) { log.CloseLogger() // create database databaseComm, err := sql.Open("postgres", postgresDSN+"?sslmode=disable") assert.NoError(b, err) const dbName = "gorse_cache_benchmark" _, err = databaseComm.Exec("DROP DATABASE IF EXISTS " + dbName) assert.NoError(b, err) _, err = databaseComm.Exec("CREATE DATABASE " + dbName) assert.NoError(b, err) err = databaseComm.Close() assert.NoError(b, err) // connect database database, err := Open(postgresDSN+strings.ToLower(dbName)+"?sslmode=disable", "gorse_") assert.NoError(b, err) // create schema err = database.Init() assert.NoError(b, err) // benchmark benchmark(b, database) // close database err = database.Close() assert.NoError(b, err) } func BenchmarkMySQL(b *testing.B) { log.CloseLogger() // create database databaseComm, err := sql.Open("mysql", mySqlDSN[len(storage.MySQLPrefix):]) assert.NoError(b, err) const dbName = "gorse_cache_benchmark" _, err = databaseComm.Exec("DROP DATABASE IF EXISTS " + dbName) assert.NoError(b, err) _, err = databaseComm.Exec("CREATE DATABASE " + dbName) assert.NoError(b, err) err = databaseComm.Close() assert.NoError(b, err) // connect database database, err := Open(mySqlDSN+dbName, "gorse_") assert.NoError(b, err) // create schema err = database.Init() assert.NoError(b, err) // benchmark benchmark(b, database) } func BenchmarkSQLite(b *testing.B) { log.CloseLogger() // create database path := fmt.Sprintf("sqlite://%s/sqlite.db", b.TempDir()) database, err := Open(path, "gorse_") assert.NoError(b, err) // create schema err = database.Init() assert.NoError(b, err) // benchmark benchmark(b, database) } ================================================ FILE: storage/data/database.go ================================================ // Copyright 2021 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package data import ( "context" "encoding/json" "reflect" "sort" "strings" "time" "github.com/gorse-io/gorse/common/expression" "github.com/gorse-io/gorse/common/jsonutil" "github.com/gorse-io/gorse/storage" "github.com/juju/errors" ) var ( ErrUserNotExist = errors.NotFoundf("user") ErrItemNotExist = errors.NotFoundf("item") ErrNoDatabase = errors.NotAssignedf("database") ) // ValidateLabels checks if labels are valid. Labels are valid if consists of: // - []string slice of strings // - []float64 slice of numbers // - map[string]any map of strings to valid labels or float64 func ValidateLabels(o any) error { if o == nil { return nil } switch labels := o.(type) { case []any: // must be []string or []float64 if len(labels) == 0 { return nil } switch labels[0].(type) { case string: for _, val := range labels { if _, ok := val.(string); !ok { return errors.Errorf("unsupported labels: %v", jsonutil.MustMarshal(labels)) } } case json.Number: for _, val := range labels { if _, ok := val.(json.Number); !ok { return errors.Errorf("unsupported labels: %v", jsonutil.MustMarshal(labels)) } } default: return errors.Errorf("unsupported labels: %v", jsonutil.MustMarshal(labels)) } return nil case map[string]any: for _, val := range labels { if err := ValidateLabels(val); err != nil { return err } } return nil case string, json.Number: return nil default: return errors.Errorf("unsupported type in labels: %v", reflect.TypeOf(labels)) } } // Item stores meta data about item. type Item struct { ItemId string `gorm:"primaryKey" mapstructure:"item_id"` IsHidden bool `mapstructure:"is_hidden"` Categories []string `gorm:"serializer:json" mapstructure:"categories"` Timestamp time.Time `gorm:"column:time_stamp" mapstructure:"timestamp"` Labels any `gorm:"serializer:json" mapstructure:"labels"` Comment string `mapstructure:"comment"` } // ItemPatch is the modification on an item. type ItemPatch struct { IsHidden *bool Categories []string Timestamp *time.Time Labels any Comment *string } // User stores meta data about user. type User struct { UserId string `gorm:"primaryKey" mapstructure:"user_id"` Labels any `gorm:"serializer:json" mapstructure:"labels"` Comment string `mapstructure:"comment"` } // UserPatch is the modification on a user. type UserPatch struct { Labels any Comment *string } // FeedbackKey identifies feedback. type FeedbackKey struct { FeedbackType string `gorm:"column:feedback_type" mapstructure:"feedback_type"` UserId string `gorm:"column:user_id" mapstructure:"user_id"` ItemId string `gorm:"column:item_id" mapstructure:"item_id"` } // Feedback stores feedback. type Feedback struct { FeedbackKey `gorm:"embedded" mapstructure:",squash"` Value float64 `gorm:"column:value" mapstructure:"value"` Timestamp time.Time `gorm:"column:time_stamp" mapstructure:"timestamp"` Updated time.Time `gorm:"column:updated" mapstructure:"updated"` Comment string `gorm:"column:comment" mapstructure:"comment"` } type UserFeedback Feedback type ItemFeedback Feedback // SortFeedbacks sorts feedback from latest to oldest. func SortFeedbacks(feedback []Feedback) { sort.Sort(feedbackSorter(feedback)) } type feedbackSorter []Feedback func (sorter feedbackSorter) Len() int { return len(sorter) } func (sorter feedbackSorter) Less(i, j int) bool { return sorter[i].Timestamp.After(sorter[j].Timestamp) } func (sorter feedbackSorter) Swap(i, j int) { sorter[i], sorter[j] = sorter[j], sorter[i] } type ScanOptions struct { BeginUserId *string EndUserId *string BeginItemId *string EndItemId *string BeginTime *time.Time EndTime *time.Time FeedbackTypes []expression.FeedbackTypeExpression OrderByItemId bool } type ScanOption func(options *ScanOptions) // WithBeginUserId sets the begin user id. The begin user id is included in the result. func WithBeginUserId(userId string) ScanOption { return func(options *ScanOptions) { options.BeginUserId = &userId } } // WithEndUserId sets the end user id. The end user id is included in the result. func WithEndUserId(userId string) ScanOption { return func(options *ScanOptions) { options.EndUserId = &userId } } // WithBeginItemId sets the beginning item id. The beginning item id is included in the result. func WithBeginItemId(itemId string) ScanOption { return func(options *ScanOptions) { options.BeginItemId = &itemId } } // WithEndItemId sets the end item id. The end item id is included in the result. func WithEndItemId(itemId string) ScanOption { return func(options *ScanOptions) { options.EndItemId = &itemId } } // WithBeginTime sets the begin time. The begin time is included in the result. func WithBeginTime(t time.Time) ScanOption { return func(options *ScanOptions) { options.BeginTime = &t } } // WithEndTime sets the end time. The end time is included in the result. func WithEndTime(t time.Time) ScanOption { return func(options *ScanOptions) { options.EndTime = &t } } // WithFeedbackTypes sets the feedback types. func WithFeedbackTypes(feedbackTypes ...expression.FeedbackTypeExpression) ScanOption { return func(options *ScanOptions) { options.FeedbackTypes = feedbackTypes } } // WithOrderByItemId sets the order by item id. func WithOrderByItemId() ScanOption { return func(options *ScanOptions) { options.OrderByItemId = true } } func NewScanOptions(opts ...ScanOption) ScanOptions { options := ScanOptions{} for _, opt := range opts { if opt != nil { opt(&options) } } return options } type Database interface { Init() error Ping() error Close() error Optimize() error Purge() error BatchInsertItems(ctx context.Context, items []Item) error BatchGetItems(ctx context.Context, itemIds []string) ([]Item, error) DeleteItem(ctx context.Context, itemId string) error GetItem(ctx context.Context, itemId string) (Item, error) ModifyItem(ctx context.Context, itemId string, patch ItemPatch) error GetItems(ctx context.Context, cursor string, n int, beginTime *time.Time) (string, []Item, error) GetLatestItems(ctx context.Context, n int, categories []string) ([]Item, error) GetItemFeedback(ctx context.Context, itemId string, feedbackTypes ...string) ([]Feedback, error) BatchInsertUsers(ctx context.Context, users []User) error DeleteUser(ctx context.Context, userId string) error GetUser(ctx context.Context, userId string) (User, error) ModifyUser(ctx context.Context, userId string, patch UserPatch) error GetUsers(ctx context.Context, cursor string, n int) (string, []User, error) GetUserFeedback(ctx context.Context, userId string, endTime *time.Time, feedbackTypes ...expression.FeedbackTypeExpression) ([]Feedback, error) GetUserItemFeedback(ctx context.Context, userId, itemId string, feedbackTypes ...string) ([]Feedback, error) DeleteUserItemFeedback(ctx context.Context, userId, itemId string, feedbackTypes ...string) (int, error) BatchInsertFeedback(ctx context.Context, feedback []Feedback, insertUser, insertItem, overwrite bool) error GetFeedback(ctx context.Context, cursor string, n int, beginTime, endTime *time.Time, feedbackTypes ...string) (string, []Feedback, error) GetUserStream(ctx context.Context, batchSize int) (chan []User, chan error) GetItemStream(ctx context.Context, batchSize int, timeLimit *time.Time) (chan []Item, chan error) GetFeedbackStream(ctx context.Context, batchSize int, options ...ScanOption) (chan []Feedback, chan error) CountUsers(ctx context.Context) (int, error) CountItems(ctx context.Context) (int, error) CountFeedback(ctx context.Context) (int, error) } // Creator creates a database instance. type Creator func(path, tablePrefix string, opts ...storage.Option) (Database, error) var creators = make(map[string]Creator) // Register a database creator. func Register(prefixes []string, creator Creator) { for _, p := range prefixes { creators[p] = creator } } // Open a connection to a database. func Open(path, tablePrefix string, opts ...storage.Option) (Database, error) { for prefix, creator := range creators { if strings.HasPrefix(path, prefix) { return creator(path, tablePrefix, opts...) } } return nil, errors.Errorf("Unknown database: %s", path) } ================================================ FILE: storage/data/database_test.go ================================================ // Copyright 2021 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package data import ( "context" "encoding/json" "fmt" "reflect" "strconv" "testing" "time" "github.com/gorse-io/gorse/common/expression" "github.com/jaswdr/faker" "github.com/juju/errors" "github.com/samber/lo" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" ) var ( positiveFeedbackType = "positiveFeedbackType" positiveFeedbackType1 = "positiveFeedbackType1" positiveFeedbackType2 = "positiveFeedbackType2" negativeFeedbackType = "negativeFeedbackType" duplicateFeedbackType = "duplicateFeedbackType" dateTime64Zero = time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC) ) type baseTestSuite struct { suite.Suite Database } func (suite *baseTestSuite) getUsers(ctx context.Context, batchSize int) []User { users := make([]User, 0) var err error var data []User cursor := "" for { cursor, data, err = suite.Database.GetUsers(ctx, cursor, batchSize) suite.NoError(err) users = append(users, data...) if cursor == "" { suite.LessOrEqual(len(data), batchSize) return users } else { suite.Equal(batchSize, len(data)) } } } func (suite *baseTestSuite) getUsersStream(ctx context.Context, batchSize int) []User { var users []User userChan, errChan := suite.Database.GetUserStream(ctx, batchSize) for batchUsers := range userChan { users = append(users, batchUsers...) } suite.NoError(<-errChan) return users } func (suite *baseTestSuite) getItems(ctx context.Context, batchSize int) []Item { items := make([]Item, 0) var err error var data []Item cursor := "" for { cursor, data, err = suite.Database.GetItems(ctx, cursor, batchSize, nil) suite.NoError(err) items = append(items, data...) if cursor == "" { suite.LessOrEqual(len(data), batchSize) return items } else { suite.Equal(batchSize, len(data)) } } } func (suite *baseTestSuite) getItemStream(ctx context.Context, batchSize int) []Item { var items []Item itemChan, errChan := suite.Database.GetItemStream(ctx, batchSize, nil) for batchUsers := range itemChan { items = append(items, batchUsers...) } suite.NoError(<-errChan) return items } func (suite *baseTestSuite) getFeedback(ctx context.Context, batchSize int, beginTime, endTime *time.Time, feedbackTypes ...string) []Feedback { feedback := make([]Feedback, 0) var err error var data []Feedback cursor := "" for { cursor, data, err = suite.Database.GetFeedback(ctx, cursor, batchSize, beginTime, endTime, feedbackTypes...) suite.NoError(err) feedback = append(feedback, data...) if cursor == "" { suite.LessOrEqual(len(data), batchSize) return feedback } else { suite.Equal(batchSize, len(data)) } } } func (suite *baseTestSuite) getFeedbackStream(ctx context.Context, batchSize int, scanOptions ...ScanOption) []Feedback { var feedbacks []Feedback feedbackChan, errChan := suite.Database.GetFeedbackStream(ctx, batchSize, scanOptions...) for batchFeedback := range feedbackChan { feedbacks = append(feedbacks, batchFeedback...) } suite.NoError(<-errChan) return feedbacks } func (suite *baseTestSuite) isClickHouse() bool { if sqlDB, isSQL := suite.Database.(*SQLDatabase); !isSQL { return false } else { return sqlDB.driver == ClickHouse } } func (suite *baseTestSuite) analyzeTables() { sqlDatabase, ok := suite.Database.(*SQLDatabase) if ok && sqlDatabase.driver == Postgres { sqlDatabase := suite.Database.(*SQLDatabase) err := sqlDatabase.gormDB.Exec(fmt.Sprintf("ANALYZE %s", sqlDatabase.ItemsTable())).Error suite.NoError(err) err = sqlDatabase.gormDB.Exec(fmt.Sprintf("ANALYZE %s", sqlDatabase.UsersTable())).Error suite.NoError(err) err = sqlDatabase.gormDB.Exec(fmt.Sprintf("ANALYZE %s", sqlDatabase.FeedbackTable())).Error suite.NoError(err) } } func (suite *baseTestSuite) TearDownSuite() { err := suite.Database.Close() suite.NoError(err) } func (suite *baseTestSuite) SetupTest() { err := suite.Database.Ping() suite.NoError(err) err = suite.Database.Purge() suite.NoError(err) } func (suite *baseTestSuite) TearDownTest() { err := suite.Database.Purge() suite.NoError(err) } func (suite *baseTestSuite) TestInit() { err := suite.Database.Init() suite.NoError(err) } func (suite *baseTestSuite) TestUsers() { ctx := suite.T().Context() // Insert users var insertedUsers []User fake := faker.New() for i := 9; i >= 0; i-- { insertedUsers = append(insertedUsers, User{ UserId: strconv.Itoa(i), Labels: map[string]any{ "color": fake.Color().ColorName(), "company": lo.Map(lo.Range(3), func(_, _ int) any { return fake.Genre().Name() }), }, Comment: fmt.Sprintf("comment %d", i), }) } err := suite.Database.BatchInsertUsers(ctx, insertedUsers) suite.NoError(err) // Count users suite.analyzeTables() count, err := suite.Database.CountUsers(ctx) suite.NoError(err) suite.Equal(10, count) // Get users users := suite.getUsers(ctx, 3) suite.Equal(10, len(users)) for i, user := range users { suite.Equal(insertedUsers[9-i], user) } // Get user stream usersFromStream := suite.getUsersStream(ctx, 3) suite.ElementsMatch(insertedUsers, usersFromStream) // Get this user user, err := suite.Database.GetUser(ctx, "0") suite.NoError(err) suite.Equal("0", user.UserId) // Delete this user err = suite.Database.DeleteUser(ctx, "0") suite.NoError(err) _, err = suite.Database.GetUser(ctx, "0") suite.True(errors.Is(err, errors.NotFound), err) // test override err = suite.Database.BatchInsertUsers(ctx, []User{{UserId: "1", Comment: "override"}}) suite.NoError(err) err = suite.Database.Optimize() suite.NoError(err) user, err = suite.Database.GetUser(ctx, "1") suite.NoError(err) suite.Equal("override", user.Comment) // test modify err = suite.Database.ModifyUser(ctx, "1", UserPatch{Comment: new("modify")}) suite.NoError(err) err = suite.Database.ModifyUser(ctx, "1", UserPatch{Labels: []string{"a", "b", "c"}}) suite.NoError(err) err = suite.Database.Optimize() suite.NoError(err) user, err = suite.Database.GetUser(ctx, "1") suite.NoError(err) suite.Equal("modify", user.Comment) suite.Equal([]any{"a", "b", "c"}, user.Labels) // test insert empty err = suite.Database.BatchInsertUsers(ctx, nil) suite.NoError(err) // insert duplicate users err = suite.Database.BatchInsertUsers(ctx, []User{{UserId: "1"}, {UserId: "1"}}) suite.NoError(err) } func (suite *baseTestSuite) TestFeedback() { ctx := suite.T().Context() // users that already exists err := suite.Database.BatchInsertUsers(ctx, []User{{"0", []string{"a"}, "comment"}}) suite.NoError(err) // items that already exists err = suite.Database.BatchInsertItems(ctx, []Item{{ItemId: "0", Labels: []string{"b"}, Timestamp: time.Date(1996, 4, 8, 10, 0, 0, 0, time.UTC)}}) suite.NoError(err) // insert feedbacks timestamp := time.Date(1996, 3, 15, 0, 0, 0, 0, time.UTC) feedback := []Feedback{ {FeedbackKey: FeedbackKey{positiveFeedbackType1, "0", "8"}, Value: 1, Timestamp: timestamp, Comment: "comment"}, {FeedbackKey: FeedbackKey{positiveFeedbackType1, "1", "6"}, Value: 1, Timestamp: timestamp, Comment: "comment"}, {FeedbackKey: FeedbackKey{positiveFeedbackType2, "2", "4"}, Value: 1, Timestamp: timestamp, Comment: "comment"}, {FeedbackKey: FeedbackKey{positiveFeedbackType2, "3", "2"}, Value: 1, Timestamp: timestamp, Comment: "comment"}, {FeedbackKey: FeedbackKey{positiveFeedbackType2, "4", "0"}, Value: 1, Timestamp: timestamp, Comment: "comment"}, } err = suite.Database.BatchInsertFeedback(ctx, feedback, true, true, true) suite.NoError(err) // set Updated for comparison for i := range feedback { feedback[i].Updated = feedback[i].Timestamp } // other type err = suite.Database.BatchInsertFeedback(ctx, []Feedback{{FeedbackKey: FeedbackKey{negativeFeedbackType, "0", "2"}}}, true, true, true) suite.NoError(err) err = suite.Database.BatchInsertFeedback(ctx, []Feedback{{FeedbackKey: FeedbackKey{negativeFeedbackType, "2", "4"}}}, true, true, true) suite.NoError(err) // future feedback futureFeedback := []Feedback{ {FeedbackKey: FeedbackKey{duplicateFeedbackType, "0", "0"}, Value: 0, Timestamp: time.Now().Add(time.Hour), Comment: "comment"}, {FeedbackKey: FeedbackKey{duplicateFeedbackType, "1", "2"}, Value: 0, Timestamp: time.Now().Add(time.Hour), Comment: "comment"}, {FeedbackKey: FeedbackKey{duplicateFeedbackType, "2", "4"}, Value: 0, Timestamp: time.Now().Add(time.Hour), Comment: "comment"}, {FeedbackKey: FeedbackKey{duplicateFeedbackType, "3", "6"}, Value: 0, Timestamp: time.Now().Add(time.Hour), Comment: "comment"}, {FeedbackKey: FeedbackKey{duplicateFeedbackType, "4", "8"}, Value: 0, Timestamp: time.Now().Add(time.Hour), Comment: "comment"}, } err = suite.Database.BatchInsertFeedback(ctx, futureFeedback, true, true, true) suite.NoError(err) // Count feedback suite.analyzeTables() count, err := suite.Database.CountFeedback(ctx) suite.NoError(err) suite.Equal(12, count) // Get feedback ret := suite.getFeedback(ctx, 3, nil, lo.ToPtr(time.Now()), positiveFeedbackType1, positiveFeedbackType2) suite.Equal(feedback, ret) ret = suite.getFeedback(ctx, 2, nil, lo.ToPtr(time.Now())) suite.Equal(len(feedback)+2, len(ret)) ret = suite.getFeedback(ctx, 2, lo.ToPtr(timestamp.Add(time.Second)), lo.ToPtr(time.Now())) suite.Empty(ret) // Get feedback stream feedbackFromStream := suite.getFeedbackStream(ctx, 3, WithEndTime(time.Now()), WithFeedbackTypes( expression.MustParseFeedbackTypeExpression(positiveFeedbackType1), expression.MustParseFeedbackTypeExpression(positiveFeedbackType2))) suite.ElementsMatch(feedback, feedbackFromStream) feedbackFromStream = suite.getFeedbackStream(ctx, 3, WithEndTime(time.Now())) suite.Equal(len(feedback)+2, len(feedbackFromStream)) feedbackFromStream = suite.getFeedbackStream(ctx, 3, WithBeginTime(timestamp.Add(time.Second)), WithEndTime(time.Now())) suite.Empty(feedbackFromStream) feedbackFromStream = suite.getFeedbackStream(ctx, 3, WithBeginUserId("1"), WithEndUserId("3"), WithEndTime(time.Now()), WithFeedbackTypes( expression.MustParseFeedbackTypeExpression(positiveFeedbackType1), expression.MustParseFeedbackTypeExpression(positiveFeedbackType2))) suite.Equal(feedback[1:4], feedbackFromStream) feedbackFromStream = suite.getFeedbackStream(ctx, 3, WithBeginItemId("2"), WithEndItemId("6"), WithEndTime(time.Now()), WithFeedbackTypes( expression.MustParseFeedbackTypeExpression(positiveFeedbackType1), expression.MustParseFeedbackTypeExpression(positiveFeedbackType2)), WithOrderByItemId()) suite.Equal([]Feedback{feedback[3], feedback[2], feedback[1]}, feedbackFromStream) // Get items err = suite.Database.Optimize() suite.NoError(err) items := suite.getItems(ctx, 3) suite.Equal(5, len(items)) for i, item := range items { suite.Equal(strconv.Itoa(i*2), item.ItemId) if item.ItemId != "0" { if suite.isClickHouse() { // ClickHouse returns 1900-01-01 00:00:00 +0000 UTC as zero date. suite.Equal(dateTime64Zero, item.Timestamp) } else { suite.Zero(item.Timestamp) } suite.Empty(item.Labels) suite.Empty(item.Comment) } } // Get users users := suite.getUsers(ctx, 2) suite.Equal(5, len(users)) for i, user := range users { suite.Equal(strconv.Itoa(i), user.UserId) if user.UserId != "0" { suite.Empty(user.Labels) suite.Empty(user.Comment) } } // check users that already exists user, err := suite.Database.GetUser(ctx, "0") suite.NoError(err) suite.Equal(User{"0", []any{"a"}, "comment"}, user) // check items that already exists item, err := suite.Database.GetItem(ctx, "0") suite.NoError(err) suite.Equal(Item{ItemId: "0", Labels: []any{"b"}, Timestamp: time.Date(1996, 4, 8, 10, 0, 0, 0, time.UTC)}, item) // Get typed feedback by user ret, err = suite.Database.GetUserFeedback(ctx, "2", lo.ToPtr(time.Now()), expression.MustParseFeedbackTypeExpression(positiveFeedbackType1), expression.MustParseFeedbackTypeExpression(positiveFeedbackType2)) suite.NoError(err) if suite.Equal(1, len(ret)) { suite.Equal("2", ret[0].UserId) suite.Equal("4", ret[0].ItemId) } // Get all feedback by user ret, err = suite.Database.GetUserFeedback(ctx, "2", lo.ToPtr(time.Now())) suite.NoError(err) suite.Equal(2, len(ret)) // Get typed feedback by item ret, err = suite.Database.GetItemFeedback(ctx, "4", positiveFeedbackType1, positiveFeedbackType2) suite.NoError(err) suite.Equal(1, len(ret)) suite.Equal("2", ret[0].UserId) suite.Equal("4", ret[0].ItemId) // Get all feedback by item ret, err = suite.Database.GetItemFeedback(ctx, "4") suite.NoError(err) suite.Equal(2, len(ret)) // test override err = suite.Database.BatchInsertFeedback(ctx, []Feedback{{ FeedbackKey: FeedbackKey{positiveFeedbackType, "0", "8"}, Value: 100, Timestamp: time.Date(1996, 4, 8, 0, 0, 0, 0, time.UTC), Comment: "override", }}, true, true, true) suite.NoError(err) err = suite.Database.Optimize() suite.NoError(err) // Get feedback by user with value filter ret, err = suite.Database.GetUserFeedback(ctx, "0", lo.ToPtr(time.Now()), expression.FeedbackTypeExpression{FeedbackType: positiveFeedbackType, Value: 50, ExprType: expression.Greater}) suite.NoError(err) suite.Equal(1, len(ret)) suite.Equal(float64(100), ret[0].Value) ret, err = suite.Database.GetUserFeedback(ctx, "0", lo.ToPtr(time.Now()), expression.MustParseFeedbackTypeExpression(positiveFeedbackType)) suite.NoError(err) suite.Equal(1, len(ret)) suite.Equal(float64(100), ret[0].Value) suite.Equal(time.Date(1996, 4, 8, 0, 0, 0, 0, time.UTC), ret[0].Timestamp) suite.Equal("override", ret[0].Comment) // test not overwrite err = suite.Database.BatchInsertFeedback(ctx, []Feedback{{ FeedbackKey: FeedbackKey{positiveFeedbackType, "0", "8"}, Value: 80, Timestamp: time.Date(1996, 3, 15, 0, 0, 0, 0, time.UTC), Comment: "not_override", }}, true, true, false) suite.NoError(err) err = suite.Database.Optimize() suite.NoError(err) ret, err = suite.Database.GetUserFeedback(ctx, "0", lo.ToPtr(time.Now()), expression.MustParseFeedbackTypeExpression(positiveFeedbackType)) suite.NoError(err) suite.Equal(1, len(ret)) suite.Equal(float64(180), ret[0].Value) suite.Equal(time.Date(1996, 3, 15, 0, 0, 0, 0, time.UTC), ret[0].Timestamp) suite.Equal("not_override", ret[0].Comment) // insert no feedback err = suite.Database.BatchInsertFeedback(ctx, nil, true, true, true) suite.NoError(err) // not insert users or items err = suite.Database.BatchInsertFeedback(ctx, []Feedback{ {FeedbackKey: FeedbackKey{"a", "100", "200"}}, {FeedbackKey: FeedbackKey{"a", "0", "200"}}, {FeedbackKey: FeedbackKey{"a", "100", "8"}}, }, false, false, false) suite.NoError(err) result, err := suite.Database.GetUserItemFeedback(ctx, "100", "200") suite.NoError(err) suite.Empty(result) result, err = suite.Database.GetUserItemFeedback(ctx, "0", "200") suite.NoError(err) suite.Empty(result) result, err = suite.Database.GetUserItemFeedback(ctx, "100", "8") suite.NoError(err) suite.Empty(result) // insert valid feedback and invalid feedback at the same time err = suite.Database.BatchInsertFeedback(ctx, []Feedback{ {FeedbackKey: FeedbackKey{"a", "0", "8"}}, {FeedbackKey: FeedbackKey{"a", "100", "200"}}, }, false, false, false) suite.NoError(err) // insert duplicate feedback err = suite.Database.BatchInsertFeedback(ctx, []Feedback{ {FeedbackKey: FeedbackKey{"a", "0", "0"}, Value: 1, Timestamp: timestamp}, {FeedbackKey: FeedbackKey{"a", "0", "0"}, Value: 1, Timestamp: timestamp}, }, true, true, true) suite.NoError(err) err = suite.Database.BatchInsertFeedback(ctx, []Feedback{ {FeedbackKey: FeedbackKey{"a", "0", "0"}, Value: 1, Timestamp: timestamp}, }, true, true, false) suite.NoError(err) // check duplicate feedback ret, err = suite.Database.GetUserItemFeedback(ctx, "0", "0", "a") suite.NoError(err) suite.Equal([]Feedback{{FeedbackKey: FeedbackKey{"a", "0", "0"}, Value: 2, Timestamp: timestamp, Updated: timestamp, Comment: ""}}, ret) // put duplicate feedback err = suite.Database.BatchInsertFeedback(ctx, []Feedback{ {FeedbackKey: FeedbackKey{"a", "0", "0"}, Value: 1, Timestamp: timestamp}, }, true, true, true) suite.NoError(err) // check duplicate feedback again ret, err = suite.Database.GetUserItemFeedback(ctx, "0", "0", "a") suite.NoError(err) if suite.isClickHouse() { suite.Equal([]Feedback{{FeedbackKey: FeedbackKey{"a", "0", "0"}, Value: 3, Timestamp: timestamp, Updated: timestamp, Comment: ""}}, ret) } else { suite.Equal([]Feedback{{FeedbackKey: FeedbackKey{"a", "0", "0"}, Value: 1, Timestamp: timestamp, Updated: timestamp, Comment: ""}}, ret) } } func (suite *baseTestSuite) TestItems() { ctx := suite.T().Context() // Items items := []Item{ { ItemId: "0", IsHidden: true, Categories: []string{"a"}, Timestamp: time.Date(1996, 3, 15, 0, 0, 0, 0, time.UTC), Labels: []any{"a"}, Comment: "comment 0", }, { ItemId: "2", Categories: []string{"b"}, Timestamp: time.Date(1997, 3, 15, 0, 0, 0, 0, time.UTC), Labels: []any{"a"}, Comment: "comment 2", }, { ItemId: "4", IsHidden: true, Categories: []string{"a"}, Timestamp: time.Date(1998, 3, 15, 0, 0, 0, 0, time.UTC), Labels: []any{"a", "b"}, Comment: "comment 4", }, { ItemId: "6", Categories: []string{"b"}, Timestamp: time.Date(1999, 3, 15, 0, 0, 0, 0, time.UTC), Labels: []any{"b"}, Comment: "comment 6", }, { ItemId: "8", IsHidden: true, Categories: []string{"a"}, Timestamp: time.Date(2000, 3, 15, 0, 0, 0, 0, time.UTC), Labels: []any{"b"}, Comment: "comment 8", }, } // Insert item err := suite.Database.BatchInsertItems(ctx, items) suite.NoError(err) // Count items suite.analyzeTables() count, err := suite.Database.CountItems(ctx) suite.NoError(err) suite.Equal(5, count) // Get items totalItems := suite.getItems(ctx, 3) suite.Equal(items, totalItems) // Get item stream itemsFromStream := suite.getItemStream(ctx, 3) suite.ElementsMatch(items, itemsFromStream) // Get item for _, item := range items { ret, err := suite.Database.GetItem(ctx, item.ItemId) suite.NoError(err) suite.Equal(item, ret) } // batch get items batchItem, err := suite.Database.BatchGetItems(ctx, []string{"2", "6"}) suite.NoError(err) suite.Equal([]Item{items[1], items[3]}, batchItem) // Test GetLatestItems latestItems, err := suite.Database.GetLatestItems(ctx, 3, nil) suite.NoError(err) suite.Equal([]Item{items[3], items[1]}, latestItems) latestItemsWithCategory, err := suite.Database.GetLatestItems(ctx, 3, []string{"b"}) suite.NoError(err) suite.Equal([]Item{items[3], items[1]}, latestItemsWithCategory) // Delete item err = suite.Database.DeleteItem(ctx, "0") suite.NoError(err) _, err = suite.Database.GetItem(ctx, "0") suite.True(errors.Is(err, errors.NotFound), err) // test override err = suite.Database.BatchInsertItems(ctx, []Item{{ItemId: "4", IsHidden: false, Categories: []string{"b"}, Labels: []string{"o"}, Comment: "override"}}) suite.NoError(err) err = suite.Database.Optimize() suite.NoError(err) item, err := suite.Database.GetItem(ctx, "4") suite.NoError(err) suite.False(item.IsHidden) suite.Equal([]string{"b"}, item.Categories) suite.Equal([]any{"o"}, item.Labels) suite.Equal("override", item.Comment) // test modify timestamp := time.Date(2000, 1, 1, 1, 1, 1, 0, time.UTC) err = suite.Database.ModifyItem(ctx, "2", ItemPatch{IsHidden: new(true)}) suite.NoError(err) err = suite.Database.ModifyItem(ctx, "2", ItemPatch{Categories: []string{"a"}}) suite.NoError(err) err = suite.Database.ModifyItem(ctx, "2", ItemPatch{Comment: new("modify")}) suite.NoError(err) err = suite.Database.ModifyItem(ctx, "2", ItemPatch{Labels: []string{"a", "b", "c"}}) suite.NoError(err) err = suite.Database.ModifyItem(ctx, "2", ItemPatch{Timestamp: ×tamp}) suite.NoError(err) err = suite.Database.Optimize() suite.NoError(err) item, err = suite.Database.GetItem(ctx, "2") suite.NoError(err) suite.True(item.IsHidden) suite.Equal([]string{"a"}, item.Categories) suite.Equal("modify", item.Comment) suite.Equal([]any{"a", "b", "c"}, item.Labels) suite.Equal(timestamp, item.Timestamp) // test insert empty err = suite.Database.BatchInsertItems(ctx, nil) suite.NoError(err) // test get empty items, err = suite.Database.BatchGetItems(ctx, nil) suite.NoError(err) suite.Empty(items) // test insert duplicate items err = suite.Database.BatchInsertItems(ctx, []Item{{ItemId: "1"}, {ItemId: "1"}}) suite.NoError(err) } func (suite *baseTestSuite) TestDeleteUser() { ctx := suite.T().Context() // Insert ret feedback := []Feedback{ {FeedbackKey: FeedbackKey{positiveFeedbackType, "a", "0"}, Value: 0, Timestamp: time.Date(1996, 3, 15, 0, 0, 0, 0, time.UTC), Comment: "comment"}, {FeedbackKey: FeedbackKey{positiveFeedbackType, "a", "2"}, Value: 0, Timestamp: time.Date(1996, 3, 15, 0, 0, 0, 0, time.UTC), Comment: "comment"}, {FeedbackKey: FeedbackKey{positiveFeedbackType, "a", "4"}, Value: 0, Timestamp: time.Date(1996, 3, 15, 0, 0, 0, 0, time.UTC), Comment: "comment"}, {FeedbackKey: FeedbackKey{positiveFeedbackType, "a", "6"}, Value: 0, Timestamp: time.Date(1996, 3, 15, 0, 0, 0, 0, time.UTC), Comment: "comment"}, {FeedbackKey: FeedbackKey{positiveFeedbackType, "a", "8"}, Value: 0, Timestamp: time.Date(1996, 3, 15, 0, 0, 0, 0, time.UTC), Comment: "comment"}, } err := suite.Database.BatchInsertFeedback(ctx, feedback, true, true, true) suite.NoError(err) // Delete user err = suite.Database.DeleteUser(ctx, "a") suite.NoError(err) _, err = suite.Database.GetUser(ctx, "a") suite.NotNil(err, "failed to delete user") ret, err := suite.Database.GetUserFeedback(ctx, "a", lo.ToPtr(time.Now()), expression.MustParseFeedbackTypeExpression(positiveFeedbackType)) suite.NoError(err) suite.Equal(0, len(ret)) _, ret, err = suite.Database.GetFeedback(ctx, "", 100, nil, lo.ToPtr(time.Now()), positiveFeedbackType) suite.NoError(err) suite.Empty(ret) } func (suite *baseTestSuite) TestDeleteItem() { ctx := suite.T().Context() // Insert ret feedbacks := []Feedback{ {FeedbackKey: FeedbackKey{positiveFeedbackType, "0", "b"}, Value: 0, Timestamp: time.Date(1996, 3, 15, 0, 0, 0, 0, time.UTC), Comment: "comment"}, {FeedbackKey: FeedbackKey{positiveFeedbackType, "1", "b"}, Value: 0, Timestamp: time.Date(1996, 3, 15, 0, 0, 0, 0, time.UTC), Comment: "comment"}, {FeedbackKey: FeedbackKey{positiveFeedbackType, "2", "b"}, Value: 0, Timestamp: time.Date(1996, 3, 15, 0, 0, 0, 0, time.UTC), Comment: "comment"}, {FeedbackKey: FeedbackKey{positiveFeedbackType, "3", "b"}, Value: 0, Timestamp: time.Date(1996, 3, 15, 0, 0, 0, 0, time.UTC), Comment: "comment"}, {FeedbackKey: FeedbackKey{positiveFeedbackType, "4", "b"}, Value: 0, Timestamp: time.Date(1996, 3, 15, 0, 0, 0, 0, time.UTC), Comment: "comment"}, } err := suite.Database.BatchInsertFeedback(ctx, feedbacks, true, true, true) suite.NoError(err) // Delete item err = suite.Database.DeleteItem(ctx, "b") suite.NoError(err) _, err = suite.Database.GetItem(ctx, "b") suite.Error(err, "failed to delete item") ret, err := suite.Database.GetItemFeedback(ctx, "b", positiveFeedbackType) suite.NoError(err) suite.Equal(0, len(ret)) _, ret, err = suite.Database.GetFeedback(ctx, "", 100, nil, lo.ToPtr(time.Now()), positiveFeedbackType) suite.NoError(err) suite.Empty(ret) } func (suite *baseTestSuite) TestDeleteFeedback() { ctx := suite.T().Context() feedbacks := []Feedback{ {FeedbackKey: FeedbackKey{"type1", "2", "3"}, Value: 0, Timestamp: time.Date(1996, 3, 15, 0, 0, 0, 0, time.UTC), Comment: "comment"}, {FeedbackKey: FeedbackKey{"type2", "2", "3"}, Value: 0, Timestamp: time.Date(1996, 3, 15, 0, 0, 0, 0, time.UTC), Comment: "comment"}, {FeedbackKey: FeedbackKey{"type3", "2", "3"}, Value: 0, Timestamp: time.Date(1996, 3, 15, 0, 0, 0, 0, time.UTC), Comment: "comment"}, {FeedbackKey: FeedbackKey{"type1", "2", "4"}, Value: 0, Timestamp: time.Date(1996, 3, 15, 0, 0, 0, 0, time.UTC), Comment: "comment"}, {FeedbackKey: FeedbackKey{"type1", "1", "3"}, Value: 0, Timestamp: time.Date(1996, 3, 15, 0, 0, 0, 0, time.UTC), Comment: "comment"}, } err := suite.Database.BatchInsertFeedback(ctx, feedbacks, true, true, true) suite.NoError(err) // set Updated for comparison for i := range feedbacks { feedbacks[i].Updated = feedbacks[i].Timestamp } // get user-item feedback ret, err := suite.Database.GetUserItemFeedback(ctx, "2", "3") suite.NoError(err) suite.ElementsMatch([]Feedback{feedbacks[0], feedbacks[1], feedbacks[2]}, ret) feedbackType2 := "type2" ret, err = suite.Database.GetUserItemFeedback(ctx, "2", "3", feedbackType2) suite.NoError(err) suite.Equal([]Feedback{feedbacks[1]}, ret) // delete user-item feedback deleteCount, err := suite.Database.DeleteUserItemFeedback(ctx, "2", "3") suite.NoError(err) if !suite.isClickHouse() { // RowAffected isn't supported by ClickHouse, suite.Equal(3, deleteCount) } err = suite.Database.Optimize() suite.NoError(err) ret, err = suite.Database.GetUserItemFeedback(ctx, "2", "3") suite.NoError(err) suite.Empty(ret) feedbackType1 := "type1" deleteCount, err = suite.Database.DeleteUserItemFeedback(ctx, "1", "3", feedbackType1) suite.NoError(err) if !suite.isClickHouse() { // RowAffected isn't supported by ClickHouse, suite.Equal(1, deleteCount) } ret, err = suite.Database.GetUserItemFeedback(ctx, "1", "3", feedbackType2) suite.NoError(err) suite.Empty(ret) } func (suite *baseTestSuite) TestTimeLimit() { ctx := suite.T().Context() // insert items items := []Item{ { ItemId: "0", Timestamp: time.Date(1996, 3, 15, 0, 0, 0, 0, time.UTC), Labels: []any{"a"}, Comment: "comment 0", }, { ItemId: "2", Timestamp: time.Date(1997, 3, 15, 0, 0, 0, 0, time.UTC), Labels: []any{"a"}, Comment: "comment 2", }, { ItemId: "4", Timestamp: time.Date(1998, 3, 15, 0, 0, 0, 0, time.UTC), Labels: []any{"a", "b"}, Comment: "comment 4", }, { ItemId: "6", Timestamp: time.Date(1999, 3, 15, 0, 0, 0, 0, time.UTC), Labels: []any{"b"}, Comment: "comment 6", }, { ItemId: "8", Timestamp: time.Date(2000, 3, 15, 0, 0, 0, 0, time.UTC), Labels: []any{"b"}, Comment: "comment 8", }, } err := suite.Database.BatchInsertItems(ctx, items) suite.NoError(err) timeLimit := time.Date(1998, 1, 1, 0, 0, 0, 0, time.UTC) _, ret, err := suite.Database.GetItems(ctx, "", 100, &timeLimit) suite.NoError(err) suite.Equal([]Item{items[2], items[3], items[4]}, ret) // insert feedback feedbacks := []Feedback{ {FeedbackKey: FeedbackKey{"type1", "2", "3"}, Value: 0, Timestamp: time.Date(1996, 3, 15, 0, 0, 0, 0, time.UTC), Comment: "comment"}, {FeedbackKey: FeedbackKey{"type2", "2", "3"}, Value: 0, Timestamp: time.Date(1997, 3, 15, 0, 0, 0, 0, time.UTC), Comment: "comment"}, {FeedbackKey: FeedbackKey{"type3", "2", "3"}, Value: 0, Timestamp: time.Date(1998, 3, 15, 0, 0, 0, 0, time.UTC), Comment: "comment"}, {FeedbackKey: FeedbackKey{"type1", "2", "4"}, Value: 0, Timestamp: time.Date(1999, 3, 15, 0, 0, 0, 0, time.UTC), Comment: "comment"}, {FeedbackKey: FeedbackKey{"type1", "1", "3"}, Value: 0, Timestamp: time.Date(2000, 3, 15, 0, 0, 0, 0, time.UTC), Comment: "comment"}, } err = suite.Database.BatchInsertFeedback(ctx, feedbacks, true, true, true) suite.NoError(err) // set Updated for comparison for i := range feedbacks { feedbacks[i].Updated = feedbacks[i].Timestamp } _, retFeedback, err := suite.Database.GetFeedback(ctx, "", 100, &timeLimit, lo.ToPtr(time.Now())) suite.NoError(err) suite.Equal([]Feedback{feedbacks[4], feedbacks[3], feedbacks[2]}, retFeedback) typeFilter := "type1" _, retFeedback, err = suite.Database.GetFeedback(ctx, "", 100, &timeLimit, lo.ToPtr(time.Now()), typeFilter) suite.NoError(err) suite.Equal([]Feedback{feedbacks[4], feedbacks[3]}, retFeedback) } func (suite *baseTestSuite) TestTimezone() { ctx := suite.T().Context() loc, err := time.LoadLocation("Asia/Tokyo") suite.NoError(err) // insert feedbacks err = suite.Database.BatchInsertFeedback(ctx, []Feedback{ {FeedbackKey: FeedbackKey{"read", "1", "1"}, Timestamp: time.Now().Add(-time.Second).In(loc)}, {FeedbackKey: FeedbackKey{"read", "1", "2"}, Timestamp: time.Now().Add(-time.Second).In(loc)}, {FeedbackKey: FeedbackKey{"read", "2", "2"}, Timestamp: time.Now().Add(-time.Second).In(loc)}, {FeedbackKey: FeedbackKey{"like", "1", "1"}, Timestamp: time.Now().Add(time.Hour).In(loc)}, {FeedbackKey: FeedbackKey{"like", "1", "2"}, Timestamp: time.Now().Add(time.Hour).In(loc)}, {FeedbackKey: FeedbackKey{"like", "2", "2"}, Timestamp: time.Now().Add(time.Hour).In(loc)}, }, true, true, true) suite.NoError(err) // get feedback stream feedback := suite.getFeedback(ctx, 10, nil, lo.ToPtr(time.Now())) suite.Equal(3, len(feedback)) // get feedback _, feedback, err = suite.Database.GetFeedback(ctx, "", 10, nil, lo.ToPtr(time.Now())) suite.NoError(err) suite.Equal(3, len(feedback)) // get user feedback feedback, err = suite.Database.GetUserFeedback(ctx, "1", lo.ToPtr(time.Now())) suite.NoError(err) suite.Equal(2, len(feedback)) // get item feedback feedback, err = suite.Database.GetItemFeedback(ctx, "2") // no future feedback by default suite.NoError(err) suite.Equal(2, len(feedback)) // get user item feedback feedback, err = suite.Database.GetUserItemFeedback(ctx, "1", "1") // return future feedback by default suite.NoError(err) suite.Equal(2, len(feedback)) // insert items now := time.Now().In(loc) err = suite.Database.BatchInsertItems(ctx, []Item{{ItemId: "100", Timestamp: now}, {ItemId: "200"}}) suite.NoError(err) err = suite.Database.ModifyItem(ctx, "200", ItemPatch{Timestamp: &now}) suite.NoError(err) err = suite.Database.Optimize() suite.NoError(err) switch database := suite.Database.(type) { case *SQLDatabase: switch suite.Database.(*SQLDatabase).driver { case Postgres: item, err := suite.Database.GetItem(ctx, "100") suite.NoError(err) suite.Equal(now.Round(time.Microsecond).In(time.UTC), item.Timestamp) item, err = suite.Database.GetItem(ctx, "200") suite.NoError(err) suite.Equal(now.Round(time.Microsecond).In(time.UTC), item.Timestamp) case ClickHouse: item, err := suite.Database.GetItem(ctx, "100") suite.NoError(err) suite.Equal(now.Truncate(time.Second).In(time.UTC), item.Timestamp) item, err = suite.Database.GetItem(ctx, "200") suite.NoError(err) suite.Equal(now.Truncate(time.Second).In(time.UTC), item.Timestamp) case SQLite: item, err := suite.Database.GetItem(ctx, "100") suite.NoError(err) suite.Equal(now.In(time.UTC), item.Timestamp.In(time.UTC)) item, err = suite.Database.GetItem(ctx, "200") suite.NoError(err) suite.Equal(now.In(time.UTC), item.Timestamp.In(time.UTC)) default: suite.T().Skipf("unknown sql database: %v", database.driver) } case *MongoDB: item, err := suite.Database.GetItem(ctx, "100") suite.NoError(err) suite.Equal(now.Truncate(time.Millisecond).In(time.UTC), item.Timestamp) item, err = suite.Database.GetItem(ctx, "200") suite.NoError(err) suite.Equal(now.Truncate(time.Millisecond).In(time.UTC), item.Timestamp) default: suite.T().Skipf("unknown database: %v", reflect.TypeOf(suite.Database)) } } func (suite *baseTestSuite) TestCollation() { ctx := suite.T().Context() err := suite.Database.BatchInsertFeedback(ctx, []Feedback{ {FeedbackKey: FeedbackKey{"type", "user", "A"}}, {FeedbackKey: FeedbackKey{"type", "user", "a"}}, {FeedbackKey: FeedbackKey{"type", "user", "B"}}, {FeedbackKey: FeedbackKey{"type", "user", "b"}}, }, true, true, true) suite.NoError(err) feedbacks := suite.getFeedbackStream(ctx, 10, WithBeginItemId("a"), WithEndItemId("B")) suite.Empty(feedbacks) } func (suite *baseTestSuite) TestPurge() { ctx := suite.T().Context() // insert data err := suite.Database.BatchInsertFeedback(ctx, lo.Map(lo.Range(100), func(t int, i int) Feedback { return Feedback{FeedbackKey: FeedbackKey{ FeedbackType: "click", UserId: strconv.Itoa(t), ItemId: strconv.Itoa(t), }} }), true, true, true) suite.NoError(err) _, users, err := suite.Database.GetUsers(ctx, "", 100) suite.NoError(err) suite.Equal(100, len(users)) _, items, err := suite.Database.GetItems(ctx, "", 100, nil) suite.NoError(err) suite.Equal(100, len(items)) _, feedbacks, err := suite.Database.GetFeedback(ctx, "", 100, nil, lo.ToPtr(time.Now())) suite.NoError(err) suite.Equal(100, len(feedbacks)) // purge data err = suite.Database.Purge() suite.NoError(err) _, users, err = suite.Database.GetUsers(ctx, "", 100) suite.NoError(err) suite.Empty(users) _, items, err = suite.Database.GetItems(ctx, "", 100, nil) suite.NoError(err) suite.Empty(items) _, feedbacks, err = suite.Database.GetFeedback(ctx, "", 100, nil, lo.ToPtr(time.Now())) suite.NoError(err) suite.Empty(feedbacks) // purge empty database err = suite.Database.Purge() suite.NoError(err) } func TestSortFeedbacks(t *testing.T) { feedback := []Feedback{ {FeedbackKey: FeedbackKey{"star", "1", "1"}, Timestamp: time.Date(2000, 10, 1, 0, 0, 0, 0, time.UTC)}, {FeedbackKey: FeedbackKey{"like", "1", "1"}, Timestamp: time.Date(2001, 10, 1, 0, 0, 0, 0, time.UTC)}, {FeedbackKey: FeedbackKey{"read", "1", "1"}, Timestamp: time.Date(2002, 10, 1, 0, 0, 0, 0, time.UTC)}, } SortFeedbacks(feedback) assert.Equal(t, []Feedback{ {FeedbackKey: FeedbackKey{"read", "1", "1"}, Timestamp: time.Date(2002, 10, 1, 0, 0, 0, 0, time.UTC)}, {FeedbackKey: FeedbackKey{"like", "1", "1"}, Timestamp: time.Date(2001, 10, 1, 0, 0, 0, 0, time.UTC)}, {FeedbackKey: FeedbackKey{"star", "1", "1"}, Timestamp: time.Date(2000, 10, 1, 0, 0, 0, 0, time.UTC)}, }, feedback) } func TestValidateLabels(t *testing.T) { assert.NoError(t, ValidateLabels(nil)) assert.NoError(t, ValidateLabels(json.Number("1"))) assert.NoError(t, ValidateLabels("label")) assert.NoError(t, ValidateLabels([]any{json.Number("1"), json.Number("2"), json.Number("3")})) assert.NoError(t, ValidateLabels([]any{"1", "2", "3"})) assert.NoError(t, ValidateLabels(map[string]any{"city": json.Number("1"), "tags": []any{json.Number("1"), json.Number("2"), json.Number("3")}})) assert.NoError(t, ValidateLabels(map[string]any{"city": "wenzhou", "tags": []any{"1", "2", "3"}})) assert.NoError(t, ValidateLabels(map[string]any{"address": map[string]any{"province": json.Number("1"), "city": json.Number("2")}})) assert.NoError(t, ValidateLabels(map[string]any{"address": map[string]any{"province": "zhejiang", "city": "wenzhou"}})) assert.Error(t, ValidateLabels(map[string]any{"price": 100, "tags": []any{json.Number("1"), "2", "3"}})) assert.Error(t, ValidateLabels(map[string]any{"city": "wenzhou", "tags": []any{"1", json.Number("2"), "3"}})) assert.Error(t, ValidateLabels(map[string]any{"city": "wenzhou", "tags": []any{"1", "2", json.Number("3")}})) } func benchmarkCountItems(b *testing.B, db Database) { ctx := b.Context() // Insert 10,000 items items := make([]Item, 100000) for i := range items { items[i] = Item{ItemId: strconv.Itoa(i)} } err := db.BatchInsertItems(ctx, items) require.NoError(b, err) // Benchmark count items b.ResetTimer() for i := 0; i < b.N; i++ { n, err := db.CountItems(ctx) require.NoError(b, err) require.Equal(b, 100000, n) } } ================================================ FILE: storage/data/mongodb.go ================================================ // Copyright 2021 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package data import ( "context" "encoding/base64" "encoding/json" "time" mapset "github.com/deckarep/golang-set/v2" "github.com/gorse-io/gorse/common/expression" "github.com/gorse-io/gorse/storage" "github.com/juju/errors" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/x/mongo/driver/connstring" "go.opentelemetry.io/contrib/instrumentation/go.mongodb.org/mongo-driver/mongo/otelmongo" ) func init() { Register([]string{storage.MongoPrefix, storage.MongoSrvPrefix}, func(path, tablePrefix string, opts ...storage.Option) (Database, error) { // connect to database database := new(MongoDB) clientOpts := options.Client() clientOpts.Monitor = otelmongo.NewMonitor() clientOpts.ApplyURI(path) var err error if database.client, err = mongo.Connect(context.Background(), clientOpts); err != nil { return nil, errors.Trace(err) } // parse DSN and extract database name if cs, err := connstring.ParseAndValidate(path); err != nil { return nil, errors.Trace(err) } else { database.dbName = cs.Database database.TablePrefix = storage.TablePrefix(tablePrefix) } return database, nil }) } func feedbackKeyFromString(s string) (*FeedbackKey, error) { var feedbackKey FeedbackKey err := json.Unmarshal([]byte(s), &feedbackKey) return &feedbackKey, err } func (k *FeedbackKey) toString() (string, error) { b, err := json.Marshal(k) return string(b), err } func unpack(o any) any { if o == nil { return nil } switch p := o.(type) { case primitive.A: return []any(p) case primitive.D: m := make(map[string]any) for _, e := range p { m[e.Key] = unpack(e.Value) } return m default: return p } } func FeedbackTypeExpressionToMongo(e expression.FeedbackTypeExpression) bson.M { filter := bson.M{"feedbackkey.feedbacktype": e.FeedbackType} switch e.ExprType { case expression.Less: filter["value"] = bson.M{"$lt": e.Value} case expression.LessOrEqual: filter["value"] = bson.M{"$lte": e.Value} case expression.Greater: filter["value"] = bson.M{"$gt": e.Value} case expression.GreaterOrEqual: filter["value"] = bson.M{"$gte": e.Value} } return filter } // MongoDB is the data storage based on MongoDB. type MongoDB struct { storage.TablePrefix client *mongo.Client dbName string } // Optimize is used by ClickHouse only. func (db *MongoDB) Optimize() error { return nil } // Init collections and indices in MongoDB. func (db *MongoDB) Init() error { ctx := context.Background() d := db.client.Database(db.dbName) // list collections var hasUsers, hasItems, hasFeedback bool collections, err := d.ListCollectionNames(ctx, bson.M{}) if err != nil { return errors.Trace(err) } for _, collectionName := range collections { switch collectionName { case db.UsersTable(): hasUsers = true case db.ItemsTable(): hasItems = true case db.FeedbackTable(): hasFeedback = true } } // create collections if !hasUsers { if err = d.CreateCollection(ctx, db.UsersTable()); err != nil { return errors.Trace(err) } } if !hasItems { if err = d.CreateCollection(ctx, db.ItemsTable()); err != nil { return errors.Trace(err) } } if !hasFeedback { if err = d.CreateCollection(ctx, db.FeedbackTable()); err != nil { return errors.Trace(err) } } // create index _, err = d.Collection(db.UsersTable()).Indexes().CreateOne(ctx, mongo.IndexModel{ Keys: bson.M{ "userid": 1, }, Options: options.Index().SetUnique(true), }) if err != nil { return errors.Trace(err) } _, err = d.Collection(db.ItemsTable()).Indexes().CreateOne(ctx, mongo.IndexModel{ Keys: bson.M{ "itemid": 1, }, Options: options.Index().SetUnique(true), }) if err != nil { return errors.Trace(err) } _, err = d.Collection(db.FeedbackTable()).Indexes().CreateOne(ctx, mongo.IndexModel{ Keys: bson.M{ "feedbackkey": 1, }, Options: options.Index().SetUnique(true), }) if err != nil { return errors.Trace(err) } _, err = d.Collection(db.FeedbackTable()).Indexes().CreateOne(ctx, mongo.IndexModel{ Keys: bson.M{ "feedbackkey.userid": 1, }, }) if err != nil { return errors.Trace(err) } _, err = d.Collection(db.FeedbackTable()).Indexes().CreateOne(ctx, mongo.IndexModel{ Keys: bson.M{ "feedbackkey.itemid": 1, }, }) if err != nil { return errors.Trace(err) } _, err = d.Collection(db.ItemsTable()).Indexes().CreateOne(ctx, mongo.IndexModel{ Keys: bson.M{ "timestamp": 1, }, }) if err != nil { return errors.Trace(err) } return nil } func (db *MongoDB) Ping() error { return db.client.Ping(context.Background(), nil) } // Close connection to MongoDB. func (db *MongoDB) Close() error { return db.client.Disconnect(context.Background()) } func (db *MongoDB) Purge() error { tables := []string{db.ItemsTable(), db.FeedbackTable(), db.UsersTable()} for _, tableName := range tables { c := db.client.Database(db.dbName).Collection(tableName) _, err := c.DeleteMany(context.Background(), bson.D{}) if err != nil { return errors.Trace(err) } } return nil } // BatchInsertItems insert items into MongoDB. func (db *MongoDB) BatchInsertItems(ctx context.Context, items []Item) error { if len(items) == 0 { return nil } c := db.client.Database(db.dbName).Collection(db.ItemsTable()) var models []mongo.WriteModel for _, item := range items { models = append(models, mongo.NewUpdateOneModel(). SetUpsert(true). SetFilter(bson.M{"itemid": bson.M{"$eq": item.ItemId}}). SetUpdate(bson.M{"$set": item})) } _, err := c.BulkWrite(ctx, models) return errors.Trace(err) } func (db *MongoDB) BatchGetItems(ctx context.Context, itemIds []string) ([]Item, error) { if len(itemIds) == 0 { return nil, nil } c := db.client.Database(db.dbName).Collection(db.ItemsTable()) r, err := c.Find(ctx, bson.M{"itemid": bson.M{"$in": itemIds}}) if err != nil { return nil, errors.Trace(err) } items := make([]Item, 0) defer r.Close(ctx) for r.Next(ctx) { var item Item if err = r.Decode(&item); err != nil { return nil, errors.Trace(err) } item.Labels = unpack(item.Labels) items = append(items, item) } return items, nil } // ModifyItem modify an item in MongoDB. func (db *MongoDB) ModifyItem(ctx context.Context, itemId string, patch ItemPatch) error { // create update update := bson.M{} if patch.IsHidden != nil { update["ishidden"] = patch.IsHidden } if patch.Categories != nil { update["categories"] = patch.Categories } if patch.Comment != nil { update["comment"] = patch.Comment } if patch.Labels != nil { update["labels"] = patch.Labels } if patch.Timestamp != nil { update["timestamp"] = patch.Timestamp } // execute c := db.client.Database(db.dbName).Collection(db.ItemsTable()) _, err := c.UpdateOne(ctx, bson.M{"itemid": bson.M{"$eq": itemId}}, bson.M{"$set": update}) return errors.Trace(err) } // DeleteItem deletes a item from MongoDB. func (db *MongoDB) DeleteItem(ctx context.Context, itemId string) error { c := db.client.Database(db.dbName).Collection(db.ItemsTable()) _, err := c.DeleteOne(ctx, bson.M{"itemid": itemId}) if err != nil { return errors.Trace(err) } c = db.client.Database(db.dbName).Collection(db.FeedbackTable()) _, err = c.DeleteMany(ctx, bson.M{ "feedbackkey.itemid": bson.M{"$eq": itemId}, }) return errors.Trace(err) } // GetItem returns a item from MongoDB. func (db *MongoDB) GetItem(ctx context.Context, itemId string) (item Item, err error) { c := db.client.Database(db.dbName).Collection(db.ItemsTable()) r := c.FindOne(ctx, bson.M{"itemid": itemId}) if r.Err() == mongo.ErrNoDocuments { err = errors.Annotate(ErrItemNotExist, itemId) return } err = r.Decode(&item) item.Labels = unpack(item.Labels) return } // GetItems returns items from MongoDB. func (db *MongoDB) GetItems(ctx context.Context, cursor string, n int, timeLimit *time.Time) (string, []Item, error) { buf, err := base64.StdEncoding.DecodeString(cursor) if err != nil { return "", nil, errors.Trace(err) } cursorItem := string(buf) c := db.client.Database(db.dbName).Collection(db.ItemsTable()) opt := options.Find() opt.SetLimit(int64(n)) opt.SetSort(bson.D{{"itemid", 1}}) filter := bson.M{"itemid": bson.M{"$gt": cursorItem}} if timeLimit != nil { filter["timestamp"] = bson.M{"$gt": *timeLimit} } r, err := c.Find(ctx, filter, opt) if err != nil { return "", nil, err } items := make([]Item, 0) defer r.Close(ctx) for r.Next(ctx) { var item Item if err = r.Decode(&item); err != nil { return "", nil, err } item.Labels = unpack(item.Labels) items = append(items, item) } if len(items) == n { cursor = items[n-1].ItemId } else { cursor = "" } return base64.StdEncoding.EncodeToString([]byte(cursor)), items, nil } // GetLatestItems returns the latest items from MongoDB. func (db *MongoDB) GetLatestItems(ctx context.Context, n int, categories []string) ([]Item, error) { c := db.client.Database(db.dbName).Collection(db.ItemsTable()) opt := options.Find() opt.SetLimit(int64(n)) opt.SetSort(bson.D{{"timestamp", -1}}) filter := bson.M{"ishidden": bson.M{"$ne": true}} if len(categories) > 0 { filter["categories"] = bson.M{"$all": categories} } r, err := c.Find(ctx, filter, opt) if err != nil { return nil, err } items := make([]Item, 0) defer r.Close(ctx) for r.Next(ctx) { var item Item if err = r.Decode(&item); err != nil { return nil, err } item.Labels = unpack(item.Labels) items = append(items, item) } return items, nil } // GetItemStream read items from MongoDB by stream. func (db *MongoDB) GetItemStream(ctx context.Context, batchSize int, timeLimit *time.Time) (chan []Item, chan error) { itemChan := make(chan []Item, bufSize) errChan := make(chan error, 1) go func() { defer close(itemChan) defer close(errChan) // send query ctx := context.Background() c := db.client.Database(db.dbName).Collection(db.ItemsTable()) opt := options.Find() filter := bson.M{} if timeLimit != nil { filter["timestamp"] = bson.M{"$gt": *timeLimit} } r, err := c.Find(ctx, filter, opt) if err != nil { errChan <- errors.Trace(err) return } // fetch result items := make([]Item, 0, batchSize) defer r.Close(ctx) for r.Next(ctx) { var item Item if err = r.Decode(&item); err != nil { errChan <- errors.Trace(err) return } item.Labels = unpack(item.Labels) items = append(items, item) if len(items) == batchSize { itemChan <- items items = make([]Item, 0, batchSize) } } if len(items) > 0 { itemChan <- items } errChan <- nil }() return itemChan, errChan } // GetItemFeedback returns feedback of a item from MongoDB. func (db *MongoDB) GetItemFeedback(ctx context.Context, itemId string, feedbackTypes ...string) ([]Feedback, error) { c := db.client.Database(db.dbName).Collection(db.FeedbackTable()) var r *mongo.Cursor var err error filter := bson.M{ "feedbackkey.itemid": bson.M{"$eq": itemId}, "timestamp": bson.M{"$lte": time.Now()}, } if len(feedbackTypes) > 0 { var conditions []bson.M for _, feedbackType := range feedbackTypes { conditions = append(conditions, bson.M{ "feedbackkey.feedbacktype": bson.M{"$eq": feedbackType}, }) } filter["$or"] = conditions } r, err = c.Find(ctx, filter) if err != nil { return nil, err } feedbacks := make([]Feedback, 0) defer r.Close(ctx) for r.Next(ctx) { var feedback Feedback if err = r.Decode(&feedback); err != nil { return nil, err } feedbacks = append(feedbacks, feedback) } return feedbacks, nil } // BatchInsertUsers inserts a user into MongoDB. func (db *MongoDB) BatchInsertUsers(ctx context.Context, users []User) error { if len(users) == 0 { return nil } c := db.client.Database(db.dbName).Collection(db.UsersTable()) var models []mongo.WriteModel for _, user := range users { models = append(models, mongo.NewUpdateOneModel(). SetUpsert(true). SetFilter(bson.M{"userid": bson.M{"$eq": user.UserId}}). SetUpdate(bson.M{"$set": user})) } _, err := c.BulkWrite(ctx, models) return errors.Trace(err) } // ModifyUser modify a user in MongoDB. func (db *MongoDB) ModifyUser(ctx context.Context, userId string, patch UserPatch) error { // create patch update := bson.M{} if patch.Labels != nil { update["labels"] = patch.Labels } if patch.Comment != nil { update["comment"] = patch.Comment } // execute c := db.client.Database(db.dbName).Collection(db.UsersTable()) _, err := c.UpdateOne(ctx, bson.M{"userid": bson.M{"$eq": userId}}, bson.M{"$set": update}) return errors.Trace(err) } // DeleteUser deletes a user from MongoDB. func (db *MongoDB) DeleteUser(ctx context.Context, userId string) error { c := db.client.Database(db.dbName).Collection(db.UsersTable()) _, err := c.DeleteOne(ctx, bson.M{"userid": userId}) if err != nil { return errors.Trace(err) } c = db.client.Database(db.dbName).Collection(db.FeedbackTable()) _, err = c.DeleteMany(ctx, bson.M{ "feedbackkey.userid": bson.M{"$eq": userId}, }) return errors.Trace(err) } // GetUser returns a user from MongoDB. func (db *MongoDB) GetUser(ctx context.Context, userId string) (user User, err error) { c := db.client.Database(db.dbName).Collection(db.UsersTable()) r := c.FindOne(ctx, bson.M{"userid": userId}) if r.Err() == mongo.ErrNoDocuments { err = errors.Annotate(ErrUserNotExist, userId) return } err = r.Decode(&user) user.Labels = unpack(user.Labels) return } // GetUsers returns users from MongoDB. func (db *MongoDB) GetUsers(ctx context.Context, cursor string, n int) (string, []User, error) { buf, err := base64.StdEncoding.DecodeString(cursor) if err != nil { return "", nil, errors.Trace(err) } cursorUser := string(buf) c := db.client.Database(db.dbName).Collection(db.UsersTable()) opt := options.Find() opt.SetLimit(int64(n)) opt.SetSort(bson.D{{"userid", 1}}) r, err := c.Find(ctx, bson.M{"userid": bson.M{"$gt": cursorUser}}, opt) if err != nil { return "", nil, err } users := make([]User, 0) defer r.Close(ctx) for r.Next(ctx) { var user User if err = r.Decode(&user); err != nil { return "", nil, err } user.Labels = unpack(user.Labels) users = append(users, user) } if len(users) == n { cursor = users[n-1].UserId } else { cursor = "" } return base64.StdEncoding.EncodeToString([]byte(cursor)), users, nil } // GetUserStream reads users from MongoDB by stream. func (db *MongoDB) GetUserStream(ctx context.Context, batchSize int) (chan []User, chan error) { userChan := make(chan []User, bufSize) errChan := make(chan error, 1) go func() { defer close(userChan) defer close(errChan) // send query ctx := context.Background() c := db.client.Database(db.dbName).Collection(db.UsersTable()) opt := options.Find() r, err := c.Find(ctx, bson.M{}, opt) if err != nil { errChan <- errors.Trace(err) return } users := make([]User, 0, batchSize) defer r.Close(ctx) for r.Next(ctx) { var user User if err = r.Decode(&user); err != nil { errChan <- errors.Trace(err) return } user.Labels = unpack(user.Labels) users = append(users, user) if len(users) == batchSize { userChan <- users users = make([]User, 0, batchSize) } } if len(users) > 0 { userChan <- users } errChan <- nil }() return userChan, errChan } // GetUserFeedback returns feedback of a user from MongoDB. func (db *MongoDB) GetUserFeedback(ctx context.Context, userId string, endTime *time.Time, feedbackTypes ...expression.FeedbackTypeExpression) ([]Feedback, error) { c := db.client.Database(db.dbName).Collection(db.FeedbackTable()) var r *mongo.Cursor var err error filter := bson.M{ "feedbackkey.userid": bson.M{"$eq": userId}, } if endTime != nil { filter["timestamp"] = bson.M{"$lte": endTime} } if len(feedbackTypes) > 0 { var conditions []bson.M for _, feedbackType := range feedbackTypes { conditions = append(conditions, FeedbackTypeExpressionToMongo(feedbackType)) } filter["$or"] = conditions } r, err = c.Find(ctx, filter) if err != nil { return nil, err } feedbacks := make([]Feedback, 0) defer r.Close(ctx) for r.Next(ctx) { var feedback Feedback if err = r.Decode(&feedback); err != nil { return nil, err } feedbacks = append(feedbacks, feedback) } return feedbacks, nil } // BatchInsertFeedback returns multiple feedback into MongoDB. func (db *MongoDB) BatchInsertFeedback(ctx context.Context, feedback []Feedback, insertUser, insertItem, overwrite bool) error { // skip empty list if len(feedback) == 0 { return nil } // collect users and items users := mapset.NewSet[string]() items := mapset.NewSet[string]() for _, v := range feedback { users.Add(v.UserId) items.Add(v.ItemId) } // insert users userList := users.ToSlice() if insertUser { var models []mongo.WriteModel for _, userId := range userList { models = append(models, mongo.NewUpdateOneModel(). SetUpsert(true). SetFilter(bson.M{"userid": bson.M{"$eq": userId}}). SetUpdate(bson.M{"$setOnInsert": User{UserId: userId}})) } c := db.client.Database(db.dbName).Collection(db.UsersTable()) _, err := c.BulkWrite(ctx, models) if err != nil { return errors.Trace(err) } } else { for _, userId := range userList { _, err := db.GetUser(ctx, userId) if err != nil { if errors.Is(err, errors.NotFound) { users.Remove(userId) continue } return errors.Trace(err) } } } // insert items itemList := items.ToSlice() if insertItem { var models []mongo.WriteModel for _, itemId := range itemList { models = append(models, mongo.NewUpdateOneModel(). SetUpsert(true). SetFilter(bson.M{"itemid": bson.M{"$eq": itemId}}). SetUpdate(bson.M{"$setOnInsert": Item{ItemId: itemId}})) } c := db.client.Database(db.dbName).Collection(db.ItemsTable()) _, err := c.BulkWrite(ctx, models) if err != nil { return errors.Trace(err) } } else { for _, itemId := range itemList { _, err := db.GetItem(ctx, itemId) if err != nil { if errors.Is(err, errors.NotFound) { items.Remove(itemId) continue } return errors.Trace(err) } } } // insert feedback c := db.client.Database(db.dbName).Collection(db.FeedbackTable()) var models []mongo.WriteModel for _, f := range feedback { if users.Contains(f.UserId) && items.Contains(f.ItemId) { f.Updated = f.Timestamp model := mongo.NewUpdateOneModel(). SetUpsert(true). SetFilter(bson.M{ "feedbackkey": f.FeedbackKey, }) if overwrite { model.SetUpdate(bson.M{"$set": f}) } else { model.SetUpdate(bson.M{ "$setOnInsert": bson.M{ "feedbackkey": f.FeedbackKey, }, "$inc": bson.M{ "value": f.Value, }, "$min": bson.M{ "timestamp": f.Timestamp, }, "$max": bson.M{ "updated": f.Updated, }, "$set": bson.M{ "comment": f.Comment, }, }) } models = append(models, model) } } if len(models) == 0 { return nil } _, err := c.BulkWrite(ctx, models) return errors.Trace(err) } // GetFeedback returns multiple feedback from MongoDB. func (db *MongoDB) GetFeedback(ctx context.Context, cursor string, n int, beginTime, endTime *time.Time, feedbackTypes ...string) (string, []Feedback, error) { buf, err := base64.StdEncoding.DecodeString(cursor) if err != nil { return "", nil, errors.Trace(err) } c := db.client.Database(db.dbName).Collection(db.FeedbackTable()) opt := options.Find() opt.SetLimit(int64(n)) opt.SetSort(bson.D{{"feedbackkey", 1}}) filter := make(bson.M) // pass cursor to filter if len(buf) > 0 { feedbackKey, err := feedbackKeyFromString(string(buf)) if err != nil { return "", nil, err } filter["feedbackkey"] = bson.M{"$gt": feedbackKey} } // pass feedback type to filter if len(feedbackTypes) > 0 { var conditions []bson.M for _, feedbackType := range feedbackTypes { conditions = append(conditions, bson.M{ "feedbackkey.feedbacktype": bson.M{"$eq": feedbackType}, }) } filter["$or"] = conditions } // pass time limit to filter timestampConditions := bson.M{} if beginTime != nil { timestampConditions["$gt"] = *beginTime } if endTime != nil { timestampConditions["$lte"] = *endTime } filter["timestamp"] = timestampConditions r, err := c.Find(ctx, filter, opt) if err != nil { return "", nil, err } feedbacks := make([]Feedback, 0) defer r.Close(ctx) for r.Next(ctx) { var feedback Feedback if err = r.Decode(&feedback); err != nil { return "", nil, err } feedbacks = append(feedbacks, feedback) } if len(feedbacks) == n { cursor, err = feedbacks[n-1].toString() if err != nil { return "", nil, err } } else { cursor = "" } return base64.StdEncoding.EncodeToString([]byte(cursor)), feedbacks, nil } // GetFeedbackStream reads feedback from MongoDB by stream. func (db *MongoDB) GetFeedbackStream(ctx context.Context, batchSize int, scanOptions ...ScanOption) (chan []Feedback, chan error) { scan := NewScanOptions(scanOptions...) feedbackChan := make(chan []Feedback, bufSize) errChan := make(chan error, 1) go func() { defer close(feedbackChan) defer close(errChan) // send query ctx := context.Background() c := db.client.Database(db.dbName).Collection(db.FeedbackTable()) opt := options.Find() filter := make(bson.M) // pass feedback type to filter if len(scan.FeedbackTypes) > 0 { var conditions []bson.M for _, feedbackType := range scan.FeedbackTypes { conditions = append(conditions, FeedbackTypeExpressionToMongo(feedbackType)) } filter["$or"] = conditions } // pass time limit to filter if scan.BeginTime != nil || scan.EndTime != nil { timestampConditions := bson.M{} if scan.BeginTime != nil { timestampConditions["$gt"] = *scan.BeginTime } if scan.EndTime != nil { timestampConditions["$lte"] = *scan.EndTime } filter["timestamp"] = timestampConditions } // pass user id to filter if scan.BeginUserId != nil || scan.EndUserId != nil { userIdConditions := bson.M{} if scan.BeginUserId != nil { userIdConditions["$gte"] = *scan.BeginUserId } if scan.EndUserId != nil { userIdConditions["$lte"] = *scan.EndUserId } filter["feedbackkey.userid"] = userIdConditions } if scan.BeginItemId != nil || scan.EndItemId != nil { itemIdConditions := bson.M{} if scan.BeginItemId != nil { itemIdConditions["$gte"] = *scan.BeginItemId } if scan.EndItemId != nil { itemIdConditions["$lte"] = *scan.EndItemId } filter["feedbackkey.itemid"] = itemIdConditions } if scan.OrderByItemId { opt.SetSort(bson.D{{"feedbackkey.itemid", 1}}) } r, err := c.Find(ctx, filter, opt) if err != nil { errChan <- errors.Trace(err) return } feedbacks := make([]Feedback, 0, batchSize) defer r.Close(ctx) for r.Next(ctx) { var feedback Feedback if err = r.Decode(&feedback); err != nil { errChan <- errors.Trace(err) return } feedbacks = append(feedbacks, feedback) if len(feedbacks) == batchSize { feedbackChan <- feedbacks feedbacks = make([]Feedback, 0, batchSize) } } if len(feedbacks) > 0 { feedbackChan <- feedbacks } errChan <- nil }() return feedbackChan, errChan } // GetUserItemFeedback returns a feedback return the user id and item id from MongoDB. func (db *MongoDB) GetUserItemFeedback(ctx context.Context, userId, itemId string, feedbackTypes ...string) ([]Feedback, error) { c := db.client.Database(db.dbName).Collection(db.FeedbackTable()) var filter = bson.M{ "feedbackkey.userid": bson.M{"$eq": userId}, "feedbackkey.itemid": bson.M{"$eq": itemId}, } if len(feedbackTypes) > 0 { var conditions []bson.M for _, feedbackType := range feedbackTypes { conditions = append(conditions, bson.M{ "feedbackkey.feedbacktype": bson.M{"$eq": feedbackType}, }) } filter["$or"] = conditions } r, err := c.Find(ctx, filter) if err != nil { return nil, err } feedbacks := make([]Feedback, 0) defer r.Close(ctx) for r.Next(ctx) { var feedback Feedback if err = r.Decode(&feedback); err != nil { return nil, err } feedbacks = append(feedbacks, feedback) } return feedbacks, nil } // DeleteUserItemFeedback deletes a feedback return the user id and item id from MongoDB. func (db *MongoDB) DeleteUserItemFeedback(ctx context.Context, userId, itemId string, feedbackTypes ...string) (int, error) { c := db.client.Database(db.dbName).Collection(db.FeedbackTable()) var filter = bson.M{ "feedbackkey.userid": bson.M{"$eq": userId}, "feedbackkey.itemid": bson.M{"$eq": itemId}, } if len(feedbackTypes) > 0 { filter["feedbackkey.feedbacktype"] = bson.M{"$in": feedbackTypes} } r, err := c.DeleteMany(ctx, filter) if err != nil { return 0, err } return int(r.DeletedCount), nil } func (db *MongoDB) CountUsers(ctx context.Context) (int, error) { n, err := db.client.Database(db.dbName).Collection(db.UsersTable()).EstimatedDocumentCount(ctx) return int(n), err } func (db *MongoDB) CountItems(ctx context.Context) (int, error) { n, err := db.client.Database(db.dbName).Collection(db.ItemsTable()).EstimatedDocumentCount(ctx) return int(n), err } func (db *MongoDB) CountFeedback(ctx context.Context) (int, error) { n, err := db.client.Database(db.dbName).Collection(db.FeedbackTable()).EstimatedDocumentCount(ctx) return int(n), err } ================================================ FILE: storage/data/mongodb_test.go ================================================ // Copyright 2020 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package data import ( "os" "testing" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" ) var ( mongoUri string ) func init() { // get environment variables env := func(key, defaultValue string) string { if value := os.Getenv(key); value != "" { return value } return defaultValue } mongoUri = env("MONGO_URI", "mongodb://root:password@127.0.0.1:27017/") } type MongoTestSuite struct { baseTestSuite } func (suite *MongoTestSuite) SetupSuite() { ctx := suite.T().Context() var err error // create database suite.Database, err = Open(mongoUri, "gorse_") suite.NoError(err) dbName := "gorse_data_test" databaseComm := suite.getMongoDB() err = databaseComm.client.Database(dbName).Drop(ctx) if err == nil { suite.T().Log("delete existed database:", dbName) } err = suite.Database.Close() suite.NoError(err) // create schema suite.Database, err = Open(mongoUri+dbName+"?authSource=admin&connect=direct", "gorse_") suite.NoError(err) err = suite.Database.Init() suite.NoError(err) } func (suite *MongoTestSuite) getMongoDB() *MongoDB { var mongoDatabase *MongoDB var ok bool mongoDatabase, ok = suite.Database.(*MongoDB) suite.True(ok) return mongoDatabase } func TestMongo(t *testing.T) { suite.Run(t, new(MongoTestSuite)) } func BenchmarkMongo_CountItems(b *testing.B) { ctx := b.Context() var err error // create database database, err := Open(mongoUri, "gorse_") require.NoError(b, err) dbName := "gorse_data_test" databaseComm := database.(*MongoDB) err = databaseComm.client.Database(dbName).Drop(ctx) require.NoError(b, err) database, err = Open(mongoUri+dbName+"?authSource=admin&connect=direct", "gorse_") require.NoError(b, err) err = database.Init() require.NoError(b, err) // benchmark benchmarkCountItems(b, database) // close database err = database.Close() require.NoError(b, err) } ================================================ FILE: storage/data/no_database.go ================================================ // Copyright 2021 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package data import ( "context" "time" "github.com/gorse-io/gorse/common/expression" ) // NoDatabase means that no database used. type NoDatabase struct{} // Optimize is used by ClickHouse only. func (NoDatabase) Optimize() error { return ErrNoDatabase } // Init method of NoDatabase returns ErrNoDatabase. func (NoDatabase) Init() error { return ErrNoDatabase } func (NoDatabase) Ping() error { return ErrNoDatabase } // Close method of NoDatabase returns ErrNoDatabase. func (NoDatabase) Close() error { return ErrNoDatabase } func (NoDatabase) Purge() error { return ErrNoDatabase } // BatchInsertItems method of NoDatabase returns ErrNoDatabase. func (NoDatabase) BatchInsertItems(_ context.Context, _ []Item) error { return ErrNoDatabase } // BatchGetItems method of NoDatabase returns ErrNoDatabase. func (NoDatabase) BatchGetItems(_ context.Context, _ []string) ([]Item, error) { return nil, ErrNoDatabase } // DeleteItem method of NoDatabase returns ErrNoDatabase. func (NoDatabase) DeleteItem(_ context.Context, _ string) error { return ErrNoDatabase } // GetItem method of NoDatabase returns ErrNoDatabase. func (NoDatabase) GetItem(_ context.Context, _ string) (Item, error) { return Item{}, ErrNoDatabase } // GetItems method of NoDatabase returns ErrNoDatabase. func (NoDatabase) GetItems(_ context.Context, _ string, _ int, _ *time.Time) (string, []Item, error) { return "", nil, ErrNoDatabase } // GetLatestItems method of NoDatabase returns ErrNoDatabase. func (NoDatabase) GetLatestItems(_ context.Context, _ int, _ []string) ([]Item, error) { return nil, ErrNoDatabase } // GetItemStream method of NoDatabase returns ErrNoDatabase. func (NoDatabase) GetItemStream(_ context.Context, _ int, _ *time.Time) (chan []Item, chan error) { itemChan := make(chan []Item, bufSize) errChan := make(chan error, 1) go func() { defer close(itemChan) defer close(errChan) errChan <- ErrNoDatabase }() return itemChan, errChan } // GetItemFeedback method of NoDatabase returns ErrNoDatabase. func (NoDatabase) GetItemFeedback(_ context.Context, _ string, _ ...string) ([]Feedback, error) { return nil, ErrNoDatabase } // BatchInsertUsers method of NoDatabase returns ErrNoDatabase. func (NoDatabase) BatchInsertUsers(_ context.Context, _ []User) error { return ErrNoDatabase } // DeleteUser method of NoDatabase returns ErrNoDatabase. func (NoDatabase) DeleteUser(_ context.Context, _ string) error { return ErrNoDatabase } // GetUser method of NoDatabase returns ErrNoDatabase. func (NoDatabase) GetUser(_ context.Context, _ string) (User, error) { return User{}, ErrNoDatabase } // GetUsers method of NoDatabase returns ErrNoDatabase. func (NoDatabase) GetUsers(_ context.Context, _ string, _ int) (string, []User, error) { return "", nil, ErrNoDatabase } // GetUserStream method of NoDatabase returns ErrNoDatabase. func (NoDatabase) GetUserStream(_ context.Context, _ int) (chan []User, chan error) { userChan := make(chan []User, bufSize) errChan := make(chan error, 1) go func() { defer close(userChan) defer close(errChan) errChan <- ErrNoDatabase }() return userChan, errChan } // GetUserFeedback method of NoDatabase returns ErrNoDatabase. func (NoDatabase) GetUserFeedback(context.Context, string, *time.Time, ...expression.FeedbackTypeExpression) ([]Feedback, error) { return nil, ErrNoDatabase } // GetUserItemFeedback method of NoDatabase returns ErrNoDatabase. func (NoDatabase) GetUserItemFeedback(_ context.Context, _, _ string, _ ...string) ([]Feedback, error) { return nil, ErrNoDatabase } // DeleteUserItemFeedback method of NoDatabase returns ErrNoDatabase. func (NoDatabase) DeleteUserItemFeedback(_ context.Context, _, _ string, _ ...string) (int, error) { return 0, ErrNoDatabase } // BatchInsertFeedback method of NoDatabase returns ErrNoDatabase. func (NoDatabase) BatchInsertFeedback(_ context.Context, _ []Feedback, _, _, _ bool) error { return ErrNoDatabase } // GetFeedback method of NoDatabase returns ErrNoDatabase. func (NoDatabase) GetFeedback(_ context.Context, _ string, _ int, _, _ *time.Time, _ ...string) (string, []Feedback, error) { return "", nil, ErrNoDatabase } // GetFeedbackStream method of NoDatabase returns ErrNoDatabase. func (NoDatabase) GetFeedbackStream(_ context.Context, _ int, _ ...ScanOption) (chan []Feedback, chan error) { feedbackChan := make(chan []Feedback, bufSize) errChan := make(chan error, 1) go func() { defer close(feedbackChan) defer close(errChan) errChan <- ErrNoDatabase }() return feedbackChan, errChan } func (d NoDatabase) ModifyItem(_ context.Context, _ string, _ ItemPatch) error { return ErrNoDatabase } func (d NoDatabase) ModifyUser(_ context.Context, _ string, _ UserPatch) error { return ErrNoDatabase } func (d NoDatabase) CountUsers(_ context.Context) (int, error) { return 0, ErrNoDatabase } func (d NoDatabase) CountItems(_ context.Context) (int, error) { return 0, ErrNoDatabase } func (d NoDatabase) CountFeedback(_ context.Context) (int, error) { return 0, ErrNoDatabase } ================================================ FILE: storage/data/no_database_test.go ================================================ // Copyright 2021 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package data import ( "testing" "time" "github.com/samber/lo" "github.com/stretchr/testify/assert" ) func TestNoDatabase(t *testing.T) { ctx := t.Context() var database NoDatabase err := database.Close() assert.ErrorIs(t, err, ErrNoDatabase) err = database.Optimize() assert.ErrorIs(t, err, ErrNoDatabase) err = database.Init() assert.ErrorIs(t, err, ErrNoDatabase) err = database.Ping() assert.ErrorIs(t, err, ErrNoDatabase) err = database.Purge() assert.ErrorIs(t, err, ErrNoDatabase) err = database.BatchInsertItems(ctx, nil) assert.ErrorIs(t, err, ErrNoDatabase) _, err = database.BatchGetItems(ctx, nil) assert.ErrorIs(t, err, ErrNoDatabase) err = database.ModifyItem(ctx, "", ItemPatch{}) assert.ErrorIs(t, err, ErrNoDatabase) _, err = database.GetItem(ctx, "") assert.ErrorIs(t, err, ErrNoDatabase) _, _, err = database.GetItems(ctx, "", 0, nil) assert.ErrorIs(t, err, ErrNoDatabase) err = database.DeleteItem(ctx, "") assert.ErrorIs(t, err, ErrNoDatabase) _, c := database.GetItemStream(ctx, 0, nil) assert.ErrorIs(t, <-c, ErrNoDatabase) err = database.BatchInsertUsers(ctx, nil) assert.ErrorIs(t, err, ErrNoDatabase) _, err = database.GetUser(ctx, "") assert.ErrorIs(t, err, ErrNoDatabase) err = database.ModifyUser(ctx, "", UserPatch{}) assert.ErrorIs(t, err, ErrNoDatabase) _, _, err = database.GetUsers(ctx, "", 0) assert.ErrorIs(t, err, ErrNoDatabase) err = database.DeleteUser(ctx, "") assert.ErrorIs(t, err, ErrNoDatabase) _, c = database.GetUserStream(ctx, 0) assert.ErrorIs(t, <-c, ErrNoDatabase) err = database.BatchInsertFeedback(ctx, nil, false, false, false) assert.ErrorIs(t, err, ErrNoDatabase) err = database.BatchInsertFeedback(ctx, nil, false, false, false) assert.ErrorIs(t, err, ErrNoDatabase) _, err = database.GetUserFeedback(ctx, "", lo.ToPtr(time.Now())) assert.ErrorIs(t, err, ErrNoDatabase) _, err = database.GetItemFeedback(ctx, "") assert.ErrorIs(t, err, ErrNoDatabase) _, _, err = database.GetFeedback(ctx, "", 0, nil, lo.ToPtr(time.Now())) assert.ErrorIs(t, err, ErrNoDatabase) _, err = database.GetUserItemFeedback(ctx, "", "") assert.ErrorIs(t, err, ErrNoDatabase) _, err = database.DeleteUserItemFeedback(ctx, "", "") assert.ErrorIs(t, err, ErrNoDatabase) _, c = database.GetFeedbackStream(ctx, 0) assert.ErrorIs(t, <-c, ErrNoDatabase) _, err = database.CountUsers(ctx) assert.ErrorIs(t, err, ErrNoDatabase) _, err = database.CountItems(ctx) assert.ErrorIs(t, err, ErrNoDatabase) _, err = database.CountFeedback(ctx) assert.ErrorIs(t, err, ErrNoDatabase) } ================================================ FILE: storage/data/proxy.go ================================================ // Copyright 2024 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package data import ( "context" "encoding/json" "io" "net" "time" "github.com/gorse-io/gorse/common/expression" "github.com/gorse-io/gorse/protocol" "github.com/juju/errors" "github.com/samber/lo" "google.golang.org/grpc" "google.golang.org/protobuf/types/known/timestamppb" ) type ProxyServer struct { protocol.UnimplementedDataStoreServer database Database server *grpc.Server } func NewProxyServer(database Database) *ProxyServer { return &ProxyServer{database: database} } func (p *ProxyServer) Serve(lis net.Listener) error { p.server = grpc.NewServer() protocol.RegisterDataStoreServer(p.server, p) return p.server.Serve(lis) } func (p *ProxyServer) Stop() { p.server.Stop() } func (p *ProxyServer) Ping(_ context.Context, _ *protocol.PingRequest) (*protocol.PingResponse, error) { return &protocol.PingResponse{}, p.database.Ping() } func (p *ProxyServer) BatchInsertItems(ctx context.Context, in *protocol.BatchInsertItemsRequest) (*protocol.BatchInsertItemsResponse, error) { items := make([]Item, len(in.Items)) for i, item := range in.Items { var labels any err := json.Unmarshal(item.Labels, &labels) if err != nil { return nil, err } items[i] = Item{ ItemId: item.ItemId, IsHidden: item.IsHidden, Categories: item.Categories, Timestamp: item.Timestamp.AsTime(), Labels: labels, Comment: item.Comment, } } err := p.database.BatchInsertItems(ctx, items) return &protocol.BatchInsertItemsResponse{}, err } func (p *ProxyServer) BatchGetItems(ctx context.Context, in *protocol.BatchGetItemsRequest) (*protocol.BatchGetItemsResponse, error) { items, err := p.database.BatchGetItems(ctx, in.ItemIds) if err != nil { return nil, err } pbItems := make([]*protocol.Item, len(items)) for i, item := range items { labels, err := json.Marshal(item.Labels) if err != nil { return nil, err } pbItems[i] = &protocol.Item{ ItemId: item.ItemId, IsHidden: item.IsHidden, Categories: item.Categories, Timestamp: timestamppb.New(item.Timestamp), Labels: labels, Comment: item.Comment, } } return &protocol.BatchGetItemsResponse{Items: pbItems}, nil } func (p *ProxyServer) DeleteItem(ctx context.Context, in *protocol.DeleteItemRequest) (*protocol.DeleteItemResponse, error) { err := p.database.DeleteItem(ctx, in.ItemId) return &protocol.DeleteItemResponse{}, err } func (p *ProxyServer) GetItem(ctx context.Context, in *protocol.GetItemRequest) (*protocol.GetItemResponse, error) { item, err := p.database.GetItem(ctx, in.ItemId) if err != nil { if errors.Is(err, errors.NotFound) { return &protocol.GetItemResponse{}, nil } return nil, err } labels, err := json.Marshal(item.Labels) if err != nil { return nil, err } return &protocol.GetItemResponse{ Item: &protocol.Item{ ItemId: item.ItemId, IsHidden: item.IsHidden, Categories: item.Categories, Timestamp: timestamppb.New(item.Timestamp), Labels: labels, Comment: item.Comment, }, }, nil } func (p *ProxyServer) ModifyItem(ctx context.Context, in *protocol.ModifyItemRequest) (*protocol.ModifyItemResponse, error) { var labels any if in.Patch.Labels != nil { err := json.Unmarshal(in.Patch.Labels, &labels) if err != nil { return nil, err } } var timestamp *time.Time if in.Patch.Timestamp != nil { timestamp = lo.ToPtr(in.Patch.Timestamp.AsTime()) } err := p.database.ModifyItem(ctx, in.ItemId, ItemPatch{ IsHidden: in.Patch.IsHidden, Categories: in.Patch.Categories, Labels: labels, Comment: in.Patch.Comment, Timestamp: timestamp, }) return &protocol.ModifyItemResponse{}, err } func (p *ProxyServer) GetItems(ctx context.Context, in *protocol.GetItemsRequest) (*protocol.GetItemsResponse, error) { var beginTime *time.Time if in.BeginTime != nil { beginTime = lo.ToPtr(in.BeginTime.AsTime()) } cursor, items, err := p.database.GetItems(ctx, in.Cursor, int(in.N), beginTime) if err != nil { return nil, err } pbItems := make([]*protocol.Item, len(items)) for i, item := range items { labels, err := json.Marshal(item.Labels) if err != nil { return nil, err } pbItems[i] = &protocol.Item{ ItemId: item.ItemId, IsHidden: item.IsHidden, Categories: item.Categories, Timestamp: timestamppb.New(item.Timestamp), Labels: labels, Comment: item.Comment, } } return &protocol.GetItemsResponse{Cursor: cursor, Items: pbItems}, nil } func (p *ProxyServer) GetItemFeedback(ctx context.Context, in *protocol.GetItemFeedbackRequest) (*protocol.GetFeedbackResponse, error) { var types []string for _, t := range in.FeedbackTypes { types = append(types, t.FeedbackType) } feedback, err := p.database.GetItemFeedback(ctx, in.ItemId, types...) if err != nil { return nil, err } pbFeedback := make([]*protocol.Feedback, len(feedback)) for i, f := range feedback { pbFeedback[i] = &protocol.Feedback{ FeedbackType: f.FeedbackType, UserId: f.UserId, ItemId: f.ItemId, Value: f.Value, Timestamp: timestamppb.New(f.Timestamp), Updated: timestamppb.New(f.Updated), Comment: f.Comment, } } return &protocol.GetFeedbackResponse{Feedback: pbFeedback}, nil } func (p *ProxyServer) BatchInsertUsers(ctx context.Context, in *protocol.BatchInsertUsersRequest) (*protocol.BatchInsertUsersResponse, error) { users := make([]User, len(in.Users)) for i, user := range in.Users { var labels any err := json.Unmarshal(user.Labels, &labels) if err != nil { return nil, err } users[i] = User{ UserId: user.UserId, Labels: labels, Comment: user.Comment, } } err := p.database.BatchInsertUsers(ctx, users) return &protocol.BatchInsertUsersResponse{}, err } func (p *ProxyServer) DeleteUser(ctx context.Context, in *protocol.DeleteUserRequest) (*protocol.DeleteUserResponse, error) { err := p.database.DeleteUser(ctx, in.UserId) return &protocol.DeleteUserResponse{}, err } func (p *ProxyServer) GetUser(ctx context.Context, in *protocol.GetUserRequest) (*protocol.GetUserResponse, error) { user, err := p.database.GetUser(ctx, in.UserId) if err != nil { if errors.Is(err, errors.NotFound) { return &protocol.GetUserResponse{}, nil } return nil, err } labels, err := json.Marshal(user.Labels) if err != nil { return nil, err } return &protocol.GetUserResponse{ User: &protocol.User{ UserId: user.UserId, Labels: labels, Comment: user.Comment, }, }, nil } func (p *ProxyServer) ModifyUser(ctx context.Context, in *protocol.ModifyUserRequest) (*protocol.ModifyUserResponse, error) { var labels any if in.Patch.Labels != nil { err := json.Unmarshal(in.Patch.Labels, &labels) if err != nil { return nil, err } } err := p.database.ModifyUser(ctx, in.UserId, UserPatch{ Labels: labels, Comment: in.Patch.Comment, }) return &protocol.ModifyUserResponse{}, err } func (p *ProxyServer) GetUsers(ctx context.Context, in *protocol.GetUsersRequest) (*protocol.GetUsersResponse, error) { cursor, users, err := p.database.GetUsers(ctx, in.Cursor, int(in.N)) if err != nil { return nil, err } pbUsers := make([]*protocol.User, len(users)) for i, user := range users { labels, err := json.Marshal(user.Labels) if err != nil { return nil, err } pbUsers[i] = &protocol.User{ UserId: user.UserId, Labels: labels, Comment: user.Comment, } } return &protocol.GetUsersResponse{Cursor: cursor, Users: pbUsers}, nil } func (p *ProxyServer) GetUserFeedback(ctx context.Context, in *protocol.GetUserFeedbackRequest) (*protocol.GetFeedbackResponse, error) { var endTime *time.Time if in.EndTime != nil { endTime = lo.ToPtr(in.EndTime.AsTime()) } types := make([]expression.FeedbackTypeExpression, len(in.FeedbackTypes)) for i, t := range in.FeedbackTypes { types[i].FromPB(t) } feedback, err := p.database.GetUserFeedback(ctx, in.UserId, endTime, types...) if err != nil { return nil, err } pbFeedback := make([]*protocol.Feedback, len(feedback)) for i, f := range feedback { pbFeedback[i] = &protocol.Feedback{ FeedbackType: f.FeedbackType, UserId: f.UserId, ItemId: f.ItemId, Value: f.Value, Timestamp: timestamppb.New(f.Timestamp), Updated: timestamppb.New(f.Updated), Comment: f.Comment, } } return &protocol.GetFeedbackResponse{Feedback: pbFeedback}, nil } func (p *ProxyServer) GetUserItemFeedback(ctx context.Context, in *protocol.GetUserItemFeedbackRequest) (*protocol.GetFeedbackResponse, error) { var types []string for _, t := range in.FeedbackTypes { types = append(types, t.FeedbackType) } feedback, err := p.database.GetUserItemFeedback(ctx, in.UserId, in.ItemId, types...) if err != nil { return nil, err } pbFeedback := make([]*protocol.Feedback, len(feedback)) for i, f := range feedback { pbFeedback[i] = &protocol.Feedback{ FeedbackType: f.FeedbackType, UserId: f.UserId, ItemId: f.ItemId, Value: f.Value, Timestamp: timestamppb.New(f.Timestamp), Updated: timestamppb.New(f.Updated), Comment: f.Comment, } } return &protocol.GetFeedbackResponse{Feedback: pbFeedback}, nil } func (p *ProxyServer) DeleteUserItemFeedback(ctx context.Context, in *protocol.DeleteUserItemFeedbackRequest) (*protocol.DeleteUserItemFeedbackResponse, error) { count, err := p.database.DeleteUserItemFeedback(ctx, in.UserId, in.ItemId, in.FeedbackTypes...) return &protocol.DeleteUserItemFeedbackResponse{Count: int32(count)}, err } func (p *ProxyServer) BatchInsertFeedback(ctx context.Context, in *protocol.BatchInsertFeedbackRequest) (*protocol.BatchInsertFeedbackResponse, error) { feedback := make([]Feedback, len(in.Feedback)) for i, f := range in.Feedback { feedback[i] = Feedback{ FeedbackKey: FeedbackKey{ FeedbackType: f.FeedbackType, UserId: f.UserId, ItemId: f.ItemId, }, Value: f.Value, Timestamp: f.Timestamp.AsTime(), Comment: f.Comment, } } err := p.database.BatchInsertFeedback(ctx, feedback, in.InsertUser, in.InsertItem, in.Overwrite) return &protocol.BatchInsertFeedbackResponse{}, err } func (p *ProxyServer) GetFeedback(ctx context.Context, in *protocol.GetFeedbackRequest) (*protocol.GetFeedbackResponse, error) { var beginTime, endTime *time.Time if in.BeginTime != nil { beginTime = lo.ToPtr(in.BeginTime.AsTime()) } if in.EndTime != nil { endTime = lo.ToPtr(in.EndTime.AsTime()) } var types []string for _, t := range in.FeedbackTypes { types = append(types, t.FeedbackType) } cursor, feedback, err := p.database.GetFeedback(ctx, in.Cursor, int(in.N), beginTime, endTime, types...) if err != nil { return nil, err } pbFeedback := make([]*protocol.Feedback, len(feedback)) for i, f := range feedback { pbFeedback[i] = &protocol.Feedback{ FeedbackType: f.FeedbackType, UserId: f.UserId, ItemId: f.ItemId, Value: f.Value, Timestamp: timestamppb.New(f.Timestamp), Updated: timestamppb.New(f.Updated), Comment: f.Comment, } } return &protocol.GetFeedbackResponse{Cursor: cursor, Feedback: pbFeedback}, nil } func (p *ProxyServer) GetUserStream(in *protocol.GetUserStreamRequest, stream grpc.ServerStreamingServer[protocol.GetUserStreamResponse]) error { usersChan, errChan := p.database.GetUserStream(stream.Context(), int(in.BatchSize)) for users := range usersChan { pbUsers := make([]*protocol.User, len(users)) for i, user := range users { labels, err := json.Marshal(user.Labels) if err != nil { return err } pbUsers[i] = &protocol.User{ UserId: user.UserId, Labels: labels, Comment: user.Comment, } } err := stream.Send(&protocol.GetUserStreamResponse{Users: pbUsers}) if err != nil { return err } } return <-errChan } func (p *ProxyServer) GetItemStream(in *protocol.GetItemStreamRequest, stream grpc.ServerStreamingServer[protocol.GetItemStreamResponse]) error { var timeLimit *time.Time if in.TimeLimit != nil { timeLimit = lo.ToPtr(in.TimeLimit.AsTime()) } itemsChan, errChan := p.database.GetItemStream(stream.Context(), int(in.BatchSize), timeLimit) for items := range itemsChan { pbItems := make([]*protocol.Item, len(items)) for i, item := range items { labels, err := json.Marshal(item.Labels) if err != nil { return err } pbItems[i] = &protocol.Item{ ItemId: item.ItemId, IsHidden: item.IsHidden, Categories: item.Categories, Timestamp: timestamppb.New(item.Timestamp), Labels: labels, Comment: item.Comment, } } err := stream.Send(&protocol.GetItemStreamResponse{Items: pbItems}) if err != nil { return err } } return <-errChan } func (p *ProxyServer) GetFeedbackStream(in *protocol.GetFeedbackStreamRequest, stream grpc.ServerStreamingServer[protocol.GetFeedbackStreamResponse]) error { var opts []ScanOption if in.ScanOptions.BeginTime != nil { opts = append(opts, WithBeginTime(in.ScanOptions.BeginTime.AsTime())) } if in.ScanOptions.EndTime != nil { opts = append(opts, WithEndTime(in.ScanOptions.EndTime.AsTime())) } if in.ScanOptions.FeedbackTypes != nil { types := make([]expression.FeedbackTypeExpression, len(in.ScanOptions.FeedbackTypes)) for i, t := range in.ScanOptions.FeedbackTypes { types[i].FromPB(t) } opts = append(opts, WithFeedbackTypes(types...)) } if in.ScanOptions.BeginUserId != nil { opts = append(opts, WithBeginUserId(*in.ScanOptions.BeginUserId)) } if in.ScanOptions.EndUserId != nil { opts = append(opts, WithEndUserId(*in.ScanOptions.EndUserId)) } if in.ScanOptions.BeginItemId != nil { opts = append(opts, WithBeginItemId(*in.ScanOptions.BeginItemId)) } if in.ScanOptions.EndItemId != nil { opts = append(opts, WithEndItemId(*in.ScanOptions.EndItemId)) } if in.ScanOptions.OrderByItemId { opts = append(opts, WithOrderByItemId()) } feedbackChan, errChan := p.database.GetFeedbackStream(stream.Context(), int(in.BatchSize), opts...) for feedback := range feedbackChan { pbFeedback := make([]*protocol.Feedback, len(feedback)) for i, f := range feedback { pbFeedback[i] = &protocol.Feedback{ FeedbackType: f.FeedbackType, UserId: f.UserId, ItemId: f.ItemId, Value: f.Value, Timestamp: timestamppb.New(f.Timestamp), Updated: timestamppb.New(f.Updated), Comment: f.Comment, } } err := stream.Send(&protocol.GetFeedbackStreamResponse{Feedback: pbFeedback}) if err != nil { return err } } return <-errChan } func (p *ProxyServer) CountUsers(ctx context.Context, in *protocol.CountUsersRequest) (*protocol.CountUsersResponse, error) { count, err := p.database.CountUsers(ctx) return &protocol.CountUsersResponse{Count: int32(count)}, err } func (p *ProxyServer) CountItems(ctx context.Context, in *protocol.CountItemsRequest) (*protocol.CountItemsResponse, error) { count, err := p.database.CountItems(ctx) return &protocol.CountItemsResponse{Count: int32(count)}, err } func (p *ProxyServer) CountFeedback(ctx context.Context, in *protocol.CountFeedbackRequest) (*protocol.CountFeedbackResponse, error) { count, err := p.database.CountFeedback(ctx) return &protocol.CountFeedbackResponse{Count: int32(count)}, err } func (p *ProxyServer) GetLatestItems(ctx context.Context, in *protocol.GetLatestItemsRequest) (*protocol.GetLatestItemsResponse, error) { items, err := p.database.GetLatestItems(ctx, int(in.N), in.Categories) if err != nil { return nil, err } pbItems := make([]*protocol.Item, len(items)) for i, item := range items { labels, err := json.Marshal(item.Labels) if err != nil { return nil, err } pbItems[i] = &protocol.Item{ ItemId: item.ItemId, IsHidden: item.IsHidden, Categories: item.Categories, Timestamp: timestamppb.New(item.Timestamp), Labels: labels, Comment: item.Comment, } } return &protocol.GetLatestItemsResponse{Items: pbItems}, nil } type ProxyClient struct { protocol.DataStoreClient } func NewProxyClient(conn *grpc.ClientConn) *ProxyClient { return &ProxyClient{ DataStoreClient: protocol.NewDataStoreClient(conn), } } func (p ProxyClient) Init() error { return errors.MethodNotAllowedf("method Init is not allowed in ProxyClient") } func (p ProxyClient) Ping() error { _, err := p.DataStoreClient.Ping(context.Background(), &protocol.PingRequest{}) return err } func (p ProxyClient) Close() error { return nil } func (p ProxyClient) Optimize() error { return nil } func (p ProxyClient) Purge() error { return errors.MethodNotAllowedf("method Purge is not allowed in ProxyClient") } func (p ProxyClient) BatchInsertItems(ctx context.Context, items []Item) error { pbItems := make([]*protocol.Item, len(items)) for i, item := range items { labels, err := json.Marshal(item.Labels) if err != nil { return err } pbItems[i] = &protocol.Item{ ItemId: item.ItemId, IsHidden: item.IsHidden, Categories: item.Categories, Timestamp: timestamppb.New(item.Timestamp), Labels: labels, Comment: item.Comment, } } _, err := p.DataStoreClient.BatchInsertItems(ctx, &protocol.BatchInsertItemsRequest{Items: pbItems}) return err } func (p ProxyClient) BatchGetItems(ctx context.Context, itemIds []string) ([]Item, error) { resp, err := p.DataStoreClient.BatchGetItems(ctx, &protocol.BatchGetItemsRequest{ItemIds: itemIds}) if err != nil { return nil, err } items := make([]Item, len(resp.Items)) for i, item := range resp.Items { var labels any err = json.Unmarshal(item.Labels, &labels) if err != nil { return nil, err } items[i] = Item{ ItemId: item.ItemId, IsHidden: item.IsHidden, Categories: item.Categories, Timestamp: item.Timestamp.AsTime(), Labels: labels, Comment: item.Comment, } } return items, nil } func (p ProxyClient) DeleteItem(ctx context.Context, itemId string) error { _, err := p.DataStoreClient.DeleteItem(ctx, &protocol.DeleteItemRequest{ItemId: itemId}) return err } func (p ProxyClient) GetItem(ctx context.Context, itemId string) (Item, error) { resp, err := p.DataStoreClient.GetItem(ctx, &protocol.GetItemRequest{ItemId: itemId}) if err != nil { return Item{}, err } if resp.Item == nil { return Item{}, errors.Annotate(ErrItemNotExist, itemId) } var labels any if err = json.Unmarshal(resp.Item.Labels, &labels); err != nil { return Item{}, err } return Item{ ItemId: resp.Item.ItemId, IsHidden: resp.Item.IsHidden, Categories: resp.Item.Categories, Timestamp: resp.Item.Timestamp.AsTime(), Labels: labels, Comment: resp.Item.Comment, }, nil } func (p ProxyClient) GetLatestItems(ctx context.Context, n int, categories []string) ([]Item, error) { resp, err := p.DataStoreClient.GetLatestItems(ctx, &protocol.GetLatestItemsRequest{N: int32(n), Categories: categories}) if err != nil { return nil, err } items := make([]Item, len(resp.Items)) for i, item := range resp.Items { var labels any if err = json.Unmarshal(item.Labels, &labels); err != nil { return nil, err } items[i] = Item{ ItemId: item.ItemId, IsHidden: item.IsHidden, Categories: item.Categories, Timestamp: item.Timestamp.AsTime(), Labels: labels, Comment: item.Comment, } } return items, nil } func (p ProxyClient) ModifyItem(ctx context.Context, itemId string, patch ItemPatch) error { var labels []byte if patch.Labels != nil { var err error labels, err = json.Marshal(patch.Labels) if err != nil { return err } } var timestamp *timestamppb.Timestamp if patch.Timestamp != nil { timestamp = timestamppb.New(*patch.Timestamp) } _, err := p.DataStoreClient.ModifyItem(ctx, &protocol.ModifyItemRequest{ ItemId: itemId, Patch: &protocol.ItemPatch{ IsHidden: patch.IsHidden, Categories: patch.Categories, Labels: labels, Comment: patch.Comment, Timestamp: timestamp, }, }) return err } func (p ProxyClient) GetItems(ctx context.Context, cursor string, n int, beginTime *time.Time) (string, []Item, error) { var beginTimeProto *timestamppb.Timestamp if beginTime != nil { beginTimeProto = timestamppb.New(*beginTime) } resp, err := p.DataStoreClient.GetItems(ctx, &protocol.GetItemsRequest{Cursor: cursor, N: int32(n), BeginTime: beginTimeProto}) if err != nil { return "", nil, err } items := make([]Item, len(resp.Items)) for i, item := range resp.Items { var labels any err = json.Unmarshal(item.Labels, &labels) if err != nil { return "", nil, err } items[i] = Item{ ItemId: item.ItemId, IsHidden: item.IsHidden, Categories: item.Categories, Timestamp: item.Timestamp.AsTime(), Labels: labels, Comment: item.Comment, } } return resp.Cursor, items, nil } func (p ProxyClient) GetItemFeedback(ctx context.Context, itemId string, feedbackTypes ...string) ([]Feedback, error) { var types []*protocol.FeedbackTypeExpression for _, t := range feedbackTypes { types = append(types, &protocol.FeedbackTypeExpression{FeedbackType: t}) } resp, err := p.DataStoreClient.GetItemFeedback(ctx, &protocol.GetItemFeedbackRequest{ ItemId: itemId, FeedbackTypes: types, }) if err != nil { return nil, err } feedback := make([]Feedback, len(resp.Feedback)) for i, f := range resp.Feedback { feedback[i] = Feedback{ FeedbackKey: FeedbackKey{ FeedbackType: f.FeedbackType, UserId: f.UserId, ItemId: f.ItemId, }, Value: f.Value, Timestamp: f.Timestamp.AsTime(), Updated: f.Updated.AsTime(), Comment: f.Comment, } } return feedback, nil } func (p ProxyClient) BatchInsertUsers(ctx context.Context, users []User) error { pbUsers := make([]*protocol.User, len(users)) for i, user := range users { labels, err := json.Marshal(user.Labels) if err != nil { return err } pbUsers[i] = &protocol.User{ UserId: user.UserId, Labels: labels, Comment: user.Comment, } } _, err := p.DataStoreClient.BatchInsertUsers(ctx, &protocol.BatchInsertUsersRequest{Users: pbUsers}) return err } func (p ProxyClient) DeleteUser(ctx context.Context, userId string) error { _, err := p.DataStoreClient.DeleteUser(ctx, &protocol.DeleteUserRequest{UserId: userId}) return err } func (p ProxyClient) GetUser(ctx context.Context, userId string) (User, error) { resp, err := p.DataStoreClient.GetUser(ctx, &protocol.GetUserRequest{UserId: userId}) if err != nil { return User{}, err } if resp.User == nil { return User{}, errors.Annotate(ErrUserNotExist, userId) } var labels any if err = json.Unmarshal(resp.User.Labels, &labels); err != nil { return User{}, err } return User{ UserId: resp.User.UserId, Labels: labels, Comment: resp.User.Comment, }, nil } func (p ProxyClient) ModifyUser(ctx context.Context, userId string, patch UserPatch) error { var labels []byte if patch.Labels != nil { var err error labels, err = json.Marshal(patch.Labels) if err != nil { return err } } _, err := p.DataStoreClient.ModifyUser(ctx, &protocol.ModifyUserRequest{ UserId: userId, Patch: &protocol.UserPatch{ Labels: labels, Comment: patch.Comment, }, }) return err } func (p ProxyClient) GetUsers(ctx context.Context, cursor string, n int) (string, []User, error) { resp, err := p.DataStoreClient.GetUsers(ctx, &protocol.GetUsersRequest{Cursor: cursor, N: int32(n)}) if err != nil { return "", nil, err } users := make([]User, len(resp.Users)) for i, user := range resp.Users { var labels any err = json.Unmarshal(user.Labels, &labels) if err != nil { return "", nil, err } users[i] = User{ UserId: user.UserId, Labels: labels, Comment: user.Comment, } } return resp.Cursor, users, nil } func (p ProxyClient) GetUserFeedback(ctx context.Context, userId string, endTime *time.Time, feedbackTypes ...expression.FeedbackTypeExpression) ([]Feedback, error) { req := &protocol.GetUserFeedbackRequest{UserId: userId} if endTime != nil { req.EndTime = timestamppb.New(*endTime) } if len(feedbackTypes) > 0 { var types []*protocol.FeedbackTypeExpression for _, t := range feedbackTypes { types = append(types, t.ToPB()) } req.FeedbackTypes = types } resp, err := p.DataStoreClient.GetUserFeedback(ctx, req) if err != nil { return nil, err } feedback := make([]Feedback, len(resp.Feedback)) for i, f := range resp.Feedback { feedback[i] = Feedback{ FeedbackKey: FeedbackKey{ FeedbackType: f.FeedbackType, UserId: f.UserId, ItemId: f.ItemId, }, Value: f.Value, Timestamp: f.Timestamp.AsTime(), Updated: f.Updated.AsTime(), Comment: f.Comment, } } return feedback, nil } func (p ProxyClient) GetUserItemFeedback(ctx context.Context, userId, itemId string, feedbackTypes ...string) ([]Feedback, error) { var types []*protocol.FeedbackTypeExpression for _, t := range feedbackTypes { types = append(types, &protocol.FeedbackTypeExpression{FeedbackType: t}) } resp, err := p.DataStoreClient.GetUserItemFeedback(ctx, &protocol.GetUserItemFeedbackRequest{ UserId: userId, ItemId: itemId, FeedbackTypes: types, }) if err != nil { return nil, err } feedback := make([]Feedback, len(resp.Feedback)) for i, f := range resp.Feedback { feedback[i] = Feedback{ FeedbackKey: FeedbackKey{ FeedbackType: f.FeedbackType, UserId: f.UserId, ItemId: f.ItemId, }, Value: f.Value, Timestamp: f.Timestamp.AsTime(), Updated: f.Updated.AsTime(), Comment: f.Comment, } } return feedback, nil } func (p ProxyClient) DeleteUserItemFeedback(ctx context.Context, userId, itemId string, feedbackTypes ...string) (int, error) { resp, err := p.DataStoreClient.DeleteUserItemFeedback(ctx, &protocol.DeleteUserItemFeedbackRequest{ UserId: userId, ItemId: itemId, FeedbackTypes: feedbackTypes, }) if err != nil { return 0, err } return int(resp.Count), nil } func (p ProxyClient) BatchInsertFeedback(ctx context.Context, feedback []Feedback, insertUser, insertItem, overwrite bool) error { reqFeedback := make([]*protocol.Feedback, len(feedback)) for i, f := range feedback { reqFeedback[i] = &protocol.Feedback{ FeedbackType: f.FeedbackType, UserId: f.UserId, ItemId: f.ItemId, Value: f.Value, Timestamp: timestamppb.New(f.Timestamp), Comment: f.Comment, } } _, err := p.DataStoreClient.BatchInsertFeedback(ctx, &protocol.BatchInsertFeedbackRequest{ Feedback: reqFeedback, InsertUser: insertUser, InsertItem: insertItem, Overwrite: overwrite, }) return err } func (p ProxyClient) GetFeedback(ctx context.Context, cursor string, n int, beginTime, endTime *time.Time, feedbackTypes ...string) (string, []Feedback, error) { req := &protocol.GetFeedbackRequest{ Cursor: cursor, N: int32(n), } if beginTime != nil { req.BeginTime = timestamppb.New(*beginTime) } if endTime != nil { req.EndTime = timestamppb.New(*endTime) } if len(feedbackTypes) > 0 { var types []*protocol.FeedbackTypeExpression for _, t := range feedbackTypes { types = append(types, &protocol.FeedbackTypeExpression{FeedbackType: t}) } req.FeedbackTypes = types } resp, err := p.DataStoreClient.GetFeedback(ctx, req) if err != nil { return "", nil, err } feedback := make([]Feedback, len(resp.Feedback)) for i, f := range resp.Feedback { feedback[i] = Feedback{ FeedbackKey: FeedbackKey{ FeedbackType: f.FeedbackType, UserId: f.UserId, ItemId: f.ItemId, }, Value: f.Value, Timestamp: f.Timestamp.AsTime(), Updated: f.Updated.AsTime(), Comment: f.Comment, } } return resp.Cursor, feedback, nil } func (p ProxyClient) GetUserStream(ctx context.Context, batchSize int) (chan []User, chan error) { usersChan := make(chan []User, bufSize) errChan := make(chan error, 1) go func() { defer close(usersChan) defer close(errChan) stream, err := p.DataStoreClient.GetUserStream(ctx, &protocol.GetUserStreamRequest{BatchSize: int32(batchSize)}) if err != nil { errChan <- err return } for { resp, err := stream.Recv() if err != nil { if err == io.EOF { break } errChan <- err return } users := make([]User, len(resp.Users)) for i, user := range resp.Users { var labels any if err = json.Unmarshal(user.Labels, &labels); err != nil { errChan <- err return } users[i] = User{ UserId: user.UserId, Labels: labels, Comment: user.Comment, } } usersChan <- users } }() return usersChan, errChan } func (p ProxyClient) GetItemStream(ctx context.Context, batchSize int, timeLimit *time.Time) (chan []Item, chan error) { itemsChan := make(chan []Item, bufSize) errChan := make(chan error, 1) go func() { defer close(itemsChan) defer close(errChan) stream, err := p.DataStoreClient.GetItemStream(ctx, &protocol.GetItemStreamRequest{BatchSize: int32(batchSize)}) if err != nil { errChan <- err return } for { resp, err := stream.Recv() if err != nil { if err == io.EOF { break } errChan <- err return } items := make([]Item, len(resp.Items)) for i, item := range resp.Items { var labels any if err = json.Unmarshal(item.Labels, &labels); err != nil { errChan <- err return } items[i] = Item{ ItemId: item.ItemId, IsHidden: item.IsHidden, Categories: item.Categories, Timestamp: item.Timestamp.AsTime(), Labels: labels, Comment: item.Comment, } } itemsChan <- items } }() return itemsChan, errChan } func (p ProxyClient) GetFeedbackStream(ctx context.Context, batchSize int, options ...ScanOption) (chan []Feedback, chan error) { var o ScanOptions for _, opt := range options { opt(&o) } var types []*protocol.FeedbackTypeExpression for _, t := range o.FeedbackTypes { types = append(types, t.ToPB()) } pbOptions := &protocol.ScanOptions{ BeginUserId: o.BeginUserId, EndUserId: o.EndUserId, BeginItemId: o.BeginItemId, EndItemId: o.EndItemId, FeedbackTypes: types, OrderByItemId: o.OrderByItemId, } if o.BeginTime != nil { pbOptions.BeginTime = timestamppb.New(*o.BeginTime) } if o.EndTime != nil { pbOptions.EndTime = timestamppb.New(*o.EndTime) } feedbackChan := make(chan []Feedback, bufSize) errChan := make(chan error, 1) go func() { defer close(feedbackChan) defer close(errChan) req := &protocol.GetFeedbackStreamRequest{ BatchSize: int32(batchSize), ScanOptions: pbOptions, } stream, err := p.DataStoreClient.GetFeedbackStream(ctx, req) if err != nil { errChan <- err return } for { resp, err := stream.Recv() if err != nil { if err == io.EOF { break } errChan <- err return } feedback := make([]Feedback, len(resp.Feedback)) for i, f := range resp.Feedback { feedback[i] = Feedback{ FeedbackKey: FeedbackKey{ FeedbackType: f.FeedbackType, UserId: f.UserId, ItemId: f.ItemId, }, Value: f.Value, Timestamp: f.Timestamp.AsTime(), Updated: f.Updated.AsTime(), Comment: f.Comment, } } feedbackChan <- feedback } }() return feedbackChan, errChan } func (p ProxyClient) CountUsers(ctx context.Context) (int, error) { resp, err := p.DataStoreClient.CountUsers(ctx, &protocol.CountUsersRequest{}) if err != nil { return 0, err } return int(resp.Count), nil } func (p ProxyClient) CountItems(ctx context.Context) (int, error) { resp, err := p.DataStoreClient.CountItems(ctx, &protocol.CountItemsRequest{}) if err != nil { return 0, err } return int(resp.Count), nil } func (p ProxyClient) CountFeedback(ctx context.Context) (int, error) { resp, err := p.DataStoreClient.CountFeedback(ctx, &protocol.CountFeedbackRequest{}) if err != nil { return 0, err } return int(resp.Count), nil } ================================================ FILE: storage/data/proxy_test.go ================================================ // Copyright 2024 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package data import ( "fmt" "net" "testing" "github.com/stretchr/testify/suite" "google.golang.org/grpc" ) type ProxyTestSuite struct { baseTestSuite sqlite Database server *ProxyServer clientConn *grpc.ClientConn } func (suite *ProxyTestSuite) SetupSuite() { // create database var err error path := fmt.Sprintf("sqlite://%s/sqlite.db", suite.T().TempDir()) suite.sqlite, err = Open(path, "gorse_") suite.NoError(err) // create schema err = suite.sqlite.Init() suite.NoError(err) // start server lis, err := net.Listen("tcp", "localhost:0") suite.NoError(err) suite.server = NewProxyServer(suite.sqlite) go func() { err = suite.server.Serve(lis) suite.NoError(err) }() // create proxy client suite.clientConn, err = grpc.Dial(lis.Addr().String(), grpc.WithInsecure()) suite.NoError(err) suite.Database = NewProxyClient(suite.clientConn) } func (suite *ProxyTestSuite) TearDownSuite() { suite.server.Stop() suite.NoError(suite.clientConn.Close()) suite.NoError(suite.sqlite.Close()) } func (suite *ProxyTestSuite) SetupTest() { err := suite.sqlite.Ping() suite.NoError(err) err = suite.sqlite.Purge() suite.NoError(err) } func (suite *ProxyTestSuite) TearDownTest() { err := suite.sqlite.Purge() suite.NoError(err) } func (suite *ProxyTestSuite) TestInit() { suite.T().Skip() } func (suite *ProxyTestSuite) TestPurge() { suite.T().Skip() } func TestProxy(t *testing.T) { suite.Run(t, new(ProxyTestSuite)) } ================================================ FILE: storage/data/sql.go ================================================ // Copyright 2021 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package data import ( "context" "database/sql" "encoding/base64" "fmt" "net/url" "strings" "time" "github.com/XSAM/otelsql" mapset "github.com/deckarep/golang-set/v2" _ "github.com/go-sql-driver/mysql" "github.com/gorse-io/gorse/common/expression" "github.com/gorse-io/gorse/common/jsonutil" "github.com/gorse-io/gorse/common/log" "github.com/gorse-io/gorse/storage" "github.com/juju/errors" _ "github.com/lib/pq" _ "github.com/mailru/go-clickhouse/v2" "github.com/samber/lo" semconv "go.opentelemetry.io/otel/semconv/v1.12.0" "gorm.io/driver/clickhouse" "gorm.io/driver/mysql" "gorm.io/driver/postgres" "gorm.io/driver/sqlite" "gorm.io/gorm" "gorm.io/gorm/clause" _ "modernc.org/sqlite" ) const bufSize = 1 func init() { Register([]string{storage.MySQLPrefix}, func(path, tablePrefix string, opts ...storage.Option) (Database, error) { name := path[len(storage.MySQLPrefix):] option := storage.NewOptions(opts...) // probe isolation variable name isolationVarName, err := storage.ProbeMySQLIsolationVariableName(name) if err != nil { return nil, errors.Trace(err) } // append parameters if name, err = storage.AppendMySQLParams(name, map[string]string{ "sql_mode": "'ONLY_FULL_GROUP_BY,STRICT_TRANS_TABLES,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION'", isolationVarName: fmt.Sprintf("'%s'", option.IsolationLevel), "parseTime": "true", }); err != nil { return nil, errors.Trace(err) } // connect to database database := new(SQLDatabase) database.driver = MySQL database.TablePrefix = storage.TablePrefix(tablePrefix) if database.client, err = otelsql.Open("mysql", name, otelsql.WithAttributes(semconv.DBSystemMySQL), otelsql.WithSpanOptions(otelsql.SpanOptions{DisableErrSkip: true}), ); err != nil { return nil, errors.Trace(err) } storage.ApplySQLPool(database.client, option) database.gormDB, err = gorm.Open(mysql.New(mysql.Config{Conn: database.client}), storage.NewGORMConfig(tablePrefix)) if err != nil { return nil, errors.Trace(err) } return database, nil }) Register([]string{storage.PostgresPrefix, storage.PostgreSQLPrefix}, func(path, tablePrefix string, opts ...storage.Option) (Database, error) { database := new(SQLDatabase) database.driver = Postgres database.TablePrefix = storage.TablePrefix(tablePrefix) option := storage.NewOptions(opts...) var err error if database.client, err = otelsql.Open("postgres", path, otelsql.WithAttributes(semconv.DBSystemPostgreSQL), otelsql.WithSpanOptions(otelsql.SpanOptions{DisableErrSkip: true}), ); err != nil { return nil, errors.Trace(err) } storage.ApplySQLPool(database.client, option) database.gormDB, err = gorm.Open(postgres.New(postgres.Config{Conn: database.client}), storage.NewGORMConfig(tablePrefix)) if err != nil { return nil, errors.Trace(err) } return database, nil }) Register([]string{storage.ClickhousePrefix, storage.CHHTTPPrefix, storage.CHHTTPSPrefix}, func(path, tablePrefix string, opts ...storage.Option) (Database, error) { // replace schema parsed, err := url.Parse(path) if err != nil { return nil, errors.Trace(err) } if strings.HasPrefix(path, storage.CHHTTPSPrefix) { parsed.Scheme = "https" } else { parsed.Scheme = "http" } uri := parsed.String() database := new(SQLDatabase) database.driver = ClickHouse database.TablePrefix = storage.TablePrefix(tablePrefix) if database.client, err = otelsql.Open("chhttp", uri, otelsql.WithAttributes(semconv.DBSystemKey.String("clickhouse")), otelsql.WithSpanOptions(otelsql.SpanOptions{DisableErrSkip: true}), ); err != nil { return nil, errors.Trace(err) } database.gormDB, err = gorm.Open(clickhouse.New(clickhouse.Config{Conn: database.client}), storage.NewGORMConfig(tablePrefix)) if err != nil { return nil, errors.Trace(err) } return database, nil }) Register([]string{storage.SQLitePrefix}, func(path, tablePrefix string, opts ...storage.Option) (Database, error) { dataSourceName := path[len(storage.SQLitePrefix):] // append parameters var err error if dataSourceName, err = storage.AppendURLParams(dataSourceName, []lo.Tuple2[string, string]{ {"_pragma", "busy_timeout(10000)"}, {"_pragma", "journal_mode(wal)"}, }); err != nil { return nil, errors.Trace(err) } // connect to database database := new(SQLDatabase) database.driver = SQLite database.TablePrefix = storage.TablePrefix(tablePrefix) if database.client, err = otelsql.Open("sqlite", dataSourceName, otelsql.WithAttributes(semconv.DBSystemSqlite), otelsql.WithSpanOptions(otelsql.SpanOptions{DisableErrSkip: true}), ); err != nil { return nil, errors.Trace(err) } database.gormDB, err = gorm.Open(sqlite.Dialector{Conn: database.client}, storage.NewGORMConfig(tablePrefix)) if err != nil { return nil, errors.Trace(err) } return database, nil }) } type SQLDriver int const ( MySQL SQLDriver = iota Postgres ClickHouse SQLite ) type SQLItem struct { ItemId string `gorm:"column:item_id;primaryKey"` IsHidden bool `gorm:"column:is_hidden"` Categories string `gorm:"column:categories"` Timestamp time.Time `gorm:"column:time_stamp"` Labels string `gorm:"column:labels"` Comment string `gorm:"column:comment"` } func NewSQLItem(item Item) (sqlItem SQLItem) { var buf []byte sqlItem.ItemId = item.ItemId sqlItem.IsHidden = item.IsHidden buf, _ = jsonutil.Marshal(item.Categories) sqlItem.Categories = string(buf) sqlItem.Timestamp = item.Timestamp buf, _ = jsonutil.Marshal(item.Labels) sqlItem.Labels = string(buf) sqlItem.Comment = item.Comment return } type SQLUser struct { UserId string `gorm:"column:user_id;primaryKey"` Labels string `gorm:"column:labels"` Comment string `gorm:"column:comment"` } func NewSQLUser(user User) (sqlUser SQLUser) { var buf []byte sqlUser.UserId = user.UserId buf, _ = jsonutil.Marshal(user.Labels) sqlUser.Labels = string(buf) sqlUser.Comment = user.Comment return } type ClickHouseItem struct { SQLItem `gorm:"embedded"` Version time.Time `gorm:"column:version"` } func NewClickHouseItem(item Item) (clickHouseItem ClickHouseItem) { clickHouseItem.SQLItem = NewSQLItem(item) clickHouseItem.Timestamp = item.Timestamp.In(time.UTC) clickHouseItem.Version = time.Now().In(time.UTC) return } type ClickhouseUser struct { SQLUser `gorm:"embedded"` Version time.Time `gorm:"column:version"` } func NewClickhouseUser(user User) (clickhouseUser ClickhouseUser) { clickhouseUser.SQLUser = NewSQLUser(user) clickhouseUser.Version = time.Now().In(time.UTC) return } func FeedbackTypeExpressionToSQL(db *gorm.DB, e expression.FeedbackTypeExpression) *gorm.DB { switch e.ExprType { case expression.Less: return db.Or("feedback_type = ? AND value < ?", e.FeedbackType, e.Value) case expression.LessOrEqual: return db.Or("feedback_type = ? AND value <= ?", e.FeedbackType, e.Value) case expression.Greater: return db.Or("feedback_type = ? AND value > ?", e.FeedbackType, e.Value) case expression.GreaterOrEqual: return db.Or("feedback_type = ? AND value >= ?", e.FeedbackType, e.Value) default: return db.Or("feedback_type = ?", e.FeedbackType) } } // SQLDatabase use MySQL as data storage. type SQLDatabase struct { storage.TablePrefix gormDB *gorm.DB client *sql.DB driver SQLDriver } // Optimize is used by ClickHouse only. func (d *SQLDatabase) Optimize() error { if d.driver == ClickHouse { for _, tableName := range []string{d.UsersTable(), d.ItemsTable(), d.FeedbackTable(), d.AggregatingFeedbackTable(), d.UserFeedbackTable(), d.ItemFeedbackTable(), d.LatestItemsTable()} { _, err := d.client.Exec("OPTIMIZE TABLE " + tableName) if err != nil { return errors.Trace(err) } } } return nil } // Init tables and indices in MySQL. func (d *SQLDatabase) Init() error { switch d.driver { case MySQL: // create tables type Items struct { ItemId string `gorm:"column:item_id;type:varchar(256) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin not null;primaryKey"` IsHidden bool `gorm:"column:is_hidden;type:bool;not null"` Categories []string `gorm:"column:categories;type:json;not null"` Timestamp time.Time `gorm:"column:time_stamp;type:datetime;not null;index:time_stamp_index"` Labels []string `gorm:"column:labels;type:json;not null"` Comment string `gorm:"column:comment;type:text;not null"` } type Users struct { UserId string `gorm:"column:user_id;type:varchar(256) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin not null;primaryKey"` Labels []string `gorm:"column:labels;type:json;not null"` Comment string `gorm:"column:comment;type:text;not null"` } type Feedback struct { FeedbackType string `gorm:"column:feedback_type;type:varchar(256) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin not null;primaryKey"` UserId string `gorm:"column:user_id;type:varchar(256) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin not null;primaryKey;index:user_id"` ItemId string `gorm:"column:item_id;type:varchar(256) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin not null;primaryKey;index:item_id"` Value float64 `gorm:"column:value;type:float;not null;default:0"` Timestamp time.Time `gorm:"column:time_stamp;type:datetime;not null"` Updated time.Time `gorm:"column:updated;type:datetime;not null;default:'2000-01-01 00:00:00'"` Comment string `gorm:"column:comment;type:text;not null"` } err := d.gormDB.Set("gorm:table_options", "ENGINE=InnoDB").AutoMigrate(Users{}, Items{}, Feedback{}) if err != nil { return errors.Trace(err) } case Postgres: // create tables type Items struct { ItemId string `gorm:"column:item_id;type:varchar(256) COLLATE \"C\";not null;primaryKey"` IsHidden bool `gorm:"column:is_hidden;type:bool;not null;default:false"` Categories string `gorm:"column:categories;type:json;not null;default:'[]'"` Timestamp time.Time `gorm:"column:time_stamp;type:timestamptz;not null;index:time_stamp_index"` Labels string `gorm:"column:labels;type:json;not null;default:'[]'"` Comment string `gorm:"column:comment;type:text;not null;default:''"` } type Users struct { UserId string `gorm:"column:user_id;type:varchar(256) COLLATE \"C\" not null;primaryKey"` Labels string `gorm:"column:labels;type:json;not null;default:'[]'"` Comment string `gorm:"column:comment;type:text;not null;default:''"` } type Feedback struct { FeedbackType string `gorm:"column:feedback_type;type:varchar(256) COLLATE \"C\";not null;primaryKey"` UserId string `gorm:"column:user_id;type:varchar(256) COLLATE \"C\";not null;primaryKey;index:user_id_index"` ItemId string `gorm:"column:item_id;type:varchar(256) COLLATE \"C\";not null;primaryKey;index:item_id_index"` Value float64 `gorm:"column:value;type:float8;not null;default:0"` Timestamp time.Time `gorm:"column:time_stamp;type:timestamptz;not null"` Updated time.Time `gorm:"column:updated;type:timestamptz;not null;default:'2000-01-01 00:00:00'"` Comment string `gorm:"column:comment;type:text;not null;default:''"` } err := d.gormDB.AutoMigrate(Users{}, Items{}, Feedback{}) if err != nil { return errors.Trace(err) } case SQLite: // create tables type Items struct { ItemId string `gorm:"column:item_id;type:varchar(256);not null;primaryKey"` IsHidden bool `gorm:"column:is_hidden;type:bool;not null;default:false"` Categories string `gorm:"column:categories;type:json;not null;default:'[]'"` Timestamp string `gorm:"column:time_stamp;type:datetime;not null;default:'0001-01-01';index:time_stamp_index"` Labels string `gorm:"column:labels;type:json;not null;default:'[]'"` Comment string `gorm:"column:comment;type:text;not null;default:''"` } type Users struct { UserId string `gorm:"column:user_id;type:varchar(256) not null;primaryKey"` Labels string `gorm:"column:labels;type:json;not null;default:'null'"` Comment string `gorm:"column:comment;type:text;not null;default:''"` } type Feedback struct { FeedbackType string `gorm:"column:feedback_type;type:varchar(256);not null;primaryKey"` UserId string `gorm:"column:user_id;type:varchar(256);not null;primaryKey;index:user_id_index"` ItemId string `gorm:"column:item_id;type:varchar(256);not null;primaryKey;index:item_id_index"` Value float64 `gorm:"column:value;type:real;not null;default:0"` Timestamp string `gorm:"column:time_stamp;type:datetime;not null;default:'0001-01-01'"` Updated string `gorm:"column:updated;type:datetime;not null;default:'0001-01-01'"` Comment string `gorm:"column:comment;type:text;not null;default:''"` } err := d.gormDB.AutoMigrate(Users{}, Items{}, Feedback{}) if err != nil { return errors.Trace(err) } case ClickHouse: // create tables type Items struct { ItemId string `gorm:"column:item_id;type:String"` IsHidden int `gorm:"column:is_hidden;type:Boolean;default:0"` Categories string `gorm:"column:categories;type:String;default:'[]'"` Timestamp time.Time `gorm:"column:time_stamp;type:Datetime64(9,'UTC')"` Labels string `gorm:"column:labels;type:String;default:'[]'"` Comment string `gorm:"column:comment;type:String"` Version struct{} `gorm:"column:version;type:DateTime"` } err := d.gormDB.Set("gorm:table_options", "ENGINE = ReplacingMergeTree(version) ORDER BY item_id").AutoMigrate(Items{}) if err != nil { return errors.Trace(err) } type Users struct { UserId string `gorm:"column:user_id;type:String"` Labels string `gorm:"column:labels;type:String;default:'[]'"` Comment string `gorm:"column:comment;type:String"` Version struct{} `gorm:"column:version;type:DateTime"` } err = d.gormDB.Set("gorm:table_options", "ENGINE = ReplacingMergeTree(version) ORDER BY user_id").AutoMigrate(Users{}) if err != nil { return errors.Trace(err) } type Feedback struct { FeedbackType string `gorm:"column:feedback_type;type:String"` UserId string `gorm:"column:user_id;type:String"` ItemId string `gorm:"column:item_id;type:String"` Value float64 `gorm:"column:value;type:Float64;default:0"` Timestamp time.Time `gorm:"column:time_stamp;type:DateTime64(9,'UTC')"` Updated time.Time `gorm:"column:updated;type:DateTime64(9,'UTC')"` Comment string `gorm:"column:comment;type:String"` } err = d.gormDB.Set("gorm:table_options", "ENGINE = MergeTree ORDER BY (feedback_type, user_id, item_id)").AutoMigrate(Feedback{}) if err != nil { return errors.Trace(err) } // create materialized views type AggregatingFeedback struct { FeedbackType string `gorm:"column:feedback_type;type:String"` UserId string `gorm:"column:user_id;type:String"` ItemId string `gorm:"column:item_id;type:String"` Value float64 `gorm:"column:value;type:SimpleAggregateFunction(sum, Float64)"` Timestamp time.Time `gorm:"column:time_stamp;type:SimpleAggregateFunction(min, DateTime64(9,'UTC'))"` Updated time.Time `gorm:"column:updated;type:SimpleAggregateFunction(max, DateTime64(9,'UTC'))"` Comment string `gorm:"column:comment;type:SimpleAggregateFunction(anyLast, String)"` } err = d.gormDB.Set("gorm:table_options", "ENGINE = AggregatingMergeTree() ORDER BY (user_id, item_id, feedback_type)").AutoMigrate(AggregatingFeedback{}) if err != nil { return errors.Trace(err) } err = d.gormDB.Exec(fmt.Sprintf("CREATE MATERIALIZED VIEW IF NOT EXISTS %s_mv TO %s AS "+ "SELECT feedback_type, user_id, item_id, sum(value) AS value, min(time_stamp) AS time_stamp, max(updated) AS updated, anyLast(comment) AS comment "+ "FROM %s GROUP BY feedback_type, user_id, item_id", d.AggregatingFeedbackTable(), d.AggregatingFeedbackTable(), d.FeedbackTable())).Error if err != nil { return errors.Trace(err) } type UserFeedback AggregatingFeedback err = d.gormDB.Set("gorm:table_options", "ENGINE = AggregatingMergeTree() ORDER BY (user_id, item_id, feedback_type)").AutoMigrate(UserFeedback{}) if err != nil { return errors.Trace(err) } err = d.gormDB.Exec(fmt.Sprintf("CREATE MATERIALIZED VIEW IF NOT EXISTS %s_mv TO %s AS "+ "SELECT feedback_type, user_id, item_id, sum(value) AS value, min(time_stamp) AS time_stamp, max(updated) AS updated, anyLast(comment) AS comment "+ "FROM %s GROUP BY feedback_type, user_id, item_id", d.UserFeedbackTable(), d.UserFeedbackTable(), d.FeedbackTable())).Error if err != nil { return errors.Trace(err) } type ItemFeedback AggregatingFeedback err = d.gormDB.Set("gorm:table_options", "ENGINE = AggregatingMergeTree() ORDER BY (item_id, user_id, feedback_type)").AutoMigrate(ItemFeedback{}) if err != nil { return errors.Trace(err) } err = d.gormDB.Exec(fmt.Sprintf("CREATE MATERIALIZED VIEW IF NOT EXISTS %s_mv TO %s AS "+ "SELECT feedback_type, user_id, item_id, sum(value) AS value, min(time_stamp) AS time_stamp, max(updated) AS updated, anyLast(comment) AS comment "+ "FROM %s GROUP BY feedback_type, user_id, item_id", d.ItemFeedbackTable(), d.ItemFeedbackTable(), d.FeedbackTable())).Error if err != nil { return errors.Trace(err) } type LatestItems struct { ItemId string `gorm:"column:item_id;type:String"` IsHidden int `gorm:"column:is_hidden;type:Boolean;default:0"` Categories string `gorm:"column:categories;type:String;default:'[]'"` Timestamp time.Time `gorm:"column:time_stamp;type:Datetime64(9,'UTC')"` Labels string `gorm:"column:labels;type:String;default:'[]'"` Comment string `gorm:"column:comment;type:String"` Version time.Time `gorm:"column:version;type:DateTime"` } err = d.gormDB.Set("gorm:table_options", "ENGINE = ReplacingMergeTree(version) ORDER BY (time_stamp, item_id) SETTINGS index_granularity = 8192").AutoMigrate(LatestItems{}) if err != nil { return errors.Trace(err) } // Create materialized view for latest items ordered by timestamp err = d.gormDB.Exec(fmt.Sprintf("CREATE MATERIALIZED VIEW IF NOT EXISTS %s_latest_mv TO %s AS "+ "SELECT item_id, is_hidden, categories, time_stamp, labels, comment, version "+ "FROM %s", d.ItemsTable(), d.LatestItemsTable(), d.ItemsTable())).Error if err != nil { return errors.Trace(err) } } return nil } func (d *SQLDatabase) Ping() error { return d.client.Ping() } // Close MySQL connection. func (d *SQLDatabase) Close() error { return d.client.Close() } func (d *SQLDatabase) Purge() error { if d.driver == ClickHouse { tables := []string{d.ItemsTable(), d.FeedbackTable(), d.UsersTable(), d.UserFeedbackTable(), d.ItemFeedbackTable(), d.LatestItemsTable()} for _, tableName := range tables { err := d.gormDB.Exec(fmt.Sprintf("alter table %s delete where 1=1", tableName)).Error if err != nil { return errors.Trace(err) } } } else { tables := []string{d.ItemsTable(), d.FeedbackTable(), d.UsersTable()} for _, tableName := range tables { err := d.gormDB.Exec(fmt.Sprintf("DELETE FROM %s", tableName)).Error if err != nil { return errors.Trace(err) } } } return nil } // BatchInsertItems inserts a batch of items into MySQL. func (d *SQLDatabase) BatchInsertItems(ctx context.Context, items []Item) error { if len(items) == 0 { return nil } if d.driver == ClickHouse { rows := make([]ClickHouseItem, 0, len(items)) memo := mapset.NewSet[string]() for _, item := range items { if !memo.Contains(item.ItemId) { memo.Add(item.ItemId) rows = append(rows, NewClickHouseItem(item)) } } err := d.gormDB.Create(rows).Error return errors.Trace(err) } else { rows := make([]SQLItem, 0, len(items)) memo := mapset.NewSet[string]() for _, item := range items { if !memo.Contains(item.ItemId) { memo.Add(item.ItemId) row := NewSQLItem(item) if d.driver == SQLite { row.Timestamp = row.Timestamp.In(time.UTC) } rows = append(rows, row) } } err := d.gormDB.WithContext(ctx).Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: "item_id"}}, DoUpdates: clause.AssignmentColumns([]string{"is_hidden", "categories", "time_stamp", "labels", "comment"}), }).Create(rows).Error return errors.Trace(err) } } func (d *SQLDatabase) BatchGetItems(ctx context.Context, itemIds []string) ([]Item, error) { if len(itemIds) == 0 { return nil, nil } result, err := d.gormDB.WithContext(ctx). Table(d.ItemsTable()). Select("item_id, is_hidden, categories, time_stamp, labels, comment"). Where("item_id IN ?", itemIds).Rows() if err != nil { return nil, errors.Trace(err) } defer result.Close() var items []Item for result.Next() { var item Item if err = d.gormDB.ScanRows(result, &item); err != nil { return nil, errors.Trace(err) } items = append(items, item) } return items, nil } // DeleteItem deletes a item from MySQL. func (d *SQLDatabase) DeleteItem(ctx context.Context, itemId string) error { if err := d.gormDB.WithContext(ctx).Delete(&SQLItem{ItemId: itemId}).Error; err != nil { return errors.Trace(err) } if err := d.gormDB.WithContext(ctx).Delete(&Feedback{}, "item_id = ?", itemId).Error; err != nil { return errors.Trace(err) } if d.driver == ClickHouse { if err := d.gormDB.WithContext(ctx).Delete(&ItemFeedback{}, "item_id = ?", itemId).Error; err != nil { return errors.Trace(err) } if err := d.gormDB.WithContext(ctx).Delete(&UserFeedback{}, "item_id = ?", itemId).Error; err != nil { return errors.Trace(err) } } return nil } // GetItem get a item from MySQL. func (d *SQLDatabase) GetItem(ctx context.Context, itemId string) (Item, error) { var result *sql.Rows var err error result, err = d.gormDB.WithContext(ctx). Table(d.ItemsTable()). Select("item_id, is_hidden, categories, time_stamp, labels, comment"). Where("item_id = ?", itemId).Rows() if err != nil { return Item{}, errors.Trace(err) } defer result.Close() if result.Next() { var item Item if err = d.gormDB.ScanRows(result, &item); err != nil { return Item{}, errors.Trace(err) } return item, nil } return Item{}, errors.Annotate(ErrItemNotExist, itemId) } // ModifyItem modify an item in MySQL. func (d *SQLDatabase) ModifyItem(ctx context.Context, itemId string, patch ItemPatch) error { // ignore empty patch if patch.IsHidden == nil && patch.Categories == nil && patch.Labels == nil && patch.Comment == nil && patch.Timestamp == nil { log.Logger().Debug("empty item patch") return nil } attributes := make(map[string]any) if patch.IsHidden != nil { if *patch.IsHidden { attributes["is_hidden"] = 1 } else { attributes["is_hidden"] = 0 } } if patch.Categories != nil { text, _ := jsonutil.Marshal(patch.Categories) attributes["categories"] = string(text) } if patch.Comment != nil { attributes["comment"] = *patch.Comment } if patch.Labels != nil { text, _ := jsonutil.Marshal(patch.Labels) attributes["labels"] = string(text) } if patch.Timestamp != nil { switch d.driver { case ClickHouse, SQLite: attributes["time_stamp"] = patch.Timestamp.In(time.UTC) default: attributes["time_stamp"] = patch.Timestamp } } err := d.gormDB.WithContext(ctx).Model(&SQLItem{ItemId: itemId}).Updates(attributes).Error return errors.Trace(err) } // GetItems returns items from MySQL. func (d *SQLDatabase) GetItems(ctx context.Context, cursor string, n int, timeLimit *time.Time) (string, []Item, error) { buf, err := base64.StdEncoding.DecodeString(cursor) if err != nil { return "", nil, errors.Trace(err) } cursorItem := string(buf) tx := d.gormDB.WithContext(ctx). Table(d.ItemsTable()). Select("item_id, is_hidden, categories, time_stamp, labels, comment") if cursorItem != "" { tx.Where("item_id >= ?", cursorItem) } if timeLimit != nil { tx.Where("time_stamp >= ?", *timeLimit) } result, err := tx.Order("item_id").Limit(n + 1).Rows() if err != nil { return "", nil, errors.Trace(err) } items := make([]Item, 0) defer result.Close() for result.Next() { var item Item if err = d.gormDB.ScanRows(result, &item); err != nil { return "", nil, errors.Trace(err) } items = append(items, item) } if len(items) == n+1 { return base64.StdEncoding.EncodeToString([]byte(items[len(items)-1].ItemId)), items[:len(items)-1], nil } return "", items, nil } // GetLatestItems returns the latest items from the database. func (d *SQLDatabase) GetLatestItems(ctx context.Context, n int, categories []string) ([]Item, error) { var tableName string if d.driver == ClickHouse { tableName = d.LatestItemsTable() } else { tableName = d.ItemsTable() } tx := d.gormDB.WithContext(ctx). Table(tableName). Select("item_id, is_hidden, categories, time_stamp, labels, comment"). Where("is_hidden = ?", false) if len(categories) > 0 { q, err := jsonutil.Marshal(categories) if err != nil { return nil, errors.Trace(err) } switch d.driver { case Postgres: tx = tx.Where("categories::jsonb @> ?::jsonb", string(q)) case MySQL, SQLite: tx = tx.Where("JSON_CONTAINS(categories,?)", string(q)) case ClickHouse: tx = tx.Where("hasAll(JSONExtractArrayRaw(categories),JSONExtractArrayRaw(?))", string(q)) } } result, err := tx.Order("time_stamp DESC").Limit(n).Rows() if err != nil { return nil, errors.Trace(err) } items := make([]Item, 0) defer result.Close() for result.Next() { var item Item if err = d.gormDB.ScanRows(result, &item); err != nil { return nil, errors.Trace(err) } items = append(items, item) } return items, nil } // GetItemStream reads items by stream. func (d *SQLDatabase) GetItemStream(ctx context.Context, batchSize int, timeLimit *time.Time) (chan []Item, chan error) { itemChan := make(chan []Item, bufSize) errChan := make(chan error, 1) go func() { defer close(itemChan) defer close(errChan) // send query tx := d.gormDB.WithContext(ctx). Table(d.ItemsTable()). Select("item_id, is_hidden, categories, time_stamp, labels, comment") if timeLimit != nil { tx.Where("time_stamp >= ?", *timeLimit) } result, err := tx.Rows() if err != nil { errChan <- errors.Trace(err) return } // fetch result items := make([]Item, 0, batchSize) defer result.Close() for result.Next() { var item Item if err = d.gormDB.ScanRows(result, &item); err != nil { errChan <- errors.Trace(err) return } items = append(items, item) if len(items) == batchSize { itemChan <- items items = make([]Item, 0, batchSize) } } if len(items) > 0 { itemChan <- items } errChan <- nil }() return itemChan, errChan } // GetItemFeedback returns feedback of a item from MySQL. func (d *SQLDatabase) GetItemFeedback(ctx context.Context, itemId string, feedbackTypes ...string) ([]Feedback, error) { tx := d.gormDB.WithContext(ctx) if d.driver == ClickHouse { tx = tx.Table(d.ItemFeedbackTable()). Select("user_id, item_id, feedback_type, sum(value) AS value, min(time_stamp) AS time_stamp, max(updated) AS updated, anyLast(comment) AS comment"). Group("user_id, item_id, feedback_type") } else { tx = tx.Table(d.FeedbackTable()). Select("user_id, item_id, feedback_type, value, time_stamp, updated, comment") } switch d.driver { case SQLite: tx.Where("time_stamp <= DATETIME()") case ClickHouse: tx.Having("time_stamp <= NOW('UTC')") default: tx.Where("time_stamp <= NOW()") } tx.Where("item_id = ?", itemId) if len(feedbackTypes) > 0 { db := d.gormDB for _, feedbackType := range feedbackTypes { db = db.Or("feedback_type = ?", feedbackType) } tx = tx.Where(db) } result, err := tx.Rows() if err != nil { return nil, errors.Trace(err) } feedbacks := make([]Feedback, 0) defer result.Close() for result.Next() { var feedback Feedback if err = d.gormDB.ScanRows(result, &feedback); err != nil { return nil, errors.Trace(err) } feedbacks = append(feedbacks, feedback) } return feedbacks, nil } // BatchInsertUsers inserts users into MySQL. func (d *SQLDatabase) BatchInsertUsers(ctx context.Context, users []User) error { if len(users) == 0 { return nil } if d.driver == ClickHouse { rows := make([]ClickhouseUser, 0, len(users)) memo := mapset.NewSet[string]() for _, user := range users { if !memo.Contains(user.UserId) { memo.Add(user.UserId) rows = append(rows, NewClickhouseUser(user)) } } err := d.gormDB.Create(rows).Error return errors.Trace(err) } else { rows := make([]SQLUser, 0, len(users)) memo := mapset.NewSet[string]() for _, user := range users { if !memo.Contains(user.UserId) { memo.Add(user.UserId) rows = append(rows, NewSQLUser(user)) } } err := d.gormDB.WithContext(ctx).Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: "user_id"}}, DoUpdates: clause.AssignmentColumns([]string{"labels", "comment"}), }).Create(rows).Error return errors.Trace(err) } } // DeleteUser deletes a user from MySQL. func (d *SQLDatabase) DeleteUser(ctx context.Context, userId string) error { if err := d.gormDB.WithContext(ctx).Delete(&SQLUser{UserId: userId}).Error; err != nil { return errors.Trace(err) } if err := d.gormDB.WithContext(ctx).Delete(&Feedback{}, "user_id = ?", userId).Error; err != nil { return errors.Trace(err) } if d.driver == ClickHouse { if err := d.gormDB.WithContext(ctx).Delete(&ItemFeedback{}, "user_id = ?", userId).Error; err != nil { return errors.Trace(err) } if err := d.gormDB.WithContext(ctx).Delete(&UserFeedback{}, "user_id = ?", userId).Error; err != nil { return errors.Trace(err) } } return nil } // GetUser returns a user from MySQL. func (d *SQLDatabase) GetUser(ctx context.Context, userId string) (User, error) { var result *sql.Rows var err error result, err = d.gormDB.WithContext(ctx).Table(d.UsersTable()). Select("user_id, labels, comment"). Where("user_id = ?", userId).Rows() if err != nil { return User{}, errors.Trace(err) } defer result.Close() if result.Next() { var user User if err = d.gormDB.ScanRows(result, &user); err != nil { return User{}, errors.Trace(err) } return user, nil } return User{}, errors.Annotate(ErrUserNotExist, userId) } // ModifyUser modify a user in MySQL. func (d *SQLDatabase) ModifyUser(ctx context.Context, userId string, patch UserPatch) error { // ignore empty patch if patch.Labels == nil && patch.Comment == nil { log.Logger().Debug("empty user patch") return nil } attributes := make(map[string]any) if patch.Comment != nil { attributes["comment"] = *patch.Comment } if patch.Labels != nil { text, _ := jsonutil.Marshal(patch.Labels) attributes["labels"] = string(text) } err := d.gormDB.WithContext(ctx).Model(&SQLUser{UserId: userId}).Updates(attributes).Error return errors.Trace(err) } // GetUsers returns users from MySQL. func (d *SQLDatabase) GetUsers(ctx context.Context, cursor string, n int) (string, []User, error) { buf, err := base64.StdEncoding.DecodeString(cursor) if err != nil { return "", nil, errors.Trace(err) } cursorUser := string(buf) tx := d.gormDB.WithContext(ctx). Table(d.UsersTable()). Select("user_id, labels, comment") if cursorUser != "" { tx.Where("user_id >= ?", cursorUser) } result, err := tx.Order("user_id").Limit(n + 1).Rows() if err != nil { return "", nil, errors.Trace(err) } users := make([]User, 0) defer result.Close() for result.Next() { var user User if err = d.gormDB.ScanRows(result, &user); err != nil { return "", nil, errors.Trace(err) } users = append(users, user) } if len(users) == n+1 { return base64.StdEncoding.EncodeToString([]byte(users[len(users)-1].UserId)), users[:len(users)-1], nil } return "", users, nil } // GetUserStream read users by stream. func (d *SQLDatabase) GetUserStream(ctx context.Context, batchSize int) (chan []User, chan error) { userChan := make(chan []User, bufSize) errChan := make(chan error, 1) go func() { defer close(userChan) defer close(errChan) // send query result, err := d.gormDB.WithContext(ctx).Table(d.UsersTable()).Select("user_id, labels, comment").Rows() if err != nil { errChan <- errors.Trace(err) return } // fetch result users := make([]User, 0, batchSize) defer result.Close() for result.Next() { var user User if err = d.gormDB.ScanRows(result, &user); err != nil { errChan <- errors.Trace(err) return } users = append(users, user) if len(users) == batchSize { userChan <- users users = make([]User, 0, batchSize) } } if len(users) > 0 { userChan <- users } errChan <- nil }() return userChan, errChan } // GetUserFeedback returns feedback of a user from MySQL. func (d *SQLDatabase) GetUserFeedback(ctx context.Context, userId string, endTime *time.Time, feedbackTypes ...expression.FeedbackTypeExpression) ([]Feedback, error) { tx := d.gormDB.WithContext(ctx) if d.driver == ClickHouse { tx = tx.Table(d.UserFeedbackTable()) } else { tx = tx.Table(d.FeedbackTable()) } if d.driver == ClickHouse { tx.Select("feedback_type, user_id, item_id, sum(value) AS value, min(time_stamp) AS time_stamp, max(updated) AS updated, anyLast(comment) AS comment"). Group("feedback_type, user_id, item_id") if endTime != nil { tx.Having("time_stamp <= ?", d.convertTimeZone(endTime)) } } else { tx.Select("feedback_type, user_id, item_id, value, time_stamp, updated, comment") if endTime != nil { tx.Where("time_stamp <= ?", d.convertTimeZone(endTime)) } } tx.Where("user_id = ?", userId) if len(feedbackTypes) > 0 { db := d.gormDB for _, feedbackType := range feedbackTypes { db = FeedbackTypeExpressionToSQL(db, feedbackType) } if d.driver == ClickHouse { tx.Having(db) } else { tx.Where(db) } } result, err := tx.Rows() if err != nil { return nil, errors.Trace(err) } feedbacks := make([]Feedback, 0) defer result.Close() for result.Next() { var feedback Feedback if err = d.gormDB.ScanRows(result, &feedback); err != nil { return nil, errors.Trace(err) } feedbacks = append(feedbacks, feedback) } return feedbacks, nil } // BatchInsertFeedback insert a batch feedback into MySQL. // If insertUser set, new users will be inserted to user table. // If insertItem set, new items will be inserted to item table. func (d *SQLDatabase) BatchInsertFeedback(ctx context.Context, feedback []Feedback, insertUser, insertItem, overwrite bool) error { tx := d.gormDB.WithContext(ctx) // skip empty list if len(feedback) == 0 { return nil } // collect users and items users := mapset.NewSet[string]() items := mapset.NewSet[string]() for _, v := range feedback { users.Add(v.UserId) items.Add(v.ItemId) } // insert users if insertUser { userList := users.ToSlice() if d.driver == ClickHouse { err := tx.Create(lo.Map(userList, func(userId string, _ int) ClickhouseUser { return ClickhouseUser{ SQLUser: SQLUser{ UserId: userId, Labels: "[]", }, } })).Error if err != nil { return errors.Trace(err) } } else { err := tx.Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: "user_id"}}, DoNothing: true, }).Create(lo.Map(userList, func(userId string, _ int) SQLUser { return SQLUser{ UserId: userId, Labels: "null", } })).Error if err != nil { return errors.Trace(err) } } } else { for _, user := range users.ToSlice() { rs, err := tx.Table(d.UsersTable()).Select("user_id").Where("user_id = ?", user).Rows() if err != nil { return errors.Trace(err) } else if !rs.Next() { users.Remove(user) } if err = rs.Close(); err != nil { return errors.Trace(err) } } } // insert items if insertItem { itemList := items.ToSlice() if d.driver == ClickHouse { err := tx.Create(lo.Map(itemList, func(itemId string, _ int) ClickHouseItem { return ClickHouseItem{ SQLItem: SQLItem{ ItemId: itemId, Labels: "[]", Categories: "[]", }, } })).Error if err != nil { return errors.Trace(err) } } else { err := tx.Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: "item_id"}}, DoNothing: true, }).Create(lo.Map(itemList, func(itemId string, _ int) SQLItem { return SQLItem{ ItemId: itemId, Labels: "null", Categories: "null", } })).Error if err != nil { return errors.Trace(err) } } } else { for _, item := range items.ToSlice() { rs, err := tx.Table(d.ItemsTable()).Select("item_id").Where("item_id = ?", item).Rows() if err != nil { return errors.Trace(err) } else if !rs.Next() { items.Remove(item) } if err = rs.Close(); err != nil { return errors.Trace(err) } } } // insert feedback if d.driver == ClickHouse { rows := make([]Feedback, 0, len(feedback)) memo := make(map[lo.Tuple3[string, string, string]]struct{}) for _, f := range feedback { if users.Contains(f.UserId) && items.Contains(f.ItemId) { if _, exist := memo[lo.Tuple3[string, string, string]{f.FeedbackType, f.UserId, f.ItemId}]; !exist { memo[lo.Tuple3[string, string, string]{f.FeedbackType, f.UserId, f.ItemId}] = struct{}{} f.Timestamp = f.Timestamp.In(time.UTC) f.Updated = f.Timestamp rows = append(rows, f) } } } if len(rows) == 0 { return nil } err := tx.Create(rows).Error return errors.Trace(err) } else { rows := make([]Feedback, 0, len(feedback)) memo := make(map[lo.Tuple3[string, string, string]]struct{}) for _, f := range feedback { if users.Contains(f.UserId) && items.Contains(f.ItemId) { if _, exist := memo[lo.Tuple3[string, string, string]{f.FeedbackType, f.UserId, f.ItemId}]; !exist { memo[lo.Tuple3[string, string, string]{f.FeedbackType, f.UserId, f.ItemId}] = struct{}{} if d.driver == SQLite { f.Timestamp = f.Timestamp.In(time.UTC) } f.Updated = f.Timestamp if d.driver == SQLite { f.Updated = f.Updated.In(time.UTC) } rows = append(rows, f) } } } if len(rows) == 0 { return nil } var updates clause.Set if overwrite { updates = clause.AssignmentColumns([]string{"time_stamp", "updated", "comment", "value"}) } else { values := make(map[string]any) switch d.driver { case MySQL: values["value"] = clause.Column{Raw: true, Name: "value + VALUES(value)"} values["time_stamp"] = clause.Column{Raw: true, Name: "LEAST(time_stamp, VALUES(time_stamp))"} values["updated"] = clause.Column{Raw: true, Name: "GREATEST(updated, VALUES(updated))"} values["comment"] = clause.Column{Raw: true, Name: "VALUES(comment)"} case Postgres: values["value"] = clause.Column{Raw: true, Name: fmt.Sprintf("%s.value + EXCLUDED.value", d.FeedbackTable())} values["time_stamp"] = clause.Column{Raw: true, Name: fmt.Sprintf("LEAST(%s.time_stamp, EXCLUDED.time_stamp)", d.FeedbackTable())} values["updated"] = clause.Column{Raw: true, Name: fmt.Sprintf("GREATEST(%s.updated, EXCLUDED.updated)", d.FeedbackTable())} values["comment"] = clause.Column{Raw: true, Name: "EXCLUDED.comment"} case SQLite: values["value"] = clause.Column{Raw: true, Name: "value + excluded.value"} values["time_stamp"] = clause.Column{Raw: true, Name: "MIN(time_stamp, excluded.time_stamp)"} values["updated"] = clause.Column{Raw: true, Name: "MAX(updated, excluded.updated)"} values["comment"] = clause.Column{Raw: true, Name: "excluded.comment"} } updates = clause.Assignments(values) } err := tx.Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: "feedback_type"}, {Name: "user_id"}, {Name: "item_id"}}, DoUpdates: updates, }).Create(rows).Error return errors.Trace(err) } } // GetFeedback returns feedback from MySQL. func (d *SQLDatabase) GetFeedback(ctx context.Context, cursor string, n int, beginTime, endTime *time.Time, feedbackTypes ...string) (string, []Feedback, error) { buf, err := base64.StdEncoding.DecodeString(cursor) if err != nil { return "", nil, errors.Trace(err) } tx := d.gormDB.WithContext(ctx).Table(d.FeedbackTable()).Select("feedback_type, user_id, item_id, value, time_stamp, updated, comment") if len(buf) > 0 { var cursorKey FeedbackKey if err := jsonutil.Unmarshal(buf, &cursorKey); err != nil { return "", nil, err } tx.Where("(feedback_type, user_id, item_id) >= (?,?,?)", cursorKey.FeedbackType, cursorKey.UserId, cursorKey.ItemId) } if len(feedbackTypes) > 0 { db := d.gormDB for _, feedbackType := range feedbackTypes { db = db.Or("feedback_type = ?", feedbackType) } tx.Where(db) } if beginTime != nil { tx.Where("time_stamp >= ?", d.convertTimeZone(beginTime)) } if endTime != nil { tx.Where("time_stamp <= ?", d.convertTimeZone(endTime)) } tx.Order("feedback_type, user_id, item_id").Limit(n + 1) result, err := tx.Rows() if err != nil { return "", nil, errors.Trace(err) } feedbacks := make([]Feedback, 0) defer result.Close() for result.Next() { var feedback Feedback if err = d.gormDB.ScanRows(result, &feedback); err != nil { return "", nil, errors.Trace(err) } feedbacks = append(feedbacks, feedback) } if len(feedbacks) == n+1 { nextCursorKey := feedbacks[len(feedbacks)-1].FeedbackKey nextCursor, err := jsonutil.Marshal(nextCursorKey) if err != nil { return "", nil, errors.Trace(err) } return base64.StdEncoding.EncodeToString(nextCursor), feedbacks[:len(feedbacks)-1], nil } return "", feedbacks, nil } // GetFeedbackStream reads feedback by stream. func (d *SQLDatabase) GetFeedbackStream(ctx context.Context, batchSize int, scanOptions ...ScanOption) (chan []Feedback, chan error) { scan := NewScanOptions(scanOptions...) feedbackChan := make(chan []Feedback, bufSize) errChan := make(chan error, 1) go func() { defer close(feedbackChan) defer close(errChan) // send query tx := d.gormDB.WithContext(ctx). Table(d.FeedbackTable()). Select("feedback_type, user_id, item_id, value, time_stamp, updated, comment") if len(scan.FeedbackTypes) > 0 { db := d.gormDB for _, feedbackType := range scan.FeedbackTypes { db = FeedbackTypeExpressionToSQL(db, feedbackType) } tx.Where(db) } if scan.BeginTime != nil { tx.Where("time_stamp >= ?", d.convertTimeZone(scan.BeginTime)) } if scan.EndTime != nil { tx.Where("time_stamp <= ?", d.convertTimeZone(scan.EndTime)) } if scan.BeginUserId != nil { tx.Where("user_id >= ?", scan.BeginUserId) } if scan.EndUserId != nil { tx.Where("user_id <= ?", scan.EndUserId) } if scan.BeginItemId != nil { tx.Where("item_id >= ?", scan.BeginItemId) } if scan.EndItemId != nil { tx.Where("item_id <= ?", scan.EndItemId) } if scan.OrderByItemId { tx.Order("item_id") } else { tx.Order("feedback_type, user_id, item_id") } result, err := tx.Rows() if err != nil { errChan <- errors.Trace(err) return } // fetch result feedbacks := make([]Feedback, 0, batchSize) defer result.Close() for result.Next() { var feedback Feedback if err = d.gormDB.ScanRows(result, &feedback); err != nil { errChan <- errors.Trace(err) return } feedbacks = append(feedbacks, feedback) if len(feedbacks) == batchSize { feedbackChan <- feedbacks feedbacks = make([]Feedback, 0, batchSize) } } if len(feedbacks) > 0 { feedbackChan <- feedbacks } errChan <- nil }() return feedbackChan, errChan } // GetUserItemFeedback gets a feedback by user id and item id from MySQL. func (d *SQLDatabase) GetUserItemFeedback(ctx context.Context, userId, itemId string, feedbackTypes ...string) ([]Feedback, error) { tx := d.gormDB.WithContext(ctx) if d.driver == ClickHouse { tx = tx.Table(d.UserFeedbackTable()). Select("feedback_type, user_id, item_id, sum(value) AS value, any(time_stamp) AS time_stamp, max(updated) AS updated, any(comment) AS comment"). Group("feedback_type, user_id, item_id") } else { tx = tx.Table(d.FeedbackTable()). Select("feedback_type, user_id, item_id, value, time_stamp, updated, comment") } tx.Where("user_id = ? AND item_id = ?", userId, itemId) if len(feedbackTypes) > 0 { db := d.gormDB for _, feedbackType := range feedbackTypes { db = db.Or("feedback_type = ?", feedbackType) } tx.Where(db) } result, err := tx.Rows() if err != nil { return nil, errors.Trace(err) } feedbacks := make([]Feedback, 0) defer result.Close() for result.Next() { var feedback Feedback if err = d.gormDB.ScanRows(result, &feedback); err != nil { return nil, errors.Trace(err) } feedbacks = append(feedbacks, feedback) } return feedbacks, nil } // DeleteUserItemFeedback deletes a feedback by user id and item id from MySQL. func (d *SQLDatabase) DeleteUserItemFeedback(ctx context.Context, userId, itemId string, feedbackTypes ...string) (int, error) { deleteUserItemFeedback := func(value any) (int, error) { tx := d.gormDB.WithContext(ctx).Where("user_id = ? AND item_id = ?", userId, itemId) if len(feedbackTypes) > 0 { tx.Where("feedback_type IN ?", feedbackTypes) } tx.Delete(value) if tx.Error != nil { return 0, errors.Trace(tx.Error) } return int(tx.RowsAffected), nil } rowAffected, err := deleteUserItemFeedback(&Feedback{}) if err != nil { return 0, errors.Trace(err) } if d.driver == ClickHouse { _, err = deleteUserItemFeedback(&UserFeedback{}) if err != nil { return 0, errors.Trace(err) } _, err = deleteUserItemFeedback(&ItemFeedback{}) if err != nil { return 0, errors.Trace(err) } } return rowAffected, nil } func (d *SQLDatabase) convertTimeZone(timestamp *time.Time) time.Time { switch d.driver { case ClickHouse, SQLite: return timestamp.In(time.UTC) default: return *timestamp } } func (d *SQLDatabase) CountUsers(ctx context.Context) (int, error) { var ( count int64 err error ) switch d.driver { case MySQL: var tableStatus struct { Rows int64 } err = d.gormDB.WithContext(ctx). Raw(fmt.Sprintf("show table status like '%s'", d.UsersTable())). Scan(&tableStatus).Error count = tableStatus.Rows case Postgres: var pgCount float64 err = d.gormDB.WithContext(ctx). Raw(fmt.Sprintf("SELECT reltuples AS estimate FROM pg_class where relname = '%s'", d.UsersTable())). Scan(&pgCount).Error count = max(int64(pgCount), 0) default: err = d.gormDB.WithContext(ctx).Table(d.UsersTable()).Count(&count).Error } return int(count), errors.Trace(err) } func (d *SQLDatabase) CountItems(ctx context.Context) (int, error) { var ( count int64 err error ) switch d.driver { case MySQL: var tableStatus struct { Rows int64 } err = d.gormDB.WithContext(ctx). Raw(fmt.Sprintf("show table status like '%s'", d.ItemsTable())). Scan(&tableStatus).Error count = tableStatus.Rows case Postgres: var pgCount float64 err = d.gormDB.WithContext(ctx). Raw(fmt.Sprintf("SELECT reltuples AS estimate FROM pg_class where relname = '%s'", d.ItemsTable())). Scan(&pgCount).Error count = max(int64(pgCount), 0) default: err = d.gormDB.WithContext(ctx).Table(d.ItemsTable()).Count(&count).Error } return int(count), errors.Trace(err) } func (d *SQLDatabase) CountFeedback(ctx context.Context) (int, error) { var ( count int64 err error ) switch d.driver { case MySQL: var tableStatus struct { Rows int64 } err = d.gormDB.WithContext(ctx). Raw(fmt.Sprintf("show table status like '%s'", d.FeedbackTable())). Scan(&tableStatus).Error count = tableStatus.Rows case Postgres: var pgCount float64 err = d.gormDB.WithContext(ctx). Raw(fmt.Sprintf("SELECT reltuples AS estimate FROM pg_class where relname = '%s'", d.FeedbackTable())). Scan(&pgCount).Error count = max(int64(pgCount), 0) default: err = d.gormDB.WithContext(ctx).Table(d.FeedbackTable()).Count(&count).Error } return int(count), errors.Trace(err) } ================================================ FILE: storage/data/sql_test.go ================================================ // Copyright 2021 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package data import ( "database/sql" "fmt" "os" "strings" "testing" "github.com/gorse-io/gorse/storage" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" ) var ( mySqlDSN string postgresDSN string clickhouseDSN string ) func init() { // get environment variables env := func(key, defaultValue string) string { if value := os.Getenv(key); value != "" { return value } return defaultValue } mySqlDSN = env("MYSQL_URI", "mysql://root:password@tcp(127.0.0.1:3306)/") postgresDSN = env("POSTGRES_URI", "postgres://gorse:gorse_pass@127.0.0.1/") clickhouseDSN = env("CLICKHOUSE_URI", "clickhouse://127.0.0.1:8123/") } type MySQLTestSuite struct { baseTestSuite } func (suite *MySQLTestSuite) SetupSuite() { // create database databaseComm, err := sql.Open("mysql", mySqlDSN[len(storage.MySQLPrefix):]) suite.NoError(err) const dbName = "gorse_data_test" _, err = databaseComm.Exec("DROP DATABASE IF EXISTS " + dbName) suite.NoError(err) _, err = databaseComm.Exec("CREATE DATABASE " + dbName) suite.NoError(err) err = databaseComm.Close() suite.NoError(err) // connect database suite.Database, err = Open(mySqlDSN+dbName, "gorse_") suite.NoError(err) // create schema err = suite.Database.Init() suite.NoError(err) } func (suite *MySQLTestSuite) TestInit() { name, err := storage.ProbeMySQLIsolationVariableName(mySqlDSN[len(storage.MySQLPrefix):]) suite.NoError(err) connection := suite.Database.(*SQLDatabase).client assertQuery(suite.T(), connection, fmt.Sprintf("SELECT @@%s", name), "READ-UNCOMMITTED") assertQuery(suite.T(), connection, "SELECT @@sql_mode", "ONLY_FULL_GROUP_BY,STRICT_TRANS_TABLES,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION") } func TestMySQL(t *testing.T) { suite.Run(t, new(MySQLTestSuite)) } type PostgresTestSuite struct { baseTestSuite } func (suite *PostgresTestSuite) SetupSuite() { var err error // create database databaseComm, err := sql.Open("postgres", postgresDSN+"?sslmode=disable") suite.NoError(err) const dbName = "gorse_data_test" _, err = databaseComm.Exec("DROP DATABASE IF EXISTS " + dbName) suite.NoError(err) _, err = databaseComm.Exec("CREATE DATABASE " + dbName) suite.NoError(err) err = databaseComm.Close() suite.NoError(err) // connect database suite.Database, err = Open(postgresDSN+strings.ToLower(dbName)+"?sslmode=disable", "gorse_") suite.NoError(err) // create schema err = suite.Database.Init() suite.NoError(err) } func TestPostgres(t *testing.T) { suite.Run(t, new(PostgresTestSuite)) } type ClickHouseTestSuite struct { baseTestSuite } func (suite *ClickHouseTestSuite) SetupSuite() { var err error // create database databaseComm, err := sql.Open("chhttp", "http://"+clickhouseDSN[len(storage.ClickhousePrefix):]) suite.NoError(err) const dbName = "gorse_data_test" _, err = databaseComm.Exec("DROP DATABASE IF EXISTS " + dbName) suite.NoError(err) _, err = databaseComm.Exec("CREATE DATABASE " + dbName) suite.NoError(err) err = databaseComm.Close() suite.NoError(err) // connect database suite.Database, err = Open(clickhouseDSN+dbName+"?mutations_sync=2", "gorse_") suite.NoError(err) // create schema err = suite.Database.Init() suite.NoError(err) } func TestClickHouse(t *testing.T) { suite.Run(t, new(ClickHouseTestSuite)) } type SQLiteTestSuite struct { baseTestSuite } func (suite *SQLiteTestSuite) SetupSuite() { var err error // create database path := fmt.Sprintf("sqlite://%s/sqlite.db", suite.T().TempDir()) suite.Database, err = Open(path, "gorse_") suite.NoError(err) // create schema err = suite.Database.Init() suite.NoError(err) } func (suite *SQLiteTestSuite) TearDownSuite() { suite.NoError(suite.Database.Close()) } func TestSQLite(t *testing.T) { suite.Run(t, new(SQLiteTestSuite)) } func assertQuery(t *testing.T, connection *sql.DB, sql string, expected string) { rows, err := connection.Query(sql) assert.NoError(t, err) assert.True(t, rows.Next()) var result string err = rows.Scan(&result) assert.NoError(t, err) assert.Equal(t, expected, result) } func BenchmarkMySQL_CountItems(b *testing.B) { // create database database, err := Open(mySqlDSN, "gorse_") require.NoError(b, err) dbName := "gorse_data_test" databaseComm := database.(*SQLDatabase) _, err = databaseComm.client.Exec("DROP DATABASE IF EXISTS " + dbName) require.NoError(b, err) _, err = databaseComm.client.Exec("CREATE DATABASE " + dbName) require.NoError(b, err) database, err = Open(mySqlDSN+dbName, "gorse_") require.NoError(b, err) err = database.Init() require.NoError(b, err) // benchmark benchmarkCountItems(b, database) // close database err = database.Close() require.NoError(b, err) } func BenchmarkPostgres_CountItems(b *testing.B) { // create database database, err := Open(postgresDSN+"gorse_data_test?sslmode=disable", "gorse_") require.NoError(b, err) err = database.Init() require.NoError(b, err) // benchmark benchmarkCountItems(b, database) // close database err = database.Close() require.NoError(b, err) } func BenchmarkClickHouse_CountItems(b *testing.B) { // create database databaseComm, err := sql.Open("chhttp", "http://"+clickhouseDSN[len(storage.ClickhousePrefix):]) require.NoError(b, err) const dbName = "gorse_data_test" _, err = databaseComm.Exec("DROP DATABASE IF EXISTS " + dbName) require.NoError(b, err) _, err = databaseComm.Exec("CREATE DATABASE " + dbName) require.NoError(b, err) err = databaseComm.Close() require.NoError(b, err) database, err := Open(clickhouseDSN+"gorse_data_test?mutations_sync=2", "gorse_") require.NoError(b, err) err = database.Init() require.NoError(b, err) // benchmark benchmarkCountItems(b, database) // close database err = database.Close() require.NoError(b, err) } func BenchmarkSQLite_CountItems(b *testing.B) { // create database database, err := Open("sqlite://"+os.TempDir()+"/sqlite.db", "gorse_") require.NoError(b, err) err = database.Init() require.NoError(b, err) // benchmark benchmarkCountItems(b, database) // close database err = database.Close() require.NoError(b, err) } ================================================ FILE: storage/docker-compose.yml ================================================ version: "3" services: redis: image: redis/redis-stack:6.2.6-v9 ports: - 6379:6379 mysql: image: mysql:8.0 ports: - 3306:3306 environment: MYSQL_ROOT_PASSWORD: password MYSQL_DATABASE: gorse MYSQL_USER: gorse MYSQL_PASSWORD: gorse_pass postgres: image: postgres:10.0 ports: - 5432:5432 environment: POSTGRES_USER: gorse POSTGRES_PASSWORD: gorse_pass mongo: image: mongo:4.0 ports: - 27017:27017 environment: MONGO_INITDB_ROOT_USERNAME: root MONGO_INITDB_ROOT_PASSWORD: password clickhouse: image: clickhouse/clickhouse-server:23 ports: - 8123:8123 rustfs: image: rustfs/rustfs:alpha ports: - 9000:9000 environment: RUSTFS_ACCESS_KEY: rustfsadmin RUSTFS_SECRET_KEY: rustfsadmin azurite: image: mcr.microsoft.com/azure-storage/azurite:3.34.0 ports: - 10000:10000 command: azurite-blob --blobHost 0.0.0.0 --blobPort 10000 qdrant: image: qdrant/qdrant:latest ports: - 6333:6333 - 6334:6334 weaviate: image: cr.weaviate.io/semitechnologies/weaviate:1.35.7 ports: - 8080:8080 # Milvus etcd: image: quay.io/coreos/etcd:v3.5.25 environment: - ETCD_AUTO_COMPACTION_MODE=revision - ETCD_AUTO_COMPACTION_RETENTION=1000 - ETCD_QUOTA_BACKEND_BYTES=4294967296 - ETCD_SNAPSHOT_COUNT=50000 command: etcd -advertise-client-urls=http://etcd:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd healthcheck: test: ["CMD", "etcdctl", "endpoint", "health"] interval: 30s timeout: 20s retries: 3 minio: image: minio/minio:RELEASE.2024-12-18T13-15-44Z environment: MINIO_ACCESS_KEY: minioadmin MINIO_SECRET_KEY: minioadmin ports: - 9001:9001 - 9000:9000 command: minio server /minio_data --console-address ":9001" healthcheck: test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"] interval: 30s timeout: 20s retries: 3 milvus: image: milvusdb/milvus:v2.6.9 command: ["milvus", "run", "standalone"] security_opt: - seccomp:unconfined environment: ETCD_ENDPOINTS: etcd:2379 MINIO_ADDRESS: minio:9000 MQ_TYPE: woodpecker healthcheck: test: ["CMD", "curl", "-f", "http://localhost:9091/healthz"] interval: 30s start_period: 90s timeout: 20s retries: 3 ports: - 19530:19530 - 9091:9091 depends_on: - etcd - minio ================================================ FILE: storage/meta/database.go ================================================ // Copyright 2024 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package meta import ( "encoding/json" "strings" "time" "github.com/XSAM/otelsql" "github.com/gorse-io/gorse/model" "github.com/gorse-io/gorse/storage" "github.com/juju/errors" "github.com/samber/lo" semconv "go.opentelemetry.io/otel/semconv/v1.12.0" "golang.org/x/exp/maps" ) const ( COLLABORATIVE_FILTERING_MODEL = "COLLABORATIVE_FILTERING_MODEL" CLICK_THROUGH_RATE_MODEL = "CLICK_THROUGH_RATE_MODEL" RECOMMEND_CONFIG = "RECOMMEND_CONFIG" ) type Model[T any] struct { ID int64 Type string Params model.Params Score T } func (m *Model[T]) ToJSON() string { return string(lo.Must1(json.Marshal(m))) } func (m *Model[T]) FromJSON(data string) error { return json.Unmarshal([]byte(data), m) } // Equal checks if two models have the same type and parameters. func (m *Model[T]) Equal(other Model[T]) bool { return m.Type == other.Type && maps.Equal(m.Params, other.Params) } type Node struct { UUID string Hostname string Type string Version string UpdateTime time.Time } type Database interface { Close() error Init() error UpdateNode(node *Node) error ListNodes() ([]*Node, error) Put(key, value string) error Get(key string) (*string, error) Delete(key string) error } // Open a connection to a database. func Open(path string, ttl time.Duration) (Database, error) { var err error if strings.HasPrefix(path, storage.SQLitePrefix) { dataSourceName := path[len(storage.SQLitePrefix):] // append parameters if dataSourceName, err = storage.AppendURLParams(dataSourceName, []lo.Tuple2[string, string]{ {"_pragma", "busy_timeout(10000)"}, {"_pragma", "journal_mode(wal)"}, }); err != nil { return nil, errors.Trace(err) } // connect to database database := new(SQLite) database.ttl = ttl if database.db, err = otelsql.Open("sqlite", dataSourceName, otelsql.WithAttributes(semconv.DBSystemSqlite), otelsql.WithSpanOptions(otelsql.SpanOptions{DisableErrSkip: true}), ); err != nil { return nil, errors.Trace(err) } return database, nil } return nil, errors.Errorf("Unknown database: %s", path) } ================================================ FILE: storage/meta/database_test.go ================================================ // Copyright 2024 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package meta import ( "testing" "time" "github.com/gorse-io/gorse/model" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" ) type baseTestSuite struct { suite.Suite Database } func (suite *baseTestSuite) TestNodes() { // Add node err := suite.Database.UpdateNode(&Node{ UUID: "node-1", Hostname: "localhost", Type: "master", Version: "v0.1.0", UpdateTime: time.Now(), }) suite.NoError(err) // Add duplicate node err = suite.Database.UpdateNode(&Node{ UUID: "node-1", Hostname: "localhost", Type: "master", Version: "v0.1.1", UpdateTime: time.Now(), }) suite.NoError(err) // Add outdated node err = suite.Database.UpdateNode(&Node{ UUID: "node-2", Hostname: "localhost", Type: "master", Version: "v0.1.0", UpdateTime: time.Now().Add(-time.Hour), }) suite.NoError(err) // List nodes nodes, err := suite.Database.ListNodes() suite.NoError(err) if suite.Equal(1, len(nodes)) { suite.Equal("node-1", nodes[0].UUID) suite.Equal("localhost", nodes[0].Hostname) suite.Equal("master", nodes[0].Type) suite.Equal("v0.1.1", nodes[0].Version) } } func (suite *baseTestSuite) TestKeyValues() { err := suite.Database.Put("key1", "value1") suite.NoError(err) err = suite.Database.Put("key2", "value2") suite.NoError(err) err = suite.Database.Put("key3", "value3") suite.NoError(err) value, err := suite.Database.Get("key1") suite.NoError(err) suite.Equal("value1", *value) value, err = suite.Database.Get("key2") suite.NoError(err) suite.Equal("value2", *value) value, err = suite.Database.Get("key3") suite.NoError(err) suite.Equal("value3", *value) // Test overwrite err = suite.Database.Put("key1", "new_value1") suite.NoError(err) value, err = suite.Database.Get("key1") suite.NoError(err) suite.Equal("new_value1", *value) // Test non-existing key value, err = suite.Database.Get("non-existing-key") suite.NoError(err) suite.Nil(value) // Test delete existing key err = suite.Database.Delete("key2") suite.NoError(err) value, err = suite.Database.Get("key2") suite.NoError(err) suite.Nil(value) // Test delete non-existing key err = suite.Database.Delete("non-existing-key") suite.NoError(err) } func TestModel_Equal(t *testing.T) { a := Model[int]{ ID: 1, Type: "test", Params: map[model.ParamName]any{ "param1": 1, "param2": "value2", }, Score: 0, } b := Model[int]{ ID: 2, Type: "test", Params: map[model.ParamName]any{ "param1": 1, "param2": "value2", }, Score: 1, } assert.True(t, a.Equal(b)) a.Type = "different" assert.False(t, a.Equal(b)) a.Type = "test" a.Params["param2"] = "different" assert.False(t, a.Equal(b)) a.Params["param2"] = "value2" } ================================================ FILE: storage/meta/sqlite.go ================================================ // Copyright 2024 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package meta import ( "database/sql" "errors" "fmt" "time" _ "modernc.org/sqlite" ) type SQLite struct { db *sql.DB ttl time.Duration } func (s *SQLite) Close() error { return s.db.Close() } func (s *SQLite) Init() error { // Create tables if _, err := s.db.Exec(` CREATE TABLE IF NOT EXISTS nodes ( uuid TEXT PRIMARY KEY, hostname TEXT, type TEXT, version TEXT, update_time DATETIME );`); err != nil { return err } if _, err := s.db.Exec(` CREATE TABLE IF NOT EXISTS cron_jobs ( name TEXT PRIMARY KEY, description TEXT, current INTEGER, total INTEGER, start_time TIMESTAMP, end_time TIMESTAMP, update_time TIMESTAMP );`); err != nil { return err } if _, err := s.db.Exec(` CREATE TABLE IF NOT EXISTS key_values ( key TEXT PRIMARY KEY, value TEXT );`); err != nil { return err } return nil } func (s *SQLite) UpdateNode(node *Node) error { _, err := s.db.Exec(` INSERT INTO nodes (uuid, hostname, type, version, update_time) VALUES (?, ?, ?, ?, ?) ON CONFLICT(uuid) DO UPDATE SET hostname = excluded.hostname, type = excluded.type, version = excluded.version, update_time = excluded.update_time `, node.UUID, node.Hostname, node.Type, node.Version, node.UpdateTime.UTC()) return err } func (s *SQLite) ListNodes() ([]*Node, error) { // List nodes within TTL rs, err := s.db.Query(` SELECT uuid, hostname, type, version, update_time FROM nodes WHERE update_time > datetime('now', ?) `, fmt.Sprintf("-%.0f seconds", s.ttl.Seconds())) if err != nil { return nil, err } defer rs.Close() var nodes []*Node for rs.Next() { var node Node if err = rs.Scan(&node.UUID, &node.Hostname, &node.Type, &node.Version, &node.UpdateTime); err != nil { return nil, err } nodes = append(nodes, &node) } // Delete outdated nodes if _, err = s.db.Exec(` DELETE FROM nodes WHERE update_time < datetime('now', ?) `, fmt.Sprintf("-%.0f seconds", s.ttl.Seconds())); err != nil { return nil, err } return nodes, nil } func (s *SQLite) Put(key, value string) error { _, err := s.db.Exec(` INSERT INTO key_values (key, value) VALUES (?, ?) ON CONFLICT(key) DO UPDATE SET value = excluded.value `, key, value) return err } func (s *SQLite) Get(key string) (*string, error) { var value string err := s.db.QueryRow(` SELECT value FROM key_values WHERE key = ? `, key).Scan(&value) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, nil // key not found } return nil, err // other error } return &value, nil // key found } func (s *SQLite) Delete(key string) error { _, err := s.db.Exec(` DELETE FROM key_values WHERE key = ? `, key) return err } ================================================ FILE: storage/meta/sqlite_test.go ================================================ // Copyright 2024 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package meta import ( "fmt" "github.com/stretchr/testify/suite" "testing" "time" ) type SQLiteTestSuite struct { baseTestSuite } func (suite *SQLiteTestSuite) SetupTest() { var err error // create database path := fmt.Sprintf("sqlite://%s/sqlite.db", suite.T().TempDir()) suite.Database, err = Open(path, time.Second) suite.NoError(err) // create schema err = suite.Database.Init() suite.NoError(err) } func (suite *SQLiteTestSuite) TearDownTest() { suite.NoError(suite.Database.Close()) } func TestSQLite(t *testing.T) { suite.Run(t, new(SQLiteTestSuite)) } ================================================ FILE: storage/options.go ================================================ package storage import ( "database/sql" "time" ) type Options struct { IsolationLevel string MaxOpenConns int MaxIdleConns int ConnMaxLifetime time.Duration MaxSearchResults int } type Option func(*Options) func WithIsolationLevel(isolationLevel string) Option { return func(o *Options) { o.IsolationLevel = isolationLevel } } func WithMaxOpenConns(maxOpenConns int) Option { return func(o *Options) { o.MaxOpenConns = maxOpenConns } } func WithMaxIdleConns(maxIdleConns int) Option { return func(o *Options) { o.MaxIdleConns = maxIdleConns } } func WithConnMaxLifetime(connMaxLifetime time.Duration) Option { return func(o *Options) { o.ConnMaxLifetime = connMaxLifetime } } func WithMaxSearchResults(limit int) Option { return func(o *Options) { o.MaxSearchResults = limit } } func ApplySQLPool(db *sql.DB, opt Options) { if opt.MaxOpenConns > 0 { db.SetMaxOpenConns(opt.MaxOpenConns) } if opt.MaxIdleConns > 0 { db.SetMaxIdleConns(opt.MaxIdleConns) } if opt.ConnMaxLifetime > 0 { db.SetConnMaxLifetime(opt.ConnMaxLifetime) } } func NewOptions(opts ...Option) Options { opt := Options{ IsolationLevel: "READ-UNCOMMITTED", MaxSearchResults: 10000, } for _, o := range opts { o(&opt) } return opt } ================================================ FILE: storage/schema_test.go ================================================ package storage import ( "github.com/samber/lo" "github.com/stretchr/testify/assert" "testing" ) func TestAppendURLParams(t *testing.T) { // test windows path url, err := AppendURLParams(`c:\\sqlite.db`, []lo.Tuple2[string, string]{{"a", "b"}}) assert.NoError(t, err) assert.Equal(t, `c:\\sqlite.db?a=b`, url) // test no scheme url, err = AppendURLParams(`sqlite.db`, []lo.Tuple2[string, string]{{"a", "b"}}) assert.NoError(t, err) assert.Equal(t, `sqlite.db?a=b`, url) } ================================================ FILE: storage/scheme.go ================================================ // Copyright 2022 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package storage import ( "database/sql" "database/sql/driver" "encoding/json" "net/url" "strings" "github.com/go-sql-driver/mysql" "github.com/juju/errors" "github.com/samber/lo" "gorm.io/gorm" "gorm.io/gorm/logger" "gorm.io/gorm/schema" "modernc.org/sqlite" ) func init() { sqlite.MustRegisterDeterministicScalarFunction("json_contains", 2, func(ctx *sqlite.FunctionContext, args []driver.Value) (driver.Value, error) { parse := func(arg driver.Value) (j []any, err error) { var data []byte switch argTyped := arg.(type) { case string: data = []byte(argTyped) case []byte: data = argTyped default: return nil, errors.Errorf("unsupported type %T", arg) } err = json.Unmarshal(data, &j) return } if args[0] == nil || args[1] == nil { return nil, nil } j1, err := parse(args[0]) if err != nil { return nil, err } j2, err := parse(args[1]) if err != nil { return nil, err } elements := make(map[any]struct{}, len(j1)) for _, e := range j1 { elements[e] = struct{}{} } for _, e := range j2 { if _, ok := elements[e]; !ok { return false, nil } } return true, nil }) } const ( MySQLPrefix = "mysql://" MongoPrefix = "mongodb://" MongoSrvPrefix = "mongodb+srv://" PostgresPrefix = "postgres://" PostgreSQLPrefix = "postgresql://" ClickhousePrefix = "clickhouse://" CHHTTPPrefix = "chhttp://" CHHTTPSPrefix = "chhttps://" SQLitePrefix = "sqlite://" RedisPrefix = "redis://" RedissPrefix = "rediss://" RedisClusterPrefix = "redis+cluster://" RedissClusterPrefix = "rediss+cluster://" QdrantPrefix = "qdrant://" WeaviatePrefix = "weaviate://" WeaviatesPrefix = "weaviates://" MilvusPrefix = "milvus://" ) func AppendURLParams(rawURL string, params []lo.Tuple2[string, string]) (string, error) { parsed, err := url.Parse(rawURL) if err != nil { return "", errors.Trace(err) } q := parsed.Query() for _, tuple := range params { q.Add(tuple.A, tuple.B) } parsed.RawQuery = q.Encode() return parsed.String(), nil } func AppendMySQLParams(dsn string, params map[string]string) (string, error) { cfg, err := mysql.ParseDSN(dsn) if err != nil { return "", errors.Trace(err) } if cfg.Params == nil { cfg.Params = make(map[string]string) } for key, value := range params { if _, exist := cfg.Params[key]; !exist { cfg.Params[key] = value } } return cfg.FormatDSN(), nil } func ProbeMySQLIsolationVariableName(dsn string) (string, error) { connection, err := sql.Open("mysql", dsn) if err != nil { return "", errors.Trace(err) } defer connection.Close() rows, err := connection.Query("SHOW VARIABLES WHERE variable_name = 'transaction_isolation' OR variable_name = 'tx_isolation'") if err != nil { return "", errors.Trace(err) } defer rows.Close() var name, value string if rows.Next() { if err = rows.Scan(&name, &value); err != nil { return "", errors.Trace(err) } } return name, nil } type TablePrefix string func (tp TablePrefix) ValuesTable() string { return string(tp) + "values" } func (tp TablePrefix) SetsTable() string { return string(tp) + "sets" } func (tp TablePrefix) MessageTable() string { return string(tp) + "message" } func (tp TablePrefix) DocumentTable() string { return string(tp) + "documents" } func (tp TablePrefix) PointsTable() string { return string(tp) + "time_series_points" } func (tp TablePrefix) UsersTable() string { return string(tp) + "users" } func (tp TablePrefix) ItemsTable() string { return string(tp) + "items" } // LatestItemsTable returns the materialized view for latest items. func (tp TablePrefix) LatestItemsTable() string { return string(tp) + "latest_items" } func (tp TablePrefix) FeedbackTable() string { return string(tp) + "feedback" } // AggregatingFeedbackTable returns the aggregating feedback table. func (tp TablePrefix) AggregatingFeedbackTable() string { return string(tp) + "aggregating_feedback" } // UserFeedbackTable returns the materialized view of user feedback. func (tp TablePrefix) UserFeedbackTable() string { return string(tp) + "user_feedback" } // ItemFeedbackTable returns the materialized view of item feedback. func (tp TablePrefix) ItemFeedbackTable() string { return string(tp) + "item_feedback" } func (tp TablePrefix) Key(key string) string { return string(tp) + key } func NewGORMConfig(tablePrefix string) *gorm.Config { return &gorm.Config{ Logger: logger.Default.LogMode(logger.Silent), CreateBatchSize: 1000, SkipDefaultTransaction: true, NamingStrategy: schema.NamingStrategy{ TablePrefix: tablePrefix, SingularTable: true, NameReplacer: strings.NewReplacer( "SQLValue", "Values", "SQLSet", "Sets", "SQLUser", "Users", "SQLItem", "Items", "SQLFeedback", "Feedback", "SQLDocument", "Documents", "PostgresDocument", "Documents", "TimeSeriesPoint", "time_series_points", "ClickhouseUser", "Users", "ClickHouseItem", "Items", "ClickHouseFeedback", "Feedback", ), }, } } ================================================ FILE: storage/vectors/database.go ================================================ // Copyright 2026 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package vectors import ( "context" "strings" "time" "github.com/gorse-io/gorse/storage" "github.com/juju/errors" ) type Distance int const ( Cosine Distance = iota Euclidean Dot ) type Vector struct { Id string Vector []float32 IsHidden bool `json:"-"` Categories []string `json:"-" gorm:"type:text;serializer:json"` Timestamp time.Time `json:"-"` } type Database interface { Init() error Optimize() error Close() error ListCollections(ctx context.Context) ([]string, error) AddCollection(ctx context.Context, name string, dimensions int, distance Distance) error DeleteCollection(ctx context.Context, name string) error AddVectors(ctx context.Context, collection string, vectors []Vector) error DeleteVectors(ctx context.Context, collection string, timestamp time.Time) error QueryVectors(ctx context.Context, collection string, q []float32, categories []string, topK int) ([]Vector, error) } // Creator creates a database instance. type Creator func(path, tablePrefix string, opts ...storage.Option) (Database, error) var creators = make(map[string]Creator) // Register a database creator. func Register(prefixes []string, creator Creator) { for _, p := range prefixes { creators[p] = creator } } // Open a connection to a database. func Open(path, tablePrefix string, opts ...storage.Option) (Database, error) { for prefix, creator := range creators { if strings.HasPrefix(path, prefix) { return creator(path, tablePrefix, opts...) } } return nil, errors.Errorf("Unknown database: %s", path) } ================================================ FILE: storage/vectors/database_test.go ================================================ // Copyright 2026 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package vectors import ( "time" "github.com/stretchr/testify/suite" ) const defaultVectorSize = 4 type vectorsTestSuite struct { suite.Suite Database } func (suite *vectorsTestSuite) SetupTest() { // purge ctx := suite.T().Context() collections, err := suite.Database.ListCollections(ctx) suite.NoError(err) for _, collection := range collections { err = suite.Database.DeleteCollection(ctx, collection) suite.NoError(err) } } func (suite *vectorsTestSuite) TestCollections() { ctx := suite.T().Context() // list collections collections, err := suite.Database.ListCollections(ctx) suite.NoError(err) suite.Empty(collections) // create collection err = suite.Database.AddCollection(ctx, "test", defaultVectorSize, Cosine) suite.NoError(err) // list collections collections, err = suite.Database.ListCollections(ctx) suite.NoError(err) suite.Equal([]string{"test"}, collections) // delete collection err = suite.Database.DeleteCollection(ctx, "test") suite.NoError(err) // list collections collections, err = suite.Database.ListCollections(ctx) suite.NoError(err) suite.Empty(collections) // delete non-existent collection err = suite.Database.DeleteCollection(ctx, "non-existent") suite.Error(err) } func (suite *vectorsTestSuite) TestVectors() { ctx := suite.T().Context() err := suite.Database.AddCollection(ctx, "test", defaultVectorSize, Cosine) suite.NoError(err) vectorA := make([]float32, defaultVectorSize) vectorA[0] = 1 vectorB := make([]float32, defaultVectorSize) vectorB[0] = 0.9 vectorB[1] = 0.1 err = suite.Database.AddVectors(ctx, "test", []Vector{ { Id: "a", Vector: vectorA, Categories: []string{"cat-a", "common"}, }, { Id: "b", Vector: vectorB, Categories: []string{"cat-b", "common"}, }, }) suite.NoError(err) results, err := suite.Database.QueryVectors(ctx, "test", vectorA, []string{"cat-a"}, 10) suite.NoError(err) suite.Len(results, 1) suite.Equal("a", results[0].Id) suite.NotEmpty(results[0].Categories) results, err = suite.Database.QueryVectors(ctx, "test", vectorA, []string{"common"}, 10) suite.NoError(err) suite.Len(results, 2) ids := map[string]bool{} for _, result := range results { ids[result.Id] = true suite.NotEmpty(result.Categories) } suite.True(ids["a"]) suite.True(ids["b"]) results, err = suite.Database.QueryVectors(ctx, "test", vectorA, nil, 1) suite.NoError(err) suite.NotEmpty(results) for _, result := range results { suite.NotEmpty(result.Categories) } } func (suite *vectorsTestSuite) TestDeleteVectors() { ctx := suite.T().Context() err := suite.Database.AddCollection(ctx, "test", defaultVectorSize, Cosine) suite.NoError(err) vectorA := make([]float32, defaultVectorSize) vectorA[0] = 1 vectorB := make([]float32, defaultVectorSize) vectorB[0] = 0.9 vectorB[1] = 0.1 cutoff := time.Now().UTC().Truncate(time.Millisecond) err = suite.Database.AddVectors(ctx, "test", []Vector{ { Id: "old", Vector: vectorA, Categories: []string{"common"}, Timestamp: cutoff.Add(-time.Hour), }, { Id: "new", Vector: vectorB, Categories: []string{"common"}, Timestamp: cutoff, }, }) suite.NoError(err) err = suite.Database.DeleteVectors(ctx, "test", cutoff) suite.NoError(err) results, err := suite.Database.QueryVectors(ctx, "test", vectorA, []string{"common"}, 10) suite.NoError(err) suite.Len(results, 1) suite.Equal("new", results[0].Id) } ================================================ FILE: storage/vectors/milvus.go ================================================ // Copyright 2026 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package vectors import ( "context" "fmt" "net/url" "strings" "time" "github.com/gorse-io/gorse/storage" "github.com/juju/errors" "github.com/milvus-io/milvus-sdk-go/v2/client" "github.com/milvus-io/milvus-sdk-go/v2/entity" ) const ( milvusIdField = "id" milvusVectorField = "vector" milvusCategoriesField = "categories" milvusTimestampField = "timestamp" ) func init() { Register([]string{storage.MilvusPrefix}, func(path, tablePrefix string, opts ...storage.Option) (Database, error) { database := new(Milvus) u, err := url.Parse(path) if err != nil { return nil, errors.Trace(err) } database.client, err = client.NewClient(context.Background(), client.Config{ Address: u.Host, }) if err != nil { return nil, errors.Trace(err) } return database, nil }) } type Milvus struct { client client.Client } func (db *Milvus) Init() error { return nil } func (db *Milvus) Optimize() error { return nil } func (db *Milvus) Close() error { return db.client.Close() } func (db *Milvus) ListCollections(ctx context.Context) ([]string, error) { collections, err := db.client.ListCollections(ctx) if err != nil { return nil, errors.Trace(err) } var names []string for _, collection := range collections { names = append(names, collection.Name) } return names, nil } func (db *Milvus) AddCollection(ctx context.Context, name string, dimensions int, distance Distance) error { schema := entity.NewSchema().WithName(name).WithDescription("gorse collection"). WithField(entity.NewField().WithName(milvusIdField).WithDataType(entity.FieldTypeVarChar).WithMaxLength(65535).WithIsPrimaryKey(true)). WithField(entity.NewField().WithName(milvusCategoriesField).WithDataType(entity.FieldTypeArray).WithElementType(entity.FieldTypeVarChar).WithMaxCapacity(100).WithMaxLength(65535)). WithField(entity.NewField().WithName(milvusTimestampField).WithDataType(entity.FieldTypeInt64)). WithField(entity.NewField().WithName(milvusVectorField).WithDataType(entity.FieldTypeFloatVector).WithDim(int64(dimensions))) err := db.client.CreateCollection(ctx, schema, entity.DefaultShardNumber) if err != nil { return errors.Trace(err) } // Create index var metricType entity.MetricType switch distance { case Cosine: metricType = entity.COSINE case Euclidean: metricType = entity.L2 case Dot: metricType = entity.IP default: return errors.NotSupportedf("distance method") } idx, err := entity.NewIndexHNSW(metricType, 8, 200) if err != nil { return errors.Trace(err) } err = db.client.CreateIndex(ctx, name, milvusVectorField, idx, false) if err != nil { return errors.Trace(err) } scalarIdx := entity.NewScalarIndex() err = db.client.CreateIndex(ctx, name, milvusTimestampField, scalarIdx, false) if err != nil { return errors.Trace(err) } // Load collection err = db.client.LoadCollection(ctx, name, false) return errors.Trace(err) } func (db *Milvus) DeleteCollection(ctx context.Context, name string) error { exists, err := db.client.HasCollection(ctx, name) if err != nil { return errors.Trace(err) } if !exists { return errors.NotFoundf("collection %s", name) } err = db.client.DropCollection(ctx, name) return errors.Trace(err) } func (db *Milvus) AddVectors(ctx context.Context, collection string, vectors []Vector) error { if len(vectors) == 0 { return nil } ids := make([]string, 0, len(vectors)) categories := make([][]string, 0, len(vectors)) timestamps := make([]int64, 0, len(vectors)) data := make([][]float32, 0, len(vectors)) for _, v := range vectors { ids = append(ids, v.Id) categories = append(categories, v.Categories) timestamps = append(timestamps, v.Timestamp.UnixMilli()) data = append(data, v.Vector) } idCol := entity.NewColumnVarChar(milvusIdField, ids) categoriesCol := entity.NewColumnVarCharArray(milvusCategoriesField, milvusStringsToBytes(categories)) timestampCol := entity.NewColumnInt64(milvusTimestampField, timestamps) vectorCol := entity.NewColumnFloatVector(milvusVectorField, len(data[0]), data) _, err := db.client.Upsert(ctx, collection, "", idCol, categoriesCol, timestampCol, vectorCol) return errors.Trace(err) } func (db *Milvus) DeleteVectors(ctx context.Context, collection string, timestamp time.Time) error { err := db.client.Delete(ctx, collection, "", fmt.Sprintf("%s < %d", milvusTimestampField, timestamp.UnixMilli())) return errors.Trace(err) } func (db *Milvus) QueryVectors(ctx context.Context, collection string, q []float32, categories []string, topK int) ([]Vector, error) { if topK <= 0 { return []Vector{}, nil } var expr string if len(categories) > 0 { var conditions []string for _, category := range categories { conditions = append(conditions, fmt.Sprintf("array_contains(%s, '%s')", milvusCategoriesField, category)) } expr = strings.Join(conditions, " or ") } searchParam, _ := entity.NewIndexHNSWSearchParam(64) results, err := db.client.Search(ctx, collection, []string{}, expr, []string{milvusIdField, milvusCategoriesField}, []entity.Vector{entity.FloatVector(q)}, milvusVectorField, entity.COSINE, topK, searchParam, client.WithSearchQueryConsistencyLevel(entity.ClStrong)) if err != nil { return nil, errors.Trace(err) } var vectors []Vector for _, result := range results { var idCol *entity.ColumnVarChar if col := result.Fields.GetColumn(milvusIdField); col != nil { idCol = col.(*entity.ColumnVarChar) } else if result.IDs != nil { idCol = result.IDs.(*entity.ColumnVarChar) } var categoriesCol *entity.ColumnVarCharArray if col := result.Fields.GetColumn(milvusCategoriesField); col != nil { categoriesCol = col.(*entity.ColumnVarCharArray) } for i := 0; i < result.ResultCount; i++ { var id string if idCol != nil { id, err = idCol.ValueByIdx(i) if err != nil { return nil, errors.Trace(err) } } var cats []string if categoriesCol != nil { catsValue, err := categoriesCol.ValueByIdx(i) if err != nil { return nil, errors.Trace(err) } cats = milvusBytesToStrings(catsValue) } vectors = append(vectors, Vector{ Id: id, Categories: cats, }) } } return vectors, nil } func milvusStringsToBytes(ss [][]string) [][][]byte { res := make([][][]byte, len(ss)) for i, s1 := range ss { res[i] = make([][]byte, len(s1)) for j, s2 := range s1 { res[i][j] = []byte(s2) } } return res } func milvusBytesToStrings(bs [][]byte) []string { res := make([]string, len(bs)) for i, b := range bs { res[i] = string(b) } return res } ================================================ FILE: storage/vectors/milvus_test.go ================================================ // Copyright 2026 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package vectors import ( "os" "testing" "github.com/stretchr/testify/suite" ) var ( milvusUri string ) func init() { // get environment variables env := func(key, defaultValue string) string { if value := os.Getenv(key); value != "" { return value } return defaultValue } milvusUri = env("MILVUS_URI", "milvus://127.0.0.1:19530") } type MilvusTestSuite struct { vectorsTestSuite } func (suite *MilvusTestSuite) SetupSuite() { var err error suite.Database, err = Open(milvusUri, "gorse_") suite.NoError(err) } func TestMilvus(t *testing.T) { suite.Run(t, new(MilvusTestSuite)) } ================================================ FILE: storage/vectors/proxy.go ================================================ // Copyright 2026 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package vectors import ( "context" "net" "time" "github.com/gorse-io/gorse/protocol" "github.com/juju/errors" "google.golang.org/grpc" "google.golang.org/protobuf/types/known/timestamppb" ) type ProxyServer struct { protocol.UnimplementedVectorStoreServer database Database server *grpc.Server } func NewProxyServer(database Database) *ProxyServer { return &ProxyServer{database: database} } func (p *ProxyServer) Serve(lis net.Listener) error { p.server = grpc.NewServer() protocol.RegisterVectorStoreServer(p.server, p) return p.server.Serve(lis) } func (p *ProxyServer) Stop() { p.server.Stop() } func (p *ProxyServer) ListCollections(ctx context.Context, _ *protocol.ListCollectionsRequest) (*protocol.ListCollectionsResponse, error) { collections, err := p.database.ListCollections(ctx) if err != nil { return nil, err } return &protocol.ListCollectionsResponse{Collections: collections}, nil } func (p *ProxyServer) AddCollection(ctx context.Context, request *protocol.AddCollectionRequest) (*protocol.AddCollectionResponse, error) { distance, err := protoDistanceToDistance(request.GetDistance()) if err != nil { return nil, err } err = p.database.AddCollection(ctx, request.GetName(), int(request.GetDimensions()), distance) if err != nil { return nil, err } return &protocol.AddCollectionResponse{}, nil } func (p *ProxyServer) DeleteCollection(ctx context.Context, request *protocol.DeleteCollectionRequest) (*protocol.DeleteCollectionResponse, error) { err := p.database.DeleteCollection(ctx, request.GetName()) if err != nil { return nil, err } return &protocol.DeleteCollectionResponse{}, nil } func (p *ProxyServer) AddVectors(ctx context.Context, request *protocol.AddVectorsRequest) (*protocol.AddVectorsResponse, error) { vectors := make([]Vector, len(request.Vectors)) for i, vector := range request.Vectors { timestamp := time.Time{} if vector.GetTimestamp() != nil { timestamp = vector.GetTimestamp().AsTime() } vectors[i] = Vector{ Id: vector.GetId(), Vector: vector.GetValues(), Categories: vector.GetCategories(), Timestamp: timestamp, } } err := p.database.AddVectors(ctx, request.GetCollection(), vectors) if err != nil { return nil, err } return &protocol.AddVectorsResponse{}, nil } func (p *ProxyServer) DeleteVectors(ctx context.Context, request *protocol.DeleteVectorsRequest) (*protocol.DeleteVectorsResponse, error) { timestamp := time.Time{} if request.GetTimestamp() != nil { timestamp = request.GetTimestamp().AsTime() } err := p.database.DeleteVectors(ctx, request.GetCollection(), timestamp) if err != nil { return nil, err } return &protocol.DeleteVectorsResponse{}, nil } func (p *ProxyServer) QueryVectors(ctx context.Context, request *protocol.QueryVectorsRequest) (*protocol.QueryVectorsResponse, error) { results, err := p.database.QueryVectors(ctx, request.GetCollection(), request.GetQuery(), request.GetCategories(), int(request.GetTopK())) if err != nil { return nil, err } pbVectors := make([]*protocol.Vector, len(results)) for i, result := range results { pbVectors[i] = &protocol.Vector{ Id: result.Id, Values: result.Vector, Categories: result.Categories, } } return &protocol.QueryVectorsResponse{Vectors: pbVectors}, nil } type ProxyClient struct { protocol.VectorStoreClient } func NewProxyClient(conn *grpc.ClientConn) *ProxyClient { return &ProxyClient{ VectorStoreClient: protocol.NewVectorStoreClient(conn), } } func (p ProxyClient) Init() error { return nil } func (p ProxyClient) Optimize() error { return nil } func (p ProxyClient) Close() error { return nil } func (p ProxyClient) ListCollections(ctx context.Context) ([]string, error) { resp, err := p.VectorStoreClient.ListCollections(ctx, &protocol.ListCollectionsRequest{}) if err != nil { return nil, err } return resp.Collections, nil } func (p ProxyClient) AddCollection(ctx context.Context, name string, dimensions int, distance Distance) error { pbDistance, err := distanceToProtoDistance(distance) if err != nil { return err } _, err = p.VectorStoreClient.AddCollection(ctx, &protocol.AddCollectionRequest{ Name: name, Dimensions: int32(dimensions), Distance: pbDistance, }) return err } func (p ProxyClient) DeleteCollection(ctx context.Context, name string) error { _, err := p.VectorStoreClient.DeleteCollection(ctx, &protocol.DeleteCollectionRequest{Name: name}) return err } func (p ProxyClient) AddVectors(ctx context.Context, collection string, vectors []Vector) error { pbVectors := make([]*protocol.Vector, len(vectors)) for i, vector := range vectors { pbVectors[i] = &protocol.Vector{ Id: vector.Id, Values: vector.Vector, Categories: vector.Categories, Timestamp: timestamppb.New(vector.Timestamp), } } _, err := p.VectorStoreClient.AddVectors(ctx, &protocol.AddVectorsRequest{ Collection: collection, Vectors: pbVectors, }) return err } func (p ProxyClient) DeleteVectors(ctx context.Context, collection string, timestamp time.Time) error { _, err := p.VectorStoreClient.DeleteVectors(ctx, &protocol.DeleteVectorsRequest{ Collection: collection, Timestamp: timestamppb.New(timestamp), }) return err } func (p ProxyClient) QueryVectors(ctx context.Context, collection string, q []float32, categories []string, topK int) ([]Vector, error) { resp, err := p.VectorStoreClient.QueryVectors(ctx, &protocol.QueryVectorsRequest{ Collection: collection, Query: q, Categories: categories, TopK: int32(topK), }) if err != nil { return nil, err } results := make([]Vector, len(resp.Vectors)) for i, vector := range resp.Vectors { results[i] = Vector{ Id: vector.GetId(), Vector: vector.GetValues(), Categories: vector.GetCategories(), } } return results, nil } func distanceToProtoDistance(distance Distance) (protocol.Distance, error) { switch distance { case Cosine: return protocol.Distance_Cosine, nil case Euclidean: return protocol.Distance_Euclidean, nil case Dot: return protocol.Distance_Dot, nil default: return protocol.Distance_Unknown, errors.NotSupportedf("distance method") } } func protoDistanceToDistance(distance protocol.Distance) (Distance, error) { switch distance { case protocol.Distance_Cosine: return Cosine, nil case protocol.Distance_Euclidean: return Euclidean, nil case protocol.Distance_Dot: return Dot, nil default: return Cosine, errors.NotSupportedf("distance method") } } ================================================ FILE: storage/vectors/proxy_test.go ================================================ // Copyright 2026 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package vectors import ( "fmt" "net" "testing" "github.com/stretchr/testify/suite" "google.golang.org/grpc" ) type ProxyTestSuite struct { vectorsTestSuite sqlite Database server *ProxyServer clientConn *grpc.ClientConn } func (suite *ProxyTestSuite) SetupSuite() { var err error path := fmt.Sprintf("sqlite://%s/sqlite.db", suite.T().TempDir()) suite.sqlite, err = Open(path, "gorse_") suite.NoError(err) lis, err := net.Listen("tcp", "localhost:0") suite.NoError(err) suite.server = NewProxyServer(suite.sqlite) go func() { err = suite.server.Serve(lis) suite.NoError(err) }() suite.clientConn, err = grpc.Dial(lis.Addr().String(), grpc.WithInsecure()) suite.NoError(err) suite.Database = NewProxyClient(suite.clientConn) } func (suite *ProxyTestSuite) TearDownSuite() { suite.server.Stop() suite.NoError(suite.clientConn.Close()) suite.NoError(suite.sqlite.Close()) } func TestProxy(t *testing.T) { suite.Run(t, new(ProxyTestSuite)) } ================================================ FILE: storage/vectors/qdrant.go ================================================ // Copyright 2026 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package vectors import ( "context" "net/url" "strconv" "time" "github.com/google/uuid" "github.com/gorse-io/gorse/storage" "github.com/juju/errors" "github.com/qdrant/go-client/qdrant" ) const ( qdrantPayloadCategoriesKey = "categories" qdrantPayloadIdKey = "id" qdrantPayloadTimestampKey = "timestamp" ) func init() { Register([]string{storage.QdrantPrefix}, func(path, tablePrefix string, opts ...storage.Option) (Database, error) { database := new(Qdrant) u, err := url.Parse(path) if err != nil { return nil, errors.Trace(err) } host := u.Hostname() port := u.Port() portInt, err := strconv.Atoi(port) if err != nil { return nil, errors.Trace(err) } database.client, err = qdrant.NewClient(&qdrant.Config{ Host: host, Port: portInt, }) if err != nil { return nil, errors.Trace(err) } return database, nil }) } type Qdrant struct { client *qdrant.Client } func (db *Qdrant) Init() error { return nil } func (db *Qdrant) Optimize() error { return nil } func (db *Qdrant) Close() error { return db.client.Close() } func (db *Qdrant) ListCollections(ctx context.Context) ([]string, error) { return db.client.ListCollections(ctx) } func (db *Qdrant) AddCollection(ctx context.Context, name string, dimensions int, distance Distance) error { var qdrantDistance qdrant.Distance switch distance { case Cosine: qdrantDistance = qdrant.Distance_Cosine case Euclidean: qdrantDistance = qdrant.Distance_Euclid case Dot: qdrantDistance = qdrant.Distance_Dot default: return errors.NotSupportedf("distance method") } err := db.client.CreateCollection(ctx, &qdrant.CreateCollection{ CollectionName: name, VectorsConfig: qdrant.NewVectorsConfig(&qdrant.VectorParams{ Size: uint64(dimensions), Distance: qdrantDistance, }), }) if err != nil { return errors.Trace(err) } _, err = db.client.CreateFieldIndex(ctx, &qdrant.CreateFieldIndexCollection{ CollectionName: name, Wait: new(true), FieldName: qdrantPayloadTimestampKey, FieldType: qdrant.FieldType_FieldTypeInteger.Enum(), }) return errors.Trace(err) } func (db *Qdrant) DeleteCollection(ctx context.Context, name string) error { return db.client.DeleteCollection(ctx, name) } func (db *Qdrant) AddVectors(ctx context.Context, collection string, vectors []Vector) error { if len(vectors) == 0 { return nil } points := make([]*qdrant.PointStruct, 0, len(vectors)) for _, vector := range vectors { points = append(points, &qdrant.PointStruct{ Id: qdrant.NewID(uuid.NewMD5(uuid.NameSpaceURL, []byte(vector.Id)).String()), Payload: map[string]*qdrant.Value{ qdrantPayloadCategoriesKey: qdrantListValue(vector.Categories), qdrantPayloadIdKey: qdrant.NewValueString(vector.Id), qdrantPayloadTimestampKey: qdrant.NewValueInt(vector.Timestamp.UnixMilli()), }, Vectors: qdrant.NewVectorsDense(vector.Vector), }) } _, err := db.client.Upsert(ctx, &qdrant.UpsertPoints{ CollectionName: collection, Points: points, }) return errors.Trace(err) } func (db *Qdrant) DeleteVectors(ctx context.Context, collection string, timestamp time.Time) error { lt := float64(timestamp.UnixMilli()) _, err := db.client.Delete(ctx, &qdrant.DeletePoints{ CollectionName: collection, Points: qdrant.NewPointsSelectorFilter(&qdrant.Filter{ Must: []*qdrant.Condition{ qdrant.NewRange(qdrantPayloadTimestampKey, &qdrant.Range{Lt: <}), }, }), }) return errors.Trace(err) } func (db *Qdrant) QueryVectors(ctx context.Context, collection string, q []float32, categories []string, topK int) ([]Vector, error) { if topK <= 0 { return []Vector{}, nil } request := &qdrant.QueryPoints{ CollectionName: collection, Query: qdrant.NewQueryDense(q), Limit: new(uint64(topK)), WithPayload: qdrant.NewWithPayloadEnable(true), WithVectors: qdrant.NewWithVectorsEnable(true), } if len(categories) > 0 { request.Filter = &qdrant.Filter{ Must: []*qdrant.Condition{ qdrant.NewMatchKeywords(qdrantPayloadCategoriesKey, categories...), }, } } response, err := db.client.Query(ctx, request) if err != nil { return nil, errors.Trace(err) } results := make([]Vector, 0, len(response)) for _, scored := range response { results = append(results, Vector{ Id: qdrantId(scored.GetPayload()), Vector: qdrantVectorOutput(scored.GetVectors()), Categories: qdrantCategories(scored.GetPayload()), }) } return results, nil } func qdrantId(payload map[string]*qdrant.Value) string { if payload == nil { return "" } if value, ok := payload[qdrantPayloadIdKey]; ok { return value.GetStringValue() } return "" } func qdrantListValue(items []string) *qdrant.Value { values := make([]*qdrant.Value, 0, len(items)) for _, item := range items { values = append(values, qdrant.NewValueString(item)) } return qdrant.NewValueFromList(values...) } func qdrantCategories(payload map[string]*qdrant.Value) []string { if payload == nil { return []string{} } value, ok := payload[qdrantPayloadCategoriesKey] if !ok || value == nil { return []string{} } list := value.GetListValue() if list == nil { return []string{} } categories := make([]string, 0, len(list.GetValues())) for _, item := range list.GetValues() { if item == nil { continue } categories = append(categories, item.GetStringValue()) } return categories } func qdrantVectorOutput(output *qdrant.VectorsOutput) []float32 { if output == nil { return nil } vector := output.GetVector() if vector == nil { return nil } return vector.GetDenseVector().GetData() } ================================================ FILE: storage/vectors/qdrant_test.go ================================================ // Copyright 2026 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package vectors import ( "os" "testing" "github.com/stretchr/testify/suite" ) var ( qdrantUri string ) func init() { // get environment variables env := func(key, defaultValue string) string { if value := os.Getenv(key); value != "" { return value } return defaultValue } qdrantUri = env("QDRANT_URI", "qdrant://127.0.0.1:6334") } type QdrantTestSuite struct { vectorsTestSuite } func (suite *QdrantTestSuite) SetupSuite() { var err error suite.Database, err = Open(qdrantUri, "gorse_") suite.NoError(err) } func TestQdrant(t *testing.T) { suite.Run(t, new(QdrantTestSuite)) } ================================================ FILE: storage/vectors/sqlite.go ================================================ // Copyright 2026 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package vectors import ( "context" "database/sql" "encoding/json" "fmt" "strings" "time" "github.com/gorse-io/gorse/storage" "github.com/juju/errors" _ "modernc.org/sqlite/vec" ) func init() { Register([]string{storage.SQLitePrefix}, func(path, tablePrefix string, opts ...storage.Option) (Database, error) { database := new(SQLite) // Strip sqlite:// prefix dbPath := strings.TrimPrefix(path, storage.SQLitePrefix) var err error database.db, err = sql.Open("sqlite", dbPath) if err != nil { return nil, errors.Trace(err) } return database, nil }) } type SQLite struct { db *sql.DB } func (db *SQLite) Init() error { return nil } func (db *SQLite) Optimize() error { return nil } func (db *SQLite) Close() error { return db.db.Close() } func (db *SQLite) ListCollections(ctx context.Context) ([]string, error) { rows, err := db.db.QueryContext(ctx, "SELECT name FROM sqlite_master WHERE type='table' AND sql LIKE '%USING vec0%'") if err != nil { return nil, errors.Trace(err) } defer rows.Close() var names []string for rows.Next() { var name string if err := rows.Scan(&name); err != nil { return nil, errors.Trace(err) } names = append(names, name) } return names, nil } func (db *SQLite) AddCollection(ctx context.Context, name string, dimensions int, distance Distance) error { var metric string switch distance { case Cosine: metric = "cosine" case Euclidean: metric = "l2" case Dot: metric = "ip" default: return errors.NotSupportedf("distance method") } _, err := db.db.ExecContext(ctx, fmt.Sprintf("CREATE VIRTUAL TABLE %s USING vec0(id TEXT, categories TEXT, timestamp INTEGER, vector FLOAT[%d] distance_metric=%s)", name, dimensions, metric)) return errors.Trace(err) } func (db *SQLite) DeleteCollection(ctx context.Context, name string) error { var count int err := db.db.QueryRowContext(ctx, "SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", name).Scan(&count) if err != nil { return errors.Trace(err) } if count == 0 { return errors.NotFoundf("collection %s", name) } _, err = db.db.ExecContext(ctx, fmt.Sprintf("DROP TABLE %s", name)) return errors.Trace(err) } func (db *SQLite) AddVectors(ctx context.Context, collection string, vectors []Vector) error { if len(vectors) == 0 { return nil } tx, err := db.db.BeginTx(ctx, nil) if err != nil { return errors.Trace(err) } stmt, err := tx.PrepareContext(ctx, fmt.Sprintf("INSERT INTO %s(id, categories, timestamp, vector) VALUES(?, ?, ?, ?)", collection)) if err != nil { _ = tx.Rollback() return errors.Trace(err) } defer stmt.Close() for _, v := range vectors { categories, err := json.Marshal(v.Categories) if err != nil { _ = tx.Rollback() return errors.Trace(err) } vectorJson, err := json.Marshal(v.Vector) if err != nil { _ = tx.Rollback() return errors.Trace(err) } timestamp := v.Timestamp.UnixMilli() _, err = stmt.ExecContext(ctx, v.Id, string(categories), timestamp, string(vectorJson)) if err != nil { _ = tx.Rollback() return errors.Trace(err) } } return errors.Trace(tx.Commit()) } func (db *SQLite) DeleteVectors(ctx context.Context, collection string, timestamp time.Time) error { _, err := db.db.ExecContext(ctx, fmt.Sprintf("DELETE FROM %s WHERE timestamp < ?", collection), timestamp.UnixMilli()) return errors.Trace(err) } func (db *SQLite) QueryVectors(ctx context.Context, collection string, q []float32, categories []string, topK int) ([]Vector, error) { if topK <= 0 { return []Vector{}, nil } qJson, err := json.Marshal(q) if err != nil { return nil, errors.Trace(err) } query := fmt.Sprintf("SELECT id, categories FROM %s WHERE vector MATCH ? AND k = ? ", collection) var args []any args = append(args, string(qJson), topK) if len(categories) > 0 { var categoryConditions []string for _, category := range categories { categoryConditions = append(categoryConditions, "json_contains(categories, ?)") args = append(args, fmt.Sprintf("[%q]", category)) } query += " AND (" + strings.Join(categoryConditions, " OR ") + ")" } query += " ORDER BY distance" rows, err := db.db.QueryContext(ctx, query, args...) if err != nil { return nil, errors.Trace(err) } defer rows.Close() var results []Vector for rows.Next() { var v Vector var categoriesStr string if err := rows.Scan(&v.Id, &categoriesStr); err != nil { return nil, errors.Trace(err) } if err := json.Unmarshal([]byte(categoriesStr), &v.Categories); err != nil { return nil, errors.Trace(err) } results = append(results, v) } return results, nil } ================================================ FILE: storage/vectors/sqlite_test.go ================================================ // Copyright 2026 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package vectors import ( "testing" "github.com/stretchr/testify/suite" ) type SQLiteTestSuite struct { vectorsTestSuite } func (suite *SQLiteTestSuite) SetupSuite() { var err error suite.Database, err = Open("sqlite://:memory:", "gorse_") suite.NoError(err) } func TestSQLite(t *testing.T) { suite.Run(t, new(SQLiteTestSuite)) } ================================================ FILE: storage/vectors/weaviate.go ================================================ // Copyright 2026 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package vectors import ( "context" "net/url" "strings" "time" "github.com/go-openapi/strfmt" "github.com/google/uuid" "github.com/gorse-io/gorse/storage" "github.com/juju/errors" "github.com/weaviate/weaviate-go-client/v4/weaviate" "github.com/weaviate/weaviate-go-client/v4/weaviate/filters" "github.com/weaviate/weaviate-go-client/v4/weaviate/graphql" "github.com/weaviate/weaviate/entities/models" ) const ( weaviatePayloadCategoriesKey = "categories" weaviatePayloadTimestampKey = "timestamp" ) func init() { Register([]string{storage.WeaviatePrefix, storage.WeaviatesPrefix}, func(path, tablePrefix string, opts ...storage.Option) (Database, error) { database := new(Weaviate) u, err := url.Parse(path) if err != nil { return nil, errors.Trace(err) } scheme := "http" if strings.HasPrefix(path, storage.WeaviatesPrefix) { scheme = "https" } cfg := weaviate.Config{ Host: u.Host, Scheme: scheme, } database.client, err = weaviate.NewClient(cfg) if err != nil { return nil, errors.Trace(err) } return database, nil }) } type Weaviate struct { client *weaviate.Client } func (db *Weaviate) Init() error { return nil } func (db *Weaviate) Optimize() error { return nil } func (db *Weaviate) Close() error { return nil } func (db *Weaviate) ListCollections(ctx context.Context) ([]string, error) { s, err := db.client.Schema().Getter().Do(ctx) if err != nil { return nil, errors.Trace(err) } var names []string for _, class := range s.Classes { names = append(names, uncapitalize(class.Class)) } return names, nil } func (db *Weaviate) AddCollection(ctx context.Context, name string, dimensions int, distance Distance) error { var weaviateDistance string switch distance { case Cosine: weaviateDistance = "cosine" case Euclidean: weaviateDistance = "l2-squared" case Dot: weaviateDistance = "dot" default: return errors.NotSupportedf("distance method") } class := &models.Class{ Class: capitalize(name), Vectorizer: "none", Properties: []*models.Property{ { Name: "originalId", DataType: []string{"string"}, }, { Name: weaviatePayloadCategoriesKey, DataType: []string{"string[]"}, }, { Name: weaviatePayloadTimestampKey, DataType: []string{"date"}, IndexFilterable: new(true), IndexRangeFilters: new(true), }, }, VectorIndexConfig: map[string]interface{}{ "distance": weaviateDistance, }, } err := db.client.Schema().ClassCreator().WithClass(class).Do(ctx) return errors.Trace(err) } func (db *Weaviate) DeleteCollection(ctx context.Context, name string) error { exists, err := db.client.Schema().ClassExistenceChecker().WithClassName(capitalize(name)).Do(ctx) if err != nil { return errors.Trace(err) } if !exists { return errors.NotFoundf("collection %s", name) } err = db.client.Schema().ClassDeleter().WithClassName(capitalize(name)).Do(ctx) return errors.Trace(err) } func (db *Weaviate) AddVectors(ctx context.Context, collection string, vectors []Vector) error { if len(vectors) == 0 { return nil } objects := make([]*models.Object, 0, len(vectors)) for _, vector := range vectors { objects = append(objects, &models.Object{ Class: capitalize(collection), ID: strfmt.UUID(uuid.NewMD5(uuid.NameSpaceURL, []byte(vector.Id)).String()), Properties: map[string]interface{}{ "originalId": vector.Id, weaviatePayloadCategoriesKey: vector.Categories, weaviatePayloadTimestampKey: vector.Timestamp, }, Vector: models.C11yVector(vector.Vector), }) } _, err := db.client.Batch().ObjectsBatcher().WithObjects(objects...).Do(ctx) return errors.Trace(err) } func (db *Weaviate) DeleteVectors(ctx context.Context, collection string, timestamp time.Time) error { _, err := db.client.Batch().ObjectsBatchDeleter(). WithClassName(capitalize(collection)). WithWhere(filters.Where(). WithPath([]string{weaviatePayloadTimestampKey}). WithOperator(filters.LessThan). WithValueDate(timestamp)). Do(ctx) return errors.Trace(err) } func (db *Weaviate) QueryVectors(ctx context.Context, collection string, q []float32, categories []string, topK int) ([]Vector, error) { if topK <= 0 { return []Vector{}, nil } fields := []graphql.Field{ {Name: "originalId"}, {Name: weaviatePayloadCategoriesKey}, } explore := db.client.GraphQL().NearVectorArgBuilder().WithVector(q) builder := db.client.GraphQL().Get(). WithClassName(capitalize(collection)). WithFields(fields...). WithNearVector(explore). WithLimit(topK) if len(categories) > 0 { operands := make([]*filters.WhereBuilder, 0, len(categories)) for _, category := range categories { operands = append(operands, filters.Where(). WithPath([]string{weaviatePayloadCategoriesKey}). WithOperator(filters.ContainsAny). WithValueString(category)) } var where *filters.WhereBuilder if len(operands) == 1 { where = operands[0] } else { where = filters.Where(). WithOperator(filters.Or). WithOperands(operands) } builder = builder.WithWhere(where) } result, err := builder.Do(ctx) if err != nil { return nil, errors.Trace(err) } if len(result.Errors) > 0 { return nil, errors.New(result.Errors[0].Message) } data := result.Data["Get"].(map[string]interface{}) items := data[capitalize(collection)].([]interface{}) results := make([]Vector, 0, len(items)) for _, item := range items { m := item.(map[string]interface{}) id := m["originalId"].(string) var cats []string if m[weaviatePayloadCategoriesKey] != nil { for _, c := range m[weaviatePayloadCategoriesKey].([]interface{}) { cats = append(cats, c.(string)) } } results = append(results, Vector{ Id: id, Categories: cats, }) } return results, nil } func capitalize(s string) string { if len(s) == 0 { return s } return strings.ToUpper(s[:1]) + s[1:] } func uncapitalize(s string) string { if len(s) == 0 { return s } return strings.ToLower(s[:1]) + s[1:] } ================================================ FILE: storage/vectors/weaviate_test.go ================================================ // Copyright 2026 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package vectors import ( "os" "testing" "github.com/stretchr/testify/suite" ) var ( weaviateUri string ) func init() { // get environment variables env := func(key, defaultValue string) string { if value := os.Getenv(key); value != "" { return value } return defaultValue } weaviateUri = env("WEAVIATE_URI", "weaviate://127.0.0.1:8080") } type WeaviateTestSuite struct { vectorsTestSuite } func (suite *WeaviateTestSuite) SetupSuite() { var err error suite.Database, err = Open(weaviateUri, "gorse_") suite.NoError(err) } func TestWeaviate(t *testing.T) { suite.Run(t, new(WeaviateTestSuite)) } ================================================ FILE: worker/metrics.go ================================================ // Copyright 2021 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package worker import ( "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" ) const ( LabelStep = "step" LabelData = "data" ) var ( UpdateUserRecommendTotal = promauto.NewGauge(prometheus.GaugeOpts{ Namespace: "gorse", Subsystem: "worker", Name: "update_user_recommend_total", }) OfflineRecommendStepSecondsVec = promauto.NewGaugeVec(prometheus.GaugeOpts{ Namespace: "gorse", Subsystem: "worker", Name: "offline_recommend_step_seconds", }, []string{LabelStep}) OfflineRecommendTotalSeconds = promauto.NewGauge(prometheus.GaugeOpts{ Namespace: "gorse", Subsystem: "worker", Name: "offline_recommend_total_seconds", }) MemoryInuseBytesVec = promauto.NewGaugeVec(prometheus.GaugeOpts{ Namespace: "gorse", Subsystem: "worker", Name: "memory_inuse_bytes", }, []string{LabelData}) ) ================================================ FILE: worker/pipeline.go ================================================ // Copyright 2025 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package worker import ( "context" "strings" "sync" "time" mapset "github.com/deckarep/golang-set/v2" "github.com/gorse-io/gorse/common/expression" "github.com/gorse-io/gorse/common/log" "github.com/gorse-io/gorse/common/monitor" "github.com/gorse-io/gorse/common/parallel" "github.com/gorse-io/gorse/common/util" "github.com/gorse-io/gorse/config" "github.com/gorse-io/gorse/logics" "github.com/gorse-io/gorse/model/ctr" "github.com/gorse-io/gorse/storage/cache" "github.com/gorse-io/gorse/storage/data" "github.com/juju/errors" "github.com/samber/lo" "go.uber.org/atomic" "go.uber.org/zap" ) type Pipeline struct { Config *config.Config CacheClient cache.Database DataClient data.Database Tracer *monitor.Monitor Jobs int MatrixFactorizationItems *logics.MatrixFactorizationItems MatrixFactorizationUsers *logics.MatrixFactorizationUsers ClickThroughRateModel ctr.FactorizationMachines dontskipColdStartUsers bool } func (p *Pipeline) Recommend(ctx context.Context, users []data.User, progress func(completed, throughput int)) { startRecommendTime := time.Now() itemCache := NewItemCache(p.DataClient) log.Logger().Info("ranking recommendation", zap.Int("n_working_users", len(users)), zap.Int("n_jobs", p.Jobs), zap.Int("cache_size", p.Config.Recommend.CacheSize)) // progress tracker completed := make(chan struct{}, 1000) _, span := p.Tracer.Start(ctx, "Generate recommendation", len(users)) defer span.End() go func() { defer util.CheckPanic() completedCount, previousCount := 0, 0 ticker := time.NewTicker(10 * time.Second) defer ticker.Stop() for { select { case _, ok := <-completed: if !ok { return } completedCount++ case <-ticker.C: throughput := completedCount - previousCount span.Add(throughput) if progress != nil { progress(completedCount, completedCount-previousCount) } previousCount = completedCount case <-ctx.Done(): return } } }() // recommendation startTime := time.Now() var ( updateUserCount atomic.Float64 collaborativeRecommendSeconds atomic.Float64 userBasedRecommendSeconds atomic.Float64 itemBasedRecommendSeconds atomic.Float64 latestRecommendSeconds atomic.Float64 popularRecommendSeconds atomic.Float64 ) defer MemoryInuseBytesVec.WithLabelValues("user_feedback_cache").Set(0) if err := parallel.Detachable(ctx, len(users), p.Jobs, p.Config.OpenAI.ChatCompletionRPM, func(pCtx *parallel.Context, jobId int) { defer func() { completed <- struct{}{} }() user := users[jobId] userId := user.UserId // skip inactive users before max recommend period if !p.checkUserActiveTime(ctx, userId) || !p.checkRecommendCacheOutOfDate(ctx, userId) { return } updateUserCount.Add(1) recommendTime := time.Now() recommender, err := logics.NewRecommender(p.Config.Recommend, p.CacheClient, p.DataClient, false, userId, nil) if err != nil { log.Logger().Error("failed to create recommender", zap.String("user_id", userId), zap.Error(err)) return } if !p.dontskipColdStartUsers && recommender.IsColdStart() { // skip cold-start users without any positive feedback return } // Update collaborative filtering recommendation. if !strings.EqualFold(p.Config.Recommend.Collaborative.Type, "none") && p.MatrixFactorizationUsers != nil && p.MatrixFactorizationItems != nil { if userEmbedding, ok := p.MatrixFactorizationUsers.Get(userId); ok { err = p.updateCollaborativeRecommend(ctx, p.MatrixFactorizationItems, userId, userEmbedding, recommender.ExcludeSet(), itemCache) if err != nil { log.Logger().Error("failed to recommend by collaborative filtering", zap.String("user_id", userId), zap.Error(err)) return } } else if !p.dontskipColdStartUsers { // skip users without collaborative filtering embeddings return } } // Generate recommendation from recommenders. var ( scores []cache.Score digest string recommenderNames []string ) if len(p.Config.Recommend.Ranker.Recommenders) > 0 { recommenderNames = p.Config.Recommend.Ranker.Recommenders } else { recommenderNames = p.Config.Recommend.ListRecommenders() } scores, digest, err = recommender.RecommendSequential(ctx, scores, 0, recommenderNames...) if err != nil { log.Logger().Error("failed to recommend items", zap.String("user_id", userId), zap.Error(err)) return } candidates := make([]cache.Score, 0, len(scores)) candidateSet := mapset.NewSet[string]() items, err := itemCache.GetMap(ctx, lo.Map(scores, func(score cache.Score, _ int) string { return score.Id })) if err != nil { log.Logger().Error("failed to download items", zap.String("user_id", userId), zap.Error(err)) return } for _, score := range scores { if _, exist := items[score.Id]; exist { score.Timestamp = recommendTime candidates = append(candidates, score) candidateSet.Add(score.Id) } } // Insert replacement items into the candidate set before ranking so all rankers (including LLM) can order them. var replacementPositiveItems, replacementNegativeItems mapset.Set[string] if p.Config.Recommend.Replacement.EnableReplacement && p.Config.Recommend.Ranker.Type != "none" { candidates, replacementPositiveItems, replacementNegativeItems, err = p.addReplacementCandidates( ctx, candidates, candidateSet, recommender.UserFeedback(), itemCache, recommendTime, ) if err != nil { log.Logger().Error("failed to prepare replacement candidates", zap.Error(err)) return } } // rank by click-through-rate var results []cache.Score if p.Config.Recommend.Ranker.Type == "fm" && p.ClickThroughRateModel != nil && !p.ClickThroughRateModel.Invalid() { results, err = p.rankByClickTroughRate(ctx, p.ClickThroughRateModel, &user, candidates, itemCache, recommendTime) if err != nil { log.Logger().Error("failed to rank items", zap.Error(err)) return } } else if p.Config.Recommend.Ranker.Type == "llm" && p.Config.OpenAI.ChatCompletionModel != "" { ranker, err := logics.NewChatReranker( p.Config.Recommend.Ranker.RerankerAPI, p.Config.Recommend.Ranker.QueryTemplate, p.Config.Recommend.Ranker.DocumentTemplate) if err != nil { log.Logger().Error("failed to create LLM ranker", zap.Error(err)) return } results, err = p.rankByLLM(ctx, pCtx, ranker, &user, recommender.UserFeedback(), candidates, itemCache, recommendTime) if err != nil { log.Logger().Error("failed to rank items by LLM", zap.Error(err)) return } } else { results = candidates } // Apply replacement decay after ranking so weights don't bypass the ranker ordering. if p.Config.Recommend.Replacement.EnableReplacement && p.Config.Recommend.Ranker.Type != "none" { results = p.applyReplacementDecay(results, replacementPositiveItems, replacementNegativeItems) } // cache recommendation if err = p.CacheClient.AddScores(ctx, cache.Recommend, userId, results); err != nil { log.Logger().Error("failed to cache recommendation", zap.Error(err)) return } if err = p.CacheClient.DeleteScores(ctx, []string{cache.Recommend}, cache.ScoreCondition{ Before: &recommendTime, Subset: new(userId), }); err != nil { log.Logger().Error("failed to delete stale recommendation", zap.Error(err)) return } if err = p.CacheClient.Set(ctx, cache.Time(cache.Key(cache.RecommendUpdateTime, userId), recommendTime), cache.String(cache.Key(cache.RecommendDigest, userId), digest), ); err != nil { log.Logger().Error("failed to cache recommendation time", zap.Error(err)) } }); err != nil { log.Logger().Error("recommendation was cancelled", zap.Error(err)) } close(completed) log.Logger().Info("complete ranking recommendation", zap.String("used_time", time.Since(startTime).String())) UpdateUserRecommendTotal.Set(updateUserCount.Load()) OfflineRecommendTotalSeconds.Set(time.Since(startRecommendTime).Seconds()) OfflineRecommendStepSecondsVec.WithLabelValues("collaborative_recommend").Set(collaborativeRecommendSeconds.Load()) OfflineRecommendStepSecondsVec.WithLabelValues("item_based_recommend").Set(itemBasedRecommendSeconds.Load()) OfflineRecommendStepSecondsVec.WithLabelValues("user_based_recommend").Set(userBasedRecommendSeconds.Load()) OfflineRecommendStepSecondsVec.WithLabelValues("latest_recommend").Set(latestRecommendSeconds.Load()) OfflineRecommendStepSecondsVec.WithLabelValues("popular_recommend").Set(popularRecommendSeconds.Load()) } // checkUserActiveTime checks if a user is active based on their last modification time. func (p *Pipeline) checkUserActiveTime(ctx context.Context, userId string) bool { if p.Config.Recommend.ActiveUserTTL == 0 { return true } // read active time activeTime, err := p.CacheClient.Get(ctx, cache.Key(cache.LastModifyUserTime, userId)).Time() if err != nil { log.Logger().Error("failed to read last modify user time", zap.String("user_id", userId), zap.Error(err)) return true } if activeTime.IsZero() { return true } // check active time if time.Since(activeTime) < time.Duration(p.Config.Recommend.ActiveUserTTL*24)*time.Hour { return true } // remove recommend cache for inactive users if err := p.CacheClient.DeleteScores(ctx, []string{cache.Recommend}, cache.ScoreCondition{Subset: new(userId)}); err != nil { log.Logger().Error("failed to delete recommend cache", zap.String("user_id", userId), zap.Error(err)) } return false } // checkRecommendCacheOutOfDate checks if recommend cache stale. func (p *Pipeline) checkRecommendCacheOutOfDate(ctx context.Context, userId string) bool { var ( activeTime time.Time recommendTime time.Time err error ) // 1. If cache is empty, stale. items, err := p.CacheClient.SearchScores(ctx, cache.Recommend, userId, nil, 0, -1) if err != nil { log.Logger().Error("failed to load offline recommendation", zap.String("user_id", userId), zap.Error(err)) return true } else if len(items) == 0 { return true } // 2. If digest is empty or not match, stale. digest, err := p.CacheClient.Get(ctx, cache.Key(cache.RecommendDigest, userId)).String() if err != nil { log.Logger().Error("failed to read offline recommendation digest", zap.String("user_id", userId), zap.Error(err)) return true } if digest == "" { return true } if digest != p.Config.Recommend.Hash() { return true } // read active time activeTime, err = p.CacheClient.Get(ctx, cache.Key(cache.LastModifyUserTime, userId)).Time() if err != nil { log.Logger().Error("failed to read last modify user time", zap.String("user_id", userId), zap.Error(err)) } // 3. If update time is empty, stale. recommendTime, err = p.CacheClient.Get(ctx, cache.Key(cache.RecommendUpdateTime, userId)).Time() if err != nil { log.Logger().Error("failed to read last update user recommend time", zap.Error(err)) return true } // 4. If update time + cache expire > current time, not stale. if recommendTime.Before(time.Now().Add(-p.Config.Recommend.CacheExpire)) { return true } // 5. If active time > recommend time, not stale. if activeTime.Before(recommendTime) { timeoutTime := recommendTime.Add(p.Config.Recommend.Ranker.CacheExpire) return timeoutTime.Before(time.Now()) } return true } func (p *Pipeline) updateCollaborativeRecommend( ctx context.Context, items *logics.MatrixFactorizationItems, userId string, userEmbedding []float32, excludeSet mapset.Set[string], itemCache *ItemCache, ) error { localStartTime := time.Now() scores := items.Search(userEmbedding, p.Config.Recommend.CacheSize+excludeSet.Cardinality()) // update categories itemsMap, err := itemCache.GetMap(ctx, lo.Map(scores, func(score cache.Score, _ int) string { return score.Id })) if err != nil { return errors.Trace(err) } // remove excluded items and non-existing items recommend := make([]cache.Score, 0, len(scores)) for i := range scores { if item, exist := itemsMap[scores[i].Id]; exist && !excludeSet.Contains(item.ItemId) { recommend = append(recommend, cache.Score{ Id: scores[i].Id, Score: scores[i].Score, Categories: item.Categories, // the scores use the timestamp of the ranking index, which is only refreshed every so often. // if we don't overwrite the timestamp here, the code below will delete all scores that were // just written. Timestamp: localStartTime, }) } } if err := p.CacheClient.AddScores(ctx, cache.CollaborativeFiltering, userId, recommend); err != nil { log.Logger().Error("failed to cache collaborative filtering recommendation result", zap.String("user_id", userId), zap.Error(err)) return errors.Trace(err) } if err := p.CacheClient.Set(ctx, cache.Time(cache.Key(cache.CollaborativeFilteringUpdateTime, userId), localStartTime), cache.String(cache.Key(cache.CollaborativeFilteringDigest, userId), p.Config.Recommend.Collaborative.Hash(&p.Config.Recommend)), ); err != nil { log.Logger().Error("failed to cache collaborative filtering recommendation time", zap.String("user_id", userId), zap.Error(err)) return errors.Trace(err) } if err := p.CacheClient.DeleteScores(ctx, []string{cache.CollaborativeFiltering}, cache.ScoreCondition{Before: &localStartTime, Subset: new(userId)}); err != nil { log.Logger().Error("failed to delete stale collaborative filtering recommendation result", zap.String("user_id", userId), zap.Error(err)) return errors.Trace(err) } return nil } // rankByClickTroughRate ranks items by predicted click-through-rate. func (p *Pipeline) rankByClickTroughRate( ctx context.Context, predictor ctr.FactorizationMachines, user *data.User, candidates []cache.Score, itemCache *ItemCache, recommendTime time.Time, ) ([]cache.Score, error) { // download items items, err := itemCache.GetSlice(ctx, lo.Map(candidates, func(score cache.Score, _ int) string { return score.Id })) if err != nil { return nil, errors.Trace(err) } // rank by CTR topItems := make([]cache.Score, 0, len(items)) if batchPredictor, ok := predictor.(ctr.BatchInference); ok { inputs := make([]lo.Tuple4[string, string, []ctr.Label, []ctr.Label], len(items)) embeddings := make([][]ctr.Embedding, len(items)) for i, item := range items { inputs[i].A = user.UserId inputs[i].B = item.ItemId inputs[i].C = ctr.ConvertLabels(user.Labels) inputs[i].D = ctr.ConvertLabels(item.Labels) embeddings[i] = ctr.ConvertEmbeddings(item.Labels) } output := batchPredictor.BatchPredict(inputs, embeddings, p.Jobs) for i, score := range output { topItems = append(topItems, cache.Score{ Id: items[i].ItemId, Score: float64(score), Categories: items[i].Categories, Timestamp: recommendTime, }) } } else { for _, item := range items { topItems = append(topItems, cache.Score{ Id: item.ItemId, Score: float64(predictor.Predict(user.UserId, item.ItemId, ctr.ConvertLabels(user.Labels), ctr.ConvertLabels(item.Labels))), Categories: item.Categories, Timestamp: recommendTime, }) } } cache.SortDocuments(topItems) return topItems, nil } func (p *Pipeline) rankByLLM( ctx context.Context, pCtx *parallel.Context, ranker *logics.ChatReranker, user *data.User, feedback []data.Feedback, candidates []cache.Score, itemCache *ItemCache, recommendTime time.Time, ) ([]cache.Score, error) { // download items items, err := itemCache.GetSlice(ctx, lo.Map(candidates, func(score cache.Score, _ int) string { return score.Id })) if err != nil { return nil, errors.Trace(err) } // convert feedback data.SortFeedbacks(feedback) contextUserFeedback := make([]data.Feedback, 0, p.Config.Recommend.ContextSize) for _, f := range feedback { if p.Config.Recommend.ContextSize <= len(contextUserFeedback) { break } if expression.MatchFeedbackTypeExpressions(p.Config.Recommend.DataSource.PositiveFeedbackTypes, f.FeedbackType, f.Value) { contextUserFeedback = append(contextUserFeedback, f) } } itemMap, err := itemCache.GetMap(ctx, lo.Map(contextUserFeedback, func(fb data.Feedback, _ int) string { return fb.ItemId })) if err != nil { return nil, errors.Trace(err) } feedbackItems := make([]*logics.FeedbackItem, 0, len(contextUserFeedback)) for _, fb := range contextUserFeedback { if item, exist := itemMap[fb.ItemId]; exist { feedbackItems = append(feedbackItems, &logics.FeedbackItem{ FeedbackType: fb.FeedbackType, Item: *item, }) } } // rank by LLM pCtx.Detach() parsed, err := ranker.Rank(ctx, user, feedbackItems, items) pCtx.Attach() if err != nil { return nil, errors.Trace(err) } // construct scores var topItems []cache.Score itemCategories := make(map[string][]string, len(items)) for _, item := range items { if item == nil { continue } itemCategories[item.ItemId] = item.Categories } for _, item := range parsed { topItems = append(topItems, cache.Score{ Id: item.Id, Score: item.Score, Timestamp: recommendTime, Categories: itemCategories[item.Id], }) } return topItems, nil } // replacement inserts historical items back to recommendation. // It now adds the replacement items before ranking, then applies decay after ranking. func (p *Pipeline) addReplacementCandidates( ctx context.Context, candidates []cache.Score, candidateSet mapset.Set[string], feedbacks []data.Feedback, itemCache *ItemCache, recommendTime time.Time, ) ([]cache.Score, mapset.Set[string], mapset.Set[string], error) { positiveItems := mapset.NewSet[string]() distinctItems := mapset.NewSet[string]() for _, feedback := range feedbacks { if expression.MatchFeedbackTypeExpressions(p.Config.Recommend.DataSource.PositiveFeedbackTypes, feedback.FeedbackType, feedback.Value) { positiveItems.Add(feedback.ItemId) distinctItems.Add(feedback.ItemId) } else if expression.MatchFeedbackTypeExpressions(p.Config.Recommend.DataSource.ReadFeedbackTypes, feedback.FeedbackType, feedback.Value) { distinctItems.Add(feedback.ItemId) } } if distinctItems.Cardinality() == 0 { return candidates, positiveItems, mapset.NewSet[string](), nil } items, err := itemCache.GetSlice(ctx, distinctItems.ToSlice()) if err != nil { return nil, nil, nil, errors.Trace(err) } // Only keep items that exist and aren't hidden in the cache. for _, item := range items { if candidateSet.Contains(item.ItemId) { continue } candidates = append(candidates, cache.Score{Id: item.ItemId, Categories: item.Categories, Timestamp: recommendTime}) candidateSet.Add(item.ItemId) } // Build negative items set from distinct minus positive, filtered by existence. existingSet := mapset.NewSet(lo.Map(items, func(it *data.Item, _ int) string { return it.ItemId })...) positiveExisting := positiveItems.Intersect(existingSet) negativeItems := distinctItems.Difference(positiveItems).Intersect(existingSet) return candidates, positiveExisting, negativeItems, nil } func (p *Pipeline) applyReplacementDecay( results []cache.Score, positiveItems mapset.Set[string], negativeItems mapset.Set[string], ) []cache.Score { if (positiveItems == nil || positiveItems.Cardinality() == 0) && (negativeItems == nil || negativeItems.Cardinality() == 0) { return results } updated := make([]cache.Score, len(results)) copy(updated, results) changed := false for i := range updated { switch { case positiveItems != nil && positiveItems.Contains(updated[i].Id): updated[i].Score *= p.Config.Recommend.Replacement.PositiveReplacementDecay changed = true case negativeItems != nil && negativeItems.Contains(updated[i].Id): updated[i].Score *= p.Config.Recommend.Replacement.ReadReplacementDecay changed = true } } if changed { cache.SortDocuments(updated) } return updated } // ItemCache is alias of map[string]data.Item. type ItemCache struct { Client data.Database Data sync.Map } // NewItemCache creates a new ItemCache. func NewItemCache(client data.Database) *ItemCache { return &ItemCache{ Client: client, Data: sync.Map{}, } } func (c *ItemCache) GetSlice(ctx context.Context, itemIds []string) ([]*data.Item, error) { requests := make([]string, 0, len(itemIds)) for _, itemId := range itemIds { if _, exist := c.Data.Load(itemId); !exist { requests = append(requests, itemId) } } response, err := c.Client.BatchGetItems(ctx, requests) if err != nil { return nil, errors.Trace(err) } for _, item := range response { c.Data.Store(item.ItemId, &item) } items := make([]*data.Item, 0, len(itemIds)) for _, itemId := range itemIds { if val, exist := c.Data.Load(itemId); exist { item := val.(*data.Item) if !item.IsHidden { items = append(items, item) } } } return items, nil } func (c *ItemCache) GetMap(ctx context.Context, itemIds []string) (map[string]*data.Item, error) { items, err := c.GetSlice(ctx, itemIds) if err != nil { return nil, errors.Trace(err) } return lo.SliceToMap(items, func(item *data.Item) (string, *data.Item) { return item.ItemId, item }), nil } ================================================ FILE: worker/pipeline_test.go ================================================ // Copyright 2025 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package worker import ( "fmt" "testing" "github.com/gorse-io/gorse/storage/data" "github.com/stretchr/testify/suite" ) type PipelineTestSuite struct { suite.Suite dataClient data.Database } func (suite *PipelineTestSuite) SetupSuite() { var err error suite.dataClient, err = data.Open(fmt.Sprintf("sqlite://%s/data.db", suite.T().TempDir()), "") suite.NoError(err) err = suite.dataClient.Init() suite.NoError(err) // insert items err = suite.dataClient.BatchInsertItems(suite.T().Context(), []data.Item{ {ItemId: "1"}, {ItemId: "2"}, {ItemId: "3"}, {ItemId: "4"}, {ItemId: "5"}, }) suite.NoError(err) } func (suite *PipelineTestSuite) TearDownSuite() { err := suite.dataClient.Close() suite.NoError(err) } func (suite *PipelineTestSuite) TestGetSlice() { c := NewItemCache(suite.dataClient) items, err := c.GetSlice(suite.T().Context(), []string{"1", "2", "3", "4", "5", "6"}) suite.NoError(err) suite.Equal(5, len(items)) } func (suite *PipelineTestSuite) TestGetMap() { c := NewItemCache(suite.dataClient) items, err := c.GetMap(suite.T().Context(), []string{"1", "2", "3", "4", "5", "6"}) suite.NoError(err) suite.Equal(5, len(items)) } func TestPipeline(t *testing.T) { suite.Run(t, new(PipelineTestSuite)) } ================================================ FILE: worker/worker.go ================================================ // Copyright 2020 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package worker import ( "context" "crypto/md5" "encoding/hex" "encoding/json" "fmt" "math/rand" "net" "net/http" "os" "strconv" "strings" "time" "github.com/gorse-io/gorse/cmd/version" "github.com/gorse-io/gorse/common/log" "github.com/gorse-io/gorse/common/monitor" "github.com/gorse-io/gorse/common/util" "github.com/gorse-io/gorse/config" "github.com/gorse-io/gorse/logics" "github.com/gorse-io/gorse/model/ctr" "github.com/gorse-io/gorse/protocol" "github.com/gorse-io/gorse/storage" "github.com/gorse-io/gorse/storage/blob" "github.com/gorse-io/gorse/storage/cache" "github.com/gorse-io/gorse/storage/data" "github.com/juju/errors" "github.com/lafikl/consistent" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/samber/lo" "go.uber.org/zap" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" ) const batchSize = 10000 // Worker manages states of a worker node. type Worker struct { Pipeline testMode bool collaborativeFilteringModelId int64 clickThroughRateModelId int64 // worker config workerName string httpHost string httpPort int masterHost string masterPort int tlsConfig *util.TLSConfig cacheFile string // database connection path cachePath string cachePrefix string dataPath string dataPrefix string blobConfig string blobStore blob.Store // master connection conn *grpc.ClientConn masterClient protocol.MasterClient latestCollaborativeFilteringModelId int64 latestClickThroughRateModelId int64 randGenerator *rand.Rand // peers peers []string me string // events tickDuration time.Duration ticker *time.Ticker syncedChan chan struct{} // meta synced events pulledChan chan struct{} // model pulled events } // NewWorker creates a new worker node. func NewWorker( masterHost string, masterPort int, httpHost string, httpPort int, jobs int, cacheFile string, tlsConfig *util.TLSConfig, interval time.Duration, ) *Worker { return &Worker{ Pipeline: Pipeline{ Config: config.GetDefaultConfig(), CacheClient: new(cache.NoDatabase), DataClient: new(data.NoDatabase), Jobs: jobs, }, randGenerator: util.NewRand(time.Now().UTC().UnixNano()), // config cacheFile: cacheFile, masterHost: masterHost, masterPort: masterPort, tlsConfig: tlsConfig, httpHost: httpHost, httpPort: httpPort, // events tickDuration: interval, ticker: time.NewTicker(interval), syncedChan: make(chan struct{}, 1), pulledChan: make(chan struct{}, 1), } } // Sync this worker to the master. func (w *Worker) Sync() { var nextBlobConfig config.BlobConfig log.Logger().Info("start meta sync", zap.Duration("meta_timeout", w.Config.Master.MetaTimeout)) for { var meta *protocol.Meta var err error if meta, err = w.masterClient.GetMeta(context.Background(), &protocol.NodeInfo{ NodeType: protocol.NodeType_Worker, Uuid: w.workerName, BinaryVersion: version.Version, Hostname: lo.Must(os.Hostname()), }); err != nil { log.Logger().Error("failed to get meta", zap.Error(err)) goto sleep } // load master config err = json.Unmarshal([]byte(meta.Config), &w.Config) if err != nil { log.Logger().Error("failed to parse master config", zap.Error(err)) goto sleep } // connect to data store if w.dataPath != w.Config.Database.DataStore || w.dataPrefix != w.Config.Database.DataTablePrefix { if strings.HasPrefix(w.Config.Database.DataStore, storage.SQLitePrefix) { log.Logger().Info("connect data store via master") w.DataClient = data.NewProxyClient(w.conn) } else { log.Logger().Info("connect data store", zap.String("database", log.RedactDBURL(w.Config.Database.DataStore))) dataOpts := w.Config.Database.StorageOptions(w.Config.Database.DataStore) if w.DataClient, err = data.Open(w.Config.Database.DataStore, w.Config.Database.DataTablePrefix, dataOpts...); err != nil { log.Logger().Error("failed to connect data store", zap.Error(err)) goto sleep } } w.dataPath = w.Config.Database.DataStore w.dataPrefix = w.Config.Database.DataTablePrefix } // connect to cache store if w.cachePath != w.Config.Database.CacheStore || w.cachePrefix != w.Config.Database.CacheTablePrefix { if strings.HasPrefix(w.Config.Database.CacheStore, storage.SQLitePrefix) { log.Logger().Info("connect cache store via master") w.CacheClient = cache.NewProxyClient(w.conn) } else { log.Logger().Info("connect cache store", zap.String("database", log.RedactDBURL(w.Config.Database.CacheStore))) cacheOpts := w.Config.Database.StorageOptions(w.Config.Database.CacheStore) if w.CacheClient, err = cache.Open(w.Config.Database.CacheStore, w.Config.Database.CacheTablePrefix, cacheOpts...); err != nil { log.Logger().Error("failed to connect cache store", zap.Error(err)) goto sleep } } w.cachePath = w.Config.Database.CacheStore w.cachePrefix = w.Config.Database.CacheTablePrefix } // connect to blob store nextBlobConfig = w.Config.Blob if w.blobConfig != nextBlobConfig.URI { w.blobStore, err = blob.NewStore(w.Config.Blob, w.conn) if err != nil { log.Logger().Error("failed to connect blob store", zap.Error(err)) goto sleep } w.blobConfig = nextBlobConfig.URI } // synchronize collaborative filtering model w.latestCollaborativeFilteringModelId = meta.CollaborativeFilteringModelId if w.latestCollaborativeFilteringModelId > w.collaborativeFilteringModelId { log.Logger().Info("new ranking model found", zap.Int64("old_version", w.collaborativeFilteringModelId), zap.Int64("new_version", w.latestCollaborativeFilteringModelId)) select { case w.syncedChan <- struct{}{}: default: } } // synchronize click-through rate model w.latestClickThroughRateModelId = meta.ClickThroughRateModelId if w.latestClickThroughRateModelId > w.clickThroughRateModelId { log.Logger().Info("new click model found", zap.Int64("old_version", w.clickThroughRateModelId), zap.Int64("new_version", w.latestClickThroughRateModelId)) select { case w.syncedChan <- struct{}{}: default: } } w.peers = meta.Workers w.me = meta.Me sleep: if w.testMode { return } time.Sleep(w.Config.Master.MetaTimeout) } } // Pull user index and ranking model from master. func (w *Worker) Pull() { for range w.syncedChan { pulled := false // pull ranking model if w.latestCollaborativeFilteringModelId > w.collaborativeFilteringModelId { log.Logger().Info("start pull collaborative filtering model") r, err := w.blobStore.Open(strconv.FormatInt(w.latestCollaborativeFilteringModelId, 10)) if err != nil { log.Logger().Error("failed to open collaborative filtering model", zap.Error(err)) } else { items := logics.NewMatrixFactorizationItems(time.Time{}) users := logics.NewMatrixFactorizationUsers() if err = items.Unmarshal(r); err != nil { log.Logger().Error("failed to unmarshal matrix factorization items", zap.Error(err)) } else if err = users.Unmarshal(r); err != nil { log.Logger().Error("failed to unmarshal matrix factorization users", zap.Error(err)) } else { w.MatrixFactorizationItems = items w.MatrixFactorizationUsers = users w.collaborativeFilteringModelId = w.latestCollaborativeFilteringModelId log.Logger().Info("synced collaborative filtering model", zap.Int64("id", w.collaborativeFilteringModelId)) pulled = true } } } // pull click model if w.latestClickThroughRateModelId > w.clickThroughRateModelId { log.Logger().Info("start pull click model") r, err := w.blobStore.Open(strconv.FormatInt(w.latestClickThroughRateModelId, 10)) if err != nil { log.Logger().Error("failed to open click-through rate model", zap.Error(err)) } else { model, err := ctr.UnmarshalModel(r) if err != nil { log.Logger().Error("failed to unmarshal click-through rate model", zap.Error(err)) } else { w.ClickThroughRateModel = model w.clickThroughRateModelId = w.latestClickThroughRateModelId log.Logger().Info("synced click-through rate model", zap.Int64("version", w.clickThroughRateModelId)) pulled = true } } } if w.testMode { return } if pulled { select { case w.pulledChan <- struct{}{}: default: } } } } // ServeHTTP serves Prometheus metrics and API. func (w *Worker) ServeHTTP() { http.Handle("/metrics", promhttp.Handler()) http.HandleFunc("/api/health/live", w.checkLive) http.HandleFunc("/api/health/ready", w.checkReady) err := http.ListenAndServe(fmt.Sprintf("%s:%d", w.httpHost, w.httpPort), nil) if err != nil { log.Logger().Fatal("failed to start http server", zap.Error(err)) } } func writeJSON(w http.ResponseWriter, content any) { w.WriteHeader(http.StatusOK) bytes, err := json.Marshal(content) if err != nil { writeError(w, err.Error(), http.StatusInternalServerError) } if _, err = w.Write(bytes); err != nil { writeError(w, err.Error(), http.StatusInternalServerError) } } func writeError(w http.ResponseWriter, error string, code int) { log.Logger().Error(strings.ToLower(http.StatusText(code)), zap.String("error", error)) http.Error(w, error, code) } // Serve as a worker node. func (w *Worker) Serve() { var err error if w.workerName, err = w.WorkerName(); err != nil { log.Logger().Fatal("failed to get worker name", zap.Error(err)) } // create progress tracer w.Tracer = monitor.NewTracer(w.workerName) // connect to master var opts []grpc.DialOption opts = append(opts, grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(512*1024*1024))) if w.tlsConfig != nil { c, err := util.NewClientCreds(w.tlsConfig) if err != nil { log.Logger().Fatal("failed to create credentials", zap.Error(err)) } opts = append(opts, grpc.WithTransportCredentials(c)) } else { opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) } w.conn, err = grpc.Dial(net.JoinHostPort(w.masterHost, strconv.Itoa(w.masterPort)), opts...) if err != nil { log.Logger().Fatal("failed to connect master", zap.Error(err)) } w.masterClient = protocol.NewMasterClient(w.conn) go w.Sync() go w.Pull() go w.ServeHTTP() loop := func() { // pull users workingUsers, err := w.pullUsers(w.peers, w.me) if err != nil { log.Logger().Error("failed to split users", zap.Error(err), zap.String("me", w.me), zap.Strings("workers", w.peers)) return } // recommendation w.Recommend(context.Background(), workingUsers, func(completed, throughput int) { log.Logger().Info("ranking recommendation", zap.Int("n_complete_users", completed), zap.Int("throughput", throughput)) if w.masterClient != nil { if _, err := w.masterClient.PushProgress(context.Background(), monitor.EncodeProgress(w.Tracer.List())); err != nil { log.Logger().Error("failed to report update task", zap.Error(err)) } } }) } for { select { case tick := <-w.ticker.C: if time.Since(tick) <= w.tickDuration { loop() } case <-w.pulledChan: loop() } } } func (w *Worker) WorkerName() (string, error) { hostname, err := os.Hostname() if err != nil { return "", err } hash := md5.New() hash.Write([]byte(hostname)) hash.Write([]byte(w.httpHost)) hash.Write([]byte(strconv.Itoa(w.httpPort))) b := hash.Sum(nil) return hex.EncodeToString(b), nil } func (w *Worker) pullUsers(peers []string, me string) ([]data.User, error) { ctx := context.Background() // locate me if !lo.Contains(peers, me) { return nil, errors.New("current node isn't in worker nodes") } // create consistent hash ring c := consistent.New() for _, peer := range peers { c.Add(peer) } // pull users from database var users []data.User userChan, errChan := w.DataClient.GetUserStream(ctx, batchSize) for batchUsers := range userChan { for _, user := range batchUsers { p, err := c.Get(user.UserId) if err != nil { return nil, errors.Trace(err) } if p == me { users = append(users, user) } } } if err := <-errChan; err != nil { return nil, errors.Trace(err) } return users, nil } type HealthStatus struct { Ready bool DataStoreError error CacheStoreError error DataStoreConnected bool CacheStoreConnected bool } func (w *Worker) checkHealth() HealthStatus { healthStatus := HealthStatus{} healthStatus.DataStoreError = w.DataClient.Ping() healthStatus.CacheStoreError = w.CacheClient.Ping() healthStatus.DataStoreConnected = healthStatus.DataStoreError == nil healthStatus.CacheStoreConnected = healthStatus.CacheStoreError == nil healthStatus.Ready = healthStatus.DataStoreConnected && healthStatus.CacheStoreConnected return healthStatus } func (w *Worker) checkLive(writer http.ResponseWriter, _ *http.Request) { healthStatus := w.checkHealth() writeJSON(writer, healthStatus) } func (w *Worker) checkReady(writer http.ResponseWriter, _ *http.Request) { healthStatus := w.checkHealth() if healthStatus.Ready { writeJSON(writer, healthStatus) } else { errReason, err := json.Marshal(healthStatus) if err != nil { writeError(writer, err.Error(), http.StatusInternalServerError) } else { writeError(writer, string(errReason), http.StatusServiceUnavailable) } } } ================================================ FILE: worker/worker_test.go ================================================ // Copyright 2020 gorse Project Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package worker import ( "bytes" "context" "encoding/json" "fmt" "io" "math/rand" "net" "net/http" "net/http/httptest" "strconv" "testing" "time" "github.com/c-bata/goptuna" "github.com/gorse-io/gorse/common/expression" "github.com/gorse-io/gorse/common/monitor" "github.com/gorse-io/gorse/common/reranker" "github.com/gorse-io/gorse/common/util" "github.com/gorse-io/gorse/config" "github.com/gorse-io/gorse/dataset" "github.com/gorse-io/gorse/logics" "github.com/gorse-io/gorse/model" "github.com/gorse-io/gorse/model/cf" "github.com/gorse-io/gorse/model/ctr" "github.com/gorse-io/gorse/protocol" "github.com/gorse-io/gorse/storage/cache" "github.com/gorse-io/gorse/storage/data" "github.com/samber/lo" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" ) type WorkerTestSuite struct { suite.Suite Worker } func (suite *WorkerTestSuite) SetupSuite() { // open database var err error suite.Tracer = monitor.NewTracer("test") suite.Config = config.GetDefaultConfig() suite.DataClient, err = data.Open(fmt.Sprintf("sqlite://%s/data.db", suite.T().TempDir()), "") suite.NoError(err) suite.CacheClient, err = cache.Open(fmt.Sprintf("sqlite://%s/cache.db", suite.T().TempDir()), "") suite.NoError(err) // init database err = suite.DataClient.Init() suite.NoError(err) err = suite.CacheClient.Init() suite.NoError(err) } func (suite *WorkerTestSuite) TearDownSuite() { err := suite.DataClient.Close() suite.NoError(err) err = suite.CacheClient.Close() suite.NoError(err) } func (suite *WorkerTestSuite) SetupTest() { err := suite.DataClient.Purge() suite.NoError(err) err = suite.CacheClient.Purge() suite.NoError(err) // configuration suite.Config = config.GetDefaultConfig() suite.Config.Recommend.Collaborative.Type = "mf" suite.Config.Recommend.Fallback.Recommenders = []string{"latest"} suite.Jobs = 1 suite.dontskipColdStartUsers = true // reset random generator suite.randGenerator = rand.New(rand.NewSource(0)) // reset index suite.MatrixFactorizationItems = nil suite.ClickThroughRateModel = nil } func (suite *WorkerTestSuite) TestPullUsers() { ctx := suite.T().Context() // create user index err := suite.DataClient.BatchInsertUsers(ctx, []data.User{ {UserId: "1"}, {UserId: "2"}, {UserId: "3"}, {UserId: "4"}, {UserId: "5"}, {UserId: "6"}, {UserId: "7"}, {UserId: "8"}, }) suite.NoError(err) // create nodes nodes := []string{"a", "b", "c"} users, err := suite.pullUsers(nodes, "b") suite.NoError(err) suite.Equal([]data.User{{UserId: "1"}, {UserId: "3"}, {UserId: "6"}}, users) _, err = suite.pullUsers(nodes, "d") suite.Error(err) } func (suite *WorkerTestSuite) TestCheckRecommendCacheTimeout() { ctx := suite.T().Context() // empty cache suite.True(suite.checkRecommendCacheOutOfDate(ctx, "0")) err := suite.CacheClient.AddScores(ctx, cache.Recommend, "0", []cache.Score{{Id: "0", Score: 0, Categories: []string{""}}}) suite.NoError(err) // digest mismatch suite.True(suite.checkRecommendCacheOutOfDate(ctx, "0")) err = suite.CacheClient.Set(ctx, cache.String(cache.Key(cache.RecommendDigest, "0"), suite.Config.Recommend.Hash())) suite.NoError(err) err = suite.CacheClient.Set(ctx, cache.Time(cache.Key(cache.LastModifyUserTime, "0"), time.Now().Add(-time.Hour))) suite.NoError(err) suite.True(suite.checkRecommendCacheOutOfDate(ctx, "0")) err = suite.CacheClient.Set(ctx, cache.Time(cache.Key(cache.RecommendUpdateTime, "0"), time.Now().Add(-time.Hour*100))) suite.NoError(err) suite.True(suite.checkRecommendCacheOutOfDate(ctx, "0")) err = suite.CacheClient.Set(ctx, cache.Time(cache.Key(cache.RecommendUpdateTime, "0"), time.Now().Add(time.Hour*100))) suite.NoError(err) suite.False(suite.checkRecommendCacheOutOfDate(ctx, "0")) err = suite.CacheClient.DeleteScores(ctx, []string{cache.Recommend}, cache.ScoreCondition{Subset: new("0")}) suite.NoError(err) suite.True(suite.checkRecommendCacheOutOfDate(ctx, "0")) } func (suite *WorkerTestSuite) TestRecommendCollaborative() { ctx := suite.T().Context() suite.Config.Recommend.Ranker.Recommenders = []string{"collaborative"} // insert feedbacks now := time.Now() err := suite.DataClient.BatchInsertFeedback(ctx, []data.Feedback{ {FeedbackKey: data.FeedbackKey{FeedbackType: "click", UserId: "0", ItemId: "9"}, Timestamp: now.Add(-time.Hour)}, {FeedbackKey: data.FeedbackKey{FeedbackType: "click", UserId: "0", ItemId: "8"}, Timestamp: now.Add(-time.Hour)}, {FeedbackKey: data.FeedbackKey{FeedbackType: "click", UserId: "0", ItemId: "7"}, Timestamp: now.Add(-time.Hour)}, {FeedbackKey: data.FeedbackKey{FeedbackType: "click", UserId: "0", ItemId: "6"}, Timestamp: now.Add(-time.Hour)}, {FeedbackKey: data.FeedbackKey{FeedbackType: "click", UserId: "0", ItemId: "5"}, Timestamp: now.Add(-time.Hour)}, {FeedbackKey: data.FeedbackKey{FeedbackType: "click", UserId: "0", ItemId: "4"}, Timestamp: now.Add(-time.Hour)}, {FeedbackKey: data.FeedbackKey{FeedbackType: "click", UserId: "0", ItemId: "3"}, Timestamp: now.Add(time.Hour)}, {FeedbackKey: data.FeedbackKey{FeedbackType: "click", UserId: "0", ItemId: "2"}, Timestamp: now.Add(time.Hour)}, {FeedbackKey: data.FeedbackKey{FeedbackType: "click", UserId: "0", ItemId: "1"}, Timestamp: now.Add(time.Hour)}, {FeedbackKey: data.FeedbackKey{FeedbackType: "click", UserId: "0", ItemId: "0"}, Timestamp: now.Add(time.Hour)}, }, true, true, true) suite.NoError(err) // insert hidden items and categorized items err = suite.DataClient.BatchInsertItems(ctx, []data.Item{ {ItemId: "10", IsHidden: true}, {ItemId: "11", IsHidden: true}, {ItemId: "3", Categories: []string{"*"}}, {ItemId: "1", Categories: []string{"*"}}, }) suite.NoError(err) // create mock model suite.MatrixFactorizationItems = logics.NewMatrixFactorizationItems(time.Time{}) for i := 0; i < 10; i++ { suite.MatrixFactorizationItems.Add(strconv.Itoa(i), []float32{float32(i)}) } suite.MatrixFactorizationUsers = logics.NewMatrixFactorizationUsers() suite.MatrixFactorizationUsers.Add("0", []float32{1}) suite.Recommend(ctx, []data.User{{UserId: "0"}}, nil) // read recommend time recommendTime, err := suite.CacheClient.Get(ctx, cache.Key(cache.RecommendUpdateTime, "0")).Time() suite.NoError(err) recommends, err := suite.CacheClient.SearchScores(ctx, cache.Recommend, "0", nil, 0, -1) suite.NoError(err) suite.Equal([]cache.Score{ {Id: "3", Score: 3, Categories: []string{"*"}, Timestamp: recommendTime}, {Id: "2", Score: 2, Timestamp: recommendTime}, {Id: "1", Score: 1, Categories: []string{"*"}, Timestamp: recommendTime}, {Id: "0", Score: 0, Timestamp: recommendTime}, }, recommends) } func (suite *WorkerTestSuite) TestRecommendItemToItem() { ctx := suite.T().Context() suite.Config.Recommend.Ranker.Recommenders = []string{"item-to-item/default"} suite.Config.Recommend.ItemToItem = []config.ItemToItemConfig{{Name: "default"}} suite.Config.Recommend.DataSource.PositiveFeedbackTypes = []expression.FeedbackTypeExpression{expression.MustParseFeedbackTypeExpression("a")} // insert feedback err := suite.DataClient.BatchInsertFeedback(ctx, []data.Feedback{ {FeedbackKey: data.FeedbackKey{FeedbackType: "a", UserId: "0", ItemId: "21"}}, {FeedbackKey: data.FeedbackKey{FeedbackType: "a", UserId: "0", ItemId: "22"}}, {FeedbackKey: data.FeedbackKey{FeedbackType: "a", UserId: "0", ItemId: "23"}}, {FeedbackKey: data.FeedbackKey{FeedbackType: "a", UserId: "0", ItemId: "24"}}, }, true, true, true) suite.NoError(err) // insert similar items err = suite.CacheClient.AddScores(ctx, cache.ItemToItem, cache.Key("default", "21"), []cache.Score{ {Id: "22", Score: 100000, Categories: []string{"*"}}, {Id: "25", Score: 1000000}, {Id: "29", Score: 1}, }) suite.NoError(err) err = suite.CacheClient.AddScores(ctx, cache.ItemToItem, cache.Key("default", "22"), []cache.Score{ {Id: "23", Score: 100000, Categories: []string{"*"}}, {Id: "25", Score: 1000000}, {Id: "28", Score: 1, Categories: []string{"*"}}, {Id: "29", Score: 1}, }) suite.NoError(err) err = suite.CacheClient.AddScores(ctx, cache.ItemToItem, cache.Key("default", "23"), []cache.Score{ {Id: "24", Score: 100000, Categories: []string{"*"}}, {Id: "25", Score: 1000000}, {Id: "27", Score: 1}, {Id: "28", Score: 1, Categories: []string{"*"}}, {Id: "29", Score: 1}, }) suite.NoError(err) err = suite.CacheClient.AddScores(ctx, cache.ItemToItem, cache.Key("default", "24"), []cache.Score{ {Id: "21", Score: 100000}, {Id: "25", Score: 1000000}, {Id: "26", Score: 1, Categories: []string{"*"}}, {Id: "27", Score: 1}, {Id: "28", Score: 1, Categories: []string{"*"}}, {Id: "29", Score: 1}, }) suite.NoError(err) // insert items err = suite.DataClient.BatchInsertItems(ctx, []data.Item{{ItemId: "21"}, {ItemId: "22"}, {ItemId: "23"}, {ItemId: "24"}, {ItemId: "25"}, {ItemId: "26"}, {ItemId: "27"}, {ItemId: "28"}, {ItemId: "29"}}) suite.NoError(err) // insert hidden items err = suite.DataClient.BatchInsertItems(ctx, []data.Item{{ItemId: "25", IsHidden: true}}) suite.NoError(err) // insert categorized items err = suite.DataClient.BatchInsertItems(ctx, []data.Item{{ItemId: "26", Categories: []string{"*"}}, {ItemId: "28", Categories: []string{"*"}}}) suite.NoError(err) suite.Recommend(ctx, []data.User{{UserId: "0"}}, nil) // read recommend time recommendTime, err := suite.CacheClient.Get(ctx, cache.Key(cache.RecommendUpdateTime, "0")).Time() suite.NoError(err) // read recommend result recommends, err := suite.CacheClient.SearchScores(ctx, cache.Recommend, "0", nil, 0, 3) suite.NoError(err) suite.Equal([]cache.Score{ {Id: "29", Score: 4, Timestamp: recommendTime}, {Id: "28", Score: 3, Categories: []string{"*"}, Timestamp: recommendTime}, {Id: "27", Score: 2, Timestamp: recommendTime}, }, recommends) recommends, err = suite.CacheClient.SearchScores(ctx, cache.Recommend, "0", []string{"*"}, 0, 3) suite.NoError(err) suite.Equal([]cache.Score{ {Id: "28", Score: 3, Categories: []string{"*"}, Timestamp: recommendTime}, {Id: "26", Score: 1, Categories: []string{"*"}, Timestamp: recommendTime}, }, recommends) } func (suite *WorkerTestSuite) TestRecommendUserToUser() { ctx := suite.T().Context() suite.Config.Recommend.Ranker.Recommenders = []string{"user-to-user/default"} suite.Config.Recommend.UserToUser = []config.UserToUserConfig{{Name: "default"}} suite.Config.Recommend.DataSource.PositiveFeedbackTypes = []expression.FeedbackTypeExpression{expression.MustParseFeedbackTypeExpression("a")} // insert similar users err := suite.CacheClient.AddScores(ctx, cache.UserToUser, cache.Key("default", "0"), []cache.Score{ {Id: "1", Score: 2}, {Id: "2", Score: 1.5}, {Id: "3", Score: 1}, }) suite.NoError(err) // insert feedback err = suite.DataClient.BatchInsertFeedback(ctx, []data.Feedback{ {FeedbackKey: data.FeedbackKey{FeedbackType: "a", UserId: "1", ItemId: "10"}}, {FeedbackKey: data.FeedbackKey{FeedbackType: "a", UserId: "1", ItemId: "11"}}, }, true, true, true) suite.NoError(err) err = suite.DataClient.BatchInsertFeedback(ctx, []data.Feedback{ {FeedbackKey: data.FeedbackKey{FeedbackType: "a", UserId: "2", ItemId: "10"}}, {FeedbackKey: data.FeedbackKey{FeedbackType: "a", UserId: "2", ItemId: "12"}}, {FeedbackKey: data.FeedbackKey{FeedbackType: "a", UserId: "2", ItemId: "48"}}, }, true, true, true) suite.NoError(err) err = suite.DataClient.BatchInsertFeedback(ctx, []data.Feedback{ {FeedbackKey: data.FeedbackKey{FeedbackType: "a", UserId: "3", ItemId: "10"}}, {FeedbackKey: data.FeedbackKey{FeedbackType: "a", UserId: "3", ItemId: "13"}}, {FeedbackKey: data.FeedbackKey{FeedbackType: "a", UserId: "3", ItemId: "48"}}, }, true, true, true) suite.NoError(err) // insert hidden items err = suite.DataClient.BatchInsertItems(ctx, []data.Item{{ItemId: "10", IsHidden: true}}) suite.NoError(err) // insert categorized items err = suite.DataClient.BatchInsertItems(ctx, []data.Item{ {ItemId: "12", Categories: []string{"*"}}, {ItemId: "48", Categories: []string{"*"}}, }) suite.NoError(err) suite.Recommend(ctx, []data.User{{UserId: "0"}}, nil) // read recommend time recommendTime, err := suite.CacheClient.Get(ctx, cache.Key(cache.RecommendUpdateTime, "0")).Time() suite.NoError(err) // read recommend result recommends, err := suite.CacheClient.SearchScores(ctx, cache.Recommend, "0", nil, 0, 3) suite.NoError(err) suite.Equal([]cache.Score{ {Id: "48", Score: 2.5, Categories: []string{"*"}, Timestamp: recommendTime}, {Id: "11", Score: 2, Timestamp: recommendTime}, {Id: "12", Score: 1.5, Categories: []string{"*"}, Timestamp: recommendTime}, }, recommends) recommends, err = suite.CacheClient.SearchScores(ctx, cache.Recommend, "0", []string{"*"}, 0, 3) suite.NoError(err) suite.Equal([]cache.Score{ {Id: "48", Score: 2.5, Categories: []string{"*"}, Timestamp: recommendTime}, {Id: "12", Score: 1.5, Categories: []string{"*"}, Timestamp: recommendTime}, }, recommends) } func (suite *WorkerTestSuite) TestRecommendLatest() { // create mock worker ctx := suite.T().Context() suite.Config.Recommend.Ranker.Recommenders = []string{"latest"} // insert items err := suite.DataClient.BatchInsertItems(ctx, []data.Item{ {ItemId: "21", Timestamp: time.Unix(21, 0)}, {ItemId: "20", Timestamp: time.Unix(20, 0)}, {ItemId: "19", Timestamp: time.Unix(19, 0)}, {ItemId: "18", Timestamp: time.Unix(18, 0)}, {ItemId: "10", Categories: []string{"*"}, Timestamp: time.Unix(10, 0)}, {ItemId: "9", Categories: []string{"*"}, Timestamp: time.Unix(9, 0)}, {ItemId: "8", Categories: []string{"*"}, Timestamp: time.Unix(8, 0)}, }) suite.NoError(err) // insert hidden items err = suite.DataClient.BatchInsertItems(ctx, []data.Item{{ItemId: "21", IsHidden: true}}) suite.NoError(err) suite.Recommend(ctx, []data.User{{UserId: "0"}}, nil) // read recommend time recommendTime, err := suite.CacheClient.Get(ctx, cache.Key(cache.RecommendUpdateTime, "0")).Time() suite.NoError(err) // read recommend result recommends, err := suite.CacheClient.SearchScores(ctx, cache.Recommend, "0", nil, 0, 3) suite.NoError(err) suite.Equal([]cache.Score{ {Id: "20", Score: 20, Timestamp: recommendTime}, {Id: "19", Score: 19, Timestamp: recommendTime}, {Id: "18", Score: 18, Timestamp: recommendTime}, }, recommends) recommends, err = suite.CacheClient.SearchScores(ctx, cache.Recommend, "0", []string{"*"}, 0, -1) suite.NoError(err) suite.Equal([]cache.Score{ {Id: "10", Score: 10, Categories: []string{"*"}, Timestamp: recommendTime}, {Id: "9", Score: 9, Categories: []string{"*"}, Timestamp: recommendTime}, {Id: "8", Score: 8, Categories: []string{"*"}, Timestamp: recommendTime}, }, recommends) // read recommend digest digest, err := suite.CacheClient.Get(ctx, cache.Key(cache.RecommendDigest, "0")).String() suite.NoError(err) suite.Equal(util.MD5("latest"), digest) } func (suite *WorkerTestSuite) TestRecommendNonPersonalized() { // create mock worker ctx := suite.T().Context() suite.Config.Recommend.Ranker.Recommenders = []string{"non-personalized/popular"} // insert items err := suite.DataClient.BatchInsertItems(ctx, []data.Item{ {ItemId: "11"}, {ItemId: "10"}, {ItemId: "9"}, {ItemId: "8"}, {ItemId: "20", Categories: []string{"*"}}, {ItemId: "19", Categories: []string{"*"}}, {ItemId: "18", Categories: []string{"*"}}, }) suite.NoError(err) // insert non-personalized recommendation err = suite.CacheClient.AddScores(ctx, cache.NonPersonalized, "popular", []cache.Score{ {Id: "11", Score: 31, Categories: []string{""}}, {Id: "10", Score: 30, Categories: []string{""}}, {Id: "9", Score: 29, Categories: []string{""}}, {Id: "8", Score: 28, Categories: []string{""}}, {Id: "20", Score: 20, Categories: []string{"", "*"}}, {Id: "19", Score: 19, Categories: []string{"", "*"}}, {Id: "18", Score: 18, Categories: []string{"", "*"}}, }) suite.NoError(err) // insert hidden items err = suite.DataClient.BatchInsertItems(ctx, []data.Item{{ItemId: "11", IsHidden: true}}) suite.NoError(err) suite.Recommend(ctx, []data.User{{UserId: "0"}}, nil) // read recommend time recommendTime, err := suite.CacheClient.Get(ctx, cache.Key(cache.RecommendUpdateTime, "0")).Time() suite.NoError(err) // read recommend result recommends, err := suite.CacheClient.SearchScores(ctx, cache.Recommend, "0", nil, 0, 3) suite.NoError(err) suite.Equal([]cache.Score{ {Id: "10", Score: 30, Categories: []string{""}, Timestamp: recommendTime}, {Id: "9", Score: 29, Categories: []string{""}, Timestamp: recommendTime}, {Id: "8", Score: 28, Categories: []string{""}, Timestamp: recommendTime}, }, recommends) recommends, err = suite.CacheClient.SearchScores(ctx, cache.Recommend, "0", []string{"*"}, 0, -1) suite.NoError(err) suite.Equal([]cache.Score{ {Id: "20", Score: 20, Categories: []string{"", "*"}, Timestamp: recommendTime}, {Id: "19", Score: 19, Categories: []string{"", "*"}, Timestamp: recommendTime}, {Id: "18", Score: 18, Categories: []string{"", "*"}, Timestamp: recommendTime}, }, recommends) } func (suite *WorkerTestSuite) TestRecommend() { ctx := suite.T().Context() suite.Config.Recommend.Ranker.Type = "fm" suite.Config.Recommend.Ranker.Recommenders = nil suite.Config.Recommend.DataSource.PositiveFeedbackTypes = []expression.FeedbackTypeExpression{expression.MustParseFeedbackTypeExpression("a")} suite.Config.Recommend.CacheSize = 1 suite.Config.Recommend.NonPersonalized = []config.NonPersonalizedConfig{{Name: "popular"}} suite.Config.Recommend.ItemToItem = []config.ItemToItemConfig{{Name: "default"}} suite.Config.Recommend.UserToUser = []config.UserToUserConfig{{Name: "default"}} suite.MatrixFactorizationItems = logics.NewMatrixFactorizationItems(time.Time{}) suite.MatrixFactorizationItems.Add("4", []float32{4}) suite.MatrixFactorizationUsers = logics.NewMatrixFactorizationUsers() suite.MatrixFactorizationUsers.Add("0", []float32{1}) suite.ClickThroughRateModel = new(mockFactorizationMachine) // insert items err := suite.DataClient.BatchInsertItems(ctx, []data.Item{ {ItemId: "0", Timestamp: time.Unix(0, 0)}, {ItemId: "1", Timestamp: time.Unix(1, 0)}, {ItemId: "2", Timestamp: time.Unix(2, 0)}, {ItemId: "3", Timestamp: time.Unix(3, 0)}, {ItemId: "4", Timestamp: time.Unix(4, 0)}, {ItemId: "5", Categories: []string{"a"}, Timestamp: time.Unix(5, 0)}, }) suite.NoError(err) // insert feedback err = suite.DataClient.BatchInsertFeedback(ctx, []data.Feedback{ {FeedbackKey: data.FeedbackKey{FeedbackType: "a", UserId: "0", ItemId: "0"}}, {FeedbackKey: data.FeedbackKey{FeedbackType: "a", UserId: "1", ItemId: "1"}}, }, true, true, true) suite.NoError(err) // insert stale recommendation err = suite.CacheClient.AddScores(ctx, cache.Recommend, "0", []cache.Score{{Id: "999", Score: 999}}) suite.NoError(err) // insert non-personalized recommendation err = suite.CacheClient.AddScores(ctx, cache.NonPersonalized, "popular", []cache.Score{{Id: "3", Categories: []string{""}}}) suite.NoError(err) // insert item-to-item recommendation err = suite.CacheClient.AddScores(ctx, cache.ItemToItem, cache.Key("default", "0"), []cache.Score{{Id: "2"}}) suite.NoError(err) // insert user-to-user recommendation err = suite.CacheClient.AddScores(ctx, cache.UserToUser, cache.Key("default", "0"), []cache.Score{{Id: "1"}}) suite.NoError(err) suite.Recommend(ctx, []data.User{{UserId: "0"}}, nil) // read recommend time recommendTime, err := suite.CacheClient.Get(ctx, cache.Key(cache.RecommendUpdateTime, "0")).Time() suite.NoError(err) // read recommend result recommends, err := suite.CacheClient.SearchScores(ctx, cache.Recommend, "0", nil, 0, 5) suite.NoError(err) suite.Equal([]cache.Score{ {Id: "5", Score: 5, Timestamp: recommendTime, Categories: []string{"a"}}, {Id: "4", Score: 4, Timestamp: recommendTime}, {Id: "3", Score: 3, Timestamp: recommendTime}, {Id: "2", Score: 2, Timestamp: recommendTime}, {Id: "1", Score: 1, Timestamp: recommendTime}, }, recommends) } func (suite *WorkerTestSuite) TestRecommendRankerNone() { ctx := suite.T().Context() suite.Config.Recommend.Ranker.Type = "none" suite.Config.Recommend.DataSource.PositiveFeedbackTypes = []expression.FeedbackTypeExpression{expression.MustParseFeedbackTypeExpression("a")} suite.Config.Recommend.CacheSize = 1 suite.Config.Recommend.NonPersonalized = []config.NonPersonalizedConfig{{Name: "popular"}} suite.Config.Recommend.ItemToItem = []config.ItemToItemConfig{{Name: "default"}} suite.Config.Recommend.UserToUser = []config.UserToUserConfig{{Name: "default"}} suite.MatrixFactorizationItems = logics.NewMatrixFactorizationItems(time.Time{}) suite.MatrixFactorizationItems.Add("4", []float32{4}) suite.MatrixFactorizationUsers = logics.NewMatrixFactorizationUsers() suite.MatrixFactorizationUsers.Add("0", []float32{1}) suite.ClickThroughRateModel = new(mockFactorizationMachine) // insert items err := suite.DataClient.BatchInsertItems(ctx, []data.Item{ {ItemId: "0", Timestamp: time.Unix(0, 0)}, {ItemId: "1", Timestamp: time.Unix(1, 0)}, {ItemId: "2", Timestamp: time.Unix(2, 0)}, {ItemId: "3", Timestamp: time.Unix(3, 0)}, {ItemId: "4", Timestamp: time.Unix(4, 0)}, {ItemId: "5", Categories: []string{"a"}, Timestamp: time.Unix(5, 0)}, }) suite.NoError(err) // insert feedback err = suite.DataClient.BatchInsertFeedback(ctx, []data.Feedback{ {FeedbackKey: data.FeedbackKey{FeedbackType: "a", UserId: "0", ItemId: "0"}}, {FeedbackKey: data.FeedbackKey{FeedbackType: "a", UserId: "1", ItemId: "1"}}, }, true, true, true) suite.NoError(err) // insert non-personalized recommendation err = suite.CacheClient.AddScores(ctx, cache.NonPersonalized, "popular", []cache.Score{{Id: "3", Categories: []string{""}}}) suite.NoError(err) // insert item-to-item recommendation err = suite.CacheClient.AddScores(ctx, cache.ItemToItem, cache.Key("default", "0"), []cache.Score{{Id: "2"}}) suite.NoError(err) // insert user-to-user recommendation err = suite.CacheClient.AddScores(ctx, cache.UserToUser, cache.Key("default", "0"), []cache.Score{{Id: "1"}}) suite.NoError(err) suite.Recommend(ctx, []data.User{{UserId: "0"}}, nil) // read recommend time recommendTime, err := suite.CacheClient.Get(ctx, cache.Key(cache.RecommendUpdateTime, "0")).Time() suite.NoError(err) // read recommend result recommends, err := suite.CacheClient.SearchScores(ctx, cache.Recommend, "0", nil, 0, 5) suite.NoError(err) suite.NotEqual([]cache.Score{ {Id: "5", Score: 5, Timestamp: recommendTime, Categories: []string{"a"}}, {Id: "4", Score: 4, Timestamp: recommendTime}, {Id: "3", Score: 3, Timestamp: recommendTime}, {Id: "2", Score: 2, Timestamp: recommendTime}, {Id: "1", Score: 1, Timestamp: recommendTime}, }, recommends) } func marshal(t *testing.T, v interface{}) string { s, err := json.Marshal(v) assert.NoError(t, err) return string(s) } func newRankingDataset() (*dataset.Dataset, *dataset.Dataset) { return dataset.NewDataset(time.Now(), 0, 0), dataset.NewDataset(time.Now(), 0, 0) } func newClickDataset() (*ctr.Dataset, *ctr.Dataset) { dataSet := &ctr.Dataset{ Index: dataset.NewUnifiedMapIndexBuilder().Build(), } return dataSet, dataSet } type mockMaster struct { protocol.UnimplementedMasterServer addr chan string grpcServer *grpc.Server cacheFilePath string dataFilePath string meta *protocol.Meta rankingModel []byte clickModel []byte userIndex []byte } func newMockMaster(t *testing.T) *mockMaster { cfg := config.GetDefaultConfig() cfg.Database.DataStore = fmt.Sprintf("sqlite://%s/data.db", t.TempDir()) cfg.Database.CacheStore = fmt.Sprintf("sqlite://%s/cache.db", t.TempDir()) // create click model train, test := newClickDataset() fm := ctr.NewAFM(model.Params{model.NEpochs: 0}) fm.Fit(t.Context(), train, test, &ctr.FitConfig{}) clickModelBuffer := bytes.NewBuffer(nil) err := ctr.MarshalModel(clickModelBuffer, fm) assert.NoError(t, err) // create ranking model trainSet, testSet := newRankingDataset() bpr := cf.NewBPR(model.Params{model.NEpochs: 0}) bpr.Fit(t.Context(), trainSet, testSet, cf.NewFitConfig()) rankingModelBuffer := bytes.NewBuffer(nil) err = cf.MarshalModel(rankingModelBuffer, bpr) assert.NoError(t, err) // create user index userIndexBuffer := bytes.NewBuffer(nil) err = dataset.MarshalIndex(userIndexBuffer, dataset.NewMapIndex()) assert.NoError(t, err) return &mockMaster{ addr: make(chan string), meta: &protocol.Meta{ Config: marshal(t, cfg), ClickThroughRateModelId: 1, CollaborativeFilteringModelId: 2, }, cacheFilePath: cfg.Database.CacheStore, dataFilePath: cfg.Database.DataStore, userIndex: userIndexBuffer.Bytes(), clickModel: clickModelBuffer.Bytes(), rankingModel: rankingModelBuffer.Bytes(), } } func (m *mockMaster) GetMeta(_ context.Context, _ *protocol.NodeInfo) (*protocol.Meta, error) { return m.meta, nil } func (m *mockMaster) Start(t *testing.T) { listen, err := net.Listen("tcp", "localhost:0") assert.NoError(t, err) m.addr <- listen.Addr().String() var opts []grpc.ServerOption m.grpcServer = grpc.NewServer(opts...) protocol.RegisterMasterServer(m.grpcServer, m) err = m.grpcServer.Serve(listen) assert.NoError(t, err) } func (m *mockMaster) Stop() { m.grpcServer.Stop() } func TestWorker_Sync(t *testing.T) { master := newMockMaster(t) go master.Start(t) address := <-master.addr conn, err := grpc.Dial(address, grpc.WithTransportCredentials(insecure.NewCredentials())) assert.NoError(t, err) serv := &Worker{ Pipeline: Pipeline{ Config: config.GetDefaultConfig(), CacheClient: new(cache.NoDatabase), DataClient: new(data.NoDatabase), }, testMode: true, masterClient: protocol.NewMasterClient(conn), syncedChan: make(chan struct{}, 1), ticker: time.NewTicker(time.Minute), } serv.Sync() assert.Equal(t, master.dataFilePath, serv.dataPath) assert.Equal(t, master.cacheFilePath, serv.cachePath) assert.NoError(t, serv.DataClient.Close()) assert.NoError(t, serv.CacheClient.Close()) assert.Equal(t, int64(1), serv.latestClickThroughRateModelId) assert.Equal(t, int64(2), serv.latestCollaborativeFilteringModelId) assert.Zero(t, serv.clickThroughRateModelId) assert.Zero(t, serv.collaborativeFilteringModelId) master.Stop() } type mockFactorizationMachine struct { ctr.BaseFactorizationMachines } func (m mockFactorizationMachine) Complexity() int { panic("implement me") } func (m mockFactorizationMachine) SuggestParams(_ goptuna.Trial) model.Params { panic("implement me") } func (m mockFactorizationMachine) Clear() { panic("implement me") } func (m mockFactorizationMachine) Invalid() bool { return false } func (m mockFactorizationMachine) Predict(_, itemId string, _, _ []ctr.Label) float32 { score, err := strconv.Atoi(itemId) if err != nil { panic(err) } return float32(score) } func (m mockFactorizationMachine) InternalPredict(_ []int32, _ []float32) float32 { panic("implement me") } func (m mockFactorizationMachine) Fit(_ context.Context, _, _ dataset.CTRSplit, _ *ctr.FitConfig) ctr.Score { panic("implement me") } func (m mockFactorizationMachine) Marshal(_ io.Writer) error { panic("implement me") } func (suite *WorkerTestSuite) TestRankByClickTroughRate() { ctx := suite.T().Context() // insert a user err := suite.DataClient.BatchInsertUsers(ctx, []data.User{{UserId: "1"}}) suite.NoError(err) // insert items err = suite.DataClient.BatchInsertItems(ctx, []data.Item{ {ItemId: "1"}, {ItemId: "2"}, {ItemId: "3"}, {ItemId: "4"}, {ItemId: "5"}, }) suite.NoError(err) // rank items itemCache := NewItemCache(suite.DataClient) result, err := suite.rankByClickTroughRate(ctx, new(mockFactorizationMachine), &data.User{UserId: "1"}, []cache.Score{{Id: "1"}, {Id: "2"}, {Id: "3"}, {Id: "4"}, {Id: "5"}}, itemCache, time.Now()) suite.NoError(err) suite.Equal([]string{"5", "4", "3", "2", "1"}, lo.Map(result, func(d cache.Score, _ int) string { return d.Id })) suite.IsDecreasing(lo.Map(result, func(d cache.Score, _ int) float64 { return d.Score })) } func (suite *WorkerTestSuite) TestRankByLLM() { ctx := suite.T().Context() mockAI := reranker.NewMockServer() go func() { _ = mockAI.Start() }() mockAI.Ready() defer mockAI.Close() // insert a user err := suite.DataClient.BatchInsertUsers(ctx, []data.User{{UserId: "u1"}}) suite.NoError(err) // insert items used by candidates and feedback err = suite.DataClient.BatchInsertItems(ctx, []data.Item{{ItemId: "1"}, {ItemId: "2"}, {ItemId: "3"}, {ItemId: "4"}, {ItemId: "5"}}) suite.NoError(err) suite.Config.Recommend.Ranker.RerankerAPI = config.RerankerAPIConfig{ URL: mockAI.URL(), AuthToken: mockAI.AuthToken(), Model: "v1", } ranker, err := logics.NewChatReranker(suite.Config.Recommend.Ranker.RerankerAPI, "{{user.UserId}}", "{{item.ItemId}}") suite.NoError(err) itemCache := NewItemCache(suite.DataClient) recommendTime := time.Now() result, err := suite.rankByLLM(ctx, nil, ranker, &data.User{UserId: "u1"}, []data.Feedback{ {FeedbackKey: data.FeedbackKey{FeedbackType: "like", UserId: "u1", ItemId: "4"}}, {FeedbackKey: data.FeedbackKey{FeedbackType: "like", UserId: "u1", ItemId: "5"}}, }, []cache.Score{{Id: "1"}, {Id: "2"}, {Id: "3"}}, itemCache, recommendTime) suite.NoError(err) suite.Equal([]string{"1", "2", "3"}, lo.Map(result, func(d cache.Score, _ int) string { return d.Id })) suite.Equal([]float64{1, 0.5, float64(1) / 3}, lo.Map(result, func(d cache.Score, _ int) float64 { return d.Score })) for _, scored := range result { suite.Equal(recommendTime, scored.Timestamp) } } func (suite *WorkerTestSuite) TestReplacement() { ctx := suite.T().Context() suite.Config.Recommend.DataSource.PositiveFeedbackTypes = []expression.FeedbackTypeExpression{ expression.MustParseFeedbackTypeExpression("p")} suite.Config.Recommend.DataSource.ReadFeedbackTypes = []expression.FeedbackTypeExpression{ expression.MustParseFeedbackTypeExpression("n")} suite.Config.Recommend.Ranker.Type = "fm" suite.Config.Recommend.Ranker.Recommenders = []string{"collaborative"} suite.Config.Recommend.Replacement.EnableReplacement = true suite.Config.Recommend.Replacement.PositiveReplacementDecay = 0.8 suite.Config.Recommend.Replacement.ReadReplacementDecay = 0.7 suite.ClickThroughRateModel = new(mockFactorizationMachine) // 1. Insert historical items into empty recommendation. // insert items err := suite.DataClient.BatchInsertItems(ctx, []data.Item{ {ItemId: "10"}, {ItemId: "9"}, {ItemId: "8"}, {ItemId: "7"}, {ItemId: "6"}, {ItemId: "5"}, }) suite.NoError(err) // insert feedback err = suite.DataClient.BatchInsertFeedback(ctx, []data.Feedback{ {FeedbackKey: data.FeedbackKey{FeedbackType: "p", UserId: "0", ItemId: "10"}}, {FeedbackKey: data.FeedbackKey{FeedbackType: "n", UserId: "0", ItemId: "9"}}, {FeedbackKey: data.FeedbackKey{FeedbackType: "i", UserId: "0", ItemId: "8"}}, }, true, false, true) suite.NoError(err) suite.Recommend(ctx, []data.User{{UserId: "0"}}, nil) // read recommend time recommendTime, err := suite.CacheClient.Get(ctx, cache.Key(cache.RecommendUpdateTime, "0")).Time() suite.NoError(err) // read recommend result recommends, err := suite.CacheClient.SearchScores(ctx, cache.Recommend, "0", nil, 0, 3) suite.NoError(err) suite.Equal([]cache.Score{ {Id: "10", Score: 8, Timestamp: recommendTime}, {Id: "9", Score: 6.3, Timestamp: recommendTime}, }, recommends) // 2. Insert historical items into non-empty recommendation. suite.Config.Recommend.CacheExpire = 0 suite.Config.Recommend.Ranker.Recommenders = []string{"latest"} suite.Recommend(ctx, []data.User{{UserId: "0"}}, nil) // read recommend time recommendTime, err = suite.CacheClient.Get(ctx, cache.Key(cache.RecommendUpdateTime, "0")).Time() suite.NoError(err) // read recommend result recommends, err = suite.CacheClient.SearchScores(ctx, cache.Recommend, "0", nil, 0, 3) suite.NoError(err) suite.Equal([]cache.Score{ {Id: "10", Score: 8, Timestamp: recommendTime}, {Id: "7", Score: 7, Timestamp: recommendTime}, {Id: "9", Score: 6.3, Timestamp: recommendTime}, }, recommends) } func (suite *WorkerTestSuite) TestUserActivity() { ctx := suite.T().Context() err := suite.CacheClient.Set(ctx, cache.Time(cache.Key(cache.LastModifyUserTime, "0"), time.Now().AddDate(0, 0, -1))) suite.NoError(err) err = suite.CacheClient.Set(ctx, cache.Time(cache.Key(cache.LastModifyUserTime, "1"), time.Now().AddDate(0, 0, -10))) suite.NoError(err) err = suite.CacheClient.AddScores(ctx, cache.Recommend, "0", []cache.Score{{Id: "0", Score: 1, Categories: []string{""}}}) suite.NoError(err) err = suite.CacheClient.AddScores(ctx, cache.Recommend, "1", []cache.Score{{Id: "1", Score: 1, Categories: []string{""}}}) suite.NoError(err) err = suite.CacheClient.AddScores(ctx, cache.Recommend, "2", []cache.Score{{Id: "2", Score: 1, Categories: []string{""}}}) suite.NoError(err) suite.True(suite.checkUserActiveTime(ctx, "0")) suite.True(suite.checkUserActiveTime(ctx, "1")) suite.True(suite.checkUserActiveTime(ctx, "2")) docs, err := suite.CacheClient.SearchScores(ctx, cache.Recommend, "0", []string{""}, 0, 1) suite.NoError(err) suite.NotEmpty(docs) docs, err = suite.CacheClient.SearchScores(ctx, cache.Recommend, "1", []string{""}, 0, 1) suite.NoError(err) suite.NotEmpty(docs) docs, err = suite.CacheClient.SearchScores(ctx, cache.Recommend, "2", []string{""}, 0, 1) suite.NoError(err) suite.NotEmpty(docs) suite.Config.Recommend.ActiveUserTTL = 5 suite.True(suite.checkUserActiveTime(ctx, "0")) suite.False(suite.checkUserActiveTime(ctx, "1")) suite.True(suite.checkUserActiveTime(ctx, "2")) docs, err = suite.CacheClient.SearchScores(ctx, cache.Recommend, "0", []string{""}, 0, 1) suite.NoError(err) suite.NotEmpty(docs) docs, err = suite.CacheClient.SearchScores(ctx, cache.Recommend, "1", []string{""}, 0, 1) suite.NoError(err) suite.Empty(docs) docs, err = suite.CacheClient.SearchScores(ctx, cache.Recommend, "2", []string{""}, 0, 1) suite.NoError(err) suite.NotEmpty(docs) } func (suite *WorkerTestSuite) TestHealth() { req := httptest.NewRequest("GET", "https://example.com/", nil) w := httptest.NewRecorder() suite.checkLive(w, req) suite.Equal(http.StatusOK, w.Code) suite.Equal(marshal(suite.T(), HealthStatus{ Ready: true, DataStoreError: nil, CacheStoreError: nil, DataStoreConnected: true, CacheStoreConnected: true, }), w.Body.String()) w = httptest.NewRecorder() suite.checkReady(w, req) suite.Equal(http.StatusOK, w.Code) suite.Equal(marshal(suite.T(), HealthStatus{ Ready: true, DataStoreError: nil, CacheStoreError: nil, DataStoreConnected: true, CacheStoreConnected: true, }), w.Body.String()) dataClient, cacheClient := suite.DataClient, suite.CacheClient suite.DataClient, suite.CacheClient = data.NoDatabase{}, cache.NoDatabase{} w = httptest.NewRecorder() suite.checkLive(w, req) suite.Equal(http.StatusOK, w.Code) suite.Equal(marshal(suite.T(), HealthStatus{ Ready: false, DataStoreError: data.ErrNoDatabase, CacheStoreError: cache.ErrNoDatabase, DataStoreConnected: false, CacheStoreConnected: false, }), w.Body.String()) w = httptest.NewRecorder() suite.checkReady(w, req) suite.Equal(http.StatusServiceUnavailable, w.Code) suite.DataClient, suite.CacheClient = dataClient, cacheClient } func TestWorker(t *testing.T) { suite.Run(t, new(WorkerTestSuite)) }