[
  {
    "path": ".github/CODEOWNERS",
    "content": "* @voxel51/developers\n\n# Aloha!\n.github/        @voxel51/aloha-shirts\npyproject.toml  @voxel51/aloha-shirts\nRELEASING.md    @voxel51/aloha-shirts\nsetup.py        @voxel51/aloha-shirts\n"
  },
  {
    "path": ".github/dependabot.yml",
    "content": "---\nversion: 2\nupdates:\n  - package-ecosystem: \"github-actions\"\n    directory: \"/\"\n    schedule:\n      interval: \"weekly\"\n      day: \"wednesday\"\n      time: \"14:00\"\n      timezone: \"UTC\"\n"
  },
  {
    "path": ".github/pull_request_template.md",
    "content": "# Rationale\n\n<!-- Explain why you are making this change. Describe the problem. -->\n\n## Changes\n\n<!-- Describe the changes. -->\n\n## Testing\n\n<!-- Describe the way the changes were tested. -->\n\n<!-- Optional Sections:\n\n## Screenshots\n## To Do\n## Notes\n## Related\n\n-->\n\n<!-- Template for collapsed sections\n<details>\n<summary></summary>\n</details>\n-->"
  },
  {
    "path": ".github/workflows/build.yml",
    "content": "name: Build\n\non:\n  pull_request:\n    branches:\n      - develop\n    types: [opened, synchronize]\n  push:\n    branches:\n      - develop\n    tags:\n      - v*\n\njobs:\n  build:\n    runs-on: ubuntu-latest\n    steps:\n      - name: Clone fiftyone-brain\n        uses: actions/checkout@v6\n        with:\n          submodules: true\n      - name: Set up Python\n        uses: actions/setup-python@v6\n        with:\n          python-version: 3.9\n      - name: Check Python version\n        run: |\n          python --version\n          pip --version\n      - name: Install dependencies\n        run: |\n          pip install --upgrade pip setuptools wheel\n          pip install -r requirements/build.txt\n      - name: Set environment\n        env:\n          RELEASE_TAG: ${{ github.ref }}\n        run: |\n          if [[ $RELEASE_TAG =~ ^refs\\/tags\\/v.* ]]; then\n             echo \"RELEASE_VERSION=$(echo '${{ github.ref }}' | sed 's/^refs\\/tags\\/v//')\" >> $GITHUB_ENV\n          fi\n      - name: Build wheel\n        run: |\n          python setup.py sdist bdist_wheel\n      - name: Upload wheel\n        uses: actions/upload-artifact@v7\n        with:\n          name: dist\n          path: dist/\n          retention-days: 1\n\n  test:\n    needs: [build]\n    runs-on: ubuntu-latest\n    env:\n      FIFTYONE_DATASET_ZOO_DIR: ${{ github.workspace }}/.fiftyone\n      FIFTYONE_DO_NOT_TRACK: true\n      FIFTYONE_MODEL_ZOO_DIR: ${{ github.workspace }}/.fiftyone\n    permissions:\n      contents: read\n      id-token: write\n    strategy:\n      fail-fast: false\n      matrix:\n        python:\n          - \"3.9\"\n          - \"3.10\"\n          - \"3.11\"\n    steps:\n      - name: Clone fiftyone-brain\n        uses: actions/checkout@v6\n        with:\n          submodules: true\n      - name: Clone fiftyone\n        uses: actions/checkout@v6\n        with:\n          fetch-depth: 1\n          path: fiftyone-src\n          ref: develop\n          repository: voxel51/fiftyone\n      - name: Clone voxel51-eta\n        uses: actions/checkout@v6\n        if: ${{ !startsWith(github.ref, 'refs/heads/rel') && !startsWith(github.ref, 'refs/tags/') }}\n        with:\n          fetch-depth: 1\n          path: eta\n          ref: develop\n          repository: voxel51/eta\n      # ETA tests will create a storage client which, \n      # in it's __init__, tries to log in to GCP.\n      # See tests/tests_uniqueness.py\n      - name: Authenticate to Google Cloud\n        uses: google-github-actions/auth@v3\n        with:\n          project_id: ${{ secrets.REPO_GCP_PROJECT }}\n          service_account: ${{ secrets.REPO_GCP_SERVICE_ACCOUNT }}\n          workload_identity_provider: ${{ secrets.REPO_GOOGLE_WORKLOAD_IDP }}\n      - name: Set Up Cloud SDK\n        uses: google-github-actions/setup-gcloud@v3\n      - name: Set up Python ${{ matrix.python }}\n        uses: actions/setup-python@v6\n        with:\n          python-version: ${{ matrix.python }}\n      - name: Free Disk Space (Ubuntu) # standard runner's 14 GB available disk size isn't enough. Need at least 22 GB free.\n        uses: jlumbroso/free-disk-space@v1.3.1\n      - name: Install dependencies\n        run: |\n          pip install --upgrade pip setuptools wheel\n      - name: Download fiftyone-brain wheel\n        uses: actions/download-artifact@v8\n        with:\n          name: dist\n          path: dist/\n      - name: Install fiftyone\n        working-directory: fiftyone-src\n        run: |\n          python setup.py bdist_wheel\n          pip install voxel51-eta[storage] fiftyone-db\n          pip install ./dist/*.whl\n      - name: Install ETA from source\n        working-directory: eta\n        # Don't install from source if this is a release.\n        # Install from PyPI\n        if: ${{ !startsWith(github.ref, 'refs/heads/rel') && !startsWith(github.ref, 'refs/tags/') }}\n        run: |\n          echo \"Installing ETA from source because github.ref = ${{ github.ref }} (not a release)\"\n          python setup.py bdist_wheel\n          pip install ./dist/*.whl --force-reinstall\n      - name: Reinstall fiftyone-brain\n        run: |\n          pip install --force-reinstall --no-deps dist/*.whl\n      - name: Install test dependencies\n        run: |\n          pip install imageio pytest torch torchvision\n      - name: Cache Zoo\n        id: fiftyone-cache\n        uses: actions/cache@v5\n        with:\n          path: |\n            .fiftyone\n          key: zoo-${{ hashFiles('tests/**') }}\n      - name: Run tests\n        run: |\n          pytest --verbose tests/ --ignore tests/intensive/\n\n  publish:\n    needs: [build, test]\n    if: startsWith(github.ref, 'refs/tags/v')\n    runs-on: ubuntu-latest\n    environment: release # For trusted publishing. See below.\n    permissions:\n      contents: read\n      id-token: write\n    steps:\n      - name: Download wheels\n        uses: actions/download-artifact@v8\n        with:\n          name: dist\n          path: dist/\n      # Utilize\n      # [trusted publishers](https://docs.pypi.org/trusted-publishers/)\n      # This will use OIDC to publish the dists/ package to pypi.\n      # See\n      # [fiftyone-brain](https://pypi.org/manage/project/fiftyone-brain/settings/publishing/)\n      - name: Publish\n        uses: pypa/gh-action-pypi-publish@v1.14.0\n"
  },
  {
    "path": ".gitignore",
    "content": "__pycache__\n.DS_store\n.ipynb_checkpoints\n\n*~\n*.egg-info\n*.py[cod]\n*.pth\n*.swp\n\n.idea\n.project\n.pydevproject\n\nbuild/\ndist/\n\n/fiftyone/brain/internal/models/cache/**/*\n!/fiftyone/brain/internal/models/cache/manifest.json\n*.pth\n"
  },
  {
    "path": ".pre-commit-config.yaml",
    "content": "repos:\n  - repo: https://github.com/asottile/blacken-docs\n    rev: v1.12.0\n    hooks:\n      - id: blacken-docs\n        additional_dependencies: [black==21.12b0]\n        args: [\"-l 79\"]\n  - repo: https://github.com/ambv/black\n    rev: 22.3.0\n    hooks:\n      - id: black\n        language_version: python3\n        args: [\"-l 79\"]\n  - repo: local\n    hooks:\n      - id: pylint\n        name: pylint\n        language: system\n        files: \\.py$\n        entry: pylint\n        args: [\"--errors-only\"]\n  - repo: local\n    hooks:\n      - id: ipynb-strip\n        name: ipynb-strip\n        language: system\n        files: \\.ipynb$\n        entry: jupyter nbconvert --clear-output --ClearOutputPreprocessor.enabled=True\n        args: [\"--log-level=ERROR\"]\n  - repo: https://github.com/pre-commit/mirrors-prettier\n    rev: v2.6.2\n    hooks:\n      - id: prettier\n        language_version: system\n"
  },
  {
    "path": ".prettierrc",
    "content": "{\n  \"overrides\": [\n    {\n      \"files\": \"*.md\",\n      \"options\": {\n        \"printWidth\": 79,\n        \"proseWrap\": \"always\",\n        \"tabWidth\": 4\n      }\n    },\n    {\n      \"files\": \"*.json\",\n      \"options\": {\n        \"tabWidth\": 4\n      }\n    }\n  ]\n}\n"
  },
  {
    "path": "CONTRIBUTING.md",
    "content": "# Contributing to FiftyOne Brain\n\nAll Brain contributions should follow the practices established in\n[FiftyOne](https://github.com/voxel51/fiftyone/blob/develop/CONTRIBUTING.md).\n\n## Adding new public methods to the Brain package\n\nThe `fiftyone.brain` package should expose all core user-functionality at the\nbase level. For example, for hardness, the user should be able to execute calls\nin the following way:\n\n```py\n# Users should be able to do this\nimport fiftyone.brain as fob\n\nfob.compute_hardness(...)\n\n# And NOT have to do this\nimport fiftyone.brain.hardness as fobh\n\nfobh.compute_hardness(...)\n```\n\nTo achieve this, follow the existing pattern of declaring new public methods in\n[`fiftyone/brain/__init__.py`](https://github.com/voxel51/fiftyone-brain/blob/develop/fiftyone/brain/__init__.py).\n\nBe sure to include a detailed docstring for all methods in this file, as they\nare pulled in by FiftyOne documentation builds and are made available in the\n[public docs](https://docs.voxel51.com/api/fiftyone.brain.html).\n"
  },
  {
    "path": "LICENSE",
    "content": "\nApache License\nVersion 2.0, January 2004\nhttp://www.apache.org/licenses/\n\nTERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n1. Definitions.\n\n\"License\" shall mean the terms and conditions for use, reproduction,\nand distribution as defined by Sections 1 through 9 of this document.\n\n\"Licensor\" shall mean the copyright owner or entity authorized by\nthe copyright owner that is granting the License.\n\n\"Legal Entity\" shall mean the union of the acting entity and all\nother entities that control, are controlled by, or are under common\ncontrol with that entity. For the purposes of this definition,\n\"control\" means (i) the power, direct or indirect, to cause the\ndirection or management of such entity, whether by contract or\notherwise, or (ii) ownership of fifty percent (50%) or more of the\noutstanding shares, or (iii) beneficial ownership of such entity.\n\n\"You\" (or \"Your\") shall mean an individual or Legal Entity\nexercising permissions granted by this License.\n\n\"Source\" form shall mean the preferred form for making modifications,\nincluding but not limited to software source code, documentation\nsource, and configuration files.\n\n\"Object\" form shall mean any form resulting from mechanical\ntransformation or translation of a Source form, including but\nnot limited to compiled object code, generated documentation,\nand conversions to other media types.\n\n\"Work\" shall mean the work of authorship, whether in Source or\nObject form, made available under the License, as indicated by a\ncopyright notice that is included in or attached to the work\n(an example is provided in the Appendix below).\n\n\"Derivative Works\" shall mean any work, whether in Source or Object\nform, that is based on (or derived from) the Work and for which the\neditorial revisions, annotations, elaborations, or other modifications\nrepresent, as a whole, an original work of authorship. For the purposes\nof this License, Derivative Works shall not include works that remain\nseparable from, or merely link (or bind by name) to the interfaces of,\nthe Work and Derivative Works thereof.\n\n\"Contribution\" shall mean any work of authorship, including\nthe original version of the Work and any modifications or additions\nto that Work or Derivative Works thereof, that is intentionally\nsubmitted to Licensor for inclusion in the Work by the copyright owner\nor by an individual or Legal Entity authorized to submit on behalf of\nthe copyright owner. For the purposes of this definition, \"submitted\"\nmeans any form of electronic, verbal, or written communication sent\nto the Licensor or its representatives, including but not limited to\ncommunication on electronic mailing lists, source code control systems,\nand issue tracking systems that are managed by, or on behalf of, the\nLicensor for the purpose of discussing and improving the Work, but\nexcluding communication that is conspicuously marked or otherwise\ndesignated in writing by the copyright owner as \"Not a Contribution.\"\n\n\"Contributor\" shall mean Licensor and any individual or Legal Entity\non behalf of whom a Contribution has been received by Licensor and\nsubsequently incorporated within the Work.\n\n2. Grant of Copyright License. Subject to the terms and conditions of\nthis License, each Contributor hereby grants to You a perpetual,\nworldwide, non-exclusive, no-charge, royalty-free, irrevocable\ncopyright license to reproduce, prepare Derivative Works of,\npublicly display, publicly perform, sublicense, and distribute the\nWork and such Derivative Works in Source or Object form.\n\n3. Grant of Patent License. Subject to the terms and conditions of\nthis License, each Contributor hereby grants to You a perpetual,\nworldwide, non-exclusive, no-charge, royalty-free, irrevocable\n(except as stated in this section) patent license to make, have made,\nuse, offer to sell, sell, import, and otherwise transfer the Work,\nwhere such license applies only to those patent claims licensable\nby such Contributor that are necessarily infringed by their\nContribution(s) alone or by combination of their Contribution(s)\nwith the Work to which such Contribution(s) was submitted. If You\ninstitute patent litigation against any entity (including a\ncross-claim or counterclaim in a lawsuit) alleging that the Work\nor a Contribution incorporated within the Work constitutes direct\nor contributory patent infringement, then any patent licenses\ngranted to You under this License for that Work shall terminate\nas of the date such litigation is filed.\n\n4. Redistribution. You may reproduce and distribute copies of the\nWork or Derivative Works thereof in any medium, with or without\nmodifications, and in Source or Object form, provided that You\nmeet the following conditions:\n\n(a) You must give any other recipients of the Work or\nDerivative Works a copy of this License; and\n\n(b) You must cause any modified files to carry prominent notices\nstating that You changed the files; and\n\n(c) You must retain, in the Source form of any Derivative Works\nthat You distribute, all copyright, patent, trademark, and\nattribution notices from the Source form of the Work,\nexcluding those notices that do not pertain to any part of\nthe Derivative Works; and\n\n(d) If the Work includes a \"NOTICE\" text file as part of its\ndistribution, then any Derivative Works that You distribute must\ninclude a readable copy of the attribution notices contained\nwithin such NOTICE file, excluding those notices that do not\npertain to any part of the Derivative Works, in at least one\nof the following places: within a NOTICE text file distributed\nas part of the Derivative Works; within the Source form or\ndocumentation, if provided along with the Derivative Works; or,\nwithin a display generated by the Derivative Works, if and\nwherever such third-party notices normally appear. The contents\nof the NOTICE file are for informational purposes only and\ndo not modify the License. You may add Your own attribution\nnotices within Derivative Works that You distribute, alongside\nor as an addendum to the NOTICE text from the Work, provided\nthat such additional attribution notices cannot be construed\nas modifying the License.\n\nYou may add Your own copyright statement to Your modifications and\nmay provide additional or different license terms and conditions\nfor use, reproduction, or distribution of Your modifications, or\nfor any such Derivative Works as a whole, provided Your use,\nreproduction, and distribution of the Work otherwise complies with\nthe conditions stated in this License.\n\n5. Submission of Contributions. Unless You explicitly state otherwise,\nany Contribution intentionally submitted for inclusion in the Work\nby You to the Licensor shall be under the terms and conditions of\nthis License, without any additional terms or conditions.\nNotwithstanding the above, nothing herein shall supersede or modify\nthe terms of any separate license agreement you may have executed\nwith Licensor regarding such Contributions.\n\n6. Trademarks. This License does not grant permission to use the trade\nnames, trademarks, service marks, or product names of the Licensor,\nexcept as required for reasonable and customary use in describing the\norigin of the Work and reproducing the content of the NOTICE file.\n\n7. Disclaimer of Warranty. Unless required by applicable law or\nagreed to in writing, Licensor provides the Work (and each\nContributor provides its Contributions) on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\nimplied, including, without limitation, any warranties or conditions\nof TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\nPARTICULAR PURPOSE. You are solely responsible for determining the\nappropriateness of using or redistributing the Work and assume any\nrisks associated with Your exercise of permissions under this License.\n\n8. Limitation of Liability. In no event and under no legal theory,\nwhether in tort (including negligence), contract, or otherwise,\nunless required by applicable law (such as deliberate and grossly\nnegligent acts) or agreed to in writing, shall any Contributor be\nliable to You for damages, including any direct, indirect, special,\nincidental, or consequential damages of any character arising as a\nresult of this License or out of the use or inability to use the\nWork (including but not limited to damages for loss of goodwill,\nwork stoppage, computer failure or malfunction, or any and all\nother commercial damages or losses), even if such Contributor\nhas been advised of the possibility of such damages.\n\n9. Accepting Warranty or Additional Liability. While redistributing\nthe Work or Derivative Works thereof, You may choose to offer,\nand charge a fee for, acceptance of support, warranty, indemnity,\nor other liability obligations and/or rights consistent with this\nLicense. However, in accepting such obligations, You may act only\non Your own behalf and on Your sole responsibility, not on behalf\nof any other Contributor, and only if You agree to indemnify,\ndefend, and hold each Contributor harmless for any liability\nincurred by, or claims asserted against, such Contributor by reason\nof your accepting any such warranty or additional liability.\n\nEND OF TERMS AND CONDITIONS\n\nAPPENDIX: How to apply the Apache License to your work.\n\nTo apply the Apache License to your work, attach the following\nboilerplate notice, with the fields enclosed by brackets \"[]\"\nreplaced with your own identifying information. (Don't include\nthe brackets!)  The text should be enclosed in the appropriate\ncomment syntax for the file format. We also recommend that a\nfile or class name and description of purpose be included on the\nsame \"printed page\" as the copyright notice for easier\nidentification within third-party archives.\n\nCopyright 2017-2026, Voxel51, Inc.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License."
  },
  {
    "path": "MANIFEST.in",
    "content": "global-include *\n\nprune fiftyone/brain/internal/models/cache/\ninclude fiftyone/brain/internal/models/cache/manifest.json\n"
  },
  {
    "path": "README.md",
    "content": "<div align=\"center\">\n<p align=\"center\">\n\n<img src=\"https://github.com/user-attachments/assets/17afdf93-289c-40f1-805c-06344f095cf6\" height=\"55px\">\n\n**Open Source AI from [Voxel51](https://voxel51.com)**\n\n<!-- prettier-ignore -->\n<a href=\"https://voxel51.com/fiftyone\">FiftyOne Website</a> •\n<a href=\"https://voxel51.com/docs/fiftyone\">FiftyOne Docs</a> •\n<a href=\"https://docs.voxel51.com/brain.html\">FiftyOne Brain Docs</a> •\n<a href=\"https://voxel51.com/blog/\">Blog</a> •\n<a href=\"https://slack.voxel51.com\">Community</a>\n\n[![PyPI python](https://img.shields.io/pypi/pyversions/fiftyone-brain)](https://pypi.org/project/fiftyone-brain)\n[![PyPI version](https://badge.fury.io/py/fiftyone-brain.svg)](https://pypi.org/project/fiftyone-brain)\n[![Downloads](https://static.pepy.tech/badge/fiftyone-brain)](https://pepy.tech/project/fiftyone-brain)\n[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](LICENSE)\n\n[![Discord](https://img.shields.io/badge/Discord-7289DA?logo=discord&logoColor=white)](https://discord.gg/fiftyone-community)\n[![Hugging Face](https://img.shields.io/badge/Hugging_Face-purple?style=flat&logo=huggingface)](https://huggingface.co/Voxel51)\n[![Voxel51 Blog](https://img.shields.io/badge/Voxel51_Blog-ff6d04?style=flat)](https://voxel51.com/blog)\n[![Newsletter](https://img.shields.io/badge/Newsletter-BE5B25?logo=mail.ru&logoColor=white)](https://share.hsforms.com/1zpJ60ggaQtOoVeBqIZdaaA2ykyk)\n[![LinkedIn](https://img.shields.io/badge/In-white?style=flat&label=Linked&labelColor=blue)](https://www.linkedin.com/company/voxel51)\n[![Twitter](https://img.shields.io/badge/Twitter-000000?logo=x&logoColor=white)](https://x.com/voxel51)\n[![Medium](https://img.shields.io/badge/Medium-12100E?logo=medium&logoColor=white)](https://medium.com/voxel51)\n\n</p>\n</div>\n\n---\n\nFiftyOne Brain contains the open source AI/ML capabilities for the\n[FiftyOne ecosystem](https://github.com/voxel51/fiftyone), enabling users to\nautomatically analyze and manipulate their datasets and models. FiftyOne Brain\nincludes features like visual similarity search, query by text, finding unique\nand representative samples, finding media quality problems and annotation\nmistakes, and more 🚀\n\n## Documentation\n\nPublic documentation for the FiftyOne Brain is\n[available here](https://docs.voxel51.com/user_guide/brain.html).\n\n## Installation\n\nThe FiftyOne Brain is distributed via the `fiftyone-brain` package, and a\nsuitable version is automatically included with every `fiftyone` install:\n\n```shell\npip install fiftyone\npip show fiftyone-brain\n```\n\n### Installing from source\n\nIf you wish to do a source install of the latest FiftyOne Brain version, simply\nclone this repository:\n\n```shell\ngit clone https://github.com/voxel51/fiftyone-brain\ncd fiftyone-brain\n```\n\nand run the install script:\n\n```shell\n# Mac or Linux\nbash install.sh\n\n# Windows\n.\\install.bat\n```\n\n### Developer installation\n\nIf you are a developer contributing to this repository, you should perform a\ndeveloper installation using the `-d` flag of the install script:\n\n```shell\n# Mac or Linux\nbash install.sh -d\n\n# Windows\n.\\install.bat -d\n```\n\nCheck out the [contribution guide](CONTRIBUTING.md) to get started.\n\n## Uninstallation\n\n```shell\npip uninstall fiftyone-brain\n```\n\n## Repository layout\n\n-   `fiftyone/brain/` definition of the `fiftyone.brain` namespace\n-   `requirements/` Python requirements for the project\n-   `tests/` tests for the various components of the Brain\n\n## Citation\n\nIf you use the FiftyOne Brain in your research, please cite the project:\n\n```bibtex\n@article{moore2020fiftyone,\n  title={FiftyOne},\n  author={Moore, B. E. and Corso, J. J.},\n  journal={GitHub. Note: https://github.com/voxel51/fiftyone-brain},\n  year={2020}\n}\n```\n"
  },
  {
    "path": "RELEASING.md",
    "content": "# Releasing the Brain package\n\n> [!NOTE]\n> These steps are to be performed by authorized Voxel51 engineers.\n\nThe `fiftyone-brain` repository follows `Gitflow`.\nReleases will be initiated when a teammate submits a \npull request from their respective `release/v*` branch to `main`.\nWe can see an example PR for\n[version 0.21.4](https://github.com/voxel51/fiftyone-brain/pull/265). \nReviewers should always check that the version in the `setup.py`\nmatches the branch version.\n\nThe release engineer will merge the pull request once it is approved.\n\nThe PyPI uploads will be triggered when a release tag is pushed to the\nrepository:\n\n1. Navigate to the\n   [releases page](https://github.com/voxel51/fiftyone-brain/pull/265).\n\n1. Select `Draft a new release`.\n\n1. Select `Create new tag` with the appropriate version and set the target to\n   `main`.\n\n    1. The tag format is `v<semantic-version>`.\n       For example, `v0.21.4`. \n       This should match the `setup.py` and release branch.\n\n1. Select `Generate release notes`.\n\n1. Select `Set as the latest release`.\n\n1. Select `Publish release`.\n\nThis will create a new tag in the repository and will trigger the\n[build/publish workflow](https://github.com/voxel51/fiftyone-brain/actions/workflows/build.yml).\nThis workflow will build the `.whl` artifacts and publish them to\n[PyPI](https://pypi.org/project/fiftyone-brain/).\n\nOnce the build are finished, submit a PR from `main` to `develop` to complete\nthe `Gitflow` process."
  },
  {
    "path": "STYLE_GUIDE.md",
    "content": "# FiftyOne Brain Style Guide\n\nThe Brain follows the same style guidelines as\n[FiftyOne](https://github.com/voxel51/fiftyone/blob/develop/STYLE_GUIDE.md).\n"
  },
  {
    "path": "fiftyone/__init__.py",
    "content": "from pkgutil import extend_path\n\n#\n# This statement allows multiple `fiftyone.XXX` packages to be installed in the\n# same environment and used simultaneously.\n#\n# https://docs.python.org/3/library/pkgutil.html#pkgutil.extend_path\n#\n__path__ = extend_path(__path__, __name__)\n\nfrom fiftyone.__public__ import *\n"
  },
  {
    "path": "fiftyone/brain/__init__.py",
    "content": "\"\"\"\nThe brains behind FiftyOne: a powerful package for dataset curation, analysis,\nand visualization.\n\nSee https://github.com/voxel51/fiftyone for more information.\n\n| Copyright 2017-2026, Voxel51, Inc.\n| `voxel51.com <https://voxel51.com/>`_\n|\n\"\"\"\nimport fiftyone.brain.config as _foc\n\nfrom .similarity import (\n    Similarity,\n    SimilarityConfig,\n    SimilarityIndex,\n)\nfrom .visualization import (\n    Visualization,\n    VisualizationConfig,\n    VisualizationResults,\n)\n\n\nbrain_config = _foc.load_brain_config()\n\n\ndef compute_hardness(\n    samples,\n    label_field,\n    hardness_field=\"hardness\",\n    progress=None,\n):\n    \"\"\"Adds a hardness field to each sample scoring the difficulty that the\n    specified label field observed in classifying the sample.\n\n    Hardness is a measure computed based on model prediction output (through\n    logits) that summarizes a measure of the uncertainty the model had with the\n    sample. This makes hardness quantitative and can be used to detect things\n    like hard samples, annotation errors during noisy training, and more.\n\n    All classifications must have their\n    :attr:`logits <fiftyone.core.labels.Classification.logits>` attributes\n    populated in order to use this method.\n\n    .. note::\n\n        Runs of this method can be referenced later via brain key\n        ``hardness_field``.\n\n    Args:\n        samples: a :class:`fiftyone.core.collections.SampleCollection`\n        label_field: the :class:`fiftyone.core.labels.Classification` or\n            :class:`fiftyone.core.labels.Classifications` field to use from\n            each sample\n        hardness_field (\"hardness\"): the field name to use to store the\n            hardness value for each sample\n        progress (None): whether to render a progress bar (True/False), use the\n            default value ``fiftyone.config.show_progress_bars`` (None), or a\n            progress callback function to invoke instead\n    \"\"\"\n    import fiftyone.brain.internal.core.hardness as fbh\n\n    return fbh.compute_hardness(samples, label_field, hardness_field, progress)\n\n\ndef compute_mistakenness(\n    samples,\n    pred_field,\n    label_field,\n    mistakenness_field=\"mistakenness\",\n    missing_field=\"possible_missing\",\n    spurious_field=\"possible_spurious\",\n    use_logits=False,\n    copy_missing=False,\n    progress=None,\n):\n    \"\"\"Computes the mistakenness (likelihood of being incorrect) of the labels\n    in ``label_field`` based on the predcted labels in ``pred_field``.\n\n    Mistakenness is measured based on either the ``confidence`` or ``logits``\n    of the predictions in ``pred_field``. This measure can be used to detect\n    things like annotation errors and unusually hard samples.\n\n    For classifications, a ``mistakenness_field`` field is populated on each\n    sample that quantifies the likelihood that the label in the ``label_field``\n    of that sample is incorrect.\n\n    For objects (detections, polylines, keypoints, etc), the mistakenness of\n    each object in ``label_field`` is computed, using\n    :meth:`fiftyone.core.collections.SampleCollection.evaluate_detections` to\n    locate corresponding objects in ``pred_field``. Three types of mistakes\n    are identified:\n\n    -   **(Mistakes)** Objects in ``label_field`` with a match in\n        ``pred_field`` are assigned a mistakenness value in their\n        ``mistakenness_field`` that captures the likelihood that the class\n        label of the object in ``label_field`` is a mistake. A\n        ``mistakenness_field + \"_loc\"`` field is also populated that captures\n        the likelihood that the object in ``label_field`` is a mistake due\n        to its localization (bounding box).\n\n    -   **(Missing)** Objects in ``pred_field`` with no matches in\n        ``label_field`` but which are likely to be correct will have their\n        ``missing_field`` attribute set to True. In addition, if\n        ``copy_missing`` is True, copies of these objects are *added* to the\n        ground truth ``label_field``.\n\n    -   **(Spurious)** Objects in ``label_field`` with no matches in\n        ``pred_field`` but which are likely to be incorrect will have their\n        ``spurious_field`` attribute set to True.\n\n    In addition, for objects, the following sample-level fields are populated:\n\n    -   **(Mistakes)** The ``mistakenness_field`` of each sample is populated\n        with the maximum mistakenness of the objects in ``label_field``\n\n    -   **(Missing)** The ``missing_field`` of each sample is populated with\n        the number of missing objects that were deemed missing from\n        ``label_field``.\n\n    -   **(Spurious)** The ``spurious_field`` of each sample is populated with\n        the number of objects in ``label_field`` that were given deemed\n        spurious.\n\n    .. note::\n\n        Runs of this method can be referenced later via brain key\n        ``mistakenness_field``.\n\n    Args:\n        samples: a :class:`fiftyone.core.collections.SampleCollection`\n        pred_field: the name of the predicted label field to use from each\n            sample. Can be of type\n            :class:`fiftyone.core.labels.Classification`,\n            :class:`fiftyone.core.labels.Classifications`,\n            :class:`fiftyone.core.labels.Detections`,\n            :class:`fiftyone.core.labels.Polylines`,\n            :class:`fiftyone.core.labels.Keypoints`, or\n            :class:`fiftyone.core.labels.TemporalDetections`\n        label_field: the name of the \"ground truth\" label field that you want\n            to test for mistakes with respect to the predictions in\n            ``pred_field``. Must have the same type as ``pred_field``\n        mistakenness_field (\"mistakenness\"): the field name to use to store the\n            mistakenness value for each sample\n        missing_field (\"possible_missing): the field in which to store\n            per-sample counts of potential missing objects\n        spurious_field (\"possible_spurious): the field in which to store\n            per-sample counts of potential spurious objects\n        use_logits (False): whether to use logits (True) or confidence (False)\n            to compute mistakenness. Logits typically yield better results,\n            when they are available\n        copy_missing (False): whether to copy predicted objects that were\n            deemed to be missing into ``label_field``\n        progress (None): whether to render a progress bar (True/False), use the\n            default value ``fiftyone.config.show_progress_bars`` (None), or a\n            progress callback function to invoke instead\n    \"\"\"\n    import fiftyone.brain.internal.core.mistakenness as fbm\n\n    return fbm.compute_mistakenness(\n        samples,\n        pred_field,\n        label_field,\n        mistakenness_field,\n        missing_field,\n        spurious_field,\n        use_logits,\n        copy_missing,\n        progress,\n    )\n\n\ndef compute_uniqueness(\n    samples,\n    uniqueness_field=\"uniqueness\",\n    roi_field=None,\n    embeddings=None,\n    similarity_index=None,\n    model=None,\n    model_kwargs=None,\n    force_square=False,\n    alpha=None,\n    batch_size=None,\n    num_workers=None,\n    skip_failures=True,\n    progress=None,\n):\n    \"\"\"Adds a uniqueness field to each sample scoring how unique it is with\n    respect to the rest of the samples.\n\n    This function only uses the pixel data and can therefore process labeled or\n    unlabeled samples.\n\n    If no ``embeddings``, ``similarity_index``, or ``model`` is provided, a\n    default model is used to generate embeddings.\n\n    .. note::\n\n        Runs of this method can be referenced later via brain key\n        ``uniqueness_field``.\n\n    Args:\n        samples: a :class:`fiftyone.core.collections.SampleCollection`\n        uniqueness_field (\"uniqueness\"): the field name to use to store the\n            uniqueness value for each sample\n        roi_field (None): an optional :class:`fiftyone.core.labels.Detection`,\n            :class:`fiftyone.core.labels.Detections`,\n            :class:`fiftyone.core.labels.Polyline`, or\n            :class:`fiftyone.core.labels.Polylines` field defining a region of\n            interest within each image to use to compute uniqueness\n        embeddings (None): if no ``model`` is provided, this argument specifies\n            pre-computed embeddings to use, which can be any of the following:\n\n            -   a ``num_samples x num_dims`` array of embeddings\n            -   if ``roi_field`` is specified,  a dict mapping sample IDs to\n                ``num_patches x num_dims`` arrays of patch embeddings\n            -   the name of a dataset field containing the embeddings to use\n\n            If a ``model`` is provided, this argument specifies the name of a\n            field in which to store the computed embeddings. In either case,\n            when working with patch embeddings, you can provide either the\n            fully-qualified path to the patch embeddings or just the name of\n            the label attribute in ``roi_field``\n        similarity_index (None): a\n            :class:`fiftyone.brain.similarity.SimilarityIndex` or the brain key\n            of a similarity index to use to load pre-computed embeddings\n        model (None): a :class:`fiftyone.core.models.Model` or the name of a\n            model from the\n            `FiftyOne Model Zoo <https://docs.voxel51.com/user_guide/model_zoo/models.html>`_\n            to use to generate embeddings. The model must expose embeddings\n            (``model.has_embeddings = True``)\n        model_kwargs (None): a dictionary of optional keyword arguments to pass\n            to the model's ``Config`` when a model name is provided\n        force_square (False): whether to minimally manipulate the patch\n            bounding boxes into squares prior to extraction. Only applicable\n            when a ``model`` and ``roi_field`` are specified\n        alpha (None): an optional expansion/contraction to apply to the patches\n            before extracting them, in ``[-1, inf)``. If provided, the length\n            and width of the box are expanded (or contracted, when\n            ``alpha < 0``) by ``(100 * alpha)%``. For example, set\n            ``alpha = 0.1`` to expand the boxes by 10%, and set\n            ``alpha = -0.1`` to contract the boxes by 10%. Only applicable when\n            a ``model`` and ``roi_field`` are specified\n        batch_size (None): a batch size to use when computing embeddings. Only\n            applicable when a ``model`` is provided\n        num_workers (None): the number of workers to use when loading images.\n            Only applicable when a Torch-based model is being used to compute\n            embeddings\n        skip_failures (True): whether to gracefully continue without raising an\n            error if embeddings cannot be generated for a sample\n        progress (None): whether to render a progress bar (True/False), use the\n            default value ``fiftyone.config.show_progress_bars`` (None), or a\n            progress callback function to invoke instead\n    \"\"\"\n    import fiftyone.brain.internal.core.uniqueness as fbu\n\n    return fbu.compute_uniqueness(\n        samples,\n        uniqueness_field,\n        roi_field,\n        embeddings,\n        similarity_index,\n        model,\n        model_kwargs,\n        force_square,\n        alpha,\n        batch_size,\n        num_workers,\n        skip_failures,\n        progress,\n    )\n\n\ndef compute_representativeness(\n    samples,\n    representativeness_field=\"representativeness\",\n    method=\"cluster-center\",\n    roi_field=None,\n    embeddings=None,\n    similarity_index=None,\n    model=None,\n    model_kwargs=None,\n    force_square=False,\n    alpha=None,\n    batch_size=None,\n    num_workers=None,\n    skip_failures=True,\n    progress=None,\n):\n    \"\"\"Adds a representativeness field to each sample scoring how representative\n    of nearby samples it is.\n\n    This function only uses the pixel data and can therefore process labeled or\n    unlabeled samples.\n\n    If no ``embeddings``, ``similarity_index``, or ``model`` is provided, a\n    default model is used to generate embeddings.\n\n    .. note::\n\n        Runs of this method can be referenced later via brain key\n        ``representativeness_field``.\n\n    Args:\n        samples: a :class:`fiftyone.core.collections.SampleCollection`\n        representativeness_field (\"representativeness\"): the field name to use\n            to store the representativeness value for each sample\n        method (\"cluster-center\"): the name of the method to use to compute the\n            representativeness. The supported values are\n            ``[\"cluster-center\", 'cluster-center-downweight']``.\n            ``\"cluster-center\"` will make a sample's representativeness\n            proportional to it's proximity to cluster centers, while\n            ``\"cluster-center-downweight\"`` will ensure more diversity in\n            representative samples\n        roi_field (None): an optional :class:`fiftyone.core.labels.Detection`,\n            :class:`fiftyone.core.labels.Detections`,\n            :class:`fiftyone.core.labels.Polyline`, or\n            :class:`fiftyone.core.labels.Polylines` field defining a region of\n            interest within each image to use to compute representativeness\n        embeddings (None): if no ``model`` is provided, this argument specifies\n            pre-computed embeddings to use, which can be any of the following:\n\n            -   a ``num_samples x num_dims`` array of embeddings\n            -   if ``roi_field`` is specified,  a dict mapping sample IDs to\n                ``num_patches x num_dims`` arrays of patch embeddings\n            -   the name of a dataset field containing the embeddings to use\n\n            If a ``model`` is provided, this argument specifies the name of a\n            field in which to store the computed embeddings. In either case,\n            when working with patch embeddings, you can provide either the\n            fully-qualified path to the patch embeddings or just the name of\n            the label attribute in ``roi_field``\n        similarity_index (None): a\n            :class:`fiftyone.brain.similarity.SimilarityIndex` or the brain key\n            of a similarity index to use to load pre-computed embeddings\n        model (None): a :class:`fiftyone.core.models.Model` or the name of a\n            model from the\n            `FiftyOne Model Zoo <https://docs.voxel51.com/user_guide/model_zoo/models.html>`_\n            to use to generate embeddings. The model must expose embeddings\n            (``model.has_embeddings = True``)\n        model_kwargs (None): a dictionary of optional keyword arguments to pass\n            to the model's ``Config`` when a model name is provided\n        force_square (False): whether to minimally manipulate the patch\n            bounding boxes into squares prior to extraction. Only applicable\n            when a ``model`` and ``roi_field`` are specified\n        alpha (None): an optional expansion/contraction to apply to the patches\n            before extracting them, in ``[-1, inf)``. If provided, the length\n            and width of the box are expanded (or contracted, when\n            ``alpha < 0``) by ``(100 * alpha)%``. For example, set\n            ``alpha = 0.1`` to expand the boxes by 10%, and set\n            ``alpha = -0.1`` to contract the boxes by 10%. Only applicable when\n            a ``model`` and ``roi_field`` are specified\n        batch_size (None): a batch size to use when computing embeddings. Only\n            applicable when a ``model`` is provided\n        num_workers (None): the number of workers to use when loading images.\n            Only applicable when a Torch-based model is being used to compute\n            embeddings\n        skip_failures (True): whether to gracefully continue without raising an\n            error if embeddings cannot be generated for a sample\n        progress (None): whether to render a progress bar (True/False), use the\n            default value ``fiftyone.config.show_progress_bars`` (None), or a\n            progress callback function to invoke instead\n    \"\"\"\n    import fiftyone.brain.internal.core.representativeness as fbr\n\n    return fbr.compute_representativeness(\n        samples,\n        representativeness_field,\n        method,\n        roi_field,\n        embeddings,\n        similarity_index,\n        model,\n        model_kwargs,\n        force_square,\n        alpha,\n        batch_size,\n        num_workers,\n        skip_failures,\n        progress,\n    )\n\n\ndef compute_visualization(\n    samples,\n    patches_field=None,\n    embeddings=None,\n    points=None,\n    create_index=False,\n    points_field=None,\n    brain_key=None,\n    num_dims=2,\n    method=None,\n    similarity_index=None,\n    model=None,\n    model_kwargs=None,\n    force_square=False,\n    alpha=None,\n    batch_size=None,\n    num_workers=None,\n    skip_failures=True,\n    progress=None,\n    **kwargs,\n):\n    \"\"\"Computes a low-dimensional representation of the samples' media or their\n    patches that can be interactively visualized.\n\n    The representation can be visualized by calling the\n    :meth:`visualize() <fiftyone.brain.visualization.VisualizationResults.visualize>`\n    method of the returned\n    :class:`fiftyone.brain.visualization.VisualizationResults` object.\n\n    If no ``embeddings``, ``similarity_index``, or ``model`` is provided, a\n    default model is used to generate embeddings.\n\n    You can use the ``method`` parameter to select the dimensionality reduction\n    method to use, and you can optionally customize the method by passing\n    additional parameters for the method's\n    :class:`fiftyone.brain.visualization.VisualizationConfig` class as\n    ``kwargs``.\n\n    The builtin ``method`` values and their associated config classes are:\n\n    -   ``\"umap\"``: :class:`fiftyone.brain.visualization.UMAPVisualizationConfig`\n    -   ``\"tsne\"``: :class:`fiftyone.brain.visualization.TSNEVisualizationConfig`\n    -   ``\"pca\"``: :class:`fiftyone.brain.visualization.PCAVisualizationConfig`\n    -   ``\"manual\"``: :class:`fiftyone.brain.visualization.ManualVisualizationConfig`\n\n    You can pass ``create_index=True`` to create a spatial index of the\n    computed points on your dataset's samples. This is highly recommended for\n    large datasets as it enables efficient querying when lassoing points in\n    embeddings plots. By default, spatial indexes are created in a field with\n    name ``points_field=brain_key``, but you can customize this by manually\n    providing a ``points_field``.\n\n    You can also provide a ``points_field`` with ``create_index=False`` to\n    store the points on your dataset without explicitly creating a database\n    index. This will allow lasso callbacks to leverage point data rather than\n    relying on ID selection, but without the added benefit of a database index\n    to further optimize performance.\n\n    Args:\n        samples: a :class:`fiftyone.core.collections.SampleCollection`\n        patches_field (None): a sample field defining the image patches in each\n            sample that have been/will be embedded. Must be of type\n            :class:`fiftyone.core.labels.Detection`,\n            :class:`fiftyone.core.labels.Detections`,\n            :class:`fiftyone.core.labels.Polyline`, or\n            :class:`fiftyone.core.labels.Polylines`\n        embeddings (None): if no ``model`` is provided, this argument specifies\n            pre-computed embeddings to use, which can be any of the following:\n\n            -   a dict mapping sample IDs to embedding vectors\n            -   a ``num_samples x num_embedding_dims`` array of embeddings\n                corresponding to the samples in ``samples``\n            -   if ``patches_field`` is specified, a dict mapping label IDs to\n                to embedding vectors\n            -   if ``patches_field`` is specified,  a dict mapping sample IDs\n                to ``num_patches x num_embedding_dims`` arrays of patch\n                embeddings\n            -   the name of a dataset field containing the embeddings to use\n\n            If a ``model`` is provided, this argument specifies the name of a\n            field in which to store the computed embeddings. In either case,\n            when working with patch embeddings, you can provide either the\n            fully-qualified path to the patch embeddings or just the name of\n            the label attribute in ``patches_field``\n        points (None): a pre-computed low-dimensional representation to use. If\n            provided, no embeddings will be used/computed. Can be any of the\n            following:\n\n            -   a dict mapping sample IDs to points vectors\n            -   a ``num_samples x num_dims`` array of points corresponding to\n                the samples in ``samples``\n            -   if ``patches_field`` is specified, a dict mapping label IDs to\n                points vectors\n            -   if ``patches_field`` is specified, a ``num_patches x num_dims``\n                array of points whose rows correspond to the flattened list of\n                patches whose IDs are shown below::\n\n                    # The list of patch IDs that the rows of `points` must match\n                    _, id_field = samples._get_label_field_path(patches_field, \"id\")\n                    patch_ids = samples.values(id_field, unwind=True)\n\n        create_index (False): whether to create a spatial index for the\n            computed points on your dataset\n        points_field (None): an optional field name in which to store the\n            spatial index. When ``create_index=True``, this defaults to\n            ``points_field=brain_key``. When working with patches, you can\n            provide either the fully-qualified path to the points field or just\n            the name of the label attribute in ``patches_field``\n        brain_key (None): a brain key under which to store the results of this\n            method\n        num_dims (2): the dimension of the visualization space\n        method (None): the dimensionality reduction method to use. The\n            supported values are\n            ``fiftyone.brain.brain_config.visualization_methods.keys()`` and\n            the default is\n            ``fiftyone.brain.brain_config.default_visualization_method``\n        similarity_index (None): a\n            :class:`fiftyone.brain.similarity.SimilarityIndex` or the brain key\n            of a similarity index to use to load pre-computed embeddings\n        model (None): a :class:`fiftyone.core.models.Model` or the name of a\n            model from the\n            `FiftyOne Model Zoo <https://docs.voxel51.com/user_guide/model_zoo/index.html>`_\n            to use to generate embeddings. The model must expose embeddings\n            (``model.has_embeddings = True``)\n        model_kwargs (None): a dictionary of optional keyword arguments to pass\n            to the model's ``Config`` when a model name is provided\n        force_square (False): whether to minimally manipulate the patch\n            bounding boxes into squares prior to extraction. Only applicable\n            when a ``model`` and ``patches_field`` are specified\n        alpha (None): an optional expansion/contraction to apply to the patches\n            before extracting them, in ``[-1, inf)``. If provided, the length\n            and width of the box are expanded (or contracted, when\n            ``alpha < 0``) by ``(100 * alpha)%``. For example, set\n            ``alpha = 0.1`` to expand the boxes by 10%, and set\n            ``alpha = -0.1`` to contract the boxes by 10%. Only applicable when\n            a ``model`` and ``patches_field`` are specified\n        batch_size (None): an optional batch size to use when computing\n            embeddings. Only applicable when a ``model`` is provided\n        num_workers (None): the number of workers to use when loading images.\n            Only applicable when a Torch-based model is being used to compute\n            embeddings\n        skip_failures (True): whether to gracefully continue without raising an\n            error if embeddings cannot be generated for a sample\n        progress (None): whether to render a progress bar (True/False), use the\n            default value ``fiftyone.config.show_progress_bars`` (None), or a\n            progress callback function to invoke instead\n        **kwargs: optional keyword arguments for the constructor of the\n            :class:`fiftyone.brain.visualization.VisualizationConfig`\n            being used\n\n    Returns:\n        a :class:`fiftyone.brain.visualization.VisualizationResults`\n    \"\"\"\n    import fiftyone.brain.visualization as fbv\n\n    return fbv.compute_visualization(\n        samples,\n        patches_field,\n        embeddings,\n        points,\n        create_index,\n        points_field,\n        brain_key,\n        num_dims,\n        method,\n        similarity_index,\n        model,\n        model_kwargs,\n        force_square,\n        alpha,\n        batch_size,\n        num_workers,\n        skip_failures,\n        progress,\n        **kwargs,\n    )\n\n\ndef compute_similarity(\n    samples,\n    patches_field=None,\n    roi_field=None,\n    embeddings=None,\n    brain_key=None,\n    model=None,\n    model_kwargs=None,\n    force_square=False,\n    alpha=None,\n    batch_size=None,\n    num_workers=None,\n    skip_failures=True,\n    progress=None,\n    backend=None,\n    **kwargs,\n):\n    \"\"\"Uses embeddings to index the samples or their patches so that you can\n    query/sort by similarity.\n\n    Calling this method only creates the index. You can then call the methods\n    exposed on the retuned :class:`fiftyone.brain.similarity.SimilarityIndex`\n    object to perform the following operations:\n\n    -   :meth:`sort_by_similarity() <fiftyone.brain.similarity.SimilarityIndex.sort_by_similarity>`:\n        Sort the samples in the collection by similarity to a specific example\n        or example(s)\n\n    All indexes support querying by image similarity by passing sample IDs to\n    :meth:`sort_by_similarity() <fiftyone.brain.similarity.SimilarityIndex.sort_by_similarity>`.\n    In addition, if you pass the name of a model from the\n    `FiftyOne Model Zoo <https://docs.voxel51.com/user_guide/model_zoo/index.html>`_\n    like ``model=\"clip-vit-base32-torch\"`` that can embed prompts to this\n    method, then you can query the index by text similarity as well.\n\n    In addition, if the backend supports it, you can call the following\n    duplicate detection methods:\n\n    -   :meth:`find_duplicates() <fiftyone.brain.similarity.DuplicatesMixin.find_duplicates>`:\n        Query the index to find all examples with near-duplicates in the\n        collection\n\n    -   :meth:`find_unique() <fiftyone.brain.similarity.DuplicatesMixin.find_unique>`:\n        Query the index to select a subset of examples of a specified size that\n        are maximally unique with respect to each other\n\n    If no ``embeddings`` or ``model`` is provided, a default model is used to\n    generate embeddings.\n\n    Args:\n        samples: a :class:`fiftyone.core.collections.SampleCollection`\n        patches_field (None): a sample field defining the image patches in each\n            sample that have been/will be embedded. Must be of type\n            :class:`fiftyone.core.labels.Detection`,\n            :class:`fiftyone.core.labels.Detections`,\n            :class:`fiftyone.core.labels.Polyline`, or\n            :class:`fiftyone.core.labels.Polylines`\n        roi_field (None): an optional :class:`fiftyone.core.labels.Detection`,\n            :class:`fiftyone.core.labels.Detections`,\n            :class:`fiftyone.core.labels.Polyline`, or\n            :class:`fiftyone.core.labels.Polylines` field defining a region of\n            interest within each image to use to compute embeddings\n        embeddings (None): embeddings to feed the index. This argument's\n            behavior depends on whether a ``model`` is provided, as described\n            below.\n\n            If no ``model`` is provided, this argument specifies pre-computed\n            embeddings to use:\n\n            -   a ``num_samples x num_dims`` array of embeddings\n            -   if ``patches_field``/``roi_field`` is specified,  a dict\n                mapping sample IDs to ``num_patches x num_dims`` arrays of\n                patch embeddings\n            -   the name of a dataset field from which to load embeddings\n            -   ``None``: use the default model to compute embeddings\n            -   ``False``: **do not** compute embeddings right now\n\n            If a ``model`` is provided, this argument specifies where to store\n            the model's embeddings:\n\n            -   the name of a field in which to store the computed embeddings\n            -   ``False``: **do not** compute embeddings right now\n\n            In either case, when working with patch embeddings, you can provide\n            either the fully-qualified path to the patch embeddings or just the\n            name of the label attribute in ``patches_field``/``roi_field``\n        brain_key (None): a brain key under which to store the results of this\n            method\n        model (None): a :class:`fiftyone.core.models.Model` or the name of a\n            model from the\n            `FiftyOne Model Zoo <https://docs.voxel51.com/user_guide/model_zoo/index.html>`_\n            to use, or that was already used, to generate embeddings. The model\n            must expose embeddings (``model.has_embeddings = True``)\n        model_kwargs (None): a dictionary of optional keyword arguments to pass\n            to the model's ``Config`` when a model name is provided\n        force_square (False): whether to minimally manipulate the patch\n            bounding boxes into squares prior to extraction. Only applicable\n            when a ``model`` and ``patches_field``/``roi_field`` are specified\n        alpha (None): an optional expansion/contraction to apply to the patches\n            before extracting them, in ``[-1, inf)``. If provided, the length\n            and width of the box are expanded (or contracted, when\n            ``alpha < 0``) by ``(100 * alpha)%``. For example, set\n            ``alpha = 0.1`` to expand the boxes by 10%, and set\n            ``alpha = -0.1`` to contract the boxes by 10%. Only applicable when\n            a ``model`` and ``patches_field``/``roi_field`` are specified\n        batch_size (None): an optional batch size to use when computing\n            embeddings. Only applicable when a ``model`` is provided\n        num_workers (None): the number of workers to use when loading images.\n            Only applicable when a Torch-based model is being used to compute\n            embeddings\n        skip_failures (True): whether to gracefully continue without raising an\n            error if embeddings cannot be generated for a sample\n        progress (None): whether to render a progress bar (True/False), use the\n            default value ``fiftyone.config.show_progress_bars`` (None), or a\n            progress callback function to invoke instead\n        backend (None): the similarity backend to use. The supported values are\n            ``fiftyone.brain.brain_config.similarity_backends.keys()`` and the\n            default is\n            ``fiftyone.brain.brain_config.default_similarity_backend``\n        **kwargs: keyword arguments for the\n            :class:`fiftyone.brian.SimilarityConfig` subclass of the backend\n            being used\n\n    Returns:\n        a :class:`fiftyone.brain.similarity.SimilarityIndex`\n    \"\"\"\n    import fiftyone.brain.similarity as fbs\n\n    return fbs.compute_similarity(\n        samples,\n        patches_field,\n        roi_field,\n        embeddings,\n        brain_key,\n        model,\n        model_kwargs,\n        force_square,\n        alpha,\n        batch_size,\n        num_workers,\n        skip_failures,\n        progress,\n        backend,\n        **kwargs,\n    )\n\n\ndef compute_near_duplicates(\n    samples,\n    threshold=0.2,\n    roi_field=None,\n    embeddings=None,\n    similarity_index=None,\n    model=None,\n    model_kwargs=None,\n    force_square=False,\n    alpha=None,\n    batch_size=None,\n    num_workers=None,\n    skip_failures=True,\n    progress=None,\n):\n    \"\"\"Detects potential duplicates in the given sample collection.\n\n    Calling this method only initializes the index. You can then call the\n    methods exposed on the returned object to perform the following operations:\n\n    -   :meth:`duplicate_ids <fiftyone.brain.similarity.DuplicatesMixin.duplicate_ids>`:\n        A list of duplicate IDs\n\n    -   :meth:`neighbors_map <fiftyone.brain.similarity.DuplicatesMixin.neighbors_map>`:\n        A dictionary mapping IDs to lists of ``(dup_id, dist)`` tuples\n\n    -   :meth:`duplicates_view() <fiftyone.brain.similarity.DuplicatesMixin.duplicates_view>`:\n        Returns a view of all duplicates in the input collection\n\n    Args:\n        samples: a :class:`fiftyone.core.collections.SampleCollection`\n        threshold (0.2): the similarity distance threshold to use when\n            detecting duplicates. Values in ``[0.1, 0.25]`` work well for the\n            default setup\n        roi_field (None): an optional :class:`fiftyone.core.labels.Detection`,\n            :class:`fiftyone.core.labels.Detections`,\n            :class:`fiftyone.core.labels.Polyline`, or\n            :class:`fiftyone.core.labels.Polylines` field defining a region of\n            interest within each image to use to compute embeddings\n        embeddings (None): if no ``model`` is provided, this argument specifies\n            pre-computed embeddings to use, which can be any of the following:\n\n            -   a ``num_samples x num_dims`` array of embeddings\n            -   if ``roi_field`` is specified,  a dict mapping sample IDs to\n                ``num_patches x num_dims`` arrays of patch embeddings\n            -   the name of a dataset field containing the embeddings to use\n\n            If a ``model`` is provided, this argument specifies the name of a\n            field in which to store the computed embeddings. In either case,\n            when working with patch embeddings, you can provide either the\n            fully-qualified path to the patch embeddings or just the name of\n            the label attribute in ``roi_field``\n        similarity_index (None): a\n            :class:`fiftyone.brain.similarity.SimilarityIndex` or the brain key\n            of a similarity index to use to load pre-computed embeddings\n        model (None): a :class:`fiftyone.core.models.Model` or the name of a\n            model from the\n            `FiftyOne Model Zoo <https://docs.voxel51.com/user_guide/model_zoo/models.html>`_\n            to use to generate embeddings. The model must expose embeddings\n            (``model.has_embeddings = True``)\n        model_kwargs (None): a dictionary of optional keyword arguments to pass\n            to the model's ``Config`` when a model name is provided\n        force_square (False): whether to minimally manipulate the patch\n            bounding boxes into squares prior to extraction. Only applicable\n            when a ``model`` and ``roi_field`` are specified\n        alpha (None): an optional expansion/contraction to apply to the patches\n            before extracting them, in ``[-1, inf)``. If provided, the length\n            and width of the box are expanded (or contracted, when\n            ``alpha < 0``) by ``(100 * alpha)%``. For example, set\n            ``alpha = 0.1`` to expand the boxes by 10%, and set\n            ``alpha = -0.1`` to contract the boxes by 10%. Only applicable when\n            a ``model`` and ``roi_field`` are specified\n        batch_size (None): a batch size to use when computing embeddings. Only\n            applicable when a ``model`` is provided\n        num_workers (None): the number of workers to use when loading images.\n            Only applicable when a Torch-based model is being used to compute\n            embeddings\n        skip_failures (True): whether to gracefully continue without raising an\n            error if embeddings cannot be generated for a sample\n        progress (None): whether to render a progress bar (True/False), use the\n            default value ``fiftyone.config.show_progress_bars`` (None), or a\n            progress callback function to invoke instead\n\n    Returns:\n        a :class:`fiftyone.brain.similarity.SimilarityIndex`\n    \"\"\"\n    import fiftyone.brain.internal.core.duplicates as fbd\n\n    return fbd.compute_near_duplicates(\n        samples,\n        threshold=threshold,\n        roi_field=roi_field,\n        embeddings=embeddings,\n        similarity_index=similarity_index,\n        model=model,\n        model_kwargs=model_kwargs,\n        force_square=force_square,\n        alpha=alpha,\n        batch_size=batch_size,\n        num_workers=num_workers,\n        skip_failures=skip_failures,\n        progress=progress,\n    )\n\n\ndef compute_exact_duplicates(\n    samples,\n    num_workers=None,\n    skip_failures=True,\n    progress=None,\n):\n    \"\"\"Detects duplicate media in a sample collection.\n\n    This method detects exact duplicates with the same filehash. Use\n    :meth:`compute_near_duplicates` to detect near-duplicates.\n\n    If duplicates are found, the first instance in ``samples`` will be the key\n    in the returned dictionary, while the subsequent duplicates will be the\n    values in the corresponding list.\n\n    Args:\n        samples: a :class:`fiftyone.core.collections.SampleCollection`\n        num_workers (None): an optional number of processes to use\n        skip_failures (True): whether to gracefully ignore samples whose\n            filehash cannot be computed\n        progress (None): whether to render a progress bar (True/False), use the\n            default value ``fiftyone.config.show_progress_bars`` (None), or a\n            progress callback function to invoke instead\n\n    Returns:\n        a dictionary mapping IDs of samples with exact duplicates to lists of\n        IDs of the duplicates for the corresponding sample\n    \"\"\"\n    import fiftyone.brain.internal.core.duplicates as fbd\n\n    return fbd.compute_exact_duplicates(\n        samples, num_workers, skip_failures, progress\n    )\n\n\ndef compute_leaky_splits(\n    samples,\n    splits,\n    threshold=0.2,\n    roi_field=None,\n    embeddings=None,\n    similarity_index=None,\n    model=None,\n    model_kwargs=None,\n    force_square=False,\n    alpha=None,\n    batch_size=None,\n    num_workers=None,\n    skip_failures=True,\n    progress=None,\n):\n    \"\"\"Computes potential leaks between splits of the given sample collection.\n\n    Calling this method only initializes the index. You can then call the\n    methods exposed on the returned object to perform the following operations:\n\n    -   :meth:`leaks_view() <fiftyone.brain.core.internal.leaky_splits.LeakySplitsIndex.leaks_view>`:\n        Returns a view of all leaks in the input collection\n\n    -   :meth:`no_leaks_view() <fiftyone.brain.core.internal.leaky_splits.LeakySplitsIndex.no_leaks_view>`:\n        Returns the subset of the input collection without any leaks\n\n    -   :meth:`leaks_for_sample() <fiftyone.brain.core.internal.leaky_splits.LeakySplitsIndex.leaks_for_sample>`:\n        Returns a view with leaks corresponding to the given sample\n\n    -   :meth:`tag_leaks() <fiftyone.brain.core.internal.leaky_splits.LeakySplitsIndex.tag_leaks>`:\n        Tags leaks in the dataset as leaks\n\n    Args:\n        samples: a :class:`fiftyone.core.collections.SampleCollection`\n        splits: the dataset splits, specified in one of the following ways:\n\n            -   a list of tag strings\n            -   the name of a string/list field that encodes the split\n                memberships\n            -   a dict mapping split names to\n                :class:`fiftyone.core.view.DatasetView` instances\n        threshold (0.2): the similarity distance threshold to use when\n            detecting leaks. Values in ``[0.1, 0.25]`` work well for the\n            default setup\n        roi_field (None): an optional :class:`fiftyone.core.labels.Detection`,\n            :class:`fiftyone.core.labels.Detections`,\n            :class:`fiftyone.core.labels.Polyline`, or\n            :class:`fiftyone.core.labels.Polylines` field defining a region of\n            interest within each image to use to compute leaks\n        embeddings (None): if no ``model`` is provided, this argument specifies\n            pre-computed embeddings to use, which can be any of the following:\n\n            -   a ``num_samples x num_dims`` array of embeddings\n            -   if ``roi_field`` is specified,  a dict mapping sample IDs to\n                ``num_patches x num_dims`` arrays of patch embeddings\n            -   the name of a dataset field containing the embeddings to use\n\n            If a ``model`` is provided, this argument specifies the name of a\n            field in which to store the computed embeddings. In either case,\n            when working with patch embeddings, you can provide either the\n            fully-qualified path to the patch embeddings or just the name of\n            the label attribute in ``roi_field``\n        similarity_index (None): a\n            :class:`fiftyone.brain.similarity.SimilarityIndex` or the brain key\n            of a similarity index to use to load pre-computed embeddings\n        model (None): a :class:`fiftyone.core.models.Model` or the name of a\n            model from the\n            `FiftyOne Model Zoo <https://docs.voxel51.com/user_guide/model_zoo/models.html>`_\n            to use to generate embeddings. The model must expose embeddings\n            (``model.has_embeddings = True``)\n        model_kwargs (None): a dictionary of optional keyword arguments to pass\n            to the model's ``Config`` when a model name is provided\n        force_square (False): whether to minimally manipulate the patch\n            bounding boxes into squares prior to extraction. Only applicable\n            when a ``model`` and ``roi_field`` are specified\n        alpha (None): an optional expansion/contraction to apply to the patches\n            before extracting them, in ``[-1, inf)``. If provided, the length\n            and width of the box are expanded (or contracted, when\n            ``alpha < 0``) by ``(100 * alpha)%``. For example, set\n            ``alpha = 0.1`` to expand the boxes by 10%, and set\n            ``alpha = -0.1`` to contract the boxes by 10%. Only applicable when\n            a ``model`` and ``roi_field`` are specified\n        batch_size (None): a batch size to use when computing embeddings. Only\n            applicable when a ``model`` is provided\n        num_workers (None): the number of workers to use when loading images.\n            Only applicable when a Torch-based model is being used to compute\n            embeddings\n        skip_failures (True): whether to gracefully continue without raising an\n            error if embeddings cannot be generated for a sample\n        progress (None): whether to render a progress bar (True/False), use the\n            default value ``fiftyone.config.show_progress_bars`` (None), or a\n            progress callback function to invoke instead\n\n    Returns:\n        a :class:`fiftyone.brain.internal.core.leaky_splits.LeakySplitsIndex`\n    \"\"\"\n    import fiftyone.brain.internal.core.leaky_splits as fbl\n\n    return fbl.compute_leaky_splits(\n        samples,\n        splits,\n        threshold=threshold,\n        roi_field=roi_field,\n        embeddings=embeddings,\n        similarity_index=similarity_index,\n        model=model,\n        model_kwargs=model_kwargs,\n        force_square=force_square,\n        alpha=alpha,\n        batch_size=batch_size,\n        num_workers=num_workers,\n        skip_failures=skip_failures,\n        progress=progress,\n    )\n"
  },
  {
    "path": "fiftyone/brain/config.py",
    "content": "\"\"\"\nBrain config.\n\n| Copyright 2017-2026, Voxel51, Inc.\n| `voxel51.com <https://voxel51.com/>`_\n|\n\"\"\"\nimport os\n\nfrom fiftyone.core.config import EnvConfig\n\n\nclass BrainConfig(EnvConfig):\n    \"\"\"FiftyOne brain configuration settings.\"\"\"\n\n    _BUILTIN_SIMILARITY_BACKENDS = {\n        \"sklearn\": {\n            \"config_cls\": \"fiftyone.brain.internal.core.sklearn.SklearnSimilarityConfig\",\n        },\n        \"pinecone\": {\n            \"config_cls\": \"fiftyone.brain.internal.core.pinecone.PineconeSimilarityConfig\",\n        },\n        \"qdrant\": {\n            \"config_cls\": \"fiftyone.brain.internal.core.qdrant.QdrantSimilarityConfig\",\n        },\n        \"milvus\": {\n            \"config_cls\": \"fiftyone.brain.internal.core.milvus.MilvusSimilarityConfig\",\n        },\n        \"lancedb\": {\n            \"config_cls\": \"fiftyone.brain.internal.core.lancedb.LanceDBSimilarityConfig\",\n        },\n        \"redis\": {\n            \"config_cls\": \"fiftyone.brain.internal.core.redis.RedisSimilarityConfig\",\n        },\n        \"mongodb\": {\n            \"config_cls\": \"fiftyone.brain.internal.core.mongodb.MongoDBSimilarityConfig\",\n        },\n        \"elasticsearch\": {\n            \"config_cls\": \"fiftyone.brain.internal.core.elasticsearch.ElasticsearchSimilarityConfig\",\n        },\n        \"pgvector\": {\n            \"config_cls\": \"fiftyone.brain.internal.core.pgvector.PgVectorSimilarityConfig\",\n        },\n        \"mosaic\": {\n            \"config_cls\": \"fiftyone.brain.internal.core.mosaic.MosaicSimilarityConfig\",\n        },\n    }\n\n    _BUILTIN_VISUALIZATION_METHODS = {\n        \"umap\": {\n            \"config_cls\": \"fiftyone.brain.visualization.UMAPVisualizationConfig\",\n        },\n        \"tsne\": {\n            \"config_cls\": \"fiftyone.brain.visualization.TSNEVisualizationConfig\",\n        },\n        \"pca\": {\n            \"config_cls\": \"fiftyone.brain.visualization.PCAVisualizationConfig\",\n        },\n        \"manual\": {\n            \"config_cls\": \"fiftyone.brain.visualization.ManualVisualizationConfig\",\n        },\n    }\n\n    def __init__(self, d=None):\n        if d is None:\n            d = {}\n\n        self.default_similarity_backend = self.parse_string(\n            d,\n            \"default_similarity_backend\",\n            env_var=\"FIFTYONE_BRAIN_DEFAULT_SIMILARITY_BACKEND\",\n            default=\"sklearn\",\n        )\n\n        self.similarity_backends = self._parse_similarity_backends(d)\n        if self.default_similarity_backend not in self.similarity_backends:\n            self.default_similarity_backend = next(\n                iter(sorted(self.similarity_backends.keys())), None\n            )\n\n        self.default_visualization_method = self.parse_string(\n            d,\n            \"default_visualization_method\",\n            env_var=\"FIFTYONE_BRAIN_DEFAULT_VISUALIZATION_METHOD\",\n            default=\"umap\",\n        )\n\n        self.visualization_methods = self._parse_visualization_methods(d)\n        if self.default_visualization_method not in self.visualization_methods:\n            self.default_visualization_method = next(\n                iter(sorted(self.visualization_methods.keys())), None\n            )\n\n    def _parse_similarity_backends(self, d):\n        d = d.get(\"similarity_backends\", {})\n        env_vars = dict(os.environ)\n\n        #\n        # `FIFTYONE_BRAIN_SIMILARITY_BACKENDS` can be used to declare which\n        # backends are exposed. This may exclude builtin backends and/or\n        # declare new backends\n        #\n\n        if \"FIFTYONE_BRAIN_SIMILARITY_BACKENDS\" in env_vars:\n            backends = env_vars[\"FIFTYONE_BRAIN_SIMILARITY_BACKENDS\"].split(\n                \",\"\n            )\n\n            # Special syntax to append rather than override default backends\n            if \"*\" in backends:\n                backends = set(b for b in backends if b != \"*\")\n                backends |= set(self._BUILTIN_SIMILARITY_BACKENDS.keys())\n\n            d = {backend: d.get(backend, {}) for backend in backends}\n        else:\n            backends = self._BUILTIN_SIMILARITY_BACKENDS.keys()\n            for backend in backends:\n                if backend not in d:\n                    d[backend] = {}\n\n        #\n        # Extract parameters from any environment variables of the form\n        # `FIFTYONE_BRAIN_SIMILARITY_<BACKEND>_<PARAMETER>`\n        #\n\n        for backend, d_backend in d.items():\n            prefix = \"FIFTYONE_BRAIN_SIMILARITY_%s_\" % backend.upper()\n            for env_name, env_value in env_vars.items():\n                if env_name.startswith(prefix):\n                    name = env_name[len(prefix) :].lower()\n                    value = _parse_env_value(env_value)\n                    d_backend[name] = value\n\n        #\n        # Set default parameters for builtin similarity backends\n        #\n\n        for backend, defaults in self._BUILTIN_SIMILARITY_BACKENDS.items():\n            if backend not in d:\n                continue\n\n            d_backend = d[backend]\n            for name, value in defaults.items():\n                if name not in d_backend:\n                    d_backend[name] = value\n\n        return d\n\n    def _parse_visualization_methods(self, d):\n        d = d.get(\"visualization_methods\", {})\n        env_vars = dict(os.environ)\n\n        #\n        # `FIFTYONE_BRAIN_VISUALIZATION_METHODS` can be used to declare which\n        # methods are exposed. This may exclude builtin methods and/or declare\n        # new methods\n        #\n\n        if \"FIFTYONE_BRAIN_VISUALIZATION_METHODS\" in env_vars:\n            methods = env_vars[\"FIFTYONE_BRAIN_VISUALIZATION_METHODS\"].split(\n                \",\"\n            )\n\n            # Special syntax to append rather than override default methods\n            if \"*\" in methods:\n                methods = set(m for m in methods if m != \"*\")\n                methods |= set(self._BUILTIN_VISUALIZATION_METHODS.keys())\n\n            d = {method: d.get(method, {}) for method in methods}\n        else:\n            methods = self._BUILTIN_VISUALIZATION_METHODS.keys()\n            for method in methods:\n                if method not in d:\n                    d[method] = {}\n\n        #\n        # Extract parameters from any environment variables of the form\n        # `FIFTYONE_BRAIN_VISUALIZATION_<METHOD>_<PARAMETER>`\n        #\n\n        for method, d_method in d.items():\n            prefix = \"FIFTYONE_BRAIN_VISUALIZATION_%s_\" % method.upper()\n            for env_name, env_value in env_vars.items():\n                if env_name.startswith(prefix):\n                    name = env_name[len(prefix) :].lower()\n                    value = _parse_env_value(env_value)\n                    d_method[name] = value\n\n        #\n        # Set default parameters for builtin visualization methods\n        #\n\n        for method, defaults in self._BUILTIN_VISUALIZATION_METHODS.items():\n            if method not in d:\n                continue\n\n            d_method = d[method]\n            for name, value in defaults.items():\n                if name not in d_method:\n                    d_method[name] = value\n\n        return d\n\n\ndef locate_brain_config():\n    \"\"\"Returns the path to the :class:`BrainConfig` on disk.\n\n    The default location is ``~/.fiftyone/brain_config.json``, but you can\n    override this path by setting the ``FIFTYONE_BRAIN_CONFIG_PATH``\n    environment variable.\n\n    Note that a config file may not actually exist on disk.\n\n    Returns:\n        the path to the :class:`BrainConfig` on disk\n    \"\"\"\n    if \"FIFTYONE_BRAIN_CONFIG_PATH\" not in os.environ:\n        return os.path.join(\n            os.path.expanduser(\"~\"), \".fiftyone\", \"brain_config.json\"\n        )\n\n    return os.environ[\"FIFTYONE_BRAIN_CONFIG_PATH\"]\n\n\ndef load_brain_config():\n    \"\"\"Loads the FiftyOne brain config.\n\n    Returns:\n        a :class:`BrainConfig` instance\n    \"\"\"\n    brain_config_path = locate_brain_config()\n    if os.path.isfile(brain_config_path):\n        return BrainConfig.from_json(brain_config_path)\n\n    return BrainConfig()\n\n\ndef _parse_env_value(value):\n    try:\n        return int(value)\n    except:\n        pass\n\n    try:\n        return float(value)\n    except:\n        pass\n\n    if value in (\"True\", \"true\"):\n        return True\n\n    if value in (\"False\", \"false\"):\n        return False\n\n    if value in (\"None\", \"\"):\n        return None\n\n    if \",\" in value:\n        return [_parse_env_value(v) for v in value.split(\",\")]\n\n    return value\n"
  },
  {
    "path": "fiftyone/brain/internal/__init__.py",
    "content": "\"\"\"\nInternal FiftyOne Brain package.\n\nContains all non-public code powering the ``fiftyone.brain`` public namespace.\n\n| Copyright 2017-2026, Voxel51, Inc.\n| `voxel51.com <https://voxel51.com/>`_\n|\n\"\"\"\n"
  },
  {
    "path": "fiftyone/brain/internal/core/__init__.py",
    "content": "\"\"\"\n| Copyright 2017-2026, Voxel51, Inc.\n| `voxel51.com <https://voxel51.com/>`_\n|\n\"\"\"\n"
  },
  {
    "path": "fiftyone/brain/internal/core/duplicates.py",
    "content": "\"\"\"\nDuplicates methods.\n\n| Copyright 2017-2026, Voxel51, Inc.\n| `voxel51.com <https://voxel51.com/>`_\n|\n\"\"\"\nfrom collections import defaultdict\nimport itertools\nimport logging\nimport multiprocessing\n\nimport eta.core.utils as etau\n\nimport fiftyone.core.media as fom\nimport fiftyone.core.utils as fou\nimport fiftyone.core.validation as fov\n\nimport fiftyone.brain as fb\nimport fiftyone.brain.similarity as fbs\nimport fiftyone.brain.internal.core.utils as fbu\n\n\nlogger = logging.getLogger(__name__)\n\n_DEFAULT_MODEL = \"resnet18-imagenet-torch\"\n\n\ndef compute_near_duplicates(\n    samples,\n    threshold=None,\n    roi_field=None,\n    embeddings=None,\n    similarity_index=None,\n    model=None,\n    model_kwargs=None,\n    force_square=False,\n    alpha=None,\n    batch_size=None,\n    num_workers=None,\n    skip_failures=True,\n    progress=None,\n):\n    \"\"\"See ``fiftyone/brain/__init__.py``.\"\"\"\n\n    fov.validate_collection(samples)\n\n    if etau.is_str(embeddings):\n        embeddings_field, embeddings_exist = fbu.parse_data_field(\n            samples,\n            embeddings,\n            data_type=\"embeddings\",\n        )\n        embeddings = None\n    else:\n        embeddings_field = None\n        embeddings_exist = None\n\n    if etau.is_str(similarity_index):\n        similarity_index = samples.load_brain_results(similarity_index)\n\n    if (\n        model is None\n        and embeddings is None\n        and similarity_index is None\n        and not embeddings_exist\n    ):\n        model = _DEFAULT_MODEL\n\n    if similarity_index is None:\n        similarity_index = fb.compute_similarity(\n            samples,\n            backend=\"sklearn\",\n            roi_field=roi_field,\n            embeddings=embeddings_field or embeddings,\n            model=model,\n            model_kwargs=model_kwargs,\n            force_square=force_square,\n            alpha=alpha,\n            batch_size=batch_size,\n            num_workers=num_workers,\n            skip_failures=skip_failures,\n            progress=progress,\n        )\n    elif not isinstance(similarity_index, fbs.DuplicatesMixin):\n        raise ValueError(\n            \"This method only supports similarity indexes that implement the \"\n            \"%s mixin\" % fbs.DuplicatesMixin\n        )\n\n    similarity_index.find_duplicates(thresh=threshold)\n\n    return similarity_index\n\n\ndef compute_exact_duplicates(samples, num_workers, skip_failures, progress):\n    \"\"\"See ``fiftyone/brain/__init__.py``.\"\"\"\n\n    fov.validate_collection(samples)\n\n    if num_workers is None:\n        if samples.media_type == fom.VIDEO:\n            num_workers = multiprocessing.cpu_count()\n        else:\n            num_workers = 1\n\n    logger.info(\"Computing filehashes...\")\n\n    method = \"md5\" if samples.media_type == fom.VIDEO else None\n\n    if num_workers <= 1:\n        hashes = _compute_filehashes(samples, method, progress)\n    else:\n        hashes = _compute_filehashes_multi(\n            samples, method, num_workers, progress\n        )\n\n    num_missing = sum(h is None for h in hashes)\n    if num_missing > 0:\n        msg = \"Failed to compute %d filehashes\" % num_missing\n        if skip_failures:\n            logger.warning(msg)\n        else:\n            raise ValueError(msg)\n\n    neighbors_map = defaultdict(list)\n\n    observed_hashes = {}\n    for _id, _hash in hashes.items():\n        if _hash is None:\n            continue\n\n        if _hash in observed_hashes:\n            neighbors_map[observed_hashes[_hash]].append(_id)\n        else:\n            observed_hashes[_hash] = _id\n\n    return dict(neighbors_map)\n\n\ndef _compute_filehashes(samples, method, progress):\n    ids, filepaths = samples.values([\"id\", \"filepath\"])\n\n    with fou.ProgressBar(total=len(ids), progress=progress) as pb:\n        return {\n            _id: _compute_filehash(filepath, method)\n            for _id, filepath in pb(zip(ids, filepaths))\n        }\n\n\ndef _compute_filehashes_multi(samples, method, num_workers, progress):\n    ids, filepaths = samples.values([\"id\", \"filepath\"])\n\n    methods = itertools.repeat(method)\n\n    inputs = list(zip(ids, filepaths, methods))\n\n    with fou.ProgressBar(total=len(inputs), progress=progress) as pb:\n        with multiprocessing.Pool(processes=num_workers) as pool:\n            return {\n                k: v\n                for k, v in pb(\n                    pool.imap_unordered(_do_compute_filehash, inputs)\n                )\n            }\n\n\ndef _compute_filehash(filepath, method):\n    try:\n        filehash = fou.compute_filehash(filepath, method=method)\n    except:\n        filehash = None\n\n    return filehash\n\n\ndef _do_compute_filehash(args):\n    _id, filepath, method = args\n    try:\n        filehash = fou.compute_filehash(filepath, method=method)\n    except:\n        filehash = None\n\n    return _id, filehash\n"
  },
  {
    "path": "fiftyone/brain/internal/core/elasticsearch.py",
    "content": "\"\"\"\nElastisearch similarity backend.\n\n| Copyright 2017-2026, Voxel51, Inc.\n| `voxel51.com <https://voxel51.com/>`_\n|\n\"\"\"\nimport logging\n\nimport numpy as np\n\nimport eta.core.utils as etau\n\nfrom fiftyone import ViewField as F\nimport fiftyone.core.utils as fou\nimport fiftyone.brain.internal.core.utils as fbu\nfrom fiftyone.brain.similarity import (\n    SimilarityConfig,\n    Similarity,\n    SimilarityIndex,\n)\n\nes = fou.lazy_import(\"elasticsearch\")\n\n\nlogger = logging.getLogger(__name__)\n\n_SUPPORTED_METRICS = {\n    \"cosine\": \"cosine\",\n    \"dotproduct\": \"dot_product\",\n    \"euclidean\": \"l2_norm\",\n    \"innerproduct\": \"max_inner_product\",\n}\n\n\nclass ElasticsearchSimilarityConfig(SimilarityConfig):\n    \"\"\"Configuration for a Elasticsearch similarity instance.\n\n    Args:\n        index_name (None): the name of the Elasticsearch index to use or\n            create. If none is provided, a new index will be created\n        metric (\"cosine\"): the embedding distance metric to use when creating a\n            new index. Supported values are\n            ``(\"cosine\", \"dotproduct\", \"euclidean\", \"innerproduct\")``\n        hosts (None): the full Elasticsearch server address(es) to use. Can be\n            a string or list of strings\n        cloud_id (None): the Cloud ID of an Elastic Cloud to connect to\n        username (None): a username to use\n        password (None): a password to use\n        api_key (None): an API key to use\n        ca_certs (None): a path to a CA certificate\n        bearer_auth (None): a bearer token to use\n        ssl_assert_fingerprint (None): a SHA256 fingerprint to use\n        verify_certs (None): whether to verify SSL certificates\n        **kwargs: keyword arguments for\n            :class:`fiftyone.brain.similarity.SimilarityConfig`\n    \"\"\"\n\n    def __init__(\n        self,\n        index_name=None,\n        metric=\"cosine\",\n        hosts=None,\n        cloud_id=None,\n        username=None,\n        password=None,\n        api_key=None,\n        ca_certs=None,\n        bearer_auth=None,\n        ssl_assert_fingerprint=None,\n        verify_certs=None,\n        **kwargs,\n    ):\n        if metric not in _SUPPORTED_METRICS:\n            raise ValueError(\n                \"Unsupported metric '%s'. Supported values are %s\"\n                % (metric, tuple(_SUPPORTED_METRICS.keys()))\n            )\n\n        super().__init__(**kwargs)\n\n        self.index_name = index_name\n        self.metric = metric\n\n        self._hosts = hosts\n        self._cloud_id = cloud_id\n        self._username = username\n        self._password = password\n        self._api_key = api_key\n        self._ca_certs = ca_certs\n        self._bearer_auth = bearer_auth\n        self._ssl_assert_fingerprint = ssl_assert_fingerprint\n        self._verify_certs = verify_certs\n\n    @property\n    def method(self):\n        return \"elasticsearch\"\n\n    @property\n    def hosts(self):\n        return self._hosts\n\n    @hosts.setter\n    def hosts(self, value):\n        self._hosts = value\n\n    @property\n    def cloud_id(self):\n        return self._cloud_id\n\n    @cloud_id.setter\n    def cloud_id(self, value):\n        self._cloud_id = value\n\n    @property\n    def username(self):\n        return self._username\n\n    @username.setter\n    def username(self, value):\n        self._username = value\n\n    @property\n    def password(self):\n        return self._password\n\n    @password.setter\n    def password(self, value):\n        self._password = value\n\n    @property\n    def api_key(self):\n        return self._api_key\n\n    @api_key.setter\n    def api_key(self, value):\n        self._api_key = value\n\n    @property\n    def ca_certs(self):\n        return self._ca_certs\n\n    @ca_certs.setter\n    def ca_certs(self, value):\n        self._ca_certs = value\n\n    @property\n    def bearer_auth(self):\n        return self._bearer_auth\n\n    @bearer_auth.setter\n    def bearer_auth(self, value):\n        self._bearer_auth = value\n\n    @property\n    def ssl_assert_fingerprint(self):\n        return self._ssl_assert_fingerprint\n\n    @ssl_assert_fingerprint.setter\n    def ssl_assert_fingerprint(self, value):\n        self._ssl_assert_fingerprint = value\n\n    @property\n    def verify_certs(self):\n        return self._verify_certs\n\n    @verify_certs.setter\n    def verify_certs(self, value):\n        self._verify_certs = value\n\n    @property\n    def max_k(self):\n        return 10000  # Elasticsearch limit\n\n    @property\n    def supports_least_similarity(self):\n        return False\n\n    @property\n    def supported_aggregations(self):\n        return (\"mean\",)\n\n    def load_credentials(\n        self,\n        hosts=None,\n        cloud_id=None,\n        username=None,\n        password=None,\n        api_key=None,\n        ca_certs=None,\n        bearer_auth=None,\n        ssl_assert_fingerprint=None,\n        verify_certs=None,\n    ):\n        self._load_parameters(\n            hosts=hosts,\n            cloud_id=cloud_id,\n            username=username,\n            password=password,\n            api_key=api_key,\n            ca_certs=ca_certs,\n            bearer_auth=bearer_auth,\n            ssl_assert_fingerprint=ssl_assert_fingerprint,\n            verify_certs=verify_certs,\n        )\n\n\nclass ElasticsearchSimilarity(Similarity):\n    \"\"\"Elasticsearch similarity factory.\n\n    Args:\n        config: a :class:`ElasticsearchSimilarityConfig`\n    \"\"\"\n\n    def ensure_requirements(self):\n        fou.ensure_package(\"elasticsearch\")\n\n    def ensure_usage_requirements(self):\n        fou.ensure_package(\"elasticsearch\")\n\n    def initialize(self, samples, brain_key):\n        return ElasticsearchSimilarityIndex(\n            samples, self.config, brain_key, backend=self\n        )\n\n\nclass ElasticsearchSimilarityIndex(SimilarityIndex):\n    \"\"\"Class for interacting with Elasticsearch similarity indexes.\n\n    Args:\n        samples: the :class:`fiftyone.core.collections.SampleCollection` used\n        config: the :class:`ElasticsearchSimilarityConfig` used\n        brain_key: the brain key\n        backend (None): a :class:`ElasticsearchSimilarity` instance\n    \"\"\"\n\n    def __init__(self, samples, config, brain_key, backend=None):\n        super().__init__(samples, config, brain_key, backend=backend)\n        self._client = None\n        self._metric = None\n        self._initialize()\n\n    @property\n    def total_index_size(self):\n        try:\n            return self._client.count(index=self.config.index_name)[\"count\"]\n        except:\n            return 0\n\n    @property\n    def client(self):\n        \"\"\"The ``elasticsearch.Elasticsearch`` instance for this index.\"\"\"\n        return self._client\n\n    def _initialize(self):\n        kwargs = {}\n\n        for key in (\n            \"hosts\",\n            \"cloud_id\",\n            \"username\",\n            \"password\",\n            \"api_key\",\n            \"ca_certs\",\n            \"bearer_auth\",\n            \"ssl_assert_fingerprint\",\n            \"verify_certs\",\n        ):\n            value = getattr(self.config, key, None)\n            if value is not None:\n                kwargs[key] = value\n\n        username = kwargs.pop(\"username\", None)\n        password = kwargs.pop(\"password\", None)\n        if username is not None and password is not None:\n            kwargs[\"basic_auth\"] = (username, password)\n\n        try:\n            self._client = es.Elasticsearch(**kwargs)\n        except Exception as e:\n            raise ValueError(\n                \"Failed to connect to Elasticsearch backend. Refer to \"\n                \"https://docs.voxel51.com/integrations/elasticsearch.html for more \"\n                \"information\"\n            ) from e\n\n        if self.config.index_name is None:\n            root = \"fiftyone-\" + fou.to_slug(self.samples._root_dataset.name)\n            index_name = fbu.get_unique_name(root, self._get_index_names())\n\n            self.config.index_name = index_name\n            self.save_config()\n\n    def _get_index_names(self):\n        return self._client.indices.get_alias().keys()\n\n    def _get_index_ids(self, batch_size=1000):\n        sample_ids = []\n        label_ids = []\n        for batch in range(0, self.total_index_size, batch_size):\n            response = self._client.search(\n                index=self.config.index_name,\n                body={\n                    \"fields\": [\"sample_id\"],\n                    \"from\": batch,\n                    \"query\": {\n                        \"bool\": {\n                            \"must\": [\n                                {\"exists\": {\"field\": \"vector\"}},\n                                {\"exists\": {\"field\": \"sample_id\"}},\n                            ]\n                        }\n                    },\n                },\n                source=False,\n                size=batch_size,\n            )\n            for doc in response[\"hits\"][\"hits\"]:\n                sample_id = doc[\"fields\"][\"sample_id\"][0]\n                sample_or_label_id = doc[\"_id\"]\n                sample_ids.append(sample_id)\n                label_ids.append(sample_or_label_id)\n\n        return sample_ids, label_ids\n\n    def _get_dimension(self):\n        if self.total_index_size == 0:\n            return None\n\n        if self.config.patches_field is not None:\n            embeddings, _, _ = self.get_embeddings(\n                label_ids=self._label_ids[:1]\n            )\n        else:\n            embeddings, _, _ = self.get_embeddings(\n                sample_ids=self._sample_ids[:1]\n            )\n\n        return embeddings.shape[1]\n\n    def _get_metric(self):\n        if self._metric is None:\n            try:\n                # We must ask ES rather than using `self.config.metric` because\n                # we may be working with a preexisting index\n                self._metric = self._client.indices.get_mapping(\n                    index=self.config.index_name\n                )[self.config.index_name][\"mappings\"][\"properties\"][\"vector\"][\n                    \"similarity\"\n                ]\n            except:\n                logger.warning(\n                    \"Failed to infer similarity metric from index '%s'\",\n                    self.config.index_name,\n                )\n\n        return self._metric\n\n    def _index_exists(self):\n        if self.config.index_name is None:\n            return False\n\n        return self.config.index_name in self._get_index_names()\n\n    def _create_index(self, dimension):\n        metric = _SUPPORTED_METRICS[self.config.metric]\n        mappings = {\n            \"properties\": {\n                \"vector\": {\n                    \"type\": \"dense_vector\",\n                    \"dims\": dimension,\n                    \"index\": \"true\",\n                    \"similarity\": metric,\n                }\n            }\n        }\n        self._client.indices.create(\n            index=self.config.index_name, mappings=mappings\n        )\n        self._metric = metric\n\n    def _get_existing_ids(self, ids):\n        docs = [{\"_index\": self.config.index_name, \"_id\": i} for i in ids]\n        resp = self._client.mget(docs=docs)\n        return [d[\"_id\"] for d in resp[\"docs\"] if d[\"found\"]]\n\n    def add_to_index(\n        self,\n        embeddings,\n        sample_ids,\n        label_ids=None,\n        overwrite=True,\n        allow_existing=True,\n        warn_existing=False,\n        reload=True,\n        batch_size=500,\n    ):\n        if not self._index_exists():\n            self._create_index(embeddings.shape[1])\n\n        if label_ids is not None:\n            ids = label_ids\n        else:\n            ids = sample_ids\n\n        if warn_existing or not allow_existing or not overwrite:\n            existing_ids = self._get_existing_ids(ids)\n            num_existing = len(existing_ids)\n\n            if num_existing > 0:\n                if not allow_existing:\n                    raise ValueError(\n                        \"Found %d IDs (eg %s) that already exist in the index\"\n                        % (num_existing, next(iter(existing_ids)))\n                    )\n\n                if warn_existing:\n                    if overwrite:\n                        logger.warning(\n                            \"Overwriting %d IDs that already exist in the \"\n                            \"index\",\n                            num_existing,\n                        )\n                    else:\n                        logger.warning(\n                            \"Skipping %d IDs that already exist in the index\",\n                            num_existing,\n                        )\n        else:\n            existing_ids = set()\n\n        if existing_ids and not overwrite:\n            del_inds = [i for i, _id in enumerate(ids) if _id in existing_ids]\n            embeddings = np.delete(embeddings, del_inds, axis=0)\n            sample_ids = np.delete(sample_ids, del_inds)\n            if label_ids is not None:\n                label_ids = np.delete(label_ids, del_inds)\n\n        if self._get_metric() == _SUPPORTED_METRICS[\"dotproduct\"]:\n            embeddings /= np.linalg.norm(embeddings, axis=1)[:, np.newaxis]\n\n        embeddings = [e.tolist() for e in embeddings]\n        sample_ids = list(sample_ids)\n        if label_ids is not None:\n            ids = list(label_ids)\n        else:\n            ids = list(sample_ids)\n\n        for _embeddings, _ids, _sample_ids in zip(\n            fou.iter_batches(embeddings, batch_size),\n            fou.iter_batches(ids, batch_size),\n            fou.iter_batches(sample_ids, batch_size),\n        ):\n            operations = []\n            for _e, _id, _sid in zip(_embeddings, _ids, _sample_ids):\n                operations.append(\n                    {\"index\": {\"_index\": self.config.index_name, \"_id\": _id}}\n                )\n                operations.append({\"sample_id\": _sid, \"vector\": _e})\n\n            self._client.bulk(\n                index=self.config.index_name,\n                operations=operations,\n                refresh=True,\n            )\n\n        if reload:\n            self.reload()\n\n    def remove_from_index(\n        self,\n        sample_ids=None,\n        label_ids=None,\n        allow_missing=True,\n        warn_missing=False,\n        reload=True,\n    ):\n        if label_ids is not None:\n            ids = label_ids\n        else:\n            ids = sample_ids\n\n        if not allow_missing or warn_missing:\n            existing_ids = self._get_existing_ids(ids)\n            missing_ids = set(ids) - set(existing_ids)\n            num_missing = len(missing_ids)\n\n            if num_missing > 0:\n                if not allow_missing:\n                    raise ValueError(\n                        \"Found %d IDs (eg %s) that are not present in the \"\n                        \"index\" % (num_missing, next(iter(missing_ids)))\n                    )\n\n                if warn_missing:\n                    logger.warning(\n                        \"Ignoring %d IDs that are not present in the index\",\n                        num_missing,\n                    )\n\n                ids = existing_ids\n\n        operations = [\n            {\"delete\": {\"_index\": self.config.index_name, \"_id\": i}}\n            for i in ids\n        ]\n        self._client.bulk(body=operations, refresh=True)\n\n        if reload:\n            self.reload()\n\n    def get_embeddings(\n        self,\n        sample_ids=None,\n        label_ids=None,\n        allow_missing=True,\n        warn_missing=False,\n    ):\n        if label_ids is not None:\n            if self.config.patches_field is None:\n                raise ValueError(\"This index does not support label IDs\")\n\n            if sample_ids is not None:\n                logger.warning(\n                    \"Ignoring sample IDs when label IDs are provided\"\n                )\n\n        if sample_ids is not None and self.config.patches_field is not None:\n            (\n                embeddings,\n                sample_ids,\n                label_ids,\n                missing_ids,\n            ) = self._get_patch_embeddings_from_sample_ids(sample_ids)\n        elif self.config.patches_field is not None:\n            (\n                embeddings,\n                sample_ids,\n                label_ids,\n                missing_ids,\n            ) = self._get_patch_embeddings_from_label_ids(label_ids)\n        else:\n            (\n                embeddings,\n                sample_ids,\n                label_ids,\n                missing_ids,\n            ) = self._get_sample_embeddings(sample_ids)\n\n        num_missing_ids = len(missing_ids)\n        if num_missing_ids > 0:\n            if not allow_missing:\n                raise ValueError(\n                    \"Found %d IDs (eg %s) that do not exist in the index\"\n                    % (num_missing_ids, missing_ids[0])\n                )\n\n            if warn_missing:\n                logger.warning(\n                    \"Skipping %d IDs that do not exist in the index\",\n                    num_missing_ids,\n                )\n\n        embeddings = np.array(embeddings)\n        sample_ids = np.array(sample_ids)\n        if label_ids is not None:\n            label_ids = np.array(label_ids)\n\n        return embeddings, sample_ids, label_ids\n\n    def _parse_embeddings_response(self, response, label_id=True):\n        found_embeddings = []\n        found_sample_ids = []\n        found_label_ids = []\n        for r in response:\n            if r.get(\"found\", True):\n                found_embeddings.append(r[\"_source\"][\"vector\"])\n                if label_id:\n                    found_sample_ids.append(r[\"_source\"][\"sample_id\"])\n                    found_label_ids.append(r[\"_id\"])\n                else:\n                    found_sample_ids.append(r[\"_id\"])\n\n        return found_embeddings, found_sample_ids, found_label_ids\n\n    def _get_sample_embeddings(self, sample_ids, batch_size=1000):\n        found_embeddings = []\n        found_sample_ids = []\n\n        if sample_ids is None:\n            sample_ids, label_ids = self._get_index_ids(batch_size=batch_size)\n\n        for batch_ids in fou.iter_batches(sample_ids, batch_size):\n            response = self._client.mget(\n                index=self.config.index_name, ids=batch_ids, source=True\n            )\n\n            (\n                _found_embeddings,\n                _found_sample_ids,\n                _,\n            ) = self._parse_embeddings_response(\n                response[\"docs\"], label_id=False\n            )\n            found_embeddings += _found_embeddings\n            found_sample_ids += _found_sample_ids\n\n        missing_ids = list(set(sample_ids) - set(found_sample_ids))\n\n        return found_embeddings, found_sample_ids, None, missing_ids\n\n    def _get_patch_embeddings_from_label_ids(self, label_ids, batch_size=1000):\n        found_embeddings = []\n        found_sample_ids = []\n        found_label_ids = []\n\n        if label_ids is None:\n            sample_ids, label_ids = self._get_index_ids(batch_size=batch_size)\n\n        for batch_ids in fou.iter_batches(label_ids, batch_size):\n            response = self._client.mget(\n                index=self.config.index_name, ids=batch_ids, source=True\n            )\n\n            (\n                _found_embeddings,\n                _found_sample_ids,\n                _found_label_ids,\n            ) = self._parse_embeddings_response(response[\"docs\"])\n            found_embeddings += _found_embeddings\n            found_sample_ids += _found_sample_ids\n            found_label_ids += _found_label_ids\n\n        missing_ids = list(set(label_ids) - set(found_label_ids))\n\n        return found_embeddings, found_sample_ids, found_label_ids, missing_ids\n\n    def _get_patch_embeddings_from_sample_ids(\n        self, sample_ids, batch_size=100\n    ):\n        found_embeddings = []\n        found_sample_ids = []\n        found_label_ids = []\n\n        if sample_ids is None:\n            sample_ids, label_ids = self._get_index_ids(batch_size=batch_size)\n\n        for batch_ids in fou.iter_batches(sample_ids, batch_size):\n            response = self._client.search(\n                index=self.config.index_name,\n                body={\"query\": {\"terms\": {\"sample_id\": sample_ids}}},\n            )\n\n            (\n                _found_embeddings,\n                _found_sample_ids,\n                _found_label_ids,\n            ) = self._parse_embeddings_response(response[\"hits\"][\"hits\"])\n            found_embeddings += _found_embeddings\n            found_sample_ids += _found_sample_ids\n            found_label_ids += _found_label_ids\n\n        missing_ids = list(set(sample_ids) - set(found_sample_ids))\n\n        return found_embeddings, found_sample_ids, found_label_ids, missing_ids\n\n    def cleanup(self):\n        self._client.indices.delete(\n            index=self.config.index_name, ignore_unavailable=True\n        )\n\n    def _kneighbors(\n        self,\n        query=None,\n        k=None,\n        reverse=False,\n        aggregation=None,\n        return_dists=False,\n    ):\n        if query is None:\n            raise ValueError(\n                \"Elasticsearch does not support full index neighbors\"\n            )\n\n        if reverse is True:\n            raise ValueError(\n                \"Elasticsearch does not support least similarity queries\"\n            )\n\n        if aggregation not in (None, \"mean\"):\n            raise ValueError(\n                f\"Elasticsearch does not support {aggregation} aggregation\"\n            )\n\n        query = self._parse_neighbors_query(query)\n        if aggregation == \"mean\" and query.ndim == 2:\n            query = query.mean(axis=0)\n\n        single_query = query.ndim == 1\n        if single_query:\n            query = [query]\n\n        if self.has_view:\n            if self.config.patches_field is not None:\n                index_ids = self.current_label_ids\n            else:\n                index_ids = self.current_sample_ids\n\n            _filter = {\"terms\": {\"_id\": list(index_ids)}}\n        else:\n            _filter = None\n\n        sample_ids = []\n        label_ids = [] if self.config.patches_field is not None else None\n        dists = []\n        for q in query:\n            if self._get_metric() == _SUPPORTED_METRICS[\"dotproduct\"]:\n                q /= np.linalg.norm(q)\n\n            knn = {\n                \"field\": \"vector\",\n                \"query_vector\": q.tolist(),\n                \"k\": k,\n                \"num_candidates\": 10 * k,\n            }\n            if _filter:\n                knn[\"filter\"] = _filter\n\n            source = self.config.patches_field is not None\n            response = self._client.search(\n                index=self.config.index_name,\n                knn=knn,\n                size=k,\n                source=source,\n            )\n\n            if self.config.patches_field is not None:\n                sample_ids.append(\n                    [\n                        r[\"_source\"][\"sample_id\"]\n                        for r in response[\"hits\"][\"hits\"]\n                    ]\n                )\n                label_ids.append([r[\"_id\"] for r in response[\"hits\"][\"hits\"]])\n            else:\n                sample_ids.append([r[\"_id\"] for r in response[\"hits\"][\"hits\"]])\n\n            if return_dists:\n                dists.append([r[\"_score\"] for r in response[\"hits\"][\"hits\"]])\n\n        if single_query:\n            sample_ids = sample_ids[0]\n            if label_ids is not None:\n                label_ids = label_ids[0]\n            if return_dists:\n                dists = dists[0]\n\n        if return_dists:\n            return sample_ids, label_ids, dists\n\n        return sample_ids, label_ids\n\n    def _parse_neighbors_query(self, query):\n        if etau.is_str(query):\n            query_ids = [query]\n            single_query = True\n        else:\n            query = np.asarray(query)\n\n            # Query by vector(s)\n            if np.issubdtype(query.dtype, np.number):\n                return query\n\n            query_ids = list(query)\n            single_query = False\n\n        # Query by ID(s)\n        response = self._client.mget(\n            index=self.config.index_name, ids=query_ids, source=True\n        )\n        query = np.array(\n            [r[\"_source\"][\"vector\"] for r in response[\"docs\"] if r[\"found\"]]\n        )\n\n        if query.size == 0:\n            raise ValueError(\n                \"Query IDs %s were not found in the index\" % query_ids\n            )\n\n        if single_query:\n            query = query[0, :]\n\n        return query\n\n    @classmethod\n    def _from_dict(cls, d, samples, config, brain_key):\n        return cls(samples, config, brain_key)\n"
  },
  {
    "path": "fiftyone/brain/internal/core/hardness.py",
    "content": "\"\"\"\nHardness methods.\n\n| Copyright 2017-2026, Voxel51, Inc.\n| `voxel51.com <https://voxel51.com/>`_\n|\n\"\"\"\nimport logging\n\nimport numpy as np\nfrom scipy.special import softmax\nfrom scipy.stats import entropy\n\nimport fiftyone.core.brain as fob\nimport fiftyone.core.labels as fol\nimport fiftyone.core.media as fom\nimport fiftyone.core.utils as fou\nimport fiftyone.core.validation as fov\n\n\nlogger = logging.getLogger(__name__)\n\n\n_ALLOWED_TYPES = (fol.Classification, fol.Classifications)\n\n\ndef compute_hardness(samples, label_field, hardness_field, progress):\n    \"\"\"See ``fiftyone/brain/__init__.py``.\"\"\"\n\n    #\n    # Algorithm\n    #\n    # Hardness is computed directly as the entropy of the logits\n    #\n\n    fov.validate_collection(samples)\n    fov.validate_collection_label_fields(samples, label_field, _ALLOWED_TYPES)\n\n    if samples.media_type == fom.VIDEO:\n        hardness_field, _ = samples._handle_frame_field(hardness_field)\n\n    config = HardnessConfig(label_field, hardness_field)\n    brain_key = hardness_field\n    brain_method = config.build()\n    brain_method.ensure_requirements()\n    brain_method.register_run(samples, brain_key, cleanup=False)\n    brain_method.register_samples(samples)\n\n    view = samples.select_fields(label_field)\n    processing_frames = samples._is_frame_field(label_field)\n\n    logger.info(\"Computing hardness...\")\n    for sample in view.iter_samples(progress=progress):\n        if processing_frames:\n            images = sample.frames.values()\n        else:\n            images = [sample]\n\n        sample_hardness = []\n        for image in images:\n            hardness = brain_method.process_image(image)\n\n            if hardness is not None:\n                sample_hardness.append(hardness)\n\n            if processing_frames:\n                image[hardness_field] = hardness\n\n        if sample_hardness:\n            sample[hardness_field] = np.max(sample_hardness)\n        else:\n            sample[hardness_field] = None\n\n        sample.save()\n\n    brain_method.save_run_results(samples, brain_key, None)\n\n    logger.info(\"Hardness computation complete\")\n\n\n# @todo move to `fiftyone/brain/hardness.py`\nclass HardnessConfig(fob.BrainMethodConfig):\n    def __init__(self, label_field, hardness_field, **kwargs):\n        self.label_field = label_field\n        self.hardness_field = hardness_field\n        super().__init__(**kwargs)\n\n    @property\n    def type(self):\n        return \"mistakenness\"\n\n    @property\n    def method(self):\n        return \"entropy\"\n\n\nclass Hardness(fob.BrainMethod):\n    def __init__(self, config):\n        super().__init__(config)\n        self.label_field = None\n\n    def ensure_requirements(self):\n        pass\n\n    def register_samples(self, samples):\n        self.label_field, _ = samples._handle_frame_field(\n            self.config.label_field\n        )\n\n    def process_image(self, sample_or_frame):\n        label = _get_data(sample_or_frame, self.label_field)\n\n        if label is None:\n            return None\n\n        return entropy(softmax(np.asarray(label.logits)))\n\n    def get_fields(self, samples, brain_key):\n        label_field = self.config.label_field\n        hardness_field = self.config.hardness_field\n\n        fields = [label_field, hardness_field]\n\n        if samples._is_frame_field(label_field):\n            fields.append(samples._FRAMES_PREFIX + hardness_field)\n\n        return fields\n\n    def cleanup(self, samples, brain_key):\n        label_field = self.config.label_field\n        hardness_field = self.config.hardness_field\n\n        samples._dataset.delete_sample_fields(hardness_field, error_level=1)\n\n        if samples._is_frame_field(label_field):\n            samples._dataset.delete_frame_fields(hardness_field, error_level=1)\n\n    def _validate_run(self, samples, brain_key, existing_info):\n        self._validate_fields_match(brain_key, \"hardness_field\", existing_info)\n\n\ndef _get_data(sample, label_field):\n    label = sample[label_field]\n    if label is None:\n        return None\n\n    if label.logits is None:\n        raise ValueError(\n            \"Sample '%s' field '%s' has no logits\" % (sample.id, label_field)\n        )\n\n    return label\n"
  },
  {
    "path": "fiftyone/brain/internal/core/lancedb.py",
    "content": "\"\"\"\nLanceDB similarity backend.\n\n| Copyright 2017-2026, Voxel51, Inc.\n| `voxel51.com <https://voxel51.com/>`_\n|\n\"\"\"\nimport logging\n\nimport numpy as np\n\nimport eta.core.utils as etau\n\nimport fiftyone.core.storage as fos\nimport fiftyone.core.utils as fou\nimport fiftyone.brain.internal.core.utils as fbu\nfrom fiftyone.brain.similarity import (\n    SimilarityConfig,\n    Similarity,\n    SimilarityIndex,\n)\n\nlancedb = fou.lazy_import(\"lancedb\")\npa = fou.lazy_import(\"pyarrow\")\n\n\n_SUPPORTED_METRICS = {\n    \"cosine\": \"cosine\",\n    \"euclidean\": \"l2\",\n}\n\nlogger = logging.getLogger(__name__)\n\n\nclass LanceDBSimilarityConfig(SimilarityConfig):\n    \"\"\"Configuration for a LanceDB similarity instance.\n\n    Args:\n        table_name (None): the name of the LanceDB table to use. If none is\n            provided, a new table will be created\n        metric (\"cosine\"): the embedding distance metric to use when creating a\n            new index. Supported values are ``(\"cosine\", \"euclidean\")``\n        uri (\"/tmp/lancedb\"): the database URI to use\n        **kwargs: keyword arguments for :class:`SimilarityConfig`\n    \"\"\"\n\n    def __init__(\n        self,\n        table_name=None,\n        metric=\"cosine\",\n        uri=\"/tmp/lancedb\",\n        **kwargs,\n    ):\n        if metric not in _SUPPORTED_METRICS:\n            raise ValueError(\n                \"Unsupported metric '%s'. Supported values are %s\"\n                % (metric, tuple(_SUPPORTED_METRICS.keys()))\n            )\n\n        super().__init__(**kwargs)\n\n        self.table_name = table_name\n        self.metric = metric\n\n        # store privately so these aren't serialized\n        self._uri = uri\n\n    @property\n    def method(self):\n        return \"lancedb\"\n\n    @property\n    def uri(self):\n        return self._uri\n\n    @uri.setter\n    def uri(self, value):\n        self._uri = value\n\n    @property\n    def max_k(self):\n        return None\n\n    @property\n    def supports_least_similarity(self):\n        return False\n\n    @property\n    def supported_aggregations(self):\n        return (\"mean\",)\n\n    def load_credentials(self, uri=None):\n        self._load_parameters(uri=uri)\n\n\nclass LanceDBSimilarity(Similarity):\n    \"\"\"LanceDB similarity factory.\n\n    Args:\n        config: a :class:`LanceDBSimilarityConfig`\n    \"\"\"\n\n    def ensure_requirements(self):\n        fou.ensure_package(\"lancedb\")\n\n    def ensure_usage_requirements(self):\n        fou.ensure_package(\"lancedb\")\n\n    def initialize(self, samples, brain_key):\n        return LanceDBSimilarityIndex(\n            samples, self.config, brain_key, backend=self\n        )\n\n\nclass LanceDBSimilarityIndex(SimilarityIndex):\n    \"\"\"Class for interacting with LanceDB similarity indexes.\n\n    Args:\n        samples: the :class:`fiftyone.core.collections.SampleCollection` used\n        config: the :class:`LanceDBSimilarityConfig` used\n        brain_key: the brain key\n        backend (None): a :class:`LanceDBSimilarity` instance\n    \"\"\"\n\n    def __init__(self, samples, config, brain_key, backend=None):\n        super().__init__(samples, config, brain_key, backend=backend)\n        self._table = None\n        self._db = None\n        self._initialize()\n\n    def _initialize(self):\n        try:\n            db = lancedb.connect(self.config.uri)\n        except Exception as e:\n            raise ValueError(\n                \"Failed to connect to LanceDB backend at URI '%s'. Refer to \"\n                \"https://docs.voxel51.com/integrations/lancedb.html for more \"\n                \"information\" % self.config.uri\n            ) from e\n\n        table_names = db.table_names()\n\n        if self.config.table_name is None:\n            root = \"fiftyone-\" + fou.to_slug(self.samples._root_dataset.name)\n            table_name = fbu.get_unique_name(root, table_names)\n\n            self.config.table_name = table_name\n            self.save_config()\n\n        if self.config.table_name in table_names:\n            table = db.open_table(self.config.table_name)\n        else:\n            table = None\n\n        self._db = db\n        self._table = table\n\n    @property\n    def table(self):\n        \"\"\"The ``lancedb.LanceTable`` instance for this index.\"\"\"\n        return self._table\n\n    @property\n    def total_index_size(self):\n        if self._table is None:\n            return 0\n\n        return len(self._table)\n\n    def add_to_index(\n        self,\n        embeddings,\n        sample_ids,\n        label_ids=None,\n        overwrite=True,\n        allow_existing=True,\n        warn_existing=False,\n        reload=True,\n    ):\n        if self._table is None:\n            pa_table = pa.Table.from_arrays(\n                [[], [], []], names=[\"id\", \"sample_id\", \"vector\"]\n            )\n        else:\n            pa_table = self._table.to_arrow()\n\n        if label_ids is not None:\n            ids = label_ids\n        else:\n            ids = sample_ids\n\n        if warn_existing or not allow_existing or not overwrite:\n            existing_ids = set(pa_table[\"id\"].to_pylist()) & set(ids)\n            num_existing = len(existing_ids)\n\n            if num_existing > 0:\n                if not allow_existing:\n                    raise ValueError(\n                        \"Found %d IDs (eg %s) that already exist in the index\"\n                        % (num_existing, next(iter(existing_ids)))\n                    )\n\n                if warn_existing:\n                    if overwrite:\n                        logger.warning(\n                            \"Overwriting %d IDs that already exist in the \"\n                            \"index\",\n                            num_existing,\n                        )\n                    else:\n                        logger.warning(\n                            \"Skipping %d IDs that already exist in the index\",\n                            num_existing,\n                        )\n        else:\n            existing_ids = set()\n\n        if existing_ids and not overwrite:\n            del_inds = [i for i, _id in enumerate(ids) if _id in existing_ids]\n            embeddings = np.delete(embeddings, del_inds, axis=0)\n            sample_ids = np.delete(sample_ids, del_inds)\n            if label_ids is not None:\n                label_ids = np.delete(label_ids, del_inds)\n\n        if label_ids is not None:\n            ids = list(label_ids)\n        else:\n            ids = list(sample_ids)\n\n        dim = embeddings.shape[1]\n\n        if self._table:\n            prev_embeddings = np.concatenate(\n                pa_table[\"vector\"].to_numpy()\n            ).reshape(-1, dim)\n            embeddings = np.concatenate([prev_embeddings, embeddings])\n            ids = pa_table[\"id\"].to_pylist() + ids\n            sample_ids = pa_table[\"sample_id\"].to_pylist() + sample_ids\n\n        embeddings = pa.array(embeddings.reshape(-1), type=pa.float32())\n        embeddings = pa.FixedSizeListArray.from_arrays(embeddings, dim)\n        sample_ids = list(sample_ids)\n        pa_table = pa.Table.from_arrays(\n            [ids, sample_ids, embeddings], names=[\"id\", \"sample_id\", \"vector\"]\n        )\n        self._table = self._db.create_table(\n            self.config.table_name, pa_table, mode=\"overwrite\"\n        )\n\n        if reload:\n            self.reload()\n\n    def remove_from_index(\n        self,\n        sample_ids=None,\n        label_ids=None,\n        allow_missing=True,\n        warn_missing=False,\n        reload=True,\n    ):\n        if label_ids is not None:\n            ids = label_ids\n        else:\n            ids = sample_ids\n\n        if not allow_missing or warn_missing:\n            existing_ids = list(self._index.fetch(ids).vectors.keys())\n            missing_ids = set(ids) - set(existing_ids)\n            num_missing = len(missing_ids)\n\n            if num_missing > 0:\n                if not allow_missing:\n                    raise ValueError(\n                        \"Found %d IDs (eg %s) that are not present in the \"\n                        \"index\" % (num_missing, next(iter(missing_ids)))\n                    )\n\n                if warn_missing:\n                    logger.warning(\n                        \"Ignoring %d IDs that are not present in the index\",\n                        num_missing,\n                    )\n\n                ids = existing_ids\n\n        df = self._table.to_pandas()\n        df = df[~df[\"id\"].isin(ids)]\n        self._table = self._db.create_table(\n            self.config.table_name, df, mode=\"overwrite\"\n        )\n\n        if reload:\n            self.reload()\n\n    def get_embeddings(\n        self,\n        sample_ids=None,\n        label_ids=None,\n        allow_missing=True,\n        warn_missing=False,\n    ):\n        if label_ids is not None:\n            if self.config.patches_field is None:\n                raise ValueError(\"This index does not support label IDs\")\n\n            if sample_ids is not None:\n                logger.warning(\n                    \"Ignoring sample IDs when label IDs are provided\"\n                )\n\n        df = self._table.to_pandas()\n\n        found_embeddings = []\n        found_sample_ids = []\n        found_label_ids = []\n        missing_ids = []\n\n        if sample_ids is not None and self.config.patches_field is not None:\n            df.set_index(\"sample_id\", drop=False, inplace=True)\n\n            if not etau.is_container(sample_ids):\n                sample_ids = [sample_ids]\n\n            for sample_id in sample_ids:\n                if sample_id in df.index:\n                    found_embeddings.append(df.loc[sample_id][\"vector\"])\n                    found_sample_ids.append(sample_id)\n                    found_label_ids.append(df.loc[sample_id][\"id\"])\n                else:\n                    missing_ids.append(sample_id)\n        elif self.config.patches_field is not None:\n            df.set_index(\"id\", drop=False, inplace=True)\n\n            if label_ids is None:\n                label_ids = list(df.index)\n            elif not etau.is_container(label_ids):\n                label_ids = [label_ids]\n\n            for label_id in label_ids:\n                if label_id in df.index:\n                    found_embeddings.append(df.loc[label_id][\"vector\"])\n                    found_sample_ids.append(df.loc[label_id][\"sample_id\"])\n                    found_label_ids.append(label_id)\n                else:\n                    missing_ids.append(label_id)\n        else:\n            df.set_index(\"id\", drop=False, inplace=True)\n\n            if sample_ids is None:\n                sample_ids = list(df.index)\n            elif not etau.is_container(sample_ids):\n                sample_ids = [sample_ids]\n\n            for sample_id in sample_ids:\n                if sample_id in df.index:\n                    found_embeddings.append(df.loc[sample_id][\"vector\"])\n                    found_sample_ids.append(sample_id)\n                else:\n                    missing_ids.append(sample_id)\n\n        num_missing_ids = len(missing_ids)\n        if num_missing_ids > 0:\n            if not allow_missing:\n                raise ValueError(\n                    \"Found %d IDs (eg %s) that do not exist in the index\"\n                    % (num_missing_ids, missing_ids[0])\n                )\n\n            if warn_missing:\n                logger.warning(\n                    \"Skipping %d IDs that do not exist in the index\",\n                    num_missing_ids,\n                )\n\n        embeddings = np.array(found_embeddings)\n        sample_ids = np.array(found_sample_ids)\n        if label_ids is not None:\n            label_ids = np.array(found_label_ids)\n\n        return embeddings, sample_ids, label_ids\n\n    def cleanup(self):\n        if self._db is None:\n            return\n\n        for tbl in (\n            self.config.table_name,\n            self.config.table_name + \"_filter\",\n        ):\n            if tbl in self._db.table_names():\n                self._db.drop_table(tbl)\n\n        self._table = None\n\n    def _kneighbors(\n        self,\n        query=None,\n        k=None,\n        reverse=False,\n        aggregation=None,\n        return_dists=False,\n    ):\n        if query is None:\n            raise ValueError(\"LanceDB does not support full index neighbors\")\n\n        if reverse is True:\n            raise ValueError(\n                \"LanceDB does not support least similarity queries\"\n            )\n\n        if aggregation not in (None, \"mean\"):\n            raise ValueError(\n                f\"LanceDB does not support {aggregation} aggregation\"\n            )\n\n        if k is None:\n            k = self.index_size\n\n        query = self._parse_neighbors_query(query)\n        if aggregation == \"mean\" and query.ndim == 2:\n            query = query.mean(axis=0)\n\n        single_query = query.ndim == 1\n        if single_query:\n            query = [query]\n\n        table = self._table\n\n        if self.has_view:\n            if self.config.patches_field is not None:\n                index_ids = list(self.current_label_ids)\n            else:\n                index_ids = list(self.current_sample_ids)\n\n            df = table.to_pandas()\n            df = df[df[\"id\"].isin(index_ids)]\n            table = self._db.create_table(\n                self.config.table_name + \"_filter\", df, mode=\"overwrite\"\n            )\n\n        metric = _SUPPORTED_METRICS[self.config.metric]\n\n        sample_ids = []\n        label_ids = [] if self.config.patches_field is not None else None\n        dists = []\n        for q in query:\n            results = table.search(q).metric(metric).limit(k).to_df()\n\n            if self.config.patches_field is not None:\n                sample_ids.append(results.sample_id.tolist())\n                label_ids.append(results.id.tolist())\n            else:\n                sample_ids.append(results.id.tolist())\n\n            if return_dists:\n                dists.append(results._distance.tolist())\n\n        if single_query:\n            sample_ids = sample_ids[0]\n            if label_ids is not None:\n                label_ids = label_ids[0]\n            if return_dists:\n                dists = dists[0]\n\n        if return_dists:\n            return sample_ids, label_ids, dists\n\n        return sample_ids, label_ids\n\n    def _parse_neighbors_query(self, query):\n        if etau.is_str(query):\n            query_ids = [query]\n            single_query = True\n        else:\n            query = np.asarray(query)\n\n            # Query by vector(s)\n            if np.issubdtype(query.dtype, np.number):\n                return query\n\n            query_ids = list(query)\n            single_query = False\n\n        # Query by ID(s)\n        df = self._table.to_pandas()\n        df = df[df[\"id\"].isin(query_ids)]\n        query = np.array([v for v in df[\"vector\"]])\n\n        if query.size == 0:\n            raise ValueError(\n                \"Query IDs %s were not found in the index\" % query_ids\n            )\n\n        if single_query:\n            query = query[0, :]\n\n        return query\n\n    @classmethod\n    def _from_dict(cls, d, samples, config, brain_key):\n        return cls(samples, config, brain_key)\n"
  },
  {
    "path": "fiftyone/brain/internal/core/leaky_splits.py",
    "content": "\"\"\"\nFinds leaks between splits.\n\n| Copyright 2017-2026, Voxel51, Inc.\n| `voxel51.com <https://voxel51.com/>`_\n|\n\"\"\"\nimport logging\n\nimport eta.core.utils as etau\n\nimport fiftyone.core.brain as fob\nimport fiftyone.core.fields as fof\nimport fiftyone.core.validation as fov\nimport fiftyone.zoo as foz\nfrom fiftyone import ViewField as F\n\nimport fiftyone.brain as fb\nimport fiftyone.brain.similarity as fbs\nimport fiftyone.brain.internal.core.utils as fbu\n\n\nlogger = logging.getLogger(__name__)\n\n_DEFAULT_MODEL = \"resnet18-imagenet-torch\"\n_DEFAULT_BATCH_SIZE = None\n\n\ndef compute_leaky_splits(\n    samples,\n    splits,\n    threshold=None,\n    roi_field=None,\n    embeddings=None,\n    similarity_index=None,\n    model=None,\n    model_kwargs=None,\n    force_square=False,\n    alpha=None,\n    batch_size=None,\n    num_workers=None,\n    skip_failures=True,\n    progress=None,\n):\n    \"\"\"See ``fiftyone/brain/__init__.py``.\"\"\"\n\n    fov.validate_collection(samples)\n\n    if etau.is_str(embeddings):\n        embeddings_field, embeddings_exist = fbu.parse_data_field(\n            samples,\n            embeddings,\n            data_type=\"embeddings\",\n        )\n        embeddings = None\n    else:\n        embeddings_field = None\n        embeddings_exist = None\n\n    if etau.is_str(similarity_index):\n        similarity_index = samples.load_brain_results(similarity_index)\n\n    if (\n        model is None\n        and embeddings is None\n        and similarity_index is None\n        and not embeddings_exist\n    ):\n        model = foz.load_zoo_model(_DEFAULT_MODEL)\n        if batch_size is None:\n            batch_size = _DEFAULT_BATCH_SIZE\n\n    config = LeakySplitsConfig(\n        splits=splits,\n        embeddings_field=embeddings_field,\n        similarity_index=similarity_index,\n        model=model,\n        model_kwargs=model_kwargs,\n    )\n\n    brain_method = config.build()\n    brain_method.ensure_requirements()\n\n    if similarity_index is None:\n        similarity_index = fb.compute_similarity(\n            samples,\n            backend=\"sklearn\",\n            roi_field=roi_field,\n            embeddings=embeddings_field or embeddings,\n            model=model,\n            model_kwargs=model_kwargs,\n            force_square=force_square,\n            alpha=alpha,\n            batch_size=batch_size,\n            num_workers=num_workers,\n            skip_failures=skip_failures,\n            progress=progress,\n        )\n    elif not isinstance(similarity_index, fbs.DuplicatesMixin):\n        raise ValueError(\n            \"This method only supports similarity indexes that implement the \"\n            \"%s mixin\" % fbs.DuplicatesMixin\n        )\n\n    split_views = _to_split_views(samples, splits)\n\n    index = brain_method.initialize(samples, similarity_index, split_views)\n\n    if threshold is not None:\n        index.find_leaks(threshold)\n\n    return index\n\n\nclass LeakySplitsConfig(fob.BrainMethodConfig):\n    def __init__(\n        self,\n        splits=None,\n        embeddings_field=None,\n        similarity_index=None,\n        model=None,\n        model_kwargs=None,\n        **kwargs,\n    ):\n        if isinstance(splits, dict):\n            splits = None\n\n        if similarity_index is not None and not etau.is_str(similarity_index):\n            similarity_index = similarity_index.key\n\n        if model is not None and not etau.is_str(model):\n            model = etau.get_class_name(model)\n\n        self.splits = splits\n        self.embeddings_field = embeddings_field\n        self.similarity_index = similarity_index\n        self.model = model\n        self.model_kwargs = model_kwargs\n\n        super().__init__(**kwargs)\n\n    @property\n    def type(self):\n        return \"leakage\"\n\n    @property\n    def method(self):\n        return \"similarity\"\n\n\nclass LeakySplits(fob.BrainMethod):\n    def initialize(self, samples, similarity_index, split_views):\n        return LeakySplitsIndex(\n            samples, self.config, similarity_index, split_views\n        )\n\n    def get_fields(self, samples, _):\n        fields = []\n        if self.config.embeddings_field is not None:\n            fields.append(self.config.embeddings_field)\n\n        return fields\n\n\nclass LeakySplitsIndex(fob.BrainResults):\n    def __init__(self, samples, config, similarity_index, split_views):\n        super().__init__(samples, config, None)\n\n        self._similarity_index = similarity_index\n        self._split_views = split_views\n        self._id2split = None\n        self._thresh = None\n        self._leak_ids = None\n\n        self._initialize()\n\n    @property\n    def split_views(self):\n        \"\"\"A dict mapping split names to views.\"\"\"\n        return self._split_views\n\n    @property\n    def thresh(self):\n        \"\"\"The threshold used by the last call to :meth:`find_leaks`.\"\"\"\n        return self._thresh\n\n    @property\n    def leak_ids(self):\n        \"\"\"The list of leaky sample IDs from the last call to\n        :meth:`find_leaks`.\n        \"\"\"\n        return self._leak_ids\n\n    def find_leaks(self, thresh):\n        \"\"\"Scans the index for leaks between splits.\n\n        Args:\n            thresh: the similarity distance threshold to use when detecting\n                potential leaks\n        \"\"\"\n        if thresh == self._thresh:\n            return\n\n        # Find duplicates\n        self._thresh = thresh\n        if self._similarity_index.thresh != self._thresh:\n            self._similarity_index.find_duplicates(self._thresh)\n\n        # Filter duplicates to just those with neighbors in different splits\n        leak_ids = []\n        neighbors_map = self._similarity_index.neighbors_map\n        for sample_id, neighbors in neighbors_map.items():\n            _leak_ids = []\n\n            sample_split = self._id2split.get(sample_id, None)\n            if sample_split is None:\n                continue\n\n            for n in neighbors:\n                neighbor_id = n[0]\n                neighbor_split = self._id2split.get(neighbor_id, None)\n                if neighbor_split is None:\n                    continue\n\n                if neighbor_split != sample_split:\n                    _leak_ids.append(neighbor_id)\n\n            if _leak_ids:\n                leak_ids.append(sample_id)\n                leak_ids.extend(_leak_ids)\n\n        self._leak_ids = leak_ids\n\n    def leaks_view(self):\n        \"\"\"Returns a view containg all potential leaks generated by the last\n        call to :meth:`find_leaks`.\n\n        Returns:\n            a :class:`fiftyone.core.view.DatasetView`\n        \"\"\"\n        if self._thresh is None:\n            raise ValueError(\"You must first call `find_leaks()`\")\n\n        return self.samples.select(self._leak_ids, ordered=True)\n\n    def leaks_for_sample(self, sample_or_id):\n        \"\"\"Returns a view that contains all leaks related to the given sample.\n\n        The given sample is always first in the returned view, followed by any\n        related leaks.\n\n        Args:\n            sample_or_id: a :class:`fiftyone.core.sample.Sample` or sample ID\n\n        Returns:\n            a :class:`fiftyone.core.view.DatasetView`\n        \"\"\"\n        if self._thresh is None:\n            raise ValueError(\"You must first call `find_leaks()`\")\n\n        if etau.is_str(sample_or_id):\n            sample_id = sample_or_id\n        else:\n            sample_id = sample_or_id.id\n\n        sample_split = self._id2split[sample_id]\n        neighbors_map = self._similarity_index.neighbors_map\n\n        leak_ids = []\n        if sample_id in neighbors_map.keys():\n            neighbors = neighbors_map[sample_id]\n            leak_ids = [\n                n[0] for n in neighbors if self._id2split[n[0]] != sample_split\n            ]\n        else:\n            for unique_id, neighbors in neighbors_map.items():\n                if sample_id in [n[0] for n in neighbors]:\n                    leak_ids = [\n                        n[0]\n                        for n in neighbors\n                        if self._id2split[n[0]] != sample_split\n                    ]\n                    leak_ids.append(unique_id)\n                    break\n\n        return self.samples.select([sample_id] + leak_ids, ordered=True)\n\n    def no_leaks_view(self, view=None):\n        \"\"\"Returns a view with leaks excluded.\n\n        Args:\n            view (None): an optional :class:`fiftyone.core.view.DatasetView`\n                from which to exclude. By default, :meth:`samples` is used\n        \"\"\"\n        if self._thresh is None:\n            raise ValueError(\"You must first call `find_leaks()`\")\n\n        if view is None:\n            view = self.samples\n\n        return view.exclude(self._leak_ids)\n\n    def tag_leaks(self, tag=\"leak\"):\n        \"\"\"Tags all potential leaks in :meth:`leaks_view` with the given tag.\n\n        Args:\n            tag (\"leak\"): the tag string to apply\n        \"\"\"\n        self.leaks_view().tag_samples(tag)\n\n    def _initialize(self):\n        id2split = {}\n\n        split_ids = {}\n        for split_name, split_view in self.split_views.items():\n            sample_ids = set(split_view.values(\"id\"))\n            split_ids[split_name] = sample_ids\n            id2split.update({sid: split_name for sid in sample_ids})\n\n        # Check for overlapping splits\n        split_names = list(split_ids.keys())\n        for idx, split1 in enumerate(split_names):\n            for split2 in split_names[idx + 1 :]:\n                overlap = split_ids[split1] & split_ids[split2]\n                if overlap:\n                    logger.warning(\n                        \"The '%s' and '%s' splits contain %d overlapping samples.\"\n                        \"Use dataset.match_tags('%s').match_tags('%s') to \"\n                        \"identify them\",\n                        split1,\n                        split2,\n                        len(overlap),\n                        split1,\n                        split2,\n                    )\n\n        # Check for samples not in index\n        index_ids = self._similarity_index.sample_ids\n        if index_ids is not None:\n            index_ids = set(index_ids)\n            all_split_ids = set(id2split.keys())\n\n            missing_ids = all_split_ids - index_ids\n            if missing_ids:\n                logger.warning(\n                    \"The provided splits contain %d samples (eg '%s') that \"\n                    \"are not present in the index\",\n                    len(missing_ids),\n                    next(iter(missing_ids)),\n                )\n\n        self._id2split = id2split\n\n\ndef _to_split_views(samples, splits):\n    if etau.is_container(splits):\n        return {tag: samples.match_tags(tag) for tag in splits}\n\n    if isinstance(splits, str):\n        field = samples.get_field(splits)\n        if isinstance(field, fof.ListField):\n            return {\n                value: samples.exists(splits).match(F(splits).contains(value))\n                for value in samples.distinct(splits)\n            }\n        else:\n            return {\n                value: samples.match(F(splits) == value)\n                for value in samples.distinct(splits)\n            }\n"
  },
  {
    "path": "fiftyone/brain/internal/core/milvus.py",
    "content": "\"\"\"\nMilvus similarity backend.\n\n| Copyright 2017-2026, Voxel51, Inc.\n| `voxel51.com <https://voxel51.com/>`_\n|\n\"\"\"\nimport logging\n\nimport numpy as np\nfrom uuid import uuid4\n\nimport eta.core.utils as etau\n\nimport fiftyone.core.utils as fou\nfrom fiftyone.brain.similarity import (\n    SimilarityConfig,\n    Similarity,\n    SimilarityIndex,\n)\nimport fiftyone.brain.internal.core.utils as fbu\n\npymilvus = fou.lazy_import(\"pymilvus\")\n\n\nlogger = logging.getLogger(__name__)\n\n_SUPPORTED_METRICS = {\n    \"cosine\": \"COSINE\",\n    \"dotproduct\": \"IP\",\n    \"euclidean\": \"L2\",\n}\n\n\nclass MilvusSimilarityConfig(SimilarityConfig):\n    \"\"\"Configuration for the Milvus similarity backend.\n\n    Args:\n        collection_name (None): the name of a Milvus collection to use or\n            create. If none is provided, a new collection will be created\n        metric (\"dotproduct\"): the embedding distance metric to use when\n            creating a new index. Supported values are\n            ``(\"cosine\", \"dotproduct\", \"euclidean\")``\n        consistency_level (\"Session\"): the consistency level to use. Supported\n            values are ``(\"Session\", \"Strong\", \"Bounded\", \"Eventually\")``\n        uri (None): a full Milvus server address to use, like\n            ``\"http://localhost:19530\"``,\n            ``\"tcp:localhost:19530\"``, or\n            ``\"https://ok.s3.south.com:19530\"``\n        user (None): a username to use\n        password (None): a password to use\n        secure (None): whether to enable TLS (True)\n        token (None): a header token for RPC calls\n        db_name (None): a database name for the connection\n        client_key_path (None): a client.key path for TLS two-way\n        client_pem_path (None): a client.pem path for TLS two-way\n        ca_pem_path (None): a ca.pem path for TLS two-way\n        server_pem_path (None): a server.pem path for TLS one-way\n        server_name (None): the server name, for TLS\n        **kwargs: keyword arguments for\n            :class:`fiftyone.brain.similarity.SimilarityConfig`\n    \"\"\"\n\n    def __init__(\n        self,\n        collection_name=None,\n        metric=\"dotproduct\",\n        consistency_level=\"Session\",\n        uri=None,\n        user=None,\n        password=None,\n        secure=None,\n        token=None,\n        db_name=None,\n        client_key_path=None,\n        client_pem_path=None,\n        ca_pem_path=None,\n        server_pem_path=None,\n        server_name=None,\n        **kwargs,\n    ):\n        if metric not in _SUPPORTED_METRICS:\n            raise ValueError(\n                \"Unsupported metric '%s'. Supported values are %s\"\n                % (metric, tuple(_SUPPORTED_METRICS.keys()))\n            )\n\n        super().__init__(**kwargs)\n\n        self.collection_name = collection_name\n        self.metric = metric\n        self.consistency_level = consistency_level\n\n        # store privately so these aren't serialized\n        self._uri = uri\n        self._user = user\n        self._password = password\n        self._secure = secure\n        self._token = token\n        self._db_name = db_name\n        self._client_key_path = client_key_path\n        self._client_pem_path = client_pem_path\n        self._ca_pem_path = ca_pem_path\n        self._server_pem_path = server_pem_path\n        self._server_name = server_name\n\n    @property\n    def method(self):\n        return \"milvus\"\n\n    @property\n    def uri(self):\n        return self._uri\n\n    @uri.setter\n    def uri(self, value):\n        self._uri = value\n\n    @property\n    def user(self):\n        return self._user\n\n    @user.setter\n    def user(self, value):\n        self._user = value\n\n    @property\n    def password(self):\n        return self._password\n\n    @password.setter\n    def password(self, value):\n        self._password = value\n\n    @property\n    def secure(self):\n        return self._secure\n\n    @secure.setter\n    def secure(self, value):\n        self._secure = value\n\n    @property\n    def token(self):\n        return self._token\n\n    @token.setter\n    def token(self, value):\n        self._token = value\n\n    @property\n    def db_name(self):\n        return self._db_name\n\n    @db_name.setter\n    def db_name(self, value):\n        self._db_name = value\n\n    @property\n    def client_key_path(self):\n        return self._client_key_path\n\n    @client_key_path.setter\n    def client_key_path(self, value):\n        self._client_key_path = value\n\n    @property\n    def client_pem_path(self):\n        return self._client_pem_path\n\n    @client_pem_path.setter\n    def client_pem_path(self, value):\n        self._client_pem_path = value\n\n    @property\n    def ca_pem_path(self):\n        return self._ca_pem_path\n\n    @ca_pem_path.setter\n    def ca_pem_path(self, value):\n        self._ca_pem_path = value\n\n    @property\n    def server_pem_path(self):\n        return self._server_pem_path\n\n    @server_pem_path.setter\n    def server_pem_path(self, value):\n        self._server_pem_path = value\n\n    @property\n    def server_name(self):\n        return self._server_name\n\n    @server_name.setter\n    def server_name(self, value):\n        self._server_name = value\n\n    @property\n    def max_k(self):\n        return 16384\n\n    @property\n    def supports_least_similarity(self):\n        return False\n\n    @property\n    def supported_aggregations(self):\n        return (\"mean\",)\n\n    @property\n    def index_params(self):\n        return {\n            \"metric_type\": _SUPPORTED_METRICS[self.metric],\n            \"index_type\": \"HNSW\",\n            \"params\": {\"M\": 8, \"efConstruction\": 64},\n        }\n\n    @property\n    def search_params(self):\n        return {\n            \"HNSW\": {\n                \"metric_type\": _SUPPORTED_METRICS[self.metric],\n                \"params\": {\"ef\": 10},\n            },\n        }\n\n    def load_credentials(\n        self,\n        uri=None,\n        user=None,\n        password=None,\n        secure=None,\n        token=None,\n        db_name=None,\n        client_key_path=None,\n        client_pem_path=None,\n        ca_pem_path=None,\n        server_pem_path=None,\n        server_name=None,\n    ):\n        self._load_parameters(\n            uri=uri,\n            user=user,\n            password=password,\n            secure=secure,\n            token=token,\n            db_name=db_name,\n            client_key_path=client_key_path,\n            client_pem_path=client_pem_path,\n            ca_pem_path=ca_pem_path,\n            server_pem_path=server_pem_path,\n            server_name=server_name,\n        )\n\n\nclass MilvusSimilarity(Similarity):\n    \"\"\"Milvus similarity factory.\n\n    Args:\n        config: a :class:`MilvusSimilarityConfig`\n    \"\"\"\n\n    def ensure_requirements(self):\n        fou.ensure_package(\"pymilvus\")\n\n    def ensure_usage_requirements(self):\n        fou.ensure_package(\"pymilvus\")\n\n    def initialize(self, samples, brain_key):\n        return MilvusSimilarityIndex(\n            samples, self.config, brain_key, backend=self\n        )\n\n\nclass MilvusSimilarityIndex(SimilarityIndex):\n    \"\"\"Class for interacting with Milvus similarity indexes.\n\n    Args:\n        samples: the :class:`fiftyone.core.collections.SampleCollection` used\n        config: the :class:`MilvusSimilarityConfig` used\n        brain_key: the brain key\n        backend (None): a :class:`MilvusSimilarity` instance\n    \"\"\"\n\n    def __init__(self, samples, config, brain_key, backend=None):\n        super().__init__(samples, config, brain_key, backend=backend)\n        self._alias = None\n        self._collection = None\n        self._initialize()\n\n    def _initialize(self):\n        kwargs = {}\n\n        for key in (\n            \"uri\",\n            \"user\",\n            \"password\",\n            \"secure\",\n            \"token\",\n            \"db_name\",\n            \"client_key_path\",\n            \"client_pem_path\",\n            \"ca_pem_path\",\n            \"server_pem_path\",\n            \"server_name\",\n        ):\n            value = getattr(self.config, key, None)\n            if value is not None:\n                kwargs[key] = value\n\n        alias = uuid4().hex if kwargs else \"default\"\n\n        try:\n            pymilvus.connections.connect(alias=alias, **kwargs)\n        except pymilvus.MilvusException as e:\n            raise ValueError(\n                \"Failed to connect to Milvus backend at URI '%s'. Refer to \"\n                \"https://docs.voxel51.com/integrations/milvus.html for more \"\n                \"information\" % self.config.uri\n            ) from e\n\n        collection_names = pymilvus.utility.list_collections(using=alias)\n\n        if self.config.collection_name is None:\n            # Milvus only supports numbers, letters and underscores\n            root = \"fiftyone-\" + fou.to_slug(self.samples._root_dataset.name)\n            root = root.replace(\"-\", \"_\")\n            collection_name = fbu.get_unique_name(root, collection_names)\n            collection_name = collection_name.replace(\"-\", \"_\")\n\n            self.config.collection_name = collection_name\n            self.save_config()\n\n        if self.config.collection_name in collection_names:\n            collection = pymilvus.Collection(\n                self.config.collection_name, using=alias\n            )\n            collection.load()\n        else:\n            collection = None\n\n        self._alias = alias\n        self._collection = collection\n\n    def _create_collection(self, dimension):\n        schema = pymilvus.CollectionSchema(\n            [\n                pymilvus.FieldSchema(\n                    \"pk\",\n                    pymilvus.DataType.VARCHAR,\n                    is_primary=True,\n                    auto_id=False,\n                    max_length=64000,\n                ),\n                pymilvus.FieldSchema(\n                    \"vector\", pymilvus.DataType.FLOAT_VECTOR, dim=dimension\n                ),\n                pymilvus.FieldSchema(\n                    \"sample_id\", pymilvus.DataType.VARCHAR, max_length=64000\n                ),\n            ]\n        )\n\n        collection = pymilvus.Collection(\n            self.config.collection_name,\n            schema,\n            consistency_level=self.config.consistency_level,\n            using=self._alias,\n        )\n        collection.create_index(\n            \"vector\", index_params=self.config.index_params\n        )\n        collection.load()\n\n        self._collection = collection\n\n    @property\n    def collection(self):\n        \"\"\"The ``pymilvus.Collection`` instance for this index.\"\"\"\n        return self._collection\n\n    @property\n    def total_index_size(self):\n        if self._collection is None:\n            return 0\n\n        return self._collection.num_entities\n\n    def add_to_index(\n        self,\n        embeddings,\n        sample_ids,\n        label_ids=None,\n        overwrite=True,\n        allow_existing=True,\n        warn_existing=False,\n        reload=True,\n        batch_size=100,\n    ):\n        if self._collection is None:\n            self._create_collection(embeddings.shape[1])\n\n        if label_ids is not None:\n            ids = label_ids\n        else:\n            ids = sample_ids\n\n        if warn_existing or not allow_existing or not overwrite:\n            existing_ids = self._get_existing_ids(ids)\n            num_existing = len(existing_ids)\n\n            if num_existing > 0:\n                if not allow_existing:\n                    raise ValueError(\n                        \"Found %d IDs (eg %s) that already exist in the index\"\n                        % (num_existing, next(iter(existing_ids)))\n                    )\n\n                if warn_existing:\n                    if overwrite:\n                        logger.warning(\n                            \"Overwriting %d IDs that already exist in the \"\n                            \"index\",\n                            num_existing,\n                        )\n                    else:\n                        logger.warning(\n                            \"Skipping %d IDs that already exist in the index\",\n                            num_existing,\n                        )\n        else:\n            existing_ids = set()\n\n        if existing_ids and not overwrite:\n            del_inds = [i for i, _id in enumerate(ids) if _id in existing_ids]\n            embeddings = np.delete(embeddings, del_inds, axis=0)\n            sample_ids = np.delete(sample_ids, del_inds)\n            if label_ids is not None:\n                label_ids = np.delete(label_ids, del_inds)\n        elif existing_ids and overwrite:\n            self._delete_ids(existing_ids)\n\n        embeddings = [e.tolist() for e in embeddings]\n        sample_ids = list(sample_ids)\n        ids = list(ids)\n\n        for _embeddings, _ids, _sample_ids in zip(\n            fou.iter_batches(embeddings, batch_size),\n            fou.iter_batches(ids, batch_size),\n            fou.iter_batches(sample_ids, batch_size),\n        ):\n            insert_data = [\n                list(_ids),\n                list(_embeddings),\n                list(_sample_ids),\n            ]\n            self._collection.insert(insert_data)\n\n        self._collection.flush()\n\n        if reload:\n            self.reload()\n\n    def _get_existing_ids(self, ids):\n        ids = ['\"' + str(entry) + '\"' for entry in ids]\n        expr = f\"\"\"pk in [{','.join(ids)}]\"\"\"\n        return self._collection.query(expr)\n\n    def _delete_ids(self, ids):\n        ids = ['\"' + str(entry) + '\"' for entry in ids]\n        expr = f\"\"\"pk in [{','.join(ids)}]\"\"\"\n        self._collection.delete(expr)\n        self._collection.flush()\n\n    def _get_embeddings(self, ids):\n        ids = ['\"' + str(entry) + '\"' for entry in ids]\n        expr = f\"\"\"pk in [{','.join(ids)}]\"\"\"\n        return self._collection.query(\n            expr, output_fields=[\"pk\", \"sample_id\", \"vector\"]\n        )\n\n    def remove_from_index(\n        self,\n        sample_ids=None,\n        label_ids=None,\n        allow_missing=True,\n        warn_missing=False,\n        reload=True,\n    ):\n        if label_ids is not None:\n            ids = label_ids\n        else:\n            ids = sample_ids\n\n        if not allow_missing or warn_missing:\n            existing_ids = self._get_existing_ids(ids)\n            missing_ids = set(ids) - set(existing_ids)\n            num_missing = len(missing_ids)\n\n            if num_missing > 0:\n                if not allow_missing:\n                    raise ValueError(\n                        \"Found %d IDs (eg %s) that are not present in the \"\n                        \"index\" % (num_missing, next(iter(missing_ids)))\n                    )\n\n                if warn_missing:\n                    logger.warning(\n                        \"Ignoring %d IDs that are not present in the index\",\n                        num_missing,\n                    )\n\n                ids = existing_ids\n\n        self._delete_ids(ids=ids)\n\n        if reload:\n            self.reload()\n\n    def get_embeddings(\n        self,\n        sample_ids=None,\n        label_ids=None,\n        allow_missing=True,\n        warn_missing=False,\n    ):\n        if label_ids is not None:\n            if self.config.patches_field is None:\n                raise ValueError(\"This index does not support label IDs\")\n\n            if sample_ids is not None:\n                logger.warning(\n                    \"Ignoring sample IDs when label IDs are provided\"\n                )\n\n        if sample_ids is not None and self.config.patches_field is not None:\n            (\n                embeddings,\n                sample_ids,\n                label_ids,\n                missing_ids,\n            ) = self._get_patch_embeddings_from_sample_ids(sample_ids)\n        elif self.config.patches_field is not None:\n            (\n                embeddings,\n                sample_ids,\n                label_ids,\n                missing_ids,\n            ) = self._get_patch_embeddings_from_label_ids(label_ids)\n        else:\n            (\n                embeddings,\n                sample_ids,\n                label_ids,\n                missing_ids,\n            ) = self._get_sample_embeddings(sample_ids)\n\n        num_missing_ids = len(missing_ids)\n        if num_missing_ids > 0:\n            if not allow_missing:\n                raise ValueError(\n                    \"Found %d IDs (eg %s) that do not exist in the index\"\n                    % (num_missing_ids, missing_ids[0])\n                )\n\n            if warn_missing:\n                logger.warning(\n                    \"Skipping %d IDs that do not exist in the index\",\n                    num_missing_ids,\n                )\n\n        embeddings = np.array(embeddings)\n        sample_ids = np.array(sample_ids)\n        if label_ids is not None:\n            label_ids = np.array(label_ids)\n\n        return embeddings, sample_ids, label_ids\n\n    def cleanup(self):\n        pymilvus.utility.drop_collection(\n            self.config.collection_name, using=self._alias\n        )\n        self._collection = None\n\n    def _get_sample_embeddings(self, sample_ids, batch_size=1000):\n        found_embeddings = []\n        found_sample_ids = []\n\n        if sample_ids is None:\n            raise ValueError(\n                \"Milvus does not support retrieving all vectors in an index\"\n            )\n\n        for batch_ids in fou.iter_batches(sample_ids, batch_size):\n            response = self._get_embeddings(list(batch_ids))\n\n            for r in response:\n                found_embeddings.append(r[\"vector\"])\n                found_sample_ids.append(r[\"sample_id\"])\n\n        missing_ids = list(set(sample_ids) - set(found_sample_ids))\n\n        return found_embeddings, found_sample_ids, None, missing_ids\n\n    def _get_patch_embeddings_from_label_ids(self, label_ids, batch_size=1000):\n        found_embeddings = []\n        found_sample_ids = []\n        found_label_ids = []\n\n        if label_ids is None:\n            raise ValueError(\n                \"Milvus does not support retrieving all vectors in an index\"\n            )\n\n        for batch_ids in fou.iter_batches(label_ids, batch_size):\n            response = self._get_embeddings(list(batch_ids))\n\n            for r in response:\n                found_embeddings.append(r[\"vector\"])\n                found_sample_ids.append(r[\"sample_id\"])\n                found_label_ids.append(r[\"pk\"])\n\n        missing_ids = list(set(label_ids) - set(found_label_ids))\n\n        return found_embeddings, found_sample_ids, found_label_ids, missing_ids\n\n    def _get_patch_embeddings_from_sample_ids(\n        self, sample_ids, batch_size=100\n    ):\n        found_embeddings = []\n        found_sample_ids = []\n        found_label_ids = []\n\n        query_vector = [0.0] * self._get_dimension()\n        top_k = min(batch_size, self.config.max_k)\n\n        for batch_ids in fou.iter_batches(sample_ids, batch_size):\n            ids = ['\"' + str(entry) + '\"' for entry in batch_ids]\n            expr = f\"\"\"pk in [{','.join(ids)}]\"\"\"\n            response = self._collection.search(\n                data=[query_vector],\n                anns_field=\"vector\",\n                param=self.config.search_params,\n                expr=expr,\n                limit=top_k,\n            )\n            ids = [x.id for x in response[0]]\n            response = self._get_embeddings(ids)\n            for r in response:\n                found_embeddings.append(r[\"vector\"])\n                found_sample_ids.append(r[\"sample_id\"])\n                found_label_ids.append(r[\"pk\"])\n\n        missing_ids = list(set(sample_ids) - set(found_sample_ids))\n\n        return found_embeddings, found_sample_ids, found_label_ids, missing_ids\n\n    def _kneighbors(\n        self,\n        query=None,\n        k=None,\n        reverse=False,\n        aggregation=None,\n        return_dists=False,\n    ):\n        if query is None:\n            raise ValueError(\"Milvus does not support full index neighbors\")\n\n        if reverse is True:\n            raise ValueError(\n                \"Milvus does not support least similarity queries\"\n            )\n\n        if k is None or k > self.config.max_k:\n            raise ValueError(\"Milvus requires k<=%s\" % self.config.max_k)\n\n        if aggregation not in (None, \"mean\"):\n            raise ValueError(\"Unsupported aggregation '%s'\" % aggregation)\n\n        query = self._parse_neighbors_query(query)\n        if aggregation == \"mean\" and query.ndim == 2:\n            query = query.mean(axis=0)\n\n        single_query = query.ndim == 1\n        if single_query:\n            query = [query]\n\n        if self.has_view:\n            if self.config.patches_field is not None:\n                index_ids = self.current_label_ids\n            else:\n                index_ids = self.current_sample_ids\n\n            expr = ['\"' + str(entry) + '\"' for entry in index_ids]\n            expr = f\"\"\"pk in [{','.join(expr)}]\"\"\"\n        else:\n            expr = None\n\n        sample_ids = []\n        label_ids = [] if self.config.patches_field is not None else None\n        dists = []\n        for q in query:\n            if self.config.patches_field is not None:\n                output_fields = [\"sample_id\"]\n            else:\n                output_fields = None\n\n            response = self._collection.search(\n                data=[q.tolist()],\n                anns_field=\"vector\",\n                limit=k,\n                expr=expr,\n                param=self.config.search_params,\n                output_fields=output_fields,\n            )\n\n            if self.config.patches_field is not None:\n                sample_ids.append(\n                    [r.entity.get(\"sample_id\") for r in response[0]]\n                )\n                label_ids.append([r.id for r in response[0]])\n            else:\n                sample_ids.append([r.id for r in response[0]])\n\n            if return_dists:\n                dists.append([r.score for r in response[0]])\n\n        if single_query:\n            sample_ids = sample_ids[0]\n            if label_ids is not None:\n                label_ids = label_ids[0]\n            if return_dists:\n                dists = dists[0]\n\n        if return_dists:\n            return sample_ids, label_ids, dists\n\n        return sample_ids, label_ids\n\n    def _parse_neighbors_query(self, query):\n        if etau.is_str(query):\n            query_ids = [query]\n            single_query = True\n        else:\n            query = np.asarray(query)\n\n            # Query by vector(s)\n            if np.issubdtype(query.dtype, np.number):\n                return query\n\n            query_ids = list(query)\n            single_query = False\n\n        # Query by ID(s)\n        response = self._get_embeddings(query_ids)\n        query = np.array([x[\"vector\"] for x in response])\n\n        if query.size == 0:\n            raise ValueError(\n                \"Query IDs %s were not found in the index\" % query_ids\n            )\n\n        if single_query:\n            query = query[0, :]\n\n        return query\n\n    def _get_dimension(self):\n        if self._collection is None:\n            return None\n\n        for field in self._collection.describe()[\"fields\"]:\n            if field[\"name\"] == \"vector\":\n                return field[\"params\"][\"dim\"]\n\n    @classmethod\n    def _from_dict(cls, d, samples, config, brain_key):\n        return cls(samples, config, brain_key)\n"
  },
  {
    "path": "fiftyone/brain/internal/core/mistakenness.py",
    "content": "\"\"\"\nMistakenness methods.\n\n| Copyright 2017-2026, Voxel51, Inc.\n| `voxel51.com <https://voxel51.com/>`_\n|\n\"\"\"\nimport logging\nfrom math import exp\n\nimport numpy as np\nfrom scipy.special import softmax\nfrom scipy.stats import entropy\n\nfrom fiftyone import ViewField as F\nimport fiftyone.core.brain as fob\nimport fiftyone.core.labels as fol\nimport fiftyone.core.media as fom\nimport fiftyone.core.utils as fou\nimport fiftyone.core.validation as fov\n\n\nlogger = logging.getLogger(__name__)\n\n\n_ALLOWED_TYPES = (\n    fol.Classification,\n    fol.Classifications,\n    fol.Detections,\n    fol.Polylines,\n    fol.Keypoints,\n    fol.TemporalDetections,\n)\n_MISSED_CONFIDENCE_THRESHOLD = 0.95\n_DETECTION_IOU = 0.5\n\n\ndef compute_mistakenness(\n    samples,\n    pred_field,\n    label_field,\n    mistakenness_field,\n    missing_field,\n    spurious_field,\n    use_logits,\n    copy_missing,\n    progress,\n):\n    \"\"\"See ``fiftyone/brain/__init__.py``.\"\"\"\n\n    #\n    # Algorithm\n    #\n    # The chance of a mistake is related to how confident the model prediction\n    # was as well as whether or not the prediction is correct. A prediction\n    # that is highly confident and incorrect is likely to be a mistake. A\n    # prediction that is low confidence and incorrect is not likely to be a\n    # mistake.\n    #\n    # Let us compute a confidence measure based on negative entropy of logits:\n    # $c = -entropy(logits)$. This value is large when there is low uncertainty\n    # and small when there is high uncertainty. Let us define modulator, $m$,\n    # based on whether or not the answer is correct. $m = -1$ when the label is\n    # correct and $1$ otherwise. Then, mistakenness is computed as\n    # $(m * exp(c) + 1) / 2$ so that high confidence correct predictions result\n    # in low mistakenness, high confidence incorrect predictions result in high\n    # mistakenness, and low confidence predictions result in middling\n    # mistakenness.\n    #\n\n    fov.validate_collection_label_fields(\n        samples, (pred_field, label_field), _ALLOWED_TYPES, same_type=True\n    )\n\n    if samples.media_type == fom.VIDEO:\n        mistakenness_field, _ = samples._handle_frame_field(mistakenness_field)\n        missing_field, _ = samples._handle_frame_field(missing_field)\n        spurious_field, _ = samples._handle_frame_field(spurious_field)\n\n    is_objects = samples._is_label_field(\n        pred_field,\n        (fol.Detections, fol.Polylines, fol.Keypoints, fol.TemporalDetections),\n    )\n    if is_objects:\n        eval_key = _make_eval_key(samples, mistakenness_field)\n        config = DetectionMistakennessConfig(\n            pred_field,\n            label_field,\n            mistakenness_field,\n            missing_field,\n            spurious_field,\n            use_logits,\n            copy_missing,\n            eval_key,\n        )\n    else:\n        eval_key = None\n        config = ClassificationMistakennessConfig(\n            pred_field, label_field, mistakenness_field, use_logits\n        )\n\n    brain_key = mistakenness_field\n    brain_method = config.build()\n    brain_method.ensure_requirements()\n    brain_method.register_run(samples, brain_key, cleanup=False)\n    brain_method.register_samples(samples)\n\n    if is_objects:\n        samples.evaluate_detections(\n            pred_field,\n            gt_field=label_field,\n            eval_key=eval_key,\n            classwise=False,\n            iou=_DETECTION_IOU,\n            progress=progress,\n        )\n\n    view = samples.select_fields([label_field, pred_field])\n    processing_frames = samples._is_frame_field(label_field)\n\n    logger.info(\"Computing mistakenness...\")\n    for sample in view.iter_samples(progress=progress):\n        if processing_frames:\n            images = sample.frames.values()\n        else:\n            images = [sample]\n\n        sample_mistakenness = []\n        num_missing = 0\n        num_spurious = 0\n        for image in images:\n            if is_objects:\n                (\n                    img_mistakenness,\n                    img_missing,\n                    img_spurious,\n                ) = brain_method.process_image(image, eval_key)\n\n                num_missing += img_missing\n                num_spurious += img_spurious\n                if processing_frames:\n                    image[missing_field] = img_missing\n                    image[spurious_field] = img_spurious\n            else:\n                img_mistakenness = brain_method.process_image(image)\n\n            if img_mistakenness is not None:\n                sample_mistakenness.append(img_mistakenness)\n\n            if processing_frames:\n                image[mistakenness_field] = img_mistakenness\n\n        if sample_mistakenness:\n            sample[mistakenness_field] = np.max(sample_mistakenness)\n        else:\n            sample[mistakenness_field] = None\n\n        if is_objects:\n            sample[missing_field] = num_missing\n            sample[spurious_field] = num_spurious\n\n        sample.save()\n\n    if eval_key is not None:\n        samples.delete_evaluation(eval_key)\n\n    brain_method.save_run_results(samples, brain_key, None)\n\n    logger.info(\"Mistakenness computation complete\")\n\n\n# @todo move to `fiftyone/brain/mistakenness.py`\n# Don't do this hastily; `get_brain_info()` on existing datasets has this\n# class's full path in it and may need migration\nclass MistakennessMethodConfig(fob.BrainMethodConfig):\n    def __init__(self, pred_field, label_field, mistakenness_field, **kwargs):\n        super().__init__(**kwargs)\n        self.pred_field = pred_field\n        self.label_field = label_field\n        self.mistakenness_field = mistakenness_field\n\n    @property\n    def type(self):\n        return \"mistakenness\"\n\n\nclass MistakennessMethod(fob.BrainMethod):\n    def __init__(self, config):\n        super().__init__(config)\n        self.pred_field = None\n        self.label_field = None\n        self.label_type = None\n\n    def ensure_requirements(self):\n        pass\n\n    def register_samples(self, samples):\n        self.pred_field, _ = samples._handle_frame_field(\n            self.config.pred_field\n        )\n        self.label_field, _ = samples._handle_frame_field(\n            self.config.label_field\n        )\n        self.label_type = samples._get_label_field_type(self.config.pred_field)\n\n    def _validate_run(self, samples, brain_key, existing_info):\n        self._validate_fields_match(brain_key, \"pred_field\", existing_info)\n        self._validate_fields_match(brain_key, \"label_field\", existing_info)\n        self._validate_fields_match(\n            brain_key, \"mistakenness_field\", existing_info\n        )\n\n\n# @todo move to `fiftyone/brain/mistakenness.py`\n# Don't do this hastily; `get_brain_info()` on existing datasets has this\n# class's full path in it and may need migration\nclass ClassificationMistakennessConfig(MistakennessMethodConfig):\n    def __init__(\n        self, pred_field, label_field, mistakenness_field, use_logits, **kwargs\n    ):\n        super().__init__(pred_field, label_field, mistakenness_field, **kwargs)\n        self.use_logits = use_logits\n\n    @property\n    def method(self):\n        return \"classification\"\n\n\nclass ClassificationMistakenness(MistakennessMethod):\n    def process_image(self, sample_or_frame):\n        use_logits = self.config.use_logits\n\n        pred_label, gt_label = _get_data(\n            sample_or_frame, self.pred_field, self.label_field, use_logits\n        )\n\n        if pred_label is None and gt_label is None:\n            return None\n\n        if pred_label is None or gt_label is None:\n            m = 1.0\n        elif isinstance(pred_label, fol.Classifications):\n            # For multilabel problems, all labels must match\n            pred_labels = set(c.label for c in pred_label.classifications)\n            gt_labels = set(c.label for c in gt_label.classifications)\n            m = float(pred_labels == gt_labels)\n        else:\n            m = float(pred_label.label == gt_label.label)\n\n        if pred_label is None:\n            mistakenness = 1.0\n        elif use_logits:\n            mistakenness = _compute_mistakenness_class(pred_label.logits, m)\n        else:\n            mistakenness = _compute_mistakenness_class_conf(\n                pred_label.confidence, m\n            )\n\n        return mistakenness\n\n    def get_fields(self, samples, brain_key):\n        pred_field = self.config.pred_field\n        label_field = self.config.label_field\n        mistakenness_field = self.config.mistakenness_field\n\n        fields = [pred_field, label_field, mistakenness_field]\n\n        if samples._is_frame_field(label_field):\n            fields.append(samples._FRAMES_PREFIX + mistakenness_field)\n\n        return fields\n\n    def cleanup(self, samples, brain_key):\n        label_field = self.config.label_field\n        mistakenness_field = self.config.mistakenness_field\n\n        samples._dataset.delete_sample_fields(\n            mistakenness_field, error_level=1\n        )\n\n        if samples._is_frame_field(label_field):\n            samples._dataset.delete_frame_fields(\n                mistakenness_field, error_level=1\n            )\n\n\n# @todo move to `fiftyone/brain/mistakenness.py`\n# Don't do this hastily; `get_brain_info()` on existing datasets has this\n# class's full path in it and may need migration\nclass DetectionMistakennessConfig(MistakennessMethodConfig):\n    def __init__(\n        self,\n        pred_field,\n        label_field,\n        mistakenness_field,\n        missing_field,\n        spurious_field,\n        use_logits,\n        copy_missing,\n        eval_key,\n        **kwargs\n    ):\n        super().__init__(pred_field, label_field, mistakenness_field, **kwargs)\n        self.missing_field = missing_field\n        self.spurious_field = spurious_field\n        self.use_logits = use_logits\n        self.copy_missing = copy_missing\n        self.eval_key = eval_key\n\n    @property\n    def method(self):\n        return \"detection\"\n\n\nclass DetectionMistakenness(MistakennessMethod):\n    def process_image(self, sample_or_frame, eval_key):\n        missing_field = self.config.missing_field\n        spurious_field = self.config.spurious_field\n        mistakenness_field = self.config.mistakenness_field\n        copy_missing = self.config.copy_missing\n        use_logits = self.config.use_logits\n\n        pred_label, gt_label = _get_data(\n            sample_or_frame, self.pred_field, self.label_field, use_logits\n        )\n\n        list_field = self.label_type._LABEL_LIST_FIELD\n\n        if pred_label is None:\n            pred_label = self.label_type()\n\n        if gt_label is None:\n            gt_label = self.label_type()\n\n        num_spurious = 0\n        num_missing = 0\n        missing_objects = {}\n        image_mistakenness = []\n        pred_map = {}\n        for pred_obj in pred_label[list_field]:\n            pred_map[pred_obj.id] = pred_obj\n            gt_id = pred_obj[eval_key + \"_id\"]\n            conf = pred_obj.confidence\n            if gt_id == \"\" and conf > _MISSED_CONFIDENCE_THRESHOLD:\n                # Unmached FP with high confidence are missing\n                pred_obj[missing_field] = True\n                num_missing += 1\n                missing_objects[pred_obj.id] = pred_obj\n\n        for gt_obj in gt_label[list_field]:\n            # Avoid adding the same unmatched FP predictions upon multiple runs\n            # of this method\n            if copy_missing and gt_obj.has_field(missing_field):\n                if gt_obj.id in missing_objects:\n                    del missing_objects[gt_obj.id]\n\n                continue\n\n            pred_id = gt_obj[eval_key + \"_id\"]\n            if pred_id == \"\":\n                # FN may be spurious\n                gt_obj[spurious_field] = True\n                num_spurious += 1\n            else:\n                # For matched FP, compute mistakenness\n                iou = gt_obj[eval_key + \"_iou\"]\n                pred_obj = pred_map[pred_id]\n                m = float(gt_obj.label == pred_obj.label)\n                if use_logits:\n                    mistakenness_class = _compute_mistakenness_class(\n                        pred_obj.logits, m\n                    )\n                    mistakenness_loc = _compute_mistakenness_loc(\n                        pred_obj.logits, iou\n                    )\n                else:\n                    mistakenness_class = _compute_mistakenness_class_conf(\n                        pred_obj.confidence, m\n                    )\n                    mistakenness_loc = _compute_mistakenness_loc_conf(\n                        pred_obj.confidence, iou\n                    )\n\n                gt_obj[mistakenness_field] = mistakenness_class\n                gt_obj[mistakenness_field + \"_loc\"] = mistakenness_loc\n                image_mistakenness.append(mistakenness_class)\n\n        if copy_missing:\n            gt_label[list_field].extend(missing_objects.values())\n            sample_or_frame[self.label_field] = gt_label\n\n        if image_mistakenness:\n            mistakenness = np.max(image_mistakenness)\n        else:\n            mistakenness = -1\n\n        return mistakenness, num_missing, num_spurious\n\n    def get_fields(self, samples, brain_key):\n        pred_field = self.config.pred_field\n        label_field = self.config.label_field\n        mistakenness_field = self.config.mistakenness_field\n        missing_field = self.config.missing_field\n        spurious_field = self.config.spurious_field\n\n        label_type = samples._get_label_field_type(pred_field)\n        list_field = label_type._LABEL_LIST_FIELD\n\n        fields = [\n            mistakenness_field,\n            missing_field,\n            spurious_field,\n            \"%s.%s.%s\" % (label_field, list_field, mistakenness_field),\n            \"%s.%s.%s_loc\" % (label_field, list_field, mistakenness_field),\n            \"%s.%s.%s\" % (pred_field, list_field, missing_field),\n            \"%s.%s.%s\" % (label_field, list_field, spurious_field),\n        ]\n\n        if samples._is_frame_field(pred_field):\n            fields.extend(\n                [\n                    samples._FRAMES_PREFIX + mistakenness_field,\n                    samples._FRAMES_PREFIX + missing_field,\n                    samples._FRAMES_PREFIX + spurious_field,\n                ]\n            )\n\n        return fields\n\n    def cleanup(self, samples, brain_key):\n        pred_field = self.config.pred_field\n        label_field = self.config.label_field\n        mistakenness_field = self.config.mistakenness_field\n        missing_field = self.config.missing_field\n        spurious_field = self.config.spurious_field\n        eval_key = self.config.eval_key\n\n        label_type = samples._get_label_field_type(pred_field)\n        list_field = label_type._LABEL_LIST_FIELD\n\n        pred_field, is_frame_field = samples._handle_frame_field(pred_field)\n        label_field, _ = samples._handle_frame_field(label_field)\n\n        fields = [\n            mistakenness_field,\n            missing_field,\n            spurious_field,\n            \"%s.%s.%s\" % (label_field, list_field, mistakenness_field),\n            \"%s.%s.%s_loc\" % (label_field, list_field, mistakenness_field),\n            \"%s.%s.%s\" % (pred_field, list_field, missing_field),\n            \"%s.%s.%s\" % (label_field, list_field, spurious_field),\n        ]\n\n        if self.config.copy_missing:\n            # Remove objects that were added to `label_field`\n            samples._dataset.filter_labels(\n                self.config.label_field, F(missing_field).exists(False)\n            ).save()\n\n        if is_frame_field:\n            samples._dataset.delete_sample_fields(\n                [mistakenness_field, spurious_field, missing_field],\n                error_level=1,\n            )\n            samples._dataset.delete_frame_fields(fields, error_level=1)\n        else:\n            samples._dataset.delete_sample_fields(fields, error_level=1)\n\n        if eval_key in samples.list_evaluations():\n            samples.delete_evaluation(eval_key)\n\n    def _validate_run(self, samples, brain_key, existing_info):\n        super()._validate_run(samples, brain_key, existing_info)\n        self._validate_fields_match(brain_key, \"missing_field\", existing_info)\n        self._validate_fields_match(brain_key, \"spurious_field\", existing_info)\n        self._validate_fields_match(brain_key, \"copy_missing\", existing_info)\n\n\ndef _make_eval_key(samples, brain_key):\n    existing_eval_keys = samples.list_evaluations()\n    eval_key = brain_key + \"_eval\"\n    if eval_key not in existing_eval_keys:\n        return eval_key\n\n    idx = 2\n    while eval_key + str(idx) in existing_eval_keys:\n        idx += 1\n\n    return eval_key + str(idx)\n\n\ndef _get_data(sample, pred_field, label_field, use_logits):\n    pred_label = sample[pred_field]\n    label = sample[label_field]\n\n    if pred_label is None:\n        return pred_label, label\n\n    if isinstance(pred_label, fol.Detections):\n        for det in pred_label.detections:\n            if det.confidence is None:\n                raise ValueError(\n                    \"Detection '%s' in sample '%s' field '%s' has no \"\n                    \"confidence\" % (det.id, sample.id, pred_field)\n                )\n    elif isinstance(pred_label, fol.Polylines):\n        for poly in pred_label.polylines:\n            if poly.confidence is None:\n                raise ValueError(\n                    \"Polyline '%s' in sample '%s' field '%s' has no \"\n                    \"confidence\" % (poly.id, sample.id, pred_field)\n                )\n    elif use_logits:\n        if pred_label.logits is None:\n            raise ValueError(\n                \"Sample '%s' field '%s' has no logits\"\n                % (sample.id, pred_field)\n            )\n    else:\n        if pred_label.confidence is None:\n            raise ValueError(\n                \"Sample '%s' field '%s' has no confidence\"\n                % (sample.id, pred_field)\n            )\n\n    return pred_label, label\n\n\ndef _compute_mistakenness_class(logits, m):\n    # constrain m to either 1 (incorrect) or -1 (correct)\n    m = m * -2.0 + 1.0\n\n    c = -1.0 * entropy(softmax(np.asarray(logits)))\n    mistakenness = (m * exp(c) + 1.0) / 2.0\n\n    return mistakenness\n\n\ndef _compute_mistakenness_loc(logits, iou):\n    # i = 0 for high iou, i = 1 for low iou\n    i = (1.0 / (1.0 - _DETECTION_IOU)) * (1.0 - iou)\n\n    # c = 0 for low confidence, c = 1 for high confidence\n    c = exp(-1.0 * entropy(softmax(np.asarray(logits))))\n\n    # mistakenness = i when c = i, mistakenness = 0.5 if c = 0\n    # mistakenness is higher with lower IoU and closer to 0 or 1 with higher\n    # confidence\n    mistakenness = (c * ((2.0 * i) - 1.0) + 1.0) / 2.0\n\n    return mistakenness\n\n\ndef _compute_mistakenness_class_conf(confidence, m):\n    # constrain m to either 1 (incorrect) or -1 (correct)\n    m = m * -2.0 + 1.0\n\n    mistakenness = (m * confidence + 1.0) / 2.0\n\n    return mistakenness\n\n\ndef _compute_mistakenness_loc_conf(confidence, iou):\n    # i = 0 for high iou, i = 1 for low iou\n    i = (1.0 / (1.0 - _DETECTION_IOU)) * (1.0 - iou)\n\n    # c = 0 for low confidence, c = 1 for high confidence\n    c = confidence\n\n    # mistakenness = i when c = i, mistakenness = 0.5 if c = 0\n    # mistakenness is higher with lower IoU and closer to 0 or 1 with higher\n    # confidence\n    mistakenness = (c * ((2.0 * i) - 1.0) + 1.0) / 2.0\n\n    return mistakenness\n"
  },
  {
    "path": "fiftyone/brain/internal/core/mongodb.py",
    "content": "\"\"\"\nMongoDB similarity backend.\n\n| Copyright 2017-2026, Voxel51, Inc.\n| `voxel51.com <https://voxel51.com/>`_\n|\n\"\"\"\nimport logging\n\nfrom bson import ObjectId\nimport numpy as np\nfrom pymongo.errors import OperationFailure\n\nimport eta.core.utils as etau\n\nfrom fiftyone import ViewField as F\nimport fiftyone.core.fields as fof\nimport fiftyone.core.media as fom\nimport fiftyone.core.utils as fou\nimport fiftyone.brain.internal.core.utils as fbu\nfrom fiftyone.brain.similarity import (\n    SimilarityConfig,\n    Similarity,\n    SimilarityIndex,\n)\n\n\nlogger = logging.getLogger(__name__)\n\n_SUPPORTED_METRICS = {\n    \"cosine\": \"cosine\",\n    \"dotproduct\": \"dotProduct\",\n    \"euclidean\": \"euclidean\",\n}\n\n\nclass MongoDBSimilarityConfig(SimilarityConfig):\n    \"\"\"Configuration for a MongoDB similarity instance.\n\n    Args:\n        index_name (None): the name of the MongoDB vector index to use or\n            create. If none is provided, a new index will be created\n        metric (\"cosine\"): the embedding distance metric to use when creating a\n            new index. Supported values are\n            ``(\"cosine\", \"dotproduct\", \"euclidean\")``\n        **kwargs: keyword arguments for\n            :class:`fiftyone.brain.similarity.SimilarityConfig`\n    \"\"\"\n\n    def __init__(self, index_name=None, metric=\"cosine\", **kwargs):\n        if kwargs.get(\"embeddings_field\") is None and index_name is None:\n            raise ValueError(\n                \"You must provide either the name of a field to read/write \"\n                \"embeddings for this index by passing the `embeddings` \"\n                \"parameter, or you must provide the name of an existing \"\n                \"vector search index via the `index_name` parameter\"\n            )\n\n        # @todo support this. Will likely require copying embeddings to a new\n        # collection as vector search indexes do not yet support array fields\n        if kwargs.get(\"patches_field\") is not None:\n            raise ValueError(\n                \"The MongoDB backend does not yet support patch embeddings\"\n            )\n\n        if metric not in _SUPPORTED_METRICS:\n            raise ValueError(\n                \"Unsupported metric '%s'. Supported values are %s\"\n                % (metric, tuple(_SUPPORTED_METRICS.keys()))\n            )\n\n        super().__init__(**kwargs)\n\n        self.index_name = index_name\n        self.metric = metric\n\n    @property\n    def method(self):\n        return \"mongodb\"\n\n    @property\n    def max_k(self):\n        return 10000  # MongoDB limit\n\n    @property\n    def supports_least_similarity(self):\n        return False\n\n    @property\n    def supported_aggregations(self):\n        return (\"mean\",)\n\n\nclass MongoDBSimilarity(Similarity):\n    \"\"\"MongoDB similarity factory.\n\n    Args:\n        config: a :class:`MongoDBSimilarityConfig`\n    \"\"\"\n\n    def ensure_requirements(self):\n        #\n        # https://pymongo.readthedocs.io/en/stable/api/pymongo/collection.html#pymongo.collection.Collection.create_search_index\n        #\n        # Could also validate that user is connected to an Atlas cluster here\n        # eg Atlas clusters generally have hostnames which end in \"mongodb.net\"\n        # https://stackoverflow.com/q/73180110\n        #\n        fou.ensure_package(\"pymongo>=4.7\")\n\n    def ensure_usage_requirements(self):\n        fou.ensure_package(\"pymongo>=4.7\")\n\n    def initialize(self, samples, brain_key):\n        return MongoDBSimilarityIndex(\n            samples, self.config, brain_key, backend=self\n        )\n\n\nclass MongoDBSimilarityIndex(SimilarityIndex):\n    \"\"\"Class for interacting with MongoDB similarity indexes.\n\n    Args:\n        samples: the :class:`fiftyone.core.collections.SampleCollection` used\n        config: the :class:`MongoDBSimilarityConfig` used\n        brain_key: the brain key\n        backend (None): a :class:`MongoDBSimilarity` instance\n    \"\"\"\n\n    def __init__(self, samples, config, brain_key, backend=None):\n        super().__init__(samples, config, brain_key, backend=backend)\n\n        self._dataset = samples._dataset\n        self._sample_ids = None\n        self._label_ids = None\n        self._index = None\n        self._initialize()\n\n    @property\n    def is_external(self):\n        return False\n\n    @property\n    def total_index_size(self):\n        if self._sample_ids is not None:\n            return len(self._sample_ids)\n\n        if self._dataset.media_type == fom.GROUP:\n            samples = self._dataset.select_group_slices(_allow_mixed=True)\n        else:\n            samples = self._dataset\n\n        patches_field = self.config.patches_field\n        embeddings_field = self.config.embeddings_field\n\n        if patches_field is not None:\n            _, embeddings_path = self._dataset._get_label_field_path(\n                patches_field, embeddings_field\n            )\n            samples = samples.filter_labels(\n                patches_field, F(embeddings_field).exists()\n            )\n            return samples.count(embeddings_path)\n\n        if samples.has_field(embeddings_field):\n            return samples.exists(embeddings_field).count()\n\n        return 0\n\n    def _initialize(self):\n        coll = self._dataset._sample_collection\n\n        try:\n            indexes = {\n                i[\"name\"]: i\n                for i in coll.aggregate([{\"$listSearchIndexes\": {}}])\n            }\n        except OperationFailure:\n            # https://www.mongodb.com/docs/manual/release-notes/7.0/#atlas-search-index-management\n            # https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-overview\n            if self.config.index_name is None:\n                raise ValueError(\n                    \"You must be running MongoDB Atlas 7.0 or later in order \"\n                    \"to use vector search indexes\"\n                )\n\n            # Must assume index exists because we can't use pymongo to check...\n            self._index = True\n\n            return\n\n        if self.config.index_name is None:\n            root = self.config.embeddings_field\n            index_name = fbu.get_unique_name(root, list(indexes.keys()))\n\n            self.config.index_name = index_name\n            self.save_config()\n        elif self.config.embeddings_field is None:\n            info = indexes.get(self.config.index_name, None)\n            if info is None:\n                raise ValueError(\n                    \"Index '%s' does not exist\" % self.config.index_name\n                )\n\n            self.config.embeddings_field = next(\n                iter(info[\"latestDefinition\"][\"mappings\"][\"fields\"].keys())\n            )\n            self.save_config()\n\n        if self.config.index_name in indexes:\n            # Index already exists\n            self._index = True\n        elif self.total_index_size > 0:\n            # Embeddings already exist but the index hasn't been declared yet\n            dimension = self._get_dimension()\n            self._create_index(dimension)\n        else:\n            # Index will be created when add_to_index() is called\n            pass\n\n    def _get_dimension(self):\n        if self._dataset.media_type == fom.GROUP:\n            samples = self._dataset.select_group_slices(_allow_mixed=True)\n        else:\n            samples = self._dataset\n\n        patches_field = self.config.patches_field\n        embeddings_field = self.config.embeddings_field\n\n        if patches_field is not None:\n            _, embeddings_path = self._dataset._get_label_field_path(\n                patches_field, embeddings_field\n            )\n            view = samples.filter_labels(\n                patches_field, F(embeddings_field).exists()\n            ).limit(1)\n            embeddings = view.values(embeddings_path, unwind=True)\n        else:\n            view = samples.exists(embeddings_field).limit(1)\n            embeddings = view.values(embeddings_field)\n\n        embedding = next(iter(embeddings), None)\n        if embedding is None:\n            return None\n\n        return len(embedding)  # MongoDB requires list fields\n\n    def _create_index(self, dimension):\n        # https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage\n        # https://www.mongodb.com/docs/languages/python/pymongo-driver/current/indexes/atlas-search-index/\n        from pymongo.operations import SearchIndexModel\n\n        field = self._dataset.get_field(self.config.embeddings_field)\n        if field is not None and not isinstance(field, fof.ListField):\n            raise ValueError(\n                \"MongoDB vector search indexes require embeddings to be \"\n                \"stored in list fields\"\n            )\n\n        metric = _SUPPORTED_METRICS[self.config.metric]\n\n        fields = [\n            {\n                \"type\": \"vector\",\n                \"numDimensions\": dimension,\n                \"path\": self.config.embeddings_field,\n                \"similarity\": metric,\n            },\n            {\n                \"type\": \"filter\",\n                \"path\": \"_id\",\n            },\n        ]\n\n        \"\"\"\n        if self._dataset.media_type == fom.GROUP:\n            fields.append(\n                {\n                    \"type\": \"filter\",\n                    \"path\": self._dataset.group_field + \".name\",\n                }\n            )\n        \"\"\"\n\n        model = SearchIndexModel(\n            name=self.config.index_name,\n            type=\"vectorSearch\",  # requires pymongo>=4.7\n            definition={\"fields\": fields},\n        )\n\n        coll = self._dataset._sample_collection\n        coll.create_search_index(model=model)\n\n        self._index = True\n\n    @property\n    def ready(self):\n        \"\"\"Returns True/False whether the vector search index is ready to be\n        queried.\n        \"\"\"\n        if self._index is None:\n            return False\n\n        try:\n            coll = self._dataset._sample_collection\n            indexes = {\n                i[\"name\"]: i\n                for i in coll.aggregate([{\"$listSearchIndexes\": {}}])\n            }\n        except OperationFailure:\n            # requires MongoDB Atlas 7.0 or later\n            return None\n\n        info = indexes.get(self.config.index_name, {})\n        return info.get(\"status\", None) == \"READY\"\n\n    def add_to_index(\n        self,\n        embeddings,\n        sample_ids,\n        label_ids=None,\n        overwrite=True,\n        allow_existing=True,\n        warn_existing=False,\n        reload=True,\n    ):\n        if self._index is None:\n            self._create_index(embeddings.shape[1])\n\n        sample_ids = np.asarray(sample_ids)\n        label_ids = np.asarray(label_ids) if label_ids is not None else None\n\n        if not overwrite or not allow_existing or warn_existing:\n            if self._sample_ids is not None:\n                _sample_ids, _label_ids = self._sample_ids, self._label_ids\n            else:\n                _sample_ids, _label_ids = self._parse_data(\n                    self._dataset, self.config\n                )\n\n            index_sample_ids, index_label_ids, ii, _ = fbu.add_ids(\n                sample_ids,\n                label_ids,\n                _sample_ids,\n                _label_ids,\n                patches_field=self.config.patches_field,\n                overwrite=overwrite,\n                allow_existing=allow_existing,\n                warn_existing=warn_existing,\n            )\n\n            self._sample_ids = index_sample_ids\n            self._label_ids = index_label_ids\n\n            if ii.size == 0:\n                return\n\n            embeddings = embeddings[ii, :]\n            sample_ids = sample_ids[ii]\n            label_ids = label_ids[ii] if label_ids is not None else None\n        else:\n            index_sample_ids = None\n            index_label_ids = None\n\n        fbu.add_embeddings(\n            self._dataset,\n            embeddings.tolist(),  # MongoDB requires list fields\n            sample_ids,\n            label_ids,\n            self.config.embeddings_field,\n            patches_field=self.config.patches_field,\n        )\n\n        if reload:\n            super().reload()\n\n        self._sample_ids = index_sample_ids\n        self._label_ids = index_label_ids\n\n    def remove_from_index(\n        self,\n        sample_ids=None,\n        label_ids=None,\n        allow_missing=True,\n        warn_missing=False,\n        reload=True,\n    ):\n        if not allow_missing or warn_missing:\n            if self._sample_ids is not None:\n                _sample_ids, _label_ids = self._sample_ids, self._label_ids\n            else:\n                _sample_ids, _label_ids = self._parse_data(\n                    self._dataset, self.config\n                )\n\n            index_sample_ids, index_label_ids, rm_inds = fbu.remove_ids(\n                sample_ids,\n                label_ids,\n                _sample_ids,\n                _label_ids,\n                patches_field=self.config.patches_field,\n                allow_missing=allow_missing,\n                warn_missing=warn_missing,\n            )\n\n            self._sample_ids = index_sample_ids\n            self._label_ids = index_label_ids\n\n            if rm_inds.size == 0:\n                return\n\n            if self.config.patches_field is not None:\n                sample_ids = None\n                label_ids = _label_ids[rm_inds]\n            else:\n                sample_ids = _sample_ids[rm_inds]\n                label_ids = None\n        else:\n            index_sample_ids = None\n            index_label_ids = None\n\n        fbu.remove_embeddings(\n            self._dataset,\n            self.config.embeddings_field,\n            sample_ids=sample_ids,\n            label_ids=label_ids,\n            patches_field=self.config.patches_field,\n        )\n\n        if reload:\n            super().reload()\n\n        self._sample_ids = index_sample_ids\n        self._label_ids = index_label_ids\n\n    def get_embeddings(\n        self,\n        sample_ids=None,\n        label_ids=None,\n        allow_missing=True,\n        warn_missing=False,\n    ):\n        if self._dataset.media_type == fom.GROUP:\n            samples = self._dataset.select_group_slices(_allow_mixed=True)\n        else:\n            samples = self._dataset\n\n        if sample_ids is not None:\n            samples = samples.select(sample_ids)\n        elif label_ids is not None:\n            if self.config.patches_field is None:\n                raise ValueError(\"This index does not support label IDs\")\n\n            if sample_ids is not None:\n                logger.warning(\n                    \"Ignoring sample IDs when label IDs are provided\"\n                )\n\n            samples = samples.select_labels(\n                ids=label_ids, fields=self.config.patches_field\n            )\n\n        _embeddings, _sample_ids, _label_ids = fbu.get_embeddings(\n            samples,\n            patches_field=self.config.patches_field,\n            embeddings_field=self.config.embeddings_field,\n        )\n\n        if label_ids is not None:\n            inds = _get_inds(\n                label_ids,\n                _label_ids,\n                \"label\",\n                allow_missing,\n                warn_missing,\n            )\n\n            embeddings = _embeddings[inds, :]\n            sample_ids = _sample_ids[inds]\n            label_ids = np.asarray(label_ids)\n        elif sample_ids is not None:\n            if etau.is_str(sample_ids):\n                sample_ids = [sample_ids]\n\n            if self.config.patches_field is not None:\n                sample_ids = set(sample_ids)\n                bools = [_id in sample_ids for _id in _sample_ids]\n                inds = np.nonzero(bools)[0]\n            else:\n                inds = _get_inds(\n                    sample_ids,\n                    _sample_ids,\n                    \"sample\",\n                    allow_missing,\n                    warn_missing,\n                )\n\n            embeddings = _embeddings[inds, :]\n            sample_ids = _sample_ids[inds]\n            if self.config.patches_field is not None:\n                label_ids = _label_ids[inds]\n            else:\n                label_ids = None\n        else:\n            embeddings = _embeddings\n            sample_ids = _sample_ids\n            label_ids = _label_ids\n\n        return embeddings, sample_ids, label_ids\n\n    def reload(self):\n        self._sample_ids = None\n        self._label_ids = None\n\n        super().reload()\n\n    def cleanup(self):\n        if self._index is None:\n            return\n\n        try:\n            coll = self._dataset._sample_collection\n            coll.drop_search_index(self.config.index_name)\n        except OperationFailure:\n            # requires MongoDB Atlas 7.0 or later\n            pass\n\n        self._index = None\n\n    def _kneighbors(\n        self,\n        query=None,\n        k=None,\n        reverse=False,\n        aggregation=None,\n        return_dists=False,\n    ):\n        if query is None:\n            raise ValueError(\"MongoDB does not support full index neighbors\")\n\n        if reverse is True:\n            raise ValueError(\n                \"MongoDB does not support least similarity queries\"\n            )\n\n        if aggregation not in (None, \"mean\"):\n            raise ValueError(\n                f\"MongoDB does not support {aggregation} aggregation\"\n            )\n\n        if k is None:\n            k = min(self.index_size, self.config.max_k)\n\n        # https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage\n        num_candidates = min(10 * k, self.config.max_k)\n\n        query = self._parse_neighbors_query(query)\n        if aggregation == \"mean\" and query.ndim == 2:\n            query = query.mean(axis=0)\n\n        single_query = query.ndim == 1\n        if single_query:\n            query = [query]\n\n        if self.has_view:\n            index_ids = self.current_sample_ids\n            # if self.config.patches_field is not None:\n            #     index_ids = self.current_label_ids\n        else:\n            index_ids = None\n\n        dataset = self._dataset\n\n        sample_ids = []\n        label_ids = None\n        # if self.config.patches_field is not None:\n        #     label_ids = []\n        dists = []\n\n        for q in query:\n            search = {\n                \"index\": self.config.index_name,\n                \"path\": self.config.embeddings_field,\n                \"limit\": k,\n                \"numCandidates\": num_candidates,\n                \"queryVector\": q.tolist(),\n            }\n\n            if index_ids is not None:\n                search[\"filter\"] = {\n                    \"_id\": {\"$in\": [ObjectId(_id) for _id in index_ids]}\n                }\n\n            \"\"\"\n            elif dataset.media_type == fom.GROUP:\n                # $vectorSearch must be the first stage in all pipelines, so we\n                # have to incorporate slice selection as a $filter\n                name_field = dataset.group_field + \".name\"\n                group_slice = self.view.group_slice or dataset.group_slice\n                search[\"filter\"] = {name_field: {\"$eq\": group_slice}}\n            \"\"\"\n\n            project = {\"_id\": 1}\n            # if self.config.patches_field is not None:\n            #     project[\"_sample_id\"] = 1\n            if return_dists:\n                project[\"score\"] = {\"$meta\": \"vectorSearchScore\"}\n\n            # https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage\n            pipeline = [{\"$vectorSearch\": search}, {\"$project\": project}]\n\n            try:\n                matches = list(\n                    dataset._aggregate(\n                        pipeline=pipeline, manual_group_select=True\n                    )\n                )\n            except OperationFailure as e:\n                if index_ids is None:\n                    raise e\n\n                logger.warning(\n                    \"This legacy search index does not yet support views. \"\n                    \"Please follow the instructions at \"\n                    \"https://github.com/voxel51/fiftyone-brain/pull/248 \"\n                    \"to upgrade it.\\n\\nIn the meantime, the full index will \"\n                    \"instead be queried, which may result in fewer \"\n                    \"matches in your current view\"\n                )\n\n                search.pop(\"filter\")\n                matches = list(\n                    dataset._aggregate(\n                        pipeline=pipeline, manual_group_select=True\n                    )\n                )\n\n            sample_ids.append([str(m[\"_id\"]) for m in matches])\n            # if self.config.patches_field is not None:\n            #     sample_ids.append([str(m[\"_sample_id\"]) for m in matches])\n            #     label_ids.append([str(m[\"_id\"]) for m in matches])\n\n            if return_dists:\n                dists.append([m[\"score\"] for m in matches])\n\n        if single_query:\n            sample_ids = sample_ids[0]\n            if label_ids is not None:\n                label_ids = label_ids[0]\n            if return_dists:\n                dists = dists[0]\n\n        if return_dists:\n            return sample_ids, label_ids, dists\n\n        return sample_ids, label_ids\n\n    def _parse_neighbors_query(self, query):\n        if etau.is_str(query):\n            query_ids = [query]\n            single_query = True\n        else:\n            query = np.asarray(query)\n\n            # Query by vector(s)\n            if np.issubdtype(query.dtype, np.number):\n                return query\n\n            query_ids = list(query)\n            single_query = False\n\n        # Query by ID(s)\n        embeddings = self._get_embeddings(query_ids)\n        num_missing = len(query_ids) - len(embeddings)\n        for e in embeddings:\n            num_missing += int(e is None)\n\n        if num_missing > 0:\n            if single_query:\n                raise ValueError(\"The query ID does not exist in this index\")\n            else:\n                raise ValueError(\n                    f\"{num_missing} query IDs do not exist in this index\"\n                )\n\n        query = np.array(embeddings)\n        if single_query:\n            query = query[0, :]\n\n        return query\n\n    def _get_embeddings(self, query_ids):\n        if self._dataset.media_type == fom.GROUP:\n            samples = self._dataset.select_group_slices(_allow_mixed=True)\n        else:\n            samples = self._dataset\n\n        patches_field = self.config.patches_field\n        embeddings_field = self.config.embeddings_field\n        if patches_field is not None:\n            _, embeddings_path = self._dataset._get_label_field_path(\n                patches_field, embeddings_field\n            )\n            view = samples.filter_labels(\n                patches_field, F(\"_id\").is_in(query_ids)\n            )\n            embeddings = view.values(embeddings_path, unwind=True)\n        else:\n            view = samples.select(query_ids)\n            embeddings = view.values(embeddings_field)\n\n        return embeddings\n\n    @staticmethod\n    def _parse_data(samples, config):\n        if samples.media_type == fom.GROUP:\n            samples = samples.select_group_slices(_allow_mixed=True)\n\n        if config.patches_field is not None:\n            samples = samples.filter_labels(\n                config.patches_field, F(config.embeddings_field).exists()\n            )\n        else:\n            samples = samples.exists(config.embeddings_field)\n\n        return fbu.get_ids(samples, patches_field=config.patches_field)\n\n    @classmethod\n    def _from_dict(cls, d, samples, config, brain_key):\n        return cls(samples, config, brain_key)\n\n\ndef _get_inds(ids, index_ids, ftype, allow_missing, warn_missing):\n    if etau.is_str(ids):\n        ids = [ids]\n\n    ids_map = {_id: i for i, _id in enumerate(index_ids)}\n\n    inds = []\n    bad_ids = []\n\n    for _id in ids:\n        idx = ids_map.get(_id, None)\n        if idx is not None:\n            inds.append(idx)\n        else:\n            bad_ids.append(_id)\n\n    num_missing = len(bad_ids)\n\n    if num_missing > 0:\n        if not allow_missing:\n            raise ValueError(\n                \"Found %d %s IDs (eg '%s') that are not present in the index\"\n                % (num_missing, ftype, bad_ids[0])\n            )\n\n        if warn_missing:\n            logger.warning(\n                \"Ignoring %d %s IDs that are not present in the index\",\n                num_missing,\n                ftype,\n            )\n\n    return np.array(inds)\n"
  },
  {
    "path": "fiftyone/brain/internal/core/mosaic.py",
    "content": "\"\"\"\nMosaic similarity backend.\n\n| Copyright 2017-2026, Voxel51, Inc.\n| `voxel51.com <https://voxel51.com/>`_\n|\n\"\"\"\nimport logging\n\nimport numpy as np\n\nimport eta.core.utils as etau\n\nimport fiftyone.core.utils as fou\nfrom fiftyone.brain.similarity import (\n    SimilarityConfig,\n    Similarity,\n    SimilarityIndex,\n)\nimport fiftyone.brain.internal.core.utils as fbu\n\nvector_search_client = fou.lazy_import(\"databricks.vector_search.client\")\n\n\nlogger = logging.getLogger(__name__)\n\n# Todo: add in required for arguments that are necessary to create the index table\nclass MosaicSimilarityConfig(SimilarityConfig):\n    \"\"\"Configuration for the Mosaic similarity backend.\n\n    Args:\n        endpoint_name (None): the name of the vector search endpoint that was created in the Databricks workspace\n        workspace_url (None): the URL of the Databricks workspace\n        catalog_name (None): the name of the catalog in the Databricks workspace\n        schema_name (None): the name of the schema in the Databricks workspace\n        index_name (None): the name of the index to use, if one is not provided, a unique name will be generated\n        service_principal_client_id (None): the client ID of the service principal created for authentication\n        service_principal_client_secret (None): the client secret of the service principal created for authentication\n        personal_access_token (None): the personal access token created for authentication\n        **kwargs: keyword arguments for\n            :class:`fiftyone.brain.similarity.SimilarityConfig`\n    \"\"\"\n\n    def __init__(\n        self,\n        endpoint_name=None,\n        workspace_url=None,\n        catalog_name=None,\n        schema_name=None,\n        index_name=None,\n        service_principal_client_id=None,\n        service_principal_client_secret=None,\n        personal_access_token=None,\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n\n        self.index_name = index_name\n        self.endpoint_name = endpoint_name\n        self.catalog_name = catalog_name\n        self.schema_name = schema_name\n\n        # store privately so these aren't serialized\n        self._workspace_url = workspace_url\n        self._service_principal_client_id = service_principal_client_id\n        self._service_principal_client_secret = service_principal_client_secret\n        self._personal_access_token = personal_access_token\n\n    @property\n    def method(self):\n        return \"mosaic\"\n\n    @property\n    def workspace_url(self):\n        return self._workspace_url\n\n    @workspace_url.setter\n    def workspace_url(self, workspace_url):\n        self._workspace_url = workspace_url\n\n    @property\n    def service_principal_client_id(self):\n        return self._service_principal_client_id\n\n    @service_principal_client_id.setter\n    def service_principal_client_id(self, service_principal_client_id):\n        self._service_principal_client_id = service_principal_client_id\n\n    @property\n    def service_principal_client_secret(self):\n        return self._service_principal_client_secret\n\n    @service_principal_client_secret.setter\n    def service_principal_client_secret(self, service_principal_client_secret):\n        self._service_principal_client_secret = service_principal_client_secret\n\n    @property\n    def personal_access_token(self):\n        return self._personal_access_token\n\n    @personal_access_token.setter\n    def personal_access_token(self, personal_access_token):\n        self._personal_access_token = personal_access_token\n\n    @property\n    def max_k(self):\n        return None\n\n    @property\n    def supports_least_similarity(self):\n        return False\n\n    @property\n    def supported_aggregations(self):\n        return (\"mean\",)\n\n    def load_credentials(\n        self,\n        workspace_url=None,\n        service_principal_client_id=None,\n        service_principal_client_secret=None,\n        personal_access_token=None,\n    ):\n        self._load_parameters(\n            workspace_url=workspace_url,\n            service_principal_client_id=service_principal_client_id,\n            service_principal_client_secret=service_principal_client_secret,\n            personal_access_token=personal_access_token,\n        )\n\n\nclass MosaicSimilarity(Similarity):\n    \"\"\"Mosaic similarity factory.\n\n    Args:\n        config: a :class:`MosaicSimilarityConfig`\n    \"\"\"\n\n    def ensure_requirements(self):\n        fou.ensure_package(\"databricks-vectorsearch\")\n\n    def ensure_usage_requirements(self):\n        fou.ensure_package(\"databricks-vectorsearch\")\n\n    def initialize(self, samples, brain_key):\n        return MosaicSimilarityIndex(\n            samples, self.config, brain_key, backend=self\n        )\n\n\nclass MosaicSimilarityIndex(SimilarityIndex):\n    \"\"\"Class for interacting with Mosaic similarity indexes.\n\n    Args:\n        samples: the :class:`fiftyone.core.collections.SampleCollection` used\n        config: the :class:`MosaicSimilarityConfig` used\n        brain_key: the brain key\n        backend (None): a :class:`MosaicSimilarity` instance\n    \"\"\"\n\n    def __init__(self, samples, config, brain_key, backend=None):\n        super().__init__(samples, config, brain_key, backend=backend)\n        self._client = None\n        self._index = None\n        self._initialize()\n\n    def _initialize(self):\n        self._client = vector_search_client.VectorSearchClient(\n            workspace_url=self.config.workspace_url,\n            service_principal_client_id=self.config.service_principal_client_id,\n            service_principal_client_secret=self.config.service_principal_client_secret,\n            personal_access_token=self.config.personal_access_token,\n        )\n\n        try:\n            index_names_result = self._client.list_indexes(\n                self.config.endpoint_name\n            )\n        except Exception as e:\n            raise ValueError(\n                f\"Failed to list indexes from endpoint :{self.config.endpoint_name}\"\n            ) from e\n\n        index_prefix = f\"{self.config.catalog_name}.{self.config.schema_name}.\"\n        if not index_names_result:\n            index_names = []\n        else:\n            index_names = [\n                ind[\"name\"].replace(index_prefix, \"\")\n                for ind in index_names_result[\"vector_indexes\"]\n                if ind[\"name\"].startswith(index_prefix)\n            ]\n\n        if self.config.index_name is None:\n            root = \"fiftyone-\" + fou.to_slug(self._samples._root_dataset.name)\n            index_name = fbu.get_unique_name(root, index_names)\n\n            self.config.index_name = index_name\n            self.save_config()\n\n        if self.config.index_name in index_names:\n            index = self._client.get_index(\n                endpoint_name=self.config.endpoint_name,\n                index_name=f\"{index_prefix}{self.config.index_name}\",\n            )\n        else:\n            index = None\n\n        self._index = index\n\n    def _create_index(self, dimension):\n        self._index = self._client.create_direct_access_index(\n            endpoint_name=self.config.endpoint_name,\n            index_name=f\"{self.config.catalog_name}.{self.config.schema_name}.{self.config.index_name}\",\n            primary_key=\"foid\",\n            embedding_dimension=dimension,\n            embedding_vector_column=\"embedding_vector\",\n            schema={\n                \"foid\": \"string\",\n                \"sample_id\": \"string\",\n                \"embedding_vector\": \"array<float>\",\n            },\n        )\n\n    @property\n    def client(self):\n        \"\"\"The ``databricks.vector_search.client.VectorSearchClient`` instance for this index.\"\"\"\n        return self._client\n\n    @property\n    def total_index_size(self):\n        if self._index is None:\n            return 0\n        return self._index.describe()[\"status\"][\"indexed_row_count\"]\n\n    def add_to_index(\n        self,\n        embeddings,\n        sample_ids,\n        label_ids=None,\n        overwrite=True,\n        allow_existing=True,\n        warn_existing=False,\n        reload=True,\n        batch_size=200,\n    ):\n        if self._index is None:\n            self._create_index(embeddings.shape[1])\n\n        if label_ids is not None:\n            ids = label_ids\n        else:\n            ids = sample_ids\n\n        if warn_existing or not allow_existing or not overwrite:\n            index_ids = self._get_index_ids()\n\n            existing_ids = set(ids) & set(index_ids)\n            num_existing = len(existing_ids)\n\n            if num_existing > 0:\n                if not allow_existing:\n                    raise ValueError(\n                        \"Found %d IDs (eg %s) that already exist in the index\"\n                        % (num_existing, next(iter(existing_ids)))\n                    )\n\n                if warn_existing:\n                    if overwrite:\n                        logger.warning(\n                            \"Overwriting %d IDs that already exist in the \"\n                            \"index\",\n                            num_existing,\n                        )\n                    else:\n                        logger.warning(\n                            \"Skipping %d IDs that already exist in the index\",\n                            num_existing,\n                        )\n        else:\n            existing_ids = set()\n\n        if existing_ids and not overwrite:\n            del_inds = [i for i, _id in enumerate(ids) if _id in existing_ids]\n            embeddings = np.delete(embeddings, del_inds, axis=0)\n            sample_ids = np.delete(sample_ids, del_inds)\n            if label_ids is not None:\n                label_ids = np.delete(label_ids, del_inds)\n\n        for _embeddings, _ids, _sample_ids in zip(\n            fou.iter_batches(embeddings, batch_size),\n            fou.iter_batches(ids, batch_size),\n            fou.iter_batches(sample_ids, batch_size),\n        ):\n            result = [\n                {\"foid\": f, \"sample_id\": s, \"embedding_vector\": list(e)}\n                for f, s, e in zip(_ids, _sample_ids, _embeddings)\n            ]\n            self._index.upsert(result)\n\n        if reload:\n            self.reload()\n\n    def _get_index_ids(self, batch_size=200):\n        ids = set()\n        result = self._index.scan(num_results=batch_size)\n        while len(result) > 0:\n            ids.update(\n                [\n                    doc[\"fields\"][0][\"value\"][\"string_value\"]\n                    for doc in result[\"data\"]\n                ]\n            )\n            last_primary_key = result[\"last_primary_key\"]\n            result = self._index.scan(\n                num_results=batch_size, last_primary_key=last_primary_key\n            )\n        return list(ids)\n\n    def _get_values(self, ids, batch_size=200):\n        embeddings = []\n        result = self._index.scan(num_results=batch_size)\n        while len(result) > 0:\n            for doc in result[\"data\"]:\n                foid = doc[\"fields\"][0][\"value\"][\"string_value\"]\n                if foid in ids:\n                    embedding = [\n                        d[\"number_value\"]\n                        for d in doc[\"fields\"][2][\"value\"][\"list_value\"][\n                            \"values\"\n                        ]\n                    ]\n                    embeddings.append(embedding)\n            last_primary_key = result[\"last_primary_key\"]\n            result = self._index.scan(\n                num_results=batch_size, last_primary_key=last_primary_key\n            )\n\n        return embeddings\n\n    def remove_from_index(\n        self,\n        sample_ids=None,\n        label_ids=None,\n        allow_missing=True,\n        warn_missing=False,\n        reload=True,\n    ):\n        if label_ids is not None:\n            ids = label_ids\n        else:\n            ids = sample_ids\n\n        if not allow_missing or warn_missing:\n            existing_ids = self._get_index_ids()\n            missing_ids = set(ids) - set(existing_ids)\n            num_missing = len(missing_ids)\n\n            if num_missing > 0:\n                if not allow_missing:\n                    raise ValueError(\n                        \"Found %d IDs (eg %s) that are not present in the \"\n                        \"index\" % (num_missing, next(iter(missing_ids)))\n                    )\n\n                if warn_missing:\n                    logger.warning(\n                        \"Ignoring %d IDs that are not present in the index\",\n                        num_missing,\n                    )\n\n                ids = existing_ids\n\n        self._index.delete(ids)\n\n        if reload:\n            self.reload()\n\n    def get_embeddings(\n        self,\n        sample_ids=None,\n        label_ids=None,\n        allow_missing=True,\n        warn_missing=False,\n    ):\n        if label_ids is not None:\n            if self.config.patches_field is None:\n                raise ValueError(\"This index does not support label IDs\")\n\n            if sample_ids is not None:\n                logger.warning(\n                    \"Ignoring sample IDs when label IDs are provided\"\n                )\n\n        if sample_ids is not None and self.config.patches_field is not None:\n            (\n                embeddings,\n                sample_ids,\n                label_ids,\n                missing_ids,\n            ) = self._get_patch_embeddings_from_sample_ids(sample_ids)\n        elif self.config.patches_field is not None:\n            (\n                embeddings,\n                sample_ids,\n                label_ids,\n                missing_ids,\n            ) = self._get_patch_embeddings_from_label_ids(label_ids)\n        else:\n            (\n                embeddings,\n                sample_ids,\n                label_ids,\n                missing_ids,\n            ) = self._get_sample_embeddings(sample_ids)\n\n        num_missing_ids = len(missing_ids)\n        if num_missing_ids > 0:\n            if not allow_missing:\n                raise ValueError(\n                    \"Found %d IDs (eg %s) that do not exist in the index\"\n                    % (num_missing_ids, missing_ids[0])\n                )\n\n            if warn_missing:\n                logger.warning(\n                    \"Skipping %d IDs that do not exist in the index\",\n                    num_missing_ids,\n                )\n\n        embeddings = np.array(embeddings)\n        sample_ids = np.array(sample_ids)\n        if label_ids is not None:\n            label_ids = np.array(label_ids)\n\n        return embeddings, sample_ids, label_ids\n\n    # Note: might be an arg in delete_brain_run?\n    def cleanup(self):\n        if self._index is not None:\n            self._client.delete_index(\n                self.config.endpoint_name,\n                f\"{self.config.catalog_name}.{self.config.schema_name}.{self.config.index_name}\",\n            )\n            self._index = None\n\n    def _get_sample_embeddings(self, sample_ids, batch_size=200):\n        found_embeddings = []\n        found_sample_ids = []\n\n        if sample_ids is None:\n            sample_ids = self._get_index_ids()\n\n        result = self._index.scan(num_results=batch_size)\n        while len(result) > 0:\n            for doc in result[\"data\"]:\n                sample_id = doc[\"fields\"][1][\"value\"][\"string_value\"]\n                if sample_id in sample_ids:\n                    embedding = [\n                        d[\"number_value\"]\n                        for d in doc[\"fields\"][2][\"value\"][\"list_value\"][\n                            \"values\"\n                        ]\n                    ]\n                    found_embeddings.append(embedding)\n                    found_sample_ids.append(sample_id)\n            last_primary_key = result[\"last_primary_key\"]\n            result = self._index.scan(\n                num_results=batch_size, last_primary_key=last_primary_key\n            )\n\n        missing_ids = list(set(sample_ids) - set(found_sample_ids))\n\n        return found_embeddings, found_sample_ids, None, missing_ids\n\n    def _get_patch_embeddings_from_label_ids(self, label_ids, batch_size=200):\n        found_embeddings = []\n        found_sample_ids = []\n        found_label_ids = []\n\n        if label_ids is None:\n            label_ids = self._get_index_ids()\n\n        result = self._index.scan(num_results=batch_size)\n        while len(result) > 0:\n            for doc in result[\"data\"]:\n                label_id = doc[\"fields\"][0][\"value\"][\"string_value\"]\n                if label_id in label_ids:\n                    embedding = [\n                        d[\"number_value\"]\n                        for d in doc[\"fields\"][2][\"value\"][\"list_value\"][\n                            \"values\"\n                        ]\n                    ]\n                    found_embeddings.append(embedding)\n                    found_label_ids.append(label_id)\n                    found_sample_ids.append(\n                        doc[\"fields\"][1][\"value\"][\"string_value\"]\n                    )\n            last_primary_key = result[\"last_primary_key\"]\n            result = self._index.scan(\n                num_results=batch_size, last_primary_key=last_primary_key\n            )\n\n        missing_ids = list(set(label_ids) - set(found_label_ids))\n\n        return found_embeddings, found_sample_ids, found_label_ids, missing_ids\n\n    def _get_patch_embeddings_from_sample_ids(\n        self, sample_ids, batch_size=200\n    ):\n        found_embeddings = []\n        found_sample_ids = []\n        found_label_ids = []\n\n        result = self._index.scan(num_results=batch_size)\n        while len(result) > 0:\n            for doc in result[\"data\"]:\n                sample_id = doc[\"fields\"][1][\"value\"][\"string_value\"]\n                if sample_id in sample_ids:\n                    embedding = [\n                        d[\"number_value\"]\n                        for d in doc[\"fields\"][2][\"value\"][\"list_value\"][\n                            \"values\"\n                        ]\n                    ]\n                    found_embeddings.append(embedding)\n                    found_sample_ids.append(sample_id)\n                    found_label_ids.append(\n                        doc[\"fields\"][0][\"value\"][\"string_value\"]\n                    )\n            last_primary_key = result[\"last_primary_key\"]\n            result = self._index.scan(\n                num_results=batch_size, last_primary_key=last_primary_key\n            )\n\n        missing_ids = list(set(sample_ids) - set(found_sample_ids))\n\n        return found_embeddings, found_sample_ids, found_label_ids, missing_ids\n\n    def _kneighbors(\n        self,\n        query=None,\n        k=None,\n        reverse=False,\n        aggregation=None,\n        return_dists=False,\n    ):\n        if query is None:\n            raise ValueError(\"Mosaic does not support full index neighbors\")\n\n        if reverse is True:\n            raise ValueError(\n                \"Mosaic does not support least similarity queries\"\n            )\n\n        if k is None:\n            k = self.index_size\n\n        if aggregation not in (None, \"mean\"):\n            raise ValueError(\"Unsupported aggregation '%s'\" % aggregation)\n\n        query = self._parse_neighbors_query(query)\n        if aggregation == \"mean\" and query.ndim == 2:\n            query = query.mean(axis=0)\n\n        single_query = query.ndim == 1\n        if single_query:\n            query = [query]\n\n        if self.has_view:\n            if self.config.patches_field is not None:\n                index_ids = self.current_label_ids\n            else:\n                index_ids = self.current_sample_ids\n\n            # @todo apply filtering in similarity_search(), not post-hoc\n            # As of this writing, filtering is supported in Mosaic but it is\n            # not robust and cannot handle a large number of IDs\n            logger.warning(\n                \"The Mosaic backend does not yet support view filters; the \"\n                \"full index will instead be queried, which may result in \"\n                \"fewer matches in your current view\"\n            )\n\n            _filter = {\"foid\": set(index_ids)}\n        else:\n            _filter = None\n\n        sample_ids = []\n        label_ids = [] if self.config.patches_field is not None else None\n        dists = []\n        for q in query:\n            results = self._index.similarity_search(\n                columns=[\"foid\", \"sample_id\"],\n                query_vector=[float(i) for i in list(q)],\n                num_results=k,\n            )[\"result\"][\"data_array\"]\n\n            if _filter is not None:\n                results = [r for r in results if r[0] in _filter[\"foid\"]]\n\n            if self.config.patches_field is not None:\n                sample_ids.append([r[1] for r in results])\n                label_ids.append([r[0] for r in results])\n            else:\n                sample_ids.append([r[0] for r in results])\n\n            if return_dists:\n                dists.append([r[2] for r in results])\n\n        if single_query:\n            sample_ids = sample_ids[0]\n            if label_ids is not None:\n                label_ids = label_ids[0]\n            if return_dists:\n                dists = dists[0]\n\n        if return_dists:\n            return sample_ids, label_ids, dists\n\n        return sample_ids, label_ids\n\n    def _parse_neighbors_query(self, query):\n        if etau.is_str(query):\n            query_ids = [query]\n            single_query = True\n        else:\n            query = np.asarray(query)\n\n            # Query by vector(s)\n            if np.issubdtype(query.dtype, np.number):\n                return query\n\n            query_ids = list(query)\n            single_query = False\n\n        # Query by ID(s)\n        embeddings = self._get_values(query_ids)\n        if len(embeddings) == 0:\n            raise ValueError(\n                \"Query IDs %s do not exist in this index\" % query_ids\n            )\n        query = np.array(embeddings)\n\n        if single_query:\n            query = query[0, :]\n\n        return query\n\n    @classmethod\n    def _from_dict(cls, d, samples, config, brain_key):\n        return cls(samples, config, brain_key)\n"
  },
  {
    "path": "fiftyone/brain/internal/core/pgvector.py",
    "content": "\"\"\"\nPGVector similarity backend.\n\n| Copyright 2017-2026, Voxel51, Inc.\n| `voxel51.com <https://voxel51.com/>`_\n|\n\"\"\"\nimport logging\n\nimport numpy as np\n\nimport eta.core.utils as etau\n\nimport fiftyone.core.utils as fou\nfrom fiftyone.brain.similarity import (\n    SimilarityConfig,\n    Similarity,\n    SimilarityIndex,\n)\nimport fiftyone.brain.internal.core.utils as fbu\n\npsycopg2 = fou.lazy_import(\"psycopg2\")\npsy_extras = fou.lazy_import(\"psycopg2.extras\")\n\nlogger = logging.getLogger(__name__)\n\n# Supported metrics for pgvector\n_SUPPORTED_METRICS = {\n    \"cosine\": \"vector_cosine_ops\",\n    \"dotproduct\": \"vector_ip_ops\",\n    \"euclidean\": \"vector_l2_ops\",\n    \"l1\": \"vector_l1_ops\",\n    \"jaccard\": \"vector_jaccard_ops\",\n    \"hamming\": \"vector_hamming_ops\",\n}\n\n\nclass PgVectorSimilarityConfig(SimilarityConfig):\n    \"\"\"Configuration for the PGVector similarity backend.\n\n    Args:\n        index_name (None): the name of the PGVector index to use or create.\n            If none is provided, a default index name will be used.\n        table_name (None): the name of the table to use or create. If none is\n            provided, a default table name will be used.\n        metric (\"cosine\"): the similarity metric to use. Supported values are\n            ``(\"cosine\", \"dotproduct\", \"euclidean\", \"l1\", \"jaccard\", \"hamming\")``\n        connection_string (None): the connection string to the PostgreSQL database\n        ssl_cert (None): the path to the SSL certificate file\n        ssl_key (None): the path to the secret key used for the client certificate\n        ssl_root_cert (None): the path to the file containing SSL certificate\n            authority (CA) certificate(s).\n        work_mem (\"64MB\"): the base maximum amount of memory to be used by a query operation\n            (such as a sort or hash table) before writing to temporary disk files\n        hnsw_m (16): the max number of connections per layer in the HNSW index\n        hnsw_ef_construction (64): the size of the dynamic candidate list for constructing the graph for the HNSW index\n        **kwargs: keyword arguments for\n            :class:`fiftyone.brain.similarity.SimilarityConfig`\n    \"\"\"\n\n    def __init__(\n        self,\n        index_name=None,\n        table_name=None,\n        metric=\"cosine\",\n        connection_string=None,\n        ssl_cert=None,\n        ssl_key=None,\n        ssl_root_cert=None,\n        work_mem=\"64MB\",\n        hnsw_m=16,\n        hnsw_ef_construction=64,\n        **kwargs,\n    ):\n        if metric not in _SUPPORTED_METRICS:\n            raise ValueError(\n                f\"Unsupported metric '{metric}'. \"\n                f\"Supported values are {_SUPPORTED_METRICS}\"\n            )\n\n        super().__init__(**kwargs)\n\n        self.metric = metric\n        self.ssl_cert = ssl_cert\n        self.ssl_key = ssl_key\n        self.ssl_root_cert = ssl_root_cert\n        self.work_mem = work_mem\n        self.index_name = index_name\n        self.table_name = table_name\n        self.hnsw_m = hnsw_m\n        self.hnsw_ef_construction = hnsw_ef_construction\n\n        self._connection_string = connection_string\n\n    @property\n    def method(self):\n        return \"pgvector\"\n\n    @property\n    def connection_string(self):\n        return self._connection_string\n\n    @connection_string.setter\n    def connection_string(self, connection_string):\n        self._connection_string = connection_string\n\n    @property\n    def max_k(self):\n        return 10000\n\n    @property\n    def supports_least_similarity(self):\n        return False\n\n    @property\n    def supported_aggregations(self):\n        return (\"mean\",)\n\n    def load_credentials(\n        self,\n        connection_string=None,\n    ):\n        self._load_parameters(connection_string=connection_string)\n\n\nclass PgVectorSimilarity(Similarity):\n    \"\"\"PGVector similarity factory.\n\n    Args:\n        config: a :class:`PgVectorSimilarityConfig`\n    \"\"\"\n\n    def ensure_requirements(self):\n        fou.ensure_package(\"psycopg2|psycopg2-binary\")\n\n    def ensure_usage_requirements(self):\n        fou.ensure_package(\"psycopg2|psycopg2-binary\")\n\n    def initialize(self, samples, brain_key):\n        return PgVectorSimilarityIndex(\n            samples, self.config, brain_key, backend=self\n        )\n\n\nclass PgVectorSimilarityIndex(SimilarityIndex):\n    \"\"\"Class for interacting with PGVector similarity indexes.\n\n    Args:\n        samples: the :class:`fiftyone.core.collections.SampleCollection` used\n        config: the :class:`PGVectorSimilarityConfig` used\n        brain_key: the brain key\n        backend (None): a :class:`PGVectorSimilarity` instance\n    \"\"\"\n\n    def __init__(self, samples, config, brain_key, backend=None):\n        super().__init__(samples, config, brain_key, backend=backend)\n        self._conn = None\n        self._cur = None\n        self._initialize()\n\n    @property\n    def total_index_size(self):\n        if self._conn.closed:\n            self._initialize()\n        try:\n            self._cur.execute(\n                f\"\"\"SELECT COUNT(*) FROM \"{self.config.table_name}\";\"\"\"\n            )\n            return self._cur.fetchone()[0]\n        except Exception as e:\n            logger.error(f\"Error getting index size: {str(e)}\")\n            return 0\n\n    def _initialize(self):\n        ssl_options = {}\n        if self.config.ssl_cert:\n            ssl_options[\"sslcert\"] = self.config.ssl_cert\n        if self.config.ssl_key:\n            ssl_options[\"sslkey\"] = self.config.ssl_key\n        if self.config.ssl_root_cert:\n            ssl_options[\"sslrootcert\"] = self.config.ssl_root_cert\n\n        logger.info(f\"Connecting to PostgreSQL database\")\n        self._conn = psycopg2.connect(\n            self.config.connection_string, **ssl_options\n        )\n        self._cur = self._conn.cursor()\n        try:\n            self._cur.execute(\"CREATE EXTENSION IF NOT EXISTS vector\")\n            self._conn.commit()\n        except Exception as e:\n            logger.error(f\"Error creating vector extension: {str(e)}\")\n            raise\n\n        if self.config.table_name is None:\n            table_names = self._get_table_names()\n            root = \"fiftyone-\" + fou.to_slug(self.samples._root_dataset.name)\n            table_name = fbu.get_unique_name(root, table_names)\n\n            self.config.table_name = table_name\n            self.save_config()\n            existing_indexes = []\n        else:\n            existing_indexes = self._get_index_names(self.config.table_name)\n\n        if self.config.index_name is None:\n            root = \"fiftyone-index-\" + fou.to_slug(\n                self.samples._root_dataset.name\n            )\n            index_name = fbu.get_unique_name(root, existing_indexes)\n            self.config.index_name = index_name\n            self.save_config()\n\n    def _get_table_names(self):\n        self._cur.execute(\n            \"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';\"\n        )\n        return [row[0] for row in self._cur.fetchall()]\n\n    def _get_index_names(self, table_name):\n        self._cur.execute(\n            f\"SELECT indexname FROM pg_indexes WHERE tablename = '{table_name}' AND schemaname = 'public';\"\n        )\n        return [row[0] for row in self._cur.fetchall()]\n\n    def _create_table(self, dimension):\n        try:\n            self._cur.execute(\n                f\"\"\"\n                    CREATE TABLE IF NOT EXISTS \"{self.config.table_name}\" (\n                    id TEXT PRIMARY KEY,\n                    sample_id TEXT,\n                    embedding_vector VECTOR({dimension})\n                );\n                \"\"\"\n            )\n            self._conn.commit()\n        except Exception as e:\n            logger.error(\n                f\"Error creating table: {self.config.table_name} with dimension {dimension}: {str(e)}\"\n            )\n            raise\n\n    def create_hnsw_index(self):\n        operator_class = _SUPPORTED_METRICS[self.config.metric]\n        try:\n            self._cur.execute(\n                f\"\"\"DROP INDEX IF EXISTS \"{self.config.index_name}\";\"\"\"\n            )\n            self._conn.commit()\n            self._cur.execute(\n                f\"\"\"\n                CREATE INDEX \"{self.config.index_name}\"\n                ON \"{self.config.table_name}\" USING hnsw (embedding_vector {operator_class})\n                WITH (m = %s, ef_construction = %s);\n                \"\"\",\n                (self.config.hnsw_m, self.config.hnsw_ef_construction),\n            )\n            self._conn.commit()\n        except Exception as e:\n            logger.error(\n                f\"Error creating HNSW index on table {self.config.table_name}:{str(e)}\"\n            )\n            raise\n\n    def _get_index_ids(self, batch_size=1000):\n        named_cursor = self._conn.cursor(\n            name=\"id_cursor\"\n        )  # Named cursor for server-side query\n        named_cursor.execute(f\"\"\"SELECT id FROM \"{self.config.table_name}\";\"\"\")\n\n        existing_ids = []\n        while True:\n            rows = named_cursor.fetchmany(batch_size)\n            if not rows:\n                break\n\n            ids = [row[0] for row in rows]\n            existing_ids.extend(ids)\n\n        named_cursor.close()\n        return existing_ids\n\n    def add_to_index(\n        self,\n        embeddings,\n        sample_ids,\n        label_ids=None,\n        overwrite=True,\n        allow_existing=True,\n        warn_existing=False,\n        reload=True,\n        batch_size=5000,\n        close_conn=True,\n    ):\n        if self._conn.closed:\n            self._initialize()\n        self._cur.execute(f\"SET work_mem TO '{self.config.work_mem}'\")\n\n        if self.config.table_name not in self._get_table_names():\n            self._create_table(embeddings.shape[1])\n\n        if label_ids is not None:\n            ids = label_ids\n        else:\n            ids = sample_ids\n\n        if warn_existing or not allow_existing or not overwrite:\n            index_ids = self._get_index_ids()\n\n            existing_ids = set(ids) & set(index_ids)\n            num_existing = len(existing_ids)\n\n            if num_existing > 0:\n                if not allow_existing:\n                    raise ValueError(\n                        \"Found %d IDs (eg %s) that already exist in the index\"\n                        % (num_existing, next(iter(existing_ids)))\n                    )\n\n                if warn_existing:\n                    if overwrite:\n                        logger.warning(\n                            \"Overwriting %d IDs that already exist in the \"\n                            \"index\",\n                            num_existing,\n                        )\n                    else:\n                        logger.warning(\n                            \"Skipping %d IDs that already exist in the index\",\n                            num_existing,\n                        )\n        else:\n            existing_ids = set()\n\n        if existing_ids and not overwrite:\n            query = f\"\"\"\n                INSERT INTO \"{self.config.table_name}\" (id, sample_id, embedding_vector)\n                VALUES %s\n                ON CONFLICT (id) DO NOTHING;\n                \"\"\"\n        else:\n            query = f\"\"\"\n                INSERT INTO \"{self.config.table_name}\" (id, sample_id, embedding_vector)\n                VALUES %s\n                ON CONFLICT (id) DO UPDATE\n                SET sample_id = EXCLUDED.sample_id,\n                    embedding_vector = EXCLUDED.embedding_vector;\n                \"\"\"\n\n        embeddings = [e.tolist() for e in embeddings]\n        sample_ids = list(sample_ids)\n        if label_ids is not None:\n            ids = list(label_ids)\n        else:\n            ids = list(sample_ids)\n\n        for _embeddings, _ids, _sample_ids in zip(\n            fou.iter_batches(embeddings, batch_size),\n            fou.iter_batches(ids, batch_size),\n            fou.iter_batches(sample_ids, batch_size),\n        ):\n            data = list(zip(_ids, _sample_ids, _embeddings))\n            psy_extras.execute_values(self._cur, query, data)\n            self._conn.commit()\n\n        self.create_hnsw_index()\n\n        if close_conn:\n            self.close_connections()\n\n        if reload:\n            self.reload()\n\n    def remove_from_index(\n        self,\n        sample_ids=None,\n        label_ids=None,\n        allow_missing=True,\n        warn_missing=False,\n        reload=True,\n    ):\n        if self._conn.closed:\n            self._initialize()\n\n        if label_ids is not None:\n            ids = label_ids\n        else:\n            ids = sample_ids\n\n        if warn_missing or not allow_missing:\n            response = self.get_embeddings_by_id(ids)\n            existing_ids = [id for id, emb in response]\n            missing_ids = set(ids) - set(existing_ids)\n            num_missing_ids = len(missing_ids)\n\n            if num_missing_ids > 0:\n                if not allow_missing:\n                    raise ValueError(\n                        \"Found %d IDs (eg %s) that do not exist in the index\"\n                        % (num_missing_ids, next(iter(missing_ids)))\n                    )\n                if warn_missing and not allow_missing:\n                    logger.warning(\n                        \"Skipping %d IDs that do not exist in the index\",\n                        num_missing_ids,\n                    )\n        try:\n            # Use parameterized query to delete multiple IDs\n            self._cur.execute(\n                f\"\"\"DELETE FROM \"{self.config.table_name}\" WHERE id IN %s;\"\"\",\n                (tuple(ids),),\n            )\n        except Exception as e:\n            self._conn.rollback()\n            logger.error(f\"Error removing embeddings for ids {ids}: {str(e)}\")\n            raise\n\n        deleted_count = self._cur.rowcount\n        self._conn.commit()\n        logger.info(f\"Deleted {deleted_count} embeddings from the index.\")\n\n        if reload:\n            self.reload()\n\n    def close_connections(self):\n        if not self._cur.closed:\n            self._cur.close()\n        if not self._conn.closed:\n            self._conn.close()\n\n    def get_embeddings_by_id(self, sample_ids=None, label_ids=None):\n        if self._conn.closed:\n            self._initialize()\n        if label_ids is not None:\n            try:\n                self._cur.execute(\n                    f\"\"\"SELECT id, sample_id, embedding_vector FROM \"{self.config.table_name}\" WHERE id = ANY(%s)\"\"\",\n                    (list(label_ids),),\n                )\n            except Exception as e:\n                logger.error(\n                    f\"Error fetching embeddings for labels {label_ids}: {str(e)}\"\n                )\n                raise\n        elif sample_ids is not None:\n            try:\n                self._cur.execute(\n                    f\"\"\"SELECT id, sample_id, embedding_vector FROM \"{self.config.table_name}\" WHERE sample_id = ANY(%s)\"\"\",\n                    (list(sample_ids),),\n                )\n            except Exception as e:\n                logger.error(\n                    f\"Error fetching embeddings for samples {sample_ids}: {str(e)}\"\n                )\n                raise\n        else:\n            try:\n                self._cur.execute(\n                    f\"\"\"SELECT id, sample_id, embedding_vector FROM \"{self.config.table_name}\";\"\"\"\n                )\n            except Exception as e:\n                logger.error(\n                    f\"Error fetching embeddings for all samples: {str(e)}\"\n                )\n                raise\n\n        results = self._cur.fetchall()\n        fo_id = []\n        sample_id = []\n        embeddings = []\n        for result in results:\n            # Convert string \"[1.2,3.4,5.6]\" to float array\n            if isinstance(result[2], str):\n                emb = np.array(\n                    [float(x) for x in result[2].strip(\"[]\").split(\",\")],\n                    dtype=np.float32,\n                )\n                embeddings.append(emb)\n            else:\n                # Already numeric\n                emb = np.array(result[2], dtype=np.float32)\n                embeddings.append(emb)\n            fo_id.append(result[0])\n            sample_id.append(result[1])\n\n        return fo_id, sample_id, embeddings\n\n    def get_embeddings(\n        self,\n        sample_ids=None,\n        label_ids=None,\n        allow_missing=True,\n        warn_missing=False,\n    ):\n        if label_ids is not None:\n            if self.config.patches_field is None:\n                raise ValueError(\"This index does not support label IDs\")\n\n            if sample_ids is not None:\n                logger.warning(\n                    \"Ignoring sample IDs when label IDs are provided\"\n                )\n\n        if sample_ids is not None and self.config.patches_field is not None:\n            (\n                label_ids,\n                found_sample_ids,\n                embeddings,\n            ) = self.get_embeddings_by_id(sample_ids=sample_ids)\n            missing_ids = list(set(sample_ids) - set(found_sample_ids))\n            sample_ids = found_sample_ids\n        elif self.config.patches_field is not None:\n            (\n                found_label_ids,\n                sample_ids,\n                embeddings,\n            ) = self.get_embeddings_by_id(label_ids=label_ids)\n            missing_ids = (\n                list(set(label_ids) - set(found_label_ids))\n                if label_ids is not None\n                else []\n            )\n            label_ids = found_label_ids\n        else:\n            (\n                label_ids,\n                found_sample_ids,\n                embeddings,\n            ) = self.get_embeddings_by_id(sample_ids=sample_ids)\n            missing_ids = (\n                list(set(sample_ids) - set(found_sample_ids))\n                if sample_ids is not None\n                else []\n            )\n            sample_ids = found_sample_ids\n\n        num_missing_ids = len(missing_ids)\n        if num_missing_ids > 0:\n            if not allow_missing:\n                raise ValueError(\n                    \"Found %d IDs (eg %s) that do not exist in the index\"\n                    % (num_missing_ids, missing_ids[0])\n                )\n\n            if warn_missing:\n                logger.warning(\n                    \"Skipping %d IDs that do not exist in the index\",\n                    num_missing_ids,\n                )\n\n        embeddings = np.array(embeddings)\n        sample_ids = np.array(sample_ids)\n        if label_ids is not None:\n            label_ids = np.array(label_ids)\n\n        return embeddings, sample_ids, label_ids\n\n    def _kneighbors(\n        self,\n        query=None,\n        k=None,\n        reverse=False,\n        aggregation=None,\n        return_dists=False,\n        close_conn=True,\n    ):\n        if self._conn.closed:\n            self._initialize()\n\n        if query is None:\n            raise ValueError(\"Postgres does not support full index neighbors\")\n\n        if aggregation not in (None, \"mean\"):\n            raise ValueError(\"Unsupported aggregation '%s'\" % aggregation)\n\n        if k is None:\n            k = self.index_size\n\n        query = self._parse_neighbors_query(query)\n        if aggregation == \"mean\" and query.ndim == 2:\n            query = query.mean(axis=0)\n\n        single_query = query.ndim == 1\n        if single_query:\n            query = [query]\n\n        index_ids = None\n        if self.has_view:\n            if self.config.patches_field is not None:\n                index_ids = list(self.current_label_ids)\n            else:\n                index_ids = list(self.current_sample_ids)\n\n            _filter = True\n        else:\n            _filter = False\n\n        sort_order = \"DESC\" if reverse else \"ASC\"\n\n        sample_ids = []\n        label_ids = [] if self.config.patches_field is not None else None\n        dists = []\n        for q in query:\n            if _filter:\n                self._cur.execute(\n                    f\"\"\"\n                    SELECT id, sample_id, embedding_vector <-> %s::vector AS distance\n                    FROM \"{self.config.table_name}\"\n                    WHERE id = ANY(%s)\n                    ORDER BY distance {sort_order}\n                    LIMIT %s;\n                    \"\"\",\n                    (q.tolist(), index_ids, k),\n                )\n            else:\n                self._cur.execute(\n                    f\"\"\"\n                    SELECT id, sample_id, embedding_vector <-> %s::vector AS distance\n                    FROM \"{self.config.table_name}\"\n                    ORDER BY distance {sort_order}\n                    LIMIT %s;\n                    \"\"\",\n                    (q.tolist(), k),\n                )\n\n            results = self._cur.fetchall()\n\n            if self.config.patches_field is not None:\n                sample_ids.append([r[1] for r in results])\n                label_ids.append([r[0] for r in results])\n            else:\n                sample_ids.append([r[0] for r in results])\n\n            if return_dists:\n                dists.append([r[2] for r in results])\n\n        if close_conn:\n            self.close_connections()\n\n        if single_query:\n            sample_ids = sample_ids[0]\n            if label_ids is not None:\n                label_ids = label_ids[0]\n            if return_dists:\n                dists = dists[0]\n\n        if return_dists:\n            return sample_ids, label_ids, dists\n\n        return sample_ids, label_ids\n\n    def _parse_neighbors_query(self, query):\n        if etau.is_str(query):\n            query_ids = [query]\n            single_query = True\n        else:\n            query = np.asarray(query)\n\n            # Query by vector(s)\n            if np.issubdtype(query.dtype, np.number):\n                return query\n\n            query_ids = list(query)\n            single_query = False\n\n        _, _, embeddings = self.get_embeddings_by_id(label_ids=query_ids)\n        if len(embeddings) == 0:\n            raise ValueError(\n                \"Query IDs %s do not exist in this index\" % query_ids\n            )\n        query = np.array(embeddings)\n\n        if single_query:\n            query = query[0, :]\n\n        return query\n\n    def cleanup(self, drop_table=False):\n        \"\"\"\n        Clean up the database by dropping the HNSW index and optionally the embeddings table.\n        \"\"\"\n        logger.info(\n            f\"Cleaning up: Deleting HNSW index '{self.config.index_name}'\"\n        )\n        self._cur.execute(\n            f\"\"\"DROP INDEX IF EXISTS \"{self.config.index_name}\";\"\"\"\n        )\n\n        if self._conn.closed:\n            self._initialize()\n\n        if drop_table:\n            self._cur.execute(\n                f\"\"\"DROP TABLE IF EXISTS \"{self.config.table_name}\";\"\"\"\n            )\n            logger.info(\n                f\"{self.config.table_name} table deleted successfully.\"\n            )\n\n        self._conn.commit()\n        # Close the database connection\n        self.close_connections()\n        logger.info(\"Database connection closed.\")\n\n    @classmethod\n    def _from_dict(cls, d, samples, config, brain_key):\n        return cls(samples, config, brain_key)\n"
  },
  {
    "path": "fiftyone/brain/internal/core/pinecone.py",
    "content": "\"\"\"\nPiencone similarity backend.\n\n| Copyright 2017-2026, Voxel51, Inc.\n| `voxel51.com <https://voxel51.com/>`_\n|\n\"\"\"\nimport logging\n\nimport numpy as np\n\nimport eta.core.utils as etau\n\nimport fiftyone.core.utils as fou\nfrom fiftyone.brain.similarity import (\n    SimilarityConfig,\n    Similarity,\n    SimilarityIndex,\n)\nimport fiftyone.brain.internal.core.utils as fbu\n\npinecone = fou.lazy_import(\"pinecone\")\n\n\nlogger = logging.getLogger(__name__)\n\n_SUPPORTED_METRICS = (\"cosine\", \"dotproduct\", \"euclidean\")\n\n\nclass PineconeSimilarityConfig(SimilarityConfig):\n    \"\"\"Configuration for the Pinecone similarity backend.\n\n    Args:\n        index_name (None): the name of a Pinecone index to use or create. If\n            none is provided, a new index will be created\n        index_type (None): the index type to use when creating a new index.\n            The supported values are ``[\"serverless\", \"pod\"]`` and the default\n            is ``\"serverless\"``\n        namespace (None): a namespace under which to store vectors added to the\n            index\n        metric (None): the embedding distance metric to use when creating a\n            new index. Supported values are\n            ``(\"cosine\", \"dotproduct\", \"euclidean\")``\n        replicas (None): an optional number of replicas when creating a new\n            pod-based index\n        shards (None): an optional number of shards when creating a new\n            pod-based index\n        pods (None): an optional number of pods when creating a new pod-based\n            index\n        pod_type (None): an optional pod type when creating a new pod-based\n            index\n        api_key (None): a Pinecone API key to use\n        cloud (None): a cloud to use when creating serverless indexes\n        region (None): a region to use when creating serverless indexes\n        environment (None): an environment to use when creating pod-based\n            indexes\n        **kwargs: keyword arguments for\n            :class:`fiftyone.brain.similarity.SimilarityConfig`\n    \"\"\"\n\n    def __init__(\n        self,\n        index_name=None,\n        index_type=None,\n        namespace=None,\n        metric=None,\n        replicas=None,\n        shards=None,\n        pods=None,\n        pod_type=None,\n        api_key=None,\n        cloud=None,\n        region=None,\n        environment=None,\n        **kwargs,\n    ):\n        if metric is not None and metric not in _SUPPORTED_METRICS:\n            raise ValueError(\n                \"Unsupported metric '%s'. Supported values are %s\"\n                % (metric, _SUPPORTED_METRICS)\n            )\n\n        super().__init__(**kwargs)\n\n        self.index_name = index_name\n        self.index_type = index_type\n        self.namespace = namespace\n        self.metric = metric\n        self.replicas = replicas\n        self.shards = shards\n        self.pods = pods\n        self.pod_type = pod_type\n\n        # store privately so these aren't serialized\n        self._api_key = api_key\n        self._cloud = cloud\n        self._region = region\n        self._environment = environment\n\n    @property\n    def method(self):\n        return \"pinecone\"\n\n    @property\n    def api_key(self):\n        return self._api_key\n\n    @api_key.setter\n    def api_key(self, value):\n        self._api_key = value\n\n    @property\n    def cloud(self):\n        return self._cloud\n\n    @cloud.setter\n    def cloud(self, value):\n        self._cloud = value\n\n    @property\n    def region(self):\n        return self._region\n\n    @region.setter\n    def region(self, value):\n        self._region = value\n\n    @property\n    def environment(self):\n        return self._environment\n\n    @environment.setter\n    def environment(self, value):\n        self._environment = value\n\n    @property\n    def max_k(self):\n        return 10000  # Pinecone limit\n\n    @property\n    def supports_least_similarity(self):\n        return False\n\n    @property\n    def supported_aggregations(self):\n        return (\"mean\",)\n\n    def load_credentials(\n        self, api_key=None, cloud=None, region=None, environment=None\n    ):\n        self._load_parameters(\n            api_key=api_key,\n            cloud=cloud,\n            region=region,\n            environment=environment,\n        )\n\n\nclass PineconeSimilarity(Similarity):\n    \"\"\"Pinecone similarity factory.\n\n    Args:\n        config: a :class:`PineconeSimilarityConfig`\n    \"\"\"\n\n    def ensure_requirements(self):\n        fou.ensure_package(\"pinecone-client\")\n\n    def ensure_usage_requirements(self):\n        fou.ensure_package(\"pinecone-client>=3.2\")\n\n    def initialize(self, samples, brain_key):\n        return PineconeSimilarityIndex(\n            samples, self.config, brain_key, backend=self\n        )\n\n\nclass PineconeSimilarityIndex(SimilarityIndex):\n    \"\"\"Class for interacting with Pinecone similarity indexes.\n\n    Args:\n        samples: the :class:`fiftyone.core.collections.SampleCollection` used\n        config: the :class:`PineconeSimilarityConfig` used\n        brain_key: the brain key\n        backend (None): a :class:`PineconeSimilarity` instance\n    \"\"\"\n\n    def __init__(self, samples, config, brain_key, backend=None):\n        super().__init__(samples, config, brain_key, backend=backend)\n        self._pinecone = None\n        self._index = None\n        self._initialize()\n\n    def _initialize(self):\n        self._pinecone = pinecone.Pinecone(api_key=self.config.api_key)\n\n        try:\n            index_names = [d[\"name\"] for d in self._pinecone.list_indexes()]\n        except Exception as e:\n            raise ValueError(\n                \"Failed to connect to Pinecone backend. \"\n                \"Refer to https://docs.voxel51.com/integrations/pinecone.html \"\n                \"for more information\"\n            ) from e\n\n        if self.config.index_name is None:\n            # https://docs.pinecone.io/troubleshooting/restrictions-on-index-names\n            root = \"fiftyone-\" + fou.to_slug(self.samples._root_dataset.name)\n            index_name = fbu.get_unique_name(root, index_names, max_len=45)\n\n            self.config.index_name = index_name\n            self.save_config()\n\n        if self.config.index_name in index_names:\n            index = self._pinecone.Index(self.config.index_name)\n        else:\n            index = None\n\n        self._index = index\n\n    def _create_index(self, dimension):\n        index_type = self.config.index_type or \"serverless\"\n\n        if index_type == \"serverless\":\n            spec = pinecone.ServerlessSpec(\n                self.config.cloud,\n                self.config.region,\n            )\n        elif index_type == \"pod\":\n            kwargs = dict(\n                pod_type=self.config.pod_type,\n                pods=self.config.pods,\n                replicas=self.config.replicas,\n                shards=self.config.shards,\n            )\n            kwargs = {k: v for k, v in kwargs.items() if v is not None}\n            spec = pinecone.PodSpec(self.config.environment, **kwargs)\n        else:\n            raise TypeError(\n                f\"Invalid index_type='{index_type}'. The supported values are \"\n                \"['serverless', 'pod']\"\n            )\n\n        metric = self.config.metric or \"cosine\"\n        self._pinecone.create_index(\n            name=self.config.index_name,\n            dimension=dimension,\n            metric=metric,\n            spec=spec,\n        )\n\n        self._index = self._pinecone.Index(self.config.index_name)\n\n    @property\n    def index(self):\n        \"\"\"The ``pinecone.Index`` instance for this index.\"\"\"\n        return self._index\n\n    @property\n    def total_index_size(self):\n        if self._index is None:\n            return 0\n\n        return self._index.describe_index_stats()[\"total_vector_count\"]\n\n    @property\n    def ready(self):\n        return self._pinecone.describe_index(self.config.index_name).status[\n            \"ready\"\n        ]\n\n    def add_to_index(\n        self,\n        embeddings,\n        sample_ids,\n        label_ids=None,\n        overwrite=True,\n        allow_existing=True,\n        warn_existing=False,\n        reload=True,\n        batch_size=100,\n        namespace=None,\n    ):\n        if namespace is None:\n            namespace = self.config.namespace\n\n        if self._index is None:\n            self._create_index(embeddings.shape[1])\n\n        if label_ids is not None:\n            ids = label_ids\n        else:\n            ids = sample_ids\n\n        if warn_existing or not allow_existing or not overwrite:\n            existing_ids = self._get_existing_ids(ids)\n            num_existing = len(existing_ids)\n\n            if num_existing > 0:\n                if not allow_existing:\n                    raise ValueError(\n                        \"Found %d IDs (eg %s) that already exist in the index\"\n                        % (num_existing, next(iter(existing_ids)))\n                    )\n\n                if warn_existing:\n                    if overwrite:\n                        logger.warning(\n                            \"Overwriting %d IDs that already exist in the \"\n                            \"index\",\n                            num_existing,\n                        )\n                    else:\n                        logger.warning(\n                            \"Skipping %d IDs that already exist in the index\",\n                            num_existing,\n                        )\n        else:\n            existing_ids = set()\n\n        if existing_ids and not overwrite:\n            del_inds = [i for i, _id in enumerate(ids) if _id in existing_ids]\n            embeddings = np.delete(embeddings, del_inds, axis=0)\n            sample_ids = np.delete(sample_ids, del_inds)\n            if label_ids is not None:\n                label_ids = np.delete(label_ids, del_inds)\n\n        embeddings = [e.tolist() for e in embeddings]\n        sample_ids = list(sample_ids)\n        if label_ids is not None:\n            ids = list(label_ids)\n        else:\n            ids = list(sample_ids)\n\n        for _embeddings, _ids, _sample_ids in zip(\n            fou.iter_batches(embeddings, batch_size),\n            fou.iter_batches(ids, batch_size),\n            fou.iter_batches(sample_ids, batch_size),\n        ):\n            _id_dicts = [\n                {\"id\": _id, \"sample_id\": _sid}\n                for _id, _sid in zip(_ids, _sample_ids)\n            ]\n            self._index.upsert(\n                list(zip(_ids, _embeddings, _id_dicts)),\n                namespace=namespace,\n            )\n\n        if reload:\n            self.reload()\n\n    def remove_from_index(\n        self,\n        sample_ids=None,\n        label_ids=None,\n        allow_missing=True,\n        warn_missing=False,\n        reload=True,\n    ):\n        if label_ids is not None:\n            ids = label_ids\n        else:\n            ids = sample_ids\n\n        if not allow_missing or warn_missing:\n            existing_ids = list(self._index.fetch(ids).vectors.keys())\n            missing_ids = set(ids) - set(existing_ids)\n            num_missing = len(missing_ids)\n\n            if num_missing > 0:\n                if not allow_missing:\n                    raise ValueError(\n                        \"Found %d IDs (eg %s) that are not present in the \"\n                        \"index\" % (num_missing, next(iter(missing_ids)))\n                    )\n\n                if warn_missing:\n                    logger.warning(\n                        \"Ignoring %d IDs that are not present in the index\",\n                        num_missing,\n                    )\n\n                ids = existing_ids\n\n        self._index.delete(ids=ids)\n\n        if reload:\n            self.reload()\n\n    def get_embeddings(\n        self,\n        sample_ids=None,\n        label_ids=None,\n        allow_missing=True,\n        warn_missing=False,\n    ):\n        if label_ids is not None:\n            if self.config.patches_field is None:\n                raise ValueError(\"This index does not support label IDs\")\n\n            if sample_ids is not None:\n                logger.warning(\n                    \"Ignoring sample IDs when label IDs are provided\"\n                )\n\n        if sample_ids is not None and self.config.patches_field is not None:\n            (\n                embeddings,\n                sample_ids,\n                label_ids,\n                missing_ids,\n            ) = self._get_patch_embeddings_from_sample_ids(sample_ids)\n        elif self.config.patches_field is not None:\n            (\n                embeddings,\n                sample_ids,\n                label_ids,\n                missing_ids,\n            ) = self._get_patch_embeddings_from_label_ids(label_ids)\n        else:\n            (\n                embeddings,\n                sample_ids,\n                label_ids,\n                missing_ids,\n            ) = self._get_sample_embeddings(sample_ids)\n\n        num_missing_ids = len(missing_ids)\n        if num_missing_ids > 0:\n            if not allow_missing:\n                raise ValueError(\n                    \"Found %d IDs (eg %s) that do not exist in the index\"\n                    % (num_missing_ids, missing_ids[0])\n                )\n\n            if warn_missing:\n                logger.warning(\n                    \"Skipping %d IDs that do not exist in the index\",\n                    num_missing_ids,\n                )\n\n        embeddings = np.array(embeddings)\n        sample_ids = np.array(sample_ids)\n        if label_ids is not None:\n            label_ids = np.array(label_ids)\n\n        return embeddings, sample_ids, label_ids\n\n    def cleanup(self):\n        self._pinecone.delete_index(self.config.index_name)\n        self._index = None\n\n    def _get_existing_ids(self, ids, batch_size=1000):\n        existing_ids = set()\n        for batch_ids in fou.iter_batches(ids, batch_size):\n            response = self._index.fetch(ids=list(batch_ids))[\"vectors\"]\n            existing_ids.update(response.keys())\n\n        return existing_ids\n\n    def _get_sample_embeddings(self, sample_ids, batch_size=1000):\n        found_embeddings = []\n        found_sample_ids = []\n\n        if sample_ids is None:\n            raise ValueError(\n                \"Pinecone does not support retrieving all vectors in an index\"\n            )\n\n        for batch_ids in fou.iter_batches(sample_ids, batch_size):\n            response = self._index.fetch(ids=list(batch_ids))[\"vectors\"]\n\n            for r in response.values():\n                found_embeddings.append(r[\"values\"])\n                found_sample_ids.append(r[\"id\"])\n\n        missing_ids = list(set(sample_ids) - set(found_sample_ids))\n\n        return found_embeddings, found_sample_ids, None, missing_ids\n\n    def _get_patch_embeddings_from_label_ids(self, label_ids, batch_size=1000):\n        found_embeddings = []\n        found_sample_ids = []\n        found_label_ids = []\n\n        if label_ids is None:\n            raise ValueError(\n                \"Pinecone does not support retrieving all vectors in an index\"\n            )\n\n        for batch_ids in fou.iter_batches(label_ids, batch_size):\n            response = self._index.fetch(ids=list(batch_ids))[\"vectors\"]\n\n            for r in response.values():\n                found_embeddings.append(r[\"values\"])\n                found_sample_ids.append(r[\"metadata\"][\"sample_id\"])\n                found_label_ids.append(r[\"id\"])\n\n        missing_ids = list(set(label_ids) - set(found_label_ids))\n\n        return found_embeddings, found_sample_ids, found_label_ids, missing_ids\n\n    def _get_patch_embeddings_from_sample_ids(\n        self, sample_ids, batch_size=100\n    ):\n        found_embeddings = []\n        found_sample_ids = []\n        found_label_ids = []\n\n        query_vector = [0.0] * self._get_dimension()\n        top_k = min(batch_size, self.config.max_k)\n\n        for batch_ids in fou.iter_batches(sample_ids, batch_size):\n            response = self._index.query(\n                vector=query_vector,\n                filter={\"sample_id\": {\"$in\": list(batch_ids)}},\n                top_k=top_k,\n                include_values=True,\n                include_metadata=True,\n            )\n\n            for r in response[\"matches\"]:\n                found_embeddings.append(r[\"values\"])\n                found_sample_ids.append(r[\"metadata\"][\"sample_id\"])\n                found_label_ids.append(r[\"id\"])\n\n        missing_ids = list(set(sample_ids) - set(found_sample_ids))\n\n        return found_embeddings, found_sample_ids, found_label_ids, missing_ids\n\n    def _kneighbors(\n        self,\n        query=None,\n        k=None,\n        reverse=False,\n        aggregation=None,\n        return_dists=False,\n    ):\n        if query is None:\n            raise ValueError(\"Pinecone does not support full index neighbors\")\n\n        if reverse is True:\n            raise ValueError(\n                \"Pinecone does not support least similarity queries\"\n            )\n\n        if k is None or k > self.config.max_k:\n            raise ValueError(\"Pinecone requires k<=%s\" % self.config.max_k)\n\n        if aggregation not in (None, \"mean\"):\n            raise ValueError(\"Unsupported aggregation '%s'\" % aggregation)\n\n        query = self._parse_neighbors_query(query)\n        if aggregation == \"mean\" and query.ndim == 2:\n            query = query.mean(axis=0)\n\n        single_query = query.ndim == 1\n        if single_query:\n            query = [query]\n\n        if self.has_view:\n            if self.config.patches_field is not None:\n                index_ids = self.current_label_ids\n            else:\n                index_ids = self.current_sample_ids\n\n            _filter = {\"id\": {\"$in\": list(index_ids)}}\n        else:\n            _filter = None\n\n        sample_ids = []\n        label_ids = [] if self.config.patches_field is not None else None\n        dists = []\n        for q in query:\n            include_metadata = self.config.patches_field is not None\n            response = self._index.query(\n                vector=q.tolist(),\n                top_k=k,\n                filter=_filter,\n                include_metadata=include_metadata,\n            )\n\n            if self.config.patches_field is not None:\n                sample_ids.append(\n                    [r[\"metadata\"][\"sample_id\"] for r in response[\"matches\"]]\n                )\n                label_ids.append([r[\"id\"] for r in response[\"matches\"]])\n            else:\n                sample_ids.append([r[\"id\"] for r in response[\"matches\"]])\n\n            if return_dists:\n                dists.append([r[\"score\"] for r in response[\"matches\"]])\n\n        if single_query:\n            sample_ids = sample_ids[0]\n            if label_ids is not None:\n                label_ids = label_ids[0]\n            if return_dists:\n                dists = dists[0]\n\n        if return_dists:\n            return sample_ids, label_ids, dists\n\n        return sample_ids, label_ids\n\n    def _parse_neighbors_query(self, query):\n        if etau.is_str(query):\n            query_ids = [query]\n            single_query = True\n        else:\n            query = np.asarray(query)\n\n            # Query by vector(s)\n            if np.issubdtype(query.dtype, np.number):\n                return query\n\n            query_ids = list(query)\n            single_query = False\n\n        # Query by ID(s)\n        response = self._index.fetch(query_ids)[\"vectors\"]\n        query = np.array([response[_id][\"values\"] for _id in query_ids])\n\n        if query.size == 0:\n            raise ValueError(\n                \"Query IDs %s were not found in the index\" % query_ids\n            )\n\n        if single_query:\n            query = query[0, :]\n\n        return query\n\n    def _get_dimension(self):\n        if self._index is None:\n            return None\n\n        return self._index.describe_index_stats().dimension\n\n    @classmethod\n    def _from_dict(cls, d, samples, config, brain_key):\n        return cls(samples, config, brain_key)\n"
  },
  {
    "path": "fiftyone/brain/internal/core/qdrant.py",
    "content": "\"\"\"\nQdrant similarity backend.\n\n| Copyright 2017-2026, Voxel51, Inc.\n| `voxel51.com <https://voxel51.com/>`_\n|\n\"\"\"\nimport logging\n\nimport numpy as np\n\nimport eta.core.utils as etau\n\nimport fiftyone.core.utils as fou\nfrom fiftyone.brain.similarity import (\n    SimilarityConfig,\n    Similarity,\n    SimilarityIndex,\n)\nimport fiftyone.brain.internal.core.utils as fbu\n\nqdrant = fou.lazy_import(\"qdrant_client\")\nqmodels = fou.lazy_import(\"qdrant_client.http.models\")\n\n\nlogger = logging.getLogger(__name__)\n\n_SUPPORTED_METRICS = {\n    \"cosine\": qmodels.Distance.COSINE,\n    \"dotproduct\": qmodels.Distance.DOT,\n    \"euclidean\": qmodels.Distance.EUCLID,\n}\n\n\nclass QdrantSimilarityConfig(SimilarityConfig):\n    \"\"\"Configuration for the Qdrant similarity backend.\n\n    Args:\n        collection_name (None): the name of a Qdrant collection to use or\n            create. If none is provided, a new collection will be created\n        metric (None): the embedding distance metric to use when creating a\n            new index. Supported values are\n            ``(\"cosine\", \"dotproduct\", \"euclidean\")``\n        replication_factor (None): an optional replication factor to use when\n            creating a new index\n        shard_number (None): an optional number of shards to use when creating\n            a new index\n        write_consistency_factor (None): an optional write consistsency factor\n            to use when creating a new index\n        hnsw_config (None): an optional dict of HNSW config parameters to use\n            when creating a new index\n        optimizers_config (None): an optional dict of optimizer parameters to\n            use when creating a new index\n        wal_config (None): an optional dict of WAL config parameters to use\n            when creating a new index\n        url (None): a Qdrant server URL to use\n        api_key (None): a Qdrant API key to use\n        grpc_port (None): Port of Qdrant gRPC interface\n        prefer_grpc (None): If `true`, use gRPC interface when possible\n        **kwargs: keyword arguments for\n            :class:`fiftyone.brain.similarity.SimilarityConfig`\n    \"\"\"\n\n    def __init__(\n        self,\n        collection_name=None,\n        metric=None,\n        replication_factor=None,\n        shard_number=None,\n        write_consistency_factor=None,\n        hnsw_config=None,\n        optimizers_config=None,\n        wal_config=None,\n        url=None,\n        api_key=None,\n        grpc_port=None,\n        prefer_grpc=None,\n        **kwargs,\n    ):\n        if metric is not None and metric not in _SUPPORTED_METRICS:\n            raise ValueError(\n                \"Unsupported metric '%s'. Supported values are %s\"\n                % (metric, tuple(_SUPPORTED_METRICS.keys()))\n            )\n\n        super().__init__(**kwargs)\n\n        self.collection_name = collection_name\n        self.metric = metric\n        self.replication_factor = replication_factor\n        self.shard_number = shard_number\n        self.write_consistency_factor = write_consistency_factor\n        self.hnsw_config = hnsw_config\n        self.optimizers_config = optimizers_config\n        self.wal_config = wal_config\n\n        # store privately so these aren't serialized\n        self._url = url\n        self._api_key = api_key\n        self._grpc_port = grpc_port\n        self._prefer_grpc = prefer_grpc\n\n    @property\n    def method(self):\n        return \"qdrant\"\n\n    @property\n    def url(self):\n        return self._url\n\n    @url.setter\n    def url(self, value):\n        self._url = value\n\n    @property\n    def api_key(self):\n        return self._api_key\n\n    @api_key.setter\n    def api_key(self, value):\n        self._api_key = value\n\n    @property\n    def grpc_port(self):\n        return self._grpc_port\n\n    @grpc_port.setter\n    def grpc_port(self, value):\n        self._grpc_port = value\n\n    @property\n    def prefer_grpc(self):\n        return self._prefer_grpc\n\n    @prefer_grpc.setter\n    def prefer_grpc(self, value):\n        self._prefer_grpc = value\n\n    @property\n    def max_k(self):\n        return None\n\n    @property\n    def supports_least_similarity(self):\n        return False\n\n    @property\n    def supported_aggregations(self):\n        return (\"mean\",)\n\n    def load_credentials(\n        self, url=None, api_key=None, grpc_port=None, prefer_grpc=None\n    ):\n        self._load_parameters(\n            url=url,\n            api_key=api_key,\n            grpc_port=grpc_port,\n            prefer_grpc=prefer_grpc,\n        )\n\n\nclass QdrantSimilarity(Similarity):\n    \"\"\"Qdrant similarity factory.\n\n    Args:\n        config: a :class:`QdrantSimilarityConfig`\n    \"\"\"\n\n    def ensure_requirements(self):\n        fou.ensure_package(\"qdrant-client\")\n\n    def ensure_usage_requirements(self):\n        fou.ensure_package(\"qdrant-client\")\n\n    def initialize(self, samples, brain_key):\n        return QdrantSimilarityIndex(\n            samples, self.config, brain_key, backend=self\n        )\n\n\nclass QdrantSimilarityIndex(SimilarityIndex):\n    \"\"\"Class for interacting with Qdrant similarity indexes.\n\n    Args:\n        samples: the :class:`fiftyone.core.collections.SampleCollection` used\n        config: the :class:`QdrantSimilarityConfig` used\n        brain_key: the brain key\n        backend (None): a :class:`QdrantSimilarity` instance\n    \"\"\"\n\n    def __init__(self, samples, config, brain_key, backend=None):\n        super().__init__(samples, config, brain_key, backend=backend)\n        self._client = None\n        self._initialize()\n\n    def _initialize(self):\n        # QdrantClient does not appear to like passing None as defaults\n        grpc_port = (\n            self.config.grpc_port\n            if self.config.grpc_port is not None\n            else 6334\n        )\n        prefer_grpc = (\n            self.config.prefer_grpc\n            if self.config.prefer_grpc is not None\n            else False\n        )\n\n        self._client = qdrant.QdrantClient(\n            url=self.config.url,\n            api_key=self.config.api_key,\n            grpc_port=grpc_port,\n            prefer_grpc=prefer_grpc,\n        )\n\n        try:\n            collection_names = self._get_collection_names()\n        except Exception as e:\n            raise ValueError(\n                \"Failed to connect to Qdrant backend at URL '%s'. Refer to \"\n                \"https://docs.voxel51.com/integrations/qdrant.html for more \"\n                \"information\" % self.config.url\n            ) from e\n\n        if self.config.collection_name is None:\n            root = \"fiftyone-\" + fou.to_slug(self.samples._root_dataset.name)\n            collection_name = fbu.get_unique_name(root, collection_names)\n\n            self.config.collection_name = collection_name\n            self.save_config()\n\n    def _get_collection_names(self):\n        return [c.name for c in self._client.get_collections().collections]\n\n    def _create_collection(self, dimension):\n        if self.config.metric:\n            metric = self.config.metric\n        else:\n            metric = \"cosine\"\n\n        vectors_config = qmodels.VectorParams(\n            size=dimension,\n            distance=_SUPPORTED_METRICS[metric],\n        )\n\n        if self.config.hnsw_config:\n            hnsw_config = qmodels.HnswConfig(**self.config.hnsw_config)\n        else:\n            hnsw_config = None\n\n        if self.config.optimizers_config:\n            optimizers_config = qmodels.OptimizersConfig(\n                **self.config.optimizers_config\n            )\n        else:\n            optimizers_config = None\n\n        if self.config.wal_config:\n            wal_config = qmodels.WalConfig(**self.config.wal_config)\n        else:\n            wal_config = None\n\n        self._client.recreate_collection(\n            collection_name=self.config.collection_name,\n            vectors_config=vectors_config,\n            shard_number=self.config.shard_number,\n            replication_factor=self.config.replication_factor,\n            hnsw_config=hnsw_config,\n            optimizers_config=optimizers_config,\n            wal_config=wal_config,\n        )\n\n    def _get_index_ids(self, batch_size=1000):\n        ids = []\n\n        offset = 0\n        while offset is not None:\n            response = self._client.scroll(\n                collection_name=self.config.collection_name,\n                offset=offset,\n                limit=batch_size,\n                with_payload=True,\n                with_vectors=False,\n            )\n            ids.extend([self._to_fiftyone_id(r.id) for r in response[0]])\n            offset = response[-1]\n\n        return ids\n\n    @property\n    def total_index_size(self):\n        try:\n            return self._client.count(self.config.collection_name).count\n        except:\n            return 0\n\n    @property\n    def client(self):\n        \"\"\"The ``qdrant.QdrantClient`` instance for this index.\"\"\"\n        return self._client\n\n    def add_to_index(\n        self,\n        embeddings,\n        sample_ids,\n        label_ids=None,\n        overwrite=True,\n        allow_existing=True,\n        warn_existing=False,\n        reload=True,\n        batch_size=1000,\n    ):\n        if self.config.collection_name not in self._get_collection_names():\n            self._create_collection(embeddings.shape[1])\n\n        if label_ids is not None:\n            ids = label_ids\n        else:\n            ids = sample_ids\n\n        if warn_existing or not allow_existing or not overwrite:\n            index_ids = self._get_index_ids()\n\n            existing_ids = set(ids) & set(index_ids)\n            num_existing = len(existing_ids)\n\n            if num_existing > 0:\n                if not allow_existing:\n                    raise ValueError(\n                        \"Found %d IDs (eg %s) that already exist in the index\"\n                        % (num_existing, next(iter(existing_ids)))\n                    )\n\n                if warn_existing:\n                    if overwrite:\n                        logger.warning(\n                            \"Overwriting %d IDs that already exist in the \"\n                            \"index\",\n                            num_existing,\n                        )\n                    else:\n                        logger.warning(\n                            \"Skipping %d IDs that already exist in the index\",\n                            num_existing,\n                        )\n        else:\n            existing_ids = set()\n\n        if existing_ids and not overwrite:\n            del_inds = [i for i, _id in enumerate(ids) if _id in existing_ids]\n            embeddings = np.delete(embeddings, del_inds, axis=0)\n            sample_ids = np.delete(sample_ids, del_inds)\n            if label_ids is not None:\n                label_ids = np.delete(label_ids, del_inds)\n\n        embeddings = [e.tolist() for e in embeddings]\n        sample_ids = list(sample_ids)\n        if label_ids is not None:\n            ids = list(label_ids)\n        else:\n            ids = list(sample_ids)\n\n        for _embeddings, _ids, _sample_ids in zip(\n            fou.iter_batches(embeddings, batch_size),\n            fou.iter_batches(ids, batch_size),\n            fou.iter_batches(sample_ids, batch_size),\n        ):\n            self._client.upsert(\n                collection_name=self.config.collection_name,\n                points=qmodels.Batch(\n                    ids=self._to_qdrant_ids(_ids),\n                    payloads=[{\"sample_id\": _id} for _id in _sample_ids],\n                    vectors=_embeddings,\n                ),\n            )\n\n        if reload:\n            self.reload()\n\n    def remove_from_index(\n        self,\n        sample_ids=None,\n        label_ids=None,\n        allow_missing=True,\n        warn_missing=False,\n        reload=True,\n    ):\n        if label_ids is not None:\n            ids = label_ids\n        else:\n            ids = sample_ids\n\n        qids = self._to_qdrant_ids(ids)\n\n        if warn_missing or not allow_missing:\n            response = self._retrieve_points(qids, with_vectors=False)\n            existing_ids = self._to_fiftyone_ids([r.id for r in response])\n            missing_ids = set(ids) - set(existing_ids)\n            num_missing_ids = len(missing_ids)\n\n            if num_missing_ids > 0:\n                if not allow_missing:\n                    raise ValueError(\n                        \"Found %d IDs (eg %s) that do not exist in the index\"\n                        % (num_missing_ids, next(iter(missing_ids)))\n                    )\n                if warn_missing and not allow_missing:\n                    logger.warning(\n                        \"Skipping %d IDs that do not exist in the index\",\n                        num_missing_ids,\n                    )\n\n            qids = self._to_qdrant_ids(existing_ids)\n\n        self._client.delete(\n            collection_name=self.config.collection_name,\n            points_selector=qmodels.PointIdsList(points=qids),\n        )\n\n        if reload:\n            self.reload()\n\n    def get_embeddings(\n        self,\n        sample_ids=None,\n        label_ids=None,\n        allow_missing=True,\n        warn_missing=False,\n    ):\n        if label_ids is not None:\n            if self.config.patches_field is None:\n                raise ValueError(\"This index does not support label IDs\")\n\n            if sample_ids is not None:\n                logger.warning(\n                    \"Ignoring sample IDs when label IDs are provided\"\n                )\n\n        if sample_ids is not None and self.config.patches_field is not None:\n            (\n                embeddings,\n                sample_ids,\n                label_ids,\n                missing_ids,\n            ) = self._get_patch_embeddings_from_sample_ids(sample_ids)\n        elif self.config.patches_field is not None:\n            (\n                embeddings,\n                sample_ids,\n                label_ids,\n                missing_ids,\n            ) = self._get_patch_embeddings_from_label_ids(label_ids)\n        else:\n            (\n                embeddings,\n                sample_ids,\n                label_ids,\n                missing_ids,\n            ) = self._get_sample_embeddings(sample_ids)\n\n        num_missing_ids = len(missing_ids)\n        if num_missing_ids > 0:\n            if not allow_missing:\n                raise ValueError(\n                    \"Found %d IDs (eg %s) that do not exist in the index\"\n                    % (num_missing_ids, missing_ids[0])\n                )\n\n            if warn_missing:\n                logger.warning(\n                    \"Skipping %d IDs that do not exist in the index\",\n                    num_missing_ids,\n                )\n\n        embeddings = np.array(embeddings)\n        sample_ids = np.array(sample_ids)\n        if label_ids is not None:\n            label_ids = np.array(label_ids)\n\n        return embeddings, sample_ids, label_ids\n\n    def cleanup(self):\n        self._client.delete_collection(self.config.collection_name)\n\n    def _retrieve_points(self, qids, with_vectors=True, with_payload=True):\n        # @todo add batching?\n        return self._client.retrieve(\n            collection_name=self.config.collection_name,\n            ids=qids,\n            with_vectors=with_vectors,\n            with_payload=with_payload,\n        )\n\n    def _get_sample_embeddings(self, sample_ids):\n        if sample_ids is None:\n            sample_ids = self._get_index_ids()\n\n        response = self._retrieve_points(\n            self._to_qdrant_ids(sample_ids),\n            with_vectors=True,\n        )\n\n        found_embeddings = [r.vector for r in response]\n        found_sample_ids = self._to_fiftyone_ids([r.id for r in response])\n        missing_ids = list(set(sample_ids) - set(found_sample_ids))\n\n        return found_embeddings, found_sample_ids, None, missing_ids\n\n    def _get_patch_embeddings_from_label_ids(self, label_ids):\n        if label_ids is None:\n            label_ids = self._get_index_ids()\n\n        response = self._retrieve_points(\n            self._to_qdrant_ids(label_ids),\n            with_vectors=True,\n            with_payload=True,\n        )\n\n        found_embeddings = [r.vector for r in response]\n        found_sample_ids = [r.payload[\"sample_id\"] for r in response]\n        found_label_ids = self._to_fiftyone_ids([r.id for r in response])\n        missing_ids = list(set(label_ids) - set(found_label_ids))\n\n        return found_embeddings, found_sample_ids, found_label_ids, missing_ids\n\n    def _get_patch_embeddings_from_sample_ids(self, sample_ids):\n        _filter = qmodels.Filter(\n            should=[\n                qmodels.FieldCondition(\n                    key=\"sample_id\", match=qmodels.MatchValue(value=sid)\n                )\n                for sid in sample_ids\n            ]\n        )\n\n        response = self._client.scroll(\n            collection_name=self.config.collection_name,\n            scroll_filter=_filter,\n            with_vectors=True,\n            with_payload=True,\n        )[0]\n\n        found_embeddings = [r.vector for r in response]\n        found_sample_ids = [r.payload[\"sample_id\"] for r in response]\n        found_label_ids = [r.id for r in response]\n        missing_ids = list(set(sample_ids) - set(found_sample_ids))\n\n        return found_embeddings, found_sample_ids, found_label_ids, missing_ids\n\n    def _kneighbors(\n        self,\n        query=None,\n        k=None,\n        reverse=False,\n        aggregation=None,\n        return_dists=False,\n    ):\n        if query is None:\n            raise ValueError(\"Qdrant does not support full index neighbors\")\n\n        if reverse is True:\n            raise ValueError(\n                \"Qdrant does not support least similarity queries\"\n            )\n\n        if aggregation not in (None, \"mean\"):\n            raise ValueError(\"Unsupported aggregation '%s'\" % aggregation)\n\n        if k is None:\n            k = self.index_size\n\n        query = self._parse_neighbors_query(query)\n        if aggregation == \"mean\" and query.ndim == 2:\n            query = query.mean(axis=0)\n\n        single_query = query.ndim == 1\n        if single_query:\n            query = [query]\n\n        if self.has_view:\n            if self.config.patches_field is not None:\n                index_ids = self.current_label_ids\n            else:\n                index_ids = self.current_sample_ids\n\n            _filter = qmodels.Filter(\n                must=[\n                    qmodels.HasIdCondition(\n                        has_id=self._to_qdrant_ids(index_ids)\n                    )\n                ]\n            )\n        else:\n            _filter = None\n\n        sample_ids = []\n        label_ids = [] if self.config.patches_field is not None else None\n        dists = []\n        for q in query:\n            with_payload = self.config.patches_field is not None\n            results = self._client.search(\n                collection_name=self.config.collection_name,\n                query_vector=q,\n                query_filter=_filter,\n                with_payload=with_payload,\n                limit=k,\n            )\n\n            if self.config.patches_field is not None:\n                sample_ids.append(\n                    self._to_fiftyone_ids(\n                        [r.payload[\"sample_id\"] for r in results]\n                    )\n                )\n                label_ids.append(\n                    self._to_fiftyone_ids([r.id for r in results])\n                )\n            else:\n                sample_ids.append(\n                    self._to_fiftyone_ids([r.id for r in results])\n                )\n\n            if return_dists:\n                dists.append([r.score for r in results])\n\n        if single_query:\n            sample_ids = sample_ids[0]\n            if label_ids is not None:\n                label_ids = label_ids[0]\n            if return_dists:\n                dists = dists[0]\n\n        if return_dists:\n            return sample_ids, label_ids, dists\n\n        return sample_ids, label_ids\n\n    def _parse_neighbors_query(self, query):\n        if etau.is_str(query):\n            query_ids = [query]\n            single_query = True\n        else:\n            query = np.asarray(query)\n\n            # Query by vector(s)\n            if np.issubdtype(query.dtype, np.number):\n                return query\n\n            query_ids = list(query)\n            single_query = False\n\n        # Query by ID(s)\n        qids = self._to_qdrant_ids(query_ids)\n        response = self._retrieve_points(qids, with_vectors=True)\n        query = np.array([r.vector for r in response])\n\n        if query.size == 0:\n            raise ValueError(\n                \"Query IDs %s were not found in the index\" % query_ids\n            )\n\n        if single_query:\n            query = query[0, :]\n\n        return query\n\n    def _to_qdrant_id(self, _id):\n        return _id + \"00000000\"\n\n    def _to_qdrant_ids(self, ids):\n        return [self._to_qdrant_id(_id) for _id in ids]\n\n    def _to_fiftyone_id(self, qid):\n        return qid.replace(\"-\", \"\")[:-8]\n\n    def _to_fiftyone_ids(self, qids):\n        return [self._to_fiftyone_id(qid) for qid in qids]\n\n    @classmethod\n    def _from_dict(cls, d, samples, config, brain_key):\n        return cls(samples, config, brain_key)\n"
  },
  {
    "path": "fiftyone/brain/internal/core/redis.py",
    "content": "\"\"\"\nRedis similarity backend.\n\n| Copyright 2017-2026, Voxel51, Inc.\n| `voxel51.com <https://voxel51.com/>`_\n|\n\"\"\"\nimport logging\n\nimport numpy as np\n\nimport eta.core.utils as etau\n\nimport fiftyone.core.utils as fou\nfrom fiftyone.brain.similarity import (\n    SimilarityConfig,\n    Similarity,\n    SimilarityIndex,\n)\nimport fiftyone.brain.internal.core.utils as fbu\n\nredis = fou.lazy_import(\"redis\")\n\n\nlogger = logging.getLogger(__name__)\n\n_SUPPORTED_METRICS = {\n    \"cosine\": \"COSINE\",\n    \"dotproduct\": \"IP\",\n    \"euclidean\": \"L2\",\n}\n\n\nclass RedisSimilarityConfig(SimilarityConfig):\n    \"\"\"Configuration for the Redis similarity backend.\n\n    Args:\n        index_name (None): the name of a Redis index to use or create. If none\n            is provided, a new index will be created\n        metric (\"cosine\"): the embedding distance metric to use when creating a\n            new index. Supported values are\n            ``(\"cosine\", \"dotproduct\", \"euclidean\")``\n        algorithm (\"FLAT\"): the search algorithm to use. The supported values\n            are ``(\"FLAT\", \"HNSW\")``\n        host (\"localhost\"): the host to use\n        port (6379): the port to use\n        db (0): the database to use\n        username (None): a username to use\n        password (None): a password to use\n        **kwargs: keyword arguments for\n            :class:`fiftyone.brain.similarity.SimilarityConfig`\n    \"\"\"\n\n    def __init__(\n        self,\n        index_name=None,\n        metric=\"cosine\",\n        algorithm=\"FLAT\",\n        host=\"localhost\",\n        port=6379,\n        db=0,\n        username=None,\n        password=None,\n        **kwargs,\n    ):\n        if metric not in _SUPPORTED_METRICS:\n            raise ValueError(\n                \"Unsupported metric '%s'. Supported values are %s\"\n                % (metric, tuple(_SUPPORTED_METRICS.keys()))\n            )\n\n        super().__init__(**kwargs)\n\n        self.index_name = index_name\n        self.metric = metric\n        self.algorithm = algorithm\n\n        # store privately so these aren't serialized\n        self._host = host\n        self._port = port\n        self._db = db\n        self._username = username\n        self._password = password\n\n    @property\n    def method(self):\n        return \"redis\"\n\n    @property\n    def host(self):\n        return self._host\n\n    @host.setter\n    def host(self, value):\n        self._host = value\n\n    @property\n    def port(self):\n        return self._port\n\n    @port.setter\n    def port(self, value):\n        self._port = value\n\n    @property\n    def db(self):\n        return self._db\n\n    @db.setter\n    def db(self, value):\n        self._db = value\n\n    @property\n    def username(self):\n        return self._username\n\n    @username.setter\n    def username(self, value):\n        self._username = value\n\n    @property\n    def password(self):\n        return self._password\n\n    @password.setter\n    def password(self, value):\n        self._password = value\n\n    @property\n    def max_k(self):\n        return None\n\n    @property\n    def supports_least_similarity(self):\n        return False\n\n    @property\n    def supported_aggregations(self):\n        return (\"mean\",)\n\n    def load_credentials(\n        self, host=None, port=None, db=None, username=None, password=None\n    ):\n        self._load_parameters(\n            host=host, port=port, db=db, username=username, password=password\n        )\n\n\nclass RedisSimilarity(Similarity):\n    \"\"\"Redis similarity factory.\n\n    Args:\n        config: a :class:`RedisSimilarityConfig`\n    \"\"\"\n\n    def ensure_requirements(self):\n        fou.ensure_package(\"redis\")\n\n    def ensure_usage_requirements(self):\n        fou.ensure_package(\"redis\")\n\n    def initialize(self, samples, brain_key):\n        return RedisSimilarityIndex(\n            samples, self.config, brain_key, backend=self\n        )\n\n\nclass RedisSimilarityIndex(SimilarityIndex):\n    \"\"\"Class for interacting with Redis similarity indexes.\n\n    Args:\n        samples: the :class:`fiftyone.core.collections.SampleCollection` used\n        config: the :class:`RedisSimilarityConfig` used\n        brain_key: the brain key\n        backend (None): a :class:`RedisSimilarity` instance\n    \"\"\"\n\n    def __init__(self, samples, config, brain_key, backend=None):\n        super().__init__(samples, config, brain_key, backend=backend)\n        self._client = None\n        self._index = None\n        self._initialize()\n\n    def _initialize(self):\n        client = redis.Redis(\n            host=self.config.host,\n            port=self.config.port,\n            db=self.config.db,\n            username=self.config.username,\n            password=self.config.password,\n            decode_responses=True,\n        )\n\n        if self.config.index_name is None:\n\n            def index_exists(index_name):\n                try:\n                    client.ft(index_name).info()\n                    return True\n                except:\n                    return False\n\n            root = \"fiftyone-\" + fou.to_slug(self._samples._root_dataset.name)\n            index_name = fbu.get_unique_name(root, index_exists)\n\n            self.config.index_name = index_name\n            self.save_config()\n\n        try:\n            index = client.ft(self.config.index_name)\n            index.info()\n        except:\n            index = None\n\n        self._client = client\n        self._index = index\n\n    def _create_index(self, dimension):\n        from redis.commands.search.field import TagField, VectorField\n        from redis.commands.search.indexDefinition import (\n            IndexDefinition,\n            IndexType,\n        )\n\n        schema = (\n            TagField(\"$.foid\", as_name=\"foid\"),\n            TagField(\"$.sample_id\", as_name=\"sample_id\"),\n            VectorField(\n                \"$.vector\",\n                self.config.algorithm,\n                {\n                    \"TYPE\": \"FLOAT32\",\n                    \"DIM\": dimension,\n                    \"DISTANCE_METRIC\": _SUPPORTED_METRICS[self.config.metric],\n                },\n                as_name=\"vector\",\n            ),\n        )\n        definition = IndexDefinition(\n            prefix=[self.config.index_name + \":\"],\n            index_type=IndexType.JSON,\n        )\n        index = self._client.ft(self.config.index_name)\n        index.create_index(fields=schema, definition=definition)\n\n        self._index = index\n\n    @property\n    def client(self):\n        \"\"\"The ``redis.client.Redis`` instance for this index.\"\"\"\n        return self._client\n\n    @property\n    def total_index_size(self):\n        try:\n            return int(self._index.info()[\"num_docs\"])\n        except:\n            return 0\n\n    def add_to_index(\n        self,\n        embeddings,\n        sample_ids,\n        label_ids=None,\n        overwrite=True,\n        allow_existing=True,\n        warn_existing=False,\n        reload=True,\n    ):\n        if self._index is None:\n            self._create_index(embeddings.shape[1])\n\n        if label_ids is not None:\n            ids = label_ids\n        else:\n            ids = sample_ids\n\n        if warn_existing or not allow_existing or not overwrite:\n            existing_ids = self._get_existing_ids(ids)\n            num_existing = len(existing_ids)\n\n            if num_existing > 0:\n                if not allow_existing:\n                    raise ValueError(\n                        \"Found %d IDs (eg %s) that already exist in the index\"\n                        % (num_existing, next(iter(existing_ids)))\n                    )\n\n                if warn_existing:\n                    if overwrite:\n                        logger.warning(\n                            \"Overwriting %d IDs that already exist in the \"\n                            \"index\",\n                            num_existing,\n                        )\n                    else:\n                        logger.warning(\n                            \"Skipping %d IDs that already exist in the index\",\n                            num_existing,\n                        )\n        else:\n            existing_ids = set()\n\n        if existing_ids and not overwrite:\n            del_inds = [i for i, _id in enumerate(ids) if _id in existing_ids]\n            embeddings = np.delete(embeddings, del_inds, axis=0)\n            sample_ids = np.delete(sample_ids, del_inds)\n            if label_ids is not None:\n                label_ids = np.delete(label_ids, del_inds)\n        elif existing_ids and overwrite:\n            self._delete_ids(existing_ids)\n\n        pipeline = self._client.pipeline()\n        for e, id, sample_id in zip(embeddings, ids, sample_ids):\n            key = f\"{self.config.index_name}:{id}\"\n            d = {\n                \"foid\": id,\n                \"sample_id\": sample_id,\n                \"vector\": e.astype(np.float32).tolist(),\n            }\n            pipeline.json().set(key, \"$\", d)\n\n        pipeline.execute()\n\n        if reload:\n            self.reload()\n\n    def _get_existing_ids(self, ids):\n        return [d[\"foid\"] for d in self._get_values(ids)]\n\n    def _delete_ids(self, ids):\n        keys = [f\"{self.config.index_name}:{id}\" for id in ids]\n        self._client.delete(*keys)\n\n    def _get_values(self, ids):\n        pipeline = self._client.pipeline()\n        for id in ids:\n            pipeline.json().get(f\"{self.config.index_name}:{id}\")\n\n        return [d for d in pipeline.execute() if d is not None]\n\n    def remove_from_index(\n        self,\n        sample_ids=None,\n        label_ids=None,\n        allow_missing=True,\n        warn_missing=False,\n        reload=True,\n    ):\n        if label_ids is not None:\n            ids = label_ids\n        else:\n            ids = sample_ids\n\n        if not allow_missing or warn_missing:\n            existing_ids = self._get_existing_ids(ids)\n            missing_ids = set(ids) - set(existing_ids)\n            num_missing = len(missing_ids)\n\n            if num_missing > 0:\n                if not allow_missing:\n                    raise ValueError(\n                        \"Found %d IDs (eg %s) that are not present in the \"\n                        \"index\" % (num_missing, next(iter(missing_ids)))\n                    )\n\n                if warn_missing:\n                    logger.warning(\n                        \"Ignoring %d IDs that are not present in the index\",\n                        num_missing,\n                    )\n\n                ids = existing_ids\n\n        self._delete_ids(ids=ids)\n\n        if reload:\n            self.reload()\n\n    def get_embeddings(\n        self,\n        sample_ids=None,\n        label_ids=None,\n        allow_missing=True,\n        warn_missing=False,\n    ):\n        if label_ids is not None:\n            if self.config.patches_field is None:\n                raise ValueError(\"This index does not support label IDs\")\n\n            if sample_ids is not None:\n                logger.warning(\n                    \"Ignoring sample IDs when label IDs are provided\"\n                )\n\n        if sample_ids is not None and self.config.patches_field is not None:\n            (\n                embeddings,\n                sample_ids,\n                label_ids,\n                missing_ids,\n            ) = self._get_patch_embeddings_from_sample_ids(sample_ids)\n        elif self.config.patches_field is not None:\n            (\n                embeddings,\n                sample_ids,\n                label_ids,\n                missing_ids,\n            ) = self._get_patch_embeddings_from_label_ids(label_ids)\n        else:\n            (\n                embeddings,\n                sample_ids,\n                label_ids,\n                missing_ids,\n            ) = self._get_sample_embeddings(sample_ids)\n\n        num_missing_ids = len(missing_ids)\n        if num_missing_ids > 0:\n            if not allow_missing:\n                raise ValueError(\n                    \"Found %d IDs (eg %s) that do not exist in the index\"\n                    % (num_missing_ids, missing_ids[0])\n                )\n\n            if warn_missing:\n                logger.warning(\n                    \"Skipping %d IDs that do not exist in the index\",\n                    num_missing_ids,\n                )\n\n        embeddings = np.array(embeddings)\n        sample_ids = np.array(sample_ids)\n        if label_ids is not None:\n            label_ids = np.array(label_ids)\n\n        return embeddings, sample_ids, label_ids\n\n    def cleanup(self):\n        if self._index is None:\n            return\n\n        self._index.dropindex(delete_documents=True)\n        self._index = None\n\n    def _get_sample_embeddings(self, sample_ids, batch_size=1000):\n        found_embeddings = []\n        found_sample_ids = []\n\n        if sample_ids is None:\n            get_id = lambda key: key.rsplit(\":\", 1)[1]\n            keys = self._client.keys(f\"{self.config.index_name}:*\")\n            sample_ids = map(get_id, keys)\n\n        for batch_ids in fou.iter_batches(sample_ids, batch_size):\n            for d in self._get_values(batch_ids):\n                found_embeddings.append(d[\"vector\"])\n                found_sample_ids.append(d[\"sample_id\"])\n\n        missing_ids = list(set(sample_ids) - set(found_sample_ids))\n\n        return found_embeddings, found_sample_ids, None, missing_ids\n\n    def _get_patch_embeddings_from_label_ids(self, label_ids, batch_size=1000):\n        found_embeddings = []\n        found_sample_ids = []\n        found_label_ids = []\n\n        if label_ids is None:\n            get_id = lambda key: key.rsplit(\":\", 1)[1]\n            keys = self._client.keys(f\"{self.config.index_name}:*\")\n            label_ids = map(get_id, keys)\n\n        for batch_ids in fou.iter_batches(label_ids, batch_size):\n            for d in self._get_values(batch_ids):\n                found_embeddings.append(d[\"vector\"])\n                found_sample_ids.append(d[\"sample_id\"])\n                found_label_ids.append(d[\"foid\"])\n\n        missing_ids = list(set(label_ids) - set(found_label_ids))\n\n        return found_embeddings, found_sample_ids, found_label_ids, missing_ids\n\n    def _get_patch_embeddings_from_sample_ids(\n        self, sample_ids, batch_size=100\n    ):\n        from redis.commands.search.query import Query\n\n        found_embeddings = []\n        found_sample_ids = []\n        found_label_ids = []\n\n        for batch_ids in fou.iter_batches(sample_ids, batch_size):\n            filter = \"@sample_id:{ \" + \" | \".join(batch_ids) + \" }\"\n            query = Query(filter).dialect(2)\n            for doc in self._index.search(query).docs:\n                found_embeddings.append(doc.embeddings)\n                found_sample_ids.append(doc.sample_id)\n                found_label_ids.append(doc.foid)\n\n        missing_ids = list(set(sample_ids) - set(found_sample_ids))\n\n        return found_embeddings, found_sample_ids, found_label_ids, missing_ids\n\n    def _kneighbors(\n        self,\n        query=None,\n        k=None,\n        reverse=False,\n        aggregation=None,\n        return_dists=False,\n    ):\n        from redis.commands.search.query import Query\n\n        if query is None:\n            raise ValueError(\"Redis does not support full index neighbors\")\n\n        if reverse is True:\n            raise ValueError(\"Redis does not support least similarity queries\")\n\n        if k is None:\n            k = self.index_size\n\n        if aggregation not in (None, \"mean\"):\n            raise ValueError(\"Unsupported aggregation '%s'\" % aggregation)\n\n        query = self._parse_neighbors_query(query)\n        if aggregation == \"mean\" and query.ndim == 2:\n            query = query.mean(axis=0)\n\n        single_query = query.ndim == 1\n        if single_query:\n            query = [query]\n\n        if self.has_view:\n            if self.config.patches_field is not None:\n                index_ids = list(self.current_label_ids)\n            else:\n                index_ids = list(self.current_sample_ids)\n\n            filter = \"@foid:{ \" + \" | \".join(index_ids) + \" }\"\n        else:\n            filter = \"*\"\n\n        sample_ids = []\n        label_ids = [] if self.config.patches_field is not None else None\n        dists = []\n        for q in query:\n            _query = (\n                Query(f\"({filter})=>[KNN {k} @vector $query AS score]\")\n                .sort_by(\"score\")\n                .return_fields(\"score\", \"foid\", \"sample_id\")\n                .dialect(2)\n                .paging(0, k)\n            )\n            _q = q.astype(np.float32).tobytes()\n            docs = self._index.search(_query, {\"query\": _q}).docs\n\n            if self.config.patches_field is not None:\n                sample_ids.append([doc.sample_id for doc in docs])\n                label_ids.append([doc.foid for doc in docs])\n            else:\n                sample_ids.append([doc.foid for doc in docs])\n\n            if return_dists:\n                dists.append([doc.score for doc in docs])\n\n        if single_query:\n            sample_ids = sample_ids[0]\n            if label_ids is not None:\n                label_ids = label_ids[0]\n            if return_dists:\n                dists = dists[0]\n\n        if return_dists:\n            return sample_ids, label_ids, dists\n\n        return sample_ids, label_ids\n\n    def _parse_neighbors_query(self, query):\n        if etau.is_str(query):\n            query_ids = [query]\n            single_query = True\n        else:\n            query = np.asarray(query)\n\n            # Query by vector(s)\n            if np.issubdtype(query.dtype, np.number):\n                return query\n\n            query_ids = list(query)\n            single_query = False\n\n        # Query by ID(s)\n        dicts = self._get_values(query_ids)\n        if not dicts:\n            raise ValueError(\n                \"Query IDs %s do not exist in this index\" % query_ids\n            )\n\n        query = np.array([d[\"vector\"] for d in dicts])\n\n        if single_query:\n            query = query[0, :]\n\n        return query\n\n    @classmethod\n    def _from_dict(cls, d, samples, config, brain_key):\n        return cls(samples, config, brain_key)\n"
  },
  {
    "path": "fiftyone/brain/internal/core/representativeness.py",
    "content": "\"\"\"\nRepresentativeness methods.\n\n| Copyright 2017-2026, Voxel51, Inc.\n| `voxel51.com <https://voxel51.com/>`_\n|\n\"\"\"\nimport logging\nimport copy\n\nimport numpy as np\nimport sklearn.cluster as skc\nfrom scipy.spatial import cKDTree\n\nimport eta.core.utils as etau\n\nimport fiftyone.core.brain as fob\nimport fiftyone.core.fields as fof\nimport fiftyone.core.labels as fol\nimport fiftyone.core.validation as fov\n\nimport fiftyone.brain.internal.core.utils as fbu\nimport fiftyone.brain.internal.models as fbm\n\n\nlogger = logging.getLogger(__name__)\n\n_ALLOWED_ROI_FIELD_TYPES = (\n    fol.Detection,\n    fol.Detections,\n    fol.Polyline,\n    fol.Polylines,\n)\n\n_DEFAULT_MODEL = \"simple-resnet-cifar10\"\n_DEFAULT_BATCH_SIZE = 16\n\n\ndef compute_representativeness(\n    samples,\n    representativeness_field,\n    method,\n    roi_field,\n    embeddings,\n    similarity_index,\n    model,\n    model_kwargs,\n    force_square,\n    alpha,\n    batch_size,\n    num_workers,\n    skip_failures,\n    progress,\n):\n    \"\"\"See ``fiftyone/brain/__init__.py``.\"\"\"\n\n    #\n    # Algorithm\n    #\n    # Compute cluster centers with MeanShift. The representativeness will\n    # then be a scaled distance to the nearest cluster center. This puts\n    # cluster centers which should represent the data the highest with a high\n    # ranking and points on the outliers with low ranking.\n    #\n\n    fov.validate_collection(samples)\n\n    if roi_field is not None:\n        fov.validate_collection_label_fields(\n            samples, roi_field, _ALLOWED_ROI_FIELD_TYPES\n        )\n\n    if etau.is_str(embeddings):\n        embeddings_field, embeddings_exist = fbu.parse_data_field(\n            samples,\n            embeddings,\n            patches_field=roi_field,\n            data_type=\"embeddings\",\n        )\n        embeddings = None\n    else:\n        embeddings_field = None\n        embeddings_exist = None\n\n    if etau.is_str(similarity_index):\n        similarity_index = samples.load_brain_results(similarity_index)\n\n    if (\n        model is None\n        and embeddings is None\n        and similarity_index is None\n        and not embeddings_exist\n    ):\n        model = fbm.load_model(_DEFAULT_MODEL)\n        if batch_size is None:\n            batch_size = _DEFAULT_BATCH_SIZE\n\n    config = RepresentativenessConfig(\n        representativeness_field,\n        method=method,\n        roi_field=roi_field,\n        embeddings_field=embeddings_field,\n        similarity_index=similarity_index,\n        model=model,\n        model_kwargs=model_kwargs,\n    )\n    brain_key = representativeness_field\n    brain_method = config.build()\n    brain_method.ensure_requirements()\n    brain_method.register_run(samples, brain_key, cleanup=False)\n\n    if roi_field is not None:\n        # @todo experiment with mean(), max(), abs().max(), etc\n        agg_fcn = lambda e: np.mean(e, axis=0)\n    else:\n        agg_fcn = None\n\n    embeddings, sample_ids, _ = fbu.get_embeddings(\n        samples,\n        model=model,\n        model_kwargs=model_kwargs,\n        patches_field=roi_field,\n        embeddings_field=embeddings_field,\n        embeddings=embeddings,\n        similarity_index=similarity_index,\n        force_square=force_square,\n        alpha=alpha,\n        handle_missing=\"image\",\n        agg_fcn=agg_fcn,\n        batch_size=batch_size,\n        num_workers=num_workers,\n        skip_failures=skip_failures,\n        progress=progress,\n    )\n\n    logger.info(\"Computing representativeness...\")\n    representativeness = _compute_representativeness(embeddings, method=method)\n\n    # Ensure field exists, even if `representativeness` is empty\n    samples._dataset.add_sample_field(representativeness_field, fof.FloatField)\n\n    representativeness = {\n        _id: u for _id, u in zip(sample_ids, representativeness)\n    }\n    if representativeness:\n        samples.set_values(\n            representativeness_field, representativeness, key_field=\"id\"\n        )\n\n    brain_method.save_run_results(samples, brain_key, None)\n\n    logger.info(\"Representativeness computation complete\")\n\n\ndef _compute_representativeness(embeddings, method=\"cluster-center\"):\n    #\n    # @todo experiment on which method for assessing representativeness\n    #\n    num_embeddings = len(embeddings)\n    logger.info(\n        \"Computing clusters for %d embeddings; this may take awhile...\",\n        num_embeddings,\n    )\n\n    initial_ranking, _ = _cluster_ranker(embeddings)\n\n    if method == \"cluster-center\":\n        final_ranking = initial_ranking\n    elif method == \"cluster-center-downweight\":\n        logger.info(\"Applying iterative downweighting...\")\n        final_ranking = _adjust_rankings(\n            embeddings, initial_ranking, ball_radius=0.5\n        )\n    else:\n        raise ValueError(\n            (\n                \"Method '%s' not supported. Please use one of \"\n                \"['cluster-center', 'cluster-center-downweight']\"\n            )\n            % method\n        )\n\n    return final_ranking\n\n\ndef _cluster_ranker(\n    embeddings, cluster_algorithm=\"kmeans\", N=20, norm_method=\"local\"\n):\n    # Cluster\n    if cluster_algorithm == \"meanshift\":\n        bandwidth = skc.estimate_bandwidth(\n            embeddings, quantile=0.8, n_samples=500\n        )\n        clusterer = skc.MeanShift(bandwidth=bandwidth, bin_seeding=True).fit(\n            embeddings\n        )\n    elif cluster_algorithm == \"kmeans\":\n        clusterer = skc.KMeans(n_clusters=N, random_state=1234).fit(embeddings)\n    else:\n        raise ValueError(\n            (\n                \"Clustering algorithm '%s' not supported. Please use one of \"\n                \"['meanshift', 'kmeans']\"\n            )\n            % cluster_algorithm\n        )\n\n    cluster_centers = clusterer.cluster_centers_\n    cluster_ids = clusterer.labels_\n\n    # Get distance from each point to it's closest cluster center\n    sample_dists = np.linalg.norm(\n        embeddings - cluster_centers[cluster_ids], axis=1\n    )\n\n    centerness_ranking = 1 / (1 + sample_dists)\n\n    # Normalize per cluster vs globally\n    norm_method = \"local\"\n    if norm_method == \"global\":\n        centerness_ranking = centerness_ranking / centerness_ranking.max()\n    elif norm_method == \"local\":\n        unique_ids = np.unique(cluster_ids)\n        for unique_id in unique_ids:\n            cluster_indices = np.where(cluster_ids == unique_id)[0]\n            cluster_dists = sample_dists[cluster_indices]\n            cluster_dists /= cluster_dists.max()\n            sample_dists[cluster_indices] = cluster_dists\n        centerness_ranking = sample_dists\n\n    return centerness_ranking, clusterer\n\n\n# Step 3: Adjust rankings to avoid redundancy\ndef _adjust_rankings(embeddings, initial_ranking, ball_radius=0.5):\n    tree = cKDTree(embeddings)\n    new_ranking = copy.deepcopy(initial_ranking)\n\n    ordered_ranking = np.argsort(new_ranking)[::-1]\n    visited_indices = set()\n\n    for ranked_index in ordered_ranking:\n        visited_indices.add(ranked_index)\n        query_embedding = embeddings[ranked_index, :]\n        nearby_indices = tree.query_ball_point(\n            query_embedding, ball_radius, return_sorted=True\n        )\n        filtered_indices = [\n            idx for idx in nearby_indices if idx not in visited_indices\n        ]\n        visited_indices |= set(filtered_indices)\n        new_ranking[filtered_indices] = new_ranking[filtered_indices] * 0.7\n\n    new_ranking = new_ranking / new_ranking.max()\n    return new_ranking\n\n\n# @todo move to `fiftyone/brain/representativeness.py`\n# Don't do this hastily; `get_brain_info()` on existing datasets has this\n# class's full path in it and may need migration\nclass RepresentativenessConfig(fob.BrainMethodConfig):\n    def __init__(\n        self,\n        representativeness_field,\n        method=None,\n        roi_field=None,\n        embeddings_field=None,\n        similarity_index=None,\n        model=None,\n        model_kwargs=None,\n        **kwargs,\n    ):\n        if similarity_index is not None and not etau.is_str(similarity_index):\n            similarity_index = similarity_index.key\n\n        if model is not None and not etau.is_str(model):\n            model = etau.get_class_name(model)\n\n        self.representativeness_field = representativeness_field\n        self._method = method\n        self.roi_field = roi_field\n        self.embeddings_field = embeddings_field\n        self.similarity_index = similarity_index\n        self.model = model\n        self.model_kwargs = model_kwargs\n        super().__init__(**kwargs)\n\n    @property\n    def type(self):\n        return \"representativeness\"\n\n    @property\n    def method(self):\n        return self._method\n\n    @classmethod\n    def _virtual_attributes(cls):\n        # By default 'method' is virtual but we omit so it *IS* serialized\n        return [\"cls\", \"type\"]\n\n\nclass Representativeness(fob.BrainMethod):\n    def ensure_requirements(self):\n        pass\n\n    def get_fields(self, samples, brain_key):\n        fields = [self.config.representativeness_field]\n        if self.config.roi_field is not None:\n            fields.append(self.config.roi_field)\n\n        if self.config.embeddings_field is not None:\n            fields.append(self.config.embeddings_field)\n\n        return fields\n\n    def cleanup(self, samples, brain_key):\n        representativeness_field = self.config.representativeness_field\n        samples._dataset.delete_sample_fields(\n            representativeness_field, error_level=1\n        )\n\n    def _validate_run(self, samples, brain_key, existing_info):\n        self._validate_fields_match(\n            brain_key, \"representativeness_field\", existing_info\n        )\n"
  },
  {
    "path": "fiftyone/brain/internal/core/sklearn.py",
    "content": "\"\"\"\nSklearn similarity backend.\n\n| Copyright 2017-2026, Voxel51, Inc.\n| `voxel51.com <https://voxel51.com/>`_\n|\n\"\"\"\nimport logging\n\nimport numpy as np\nimport sklearn.metrics as skm\nimport sklearn.neighbors as skn\nimport sklearn.preprocessing as skp\n\nimport eta.core.utils as etau\n\nimport fiftyone.core.media as fom\nfrom fiftyone.brain.similarity import (\n    DuplicatesMixin,\n    SimilarityConfig,\n    Similarity,\n    SimilarityIndex,\n)\nimport fiftyone.brain.internal.core.utils as fbu\n\n\nlogger = logging.getLogger(__name__)\n\n_AGGREGATIONS = {\n    \"mean\": np.mean,\n    \"post-mean\": np.nanmean,\n    \"post-min\": np.nanmin,\n    \"post-max\": np.nanmax,\n}\n\n_MAX_PRECOMPUTE_DISTS = 15000  # ~1.7GB to store distance matrix in-memory\n_COSINE_HACK_ATTR = \"_cosine_hack\"\n\n\nclass SklearnSimilarityConfig(SimilarityConfig):\n    \"\"\"Configuration for the sklearn similarity backend.\n\n    Args:\n        metric (\"cosine\"): the embedding distance metric to use. See\n            ``sklearn.metrics.pairwise_distance`` for supported values\n        **kwargs: keyword arguments for\n            :class:`fiftyone.brain.similarity.SimilarityConfig`\n    \"\"\"\n\n    def __init__(self, metric=\"cosine\", **kwargs):\n        super().__init__(**kwargs)\n        self.metric = metric\n\n    @property\n    def method(self):\n        return \"sklearn\"\n\n    @property\n    def max_k(self):\n        return None\n\n    @property\n    def supports_least_similarity(self):\n        return True\n\n    @property\n    def supported_aggregations(self):\n        return tuple(_AGGREGATIONS.keys())\n\n\nclass SklearnSimilarity(Similarity):\n    \"\"\"Sklearn similarity factory.\n\n    Args:\n        config: an :class:`SklearnSimilarityConfig`\n    \"\"\"\n\n    def initialize(self, samples, brain_key):\n        return SklearnSimilarityIndex(\n            samples, self.config, brain_key, backend=self\n        )\n\n\nclass SklearnSimilarityIndex(SimilarityIndex, DuplicatesMixin):\n    \"\"\"Class for interacting with sklearn similarity indexes.\n\n    Args:\n        samples: the :class:`fiftyone.core.collections.SampleCollection` used\n        config: the :class:`SklearnSimilarityConfig` used\n        brain_key: the brain key\n        embeddings (None): a ``num_embeddings x num_dims`` array of embeddings\n        sample_ids (None): a ``num_embeddings`` array of sample IDs\n        label_ids (None): a ``num_embeddings`` array of label IDs, if\n            applicable\n        backend (None): a :class:`SklearnSimilarity` instance\n    \"\"\"\n\n    def __init__(\n        self,\n        samples,\n        config,\n        brain_key,\n        embeddings=None,\n        sample_ids=None,\n        label_ids=None,\n        backend=None,\n    ):\n        embeddings, sample_ids, label_ids = self._parse_data(\n            samples,\n            config,\n            embeddings=embeddings,\n            sample_ids=sample_ids,\n            label_ids=label_ids,\n        )\n\n        self._dataset = samples._dataset\n        self._embeddings = embeddings\n        self._sample_ids = sample_ids\n        self._label_ids = label_ids\n        self._ids_to_inds = None\n        self._curr_ids_to_inds = None\n        self._neighbors_helper = None\n\n        SimilarityIndex.__init__(\n            self, samples, config, brain_key, backend=backend\n        )\n        DuplicatesMixin.__init__(self)\n\n    @property\n    def is_external(self):\n        return self.config.embeddings_field is None\n\n    @property\n    def embeddings(self):\n        return self._embeddings\n\n    @property\n    def sample_ids(self):\n        return self._sample_ids\n\n    @property\n    def label_ids(self):\n        return self._label_ids\n\n    @property\n    def total_index_size(self):\n        return len(self._sample_ids)\n\n    def add_to_index(\n        self,\n        embeddings,\n        sample_ids,\n        label_ids=None,\n        overwrite=True,\n        allow_existing=True,\n        warn_existing=False,\n        reload=True,\n    ):\n        sample_ids = np.asarray(sample_ids)\n        label_ids = np.asarray(label_ids) if label_ids is not None else None\n\n        _sample_ids, _label_ids, ii, jj = fbu.add_ids(\n            sample_ids,\n            label_ids,\n            self._sample_ids,\n            self._label_ids,\n            patches_field=self.config.patches_field,\n            overwrite=overwrite,\n            allow_existing=allow_existing,\n            warn_existing=warn_existing,\n        )\n\n        if ii.size == 0:\n            return\n\n        _embeddings = embeddings[ii, :]\n\n        if self.config.embeddings_field is not None:\n            fbu.add_embeddings(\n                self._dataset,\n                _embeddings,\n                sample_ids[ii],\n                label_ids[ii] if label_ids is not None else None,\n                self.config.embeddings_field,\n                patches_field=self.config.patches_field,\n            )\n\n        _e = self._embeddings\n\n        n = _e.shape[0]\n        if n == 0:\n            _e = np.empty((0, embeddings.shape[1]), dtype=embeddings.dtype)\n        d = _e.shape[1]\n        m = max(jj) - n + 1\n\n        if m > 0:\n            if _e.size > 0:\n                _e = np.concatenate((_e, np.empty((m, d), dtype=_e.dtype)))\n            else:\n                _e = np.empty_like(_embeddings)\n\n        _e[jj, :] = _embeddings\n\n        self._embeddings = _e\n        self._sample_ids = _sample_ids\n        self._label_ids = _label_ids\n        self._ids_to_inds = None\n        self._curr_ids_to_inds = None\n        self._neighbors_helper = None\n\n        if reload:\n            super().reload()\n\n    def remove_from_index(\n        self,\n        sample_ids=None,\n        label_ids=None,\n        allow_missing=True,\n        warn_missing=False,\n        reload=True,\n    ):\n        _sample_ids, _label_ids, rm_inds = fbu.remove_ids(\n            sample_ids,\n            label_ids,\n            self._sample_ids,\n            self._label_ids,\n            patches_field=self.config.patches_field,\n            allow_missing=allow_missing,\n            warn_missing=warn_missing,\n        )\n\n        if rm_inds.size == 0:\n            return\n\n        if self.config.embeddings_field is not None:\n            if self.config.patches_field is not None:\n                rm_sample_ids = None\n                rm_label_ids = self._label_ids[rm_inds]\n            else:\n                rm_sample_ids = self._sample_ids[rm_inds]\n                rm_label_ids = None\n\n            fbu.remove_embeddings(\n                self._dataset,\n                self.config.embeddings_field,\n                sample_ids=rm_sample_ids,\n                label_ids=rm_label_ids,\n                patches_field=self.config.patches_field,\n            )\n\n        _embeddings = np.delete(self._embeddings, rm_inds, axis=0)\n\n        self._embeddings = _embeddings\n        self._sample_ids = _sample_ids\n        self._label_ids = _label_ids\n        self._ids_to_inds = None\n        self._curr_ids_to_inds = None\n        self._neighbors_helper = None\n\n        if reload:\n            super().reload()\n\n    def use_view(self, *args, **kwargs):\n        self._curr_ids_to_inds = None\n        return super().use_view(*args, **kwargs)\n\n    def get_embeddings(\n        self,\n        sample_ids=None,\n        label_ids=None,\n        allow_missing=True,\n        warn_missing=False,\n    ):\n        if label_ids is not None:\n            if self.config.patches_field is None:\n                raise ValueError(\"This index does not support label IDs\")\n\n            if sample_ids is not None:\n                logger.warning(\n                    \"Ignoring sample IDs when label IDs are provided\"\n                )\n\n            inds = _get_inds(\n                label_ids,\n                self.label_ids,\n                \"label\",\n                allow_missing,\n                warn_missing,\n            )\n\n            embeddings = self._embeddings[inds, :]\n            sample_ids = self.sample_ids[inds]\n            label_ids = np.asarray(label_ids)\n        elif sample_ids is not None:\n            if etau.is_str(sample_ids):\n                sample_ids = [sample_ids]\n\n            if self.config.patches_field is not None:\n                sample_ids = set(sample_ids)\n                bools = [_id in sample_ids for _id in self.sample_ids]\n                inds = np.nonzero(bools)[0]\n            else:\n                inds = _get_inds(\n                    sample_ids,\n                    self.sample_ids,\n                    \"sample\",\n                    allow_missing,\n                    warn_missing,\n                )\n\n            embeddings = self._embeddings[inds, :]\n            sample_ids = self.sample_ids[inds]\n            if self.config.patches_field is not None:\n                label_ids = self.label_ids[inds]\n            else:\n                label_ids = None\n        else:\n            embeddings = self._embeddings.copy()\n            sample_ids = self.sample_ids.copy()\n            if self.config.patches_field is not None:\n                label_ids = self.label_ids.copy()\n            else:\n                label_ids = None\n\n        return embeddings, sample_ids, label_ids\n\n    def reload(self):\n        if self.config.embeddings_field is not None:\n            embeddings, sample_ids, label_ids = self._parse_data(\n                self._dataset, self.config\n            )\n\n            self._embeddings = embeddings\n            self._sample_ids = sample_ids\n            self._label_ids = label_ids\n            self._ids_to_inds = None\n            self._curr_ids_to_inds = None\n            self._neighbors_helper = None\n\n        super().reload()\n\n    def cleanup(self):\n        pass\n\n    def attributes(self):\n        attrs = super().attributes()\n\n        if self.config.embeddings_field is None:\n            attrs.extend([\"embeddings\", \"sample_ids\", \"label_ids\"])\n\n        return attrs\n\n    def _kneighbors(\n        self,\n        query=None,\n        k=None,\n        reverse=False,\n        aggregation=None,\n        return_dists=False,\n    ):\n        if aggregation is not None:\n            return self._kneighbors_aggregate(\n                query, k, reverse, aggregation, return_dists\n            )\n\n        (\n            query,\n            query_inds,\n            full_index,\n            single_query,\n        ) = self._parse_neighbors_query(query)\n\n        can_use_dists = full_index or query_inds is not None\n        neighbors, dists = self._get_neighbors(can_use_dists=can_use_dists)\n\n        if dists is not None:\n            # Use pre-computed distances\n            if query_inds is not None:\n                _dists = dists[query_inds, :]\n            else:\n                _dists = dists\n\n            # note: this must gracefully ignore nans\n            inds = _nanargmin(_dists, k=k)\n\n            if return_dists:\n                dists = [d[i] for i, d in zip(inds, _dists)]\n            else:\n                dists = None\n        else:\n            if return_dists:\n                dists, inds = neighbors.kneighbors(\n                    X=query, n_neighbors=k, return_distance=True\n                )\n                inds = list(inds)\n                dists = list(dists)\n            else:\n                inds = neighbors.kneighbors(\n                    X=query, n_neighbors=k, return_distance=False\n                )\n                inds = list(inds)\n                dists = None\n\n        return self._format_output(\n            inds, dists, full_index, single_query, return_dists\n        )\n\n    def _radius_neighbors(self, query=None, thresh=None, return_dists=False):\n        (\n            query,\n            query_inds,\n            full_index,\n            single_query,\n        ) = self._parse_neighbors_query(query)\n\n        can_use_dists = full_index or query_inds is not None\n        neighbors, dists = self._get_neighbors(can_use_dists=can_use_dists)\n\n        # When not using brute force, we approximate cosine distance by\n        # computing Euclidean distance on unit-norm embeddings.\n        # ED = sqrt(2 * CD), so we need to scale the threshold appropriately\n        if getattr(neighbors, _COSINE_HACK_ATTR, False):\n            thresh = np.sqrt(2.0 * thresh)\n\n        if dists is not None:\n            # Use pre-computed distances\n            if query_inds is not None:\n                _dists = dists[query_inds, :]\n            else:\n                _dists = dists\n\n            # note: this must gracefully ignore nans\n            inds = [np.nonzero(d <= thresh)[0] for d in _dists]\n\n            if return_dists:\n                dists = [d[i] for i, d in zip(inds, _dists)]\n            else:\n                dists = None\n        else:\n            if return_dists:\n                dists, inds = neighbors.radius_neighbors(\n                    X=query, radius=thresh, return_distance=True\n                )\n            else:\n                dists = None\n                inds = neighbors.radius_neighbors(\n                    X=query, radius=thresh, return_distance=False\n                )\n\n        return self._format_output(\n            inds, dists, full_index, single_query, return_dists\n        )\n\n    def _kneighbors_aggregate(\n        self, query, k, reverse, aggregation, return_dists\n    ):\n        if query is None:\n            raise ValueError(\"Full index queries do not support aggregation\")\n\n        if aggregation not in _AGGREGATIONS:\n            raise ValueError(\n                \"Unsupported aggregation method '%s'. Supported values are %s\"\n                % (aggregation, tuple(_AGGREGATIONS.keys()))\n            )\n\n        query, query_inds, _, _ = self._parse_neighbors_query(query)\n\n        # Pre-aggregation\n        if aggregation == \"mean\":\n            if query.shape[0] > 1:\n                query = query.mean(axis=0, keepdims=True)\n                query_inds = None\n\n            aggregation = None\n\n        can_use_dists = query_inds is not None\n        _, dists = self._get_neighbors(\n            can_use_neighbors=False, can_use_dists=can_use_dists\n        )\n\n        if dists is not None:\n            # Use pre-computed distances\n            dists = dists[query_inds, :]\n        else:\n            keep_inds = self._current_inds\n            index_embeddings = self._embeddings\n            if keep_inds is not None:\n                index_embeddings = index_embeddings[keep_inds]\n\n            dists = skm.pairwise_distances(\n                query, index_embeddings, metric=self.config.metric\n            )\n\n        # Post-aggregation\n        if aggregation is not None:\n            # note: this must gracefully ignore nans\n            agg_fcn = _AGGREGATIONS[aggregation]\n            dists = agg_fcn(dists, axis=0)\n        else:\n            dists = dists[0, :]\n\n        if can_use_dists:\n            dists[np.isnan(dists)] = 0.0\n\n        inds = np.argsort(dists)\n        if reverse:\n            inds = np.flip(inds)\n\n        if k is not None:\n            inds = inds[:k]\n\n        sample_ids = list(self.current_sample_ids[inds])\n\n        if self.config.patches_field is not None:\n            label_ids = list(self.current_label_ids[inds])\n        else:\n            label_ids = None\n\n        if return_dists:\n            dists = list(dists[inds])\n            return sample_ids, label_ids, dists\n\n        return sample_ids, label_ids\n\n    def _parse_neighbors_query(self, query):\n        # Full index\n        if query is None:\n            return None, None, True, False\n\n        if etau.is_str(query):\n            query_ids = [query]\n            single_query = True\n        else:\n            query = np.asarray(query)\n\n            # Query vector(s)\n            if np.issubdtype(query.dtype, np.number):\n                single_query = query.ndim == 1\n                if single_query:\n                    query = query[np.newaxis, :]\n\n                return query, None, False, single_query\n\n            query_ids = list(query)\n            single_query = False\n\n        # Retrieve indices into active `dists` matrix, if possible\n        ids_to_inds = self._get_ids_to_inds(full=False)\n        query_inds = []\n        for _id in query_ids:\n            _ind = ids_to_inds.get(_id, None)\n            if _ind is not None:\n                query_inds.append(_ind)\n            else:\n                # At least one query ID is not in the active index\n                query_inds = None\n                break\n\n        # Retrieve embeddings\n        ids_to_inds = self._get_ids_to_inds(full=True)\n        inds = []\n        bad_ids = []\n        for _id in query_ids:\n            _ind = ids_to_inds.get(_id, None)\n            if _ind is not None:\n                inds.append(_ind)\n            else:\n                bad_ids.append(_id)\n\n        inds = np.array(inds)\n\n        if bad_ids:\n            raise ValueError(\n                \"Query IDs %s do not exist in this index\" % bad_ids\n            )\n\n        query = self._embeddings[inds, :]\n\n        if query_inds is not None:\n            query_inds = np.array(query_inds)\n\n        return query, query_inds, False, single_query\n\n    def _get_ids_to_inds(self, full=False):\n        if full:\n            if self._ids_to_inds is None:\n                if self.config.patches_field is not None:\n                    ids = self.label_ids\n                else:\n                    ids = self.sample_ids\n\n                self._ids_to_inds = {_id: i for i, _id in enumerate(ids)}\n\n            return self._ids_to_inds\n\n        if self._curr_ids_to_inds is None:\n            if self.config.patches_field is not None:\n                ids = self.current_label_ids\n            else:\n                ids = self.current_sample_ids\n\n            self._curr_ids_to_inds = {_id: i for i, _id in enumerate(ids)}\n\n        return self._curr_ids_to_inds\n\n    def _get_neighbors(self, can_use_neighbors=True, can_use_dists=True):\n        if self._neighbors_helper is None:\n            self._neighbors_helper = NeighborsHelper(\n                self._embeddings, self.config.metric\n            )\n\n        return self._neighbors_helper.get_neighbors(\n            keep_inds=self._current_inds,\n            can_use_neighbors=can_use_neighbors,\n            can_use_dists=can_use_dists,\n        )\n\n    def _format_output(\n        self, inds, dists, full_index, single_query, return_dists\n    ):\n        if full_index:\n            if return_dists:\n                return inds, dists\n\n            return inds\n\n        curr_sample_ids = self.current_sample_ids\n        sample_ids = [[curr_sample_ids[i] for i in _inds] for _inds in inds]\n        if single_query:\n            sample_ids = sample_ids[0]\n\n        if self.config.patches_field is not None:\n            curr_label_ids = self.current_label_ids\n            label_ids = [[curr_label_ids[i] for i in _inds] for _inds in inds]\n            if single_query:\n                label_ids = label_ids[0]\n        else:\n            label_ids = None\n\n        if return_dists:\n            dists = [list(d) for d in dists]\n            if single_query:\n                dists = dists[0]\n\n            return sample_ids, label_ids, dists\n\n        return sample_ids, label_ids\n\n    @staticmethod\n    def _parse_data(\n        samples,\n        config,\n        embeddings=None,\n        sample_ids=None,\n        label_ids=None,\n    ):\n        if embeddings is None:\n            samples = samples._dataset\n            if samples.media_type == fom.GROUP:\n                samples = samples.select_group_slices(_allow_mixed=True)\n\n            embeddings, sample_ids, label_ids = fbu.get_embeddings(\n                samples,\n                patches_field=config.patches_field,\n                embeddings_field=config.embeddings_field,\n            )\n        elif sample_ids is None:\n            sample_ids, label_ids = fbu.get_ids(\n                samples,\n                patches_field=config.patches_field,\n                data=embeddings,\n                data_type=\"embeddings\",\n            )\n\n        return embeddings, sample_ids, label_ids\n\n    @classmethod\n    def _from_dict(cls, d, samples, config, brain_key):\n        embeddings = d.get(\"embeddings\", None)\n        if embeddings is not None:\n            embeddings = np.array(embeddings)\n\n        sample_ids = d.get(\"sample_ids\", None)\n        if sample_ids is not None:\n            sample_ids = np.array(sample_ids)\n\n        label_ids = d.get(\"label_ids\", None)\n        if label_ids is not None:\n            label_ids = np.array(label_ids)\n\n        return cls(\n            samples,\n            config,\n            brain_key,\n            embeddings=embeddings,\n            sample_ids=sample_ids,\n            label_ids=label_ids,\n        )\n\n\nclass NeighborsHelper(object):\n\n    _UNAVAILABLE = \"UNAVAILABLE\"\n\n    def __init__(self, embeddings, metric):\n        self.embeddings = embeddings\n        self.metric = metric\n\n        self._initialized = False\n        self._full_dists = None\n\n        self._curr_keep_inds = None\n        self._curr_neighbors = None\n        self._curr_dists = None\n\n    def get_neighbors(\n        self,\n        keep_inds=None,\n        can_use_neighbors=True,\n        can_use_dists=True,\n    ):\n        iokay = self._same_keep_inds(keep_inds)\n        nokay = not can_use_neighbors or self._curr_neighbors is not None\n        dokay = not can_use_dists or self._curr_dists is not None\n\n        if iokay and nokay and dokay:\n            neighbors = self._curr_neighbors\n            dists = self._curr_dists\n        else:\n            neighbors, dists = self._build(\n                keep_inds=keep_inds,\n                can_use_neighbors=can_use_neighbors,\n                can_use_dists=can_use_dists,\n            )\n\n            if not iokay:\n                self._curr_keep_inds = keep_inds\n\n            if self._curr_neighbors is None or not iokay:\n                self._curr_neighbors = neighbors\n\n            if self._curr_dists is None or not iokay:\n                self._curr_dists = dists\n\n        if not can_use_neighbors or neighbors is self._UNAVAILABLE:\n            neighbors = None\n\n        if not can_use_dists or dists is self._UNAVAILABLE:\n            dists = None\n\n        return neighbors, dists\n\n    def _same_keep_inds(self, keep_inds):\n        # This handles either argument being None\n        return np.array_equal(keep_inds, self._curr_keep_inds)\n\n    def _build(\n        self, keep_inds=None, can_use_neighbors=True, can_use_dists=True\n    ):\n        if can_use_dists:\n            if (\n                self._full_dists is None\n                and len(self.embeddings) <= _MAX_PRECOMPUTE_DISTS\n            ):\n                self._full_dists = self._build_dists(self.embeddings)\n\n            if self._full_dists is not None:\n                if keep_inds is not None:\n                    dists = self._full_dists[keep_inds, :][:, keep_inds]\n                else:\n                    dists = self._full_dists\n            elif (\n                keep_inds is not None\n                and len(keep_inds) <= _MAX_PRECOMPUTE_DISTS\n            ):\n                dists = self._build_dists(self.embeddings[keep_inds])\n            else:\n                dists = self._UNAVAILABLE\n        else:\n            dists = None\n\n        if can_use_neighbors:\n            if not isinstance(dists, np.ndarray):\n                embeddings = self.embeddings\n                if keep_inds is not None:\n                    embeddings = embeddings[keep_inds]\n\n                neighbors = self._build_neighbors(embeddings)\n            else:\n                neighbors = self._UNAVAILABLE\n        else:\n            neighbors = None\n\n        return neighbors, dists\n\n    def _build_dists(self, embeddings):\n        logger.debug(\"Generating index for %d embeddings...\", len(embeddings))\n\n        # Center embeddings\n        embeddings = np.asarray(embeddings)\n        embeddings -= embeddings.mean(axis=0, keepdims=True)\n\n        dists = skm.pairwise_distances(embeddings, metric=self.metric)\n        np.fill_diagonal(dists, np.nan)\n\n        logger.debug(\"Index complete\")\n\n        return dists\n\n    def _build_neighbors(self, embeddings):\n        logger.debug(\n            \"Generating neighbors graph for %d embeddings...\",\n            len(embeddings),\n        )\n\n        # Center embeddings\n        embeddings = np.asarray(embeddings)\n        embeddings -= embeddings.mean(axis=0, keepdims=True)\n\n        metric = self.metric\n\n        if metric == \"cosine\":\n            # Nearest neighbors does not directly support cosine distance, so\n            # we approximate via euclidean distance on unit-norm embeddings\n            cosine_hack = True\n            embeddings = skp.normalize(embeddings, axis=1)\n            metric = \"euclidean\"\n        else:\n            cosine_hack = False\n\n        neighbors = skn.NearestNeighbors(metric=metric)\n        neighbors.fit(embeddings)\n\n        setattr(neighbors, _COSINE_HACK_ATTR, cosine_hack)\n\n        logger.debug(\"Index complete\")\n\n        return neighbors\n\n\ndef _get_inds(ids, index_ids, ftype, allow_missing, warn_missing):\n    if etau.is_str(ids):\n        ids = [ids]\n\n    ids_map = {_id: i for i, _id in enumerate(index_ids)}\n\n    inds = []\n    bad_ids = []\n\n    for _id in ids:\n        idx = ids_map.get(_id, None)\n        if idx is not None:\n            inds.append(idx)\n        else:\n            bad_ids.append(_id)\n\n    num_missing = len(bad_ids)\n\n    if num_missing > 0:\n        if not allow_missing:\n            raise ValueError(\n                \"Found %d %s IDs (eg '%s') that are not present in the index\"\n                % (num_missing, ftype, bad_ids[0])\n            )\n\n        if warn_missing:\n            logger.warning(\n                \"Ignoring %d %s IDs that are not present in the index\",\n                num_missing,\n                ftype,\n            )\n\n    return np.array(inds)\n\n\ndef _nanargmin(array, k=1):\n    if k == 1:\n        inds = np.nanargmin(array, axis=1)\n        inds = [np.array([i]) for i in inds]\n    else:\n        inds = np.argsort(array, axis=1)\n        inds = list(inds[:, :k])\n\n    return inds\n"
  },
  {
    "path": "fiftyone/brain/internal/core/uniqueness.py",
    "content": "\"\"\"\nUniqueness methods.\n\n| Copyright 2017-2026, Voxel51, Inc.\n| `voxel51.com <https://voxel51.com/>`_\n|\n\"\"\"\nimport logging\n\nimport numpy as np\n\nimport eta.core.utils as etau\n\nimport fiftyone.brain as fb\nimport fiftyone.core.brain as fob\nimport fiftyone.core.fields as fof\nimport fiftyone.core.labels as fol\nimport fiftyone.core.utils as fou\nimport fiftyone.core.validation as fov\n\nimport fiftyone.brain.internal.core.utils as fbu\nimport fiftyone.brain.internal.models as fbm\n\n\nlogger = logging.getLogger(__name__)\n\n_ALLOWED_ROI_FIELD_TYPES = (\n    fol.Detection,\n    fol.Detections,\n    fol.Polyline,\n    fol.Polylines,\n)\n\n_DEFAULT_MODEL = \"simple-resnet-cifar10\"\n_DEFAULT_BATCH_SIZE = 16\n\n\ndef compute_uniqueness(\n    samples,\n    uniqueness_field,\n    roi_field,\n    embeddings,\n    similarity_index,\n    model,\n    model_kwargs,\n    force_square,\n    alpha,\n    batch_size,\n    num_workers,\n    skip_failures,\n    progress,\n):\n    \"\"\"See ``fiftyone/brain/__init__.py``.\"\"\"\n\n    #\n    # Algorithm\n    #\n    # Uniqueness is computed based on a classification model.  Each sample is\n    # embedded into a vector space based on the model. Then, we compute the\n    # knn's (k is a parameter of the uniqueness function). The uniqueness is\n    # then proportional to these distances. The intuition is that a sample is\n    # unique when it is far from other samples in the set. This is different\n    # than, say, \"representativeness\" which would stress samples that are core\n    # to dense clusters of related samples.\n    #\n\n    fov.validate_collection(samples)\n\n    if roi_field is not None:\n        fov.validate_collection_label_fields(\n            samples, roi_field, _ALLOWED_ROI_FIELD_TYPES\n        )\n\n    if etau.is_str(embeddings):\n        embeddings_field, embeddings_exist = fbu.parse_data_field(\n            samples,\n            embeddings,\n            patches_field=roi_field,\n            data_type=\"embeddings\",\n        )\n        embeddings = None\n    else:\n        embeddings_field = None\n        embeddings_exist = None\n\n    if etau.is_str(similarity_index):\n        similarity_index = samples.load_brain_results(similarity_index)\n\n    if (\n        model is None\n        and embeddings is None\n        and similarity_index is None\n        and not embeddings_exist\n    ):\n        model = fbm.load_model(_DEFAULT_MODEL)\n        if batch_size is None:\n            batch_size = _DEFAULT_BATCH_SIZE\n\n    config = UniquenessConfig(\n        uniqueness_field,\n        roi_field=roi_field,\n        embeddings_field=embeddings_field,\n        similarity_index=similarity_index,\n        model=model,\n        model_kwargs=model_kwargs,\n    )\n    brain_key = uniqueness_field\n    brain_method = config.build()\n    brain_method.ensure_requirements()\n    brain_method.register_run(samples, brain_key, cleanup=False)\n\n    if roi_field is not None:\n        # @todo experiment with mean(), max(), abs().max(), etc\n        agg_fcn = lambda e: np.mean(e, axis=0)\n    else:\n        agg_fcn = None\n\n    embeddings, sample_ids, _ = fbu.get_embeddings(\n        samples,\n        model=model,\n        model_kwargs=model_kwargs,\n        patches_field=roi_field,\n        embeddings_field=embeddings_field,\n        embeddings=embeddings,\n        similarity_index=similarity_index,\n        force_square=force_square,\n        alpha=alpha,\n        handle_missing=\"image\",\n        agg_fcn=agg_fcn,\n        batch_size=batch_size,\n        num_workers=num_workers,\n        skip_failures=skip_failures,\n        progress=progress,\n    )\n\n    if similarity_index is None:\n        similarity_index = fb.compute_similarity(\n            samples, backend=\"sklearn\", embeddings=False\n        )\n        similarity_index.add_to_index(embeddings, sample_ids)\n\n    logger.info(\"Computing uniqueness...\")\n    uniqueness = _compute_uniqueness(\n        embeddings, similarity_index, progress=progress\n    )\n\n    # Ensure field exists, even if `uniqueness` is empty\n    samples._dataset.add_sample_field(uniqueness_field, fof.FloatField)\n\n    uniqueness = {_id: u for _id, u in zip(sample_ids, uniqueness)}\n    if uniqueness:\n        samples.set_values(uniqueness_field, uniqueness, key_field=\"id\")\n\n    brain_method.save_run_results(samples, brain_key, None)\n\n    logger.info(\"Uniqueness computation complete\")\n\n\ndef _compute_uniqueness(\n    embeddings, similarity_index, batch_size=10, progress=None\n):\n    K = 3\n\n    num_embeddings = len(embeddings)\n    if num_embeddings <= K:\n        return [1] * num_embeddings\n\n    if similarity_index.config.method == \"sklearn\":\n        _, dists = similarity_index._kneighbors(k=K + 1, return_dists=True)\n    else:\n        dists = []\n        with fou.ProgressBar(total=num_embeddings, progress=progress) as pb:\n            for _embeddings in fou.iter_slices(embeddings, batch_size):\n                _, _, _dists = similarity_index._kneighbors(\n                    query=_embeddings, k=K + 1, return_dists=True\n                )\n                dists.extend(_dists)\n                pb.update(len(_dists))\n\n    dists = np.array(dists)\n\n    # @todo experiment on which method for assessing uniqueness is best\n    #\n    # To get something going, for now, just take a weighted mean\n    #\n    weights = [0.6, 0.3, 0.1]\n    sample_dists = np.mean(dists[:, 1:] * weights, axis=1)\n\n    # Normalize to keep the user on common footing across datasets\n    sample_dists /= sample_dists.max()\n\n    return sample_dists\n\n\n# @todo move to `fiftyone/brain/uniqueness.py`\n# Don't do this hastily; `get_brain_info()` on existing datasets has this\n# class's full path in it and may need migration\nclass UniquenessConfig(fob.BrainMethodConfig):\n    def __init__(\n        self,\n        uniqueness_field,\n        roi_field=None,\n        embeddings_field=None,\n        similarity_index=None,\n        model=None,\n        model_kwargs=None,\n        **kwargs,\n    ):\n        if similarity_index is not None and not etau.is_str(similarity_index):\n            similarity_index = similarity_index.key\n\n        if model is not None and not etau.is_str(model):\n            model = etau.get_class_name(model)\n\n        self.uniqueness_field = uniqueness_field\n        self.roi_field = roi_field\n        self.embeddings_field = embeddings_field\n        self.similarity_index = similarity_index\n        self.model = model\n        self.model_kwargs = model_kwargs\n\n        super().__init__(**kwargs)\n\n    @property\n    def type(self):\n        return \"uniqueness\"\n\n    @property\n    def method(self):\n        return \"neighbors\"\n\n\nclass Uniqueness(fob.BrainMethod):\n    def ensure_requirements(self):\n        pass\n\n    def get_fields(self, samples, brain_key):\n        fields = [self.config.uniqueness_field]\n        if self.config.roi_field is not None:\n            fields.append(self.config.roi_field)\n\n        if self.config.embeddings_field is not None:\n            fields.append(self.config.embeddings_field)\n\n        return fields\n\n    def cleanup(self, samples, brain_key):\n        uniqueness_field = self.config.uniqueness_field\n        samples._dataset.delete_sample_fields(uniqueness_field, error_level=1)\n\n    def _validate_run(self, samples, brain_key, existing_info):\n        self._validate_fields_match(\n            brain_key, \"uniqueness_field\", existing_info\n        )\n"
  },
  {
    "path": "fiftyone/brain/internal/core/utils.py",
    "content": "\"\"\"\nUtilities.\n\n| Copyright 2017-2026, Voxel51, Inc.\n| `voxel51.com <https://voxel51.com/>`_\n|\n\"\"\"\nimport itertools\nimport logging\nimport random\nimport string\n\nimport numpy as np\n\nimport eta.core.utils as etau\n\nimport fiftyone.brain as fob\nimport fiftyone.core.fields as fof\nimport fiftyone.core.labels as fol\nimport fiftyone.core.models as fom\nimport fiftyone.core.media as fomm\nimport fiftyone.core.patches as fop\nimport fiftyone.zoo as foz\nfrom fiftyone import ViewField as F\n\n\nlogger = logging.getLogger(__name__)\n\n\ndef parse_data(\n    samples,\n    patches_field=None,\n    data=None,\n    data_type=\"embeddings\",\n    allow_missing=True,\n    warn_missing=True,\n):\n    if isinstance(data, fob.SimilarityIndex):\n        return get_embeddings_from_index(\n            samples,\n            data,\n            patches_field=None,\n            allow_missing=True,\n            warn_missing=True,\n        )\n\n    _validate_args(samples, patches_field=patches_field)\n\n    if patches_field is None:\n        if isinstance(data, dict):\n            sample_ids, data = zip(*data.items())\n            return np.array(data), np.array(sample_ids), None\n\n        sample_ids, _ = get_ids(samples, data=data, data_type=data_type)\n        return data, sample_ids, None\n\n    if isinstance(data, dict):\n        value = next(iter(data.values()), None)\n        if isinstance(value, np.ndarray) and value.ndim == 1:\n            label_ids, data = zip(*data.items())\n            return _parse_label_data(\n                samples,\n                patches_field,\n                label_ids,\n                data,\n                data_type,\n                allow_missing,\n                warn_missing,\n            )\n\n    sample_ids, label_ids = get_ids(\n        samples,\n        patches_field=patches_field,\n        data=data,\n        data_type=data_type,\n    )\n\n    return data, sample_ids, label_ids\n\n\ndef _parse_label_data(\n    samples,\n    patches_field,\n    label_ids,\n    data,\n    data_type,\n    allow_missing,\n    warn_missing,\n):\n    if samples._is_patches:\n        sample_id_path = \"sample_id\"\n    else:\n        sample_id_path = \"id\"\n\n    label_type, label_id_path = samples._get_label_field_path(\n        patches_field, \"id\"\n    )\n    is_list_field = issubclass(label_type, fol._LABEL_LIST_FIELDS)\n\n    ref_sample_ids, ref_label_ids = samples._dataset.values(\n        [sample_id_path, label_id_path]\n    )\n\n    if is_list_field:\n        ids_map = {}\n        for _sample_id, _lids in zip(ref_sample_ids, ref_label_ids):\n            if _lids:\n                for _label_id in _lids:\n                    ids_map[_label_id] = _sample_id\n    else:\n        ids_map = dict(zip(ref_label_ids, ref_sample_ids))\n\n    _data = []\n    _sample_ids = []\n    _label_ids = []\n    _missing_ids = []\n    for _lid, _d in zip(label_ids, data):\n        _sid = ids_map.get(_lid, None)\n        if _sid is not None:\n            _data.append(_d)\n            _sample_ids.append(_sid)\n            _label_ids.append(_lid)\n        else:\n            _missing_ids.append(_lid)\n\n    num_missing = len(_missing_ids)\n    if num_missing > 0:\n        if not allow_missing:\n            raise ValueError(\n                \"Unable to retrieve sample IDs for %d label IDs (eg %s)\"\n                % (num_missing, _missing_ids[0])\n            )\n\n        if warn_missing:\n            logger.warning(\n                \"Ignoring %s for %d label IDs (eg %s) for which sample IDs \"\n                \"could not be retrieved\",\n                data_type,\n                num_missing,\n                _missing_ids[0],\n            )\n\n    return np.array(_data), np.array(_sample_ids), np.array(_label_ids)\n\n\ndef get_embeddings_from_index(\n    samples,\n    similarity_index,\n    patches_field=None,\n    allow_missing=True,\n    warn_missing=True,\n):\n    if patches_field is None:\n        if samples._is_patches:\n            sample_id_path = \"sample_id\"\n        else:\n            sample_id_path = \"id\"\n\n        sample_ids = samples.values(sample_id_path)\n        label_ids = None\n    else:\n        if samples._is_patches:\n            label_id_path = \"id\"\n        else:\n            _, label_id_path = samples._get_label_field_path(\n                patches_field, \"id\"\n            )\n\n        sample_ids = None\n        label_ids = samples.values(label_id_path, unwind=True)\n\n    logger.info(\"Retrieving embeddings from similarity index...\")\n    return similarity_index.get_embeddings(\n        sample_ids=sample_ids,\n        label_ids=label_ids,\n        allow_missing=allow_missing,\n        warn_missing=warn_missing,\n    )\n\n\ndef get_ids(\n    samples,\n    patches_field=None,\n    data=None,\n    data_type=\"embeddings\",\n    handle_missing=\"skip\",\n    ref_sample_ids=None,\n):\n    _validate_args(samples, patches_field=patches_field)\n\n    if patches_field is None:\n        if ref_sample_ids is not None:\n            sample_ids = ref_sample_ids\n        else:\n            sample_ids = samples.values(\"id\")\n\n        if data is not None and len(sample_ids) != len(data):\n            raise ValueError(\n                \"The number of %s (%d) in these results no longer matches the \"\n                \"number of samples (%d) in the collection. You must \"\n                \"regenerate the results\"\n                % (data_type, len(data), len(sample_ids))\n            )\n\n        return np.array(sample_ids), None\n\n    sample_ids, label_ids = _get_patch_ids(\n        samples,\n        patches_field,\n        handle_missing=handle_missing,\n        ref_sample_ids=ref_sample_ids,\n    )\n\n    if data is not None and len(sample_ids) != len(data):\n        raise ValueError(\n            \"The number of %s (%d) in these results no longer matches the \"\n            \"number of labels (%d) in the '%s' field of the collection. You \"\n            \"must regenerate the results\"\n            % (data_type, len(data), len(sample_ids), patches_field)\n        )\n\n    return np.array(sample_ids), np.array(label_ids)\n\n\ndef filter_ids(\n    samples,\n    index_sample_ids,\n    index_label_ids,\n    patches_field=None,\n    allow_missing=True,\n    warn_missing=False,\n):\n    _validate_args(samples, patches_field=patches_field)\n\n    if patches_field is None:\n        if samples._is_patches:\n            sample_ids = np.array(samples.values(\"sample_id\"))\n        else:\n            sample_ids = np.array(samples.values(\"id\"))\n\n        if index_sample_ids is None:\n            return sample_ids, None, None, None\n\n        keep_inds, good_inds, bad_ids = _parse_ids(\n            sample_ids,\n            index_sample_ids,\n            \"samples\",\n            allow_missing,\n            warn_missing,\n        )\n\n        if bad_ids is not None:\n            sample_ids = sample_ids[good_inds]\n\n        return sample_ids, None, keep_inds, good_inds\n\n    sample_ids, label_ids = _get_patch_ids(samples, patches_field)\n\n    if index_label_ids is None:\n        return sample_ids, label_ids, None, None\n\n    keep_inds, good_inds, bad_ids = _parse_ids(\n        label_ids,\n        index_label_ids,\n        \"labels\",\n        allow_missing,\n        warn_missing,\n    )\n\n    if bad_ids is not None:\n        sample_ids = sample_ids[good_inds]\n        label_ids = label_ids[good_inds]\n\n    return sample_ids, label_ids, keep_inds, good_inds\n\n\ndef _get_patch_ids(\n    samples, patches_field, handle_missing=\"skip\", ref_sample_ids=None\n):\n    if samples._is_patches:\n        sample_id_path = \"sample_id\"\n    else:\n        sample_id_path = \"id\"\n\n    label_type, label_id_path = samples._get_label_field_path(\n        patches_field, \"id\"\n    )\n    is_list_field = issubclass(label_type, fol._LABEL_LIST_FIELDS)\n\n    sample_ids, label_ids = samples.values([sample_id_path, label_id_path])\n\n    if ref_sample_ids is not None:\n        sample_ids, label_ids = _apply_ref_sample_ids(\n            sample_ids, label_ids, ref_sample_ids\n        )\n\n    if is_list_field:\n        sample_ids, label_ids = _flatten_list_ids(\n            sample_ids, label_ids, handle_missing\n        )\n\n    return np.array(sample_ids), np.array(label_ids)\n\n\ndef _apply_ref_sample_ids(sample_ids, label_ids, ref_sample_ids):\n    ref_label_ids = [None] * len(ref_sample_ids)\n    inds_map = {_id: i for i, _id in enumerate(ref_sample_ids)}\n    for _id, _lid in zip(sample_ids, label_ids):\n        idx = inds_map.get(_id, None)\n        if idx is not None:\n            ref_label_ids[idx] = _lid\n\n    return ref_sample_ids, ref_label_ids\n\n\ndef _flatten_list_ids(sample_ids, label_ids, handle_missing):\n    _sample_ids = []\n    _label_ids = []\n    _add_missing = handle_missing == \"image\"\n\n    for _id, _lids in zip(sample_ids, label_ids):\n        if _lids:\n            for _lid in _lids:\n                _sample_ids.append(_id)\n                _label_ids.append(_lid)\n        elif _add_missing:\n            _sample_ids.append(_id)\n            _label_ids.append(None)\n\n    return _sample_ids, _label_ids\n\n\ndef _parse_ids(ids, index_ids, ftype, allow_missing, warn_missing):\n    if np.array_equal(ids, index_ids):\n        return None, None, None\n\n    inds_map = {_id: idx for idx, _id in enumerate(index_ids)}\n\n    keep_inds = []\n    bad_inds = []\n    bad_ids = []\n    for _idx, _id in enumerate(ids):\n        ind = inds_map.get(_id, None)\n        if ind is not None:\n            keep_inds.append(ind)\n        else:\n            bad_inds.append(_idx)\n            bad_ids.append(_id)\n\n    num_missing_index = len(index_ids) - len(keep_inds)\n    if num_missing_index > 0:\n        if not allow_missing:\n            raise ValueError(\n                \"The index contains %d %s that are not present in the \"\n                \"provided collection\" % (num_missing_index, ftype)\n            )\n\n        if warn_missing:\n            logger.warning(\n                \"Ignoring %d %s from the index that are not present in the \"\n                \"provided collection\",\n                num_missing_index,\n                ftype,\n            )\n\n    num_missing_collection = len(bad_ids)\n    if num_missing_collection > 0:\n        if not allow_missing:\n            raise ValueError(\n                \"The provided collection contains %d %s not present in the \"\n                \"index\" % (num_missing_collection, ftype)\n            )\n\n        if warn_missing:\n            logger.warning(\n                \"Ignoring %d %s from the provided collection that are not \"\n                \"present in the index\",\n                num_missing_collection,\n                ftype,\n            )\n\n        bad_inds = np.array(bad_inds, dtype=np.int64)\n\n        good_inds = np.full(ids.shape, True)\n        good_inds[bad_inds] = False\n    else:\n        good_inds = None\n        bad_ids = None\n\n    keep_inds = np.array(keep_inds, dtype=np.int64)\n\n    return keep_inds, good_inds, bad_ids\n\n\ndef skip_ids(samples, ids, patches_field=None, warn_existing=False):\n    sample_ids, label_ids = get_ids(samples, patches_field=patches_field)\n\n    if patches_field is not None:\n        exclude_ids = list(set(label_ids) & set(ids))\n        num_existing = len(exclude_ids)\n\n        if num_existing > 0:\n            if warn_existing:\n                logger.warning(\"Skipping %d existing label IDs\", num_existing)\n\n            samples = samples.exclude_labels(\n                ids=exclude_ids, fields=patches_field\n            )\n    else:\n        exclude_ids = list(set(sample_ids) & set(ids))\n        num_existing = len(exclude_ids)\n\n        if num_existing > 0:\n            if warn_existing:\n                logger.warning(\"Skipping %d existing sample IDs\", num_existing)\n\n            samples = samples.exclude(exclude_ids)\n\n    return samples\n\n\ndef add_ids(\n    sample_ids,\n    label_ids,\n    index_sample_ids,\n    index_label_ids,\n    patches_field=None,\n    overwrite=True,\n    allow_existing=True,\n    warn_existing=False,\n):\n    if patches_field is not None:\n        ids = label_ids\n        index_ids = index_label_ids\n    else:\n        ids = sample_ids\n        index_ids = index_sample_ids\n\n    ii = []\n    jj = []\n\n    ids_map = {_id: _i for _i, _id in enumerate(index_ids)}\n    new_idx = len(index_ids)\n    for _i, _id in enumerate(ids):\n        _idx = ids_map.get(_id, None)\n        if _idx is None:\n            _idx = new_idx\n            new_idx += 1\n\n        ii.append(_i)\n        jj.append(_idx)\n\n    ii = np.array(ii)\n    jj = np.array(jj)\n\n    n = len(index_sample_ids)\n\n    if not overwrite:\n        existing_inds = np.nonzero(jj < n)[0]\n        num_existing = existing_inds.size\n\n        if num_existing > 0:\n            if not allow_existing:\n                raise ValueError(\n                    \"Found %d IDs (eg '%s') that are already present in the \"\n                    \"index\" % (num_existing, ids[ii[0]])\n                )\n            elif warn_existing:\n                logger.warning(\n                    \"Ignoring %d IDs (eg '%s') that are already present in \"\n                    \"the index\",\n                    num_existing,\n                    ids[ii[0]],\n                )\n\n                ii = np.delete(ii, existing_inds)\n                jj = np.delete(jj, existing_inds)\n\n    if ii.size > 0:\n        sample_ids = np.array(sample_ids)\n        if patches_field is not None:\n            label_ids = np.array(label_ids)\n\n        m = max(jj) - n + 1\n\n        if n == 0:\n            index_sample_ids = np.array([], dtype=sample_ids.dtype)\n            if patches_field is not None:\n                index_label_ids = np.array([], dtype=label_ids.dtype)\n\n        if m > 0:\n            index_sample_ids = np.concatenate(\n                (index_sample_ids, np.empty(m, dtype=index_sample_ids.dtype))\n            )\n            if patches_field is not None:\n                index_label_ids = np.concatenate(\n                    (index_label_ids, np.empty(m, dtype=index_label_ids.dtype))\n                )\n\n        index_sample_ids[jj] = sample_ids[ii]\n        if patches_field is not None:\n            index_label_ids[jj] = label_ids[ii]\n\n    return index_sample_ids, index_label_ids, ii, jj\n\n\ndef add_embeddings(\n    samples,\n    embeddings,\n    sample_ids,\n    label_ids,\n    embeddings_field,\n    patches_field=None,\n):\n    dataset = samples._dataset\n    if dataset.media_type == fomm.GROUP:\n        view = dataset.select_group_slices(_allow_mixed=True)\n    else:\n        view = dataset\n\n    if patches_field is not None:\n        _, embeddings_path = dataset._get_label_field_path(\n            patches_field, embeddings_field\n        )\n\n        values = dict(zip(label_ids, embeddings))\n        view.set_label_values(embeddings_path, values, dynamic=True)\n    else:\n        values = dict(zip(sample_ids, embeddings))\n        view.set_values(embeddings_field, values, key_field=\"id\")\n\n\ndef remove_ids(\n    sample_ids,\n    label_ids,\n    index_sample_ids,\n    index_label_ids,\n    patches_field=None,\n    allow_missing=True,\n    warn_missing=False,\n):\n    rm_inds = []\n\n    if sample_ids is not None:\n        rm_inds.extend(\n            _find_ids(\n                sample_ids,\n                index_sample_ids,\n                allow_missing,\n                warn_missing,\n                \"sample\",\n            )\n        )\n\n    if label_ids is not None:\n        rm_inds.extend(\n            _find_ids(\n                label_ids,\n                index_label_ids,\n                allow_missing,\n                warn_missing,\n                \"label\",\n            )\n        )\n\n    rm_inds = np.array(rm_inds)\n\n    if rm_inds.size > 0:\n        index_sample_ids = np.delete(index_sample_ids, rm_inds)\n        if patches_field is not None:\n            index_label_ids = np.delete(index_label_ids, rm_inds)\n\n    return index_sample_ids, index_label_ids, rm_inds\n\n\ndef _find_ids(ids, index_ids, allow_missing, warn_missing, ftype):\n    found_inds = []\n    missing_ids = []\n\n    ids_map = {_id: _i for _i, _id in enumerate(index_ids)}\n    for _id in ids:\n        ind = ids_map.get(_id, None)\n        if ind is not None:\n            found_inds.append(ind)\n        elif not allow_missing:\n            missing_ids.append(_id)\n\n    num_missing = len(missing_ids)\n\n    if num_missing > 0:\n        if not allow_missing:\n            raise ValueError(\n                \"Found %d %d IDs (eg '%s') that are not present in the index\"\n                % (num_missing, ftype, missing_ids[0])\n            )\n\n        if warn_missing:\n            logger.warning(\n                \"Ignoring %d %d IDs (eg '%s') that are not present in the \"\n                \"index\",\n                num_missing,\n                ftype,\n                missing_ids[0],\n            )\n\n    return found_inds\n\n\ndef remove_embeddings(\n    samples,\n    embeddings_field,\n    sample_ids=None,\n    label_ids=None,\n    patches_field=None,\n):\n    dataset = samples._dataset\n    if dataset.media_type == fomm.GROUP:\n        view = dataset.select_group_slices(_allow_mixed=True)\n    else:\n        view = dataset\n\n    if patches_field is not None:\n        _, embeddings_path = dataset._get_label_field_path(\n            patches_field, embeddings_field\n        )\n\n        if sample_ids is not None and label_ids is None:\n            _, id_path = dataset._get_label_field_path(patches_field, \"id\")\n            label_ids = view.select(sample_ids).values(id_path, unwind=True)\n\n        if label_ids is not None:\n            values = dict(zip(label_ids, itertools.repeat(None)))\n            view.set_label_values(embeddings_path, values)\n    elif sample_ids is not None:\n        values = dict(zip(sample_ids, itertools.repeat(None)))\n        view.set_values(embeddings_field, values, key_field=\"id\")\n\n\ndef filter_values(values, keep_inds, patches_field=None):\n    if patches_field:\n        _values = list(itertools.chain.from_iterable(values))\n    else:\n        _values = values\n\n    _values = np.asarray(_values)\n\n    if _values.size == keep_inds.size:\n        _values = _values[keep_inds]\n    else:\n        num_expected = np.count_nonzero(keep_inds)\n        if _values.size != num_expected:\n            raise ValueError(\n                \"Expected %d raw values or %d pre-filtered values; found %d \"\n                \"values\" % (keep_inds.size, num_expected, values.size)\n            )\n\n    # @todo we might need to re-ravel patch values here in the future\n    # We currently do not do this because all downstream users of this data\n    # will gracefully handle either flat or nested list data\n\n    return _values\n\n\ndef get_values(samples, path_or_expr, ids, patches_field=None):\n    _validate_args(\n        samples, patches_field=patches_field, path_or_expr=path_or_expr\n    )\n    return samples._get_values_by_id(\n        path_or_expr, ids, link_field=patches_field\n    )\n\n\ndef parse_data_field(\n    samples,\n    data_field,\n    patches_field=None,\n    data_type=\"embeddings\",\n):\n    if not etau.is_str(data_field):\n        raise ValueError(\n            \"Invalid %s field '%s'; expected a string field name\"\n            % (data_type, data_field)\n        )\n\n    if patches_field is None:\n        _data_field, is_frame_field = samples._handle_frame_field(data_field)\n\n        if \".\" in _data_field:\n            root, _ = _data_field.rsplit(\".\", 1)\n            if not samples.has_field(root):\n                raise ValueError(\n                    \"Invalid %s field '%s'; root field '%s' does not exist\"\n                    % (data_type, data_field, root)\n                )\n\n        data_exists = samples.has_field(data_field)\n\n        return data_field, data_exists\n\n    if data_field.startswith(patches_field + \".\"):\n        _, root = samples._get_label_field_path(patches_field)\n        if not data_field.startswith(root + \".\"):\n            raise ValueError(\n                \"Invalid %s field '%s' for patches field '%s'\"\n                % (data_type, data_field, patches_field)\n            )\n\n        data_field = data_field[len(root) + 1 :]\n\n    if \".\" in data_field:\n        _, root = samples._get_label_field_path(patches_field)\n        root += data_field.rsplit(\".\", 1)[0]\n        if not samples.has_field(root):\n            raise ValueError(\n                \"Invalid %s field '%s'; root field '%s' does not exist\"\n                % (data_type, data_field, root)\n            )\n\n    _, data_path = samples._get_label_field_path(patches_field, data_field)\n\n    data_exists = samples.has_field(data_path)\n\n    return data_field, data_exists\n\n\ndef get_embeddings(\n    samples,\n    model=None,\n    model_kwargs=None,\n    patches_field=None,\n    embeddings_field=None,\n    embeddings=None,\n    similarity_index=None,\n    force_square=False,\n    alpha=None,\n    handle_missing=\"skip\",\n    agg_fcn=None,\n    batch_size=None,\n    num_workers=None,\n    skip_failures=True,\n    progress=None,\n):\n    _validate_args(samples, patches_field=patches_field)\n\n    if (\n        model is None\n        and embeddings_field is None\n        and embeddings is None\n        and similarity_index is None\n    ):\n        return _empty_embeddings(patches_field)\n\n    if similarity_index is not None:\n        return get_embeddings_from_index(\n            samples,\n            similarity_index,\n            patches_field=patches_field,\n            allow_missing=True,\n            warn_missing=True,\n        )\n\n    if (\n        embeddings is None\n        and model is not None\n        and not _has_embeddings_field(samples, embeddings_field, patches_field)\n    ):\n        if etau.is_str(model):\n            model_kwargs = model_kwargs or {}\n            model = foz.load_zoo_model(model, **model_kwargs)\n\n        if patches_field is not None:\n            logger.info(\"Computing patch embeddings...\")\n            embeddings = samples.compute_patch_embeddings(\n                model,\n                patches_field,\n                embeddings_field=embeddings_field,\n                force_square=force_square,\n                alpha=alpha,\n                handle_missing=handle_missing,\n                batch_size=batch_size,\n                num_workers=num_workers,\n                skip_failures=skip_failures,\n                progress=progress,\n            )\n        else:\n            logger.info(\"Computing embeddings...\")\n            embeddings = samples.compute_embeddings(\n                model,\n                embeddings_field=embeddings_field,\n                batch_size=batch_size,\n                num_workers=num_workers,\n                skip_failures=skip_failures,\n                progress=progress,\n            )\n\n    if embeddings is None and embeddings_field is not None:\n        embeddings, samples = _load_embeddings(\n            samples, embeddings_field, patches_field=patches_field\n        )\n        ref_sample_ids = None\n    else:\n        if isinstance(embeddings, dict):\n            embeddings = [\n                embeddings.get(_id, None) for _id in samples.values(\"id\")\n            ]\n\n        embeddings, ref_sample_ids = _handle_missing_embeddings(\n            embeddings, samples\n        )\n\n    if not isinstance(embeddings, np.ndarray) and not embeddings:\n        return _empty_embeddings(patches_field)\n\n    if patches_field is not None:\n        if agg_fcn is not None:\n            embeddings = np.stack([agg_fcn(e) for e in embeddings])\n        else:\n            embeddings = np.concatenate(embeddings, axis=0)\n    elif not isinstance(embeddings, np.ndarray):\n        embeddings = np.stack(embeddings)\n\n    if agg_fcn is not None:\n        patches_field = None\n\n    sample_ids, label_ids = get_ids(\n        samples,\n        patches_field=patches_field,\n        data=embeddings,\n        data_type=\"embeddings\",\n        handle_missing=handle_missing,\n        ref_sample_ids=ref_sample_ids,\n    )\n\n    return embeddings, sample_ids, label_ids\n\n\ndef get_unique_name(name, ref_names_or_fcn, max_len=None):\n    unique_name = _get_unique_name(name, ref_names_or_fcn)\n\n    if max_len is not None:\n        while name and len(unique_name) > max_len:\n            name = name[:-1]\n            unique_name = _get_unique_name(name, ref_names_or_fcn)\n\n    return unique_name\n\n\ndef _get_unique_name(name, ref_names_or_fcn):\n    if etau.is_container(ref_names_or_fcn):\n        return _get_unique_name_from_list(name, ref_names_or_fcn)\n\n    return _get_unique_name_from_function(name, ref_names_or_fcn)\n\n\ndef _get_unique_name_from_list(name, ref_names):\n    ref_names = set(ref_names)\n\n    if name not in ref_names:\n        return name\n\n    name += \"-\" + _get_random_characters(6)\n    while name in ref_names:\n        name += _get_random_characters(1)\n\n    return name\n\n\ndef _get_unique_name_from_function(name, exists_fcn):\n    if not exists_fcn(name):\n        return name\n\n    name += \"-\" + _get_random_characters(6)\n    while exists_fcn(name):\n        name += _get_random_characters(1)\n\n    return name\n\n\ndef _get_random_characters(n):\n    return \"\".join(\n        random.choice(string.ascii_lowercase + string.digits) for _ in range(n)\n    )\n\n\ndef _empty_embeddings(patches_field):\n    embeddings = np.empty((0, 0), dtype=float)\n    sample_ids = np.array([], dtype=\"<U24\")\n\n    if patches_field is not None:\n        label_ids = np.array([], dtype=\"<U24\")\n    else:\n        label_ids = None\n\n    return embeddings, sample_ids, label_ids\n\n\ndef _has_embeddings_field(samples, embeddings_field, patches_field=None):\n    if embeddings_field is None:\n        return False\n\n    if patches_field is not None:\n        _, embeddings_path = samples._get_label_field_path(\n            patches_field, embeddings_field\n        )\n    else:\n        embeddings_path = embeddings_field\n\n    return samples.has_field(embeddings_path)\n\n\ndef _load_embeddings(samples, embeddings_field, patches_field=None):\n    if patches_field is not None:\n        label_type, embeddings_path = samples._get_label_field_path(\n            patches_field, embeddings_field\n        )\n        is_list_field = issubclass(label_type, fol._LABEL_LIST_FIELDS)\n    elif samples.has_field(embeddings_field):\n        embeddings_path = embeddings_field\n        is_list_field = False\n    else:\n        return [], samples.limit(0)\n\n    if is_list_field:\n        samples = samples.filter_labels(\n            patches_field, F(embeddings_field) != None\n        )\n    else:\n        samples = samples.match(F(embeddings_path) != None)\n\n    if samples.has_field(embeddings_path):\n        _field = None\n    else:\n        _field = fof.VectorField()\n\n    embeddings = samples.values(embeddings_path, _field=_field)\n\n    if is_list_field:\n        embeddings = [np.stack(e) for e in embeddings if e]\n\n    return embeddings, samples\n\n\ndef _validate_args(samples, patches_field=None, path_or_expr=None):\n    if patches_field is not None:\n        _validate_patches_args(\n            samples, patches_field, path_or_expr=path_or_expr\n        )\n    else:\n        _validate_samples_args(samples, path_or_expr=path_or_expr)\n\n\ndef _validate_samples_args(samples, path_or_expr=None):\n    if not etau.is_str(path_or_expr):\n        return\n\n    path, _, list_fields, _, _ = samples._parse_field_name(path_or_expr)\n\n    if list_fields:\n        raise ValueError(\n            \"Values path '%s' contains invalid list field '%s'\"\n            % (path, list_fields[0])\n        )\n\n\ndef _validate_patches_args(samples, patches_field, path_or_expr=None):\n    if samples.media_type == fomm.VIDEO:\n        raise ValueError(\n            \"This method does not directly support frame patches for video \"\n            \"collections. Try converting to a frames view via `to_frames()` \"\n            \"first\"\n        )\n\n    if etau.is_str(path_or_expr) and not path_or_expr.startswith(\n        patches_field + \".\"\n    ):\n        raise ValueError(\n            \"Values path '%s' must start with patches field '%s'\"\n            % (path_or_expr, patches_field)\n        )\n\n    if (\n        isinstance(samples, fop.PatchesView)\n        and patches_field != samples.patches_field\n    ):\n        raise ValueError(\n            \"This patches view contains labels from field '%s', not \"\n            \"'%s'\" % (samples.patches_field, patches_field)\n        )\n\n    if isinstance(\n        samples, fop.EvaluationPatchesView\n    ) and patches_field not in (\n        samples.gt_field,\n        samples.pred_field,\n    ):\n        raise ValueError(\n            \"This evaluation patches view contains patches from \"\n            \"fields '%s' and '%s', not '%s'\"\n            % (samples.gt_field, samples.pred_field, patches_field)\n        )\n\n\ndef _handle_missing_embeddings(embeddings, samples):\n    if isinstance(embeddings, np.ndarray):\n        return embeddings, None\n\n    missing_inds = []\n    for idx, embedding in enumerate(embeddings):\n        if embedding is None:\n            missing_inds.append(idx)\n\n    if not missing_inds:\n        return embeddings, None\n\n    embeddings = [e for e in embeddings if e is not None]\n    ref_sample_ids = list(np.delete(samples.values(\"id\"), missing_inds))\n\n    return embeddings, ref_sample_ids\n"
  },
  {
    "path": "fiftyone/brain/internal/core/visualization.py",
    "content": "\"\"\"\nVisualization methods.\n\n| Copyright 2017-2026, Voxel51, Inc.\n| `voxel51.com <https://voxel51.com/>`_\n|\n\"\"\"\n\n# For backwards-compatibility with older versions of plugins like\n# https://github.com/voxel51/fiftyone-plugins/blob/5c800f1ded53c285f8e17f37e1ad9b2472fa93e7/plugins/brain/__init__.py#L25\nfrom fiftyone.brain.visualization import (\n    Visualization,\n    UMAPVisualization,\n    TSNEVisualization,\n    PCAVisualization,\n    ManualVisualization,\n)\n"
  },
  {
    "path": "fiftyone/brain/internal/models/.gitignore",
    "content": "cache/\n"
  },
  {
    "path": "fiftyone/brain/internal/models/__init__.py",
    "content": "\"\"\"\nBrain models.\n\n| Copyright 2017-2026, Voxel51, Inc.\n| `voxel51.com <https://voxel51.com/>`_\n|\n\"\"\"\nfrom copy import deepcopy\nimport logging\nimport os\n\nfrom eta.core.config import ConfigError\nimport eta.core.learning as etal\nimport eta.core.models as etam\n\nimport fiftyone.core.models as fom\n\n\nlogger = logging.getLogger(__name__)\n\n\n_THIS_DIR = os.path.dirname(os.path.abspath(__file__))\n_MODELS_MANIFEST_PATH = os.path.join(_THIS_DIR, \"manifest.json\")\n_MODELS_DIR = os.path.join(_THIS_DIR, \"cache\")\n\n\ndef list_models():\n    \"\"\"Returns the list of available models.\n\n    Returns:\n        a list of model names\n    \"\"\"\n    manifest = _load_models_manifest()\n    return sorted([model.name for model in manifest])\n\n\ndef list_downloaded_models():\n    \"\"\"Returns information about the models that have been downloaded.\n\n    Returns:\n        a dict mapping model names to (model path, ``eta.core.models.Model``)\n        tuples\n    \"\"\"\n    manifest = _load_models_manifest()\n    models = {}\n    for model in manifest:\n        if model.is_in_dir(_MODELS_DIR):\n            model_path = model.get_path_in_dir(_MODELS_DIR)\n            models[model.name] = (model_path, model)\n\n    return models\n\n\ndef is_model_downloaded(name):\n    \"\"\"Determines whether the model of the given name is downloaded.\n\n    Args:\n        name: the name of the model, which can have ``@<ver>`` appended to\n            refer to a specific version of the model. If no version is\n            specified, the latest version of the model is used\n\n    Returns:\n        True/False\n    \"\"\"\n    model = _get_model(name)\n    return model.is_in_dir(_MODELS_DIR)\n\n\ndef download_model(name, overwrite=False):\n    \"\"\"Downloads the model of the given name.\n\n    If the model is already downloaded, it is not re-downloaded unless\n    ``overwrite == True`` is specified.\n\n    Args:\n        name: the name of the model, which can have ``@<ver>`` appended to\n            refer to a specific version of the model. If no version is\n            specified, the latest version of the model is used. Call\n            :func:`list_models` to see the available models\n        overwrite (False): whether to overwrite any existing files\n\n    Returns:\n        tuple of\n\n        -   model: the ``eta.core.models.Model`` instance for the model\n        -   model_path: the path to the downloaded model on disk\n    \"\"\"\n    model, model_path = _get_model_in_dir(name)\n\n    if not overwrite and is_model_downloaded(name):\n        logger.info(\"Model '%s' is already downloaded\", name)\n    else:\n        model.manager.download_model(model_path, force=overwrite)\n\n    return model, model_path\n\n\ndef install_model_requirements(name, error_level=0):\n    \"\"\"Installs any package requirements for the model with the given name.\n\n    Args:\n        name: the name of the model, which can have ``@<ver>`` appended to\n            refer to a specific version of the model. If no version is\n            specified, the latest version of the model is used. Call\n            :func:`list_models` to see the available models\n        error_level: the error level to use, defined as:\n\n            0: raise error if a requirement install fails\n            1: log warning if a requirement install fails\n            2: ignore install fails\n    \"\"\"\n    model = _get_model(name)\n    model.install_requirements(error_level=error_level)\n\n\ndef ensure_model_requirements(name, error_level=0):\n    \"\"\"Ensures that the package requirements for the model with the given name\n    are satisfied.\n\n    Args:\n        name: the name of the model, which can have ``@<ver>`` appended to\n            refer to a specific version of the model. If no version is\n            specified, the latest version of the model is used. Call\n            :func:`list_models` to see the available models\n        error_level: the error level to use, defined as:\n\n            0: raise error if a requirement is not satisfied\n            1: log warning if a requirement is not satisifed\n            2: ignore unsatisifed requirements\n    \"\"\"\n    model = _get_model(name)\n    model.ensure_requirements(error_level=error_level)\n\n\ndef load_model(\n    name,\n    download_if_necessary=True,\n    install_requirements=False,\n    error_level=0,\n    **kwargs\n):\n    \"\"\"Loads the model of the given name.\n\n    By default, the model will be downloaded if necessary.\n\n    Args:\n        name: the name of the model, which can have ``@<ver>`` appended to\n            refer to a specific version of the model. If no version is\n            specified, the latest version of the model is used. Call\n            :func:`list_models` to see the available models\n        download_if_necessary (True): whether to download the model if it is\n            not found in the specified directory\n        install_requirements: whether to install any requirements before\n            loading the model. By default, this is False\n        error_level: the error level to use, defined as:\n\n            0: raise error if a requirement is not satisfied\n            1: log warning if a requirement is not satisifed\n            2: ignore unsatisifed requirements\n\n        **kwargs: keyword arguments to inject into the model's ``Config``\n            instance\n\n    Returns:\n        a :class:`fiftyone.core.models.Model`\n    \"\"\"\n    model = _get_model(name)\n\n    if not model.is_in_dir(_MODELS_DIR):\n        if not download_if_necessary:\n            raise ValueError(\"Model '%s' is not downloaded\" % name)\n\n        download_model(name)\n\n    if install_requirements:\n        model.install_requirements(error_level=error_level)\n    else:\n        model.ensure_requirements(error_level=error_level)\n\n    config_dict = deepcopy(model.default_deployment_config_dict)\n    model_path = model.get_path_in_dir(_MODELS_DIR)\n\n    return fom.load_model(config_dict, model_path=model_path, **kwargs)\n\n\ndef find_model(name):\n    \"\"\"Returns the path to the model on disk.\n\n    The model must be downloaded. Use :func:`download_model` to download\n    models.\n\n    Args:\n        name: the name of the model, which can have ``@<ver>`` appended to\n            refer to a specific version of the model. If no version is\n            specified, the latest version of the model is used\n\n    Returns:\n        the path to the model on disk\n\n    Raises:\n        ValueError: if the model does not exist or has not been downloaded\n    \"\"\"\n    model, model_path = _get_model_in_dir(name)\n    if not model.is_model_downloaded(model_path):\n        raise ValueError(\"Model '%s' is not downloaded\" % name)\n\n    return model_path\n\n\ndef get_model(name):\n    \"\"\"Returns the ``eta.core.models.Model`` instance for the model with the\n    given name.\n\n    Args:\n        name: the name of the model\n\n    Returnsn ``eta.core.models.Model``:class:`ZooModel`\n    \"\"\"\n    return _get_model(name)\n\n\ndef delete_model(name):\n    \"\"\"Deletes the model from local disk, if necessary.\n\n    Args:\n        name: the name of the model, which can have ``@<ver>`` appended to\n            refer to a specific version of the model. If no version is\n            specified, the latest version of the model is used\n    \"\"\"\n    model, model_path = _get_model_in_dir(name)\n    model.flush_model(model_path)\n\n\nclass HasBrainModel(etal.HasPublishedModel):\n    \"\"\"Mixin class for Config classes of :class:`fiftyone.core.models.Model`\n    instances whose models are stored privately by the FiftyOne Brain.\n    \"\"\"\n\n    def download_model_if_necessary(self):\n        # pylint: disable=attribute-defined-outside-init\n        if not self.model_name and not self.model_path:\n            raise ConfigError(\n                \"Either `model_name` or `model_path` must be provided\"\n            )\n\n        if self.model_path is None:\n            self.model_path = download_model(self.model_name)\n\n    @classmethod\n    def _get_model(cls, model_name):\n        return get_model(model_name)\n\n\ndef _load_models_manifest():\n    return etam.ModelsManifest.from_json(_MODELS_MANIFEST_PATH)\n\n\ndef _get_model_in_dir(name):\n    model = _get_model(name)\n    model_path = model.get_path_in_dir(_MODELS_DIR)\n    return model, model_path\n\n\ndef _get_model(name):\n    if etam.Model.has_version_str(name):\n        return _get_exact_model(name)\n\n    return _get_latest_model(name)\n\n\ndef _get_exact_model(name):\n    manifest = _load_models_manifest()\n    try:\n        return manifest.get_model_with_name(name)\n    except etam.ModelError:\n        raise ValueError(\"No model with name '%s' was found\" % name)\n\n\ndef _get_latest_model(base_name):\n    manifest = _load_models_manifest()\n    try:\n        return manifest.get_latest_model_with_base_name(base_name)\n    except etam.ModelError:\n        raise ValueError(\"No models found with base name '%s'\" % base_name)\n"
  },
  {
    "path": "fiftyone/brain/internal/models/manifest.json",
    "content": "{\n    \"models\": [\n        {\n            \"base_name\": \"simple-resnet-cifar10\",\n            \"base_filename\": \"simple-resnet-cifar10.pth\",\n            \"version\": \"1.0\",\n            \"description\": \"Simple ResNet trained on CIFAR-10\",\n            \"manager\": {\n                \"type\": \"fiftyone.core.models.ModelManager\",\n                \"config\": {\n                    \"google_drive_id\": \"1SIO9XreK0w1ja4EuhBWcR10CnWxCOsom\"\n                }\n            },\n            \"default_deployment_config_dict\": {\n                \"type\": \"fiftyone.brain.internal.models.torch.TorchImageModel\",\n                \"config\": {\n                    \"entrypoint_fcn\": \"fiftyone.brain.internal.models.simple_resnet.simple_resnet\",\n                    \"output_processor_cls\": \"fiftyone.utils.torch.ClassifierOutputProcessor\",\n                    \"labels_string\": \"airplane,automobile,bird,cat,deer,dog,frog,horse,ship,truck\",\n                    \"image_size\": [32, 32],\n                    \"image_mean\": [0.4914, 0.4822, 0.4465],\n                    \"image_std\": [0.2023, 0.1994, 0.201],\n                    \"embeddings_layer\": \"flatten\",\n                    \"use_half_precision\": false,\n                    \"cudnn_benchmark\": true\n                }\n            },\n            \"date_created\": \"2020-05-07 08:25:51\"\n        }\n    ]\n}\n"
  },
  {
    "path": "fiftyone/brain/internal/models/simple_resnet.py",
    "content": "\"\"\"\nImplementation of a simple ResNet that is suitable only for smallish data.\n\nThe original implementation of this is from David Page's work on fast model\ntraining with resnets at https://github.com/davidcpage/cifar10-fast.\n\n| Copyright 2017-2026, Voxel51, Inc.\n| `voxel51.com <https://voxel51.com/>`_\n|\n\"\"\"\nfrom collections import namedtuple\nimport os\n\nimport numpy as np\nimport torch\nfrom torch import nn\n\n\ndef simple_resnet(\n    channels=None,\n    weight=0.125,\n    pool=nn.MaxPool2d(2),\n    extra_layers=(),\n    res_layers=(\"layer1\", \"layer3\"),\n):\n    channels = channels or {\n        \"prep\": 64,\n        \"layer1\": 128,\n        \"layer2\": 256,\n        \"layer3\": 512,\n    }\n    net = {\n        \"input\": (None, []),\n        \"prep\": conv_bn(3, channels[\"prep\"]),\n        \"layer1\": dict(\n            conv_bn(channels[\"prep\"], channels[\"layer1\"]), pool=pool\n        ),\n        \"layer2\": dict(\n            conv_bn(channels[\"layer1\"], channels[\"layer2\"]), pool=pool\n        ),\n        \"layer3\": dict(\n            conv_bn(channels[\"layer2\"], channels[\"layer3\"]), pool=pool\n        ),\n        \"pool\": nn.MaxPool2d(4),\n        \"flatten\": Flatten(),\n        \"linear\": nn.Linear(channels[\"layer3\"], 10, bias=False),\n        \"logits\": Mul(weight),\n    }\n    for layer in res_layers:\n        net[layer][\"residual\"] = residual(channels[layer])\n\n    for layer in extra_layers:\n        net[layer][\"extra\"] = conv_bn(channels[layer], channels[layer])\n\n    return Network(net, input_layer=\"input\", output_layer=\"logits\")\n\n\nclass Network(nn.Module):\n    def __init__(self, net, input_layer=None, output_layer=None):\n        super().__init__()\n        self.input_layer = input_layer\n        self.output_layer = output_layer\n        self.graph = build_graph(net)\n        for path, (val, _) in self.graph.items():\n            setattr(self, path.replace(\"/\", \"_\"), val)\n\n    def nodes(self):\n        return (node for node, _ in self.graph.values())\n\n    def forward(self, inputs):\n        if self.input_layer:\n            outputs = {self.input_layer: inputs}\n        else:\n            outputs = dict(inputs)\n\n        for k, (node, ins) in self.graph.items():\n            # only compute nodes that are not supplied as inputs.\n            if k not in outputs:\n                outputs[k] = node(*[outputs[x] for x in ins])\n\n        if self.output_layer:\n            return outputs[self.output_layer]\n\n        return outputs\n\n    def half(self):\n        for node in self.nodes():\n            if isinstance(node, nn.Module) and not isinstance(\n                node, nn.BatchNorm2d\n            ):\n                node.half()\n\n        return self\n\n\ndef has_inputs(node):\n    return type(node) is tuple\n\n\ndef build_graph(net):\n    flattened = pipeline(net)\n    resolve_input = lambda rel_path, path, idx: (\n        os.path.normpath(os.path.sep.join((path, \"..\", rel_path)))\n        if isinstance(rel_path, str)\n        else flattened[idx + rel_path][0]\n    )\n    return {\n        path: (\n            node[0],\n            [resolve_input(rel_path, path, idx) for rel_path in node[1]],\n        )\n        for idx, (path, node) in enumerate(flattened)\n    }\n\n\ndef pipeline(net):\n    return [\n        (os.path.sep.join(path), (node if has_inputs(node) else (node, [-1])))\n        for (path, node) in path_iter(net)\n    ]\n\n\nclass Crop(namedtuple(\"Crop\", (\"h\", \"w\"))):\n    def __call__(self, x, x0, y0):\n        return x[..., y0 : y0 + self.h, x0 : x0 + self.w]\n\n    def options(self, shape):\n        *_, H, W = shape\n        return [\n            {\"x0\": x0, \"y0\": y0}\n            for x0 in range(W + 1 - self.w)\n            for y0 in range(H + 1 - self.h)\n        ]\n\n    def output_shape(self, shape):\n        *_, H, W = shape\n        return (*_, self.h, self.w)\n\n\nclass FlipLR(namedtuple(\"FlipLR\", ())):\n    def __call__(self, x, choice):\n        if isinstance(x, np.ndarray):\n            return x[..., ::-1].copy()\n\n        return torch.flip(x, [-1]) if choice else x\n\n    def options(self, shape):\n        return [{\"choice\": b} for b in [True, False]]\n\n\nclass Cutout(namedtuple(\"Cutout\", (\"h\", \"w\"))):\n    def __call__(self, x, x0, y0):\n        x[..., y0 : y0 + self.h, x0 : x0 + self.w] = 0.0\n        return x\n\n    def options(self, shape):\n        *_, H, W = shape\n        return [\n            {\"x0\": x0, \"y0\": y0}\n            for x0 in range(W + 1 - self.w)\n            for y0 in range(H + 1 - self.h)\n        ]\n\n\nclass PiecewiseLinear(namedtuple(\"PiecewiseLinear\", (\"knots\", \"vals\"))):\n    def __call__(self, t):\n        return np.interp([t], self.knots, self.vals)[0]\n\n\nclass Const(namedtuple(\"Const\", [\"val\"])):\n    def __call__(self, x):\n        return self.val\n\n\nclass Identity(namedtuple(\"Identity\", [])):\n    def __call__(self, x):\n        return x\n\n\nclass Add(namedtuple(\"Add\", [])):\n    def __call__(self, x, y):\n        return x + y\n\n\nclass AddWeighted(namedtuple(\"AddWeighted\", [\"wx\", \"wy\"])):\n    def __call__(self, x, y):\n        return self.wx * x + self.wy * y\n\n\nclass Mul(nn.Module):\n    def __init__(self, weight):\n        super().__init__()\n        self.weight = weight\n\n    def __call__(self, x):\n        return x * self.weight\n\n\nclass Flatten(nn.Module):\n    def forward(self, x):\n        return x.view(x.size(0), x.size(1))\n\n\nclass Concat(nn.Module):\n    def forward(self, *xs):\n        return torch.cat(xs, 1)\n\n\nclass BatchNorm(nn.BatchNorm2d):\n    def __init__(\n        self,\n        num_features,\n        eps=1e-05,\n        momentum=0.1,\n        weight_freeze=False,\n        bias_freeze=False,\n        weight_init=1.0,\n        bias_init=0.0,\n    ):\n        super().__init__(num_features, eps=eps, momentum=momentum)\n        if weight_init is not None:\n            self.weight.data.fill_(weight_init)\n\n        if bias_init is not None:\n            self.bias.data.fill_(bias_init)\n\n        self.weight.requires_grad = not weight_freeze\n        self.bias.requires_grad = not bias_freeze\n\n\ndef conv_bn(c_in, c_out):\n    return {\n        \"conv\": nn.Conv2d(\n            c_in, c_out, kernel_size=3, stride=1, padding=1, bias=False\n        ),\n        \"bn\": BatchNorm(c_out),\n        \"relu\": nn.ReLU(True),\n    }\n\n\ndef residual(c):\n    return {\n        \"in\": Identity(),\n        \"res1\": conv_bn(c, c),\n        \"res2\": conv_bn(c, c),\n        \"add\": (Add(), [\"in\", \"res2/relu\"]),\n    }\n\n\ndef path_iter(nested_dict, pfx=()):\n    for name, val in nested_dict.items():\n        if isinstance(val, dict):\n            yield from path_iter(val, (*pfx, name))\n        else:\n            yield ((*pfx, name), val)\n\n\nMODEL = \"model\"\nVALID_MODEL = \"valid_model\"\nOUTPUT = \"output\"\n"
  },
  {
    "path": "fiftyone/brain/internal/models/torch.py",
    "content": "\"\"\"\nPyTorch utilities.\n\n| Copyright 2017-2026, Voxel51, Inc.\n| `voxel51.com <https://voxel51.com/>`_\n|\n\"\"\"\nimport fiftyone.utils.torch as fout\n\nfrom fiftyone.brain.internal.models import HasBrainModel\n\nimport torch\n\n\nclass TorchImageModelConfig(fout.TorchImageModelConfig, HasBrainModel):\n    \"\"\"Configuration for running a :class:`TorchImageModel`.\n\n    See :class:`fiftyone.utils.torch.TorchImageModelConfig` for additional\n    parameters.\n\n    Args:\n        model_name (None): the name of the Brain model state dict to load\n        model_path (None): the path to a state dict on disk to load\n    \"\"\"\n\n    def __init__(self, d):\n        d = self.init(d)\n        super().__init__(d)\n\n\nclass TorchImageModel(fout.TorchImageModel):\n    \"\"\"Wrapper for evaluating a Torch model on images whose state dict is\n    stored privately by the Brain.\n\n    Args:\n        config: an :class:`TorchImageModelConfig`\n    \"\"\"\n\n    def _download_model(self, config):\n        config.download_model_if_necessary()\n\n    def _load_state_dict(self, model, config):\n        state_dict = torch.load(config.model_path, map_location=self.device)\n        model.load_state_dict(state_dict)\n"
  },
  {
    "path": "fiftyone/brain/similarity.py",
    "content": "\"\"\"\nSimilarity interface.\n\n| Copyright 2017-2026, Voxel51, Inc.\n| `voxel51.com <https://voxel51.com/>`_\n|\n\"\"\"\nfrom collections import defaultdict\nfrom copy import deepcopy\nimport inspect\nimport logging\n\nfrom bson import ObjectId\nimport numpy as np\n\nimport eta.core.utils as etau\n\nimport fiftyone.brain as fb\nimport fiftyone.core.brain as fob\nimport fiftyone.core.context as foc\nimport fiftyone.core.dataset as fod\nimport fiftyone.core.fields as fof\nimport fiftyone.core.labels as fol\nimport fiftyone.core.media as fom\nimport fiftyone.core.patches as fop\nimport fiftyone.core.stages as fos\nimport fiftyone.core.utils as fou\nimport fiftyone.core.view as fov\nimport fiftyone.core.validation as fova\nimport fiftyone.zoo as foz\nfrom fiftyone import ViewField as F\n\nfbu = fou.lazy_import(\"fiftyone.brain.internal.core.utils\")\n\n\nlogger = logging.getLogger(__name__)\n\n_ALLOWED_ROI_FIELD_TYPES = (\n    fol.Detection,\n    fol.Detections,\n    fol.Polyline,\n    fol.Polylines,\n)\n\n_DEFAULT_MODEL = \"mobilenet-v2-imagenet-torch\"\n_DEFAULT_BATCH_SIZE = None\n\n\ndef compute_similarity(\n    samples,\n    patches_field,\n    roi_field,\n    embeddings,\n    brain_key,\n    model,\n    model_kwargs,\n    force_square,\n    alpha,\n    batch_size,\n    num_workers,\n    skip_failures,\n    progress,\n    backend,\n    **kwargs,\n):\n    \"\"\"See ``fiftyone/brain/__init__.py``.\"\"\"\n\n    fova.validate_collection(samples)\n\n    if roi_field is not None:\n        fova.validate_collection_label_fields(\n            samples, roi_field, _ALLOWED_ROI_FIELD_TYPES\n        )\n\n    # Allow for `embeddings_field=XXX` and `embeddings=False` together\n    embeddings_field = kwargs.pop(\"embeddings_field\", None)\n    if embeddings_field is not None or etau.is_str(embeddings):\n        if embeddings_field is None:\n            embeddings_field = embeddings\n            embeddings = None\n\n        embeddings_field, embeddings_exist = fbu.parse_data_field(\n            samples,\n            embeddings_field,\n            patches_field=patches_field or roi_field,\n            data_type=\"embeddings\",\n        )\n    else:\n        embeddings_field = None\n        embeddings_exist = None\n\n    if model is None and embeddings is None and not embeddings_exist:\n        model = _DEFAULT_MODEL\n        if batch_size is None:\n            batch_size = _DEFAULT_BATCH_SIZE\n\n    if etau.is_str(model):\n        _model_kwargs = model_kwargs or {}\n        _model = foz.load_zoo_model(model, **_model_kwargs)\n    else:\n        _model = model\n\n    try:\n        supports_prompts = _model.can_embed_prompts\n    except:\n        supports_prompts = False\n\n    if brain_key is not None and supports_prompts and not etau.is_str(model):\n        logger.warning(\n            \"This index will not support prompt queries in the App or in \"\n            \"future Python sessions. You can support this by providing the \"\n            \"string name of a zoo model rather than a Model instance to \"\n            \"compute_similarity(model=).\"\n        )\n\n    config = _parse_config(\n        backend,\n        embeddings_field=embeddings_field,\n        patches_field=patches_field,\n        roi_field=roi_field,\n        model=model,\n        model_kwargs=model_kwargs,\n        supports_prompts=supports_prompts,\n        **kwargs,\n    )\n    brain_method = config.build()\n    brain_method.ensure_requirements()\n\n    # Similarity indexes can be modified after creation, so we always register\n    # the index on the full dataset so that queries will always be performed\n    # against the full index by default\n    dataset = samples._root_dataset\n    if samples._is_frames:\n        dataset = samples._base_view\n\n    if brain_key is not None:\n        # Don't allow overwriting an existing run with same key, since we\n        # need the existing run in order to perform workflows like\n        # automatically cleaning up the backend's index\n        brain_method.register_run(dataset, brain_key, overwrite=False)\n\n    results = brain_method.initialize(dataset, brain_key)\n\n    results._model = _model\n    results._supports_prompts = supports_prompts\n\n    get_embeddings = embeddings is not False\n    if not results.is_external and results.total_index_size > 0:\n        # No need to load embeddings because the index already has them\n        get_embeddings = False\n\n    if get_embeddings:\n        # Don't immediatly store embeddings in DB; let `add_to_index()` do it\n        if not embeddings_exist:\n            embeddings_field = None\n\n        if roi_field is not None:\n            handle_missing = \"image\"\n            agg_fcn = lambda e: np.mean(e, axis=0)\n        else:\n            handle_missing = \"skip\"\n            agg_fcn = None\n\n        embeddings, sample_ids, label_ids = fbu.get_embeddings(\n            samples,\n            model=_model,\n            patches_field=patches_field or roi_field,\n            embeddings=embeddings,\n            embeddings_field=embeddings_field,\n            force_square=force_square,\n            alpha=alpha,\n            handle_missing=handle_missing,\n            agg_fcn=agg_fcn,\n            batch_size=batch_size,\n            num_workers=num_workers,\n            skip_failures=skip_failures,\n            progress=progress,\n        )\n    else:\n        embeddings = None\n        sample_ids = None\n        label_ids = None\n\n    if embeddings is not None:\n        results.add_to_index(embeddings, sample_ids, label_ids=label_ids)\n\n    brain_method.save_run_results(dataset, brain_key, results)\n\n    return results\n\n\ndef _parse_config(name, **kwargs):\n    if name is None:\n        name = fb.brain_config.default_similarity_backend\n\n    if inspect.isclass(name):\n        return name(**kwargs)\n\n    backends = fb.brain_config.similarity_backends\n\n    if name not in backends:\n        raise ValueError(\n            \"Unsupported backend '%s'. The available backends are %s\"\n            % (name, sorted(backends.keys()))\n        )\n\n    params = deepcopy(backends[name])\n\n    config_cls = kwargs.pop(\"config_cls\", None)\n\n    if config_cls is None:\n        config_cls = params.pop(\"config_cls\", None)\n\n    if config_cls is None:\n        raise ValueError(\"Similarity backend '%s' has no `config_cls`\" % name)\n\n    if etau.is_str(config_cls):\n        config_cls = etau.get_class(config_cls)\n\n    params.update(**kwargs)\n    return config_cls(**params)\n\n\nclass SimilarityConfig(fob.BrainMethodConfig):\n    \"\"\"Similarity configuration.\n\n    Args:\n        embeddings_field (None): the sample field containing the embeddings,\n            if one was provided\n        model (None): the :class:`fiftyone.core.models.Model` or name of the\n            zoo model that was used to compute embeddings, if known\n        model_kwargs (None): a dictionary of optional keyword arguments to pass\n            to the model's ``Config`` when a model name is provided\n        patches_field (None): the sample field defining the patches being\n            analyzed, if any\n        roi_field (None): the sample field defining a region of interest within\n            each image to use to compute embeddings, if any\n        supports_prompts (False): whether this run supports prompt queries\n    \"\"\"\n\n    def __init__(\n        self,\n        embeddings_field=None,\n        model=None,\n        model_kwargs=None,\n        patches_field=None,\n        roi_field=None,\n        supports_prompts=None,\n        **kwargs,\n    ):\n        if model is not None and not etau.is_str(model):\n            model = None\n\n            # We can't declare permanent support for prompts because we don't\n            # know how to load the model in future sessions\n            supports_prompts = None\n\n        self.embeddings_field = embeddings_field\n        self.model = model\n        self.model_kwargs = model_kwargs\n        self.patches_field = patches_field\n        self.roi_field = roi_field\n        self.supports_prompts = supports_prompts\n        super().__init__(**kwargs)\n\n    @property\n    def type(self):\n        return \"similarity\"\n\n    @property\n    def method(self):\n        \"\"\"The name of the similarity backend.\"\"\"\n        raise NotImplementedError(\"subclass must implement method\")\n\n    @property\n    def max_k(self):\n        \"\"\"A maximum k value for nearest neighbor queries, or None if there is\n        no limit.\n        \"\"\"\n        raise NotImplementedError(\"subclass must implement max_k\")\n\n    @property\n    def supports_least_similarity(self):\n        \"\"\"Whether this backend supports least similarity queries.\"\"\"\n        raise NotImplementedError(\n            \"subclass must implement supports_least_similarity\"\n        )\n\n    @property\n    def supported_aggregations(self):\n        \"\"\"A tuple of supported values for the ``aggregation`` parameter of the\n        backend's\n        :meth:`sort_by_similarity() <SimilarityIndex.sort_by_similarity>` and\n        :meth:`_kneighbors() <SimilarityIndex._kneighbors>` methods.\n        \"\"\"\n        raise NotImplementedError(\n            \"subclass must implement supported_aggregations\"\n        )\n\n    def load_credentials(self, **kwargs):\n        self._load_parameters(**kwargs)\n\n    def _load_parameters(self, **kwargs):\n        name = self.method\n        parameters = fb.brain_config.similarity_backends.get(name, {})\n\n        for name, value in kwargs.items():\n            if value is None:\n                value = parameters.get(name, None)\n\n            if value is not None:\n                setattr(self, name, value)\n\n\nclass Similarity(fob.BrainMethod):\n    \"\"\"Base class for similarity factories.\n\n    Args:\n        config: a :class:`SimilarityConfig`\n    \"\"\"\n\n    def initialize(self, samples, brain_key):\n        \"\"\"Initializes a similarity index.\n\n        Args:\n            samples: a :class:`fiftyone.core.collections.SampleColllection`\n            brain_key: the brain key\n\n        Returns:\n            a :class:`SimilarityIndex`\n        \"\"\"\n        raise NotImplementedError(\"subclass must implement initialize()\")\n\n    def get_fields(self, samples, brain_key):\n        fields = []\n        if self.config.patches_field is not None:\n            fields.append(self.config.patches_field)\n\n        if self.config.embeddings_field is not None:\n            fields.append(self.config.embeddings_field)\n\n        return fields\n\n\nclass SimilarityIndex(fob.BrainResults):\n    \"\"\"Base class for similarity indexes.\n\n    Args:\n        samples: the :class:`fiftyone.core.collections.SampleCollection` used\n        config: the :class:`SimilarityConfig` used\n        brain_key: the brain key\n        backend (None): a :class:`Similarity` backend\n    \"\"\"\n\n    def __init__(self, samples, config, brain_key, backend=None):\n        super().__init__(samples, config, brain_key, backend=backend)\n\n        self._model = None\n        self._supports_prompts = None\n        self._last_view = None\n        self._last_views = []\n        self._curr_view = None\n        self._curr_view_allow_missing = None\n        self._curr_view_warn_missing = None\n        self._curr_sample_ids = None\n        self._curr_label_ids = None\n        self._curr_keep_inds = None\n        self._curr_missing_size = None\n\n        self.use_view(samples)\n\n    def __enter__(self):\n        self._last_views.append(self._last_view)\n        return self\n\n    def __exit__(self, *args):\n        try:\n            last_view = self._last_views.pop()\n        except:\n            last_view = self._samples\n\n        self.use_view(last_view)\n\n    @property\n    def config(self):\n        \"\"\"The :class:`SimilarityConfig` for these results.\"\"\"\n        return self._config\n\n    @property\n    def supports_prompts(self):\n        \"\"\"Whether this similarity index supports prompt queries.\"\"\"\n        if self._supports_prompts is not None:\n            return self._supports_prompts\n\n        return self.config.supports_prompts or False\n\n    @property\n    def is_external(self):\n        \"\"\"Whether this similarity index manages its own embeddings (True) or\n        loads them directly from the ``embeddings_field`` of the dataset\n        (False).\n        \"\"\"\n        return True  # assume external unless explicitly overridden\n\n    @property\n    def sample_ids(self):\n        \"\"\"The sample IDs of the full index, or ``None`` if not supported.\"\"\"\n        return None\n\n    @property\n    def label_ids(self):\n        \"\"\"The label IDs of the full index, or ``None`` if not applicable or\n        not supported.\n        \"\"\"\n        return None\n\n    @property\n    def total_index_size(self):\n        \"\"\"The total number of data points in the index.\n\n        If :meth:`use_view` has been called to restrict the index, this value\n        may be larger than the current :meth:`index_size`.\n        \"\"\"\n        raise NotImplementedError(\"subclass must implement total_index_size\")\n\n    @property\n    def has_view(self):\n        \"\"\"Whether the index is currently restricted to a view.\n\n        Use :meth:`use_view` to restrict the index to a view, and use\n        :meth:`clear_view` to reset to the full index.\n        \"\"\"\n\n        # Full dataset\n        if isinstance(self._curr_view, fod.Dataset):\n            return False\n\n        # Full group slices view\n        if (\n            isinstance(self._curr_view, fov.DatasetView)\n            and self._curr_view._root_dataset.media_type == fom.GROUP\n            and len(self._curr_view._stages) == 1\n            and isinstance(self._curr_view._stages[0], fos.SelectGroupSlices)\n            and self._curr_view._pipeline() == []\n        ):\n            return False\n\n        # Full patches view\n        if (\n            self.config.patches_field is not None\n            and isinstance(self._curr_view, fop.PatchesView)\n            and len(self._curr_view._all_stages) == 1\n        ):\n            return False\n\n        return self._curr_view.view() != self._samples.view()\n\n    @property\n    def view(self):\n        \"\"\"The :class:`fiftyone.core.collections.SampleCollection` against\n        which results are currently being generated.\n\n        If :meth:`use_view` has been called, this view may be different than\n        the collection on which the full index was generated.\n        \"\"\"\n        return self._curr_view\n\n    @property\n    def current_sample_ids(self):\n        \"\"\"The sample IDs of the currently active data points in the index.\n\n        If :meth:`use_view` has been called, this may be a subset of the full\n        index.\n\n        If the index does not support full sample ID lists (ie if\n        :meth:`sample_ids` is ``None``), then this will be all sample IDs in\n        the current :meth:`view` regardless of whether all samples are indexed.\n        \"\"\"\n        self._apply_view_if_necessary()\n        return self._curr_sample_ids\n\n    @property\n    def current_label_ids(self):\n        \"\"\"The label IDs of the currently active data points in the index, or\n        ``None`` if not applicable.\n\n        If :meth:`use_view` has been called, this may be a subset of the full\n        index.\n\n        If the index does not support full label ID lists (ie if\n        :meth:`label_ids` is ``None``), then this will be all label IDs in\n        the current :meth:`view` regardless of whether all labels are indexed.\n        \"\"\"\n        self._apply_view_if_necessary()\n        return self._curr_label_ids\n\n    @property\n    def _current_inds(self):\n        \"\"\"The indices of :meth:`current_sample_ids` in :meth:`sample_ids`, or\n        ``None`` if not supported or if the full index is currently being used.\n        \"\"\"\n        self._apply_view_if_necessary()\n        return self._curr_keep_inds\n\n    @property\n    def index_size(self):\n        \"\"\"The number of active data points in the index.\n\n        If :meth:`use_view` has been called to restrict the index, this\n        property will reflect the size of the active index.\n        \"\"\"\n        self._apply_view_if_necessary()\n        return len(self._curr_sample_ids)\n\n    @property\n    def missing_size(self):\n        \"\"\"The total number of data points in :meth:`view` that are missing\n        from this index, or ``None`` if unknown.\n\n        This property is only applicable when :meth:`use_view` has been called,\n        and it will be ``None`` if no data points are missing or when the\n        backend does not support it.\n        \"\"\"\n        self._apply_view_if_necessary()\n        return self._curr_missing_size\n\n    def add_to_index(\n        self,\n        embeddings,\n        sample_ids,\n        label_ids=None,\n        overwrite=True,\n        allow_existing=True,\n        warn_existing=False,\n        reload=True,\n    ):\n        \"\"\"Adds the given embeddings to the index.\n\n        Args:\n            embeddings: a ``num_embeddings x num_dims`` array of embeddings\n            sample_ids: a ``num_embeddings`` array of sample IDs\n            label_ids (None): a ``num_embeddings`` array of label IDs, if\n                applicable\n            overwrite (True): whether to replace (True) or ignore (False)\n                existing embeddings with the same sample/label IDs\n            allow_existing (True): whether to ignore (True) or raise an error\n                (False) when ``overwrite`` is False and a provided ID already\n                exists in the\n            warn_existing (False): whether to log a warning if an embedding is\n                not added to the index because its ID already exists\n            reload (True): whether to call :meth:`reload` to refresh the\n                current view after the update\n        \"\"\"\n        raise NotImplementedError(\"subclass must implement add_to_index()\")\n\n    def remove_from_index(\n        self,\n        sample_ids=None,\n        label_ids=None,\n        allow_missing=True,\n        warn_missing=False,\n        reload=True,\n    ):\n        \"\"\"Removes the specified embeddings from the index.\n\n        Args:\n            sample_ids (None): an array of sample IDs\n            label_ids (None): an array of label IDs, if applicable\n            allow_missing (True): whether to allow the index to not contain IDs\n                that you provide (True) or whether to raise an error in this\n                case (False)\n            warn_missing (False): whether to log a warning if the index does\n                not contain IDs that you provide\n            reload (True): whether to call :meth:`reload` to refresh the\n                current view after the update\n        \"\"\"\n        raise NotImplementedError(\n            \"subclass must implement remove_from_index()\"\n        )\n\n    def get_embeddings(\n        self,\n        sample_ids=None,\n        label_ids=None,\n        allow_missing=True,\n        warn_missing=False,\n    ):\n        \"\"\"Retrieves the embeddings for the given IDs from the index.\n\n        If no IDs are provided, the entire index is returned.\n\n        Args:\n            sample_ids (None): a sample ID or list of sample IDs for which to\n                retrieve embeddings\n            label_ids (None): a label ID or list of label IDs for which to\n                retrieve embeddings\n            allow_missing (True): whether to allow the index to not contain IDs\n                that you provide (True) or whether to raise an error in this\n                case (False)\n            warn_missing (False): whether to log a warning if the index does\n                not contain IDs that you provide\n\n        Returns:\n            a tuple of:\n\n            -   a ``num_embeddings x num_dims`` array of embeddings\n            -   a ``num_embeddings`` array of sample IDs\n            -   a ``num_embeddings`` array of label IDs, if applicable, or else\n                ``None``\n        \"\"\"\n        raise NotImplementedError(\"subclass must implement get_embeddings()\")\n\n    def use_view(\n        self,\n        samples,\n        allow_missing=True,\n        warn_missing=False,\n    ):\n        \"\"\"Restricts the index to the provided view.\n\n        Subsequent calls to methods on this instance will only contain results\n        from the specified view rather than the full index.\n\n        Use :meth:`clear_view` to reset to the full index. Or, equivalently,\n        use the context manager interface as demonstrated below to\n        automatically reset the view when the context exits.\n\n        Example usage::\n\n            import fiftyone as fo\n            import fiftyone.brain as fob\n            import fiftyone.zoo as foz\n\n            dataset = foz.load_zoo_dataset(\"quickstart\")\n\n            results = fob.compute_similarity(dataset)\n            print(results.index_size)  # 200\n\n            view = dataset.take(50)\n\n            with results.use_view(view):\n                print(results.index_size)  # 50\n\n                results.find_unique(10)\n                print(results.unique_ids)\n\n                plot = results.visualize_unique()\n                plot.show()\n\n        Args:\n            samples: a :class:`fiftyone.core.collections.SampleCollection`\n            allow_missing (True): whether to allow the provided collection to\n                contain data points that this index does not contain (True) or\n                whether to raise an error in this case (False)\n            warn_missing (False): whether to log a warning if the provided\n                collection contains data points that this index does not\n                contain\n\n        Returns:\n            self\n        \"\"\"\n        self._last_view = self._curr_view\n        self._curr_view = samples\n        self._curr_view_allow_missing = allow_missing\n        self._curr_view_warn_missing = warn_missing\n        self._curr_sample_ids = None\n        self._curr_label_ids = None\n        self._curr_keep_inds = None\n        self._curr_missing_size = None\n\n        return self\n\n    def _apply_view(self):\n        sample_ids = self.sample_ids\n        label_ids = self.label_ids\n\n        if sample_ids is not None and not self.has_view:\n            keep_inds = None\n            good_inds = None\n        else:\n            sample_ids, label_ids, keep_inds, good_inds = fbu.filter_ids(\n                self._curr_view,\n                sample_ids,\n                label_ids,\n                patches_field=self.config.patches_field,\n                allow_missing=self._curr_view_allow_missing,\n                warn_missing=self._curr_view_warn_missing,\n            )\n\n        if good_inds is not None:\n            missing_size = good_inds.size - np.count_nonzero(good_inds)\n        else:\n            missing_size = None\n\n        self._curr_sample_ids = sample_ids\n        self._curr_label_ids = label_ids\n        self._curr_keep_inds = keep_inds\n        self._curr_missing_size = missing_size\n\n    def _apply_view_if_necessary(self):\n        if self._curr_sample_ids is None:\n            self._apply_view()\n\n    def clear_view(self):\n        \"\"\"Clears the view set by :meth:`use_view`, if any.\n\n        Subsequent operations will be performed on the full index.\n        \"\"\"\n        self.use_view(self._samples)\n\n    def reload(self):\n        \"\"\"Reloads the index for the current view.\n\n        Subclasses may override this method, but by default this method simply\n        passes the current :meth:`view` back into :meth:`use_view`, which\n        updates the index's current ID set based on any changes to the view\n        since the index was last loaded.\n        \"\"\"\n        self.use_view(self._curr_view)\n\n    def cleanup(self):\n        \"\"\"Deletes the similarity index from the backend.\"\"\"\n        raise NotImplementedError(\"subclass must implement cleanup()\")\n\n    def values(self, path_or_expr):\n        \"\"\"Extracts a flat list of values from the given field or expression\n        corresponding to the current :meth:`view`.\n\n        This method always returns values in the same order as\n        :meth:`current_sample_ids` and :meth:`current_label_ids`.\n\n        Args:\n            path_or_expr: the values to extract, which can be:\n\n                -   the name of a sample field or ``embedded.field.name`` from\n                    which to extract numeric or string values\n                -   a :class:`fiftyone.core.expressions.ViewExpression`\n                    defining numeric or string values to compute via\n                    :meth:`fiftyone.core.collections.SampleCollection.values`\n\n        Returns:\n            a list of values\n        \"\"\"\n        samples = self.view\n        patches_field = self.config.patches_field\n\n        if patches_field is not None:\n            ids = self.current_label_ids\n        else:\n            ids = self.current_sample_ids\n\n        return fbu.get_values(\n            samples, path_or_expr, ids, patches_field=patches_field\n        )\n\n    def sort_by_similarity(\n        self,\n        query,\n        k=None,\n        reverse=False,\n        aggregation=\"mean\",\n        dist_field=None,\n        _mongo=False,\n    ):\n        \"\"\"Returns a view that sorts the samples/labels in :meth:`view` by\n        similarity to the specified query.\n\n        When querying by IDs, the query can be any ID(s) in the full index of\n        this instance, even if the current :meth:`view` contains a subset of\n        the full index.\n\n        Args:\n            query: the query, which can be any of the following:\n\n                -   an ID or iterable of IDs\n                -   a ``num_dims`` vector or ``num_queries x num_dims`` array\n                    of vectors\n                -   a prompt or iterable of prompts (if supported by the index)\n\n            k (None): the number of matches to return. Some backends may\n                support ``None``, in which case all samples will be sorted\n            reverse (False): whether to sort by least similarity (True) or\n                greatest similarity (False). Some backends may not support\n                least similarity\n            aggregation (\"mean\"): the aggregation method to use when multiple\n                queries are provided. The default is ``\"mean\"``, which means\n                that the query vectors are averaged prior to searching. Some\n                backends may support additional options\n            dist_field (None): the name of a float field in which to store the\n                distance of each example to the specified query. The field is\n                created if necessary\n\n        Returns:\n            a :class:`fiftyone.core.view.DatasetView`\n        \"\"\"\n        samples = self.view\n        patches_field = self.config.patches_field\n\n        selecting_samples = patches_field is None or isinstance(\n            samples, fop.PatchesView\n        )\n\n        kwargs = dict(\n            query=self._parse_query(query),\n            k=k,\n            reverse=reverse,\n            aggregation=aggregation,\n            return_dists=dist_field is not None,\n        )\n\n        if dist_field is not None:\n            sample_ids, label_ids, dists = self._kneighbors(**kwargs)\n        else:\n            sample_ids, label_ids = self._kneighbors(**kwargs)\n\n        if selecting_samples:\n            if patches_field is not None:\n                ids = label_ids\n            else:\n                ids = sample_ids\n        else:\n            ids = label_ids\n\n        # Store query distances\n        if dist_field is not None:\n            if selecting_samples:\n                values = dict(zip(ids, dists))\n                samples.set_values(dist_field, values, key_field=\"id\")\n            else:\n                label_type, path = samples._get_label_field_path(\n                    patches_field, dist_field\n                )\n                if issubclass(label_type, fol._LABEL_LIST_FIELDS):\n                    samples._set_list_values_by_id(\n                        path,\n                        sample_ids,\n                        label_ids,\n                        dists,\n                        path.rsplit(\".\", 1)[0],\n                    )\n                else:\n                    values = dict(zip(sample_ids, dists))\n                    samples.set_values(path, values, key_field=\"id\")\n\n        # Construct sorted view\n        stages = []\n\n        if selecting_samples:\n            stage = fos.Select(ids, ordered=True)\n            stages.append(stage)\n        else:\n            # Sorting by object similarity but this is not a patches view, so\n            # arrange the samples in order of their first occuring label\n            result_sample_ids = _unique_no_sort(sample_ids)\n            stage = fos.Select(result_sample_ids, ordered=True)\n            stages.append(stage)\n\n            if k is not None:\n                _ids = [ObjectId(_id) for _id in ids]\n                stage = fos.FilterLabels(patches_field, F(\"_id\").is_in(_ids))\n                stages.append(stage)\n\n        if _mongo:\n            pipeline = []\n            for stage in stages:\n                stage.validate(samples)\n                pipeline.extend(stage.to_mongo(samples))\n\n            return pipeline\n\n        view = samples\n        for stage in stages:\n            view = view.add_stage(stage)\n\n        return view\n\n    def _parse_query(self, query):\n        if query is None:\n            raise ValueError(\"At least one query must be provided\")\n\n        if isinstance(query, np.ndarray):\n            # Query by vector(s)\n            if query.size == 0:\n                raise ValueError(\"At least one query vector must be provided\")\n\n            return query\n\n        if etau.is_str(query):\n            query = [query]\n        else:\n            query = list(query)\n\n        if not query:\n            raise ValueError(\"At least one query must be provided\")\n\n        if etau.is_numeric(query[0]):\n            return np.asarray(query)\n\n        try:\n            ObjectId(query[0])\n            is_prompts = False\n        except:\n            is_prompts = True\n\n        if is_prompts:\n            if not self.supports_prompts:\n                raise ValueError(\n                    \"Invalid query '%s'; this model does not support prompts\"\n                    % query[0]\n                )\n\n            model = self.get_model()\n            with model:\n                return model.embed_prompts(query)\n\n        return query\n\n    def _kneighbors(\n        self,\n        query=None,\n        k=None,\n        reverse=False,\n        aggregation=None,\n        return_dists=False,\n    ):\n        \"\"\"Returns the k-nearest neighbors for the given query.\n\n        This method should only return results from the current :meth:`view`.\n\n        Args:\n            query (None): the query, which can be any of the following:\n\n                -   an ID or list of IDs for which to return neighbors\n                -   an embedding or ``num_queries x num_dim`` array of\n                    embeddings for which to return neighbors\n                -   Some backends may also support ``None``, in which case the\n                    neighbors for all points in the current :meth:`view` are\n                    returned\n\n            k (None): the number of neighbors to return. Some backends may\n                enforce upper bounds on this parameter\n            reverse (False): whether to sort by least similarity (True) or\n                greatest similarity (False). Some backends may not support\n                least similarity\n            aggregation (None): an optional aggregation method to use when\n                multiple queries are provided. All backends must support\n                ``\"mean\"``, which averages query vectors prior to searching.\n                Backends may support additional options as well\n            return_dists (False): whether to return query-neighbor distances\n\n        Returns:\n            the query result, in one of the following formats:\n\n                -   a ``(sample_ids, label_ids, dists)`` tuple, when\n                    ``return_dists`` is True\n                -   a ``(sample_ids, label_ids)`` tuple, when ``return_dists``\n                    is False\n\n            In the above, ``sample_ids`` and ``label_ids`` (if applicable)\n            contain the IDs of the nearest neighbors, in one of the following\n            formats:\n\n                -   a list of nearest neighbor IDs, when a single query ID or\n                    vector is provided, **or** when an ``aggregation`` is\n                    provided\n                -   a list of lists of nearest neighbor IDs, when multiple\n                    query IDs/vectors and no ``aggregation`` is provided\n\n            and ``dists`` contains the corresponding query-neighbor distances\n            for each result.\n\n            If the backend supports full index queries (``query=None``), then\n            ``inds`` are returned rather than ``(sample_ids, label_ids)``, in\n            the following format:\n\n                -   a list of arrays of the **integer indexes** (not IDs) of\n                    nearest neighbor points for every vector in the index, when\n                    no query is provided\n        \"\"\"\n        raise NotImplementedError(\"subclass must implement _kneighbors()\")\n\n    def get_model(self):\n        \"\"\"Returns the stored model for this index.\n\n        Returns:\n            a :class:`fiftyone.core.models.Model`\n        \"\"\"\n        if self._model is None:\n            model = self.config.model\n            if model is None:\n                raise ValueError(\"These results don't have a stored model\")\n\n            if etau.is_str(model):\n                model_kwargs = self.config.model_kwargs or {}\n                model = foz.load_zoo_model(model, **model_kwargs)\n\n            self._model = model\n\n        return self._model\n\n    def compute_embeddings(\n        self,\n        samples,\n        model=None,\n        batch_size=None,\n        num_workers=None,\n        skip_failures=True,\n        skip_existing=False,\n        warn_existing=False,\n        force_square=False,\n        alpha=None,\n        progress=None,\n    ):\n        \"\"\"Computes embeddings for the given samples using this backend's\n        model.\n\n        Args:\n            samples: a :class:`fiftyone.core.collections.SampleCollection`\n            model (None): a :class:`fiftyone.core.models.Model` to apply. If\n                not provided, these results must have been created with a\n                stored model, which will be used by default\n            batch_size (None): an optional batch size to use when computing\n                embeddings. Only applicable when a ``model`` is provided\n            num_workers (None): the number of workers to use when loading\n                images. Only applicable when a Torch-based model is being used\n                to compute embeddings\n            skip_failures (True): whether to gracefully continue without\n                raising an error if embeddings cannot be generated for a sample\n            skip_existing (False): whether to skip generating embeddings for\n                sample/label IDs that are already in the index\n            warn_existing (False): whether to log a warning if any IDs already\n                exist in the index\n            force_square (False): whether to minimally manipulate the patch\n                bounding boxes into squares prior to extraction. Only\n                applicable when a ``model`` and ``patches_field`` are specified\n            alpha (None): an optional expansion/contraction to apply to the\n                patches before extracting them, in ``[-1, inf)``. If provided,\n                the length and width of the box are expanded (or contracted,\n                when ``alpha < 0``) by ``(100 * alpha)%``. For example, set\n                ``alpha = 0.1`` to expand the boxes by 10%, and set\n                ``alpha = -0.1`` to contract the boxes by 10%. Only applicable\n                when a ``model`` and ``patches_field`` are specified\n            progress (None): whether to render a progress bar (True/False), use\n                the default value ``fiftyone.config.show_progress_bars``\n                (None), or a progress callback function to invoke instead\n\n        Returns:\n            a tuple of:\n\n            -   a ``num_embeddings x num_dims`` array of embeddings\n            -   a ``num_embeddings`` array of sample IDs\n            -   a ``num_embeddings`` array of label IDs, if applicable, or else\n                ``None``\n        \"\"\"\n        if model is None:\n            model = self.get_model()\n\n        if skip_existing:\n            if self.config.patches_field is not None:\n                index_ids = self.label_ids\n            else:\n                index_ids = self.sample_ids\n\n            if index_ids is not None:\n                samples = fbu.skip_ids(\n                    samples,\n                    index_ids,\n                    patches_field=self.config.patches_field,\n                    warn_existing=warn_existing,\n                )\n            else:\n                logger.warning(\n                    \"This index does not support skipping existing IDs\"\n                )\n\n        if self.config.roi_field is not None:\n            patches_field = self.config.roi_field\n            handle_missing = \"image\"\n            agg_fcn = lambda e: np.mean(e, axis=0)\n        else:\n            patches_field = self.config.patches_field\n            handle_missing = \"skip\"\n            agg_fcn = None\n\n        return fbu.get_embeddings(\n            samples,\n            model=model,\n            patches_field=patches_field,\n            force_square=force_square,\n            alpha=alpha,\n            handle_missing=handle_missing,\n            agg_fcn=agg_fcn,\n            batch_size=batch_size,\n            num_workers=num_workers,\n            skip_failures=skip_failures,\n            progress=progress,\n        )\n\n    @classmethod\n    def _from_dict(cls, d, samples, config, brain_key):\n        \"\"\"Builds a :class:`SimilarityIndex` from a JSON representation of it.\n\n        Args:\n            d: a JSON dict\n            samples: the :class:`fiftyone.core.collections.SampleCollection`\n                for the run\n            config: the :class:`SimilarityConfig` for the run\n            brain_key: the brain key\n\n        Returns:\n            a :class:`SimilarityIndex`\n        \"\"\"\n        raise NotImplementedError(\"subclass must implement _from_dict()\")\n\n\nclass DuplicatesMixin(object):\n    \"\"\"Mixin for :class:`SimilarityIndex` instances that support duplicate\n    detection operations.\n\n    Similarity backends can expose this mixin simply by implementing\n    :meth:`_radius_neighbors`.\n    \"\"\"\n\n    def __init__(self):\n        self._thresh = None\n        self._unique_ids = None\n        self._duplicate_ids = None\n        self._neighbors_map = None\n\n    @property\n    def thresh(self):\n        \"\"\"The threshold used by the last call to :meth:`find_duplicates` or\n        :meth:`find_unique`.\n        \"\"\"\n        return self._thresh\n\n    @property\n    def unique_ids(self):\n        \"\"\"A list of unique IDs from the last call to :meth:`find_duplicates`\n        or :meth:`find_unique`.\n        \"\"\"\n        return self._unique_ids\n\n    @property\n    def duplicate_ids(self):\n        \"\"\"A list of duplicate IDs from the last call to\n        :meth:`find_duplicates` or :meth:`find_unique`.\n        \"\"\"\n        return self._duplicate_ids\n\n    @property\n    def neighbors_map(self):\n        \"\"\"A dictionary mapping IDs to lists of ``(dup_id, dist)`` tuples from\n        the last call to :meth:`find_duplicates`.\n        \"\"\"\n        return self._neighbors_map\n\n    def _radius_neighbors(self, query=None, thresh=None, return_dists=False):\n        \"\"\"Returns the neighbors within the given distance threshold for the\n        given query.\n\n        This method should only return results from the current :meth:`view`.\n\n        Args:\n            query (None): the query, which can be any of the following:\n\n                -   an ID or list of IDs for which to return neighbors\n                -   an embedding or ``num_queries x num_dim`` array of\n                    embeddings for which to return neighbors\n                -   ``None``, in which case the neighbors for all points in the\n                    current :meth:`view` are returned\n\n            thresh (None): the distance threshold to use\n            return_dists (False): whether to return query-neighbor distances\n\n        Returns:\n            the query result, in one of the following formats:\n\n                -   a ``(sample_ids, label_ids, dists)`` tuple, when\n                    ``return_dists`` is True\n                -   a ``(sample_ids, label_ids)`` tuple, when ``return_dists``\n                    is False\n\n            In the above, ``sample_ids`` and ``label_ids`` (if applicable)\n            contain the IDs of the nearest neighbors, in one of the following\n            formats:\n\n                -   a list of nearest neighbor IDs, when a single query ID or\n                    vector is provided, **or** when an ``aggregation`` is\n                    provided\n                -   a list of lists of nearest neighbor IDs, when multiple\n                    query IDs/vectors and no ``aggregation`` is provided\n\n            and ``dists`` contains the corresponding query-neighbor distances\n            for each result.\n\n            If the backend supports full index queries (``query=None``), then\n            ``inds`` are returned rather than ``(sample_ids, label_ids)``, in\n            the following format:\n\n                -   a list of arrays of the **integer indexes** (not IDs) of\n                    nearest neighbor points for every vector in the index, when\n                    no query is provided\n        \"\"\"\n        raise NotImplementedError(\n            \"subclass must implement _radius_neighbors()\"\n        )\n\n    def find_duplicates(self, thresh=None, fraction=None):\n        \"\"\"Queries the index to find near-duplicate examples based on the\n        provided parameters.\n\n        Calling this method populates the :meth:`unique_ids`,\n        :meth:`duplicate_ids`, :attr:`neighbors_map`, and :attr:`thresh`\n        properties of this object with the results of the query.\n\n        Use :meth:`duplicates_view` and :meth:`visualize_duplicates` to analyze\n        the results generated by this method.\n\n        Args:\n            thresh (None): a distance threshold to use to determine duplicates.\n                If specified, the non-duplicate set will be the (approximately)\n                largest set such that all pairwise distances between\n                non-duplicate examples are greater than this threshold\n            fraction (None): a desired fraction of images/patches to tag as\n                duplicates, in ``[0, 1]``. In this case ``thresh`` is\n                automatically tuned to achieve the desired fraction of\n                duplicates\n        \"\"\"\n        if self.config.patches_field is not None:\n            logger.info(\"Computing duplicate patches...\")\n            ids = self.current_label_ids\n        else:\n            logger.info(\"Computing duplicate samples...\")\n            ids = self.current_sample_ids\n\n        # Detect duplicates\n        if fraction is not None:\n            num_keep = int(round(min(max(0, 1.0 - fraction), 1) * len(ids)))\n            unique_ids, thresh = self._remove_duplicates_count(\n                num_keep, ids, init_thresh=thresh\n            )\n        else:\n            unique_ids = self._remove_duplicates_thresh(thresh, ids)\n\n        _unique_ids = set(unique_ids)\n        duplicate_ids = [_id for _id in ids if _id not in _unique_ids]\n\n        # Locate nearest non-duplicate for each duplicate\n        if unique_ids and duplicate_ids:\n            if self.config.patches_field is not None:\n                unique_view = self._samples.select_labels(\n                    ids=unique_ids, fields=self.config.patches_field\n                )\n            else:\n                unique_view = self._samples.select(unique_ids)\n\n            with self.use_view(unique_view):\n                _sample_ids, _label_ids, dists = self._kneighbors(\n                    query=duplicate_ids, k=1, return_dists=True\n                )\n                if self.config.patches_field is not None:\n                    nearest_ids = _label_ids\n                else:\n                    nearest_ids = _sample_ids\n\n            neighbors_map = defaultdict(list)\n            for dup_id, _ids, _dists in zip(duplicate_ids, nearest_ids, dists):\n                neighbors_map[_ids[0]].append((dup_id, _dists[0]))\n\n            neighbors_map = {\n                k: sorted(v, key=lambda t: t[1])\n                for k, v in neighbors_map.items()\n            }\n        else:\n            neighbors_map = {}\n\n        logger.info(\"Duplicates computation complete\")\n\n        self._thresh = thresh\n        self._unique_ids = unique_ids\n        self._duplicate_ids = duplicate_ids\n        self._neighbors_map = neighbors_map\n\n    def find_unique(self, count):\n        \"\"\"Queries the index to select a subset of examples of the specified\n        size that are maximally unique with respect to each other.\n\n        Calling this method populates the :meth:`unique_ids`,\n        :meth:`duplicate_ids`, and :attr:`thresh` properties of this object\n        with the results of the query.\n\n        Use :meth:`unique_view` and :meth:`visualize_unique` to analyze the\n        results generated by this method.\n\n        Args:\n            count: the desired number of unique examples\n        \"\"\"\n        if self.config.patches_field is not None:\n            logger.info(\"Computing unique patches...\")\n            ids = self.current_label_ids\n        else:\n            logger.info(\"Computing unique samples...\")\n            ids = self.current_sample_ids\n\n        unique_ids, thresh = self._remove_duplicates_count(count, ids)\n\n        _unique_ids = set(unique_ids)\n        duplicate_ids = [_id for _id in ids if _id not in _unique_ids]\n\n        logger.info(\"Uniqueness computation complete\")\n\n        self._thresh = thresh\n        self._unique_ids = unique_ids\n        self._duplicate_ids = duplicate_ids\n        self._neighbors_map = None\n\n    def _remove_duplicates_count(self, num_keep, ids, init_thresh=None):\n        if init_thresh is not None:\n            thresh = init_thresh\n        else:\n            thresh = 1\n\n        if num_keep <= 0:\n            logger.info(\n                \"threshold: -, kept: %d, target: %d\", num_keep, num_keep\n            )\n            return set(), None\n\n        if num_keep >= len(ids):\n            logger.info(\n                \"threshold: -, kept: %d, target: %d\", num_keep, num_keep\n            )\n            return set(ids), None\n\n        thresh_lims = [0, None]\n        num_target = num_keep\n        num_keep = -1\n\n        while True:\n            keep_ids = self._remove_duplicates_thresh(thresh, ids)\n            num_keep_last = num_keep\n            num_keep = len(keep_ids)\n\n            logger.info(\n                \"threshold: %f, kept: %d, target: %d\",\n                thresh,\n                num_keep,\n                num_target,\n            )\n\n            if num_keep == num_target or (\n                num_keep == num_keep_last\n                and thresh_lims[1] is not None\n                and thresh_lims[1] - thresh_lims[0] < 1e-6\n            ):\n                break\n\n            if num_keep < num_target:\n                # Need to decrease threshold\n                thresh_lims[1] = thresh\n                thresh = 0.5 * (thresh_lims[0] + thresh)\n            else:\n                # Need to increase threshold\n                thresh_lims[0] = thresh\n                if thresh_lims[1] is not None:\n                    thresh = 0.5 * (thresh + thresh_lims[1])\n                else:\n                    thresh *= 2\n\n        return keep_ids, thresh\n\n    def _remove_duplicates_thresh(self, thresh, ids):\n        nearest_inds = self._radius_neighbors(thresh=thresh)\n\n        n = len(ids)\n        keep = set(range(n))\n        for ind in range(n):\n            if ind in keep:\n                keep -= {i for i in nearest_inds[ind] if i > ind}\n\n        return [ids[i] for i in keep]\n\n    def plot_distances(self, bins=100, log=False, backend=\"plotly\", **kwargs):\n        \"\"\"Plots a histogram of the distance between each example and its\n        nearest neighbor.\n\n        If `:meth:`find_duplicates` or :meth:`find_unique` has been executed,\n        the threshold used is also indicated on the plot.\n\n        Args:\n            bins (100): the number of bins to use\n            log (False): whether to use a log scale y-axis\n            backend (\"plotly\"): the plotting backend to use. Supported values\n                are ``(\"plotly\", \"matplotlib\")``\n            **kwargs: keyword arguments for the backend plotting method\n\n        Returns:\n            one of the following:\n\n            -   a :class:`fiftyone.core.plots.plotly.PlotlyNotebookPlot`, if\n                you are working in a notebook context and the plotly backend is\n                used\n            -   a plotly or matplotlib figure, otherwise\n        \"\"\"\n        metric = self.config.metric\n        thresh = self.thresh\n\n        _, dists = self._kneighbors(k=1, return_dists=True)\n        dists = np.array([d[0] for d in dists])\n\n        if backend == \"matplotlib\":\n            return _plot_distances_mpl(\n                dists, metric, thresh, bins, log, **kwargs\n            )\n\n        return _plot_distances_plotly(\n            dists, metric, thresh, bins, log, **kwargs\n        )\n\n    def duplicates_view(\n        self,\n        type_field=None,\n        id_field=None,\n        dist_field=None,\n        sort_by=\"distance\",\n        reverse=False,\n    ):\n        \"\"\"Returns a view that contains only the duplicate examples and their\n        corresponding nearest non-duplicate examples generated by the last call\n        to :meth:`find_duplicates`.\n\n        If you are analyzing patches, the returned view will be a\n        :class:`fiftyone.core.patches.PatchesView`.\n\n        The examples are organized so that each non-duplicate is immediately\n        followed by all duplicate(s) that are nearest to it.\n\n        Args:\n            type_field (None): the name of a string field in which to store\n                ``\"nearest\"`` and ``\"duplicate\"`` labels. The field is created\n                if necessary\n            id_field (None): the name of a string field in which to store the\n                ID of the nearest non-duplicate for each example in the view.\n                The field is created if necessary\n            dist_field (None): the name of a float field in which to store the\n                distance of each example to its nearest non-duplicate example.\n                The field is created if necessary\n            sort_by (\"distance\"): specifies how to sort the groups of duplicate\n                examples. The supported values are:\n\n                -   ``\"distance\"``: sort the groups by the distance between the\n                    non-duplicate and its (nearest, if multiple) duplicate\n                -   ``\"count\"``: sort the groups by the number of duplicate\n                    examples\n\n            reverse (False): whether to sort in descending order\n\n        Returns:\n            a :class:`fiftyone.core.view.DatasetView`\n        \"\"\"\n        if self.neighbors_map is None:\n            raise ValueError(\n                \"You must first call `find_duplicates()` to generate results\"\n            )\n\n        samples = self.view\n        patches_field = self.config.patches_field\n        neighbors_map = self.neighbors_map\n\n        if patches_field is not None and not isinstance(\n            samples, fop.PatchesView\n        ):\n            samples = samples.to_patches(patches_field)\n\n        if sort_by == \"distance\":\n            key = lambda kv: min(e[1] for e in kv[1])\n        elif sort_by == \"count\":\n            key = lambda kv: len(kv[1])\n        else:\n            raise ValueError(\n                \"Invalid sort_by='%s'; supported values are %s\"\n                % (sort_by, (\"distance\", \"count\"))\n            )\n\n        existing_ids = set(samples.values(\"id\"))\n        neighbors = [\n            (k, v) for k, v in neighbors_map.items() if k in existing_ids\n        ]\n\n        ids = []\n        types = {}\n        nearest_ids = {}\n        dists = {}\n        for _id, duplicates in sorted(neighbors, key=key, reverse=reverse):\n            ids.append(_id)\n            types[_id] = \"nearest\"\n            nearest_ids[_id] = _id\n            dists[_id] = 0.0\n\n            for dup_id, dist in duplicates:\n                ids.append(dup_id)\n                types[dup_id] = \"duplicate\"\n                nearest_ids[dup_id] = _id\n                dists[dup_id] = dist\n\n        if type_field is not None:\n            samples.set_values(type_field, types, key_field=\"id\")\n\n        if id_field is not None:\n            samples.set_values(id_field, nearest_ids, key_field=\"id\")\n\n        if dist_field is not None:\n            samples.set_values(dist_field, dists, key_field=\"id\")\n\n        return samples.select(ids, ordered=True)\n\n    def unique_view(self):\n        \"\"\"Returns a view that contains only the unique examples generated by\n        the last call to :meth:`find_duplicates` or :meth:`find_unique`.\n\n        If you are analyzing patches, the returned view will be a\n        :class:`fiftyone.core.patches.PatchesView`.\n\n        Returns:\n            a :class:`fiftyone.core.view.DatasetView`\n        \"\"\"\n        if self.unique_ids is None:\n            raise ValueError(\n                \"You must first call `find_unique()` or `find_duplicates()` \"\n                \"to generate results\"\n            )\n\n        samples = self.view\n        patches_field = self.config.patches_field\n        unique_ids = self.unique_ids\n\n        if patches_field is not None and not isinstance(\n            samples, fop.PatchesView\n        ):\n            samples = samples.to_patches(patches_field)\n\n        return samples.select(unique_ids)\n\n    def visualize_duplicates(self, visualization, backend=\"plotly\", **kwargs):\n        \"\"\"Generates an interactive scatterplot of the results generated by the\n        last call to :meth:`find_duplicates`.\n\n        The ``visualization`` argument can be any visualization computed on the\n        same dataset (or subset of it) as long as it contains every\n        sample/object in the view whose results you are visualizing.\n\n        The points are colored based on the following partition:\n\n            -   \"duplicate\": duplicate example\n            -   \"nearest\": nearest neighbor of a duplicate example\n            -   \"unique\": the remaining unique examples\n\n        Edges are also drawn between each duplicate and its nearest\n        non-duplicate neighbor.\n\n        You can attach plots generated by this method to an App session via its\n        :attr:`fiftyone.core.session.Session.plots` attribute, which will\n        automatically sync the session's view with the currently selected\n        points in the plot.\n\n        Args:\n            visualization: a\n                :class:`fiftyone.brain.visualization.VisualizationResults`\n                instance to use to visualize the results\n            backend (\"plotly\"): the plotting backend to use. Supported values\n                are ``(\"plotly\", \"matplotlib\")``\n            **kwargs: keyword arguments for the backend plotting method:\n\n                -   \"plotly\" backend: :meth:`fiftyone.core.plots.plotly.scatterplot`\n                -   \"matplotlib\" backend: :meth:`fiftyone.core.plots.matplotlib.scatterplot`\n\n        Returns:\n            a :class:`fiftyone.core.plots.base.InteractivePlot`\n        \"\"\"\n        if self.neighbors_map is None:\n            raise ValueError(\n                \"You must first call `find_duplicates()` to generate results\"\n            )\n\n        samples = self.view\n        duplicate_ids = self.duplicate_ids\n        neighbors_map = self.neighbors_map\n        patches_field = self.config.patches_field\n\n        dup_ids = set(duplicate_ids)\n        nearest_ids = set(neighbors_map.keys())\n\n        with visualization.use_view(samples, allow_missing=True):\n            if patches_field is not None:\n                ids = visualization.current_label_ids\n            else:\n                ids = visualization.current_sample_ids\n\n            labels = []\n            for _id in ids:\n                if _id in dup_ids:\n                    label = \"duplicate\"\n                elif _id in nearest_ids:\n                    label = \"nearest\"\n                else:\n                    label = \"unique\"\n\n                labels.append(label)\n\n            if backend == \"plotly\":\n                kwargs[\"edges\"] = _build_edges(ids, neighbors_map)\n                kwargs[\"edges_title\"] = \"neighbors\"\n                kwargs[\"labels_title\"] = \"type\"\n\n            return visualization.visualize(\n                labels=labels,\n                classes=[\"unique\", \"nearest\", \"duplicate\"],\n                backend=backend,\n                **kwargs,\n            )\n\n    def visualize_unique(self, visualization, backend=\"plotly\", **kwargs):\n        \"\"\"Generates an interactive scatterplot of the results generated by the\n        last call to :meth:`find_unique`.\n\n        The ``visualization`` argument can be any visualization computed on the\n        same dataset (or subset of it) as long as it contains every\n        sample/object in the view whose results you are visualizing.\n\n        The points are colored based on the following partition:\n\n            -   \"unique\": the unique examples\n            -   \"other\": the other examples\n\n        You can attach plots generated by this method to an App session via its\n        :attr:`fiftyone.core.session.Session.plots` attribute, which will\n        automatically sync the session's view with the currently selected\n        points in the plot.\n\n        Args:\n            visualization: a\n                :class:`fiftyone.brain.visualization.VisualizationResults`\n                instance to use to visualize the results\n            backend (\"plotly\"): the plotting backend to use. Supported values\n                are ``(\"plotly\", \"matplotlib\")``\n            **kwargs: keyword arguments for the backend plotting method:\n\n                -   \"plotly\" backend: :meth:`fiftyone.core.plots.plotly.scatterplot`\n                -   \"matplotlib\" backend: :meth:`fiftyone.core.plots.matplotlib.scatterplot`\n\n        Returns:\n            a :class:`fiftyone.core.plots.base.InteractivePlot`\n        \"\"\"\n        if self.unique_ids is None:\n            raise ValueError(\n                \"You must first call `find_unique()` to generate results\"\n            )\n\n        samples = self.view\n        unique_ids = self.unique_ids\n        patches_field = self.config.patches_field\n\n        unique_ids = set(unique_ids)\n\n        with visualization.use_view(samples, allow_missing=True):\n            if patches_field is not None:\n                ids = visualization.current_label_ids\n            else:\n                ids = visualization.current_sample_ids\n\n            labels = []\n            for _id in ids:\n                if _id in unique_ids:\n                    label = \"unique\"\n                else:\n                    label = \"other\"\n\n                labels.append(label)\n\n            return visualization.visualize(\n                labels=labels,\n                classes=[\"other\", \"unique\"],\n                backend=backend,\n                **kwargs,\n            )\n\n\ndef _unique_no_sort(values):\n    seen = set()\n    return [v for v in values if v not in seen and not seen.add(v)]\n\n\ndef _build_edges(ids, neighbors_map):\n    inds_map = {_id: idx for idx, _id in enumerate(ids)}\n\n    edges = []\n    for nearest_id, duplicates in neighbors_map.items():\n        nearest_ind = inds_map[nearest_id]\n        for dup_id, _ in duplicates:\n            dup_ind = inds_map[dup_id]\n            edges.append((dup_ind, nearest_ind))\n\n    return np.array(edges)\n\n\ndef _plot_distances_plotly(dists, metric, thresh, bins, log, **kwargs):\n    import plotly.graph_objects as go\n    import fiftyone.core.plots.plotly as fopl\n\n    counts, edges = np.histogram(dists, bins=bins)\n    left_edges = edges[:-1]\n    widths = edges[1:] - edges[:-1]\n    customdata = np.stack((edges[:-1], edges[1:]), axis=1)\n\n    hover_lines = [\n        \"<b>count: %{y}</b>\",\n        \"distance: [%{customdata[0]:.2f}, %{customdata[1]:.2f}]\",\n    ]\n    hovertemplate = \"<br>\".join(hover_lines) + \"<extra></extra>\"\n\n    bar = go.Bar(\n        x=left_edges,\n        y=counts,\n        width=widths,\n        customdata=customdata,\n        offset=0,\n        marker_color=\"#FF6D04\",\n        hovertemplate=hovertemplate,\n        showlegend=False,\n    )\n\n    traces = [bar]\n\n    if thresh is not None:\n        line = go.Scatter(\n            x=[thresh, thresh],\n            y=[0, max(counts)],\n            mode=\"lines\",\n            line=dict(color=\"#17191C\", width=3),\n            hovertemplate=\"<b>thresh: %{x}</b><extra></extra>\",\n            showlegend=False,\n        )\n        traces.append(line)\n\n    figure = go.Figure(traces)\n\n    figure.update_layout(\n        xaxis_title=\"nearest neighbor distance (%s)\" % metric,\n        yaxis_title=\"count\",\n        hovermode=\"x\",\n        yaxis_rangemode=\"tozero\",\n    )\n\n    if log:\n        figure.update_layout(yaxis_type=\"log\")\n\n    figure.update_layout(**fopl._DEFAULT_LAYOUT)\n    figure.update_layout(**kwargs)\n\n    if foc.is_jupyter_context():\n        figure = fopl.PlotlyNotebookPlot(figure)\n\n    return figure\n\n\ndef _plot_distances_mpl(\n    dists, metric, thresh, bins, log, ax=None, figsize=None, **kwargs\n):\n    import matplotlib.pyplot as plt\n\n    if ax is None:\n        fig, ax = plt.subplots()\n    else:\n        fig = ax.figure\n\n    counts, edges = np.histogram(dists, bins=bins)\n    left_edges = edges[:-1]\n    widths = edges[1:] - edges[:-1]\n\n    ax.bar(\n        left_edges,\n        counts,\n        width=widths,\n        align=\"edge\",\n        color=\"#FF6D04\",\n        **kwargs,\n    )\n\n    if thresh is not None:\n        ax.vlines(thresh, 0, max(counts), color=\"#17191C\", linewidth=3)\n\n    if log:\n        ax.set_yscale(\"log\")\n\n    ax.set_xlabel(\"nearest neighbor distance (%s)\" % metric)\n    ax.set_ylabel(\"count\")\n\n    if figsize is not None:\n        fig.set_size_inches(*figsize)\n\n    plt.tight_layout()\n\n    return fig\n"
  },
  {
    "path": "fiftyone/brain/visualization.py",
    "content": "\"\"\"\nVisualization interface.\n\n| Copyright 2017-2026, Voxel51, Inc.\n| `voxel51.com <https://voxel51.com/>`_\n|\n\"\"\"\nfrom copy import deepcopy\nimport inspect\nimport logging\nfrom packaging import version\n\nimport numpy as np\nimport sklearn\nimport sklearn.decomposition as skd\nimport sklearn.manifold as skm\n\nimport eta.core.utils as etau\n\nimport fiftyone.brain as fb\nimport fiftyone.core.brain as fob\nimport fiftyone.core.expressions as foe\nimport fiftyone.core.fields as fof\nimport fiftyone.core.plots as fop\nimport fiftyone.core.utils as fou\nimport fiftyone.core.validation as fov\n\nfbu = fou.lazy_import(\"fiftyone.brain.internal.core.utils\")\n\numap = fou.lazy_import(\"umap\")\n\n\nlogger = logging.getLogger(__name__)\n\n_DEFAULT_MODEL = \"mobilenet-v2-imagenet-torch\"\n_DEFAULT_BATCH_SIZE = None\n\n\ndef compute_visualization(\n    samples,\n    patches_field,\n    embeddings,\n    points,\n    create_index,\n    points_field,\n    brain_key,\n    num_dims,\n    method,\n    similarity_index,\n    model,\n    model_kwargs,\n    force_square,\n    alpha,\n    batch_size,\n    num_workers,\n    skip_failures,\n    progress,\n    **kwargs,\n):\n    \"\"\"See ``fiftyone/brain/__init__.py``.\"\"\"\n\n    fov.validate_collection(samples)\n\n    if method == \"manual\" and points is None:\n        raise ValueError(\n            \"You must provide your own `points` when `method='manual'`\"\n        )\n\n    if points is not None:\n        method = \"manual\"\n        model = None\n        embeddings = None\n        embeddings_field = None\n        num_dims = _get_dimension(points)\n\n    if create_index and points_field is None:\n        points_field = brain_key\n\n    if points_field is not None and num_dims != 2:\n        raise ValueError(\"`points_field` is only supported when `num_dims=2`\")\n\n    if etau.is_str(embeddings):\n        embeddings_field, embeddings_exist = fbu.parse_data_field(\n            samples,\n            embeddings,\n            patches_field=patches_field,\n            data_type=\"embeddings\",\n        )\n        embeddings = None\n    else:\n        embeddings_field = None\n        embeddings_exist = None\n\n    if points_field is not None:\n        points_field, _ = fbu.parse_data_field(\n            samples,\n            points_field,\n            patches_field=patches_field,\n            data_type=\"points\",\n        )\n\n    if etau.is_str(similarity_index):\n        similarity_index = samples.load_brain_results(similarity_index)\n\n    if (\n        model is None\n        and points is None\n        and embeddings is None\n        and similarity_index is None\n        and not embeddings_exist\n    ):\n        model = _DEFAULT_MODEL\n        if batch_size is None:\n            batch_size = _DEFAULT_BATCH_SIZE\n\n    config = _parse_config(\n        method,\n        embeddings_field=embeddings_field,\n        points_field=points_field,\n        similarity_index=similarity_index,\n        model=model,\n        model_kwargs=model_kwargs,\n        patches_field=patches_field,\n        num_dims=num_dims,\n        **kwargs,\n    )\n\n    brain_method = config.build()\n    brain_method.ensure_requirements()\n\n    if brain_key is not None:\n        brain_method.register_run(samples, brain_key)\n\n    if points is None:\n        embeddings, sample_ids, label_ids = fbu.get_embeddings(\n            samples,\n            model=model,\n            model_kwargs=model_kwargs,\n            patches_field=patches_field,\n            embeddings_field=embeddings_field,\n            embeddings=embeddings,\n            similarity_index=similarity_index,\n            force_square=force_square,\n            alpha=alpha,\n            batch_size=batch_size,\n            num_workers=num_workers,\n            skip_failures=skip_failures,\n            progress=progress,\n        )\n\n        logger.info(\"Generating visualization...\")\n        points = brain_method.fit(embeddings)\n    else:\n        points, sample_ids, label_ids = fbu.parse_data(\n            samples,\n            patches_field=patches_field,\n            data=points,\n            data_type=\"points\",\n        )\n\n    if points_field is not None:\n        _generate_spatial_index(\n            samples,\n            points,\n            points_field,\n            sample_ids,\n            label_ids=label_ids,\n            patches_field=patches_field,\n            create_index=create_index,\n            progress=progress,\n        )\n\n    results = VisualizationResults(\n        samples,\n        config,\n        brain_key,\n        points,\n        sample_ids=sample_ids,\n        label_ids=label_ids,\n    )\n\n    brain_method.save_run_results(samples, brain_key, results)\n\n    return results\n\n\ndef values(results, path_or_expr):\n    samples = results.view\n    patches_field = results.config.patches_field\n    if patches_field is not None:\n        ids = results.current_label_ids\n    else:\n        ids = results.current_sample_ids\n\n    return fbu.get_values(\n        samples, path_or_expr, ids, patches_field=patches_field\n    )\n\n\ndef visualize(\n    results,\n    labels=None,\n    sizes=None,\n    classes=None,\n    backend=\"plotly\",\n    **kwargs,\n):\n    points = results.current_points\n    samples = results.view\n    patches_field = results.config.patches_field\n    good_inds = results._curr_good_inds\n    if patches_field is not None:\n        ids = results.current_label_ids\n    else:\n        ids = results.current_sample_ids\n\n    if good_inds is not None:\n        if etau.is_container(labels) and not _is_expr(labels):\n            labels = fbu.filter_values(\n                labels, good_inds, patches_field=patches_field\n            )\n\n        if etau.is_container(sizes) and not _is_expr(sizes):\n            sizes = fbu.filter_values(\n                sizes, good_inds, patches_field=patches_field\n            )\n\n    if labels is not None and _is_expr(labels):\n        labels = fbu.get_values(\n            samples, labels, ids, patches_field=patches_field\n        )\n\n    if sizes is not None and _is_expr(sizes):\n        sizes = fbu.get_values(\n            samples, sizes, ids, patches_field=patches_field\n        )\n\n    return fop.scatterplot(\n        points,\n        samples=samples,\n        ids=ids,\n        link_field=patches_field,\n        labels=labels,\n        sizes=sizes,\n        classes=classes,\n        backend=backend,\n        **kwargs,\n    )\n\n\ndef _is_expr(arg):\n    return isinstance(arg, (foe.ViewExpression, dict))\n\n\ndef _parse_config(name, **kwargs):\n    if name is None:\n        name = fb.brain_config.default_visualization_method\n\n    if inspect.isclass(name):\n        return name(**kwargs)\n\n    methods = fb.brain_config.visualization_methods\n\n    if name not in methods:\n        raise ValueError(\n            \"Unsupported method '%s'. The available methods are %s\"\n            % (name, sorted(methods.keys()))\n        )\n\n    params = deepcopy(methods[name])\n\n    config_cls = kwargs.pop(\"config_cls\", None)\n\n    if config_cls is None:\n        config_cls = params.pop(\"config_cls\", None)\n\n    if config_cls is None:\n        raise ValueError(\n            \"Visualization method '%s' has no `config_cls`\" % name\n        )\n\n    if etau.is_str(config_cls):\n        config_cls = etau.get_class(config_cls)\n\n    params.update(**kwargs)\n    return config_cls(**params)\n\n\ndef _get_dimension(points):\n    if isinstance(points, dict):\n        points = next(iter(points.values()), None)\n\n    if isinstance(points, list):\n        points = next(iter(points), None)\n\n    if points is None:\n        return 2\n\n    return points.shape[-1]\n\n\ndef _generate_spatial_index(\n    samples,\n    points,\n    points_field,\n    sample_ids,\n    label_ids=None,\n    patches_field=None,\n    create_index=True,\n    progress=False,\n):\n    # Indexes are not currently usable on patch visualizations\n    if create_index and patches_field is not None:\n        create_index = False\n\n    dataset = samples._root_dataset\n    if patches_field is not None:\n        _, points_field = dataset._get_label_field_path(\n            patches_field, points_field\n        )\n\n    logger.info(\"Generating spatial index in field '%s'...\", points_field)\n\n    dataset.add_sample_field(\n        points_field, fof.ListField, subfield=fof.FloatField\n    )\n\n    points = points.astype(float)\n\n    if create_index:\n        min_val, max_val = points.min(), points.max()\n        dataset.create_index([(points_field, \"2d\")], min=min_val, max=max_val)\n\n    points = points.tolist()\n    if patches_field is not None:\n        values = dict(zip(label_ids, points))\n        dataset.set_label_values(points_field, values, progress=progress)\n    else:\n        values = dict(zip(sample_ids, points))\n        dataset.set_values(\n            points_field, values, key_field=\"id\", progress=progress\n        )\n\n\nclass VisualizationResults(fob.BrainResults):\n    \"\"\"Class storing the results of\n    :meth:`fiftyone.brain.compute_visualization`.\n\n    Args:\n        samples: the :class:`fiftyone.core.collections.SampleCollection` used\n        config: the :class:`VisualizationConfig` used\n        brain_key: the brain key\n        points: a ``num_points x num_dims`` array of visualization points\n        sample_ids (None): a ``num_points`` array of sample IDs\n        label_ids (None): a ``num_points`` array of label IDs, if applicable\n        backend (None): a :class:`Visualization` backend\n    \"\"\"\n\n    def __init__(\n        self,\n        samples,\n        config,\n        brain_key,\n        points,\n        sample_ids=None,\n        label_ids=None,\n        backend=None,\n    ):\n        super().__init__(samples, config, brain_key, backend=backend)\n\n        if sample_ids is None:\n            sample_ids, label_ids = fbu.get_ids(\n                samples,\n                patches_field=config.patches_field,\n                data=points,\n                data_type=\"points\",\n            )\n\n        self.points = points\n        self.sample_ids = sample_ids\n        self.label_ids = label_ids\n\n        self._last_view = None\n        self._curr_view = None\n        self._curr_points = None\n        self._curr_sample_ids = None\n        self._curr_label_ids = None\n        self._curr_keep_inds = None\n        self._curr_good_inds = None\n\n        self.use_view(samples)\n\n    def __enter__(self):\n        self._last_view = self.view\n        return self\n\n    def __exit__(self, *args):\n        self.use_view(self._last_view)\n        self._last_view = None\n\n    @property\n    def config(self):\n        \"\"\"The :class:`VisualizationConfig` for the results.\"\"\"\n        return self._config\n\n    @property\n    def index_size(self):\n        \"\"\"The number of active points in the index.\n\n        If :meth:`use_view` has been called to restrict the index, this\n        property will reflect the size of the active index.\n        \"\"\"\n        return len(self._curr_sample_ids)\n\n    @property\n    def total_index_size(self):\n        \"\"\"The total number of data points in the index.\n\n        If :meth:`use_view` has been called to restrict the index, this value\n        may be larger than the current :meth:`index_size`.\n        \"\"\"\n        return len(self.points)\n\n    @property\n    def missing_size(self):\n        \"\"\"The total number of data points in :meth:`view` that are missing\n        from this index.\n\n        This property is only applicable when :meth:`use_view` has been called,\n        and it will be ``None`` if no data points are missing.\n        \"\"\"\n        good = self._curr_good_inds\n\n        if good is None:\n            return None\n\n        return good.size - np.count_nonzero(good)\n\n    @property\n    def current_points(self):\n        \"\"\"The currently active points in the index.\n\n        If :meth:`use_view` has been called, this may be a subset of the full\n        index.\n        \"\"\"\n        return self._curr_points\n\n    @property\n    def current_sample_ids(self):\n        \"\"\"The sample IDs of the currently active points in the index.\n\n        If :meth:`use_view` has been called, this may be a subset of the full\n        index.\n        \"\"\"\n        return self._curr_sample_ids\n\n    @property\n    def current_label_ids(self):\n        \"\"\"The label IDs of the currently active points in the index, or\n        ``None`` if not applicable.\n\n        If :meth:`use_view` has been called, this may be a subset of the full\n        index.\n        \"\"\"\n        return self._curr_label_ids\n\n    @property\n    def view(self):\n        \"\"\"The :class:`fiftyone.core.collections.SampleCollection` against\n        which results are currently being generated.\n\n        If :meth:`use_view` has been called, this view may be different than\n        the collection on which the full index was generated.\n        \"\"\"\n        return self._curr_view\n\n    @property\n    def has_spatial_index(self):\n        \"\"\"Whether these results have a spatial index.\n\n        Use :meth:`index_points` to add a spatial index to an existing set of\n        visualization results.\n        \"\"\"\n        return self.config.points_field is not None\n\n    def use_view(\n        self, sample_collection, allow_missing=True, warn_missing=False\n    ):\n        \"\"\"Restricts the index to the provided view.\n\n        Subsequent calls to methods on this instance will only contain results\n        from the specified view rather than the full index.\n\n        Use :meth:`clear_view` to reset to the full index. Or, equivalently,\n        use the context manager interface as demonstrated below to\n        automatically reset the view when the context exits.\n\n        Example usage::\n\n            import fiftyone as fo\n            import fiftyone.brain as fob\n            import fiftyone.zoo as foz\n\n            dataset = foz.load_zoo_dataset(\"quickstart\")\n\n            results = fob.compute_visualization(dataset)\n            print(results.index_size)  # 200\n\n            view = dataset.take(50)\n\n            with results.use_view(view):\n                print(results.index_size)  # 50\n\n                plot = results.visualize()\n                plot.show()\n\n        Args:\n            sample_collection: a\n                :class:`fiftyone.core.collections.SampleCollection`\n            allow_missing (True): whether to allow the provided collection to\n                contain data points that this index does not contain (True) or\n                whether to raise an error in this case (False)\n            warn_missing (False): whether to log a warning if the provided\n                collection contains data points that this index does not\n                contain\n\n        Returns:\n            self\n        \"\"\"\n        sample_ids, label_ids, keep_inds, good_inds = fbu.filter_ids(\n            sample_collection,\n            self.sample_ids,\n            self.label_ids,\n            patches_field=self._config.patches_field,\n            allow_missing=allow_missing,\n            warn_missing=warn_missing,\n        )\n\n        if keep_inds is not None:\n            points = self.points[keep_inds, :]\n        else:\n            points = self.points\n\n        self._curr_view = sample_collection\n        self._curr_points = points\n        self._curr_sample_ids = sample_ids\n        self._curr_label_ids = label_ids\n        self._curr_keep_inds = keep_inds\n        self._curr_good_inds = good_inds\n\n        return self\n\n    def clear_view(self):\n        \"\"\"Clears the view set by :meth:`use_view`, if any.\n\n        Subsequent operations will be performed on the full index.\n        \"\"\"\n        self.use_view(self._samples)\n\n    def values(self, path_or_expr):\n        \"\"\"Extracts a flat list of values from the given field or expression\n        corresponding to the current :meth:`view`.\n\n        This method always returns values in the same order as\n        :meth:`current_points`, :meth:`current_sample_ids`, and\n        :meth:`current_label_ids`.\n\n        Args:\n            path_or_expr: the values to extract, which can be:\n\n                -   the name of a sample field or ``embedded.field.name`` from\n                    which to extract numeric or string values\n                -   a :class:`fiftyone.core.expressions.ViewExpression`\n                    defining numeric or string values to compute via\n                    :meth:`fiftyone.core.collections.SampleCollection.values`\n\n        Returns:\n            a list of values\n        \"\"\"\n        return values(self, path_or_expr)\n\n    def visualize(\n        self,\n        labels=None,\n        sizes=None,\n        classes=None,\n        backend=\"plotly\",\n        **kwargs,\n    ):\n        \"\"\"Generates an interactive scatterplot of the visualization results\n        for the current :meth:`view`.\n\n        This method supports 2D or 3D visualizations, but interactive point\n        selection is only available in 2D.\n\n        You can use the ``labels`` parameters to define a coloring for the\n        points, and you can use the ``sizes`` parameter to scale the sizes of\n        the points.\n\n        You can attach plots generated by this method to an App session via its\n        :attr:`fiftyone.core.session.Session.plots` attribute, which will\n        automatically sync the session's view with the currently selected\n        points in the plot.\n\n        Args:\n            labels (None): data to use to color the points. Can be any of the\n                following:\n\n                -   the name of a sample field or ``embedded.field.name`` from\n                    which to extract numeric or string values\n                -   a :class:`fiftyone.core.expressions.ViewExpression`\n                    defining numeric or string values to compute via\n                    :meth:`fiftyone.core.collections.SampleCollection.values`\n                -   a list or array-like of numeric or string values\n                -   a list of lists of numeric or string values, if the data in\n                    this visualization corresponds to a label list field like\n                    :class:`fiftyone.core.labels.Detections`\n\n            sizes (None): data to use to scale the sizes of the points. Can be\n                any of the following:\n\n                -   the name of a sample field or ``embedded.field.name`` from\n                    which to extract numeric values\n                -   a :class:`fiftyone.core.expressions.ViewExpression`\n                    defining numeric values to compute via\n                    :meth:`fiftyone.core.collections.SampleCollection.values`\n                -   a list or array-like of numeric values\n                -   a list of lists of numeric values, if the data in this\n                    visualization corresponds to a label list field like\n                    :class:`fiftyone.core.labels.Detections`\n\n            classes (None): an optional list of classes whose points to plot.\n                Only applicable when ``labels`` contains strings\n            backend (\"plotly\"): the plotting backend to use. Supported values\n                are ``(\"plotly\", \"matplotlib\")``\n            **kwargs: keyword arguments for the backend plotting method:\n\n                -   \"plotly\" backend: :meth:`fiftyone.core.plots.plotly.scatterplot`\n                -   \"matplotlib\" backend: :meth:`fiftyone.core.plots.matplotlib.scatterplot`\n\n        Returns:\n            an :class:`fiftyone.core.plots.base.InteractivePlot`\n        \"\"\"\n        return visualize(\n            self,\n            labels=labels,\n            sizes=sizes,\n            classes=classes,\n            backend=backend,\n            **kwargs,\n        )\n\n    def index_points(\n        self,\n        points_field=None,\n        create_index=True,\n        progress=None,\n    ):\n        \"\"\"Adds a spatial index for these visualization results to its\n        dataset's samples.\n\n        This method is useful if you want to add a spatial index to existing\n        visualization results that don't yet have one.\n\n        Spatial indexes are highly recommended for large datasets as they\n        enable efficient querying when lassoing points in embeddings plots.\n\n        Args:\n            points_field (None): an optional field name in which to store the\n                spatial index. The default is the result's ``brain_key``\n            create_index (True): whether to create a database index for the\n                points\n            progress (None): whether to render a progress bar (True/False),\n                use the default value ``fiftyone.config.show_progress_bars``\n                (None), or a progress callback function to invoke instead\n        \"\"\"\n        if points_field is None:\n            if self.key is None:\n                raise ValueError(\n                    \"You must provide a `points_field` when indexing points \"\n                    \"that are not associated with a brain key\"\n                )\n\n            points_field = self.key\n\n        _generate_spatial_index(\n            self.samples,\n            self.points,\n            points_field,\n            self.sample_ids,\n            label_ids=self.label_ids,\n            patches_field=self.config.patches_field,\n            create_index=create_index,\n            progress=progress,\n        )\n\n        if self.key is not None:\n            self.config.points_field = points_field\n            self.save_config()\n\n    def remove_index(self):\n        \"\"\"Removes the spatial index from these visualization results, if one\n        exists.\n        \"\"\"\n        points_field = self.config.points_field\n        if points_field is None:\n            return\n\n        dataset = self.samples._root_dataset\n        if self.config.patches_field is not None:\n            _, points_field = dataset._get_label_field_path(\n                self.config.patches_field, points_field\n            )\n\n        dataset.delete_sample_field(points_field, error_level=1)\n\n        if self.key is not None:\n            self.config.points_field = None\n            self.save_config()\n\n    @classmethod\n    def _from_dict(cls, d, samples, config, brain_key):\n        points = np.array(d[\"points\"])\n\n        sample_ids = d.get(\"sample_ids\", None)\n        if sample_ids is not None:\n            sample_ids = np.array(sample_ids)\n\n        label_ids = d.get(\"label_ids\", None)\n        if label_ids is not None:\n            label_ids = np.array(label_ids)\n\n        return cls(\n            samples,\n            config,\n            brain_key,\n            points,\n            sample_ids=sample_ids,\n            label_ids=label_ids,\n        )\n\n\nclass VisualizationConfig(fob.BrainMethodConfig):\n    \"\"\"Base class for configuring visualization methods.\n\n    Args:\n        embeddings_field (None): the sample field containing the embeddings,\n            if one was provided\n        points_field (None): the name of a field in which to store the\n            visualization points, if requested\n        similarity_index (None): the similarity index containing the\n            embeddings, if one was provided\n        model (None): the :class:`fiftyone.core.models.Model` or name of the\n            zoo model that was used to compute embeddings, if known\n        model_kwargs (None): a dictionary of optional keyword arguments to pass\n            to the model's ``Config`` when a model name is provided\n        patches_field (None): the sample field defining the patches being\n            analyzed, if any\n        num_dims (2): the dimension of the visualization space\n    \"\"\"\n\n    def __init__(\n        self,\n        embeddings_field=None,\n        points_field=None,\n        similarity_index=None,\n        model=None,\n        model_kwargs=None,\n        patches_field=None,\n        num_dims=2,\n        **kwargs,\n    ):\n        if similarity_index is not None and not etau.is_str(similarity_index):\n            similarity_index = similarity_index.key\n\n        if model is not None and not etau.is_str(model):\n            model = None\n\n        self.embeddings_field = embeddings_field\n        self.points_field = points_field\n        self.similarity_index = similarity_index\n        self.model = model\n        self.model_kwargs = model_kwargs\n        self.patches_field = patches_field\n        self.num_dims = num_dims\n        super().__init__(**kwargs)\n\n    @property\n    def type(self):\n        return \"visualization\"\n\n\nclass Visualization(fob.BrainMethod):\n    def fit(self, embeddings):\n        raise NotImplementedError(\"subclass must implement fit()\")\n\n    def get_fields(self, samples, brain_key):\n        fields = []\n        if self.config.patches_field is not None:\n            fields.append(self.config.patches_field)\n        elif self.config.points_field is not None:\n            fields.append(self.config.points_field)\n\n        return fields\n\n    def rename(self, samples, key, new_key):\n        patches_field = self.config.patches_field\n        points_field = self.config.points_field\n        dataset = samples._root_dataset\n\n        if points_field is not None and points_field == key:\n            old_path = key\n            new_path = new_key\n            if patches_field is not None:\n                _, old_path = dataset._get_label_field_path(\n                    patches_field, old_path\n                )\n                _, new_path = dataset._get_label_field_path(\n                    patches_field, new_path\n                )\n\n            self.config.points_field = new_key\n            self.update_run_config(samples, key, self.config)\n\n            dataset.rename_sample_field(old_path, new_path)\n\n    def cleanup(self, samples, key):\n        patches_field = self.config.patches_field\n        points_field = self.config.points_field\n        dataset = samples._root_dataset\n\n        if points_field is not None:\n            if patches_field is not None:\n                _, points_field = dataset._get_label_field_path(\n                    patches_field, points_field\n                )\n\n            dataset.delete_sample_field(points_field, error_level=1)\n\n\nclass UMAPVisualizationConfig(VisualizationConfig):\n    \"\"\"Configuration for Uniform Manifold Approximation and Projection (UMAP)\n    embedding visualization.\n\n    See https://github.com/lmcinnes/umap for more information about the\n    supported parameters.\n\n    Args:\n        embeddings_field (None): the sample field containing the embeddings,\n            if one was provided\n        points_field (None): the name of a field in which to store the\n            visualization points, if requested\n        similarity_index (None): the similarity index containing the\n            embeddings, if one was provided\n        model (None): the :class:`fiftyone.core.models.Model` or name of the\n            zoo model that was used to compute embeddings, if known\n        model_kwargs (None): a dictionary of optional keyword arguments to pass\n            to the model's ``Config`` when a model name is provided\n        patches_field (None): the sample field defining the patches being\n            analyzed, if any\n        num_dims (2): the dimension of the visualization space\n        num_neighbors (15): the number of neighboring points used in local\n            approximations of manifold structure. Larger values will result in\n            more global structure being preserved at the loss of detailed local\n            structure. Typical values are in ``[5, 50]``\n        metric (\"euclidean\"): the metric to use when calculating distance\n            between embeddings. See the UMAP documentation for supported values\n        min_dist (0.1): the effective minimum distance between embedded\n            points. This controls how tightly the embedding is allowed compress\n            points together. Larger values ensure embedded points are more\n            evenly distributed, while smaller values allow the algorithm to\n            optimise more accurately with regard to local structure. Typical\n            values are in ``[0.001, 0.5]``\n        seed (None): a random seed\n        verbose (True): whether to log progress\n    \"\"\"\n\n    def __init__(\n        self,\n        embeddings_field=None,\n        points_field=None,\n        similarity_index=None,\n        model=None,\n        model_kwargs=None,\n        patches_field=None,\n        num_dims=2,\n        num_neighbors=15,\n        metric=\"euclidean\",\n        min_dist=0.1,\n        seed=None,\n        verbose=True,\n        **kwargs,\n    ):\n        super().__init__(\n            embeddings_field=embeddings_field,\n            points_field=points_field,\n            similarity_index=similarity_index,\n            model=model,\n            model_kwargs=model_kwargs,\n            patches_field=patches_field,\n            num_dims=num_dims,\n            **kwargs,\n        )\n        self.num_neighbors = num_neighbors\n        self.metric = metric\n        self.min_dist = min_dist\n        self.seed = seed\n        self.verbose = verbose\n\n    @property\n    def method(self):\n        return \"umap\"\n\n\nclass UMAPVisualization(Visualization):\n    def ensure_requirements(self):\n        fou.ensure_package(\n            \"umap-learn>=0.5\",\n            error_msg=(\n                \"You must install the `umap-learn>=0.5` package in order to \"\n                \"use UMAP-based visualization. This is recommended, as UMAP \"\n                \"is awesome! If you do not wish to install UMAP, try \"\n                \"`method='tsne'` instead\"\n            ),\n        )\n\n    def fit(self, embeddings):\n        _umap = umap.UMAP(\n            n_components=self.config.num_dims,\n            n_neighbors=self.config.num_neighbors,\n            metric=self.config.metric,\n            min_dist=self.config.min_dist,\n            random_state=self.config.seed,\n            verbose=self.config.verbose,\n        )\n        return _umap.fit_transform(embeddings)\n\n\nclass TSNEVisualizationConfig(VisualizationConfig):\n    \"\"\"Configuration for t-distributed Stochastic Neighbor Embedding (t-SNE)\n    visualization.\n\n    See https://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html\n    for more information about the supported parameters.\n\n    Args:\n        embeddings_field (None): the sample field containing the embeddings,\n            if one was provided\n        points_field (None): the name of a field in which to store the\n            visualization points, if requested\n        similarity_index (None): the similarity index containing the\n            embeddings, if one was provided\n        model (None): the :class:`fiftyone.core.models.Model` or name of the\n            zoo model that was used to compute embeddings, if known\n        model_kwargs (None): a dictionary of optional keyword arguments to pass\n            to the model's ``Config`` when a model name is provided\n        patches_field (None): the sample field defining the patches being\n            analyzed, if any\n        num_dims (2): the dimension of the visualization space\n        pca_dims (50): the number of PCA dimensions to compute prior to running\n            t-SNE. It is highly recommended to reduce the number of dimensions\n            to a reasonable number (e.g. 50) before running t-SNE, as this will\n            suppress some noise and speed up the computation of pairwise\n            distances between samples\n        svd_solver (\"randomized\"): the SVD solver to use when performing PCA.\n            Consult the sklearn docmentation for details\n        metric (\"euclidean\"): the metric to use when calculating distance\n            between embeddings. Must be a supported value for the ``metric``\n            argument of ``scipy.spatial.distance.pdist``\n        perplexity (30.0): the perplexity to use. Perplexity is related to the\n            number of nearest neighbors that is used in other manifold learning\n            algorithms. Larger datasets usually require a larger perplexity.\n            Typical values are in ``[5, 50]``\n        learning_rate (200.0): the learning rate to use. Typical values are\n            in ``[10, 1000]``. If the learning rate is too high, the data may\n            look like a ball with any point approximately equidistant from its\n            nearest neighbours. If the learning rate is too low, most points\n            may look compressed in a dense cloud with few outliers. If the cost\n            function gets stuck in a bad local minimum increasing the learning\n            rate may help\n        max_iters (1000): the maximum number of iterations to run. Should be at\n            least 250\n        seed (None): a random seed\n        verbose (True): whether to log progress\n    \"\"\"\n\n    def __init__(\n        self,\n        embeddings_field=None,\n        points_field=None,\n        similarity_index=None,\n        model=None,\n        model_kwargs=None,\n        patches_field=None,\n        num_dims=2,\n        pca_dims=50,\n        svd_solver=\"randomized\",\n        metric=\"euclidean\",\n        perplexity=30.0,\n        learning_rate=200.0,\n        max_iters=1000,\n        seed=None,\n        verbose=True,\n        **kwargs,\n    ):\n        super().__init__(\n            embeddings_field=embeddings_field,\n            points_field=points_field,\n            similarity_index=similarity_index,\n            model=model,\n            model_kwargs=model_kwargs,\n            patches_field=patches_field,\n            num_dims=num_dims,\n            **kwargs,\n        )\n        self.pca_dims = pca_dims\n        self.svd_solver = svd_solver\n        self.metric = metric\n        self.perplexity = perplexity\n        self.learning_rate = learning_rate\n        self.max_iters = max_iters\n        self.seed = seed\n        self.verbose = verbose\n\n    @property\n    def method(self):\n        return \"tsne\"\n\n\nclass TSNEVisualization(Visualization):\n    def fit(self, embeddings):\n        if self.config.pca_dims is not None:\n            _pca = skd.PCA(\n                n_components=self.config.pca_dims,\n                svd_solver=self.config.svd_solver,\n                random_state=self.config.seed,\n            )\n            embeddings = _pca.fit_transform(embeddings)\n\n        embeddings = embeddings.astype(np.float32, copy=False)\n\n        verbose = 2 if self.config.verbose else 0\n\n        sklearn_version = version.parse(sklearn.__version__)\n        iter_param = (\n            \"max_iter\"\n            if sklearn_version >= version.parse(\"1.5.0\")\n            else \"n_iter\"\n        )\n\n        _tsne = skm.TSNE(\n            n_components=self.config.num_dims,\n            perplexity=self.config.perplexity,\n            learning_rate=self.config.learning_rate,\n            metric=self.config.metric,\n            init=\"pca\",\n            random_state=self.config.seed,\n            verbose=verbose,\n            **{iter_param: self.config.max_iters},\n        )\n        return _tsne.fit_transform(embeddings)\n\n\nclass PCAVisualizationConfig(VisualizationConfig):\n    \"\"\"Configuration for principal component analysis (PCA) embedding\n    visualization.\n\n    See https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.PCA.html\n    for more information about the supported parameters.\n\n    Args:\n        embeddings_field (None): the sample field containing the embeddings,\n            if one was provided\n        points_field (None): the name of a field in which to store the\n            visualization points, if requested\n        similarity_index (None): the similarity index containing the\n            embeddings, if one was provided\n        model (None): the :class:`fiftyone.core.models.Model` or name of the\n            zoo model that was used to compute embeddings, if known\n        model_kwargs (None): a dictionary of optional keyword arguments to pass\n            to the model's ``Config`` when a model name is provided\n        patches_field (None): the sample field defining the patches being\n            analyzed, if any\n        num_dims (2): the dimension of the visualization space\n        svd_solver (\"randomized\"): the SVD solver to use. Consult the sklearn\n            docmentation for details\n        seed (None): a random seed\n    \"\"\"\n\n    def __init__(\n        self,\n        embeddings_field=None,\n        points_field=None,\n        similarity_index=None,\n        model=None,\n        model_kwargs=None,\n        patches_field=None,\n        num_dims=2,\n        svd_solver=\"randomized\",\n        seed=None,\n        **kwargs,\n    ):\n        super().__init__(\n            embeddings_field=embeddings_field,\n            points_field=points_field,\n            similarity_index=similarity_index,\n            model=model,\n            model_kwargs=model_kwargs,\n            patches_field=patches_field,\n            num_dims=num_dims,\n            **kwargs,\n        )\n        self.svd_solver = svd_solver\n        self.seed = seed\n\n    @property\n    def method(self):\n        return \"pca\"\n\n\nclass PCAVisualization(Visualization):\n    def fit(self, embeddings):\n        _pca = skd.PCA(\n            n_components=self.config.num_dims,\n            svd_solver=self.config.svd_solver,\n            random_state=self.config.seed,\n        )\n        return _pca.fit_transform(embeddings)\n\n\nclass ManualVisualizationConfig(VisualizationConfig):\n    \"\"\"Configuration for manually-provided low-dimensional visualizations.\n\n    Args:\n        patches_field (None): the sample field defining the patches being\n            analyzed, if any\n        num_dims (2): the dimension of the visualization space\n    \"\"\"\n\n    def __init__(self, patches_field=None, num_dims=2, **kwargs):\n        super().__init__(\n            patches_field=patches_field, num_dims=num_dims, **kwargs\n        )\n\n    @property\n    def method(self):\n        return \"manual\"\n\n\nclass ManualVisualization(Visualization):\n    def fit(self, embeddings):\n        raise NotImplementedError(\n            \"The low-dimensional representation must be manually provided \"\n            \"when using this method\"\n        )\n"
  },
  {
    "path": "install.bat",
    "content": "@echo off\n:: Installs the `fiftyone-brain` package and its dependencies.\n::\n:: Usage:\n:: .\\install.bat\n::\n:: Copyright 2017-2026, Voxel51, Inc.\n:: voxel51.com\n::\n:: Commands:\n:: -h      Display help message\n:: -d      Install developer dependencies.\n\nset SHOW_HELP=false\nset DEV_INSTALL=false\n\n:parse\nIF \"%~1\"==\"\" GOTO endparse\nIF \"%~1\"==\"-h\" GOTO helpmessage\nIF \"%~1\"==\"-d\" set DEV_INSTALL=true\nSHIFT\nGOTO parse\n:endparse\n\necho ***** INSTALLING FIFTYONE-BRAIN *****\nIF %DEV_INSTALL%==true (\n  echo Performing dev install\n  pip install -r requirements/dev.txt\n  pre-commit install\n  pip install -e .\n) else (\n  pip install -r requirements.txt\n  pip install .\n)\n\necho ***** INSTALLATION COMPLETE *****\nexit /b\n\n:helpmessage\necho Additional Arguments:\necho -h      Display help message\necho -d      Install developer dependencies.\nexit /b"
  },
  {
    "path": "install.sh",
    "content": "#!/bin/sh\n# Installs the `fiftyone-brain` package and its dependencies.\n#\n# Usage:\n#   sh install.sh\n#\n# Copyright 2017-2026, Voxel51, Inc.\n# voxel51.com\n#\n\n# Show usage information\nset -e\nusage() {\n    echo \"Usage:  sh $0 [-h] [-d]\n\nGetting help:\n-h      Display this help message.\n\nCustom installations:\n-d      Install developer dependencies.\n\"\n}\n\n# Parse flags\nSHOW_HELP=false\nDEV_INSTALL=false\nwhile getopts \"hd\" FLAG; do\n    case \"${FLAG}\" in\n        h) SHOW_HELP=true ;;\n        d) DEV_INSTALL=true ;;\n        *) usage ;;\n    esac\ndone\n[ ${SHOW_HELP} = true ] && usage && exit 0\nOS=$(uname -s)\n\necho \"***** INSTALLING FIFTYONE-BRAIN *****\"\nif [ ${DEV_INSTALL} = true ]; then\n    echo \"Performing dev install\"\n    pip install -r requirements/dev.txt\n    pre-commit install\n    pip install -e .\nelse\n    pip install -r requirements.txt\n    pip install .\nfi\n\necho \"***** INSTALLATION COMPLETE *****\"\n"
  },
  {
    "path": "pylintrc",
    "content": "[MASTER]\n\n# Specify a configuration file.\n#rcfile=\n\n# Python code to execute, usually for sys.path manipulation such as\n# pygtk.require().\n#init-hook=\n\n# Add files or directories to the blacklist. They should be base names, not\n# paths.\nignore=CVS\n\n# Add files or directories matching the regex patterns to the blacklist. The\n# regex matches against base names, not paths.\nignore-patterns=\n\n# Pickle collected data for later comparisons.\npersistent=yes\n\n# List of plugins (as comma separated values of python modules names) to load,\n# usually to register additional checkers.\nload-plugins=\n\n# Use multiple processes to speed up Pylint.\njobs=1\n\n# Allow loading of arbitrary C extensions. Extensions are imported into the\n# active Python interpreter and may run arbitrary code.\nunsafe-load-any-extension=no\n\n# A comma-separated list of package or module names from where C extensions may\n# be loaded. Extensions are loading into the active Python interpreter and may\n# run arbitrary code\nextension-pkg-whitelist=\n\n# Allow optimization of some AST trees. This will activate a peephole AST\n# optimizer, which will apply various small optimizations. For instance, it can\n# be used to obtain the result of joining multiple strings with the addition\n# operator. Joining a lot of strings can lead to a maximum recursion error in\n# Pylint and this flag can prevent that. It has one side effect, the resulting\n# AST will be different than the one from reality. This option is deprecated\n# and it will be removed in Pylint 2.0.\noptimize-ast=no\n\n\n[MESSAGES CONTROL]\n\n# Only show warnings with the listed confidence levels. Leave empty to show\n# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED\nconfidence=\n\n# Enable the message, report, category or checker with the given id(s). You can\n# either give multiple identifier separated by comma (,) or put this option\n# multiple time (only on the command line, not in the configuration file where\n# it should appear only once). See also the \"--disable\" option for examples.\n#enable=\n\n# Disable the message, report, category or checker with the given id(s). You\n# can either give multiple identifiers separated by comma (,) or put this\n# option multiple times (only on the command line, not in the configuration\n# file where it should appear only once).You can also use \"--disable=all\" to\n# disable everything first and then reenable specific checks. For example, if\n# you want to run only the similarities checker, you can use \"--disable=all\n# --enable=similarities\". If you want to run only the classes checker, but have\n# no Warning level messages displayed, use\"--disable=all --enable=classes\n# --disable=W\"\n#disable=import-star-module-level,old-octal-literal,oct-method,print-statement,unpacking-in-except,parameter-unpacking,backtick,old-raise-syntax,old-ne-operator,long-suffix,dict-view-method,dict-iter-method,metaclass-assignment,next-method-called,raising-string,indexing-exception,raw_input-builtin,long-builtin,file-builtin,execfile-builtin,coerce-builtin,cmp-builtin,buffer-builtin,basestring-builtin,apply-builtin,filter-builtin-not-iterating,using-cmp-argument,useless-suppression,range-builtin-not-iterating,suppressed-message,no-absolute-import,old-division,cmp-method,reload-builtin,zip-builtin-not-iterating,intern-builtin,unichr-builtin,reduce-builtin,standarderror-builtin,unicode-builtin,xrange-builtin,coerce-method,delslice-method,getslice-method,setslice-method,input-builtin,round-builtin,hex-method,nonzero-method,map-builtin-not-iterating\n\ndisable=too-few-public-methods,too-many-instance-attributes,too-many-arguments,too-many-locals,too-many-lines,non-iterator-returned,too-many-statements,useless-object-inheritance,abstract-method,too-many-ancestors,too-many-branches,unnecessary-pass,too-many-public-methods,bad-continuation\n\n\n[REPORTS]\n\n# Set the output format. Available formats are text, parseable, colorized, msvs\n# (visual studio) and html. You can also give a reporter class, eg\n# mypackage.mymodule.MyReporterClass.\noutput-format=colorized\n\n# Put messages in a separate file for each module / package specified on the\n# command line instead of printing them on stdout. Reports (if any) will be\n# written in a file name \"pylint_global.[txt|html]\". This option is deprecated\n# and it will be removed in Pylint 2.0.\nfiles-output=no\n\n# Tells whether to display a full report or only the messages\nreports=no\nscore=no\n\n# Python expression which should return a note less than 10 (10 is the highest\n# note). You have access to the variables errors warning, statement which\n# respectively contain the number of errors / warnings messages and the total\n# number of statements analyzed. This is used by the global evaluation report\n# (RP0004).\nevaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)\n\n# Template used to display messages. This is a python new-style format string\n# used to format the message information. See doc for all details\n#msg-template=\n\n\n[BASIC]\n\n# Good variable names which should always be accepted, separated by a comma\ngood-names=i,j,k\n\n# Bad variable names which should always be refused, separated by a comma\nbad-names=\n\n# Colon-delimited sets of names that determine each other's naming style when\n# the name regexes allow several styles.\nname-group=\n\n# Include a hint for the correct naming format with invalid-name\ninclude-naming-hint=no\n\n# List of decorators that produce properties, such as abc.abstractproperty. Add\n# to this list to register other decorators that produce valid properties.\nproperty-classes=abc.abstractproperty\n\n# Regular expression matching correct function names\nfunction-rgx=[a-z_]([a-z0-9_]{0,30})$\n\n# Naming hint for function names\nfunction-name-hint=[a-z_]([a-z0-9_]{0,30})$\n\n# Regular expression matching correct variable names\nvariable-rgx=[a-z_]([a-z0-9_]{0,30})$\n\n# Naming hint for variable names\nvariable-name-hint=[a-z_]([a-z0-9_]{0,30})$\n\n# Regular expression matching correct constant names\nconst-rgx=(([A-Z_][A-Z0-9_]*)|(__.*__))$\n\n# Naming hint for constant names\nconst-name-hint=(([A-Z_][A-Z0-9_]*)|(__.*__))$\n\n# Regular expression matching correct attribute names\nattr-rgx=[a-z_]([a-z0-9_]{0,30})$\n\n# Naming hint for attribute names\nattr-name-hint=[a-z_]([a-z0-9_]{0,30})$\n\n# Regular expression matching correct argument names\nargument-rgx=[a-z_]([a-z0-9_]{0,30})$\n\n# Naming hint for argument names\nargument-name-hint=[a-z_]([a-z0-9_]{0,30})$\n\n# Regular expression matching correct class attribute names\nclass-attribute-rgx=([A-Za-z_]([A-Za-z0-9_]{0,30})|(__.*__))$\n\n# Naming hint for class attribute names\nclass-attribute-name-hint=([A-Za-z_]([A-Za-z0-9_]{0,30})|(__.*__))$\n\n# Regular expression matching correct inline iteration names\ninlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$\n\n# Naming hint for inline iteration names\ninlinevar-name-hint=[A-Za-z_][A-Za-z0-9_]*$\n\n# Regular expression matching correct class names\nclass-rgx=[A-Z_][a-zA-Z0-9]+$\n\n# Naming hint for class names\nclass-name-hint=[A-Z_][a-zA-Z0-9]+$\n\n# Regular expression matching correct module names\nmodule-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$\n\n# Naming hint for module names\nmodule-name-hint=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$\n\n# Regular expression matching correct method names\nmethod-rgx=[a-z_]([a-z0-9_]{0,30})$\n\n# Naming hint for method names\nmethod-name-hint=[a-z_]([a-z0-9_]{0,30})$\n\n# Regular expression which should only match function or class names that do\n# not require a docstring.\nno-docstring-rgx=^_\n\n# Minimum line length for functions/classes that require docstrings, shorter\n# ones are exempt.\ndocstring-min-length=-1\n\n\n[ELIF]\n\n# Maximum number of nested blocks for function / method body\nmax-nested-blocks=5\n\n\n[FORMAT]\n\n# Maximum number of characters on a single line.\nmax-line-length=79\n\n# Regexp for a line that is allowed to be longer than the limit.\nignore-long-lines=^\\s*(# )?<?https?://\\S+>?$\n\n# Allow the body of an if to be on the same line as the test if there is no\n# else.\nsingle-line-if-stmt=no\n\n# List of optional constructs for which whitespace checking is disabled. `dict-\n# separator` is used to allow tabulation in dicts, etc.: {1  : 1,\\n222: 2}.\n# `trailing-comma` allows a space between comma and closing bracket: (a, ).\n# `empty-line` allows space-only lines.\nno-space-check=trailing-comma,dict-separator\n\n# Maximum number of lines in a module\nmax-module-lines=1000\n\n# String used as indentation unit. This is usually \"    \" (4 spaces) or \"\\t\" (1\n# tab).\nindent-string='    '\n\n# Number of spaces of indent required inside a hanging  or continued line.\nindent-after-paren=4\n\n# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.\nexpected-line-ending-format=\n\n\n[LOGGING]\n\n# Logging modules to check that the string format arguments are in logging\n# function parameter format\nlogging-modules=logging\n\n\n[MISCELLANEOUS]\n\n# List of note tags to take in consideration, separated by a comma.\nnotes=FIXME,XXX,TODO\n\n\n[SIMILARITIES]\n\n# Minimum lines number of a similarity.\nmin-similarity-lines=4\n\n# Ignore comments when computing similarities.\nignore-comments=yes\n\n# Ignore docstrings when computing similarities.\nignore-docstrings=yes\n\n# Ignore imports when computing similarities.\nignore-imports=no\n\n\n[SPELLING]\n\n# Spelling dictionary name. Available dictionaries: none. To make it working\n# install python-enchant package.\nspelling-dict=\n\n# List of comma separated words that should not be checked.\nspelling-ignore-words=\n\n# A path to a file that contains private dictionary; one word per line.\nspelling-private-dict-file=\n\n# Tells whether to store unknown words to indicated private dictionary in\n# --spelling-private-dict-file option instead of raising a message.\nspelling-store-unknown-words=no\n\n\n[TYPECHECK]\n\n# Tells whether missing members accessed in mixin class should be ignored. A\n# mixin class is detected if its name ends with \"mixin\" (case insensitive).\nignore-mixin-members=yes\n\n# List of module names for which member attributes should not be checked\n# (useful for modules/projects where namespaces are manipulated during runtime\n# and thus existing member attributes cannot be deduced by static analysis. It\n# supports qualified module names, as well as Unix pattern matching.\nignored-modules=\n\n# List of class names for which member attributes should not be checked (useful\n# for classes with dynamically set attributes). This supports the use of\n# qualified names.\nignored-classes=optparse.Values,thread._local,_thread._local\n\n# List of members which are set dynamically and missed by pylint inference\n# system, and so shouldn't trigger E1101 when accessed. Python regular\n# expressions are accepted.\ngenerated-members=torch.*,fiftyone.*,fo.*\n\n# List of decorators that produce context managers, such as\n# contextlib.contextmanager. Add to this list to register other decorators that\n# produce valid context managers.\ncontextmanager-decorators=contextlib.contextmanager\n\n\n[CLASSES]\n\n# List of method names used to declare (i.e. assign) instance attributes.\ndefining-attr-methods=__init__,__new__,setUp\n\n# List of valid names for the first argument in a class method.\nvalid-classmethod-first-arg=cls\n\n# List of valid names for the first argument in a metaclass class method.\nvalid-metaclass-classmethod-first-arg=mcs\n\n# List of member names, which should be excluded from the protected access\n# warning.\nexclude-protected=_asdict,_fields,_replace,_source,_make\n\n\n[DESIGN]\n\n# Maximum number of arguments for function / method\nmax-args=5\n\n# Argument names that match this expression will be ignored. Default to name\n# with leading underscore\nignored-argument-names=_.*\n\n# Maximum number of locals for function / method body\nmax-locals=15\n\n# Maximum number of return / yield for function / method body\nmax-returns=6\n\n# Maximum number of branch for function / method body\nmax-branches=12\n\n# Maximum number of statements in function / method body\nmax-statements=50\n\n# Maximum number of parents for a class (see R0901).\nmax-parents=7\n\n# Maximum number of attributes for a class (see R0902).\nmax-attributes=7\n\n# Minimum number of public methods for a class (see R0903).\nmin-public-methods=2\n\n# Maximum number of public methods for a class (see R0904).\nmax-public-methods=20\n\n# Maximum number of boolean expressions in a if statement\nmax-bool-expr=5\n\n\n[IMPORTS]\n\n# Deprecated modules which should not be used, separated by a comma\ndeprecated-modules=regsub,TERMIOS,Bastion,rexec\n\n# Create a graph of every (i.e. internal and external) dependencies in the\n# given file (report RP0402 must not be disabled)\nimport-graph=\n\n# Create a graph of external dependencies in the given file (report RP0402 must\n# not be disabled)\next-import-graph=\n\n# Create a graph of internal dependencies in the given file (report RP0402 must\n# not be disabled)\nint-import-graph=\n\n# Force import order to recognize a module as part of the standard\n# compatibility libraries.\nknown-standard-library=\n\n# Force import order to recognize a module as part of a third party library.\nknown-third-party=enchant\n\n# Analyse import fallback blocks. This can be used to support both Python 2 and\n# 3 compatible code, which means that the block might have code that exists\n# only in one or another interpreter, leading to false positives when analysed.\nanalyse-fallback-blocks=no\n\n\n[EXCEPTIONS]\n\n# Exceptions that will emit a warning when being caught. Defaults to\n# \"Exception\"\novergeneral-exceptions=Exception\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[tool.black]\nline-length = 79\ninclude = '\\.pyi?$'\nexclude = '''\n/(\n  | \\.git\n)/\n'''\n"
  },
  {
    "path": "pytest.ini",
    "content": "[pytest]\npython_files = *test*.py\nfilterwarnings =\n    ignore:dns.hash module will be removed in future versions:DeprecationWarning\n    ignore:the imp module is deprecated in favour of importlib:DeprecationWarning\n    ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning\n    ignore:Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated:DeprecationWarning\n    ignore:numpy.* size changed, may indicate binary incompatibility:RuntimeWarning\n"
  },
  {
    "path": "requirements/build.txt",
    "content": "-r common.txt\n\npytest==5.4.3\ntwine>=3\n"
  },
  {
    "path": "requirements/common.txt",
    "content": "numpy\nscipy\nscikit-learn\n"
  },
  {
    "path": "requirements/dev.txt",
    "content": "-r common.txt\n\nflickrapi==2.4.0\nimageio==2.8.0\nipython>=7.16.1\npandas\npre-commit==2.0.1\npylint==2.3.1\npytest==7.3.1\ntwine>=3\nvoxel51-eta[storage]\n"
  },
  {
    "path": "requirements/prod.txt",
    "content": "-r common.txt\n"
  },
  {
    "path": "requirements.txt",
    "content": "-r requirements/prod.txt\n"
  },
  {
    "path": "setup.py",
    "content": "#!/usr/bin/env python\n\"\"\"\nInstalls `fiftyone-brain`.\n\n| Copyright 2017-2026, Voxel51, Inc.\n| `voxel51.com <https://voxel51.com/>`_\n|\n\"\"\"\nimport os\nfrom setuptools import setup\n\n\nVERSION = \"0.21.4\"\n\n\ndef get_version():\n    if \"RELEASE_VERSION\" in os.environ:\n        version = os.environ[\"RELEASE_VERSION\"]\n        if not version.startswith(VERSION):\n            raise ValueError(\n                \"Release version doest not match version: %s and %s\"\n                % (version, VERSION)\n            )\n        return version\n\n    return VERSION\n\n\nwith open(\"README.md\", \"r\") as fh:\n    long_description = fh.read()\n\n\nsetup(\n    name=\"fiftyone-brain\",\n    version=get_version(),\n    description=\"FiftyOne Brain\",\n    author=\"Voxel51, Inc.\",\n    author_email=\"info@voxel51.com\",\n    url=\"https://github.com/voxel51/fiftyone-brain\",\n    license=\"Apache\",\n    long_description=long_description,\n    long_description_content_type=\"text/markdown\",\n    packages=[\"fiftyone.brain\"],\n    include_package_data=True,\n    install_requires=[\"numpy\", \"scipy>=1.2.0\", \"scikit-learn\"],\n    classifiers=[\n        \"Development Status :: 4 - Beta\",\n        \"Intended Audience :: Developers\",\n        \"Intended Audience :: Science/Research\",\n        \"License :: OSI Approved :: Apache Software License\",\n        \"Topic :: Scientific/Engineering :: Artificial Intelligence\",\n        \"Topic :: Scientific/Engineering :: Image Processing\",\n        \"Topic :: Scientific/Engineering :: Image Recognition\",\n        \"Topic :: Scientific/Engineering :: Information Analysis\",\n        \"Topic :: Scientific/Engineering :: Visualization\",\n        \"Operating System :: MacOS :: MacOS X\",\n        \"Operating System :: POSIX :: Linux\",\n        \"Operating System :: Microsoft :: Windows\",\n        \"Programming Language :: Python :: 3\",\n        \"Programming Language :: Python :: 3.9\",\n        \"Programming Language :: Python :: 3.10\",\n        \"Programming Language :: Python :: 3.11\",\n    ],\n    scripts=[],\n    python_requires=\">=3.9\",\n)\n"
  },
  {
    "path": "tests/README.md",
    "content": "# FiftyOne-Brain Tests\n\nThe brain currently uses both\n[unittest](https://docs.python.org/3/library/unittest.html) and\n[pytest](https://docs.pytest.org/en/stable) to implement its tests.\n\n## Contents\n\n| File                 | Description                                              |\n| -------------------- | -------------------------------------------------------- |\n| `test_uniqueness.py` | Tests of the uniqueness capability                       |\n| `models/*.py`        | Tests of the various models used by the brain            |\n| `intensive/*.py`     | Intensive tests that are not included in automated tests |\n\n## Running tests\n\nTo run all tests in this directory, execute:\n\n```shell\npytest . -s\n```\n\nTo run a specific set of tests, execute:\n\n```shell\npytest <file>.py -s\n```\n\nTo run a specific test case, execute:\n\n```shell\npytest <file>.py -s -k <test_function_name>\n```\n\n## Copyright\n\nCopyright 2017-2026, Voxel51, Inc.<br> voxel51.com\n"
  },
  {
    "path": "tests/intensive/test_interface.py",
    "content": "\"\"\"\nBrain interface tests.\n\n| Copyright 2017-2026, Voxel51, Inc.\n| `voxel51.com <https://voxel51.com/>`_\n|\n\"\"\"\nimport unittest\n\nimport fiftyone as fo\nimport fiftyone.brain as fob\nimport fiftyone.zoo as foz\n\n\ndef test_uniqueness():\n    dataset = foz.load_zoo_dataset(\"quickstart\").clone()\n\n    fob.compute_uniqueness(dataset)\n    print(dataset.list_brain_runs())\n    print(dataset.get_brain_info(\"uniqueness\"))\n    print(dataset.bounds(\"uniqueness\"))\n\n    dataset.delete_brain_runs()\n    print(dataset)\n\n\ndef test_detection_mistakenness():\n    dataset = foz.load_zoo_dataset(\"quickstart\").clone()\n\n    fob.compute_mistakenness(\n        dataset, \"predictions\", label_field=\"ground_truth\", copy_missing=True\n    )\n    print(dataset.list_brain_runs())\n    print(dataset.get_brain_info(\"mistakenness\"))\n\n    # should be non-trivial\n    print(dataset.bounds(\"mistakenness\"))\n    print(dataset.bounds(\"possible_missing\"))\n    print(dataset.bounds(\"possible_spurious\"))\n    print(dataset.bounds(\"ground_truth.detections.mistakenness\"))\n    print(dataset.bounds(\"ground_truth.detections.mistakenness_loc\"))\n    print(dataset.count_values(\"ground_truth.detections.possible_spurious\"))\n    print(dataset.count_values(\"predictions.detections.possible_missing\"))\n    print(dataset.count_values(\"ground_truth.detections.possible_missing\"))\n\n    dataset.delete_brain_runs()\n    print(dataset)\n\n    # should be None\n    print(dataset.bounds(\"ground_truth.detections.mistakenness\"))\n    print(dataset.bounds(\"ground_truth.detections.mistakenness_loc\"))\n    print(dataset.count_values(\"ground_truth.detections.possible_spurious\"))\n    print(dataset.count_values(\"predictions.detections.possible_missing\"))\n    print(dataset.count_values(\"ground_truth.detections.possible_missing\"))\n\n\ndef test_classification_mistakenness_confidence():\n    dataset = foz.load_zoo_dataset(\"quickstart\").clone()\n    test_view = dataset.take(10)\n\n    # labels proxy\n    model = foz.load_zoo_model(\"alexnet-imagenet-torch\")\n    test_view.apply_model(model, \"alexnet\")\n\n    # predictions proxy\n    model = foz.load_zoo_model(\"resnet50-imagenet-torch\")\n    test_view.apply_model(model, \"resnet50\")\n\n    fob.compute_mistakenness(test_view, \"resnet50\", label_field=\"alexnet\")\n    print(dataset.list_brain_runs())\n    print(dataset.load_brain_view(\"mistakenness\"))\n    print(dataset.bounds(\"mistakenness\"))\n\n    dataset.delete_brain_runs()\n    print(dataset)\n\n\ndef test_classification_mistakenness_logits():\n    dataset = foz.load_zoo_dataset(\"quickstart\").clone()\n    test_view = dataset.take(10)\n\n    # labels proxy\n    model = foz.load_zoo_model(\"alexnet-imagenet-torch\")\n    test_view.apply_model(model, \"alexnet\")\n\n    # predictions proxy\n    model = foz.load_zoo_model(\"resnet50-imagenet-torch\")\n    test_view.apply_model(model, \"resnet50\", store_logits=True)\n\n    fob.compute_mistakenness(\n        test_view, \"resnet50\", label_field=\"alexnet\", use_logits=True\n    )\n    print(dataset.list_brain_runs())\n    print(dataset.load_brain_view(\"mistakenness\"))\n    print(dataset.bounds(\"mistakenness\"))\n\n    dataset.delete_brain_runs()\n    print(dataset)\n\n\ndef test_hardness():\n    dataset = foz.load_zoo_dataset(\"quickstart\").clone()\n    test_view = dataset.take(10)\n    model = foz.load_zoo_model(\"alexnet-imagenet-torch\")\n    test_view.apply_model(model, \"alexnet\", store_logits=True)\n\n    fob.compute_hardness(test_view, \"alexnet\")\n    print(dataset.list_brain_runs())\n    print(dataset.get_brain_info(\"hardness\"))\n    print(dataset.load_brain_view(\"hardness\"))\n    print(dataset.bounds(\"hardness\"))\n\n    dataset.delete_brain_runs()\n    print(dataset)\n\n\nif __name__ == \"__main__\":\n    fo.config.show_progress_bars = True\n    unittest.main(verbosity=2)\n"
  },
  {
    "path": "tests/intensive/test_similarity.py",
    "content": "\"\"\"\nSimilarity tests.\n\nUsage::\n\n    # Optional: specific backends to test\n    export SIMILARITY_BACKENDS=qdrant,pinecone,milvus,redis,elasticsearch,mosaic,pgvector,lancedb\n\n    pytest tests/intensive/test_similarity.py -s -k test_XXX\n\nQdrant setup::\n\n    docker pull qdrant/qdrant\n    docker run -p 6333:6333 qdrant/qdrant\n\n    pip install qdrant-client\n\nPinecone setup::\n\n    # Sign up at https://www.pinecone.io\n    # Download API key and environment\n\n    pip install pinecone-client\n\nMilvus setup::\n\n    # Instructions from: https://milvus.io/docs/install_standalone-docker.md\n    wget https://github.com/milvus-io/milvus/releases/download/v2.2.11/milvus-standalone-docker-compose.yml -O docker-compose.yml\n    docker compose up -d\n\n    pip install pymilvus\n\nLanceDB setup::\n\n    pip install lancedb\n\nRedis setup::\n\n    brew tap redis-stack/redis-stack\n    brew install redis-stack\n    redis-stack-server\n\n    pip install redis\n\nElasticsearch setup::\n\n    # Instructions from: https://www.elastic.co/guide/en/elasticsearch/reference/current/getting-started.html#run-elasticsearch\n    docker run -p 127.0.0.1:9200:9200 -d \\\n        --name elasticsearch \\\n        -e ELASTIC_PASSWORD=elastic \\\n        -e \"discovery.type=single-node\" \\\n        -e \"xpack.security.http.ssl.enabled=false\" \\\n        -e \"xpack.license.self_generated.type=trial\" \\\n        docker.elastic.co/elasticsearch/elasticsearch:8.15.0\n\n    pip install elasticsearch\n\nMosaic setup::\n\n    # In your databricks workspace, generate a personal access token for authentication.\n    # You will also need to create a catalog and schema in your workspace.\n    # You will have to create an endpoint under `compute` -> `vector search`\n\n    pip install databricks-vectorsearch\n\nPGVector setup::\n\n    # Run a postgres instance locally with pgvector extension\n    docker pull pgvector/pgvector:pg17\n    docker run --name postgres -p 5432:5432 -e POSTGRES_PASSWORD=mysecretpassword -d pgvector/pgvector:pg17\n\n    # Enter the container and create the vector extension\n    docker exec -it postgres ./bin/psql -U postgres\n    CREATE EXTENSION IF NOT EXISTS vector;  # run in container\n\n    pip install psycopg2\n\nBrain config setup at `~/.fiftyone/brain_config.json`::\n\n    {\n        \"similarity_backends\": {\n            \"pinecone\": {\n                \"api_key\": \"XXXXXXXX\",\n                \"cloud\": \"aws\",\n                \"region\": \"us-east-1\",\n                \"environment\": \"us-east-1-aws\"\n            },\n            \"qdrant\": {\n                \"url\": \"http://localhost:6333\"\n            },\n            \"milvus\": {\n                \"uri\": \"http://localhost:19530\"\n            },\n            \"lancedb\": {\n                \"uri\": \"/tmp/lancedb\"\n            },\n            \"redis\": {\n                \"host\": \"localhost\",\n                \"port\": 6379\n            }\n            \"elasticsearch\": {\n                \"hosts\": \"http://localhost:9200\",\n                \"username\": \"elastic\",\n                \"password\": \"elastic\"\n            },\n            \"mosaic\": {\n                \"workspace_url\": \"https://<unique-url>.cloud.databricks.com/\",\n                \"personal_access_token\": \"<personal-access-token>\",\n                \"catalog_name\": \"<catalong_name>\",\n                \"schema_name\": \"<schema_name>\",\n                \"endpoint_name\": \"<endpoint_name>\"\n            },\n            \"pgvector\": {\n                \"connection_string\": \"postgresql://postgres:mysecretpassword@localhost:5432/postgres\"\n            }\n        }\n    }\n\n| Copyright 2017-2026, Voxel51, Inc.\n| `voxel51.com <https://voxel51.com/>`_\n|\n\"\"\"\nimport random\nimport os\nimport time\nimport unittest\n\nimport numpy as np\n\nimport fiftyone as fo\nimport fiftyone.brain as fob  # pylint: disable=import-error,no-name-in-module\nimport fiftyone.zoo as foz\nfrom fiftyone import ViewField as F\n\n\nCUSTOM_BACKENDS = [\n    \"qdrant\",\n    \"pinecone\",\n    \"milvus\",\n    \"redis\",\n    \"elasticsearch\",\n    \"mosaic\",\n    \"pgvector\",\n    \"lancedb\",\n]\n\n\ndef get_custom_backends():\n    if \"SIMILARITY_BACKENDS\" in os.environ:\n        return os.environ[\"SIMILARITY_BACKENDS\"].split(\",\")\n\n    return CUSTOM_BACKENDS\n\n\ndef test_brain_config():\n    similarity_backends = fob.brain_config.similarity_backends\n\n    assert \"sklearn\" in similarity_backends\n\n    for backend in get_custom_backends():\n        if backend == \"qdrant\":\n            assert \"qdrant\" in similarity_backends\n\n            # this isn't mandatory\n            # assert \"url\" in similarity_backends[\"qdrant\"]\n\n        if backend == \"pinecone\":\n            assert \"pinecone\" in similarity_backends\n\n            # this isn't mandatory\n            # assert \"api_key\" in similarity_backends[\"pinecone\"]\n            # assert \"cloud\" in similarity_backends[\"pinecone\"]\n            # assert \"region\" in similarity_backends[\"pinecone\"]\n            # assert \"environment\" in similarity_backends[\"pinecone\"]\n\n        if backend == \"milvus\":\n            assert \"milvus\" in similarity_backends\n\n            # this isn't mandatory\n            # assert \"uri\" in similarity_backends[\"milvus\"]\n\n        if backend == \"lancedb\":\n            assert \"lancedb\" in similarity_backends\n\n            # this isn't mandatory\n            # assert \"uri\" in similarity_backends[\"lancedb\"]\n\n        if backend == \"redis\":\n            assert \"redis\" in similarity_backends\n\n            # this isn't mandatory\n            # assert \"host\" in similarity_backends[\"redis\"]\n            # assert \"port\" in similarity_backends[\"redis\"]\n\n        if backend == \"elasticsearch\":\n            assert \"elasticsearch\" in similarity_backends\n\n\ndef test_image_similarity_backends():\n    dataset = foz.load_zoo_dataset(\n        \"quickstart\",\n        dataset_name=\"quickstart-test-similarity-image\",\n        drop_existing_dataset=True,\n    )\n\n    # sklearn backend\n    ###########################################################################\n\n    index1 = fob.compute_similarity(\n        dataset,\n        model=\"clip-vit-base32-torch\",\n        metric=\"euclidean\",\n        embeddings=False,\n        backend=\"sklearn\",\n        brain_key=\"clip_sklearn\",\n    )\n\n    embeddings, sample_ids, _ = index1.compute_embeddings(dataset)\n\n    index1.add_to_index(embeddings, sample_ids)\n    index1.save()\n    index1.reload()\n    assert index1.total_index_size == 200\n    assert index1.index_size == 200\n    assert index1.missing_size is None\n\n    prompt = \"kites high in the air\"\n\n    view1 = dataset.sort_by_similarity(prompt, k=10, brain_key=\"clip_sklearn\")\n    assert len(view1) == 10\n\n    del index1\n    dataset.clear_cache()\n\n    print(dataset.get_brain_info(\"clip_sklearn\"))\n\n    index1 = dataset.load_brain_results(\"clip_sklearn\")\n    assert index1.total_index_size == 200\n\n    embeddings1, sample_ids1, _ = index1.get_embeddings()\n    assert embeddings1.shape == (200, 512)\n    assert sample_ids1.shape == (200,)\n\n    ids = random.sample(list(index1.sample_ids), 100)\n\n    embeddings1, sample_ids1, _ = index1.get_embeddings(sample_ids=ids)\n    assert embeddings1.shape == (100, 512)\n    assert sample_ids1.shape == (100,)\n\n    index1.remove_from_index(sample_ids=ids)\n    assert index1.total_index_size == 100\n\n    index1.cleanup()\n    dataset.delete_brain_run(\"clip_sklearn\")\n\n    # custom backends\n    ###########################################################################\n\n    for backend in get_custom_backends():\n        brain_key = \"clip_\" + backend\n\n        index2 = fob.compute_similarity(\n            dataset,\n            model=\"clip-vit-base32-torch\",\n            metric=\"euclidean\",\n            embeddings=False,\n            backend=backend,\n            brain_key=brain_key,\n        )\n\n        index2.add_to_index(embeddings, sample_ids)\n        assert _verify_total_index_size(index=index2, expected_size=200)\n        assert index2.total_index_size == 200\n        assert index2.index_size == 200\n        assert index2.missing_size is None\n\n        view2 = dataset.sort_by_similarity(prompt, k=10, brain_key=brain_key)\n        assert len(view2) == 10\n\n        del index2\n        dataset.clear_cache()\n\n        print(dataset.get_brain_info(brain_key))\n\n        index2 = dataset.load_brain_results(brain_key)\n        assert index2.total_index_size == 200\n\n        # Pinecone and Milvus require IDs, so this method is not supported\n        if backend not in (\"pinecone\", \"milvus\"):\n            embeddings2, sample_ids2, _ = index2.get_embeddings()\n            assert embeddings2.shape == (200, 512)\n            assert sample_ids2.shape == (200,)\n\n        embeddings2, sample_ids2, _ = index2.get_embeddings(sample_ids=ids)\n        assert embeddings2.shape == (100, 512)\n        assert sample_ids2.shape == (100,)\n        assert set(sample_ids1) == set(sample_ids2)\n\n        embeddings2_dict = dict(zip(sample_ids2, embeddings2))\n        _embeddings2 = np.array([embeddings2_dict[i] for i in sample_ids1])\n        assert np.allclose(embeddings1, _embeddings2)\n\n        index2.remove_from_index(sample_ids=ids)\n\n        # Collection size is known to be wrong in Milvus after deletions\n        # As of July 5, 2023 this has not been fixed\n        # https://github.com/milvus-io/milvus/issues/17193\n        if backend != \"milvus\":\n            assert index2.total_index_size == 100\n\n        index2.cleanup()\n        dataset.delete_brain_run(brain_key)\n\n    dataset.delete()\n\n\ndef test_patch_similarity_backends():\n    dataset = foz.load_zoo_dataset(\n        \"quickstart\",\n        dataset_name=\"quickstart-test-similarity-patch\",\n        drop_existing_dataset=True,\n    )\n\n    # sklearn backend\n    ###########################################################################\n\n    index1 = fob.compute_similarity(\n        dataset,\n        patches_field=\"ground_truth\",\n        model=\"clip-vit-base32-torch\",\n        metric=\"euclidean\",\n        embeddings=False,\n        backend=\"sklearn\",\n        brain_key=\"gt_clip_sklearn\",\n    )\n\n    embeddings, sample_ids, label_ids = index1.compute_embeddings(dataset)\n\n    index1.add_to_index(embeddings, sample_ids, label_ids=label_ids)\n    index1.save()\n    index1.reload()\n    assert index1.total_index_size == 1232\n    assert index1.index_size == 1232\n    assert index1.missing_size is None\n\n    view = dataset.to_patches(\"ground_truth\")\n\n    prompt = \"cute puppies\"\n\n    view1 = view.sort_by_similarity(prompt, k=10, brain_key=\"gt_clip_sklearn\")\n    assert len(view1) == 10\n\n    del index1\n    dataset.clear_cache()\n\n    print(dataset.get_brain_info(\"gt_clip_sklearn\"))\n\n    index1 = dataset.load_brain_results(\"gt_clip_sklearn\")\n    assert index1.total_index_size == 1232\n\n    embeddings1, sample_ids1, label_ids1 = index1.get_embeddings()\n    assert embeddings1.shape == (1232, 512)\n    assert sample_ids1.shape == (1232,)\n    assert label_ids1.shape == (1232,)\n\n    ids = random.sample(list(index1.label_ids), 100)\n\n    embeddings1, sample_ids1, label_ids1 = index1.get_embeddings(label_ids=ids)\n    assert embeddings1.shape == (100, 512)\n    assert sample_ids1.shape == (100,)\n    assert label_ids1.shape == (100,)\n\n    index1.remove_from_index(label_ids=ids)\n    assert index1.total_index_size == 1132\n\n    index1.cleanup()\n\n    dataset.delete_brain_run(\"gt_clip_sklearn\")\n\n    # custom backends\n    ###########################################################################\n\n    for backend in get_custom_backends():\n        brain_key = \"gt_clip_\" + backend\n\n        index2 = fob.compute_similarity(\n            dataset,\n            patches_field=\"ground_truth\",\n            model=\"clip-vit-base32-torch\",\n            metric=\"euclidean\",\n            embeddings=False,\n            backend=backend,\n            brain_key=brain_key,\n        )\n\n        index2.add_to_index(embeddings, sample_ids, label_ids=label_ids)\n        assert _verify_total_index_size(index=index2, expected_size=1232)\n        assert index2.total_index_size == 1232\n        assert index2.index_size == 1232\n        assert index2.missing_size is None\n\n        view2 = view.sort_by_similarity(prompt, k=10, brain_key=brain_key)\n        assert len(view2) == 10\n\n        del index2\n        dataset.clear_cache()\n\n        print(dataset.get_brain_info(brain_key))\n\n        index2 = dataset.load_brain_results(brain_key)\n        assert index2.total_index_size == 1232\n\n        # Pinecone and Milvus require IDs, so this method is not supported\n        if backend not in (\"pinecone\", \"milvus\"):\n            embeddings2, sample_ids2, label_ids2 = index2.get_embeddings()\n            assert embeddings2.shape == (1232, 512)\n            assert sample_ids2.shape == (1232,)\n            assert label_ids2.shape == (1232,)\n\n        embeddings2, sample_ids2, label_ids2 = index2.get_embeddings(\n            label_ids=ids\n        )\n        assert embeddings2.shape == (100, 512)\n        assert sample_ids2.shape == (100,)\n        assert label_ids2.shape == (100,)\n        assert set(label_ids1) == set(label_ids2)\n\n        embeddings2_dict = dict(zip(label_ids2, embeddings2))\n        _embeddings2 = np.array([embeddings2_dict[i] for i in label_ids1])\n        assert np.allclose(embeddings1, _embeddings2)\n\n        index2.remove_from_index(label_ids=ids)\n\n        # Collection size is known to be wrong in Milvus after deletions\n        # As of July 5, 2023 this has not been fixed\n        # https://github.com/milvus-io/milvus/issues/17193\n        if backend != \"milvus\":\n            assert index2.total_index_size == 1132\n\n        index2.cleanup()\n        dataset.delete_brain_run(brain_key)\n\n    dataset.delete()\n\n\ndef test_qdrant_backend_config():\n    \"\"\"\n    - *_similarity_backends tests run with custom backends as \"externally\" configured\n    - To test varying connection details (eg with qdrant), re-configure externally and re-run tests\n    - This test white-box tests that gRPC-related config settings are applied to QdrantClient\n    \"\"\"\n\n    backend = \"qdrant\"\n    if backend not in get_custom_backends():\n        return\n\n    dataset = foz.load_zoo_dataset(\"quickstart\", max_samples=5)\n    brain_key = \"clip_\" + backend\n    index = fob.compute_similarity(\n        dataset,\n        model=\"clip-vit-base32-torch\",\n        metric=\"euclidean\",\n        embeddings=False,\n        backend=backend,\n        brain_key=brain_key,\n    )\n\n    qclient = index.client\n    qremote = qclient._client\n    qdrant_config = fob.brain_config.similarity_backends[\"qdrant\"]\n\n    if \"prefer_grpc\" in qdrant_config:\n        prefer_grpc = qdrant_config[\"prefer_grpc\"]\n        assert qremote._prefer_grpc == prefer_grpc\n        print(f\"Applied qdrant config prefer_grpc={prefer_grpc}\")\n    else:\n        print(\"Qdrant config prefer_grpc unset\")\n\n    if \"grpc_port\" in qdrant_config:\n        grpc_port = qdrant_config[\"grpc_port\"]\n        assert qremote._grpc_port == grpc_port\n        print(f\"Applied qdrant config grpc_port={grpc_port}\")\n    else:\n        print(\"Qdrant config grpc_port unset\")\n\n    dataset.delete()\n\n\ndef test_images():\n    dataset = _load_images_dataset()\n\n    index = dataset.load_brain_results(\"img_sim\")\n\n    assert index.total_index_size == len(dataset)\n    assert set(dataset.values(\"id\")) == set(index.sample_ids)\n\n\ndef test_images_subset():\n    dataset = _load_images_dataset()\n\n    index = dataset.load_brain_results(\"img_sim\")\n\n    view = dataset.take(10)\n    index.use_view(view)\n\n    assert index.index_size == len(view)\n    assert set(view.values(\"id\")) == set(index.current_sample_ids)\n\n\ndef test_images_missing():\n    dataset = _load_images_dataset().limit(4).clone()\n    dataset.add_samples(\n        [\n            fo.Sample(filepath=\"non-existent1.png\"),\n            fo.Sample(filepath=\"non-existent2.png\"),\n            fo.Sample(filepath=\"non-existent3.png\"),\n            fo.Sample(filepath=\"non-existent4.png\"),\n        ]\n    )\n\n    sample_ids = dataset[:4].values(\"id\")\n\n    index = fob.compute_similarity(dataset, batch_size=1)\n\n    assert index.total_index_size == 4\n    assert set(sample_ids) == set(index.sample_ids)\n\n    model = foz.load_zoo_model(\"inception-v3-imagenet-torch\")\n    index = fob.compute_similarity(\n        dataset,\n        model=model,\n        embeddings=\"embeddings_missing\",\n        batch_size=1,\n    )\n\n    assert len(dataset.exists(\"embeddings_missing\")) == 4\n    assert index.index_size == 4\n    assert set(sample_ids) == set(index.sample_ids)\n\n\ndef test_images_embeddings():\n    dataset = foz.load_zoo_dataset(\n        \"quickstart\", max_samples=10, drop_existing_dataset=True\n    )\n    model = foz.load_zoo_model(\"clip-vit-base32-torch\")\n    n = len(dataset)\n\n    # Embeddings are computed on-the-fly and stored on dataset\n    index1 = fob.compute_similarity(\n        dataset,\n        embeddings=\"embeddings\",\n        model=\"clip-vit-base32-torch\",\n        brain_key=\"img_sim1\",\n        backend=\"sklearn\",\n    )\n    assert index1.total_index_size == n\n    assert index1.config.supports_prompts is True\n    assert \"embeddings\" not in index1.serialize()\n\n    # Embeddings already exist on dataset\n    dataset.compute_embeddings(model, embeddings_field=\"embeddings2\")\n    index2 = fob.compute_similarity(\n        dataset,\n        embeddings=\"embeddings2\",\n        model=\"clip-vit-base32-torch\",\n        brain_key=\"img_sim2\",\n        backend=\"sklearn\",\n    )\n    assert index2.total_index_size == n\n    assert index2.config.supports_prompts is True\n    assert \"embeddings\" not in index2.serialize()\n\n    # Embeddings stored in index itself\n    index3 = fob.compute_similarity(\n        dataset,\n        model=\"clip-vit-base32-torch\",\n        brain_key=\"img_sim3\",\n        backend=\"sklearn\",\n    )\n    assert index3.total_index_size == n\n    assert index3.config.supports_prompts is True\n    assert \"embeddings\" in index3.serialize()\n\n    # Embeddings stored on dataset (but field doesn't initially exist)\n    index4 = fob.compute_similarity(\n        dataset,\n        embeddings=\"embeddings4\",\n        brain_key=\"img_sim4\",\n        backend=\"sklearn\",\n    )\n    embeddings = np.random.randn(n, 512)\n    sample_ids = dataset.values(\"id\")\n    index4.add_to_index(embeddings, sample_ids)\n    assert index4.total_index_size == n\n    assert index4.config.supports_prompts is not True\n    assert \"embeddings\" not in index4.serialize()\n\n    dataset.delete()\n\n\ndef test_patches():\n    dataset = _load_patches_dataset()\n\n    index = dataset.load_brain_results(\"gt_sim\")\n\n    label_ids = dataset.values(\"ground_truth.detections.id\", unwind=True)\n\n    assert index.total_index_size == len(label_ids)\n    assert set(label_ids) == set(index.label_ids)\n\n\ndef test_patches_subset():\n    dataset = _load_patches_dataset()\n\n    index = dataset.load_brain_results(\"gt_sim\")\n\n    label_ids = dataset.values(\"ground_truth.detections.id\", unwind=True)\n\n    assert index.total_index_size == len(label_ids)\n    assert set(label_ids) == set(index.label_ids)\n\n    view = dataset.filter_labels(\"ground_truth\", F(\"label\") == \"person\")\n    index.use_view(view)\n\n    label_ids = view.values(\"ground_truth.detections.id\", unwind=True)\n\n    assert index.index_size == len(label_ids)\n    assert set(label_ids) == set(index.current_label_ids)\n\n\ndef test_patches_missing():\n    dataset = _load_patches_dataset().limit(4).clone()\n    dataset.add_samples(\n        [\n            fo.Sample(filepath=\"non-existent1.png\"),\n            fo.Sample(filepath=\"non-existent2.png\"),\n            fo.Sample(filepath=\"non-existent3.png\"),\n            fo.Sample(filepath=\"non-existent4.png\"),\n        ]\n    )\n\n    for sample in dataset[4:]:\n        sample[\"ground_truth\"] = fo.Detections(\n            detections=[fo.Detection(bounding_box=[0.1, 0.1, 0.8, 0.8])]\n        )\n        sample.save()\n\n    index = fob.compute_similarity(\n        dataset, patches_field=\"ground_truth\", batch_size=1\n    )\n\n    num_patches = dataset[:4].count(\"ground_truth.detections\")\n    label_ids = dataset[:4].values(\"ground_truth.detections.id\", unwind=True)\n\n    assert index.total_index_size == num_patches\n    assert set(label_ids) == set(index.label_ids)\n\n    model = foz.load_zoo_model(\"inception-v3-imagenet-torch\")\n    index = fob.compute_similarity(\n        dataset,\n        model=model,\n        patches_field=\"ground_truth\",\n        embeddings=\"embeddings_missing\",\n        batch_size=1,\n    )\n\n    view = dataset.filter_labels(\n        \"ground_truth\", F(\"embeddings_missing\") != None\n    )\n\n    assert view.count(\"ground_truth.detections\") == num_patches\n    assert index.total_index_size == num_patches\n    assert set(label_ids) == set(index.label_ids)\n\n\ndef test_patches_embeddings():\n    dataset = foz.load_zoo_dataset(\n        \"quickstart\", max_samples=10, drop_existing_dataset=True\n    )\n    model = foz.load_zoo_model(\"clip-vit-base32-torch\")\n    n = dataset.count(\"ground_truth.detections\")\n\n    # Embeddings are computed on-the-fly and stored on dataset\n    index1 = fob.compute_similarity(\n        dataset,\n        patches_field=\"ground_truth\",\n        embeddings=\"embeddings\",\n        model=\"clip-vit-base32-torch\",\n        brain_key=\"gt_sim1\",\n        backend=\"sklearn\",\n    )\n    assert index1.total_index_size == n\n    assert index1.config.supports_prompts is True\n    assert \"embeddings\" not in index1.serialize()\n\n    # Embeddings already exist on dataset\n    dataset.compute_patch_embeddings(\n        model, \"ground_truth\", embeddings_field=\"embeddings2\"\n    )\n    index2 = fob.compute_similarity(\n        dataset,\n        patches_field=\"ground_truth\",\n        embeddings=\"embeddings2\",\n        model=\"clip-vit-base32-torch\",\n        brain_key=\"gt_sim2\",\n        backend=\"sklearn\",\n    )\n    assert index2.total_index_size == n\n    assert index2.config.supports_prompts is True\n    assert \"embeddings\" not in index2.serialize()\n\n    # Embeddings stored in index itself\n    index3 = fob.compute_similarity(\n        dataset,\n        patches_field=\"ground_truth\",\n        model=\"clip-vit-base32-torch\",\n        brain_key=\"gt_sim3\",\n        backend=\"sklearn\",\n    )\n    assert index3.total_index_size == n\n    assert index3.config.supports_prompts is True\n    assert \"embeddings\" in index3.serialize()\n\n    # Embeddings stored on dataset (but field doesn't initially exist)\n    index4 = fob.compute_similarity(\n        dataset,\n        patches_field=\"ground_truth\",\n        embeddings=\"embeddings4\",\n        brain_key=\"gt_sim4\",\n        backend=\"sklearn\",\n    )\n    embeddings = np.random.randn(n, 512)\n    view = dataset.to_patches(\"ground_truth\")\n    sample_ids, label_ids = view.values([\"sample_id\", \"id\"])\n    index4.add_to_index(embeddings, sample_ids, label_ids=label_ids)\n    assert index4.total_index_size == n\n    assert index4.config.supports_prompts is not True\n    assert \"embeddings\" not in index4.serialize()\n\n    dataset.delete()\n\n\ndef _load_images_dataset():\n    name = \"test-similarity-images\"\n\n    if fo.dataset_exists(name):\n        return fo.load_dataset(name)\n\n    return _make_images_dataset(name)\n\n\ndef _load_patches_dataset():\n    name = \"test-similarity-patches\"\n\n    if fo.dataset_exists(name):\n        return fo.load_dataset(name)\n\n    return _make_patches_dataset(name)\n\n\ndef _make_images_dataset(name):\n    dataset = foz.load_zoo_dataset(\n        \"quickstart\", max_samples=20, dataset_name=name\n    )\n    model = foz.load_zoo_model(\"inception-v3-imagenet-torch\")\n\n    # Embed images\n    dataset.compute_embeddings(\n        model, embeddings_field=\"embeddings\", batch_size=8\n    )\n\n    # Image similarity\n    fob.compute_similarity(\n        dataset, embeddings=\"embeddings\", brain_key=\"img_sim\"\n    )\n\n    return dataset\n\n\ndef _make_patches_dataset(name):\n    dataset = foz.load_zoo_dataset(\n        \"quickstart\", max_samples=20, dataset_name=name\n    )\n    model = foz.load_zoo_model(\"inception-v3-imagenet-torch\")\n\n    # Embed ground truth patches\n    dataset.compute_patch_embeddings(\n        model,\n        \"ground_truth\",\n        embeddings_field=\"embeddings\",\n        batch_size=8,\n        force_square=True,\n    )\n\n    # Patch similarity\n    fob.compute_similarity(\n        dataset,\n        patches_field=\"ground_truth\",\n        embeddings=\"embeddings\",\n        brain_key=\"gt_sim\",\n    )\n\n    return dataset\n\n\ndef _verify_total_index_size(index, expected_size, timeout=10, interval=1):\n    elapsed_time = 0\n    while index.total_index_size != expected_size and elapsed_time < timeout:\n        time.sleep(interval)\n        elapsed_time += interval\n\n    return index.total_index_size == expected_size\n\n\nif __name__ == \"__main__\":\n    fo.config.show_progress_bars = True\n    unittest.main(verbosity=2)\n"
  },
  {
    "path": "tests/intensive/test_uniqueness.py",
    "content": "\"\"\"\nUniqueness tests.\n\n| Copyright 2017-2026, Voxel51, Inc.\n| `voxel51.com <https://voxel51.com/>`_\n|\n\"\"\"\nimport unittest\n\nimport fiftyone as fo\nimport fiftyone.brain as fob\nimport fiftyone.zoo as foz\n\n\ndef test_uniqueness():\n    _run_uniqueness()\n\n\ndef test_uniqueness_torch():\n    model = foz.load_zoo_model(\"inception-v3-imagenet-torch\")\n    _run_uniqueness(model=model, batch_size=16)\n\n\ndef test_uniqueness_tf():\n    model = foz.load_zoo_model(\"resnet-v2-50-imagenet-tf1\")\n    _run_uniqueness(model=model, batch_size=16)\n\n\ndef test_uniqueness_missing():\n    dataset = fo.Dataset()\n    dataset.add_samples(\n        [\n            fo.Sample(filepath=\"non-existent1.png\"),\n            fo.Sample(filepath=\"non-existent2.png\"),\n            fo.Sample(filepath=\"non-existent3.png\"),\n            fo.Sample(filepath=\"non-existent4.png\"),\n        ]\n    )\n\n    fob.compute_uniqueness(dataset, batch_size=1)\n\n    view = dataset.exists(\"uniqueness\")\n\n    assert dataset.has_field(\"uniqueness\")\n    assert len(view) == 0\n\n\ndef test_roi_uniqueness():\n    _run_uniqueness(roi_field=\"ground_truth\")\n\n\ndef test_roi_uniqueness_torch():\n    model = foz.load_zoo_model(\"inception-v3-imagenet-torch\")\n    _run_uniqueness(roi_field=\"ground_truth\", model=model, batch_size=16)\n\n\ndef test_roi_uniqueness_tf():\n    model = foz.load_zoo_model(\"resnet-v2-50-imagenet-tf1\")\n    _run_uniqueness(roi_field=\"ground_truth\", model=model, batch_size=16)\n\n\ndef test_roi_uniqueness_missing():\n    dataset = fo.Dataset()\n    dataset.add_samples(\n        [\n            fo.Sample(filepath=\"non-existent1.png\"),\n            fo.Sample(filepath=\"non-existent2.png\"),\n            fo.Sample(filepath=\"non-existent3.png\"),\n            fo.Sample(filepath=\"non-existent4.png\"),\n        ]\n    )\n\n    for sample in dataset:\n        sample[\"ground_truth\"] = fo.Detections(\n            detections=[fo.Detection(bounding_box=[0.1, 0.1, 0.8, 0.8])]\n        )\n        sample.save()\n\n    fob.compute_uniqueness(dataset, roi_field=\"ground_truth\", batch_size=1)\n\n    view = dataset.exists(\"uniqueness\")\n\n    assert dataset.has_field(\"uniqueness\")\n    assert len(view) == 0\n\n\ndef test_uniqueness_similarity_index():\n    dataset = foz.load_zoo_dataset(\n        \"quickstart\", dataset_name=fo.get_default_dataset_name()\n    )\n    dataset.delete_sample_field(\"uniqueness\")\n\n    # Full similarity index\n\n    similarity_index = fob.compute_similarity(\n        dataset, brain_key=\"sklearn_index\", backend=\"sklearn\"\n    )\n\n    fob.compute_uniqueness(dataset, similarity_index=similarity_index)\n\n    assert dataset.has_field(\"uniqueness\")\n\n    dataset.clear_cache()\n    dataset.delete_sample_field(\"uniqueness\")\n\n    fob.compute_uniqueness(dataset, similarity_index=\"sklearn_index\")\n\n    assert dataset.has_field(\"uniqueness\")\n\n    # Partial similarity index\n\n    view = dataset.take(100, seed=51)\n    similarity_index2 = fob.compute_similarity(\n        view, brain_key=\"sklearn_index2\", backend=\"sklearn\"\n    )\n\n    fob.compute_uniqueness(\n        dataset,\n        uniqueness_field=\"uniqueness2\",\n        similarity_index=\"sklearn_index2\",\n    )\n\n    assert len(dataset.exists(\"uniqueness2\")) == len(view)\n\n\ndef _run_uniqueness(roi_field=None, model=None, batch_size=None):\n    dataset = foz.load_zoo_dataset(\n        \"quickstart\", dataset_name=fo.get_default_dataset_name()\n    )\n    dataset.delete_sample_field(\"uniqueness\")\n\n    view = dataset.take(50)\n    num_samples = len(view)\n\n    fob.compute_uniqueness(\n        view, roi_field=roi_field, model=model, batch_size=batch_size\n    )\n\n    num_uniqueness = dataset.count(\"uniqueness\")\n    assert num_uniqueness == num_samples\n\n    bounds = dataset.bounds(\"uniqueness\")\n\n    assert bounds[0] >= 0\n    assert bounds[1] <= 1\n\n\nif __name__ == \"__main__\":\n    fo.config.show_progress_bars = True\n    unittest.main(verbosity=2)\n"
  },
  {
    "path": "tests/intensive/test_visualization.py",
    "content": "\"\"\"\nVisualization tests.\n\nAll of these tests are designed to be run manually via::\n\n    pytest tests/intensive/test_visualization.py -s -k test_<name>\n\n| Copyright 2017-2026, Voxel51, Inc.\n| `voxel51.com <https://voxel51.com/>`_\n|\n\"\"\"\nimport unittest\n\nimport cv2\nimport numpy as np\n\nimport fiftyone as fo\nimport fiftyone.brain as fob\nimport fiftyone.zoo as foz\nfrom fiftyone import ViewField as F\n\n\ndef test_mnist():\n    dataset = foz.load_zoo_dataset(\"mnist\", split=\"test\")\n\n    # pylint: disable=no-member\n    embeddings = np.array(\n        [\n            cv2.imread(f, cv2.IMREAD_UNCHANGED).ravel()\n            for f in dataset.values(\"filepath\")\n        ]\n    )\n\n    results = fob.compute_visualization(\n        dataset,\n        embeddings=embeddings,\n        num_dims=2,\n        verbose=True,\n        seed=51,\n    )\n\n    plot = results.visualize(labels=\"ground_truth.label\")\n    plot.show()\n\n    input(\"Press enter to continue...\")\n\n\ndef test_images():\n    dataset = _load_images_dataset()\n\n    results = dataset.load_brain_results(\"img_viz\")\n\n    assert results.total_index_size == len(dataset)\n    assert set(dataset.values(\"id\")) == set(results.sample_ids)\n\n    plot = results.visualize(labels=\"uniqueness\")\n    plot.show()\n\n    input(\"Press enter to continue...\")\n\n\ndef test_images_subset():\n    dataset = _load_images_dataset()\n\n    results = dataset.load_brain_results(\"img_viz\")\n\n    view = dataset.take(10)\n    results.use_view(view)\n\n    assert results.index_size == len(view)\n    assert set(view.values(\"id\")) == set(results.current_sample_ids)\n\n    plot = results.visualize(labels=\"uniqueness\")\n    plot.show()\n\n    input(\"Press enter to continue...\")\n\n\ndef test_images_missing():\n    dataset = _load_images_dataset().limit(4).clone()\n    dataset.add_samples(\n        [\n            fo.Sample(filepath=\"non-existent1.png\"),\n            fo.Sample(filepath=\"non-existent2.png\"),\n            fo.Sample(filepath=\"non-existent3.png\"),\n            fo.Sample(filepath=\"non-existent4.png\"),\n        ]\n    )\n\n    sample_ids = dataset[:4].values(\"id\")\n\n    results = fob.compute_visualization(dataset, batch_size=1)\n\n    assert results.total_index_size == 4\n    assert set(sample_ids) == set(results.sample_ids)\n\n    model = foz.load_zoo_model(\"inception-v3-imagenet-torch\")\n    results = fob.compute_visualization(\n        dataset,\n        model=model,\n        embeddings=\"embeddings_missing\",\n        batch_size=1,\n    )\n\n    assert len(dataset.exists(\"embeddings_missing\")) == 4\n    assert results.total_index_size == 4\n    assert set(sample_ids) == set(results.sample_ids)\n\n\ndef test_patches():\n    dataset = _load_patches_dataset()\n\n    results = dataset.load_brain_results(\"gt_viz\")\n\n    label_ids = dataset.values(\"ground_truth.detections.id\", unwind=True)\n\n    assert results.total_index_size == len(label_ids)\n    assert set(label_ids) == set(results.label_ids)\n\n    plot = results.visualize(labels=\"ground_truth.detections.label\")\n    plot.show()\n\n    input(\"Press enter to continue...\")\n\n\ndef test_patches_subset():\n    dataset = _load_patches_dataset()\n\n    results = dataset.load_brain_results(\"gt_viz\")\n\n    plot = results.visualize(\n        labels=\"ground_truth.detections.label\",\n        classes=[\"person\"],\n    )\n    plot.show()\n\n    input(\"Press enter to continue...\")\n\n    view = dataset.filter_labels(\"ground_truth\", F(\"label\") == \"person\")\n    results.use_view(view)\n\n    label_ids = view.values(\"ground_truth.detections.id\", unwind=True)\n\n    assert results.index_size == len(label_ids)\n    assert set(label_ids) == set(results.current_label_ids)\n\n    plot = results.visualize(labels=\"ground_truth.detections.label\")\n    plot.show()\n\n    input(\"Press enter to continue...\")\n\n\ndef test_patches_missing():\n    dataset = _load_patches_dataset().limit(4).clone()\n    dataset.add_samples(\n        [\n            fo.Sample(filepath=\"non-existent1.png\"),\n            fo.Sample(filepath=\"non-existent2.png\"),\n            fo.Sample(filepath=\"non-existent3.png\"),\n            fo.Sample(filepath=\"non-existent4.png\"),\n        ]\n    )\n\n    for sample in dataset[4:]:\n        sample[\"ground_truth\"] = fo.Detections(\n            detections=[fo.Detection(bounding_box=[0.1, 0.1, 0.8, 0.8])]\n        )\n        sample.save()\n\n    results = fob.compute_visualization(\n        dataset, patches_field=\"ground_truth\", batch_size=1\n    )\n\n    num_patches = dataset[:4].count(\"ground_truth.detections\")\n    label_ids = dataset[:4].values(\"ground_truth.detections.id\", unwind=True)\n\n    assert results.total_index_size == num_patches\n    assert set(label_ids) == set(results.label_ids)\n\n    model = foz.load_zoo_model(\"inception-v3-imagenet-torch\")\n    results = fob.compute_visualization(\n        dataset,\n        model=model,\n        patches_field=\"ground_truth\",\n        embeddings=\"embeddings_missing\",\n        batch_size=1,\n    )\n\n    view = dataset.filter_labels(\n        \"ground_truth\", F(\"embeddings_missing\") != None\n    )\n\n    assert view.count(\"ground_truth.detections\") == num_patches\n    assert results.total_index_size == num_patches\n    assert set(label_ids) == set(results.label_ids)\n\n\ndef test_points():\n    dataset = foz.load_zoo_dataset(\"quickstart\")\n\n    n = len(dataset)\n    p = dataset.count(\"ground_truth.detections\")\n    d = 512\n\n    points1 = np.random.rand(n, d)\n    results1 = fob.compute_visualization(\n        dataset,\n        points=points1,\n        brain_key=\"test1\",\n    )\n    assert results1.points.shape == (n, d)\n\n    points2 = {_id: np.random.rand(d) for _id in dataset.values(\"id\")}\n    results2 = fob.compute_visualization(\n        dataset,\n        points=points2,\n        brain_key=\"test2\",\n    )\n    assert results2.points.shape == (n, d)\n\n    points3 = np.random.rand(p, d)\n    results3 = fob.compute_visualization(\n        dataset,\n        patches_field=\"ground_truth\",\n        points=points3,\n        brain_key=\"test3\",\n    )\n    assert results3.points.shape == (p, d)\n\n    points4 = {\n        _id: np.random.rand(d)\n        for _id in dataset.values(\"ground_truth.detections.id\", unwind=True)\n    }\n    results4 = fob.compute_visualization(\n        dataset,\n        patches_field=\"ground_truth\",\n        points=points4,\n        brain_key=\"test4\",\n    )\n    assert results4.points.shape == (p, d)\n\n    dataset.delete()\n\n\ndef test_similarity_index():\n    dataset = foz.load_zoo_dataset(\n        \"quickstart\", dataset_name=fo.get_default_dataset_name()\n    )\n\n    # Full similarity index\n\n    similarity_index = fob.compute_similarity(\n        dataset, brain_key=\"sklearn_index\", backend=\"sklearn\"\n    )\n\n    results = fob.compute_visualization(\n        dataset,\n        brain_key=\"img_viz\",\n        similarity_index=similarity_index,\n    )\n\n    assert len(results.points) == len(dataset)\n\n    # Partial similarity index\n\n    view = dataset.take(100, seed=51)\n    similarity_index2 = fob.compute_similarity(\n        view, brain_key=\"sklearn_index2\", backend=\"sklearn\"\n    )\n\n    results2 = fob.compute_visualization(\n        dataset,\n        brain_key=\"img_viz2\",\n        similarity_index=\"sklearn_index2\",\n    )\n\n    assert len(results2.points) == len(view)\n\n\ndef test_points_field():\n    dataset = _load_images_dataset()\n\n    num_points = len(dataset)\n    points = np.random.randn(num_points, 2)\n\n    brain_key = \"test_points\"\n    points_field = brain_key\n\n    fob.compute_visualization(\n        dataset,\n        brain_key=brain_key,\n        points=points,\n        create_index=True,\n    )\n\n    dataset.clear_cache()\n    results = dataset.load_brain_results(brain_key)\n\n    assert results.config.points_field == points_field\n    assert dataset.has_sample_field(points_field)\n    assert points_field in dataset.list_indexes()\n\n    sample_points = dataset.first()[points_field]\n\n    assert isinstance(sample_points, list)\n    assert len(sample_points) == 2\n    assert isinstance(sample_points[0], float)\n\n    points = results.points\n\n    assert len(points) == num_points\n    assert len(points[0]) == 2\n\n    all_points = dataset.values(points_field)\n\n    assert np.allclose(points, all_points)\n\n    dataset.delete_brain_run(brain_key)\n\n    assert not dataset.has_sample_field(points_field)\n    assert points_field not in dataset.list_indexes()\n\n\ndef test_points_field_patches():\n    dataset = _load_patches_dataset()\n\n    num_points = dataset.count(\"ground_truth.detections\")\n    points = np.random.randn(num_points, 2)\n\n    brain_key = \"test_points\"\n    points_field = brain_key\n    points_path = f\"ground_truth.detections.{points_field}\"\n\n    fob.compute_visualization(\n        dataset,\n        brain_key=brain_key,\n        points=points,\n        patches_field=\"ground_truth\",\n        create_index=True,\n    )\n\n    dataset.clear_cache()\n    results = dataset.load_brain_results(brain_key)\n\n    assert results.config.points_field == points_field\n    assert dataset.has_sample_field(points_path)\n    # Patch visualizations can't currently make use of database indexes\n    assert points_path not in dataset.list_indexes()\n\n    label_points = dataset.first().ground_truth.detections[0][points_field]\n\n    assert isinstance(label_points, list)\n    assert len(label_points) == 2\n    assert isinstance(label_points[0], float)\n\n    points = results.points\n\n    assert len(points) == num_points\n    assert len(points[0]) == 2\n\n    all_points = dataset.values(f\"ground_truth.detections[].{points_field}\")\n\n    assert np.allclose(points, all_points)\n\n    dataset.delete_brain_run(brain_key)\n\n    assert not dataset.has_sample_field(points_path)\n\n\ndef test_index_points():\n    dataset = _load_images_dataset()\n\n    num_points = len(dataset)\n    points = np.random.randn(num_points, 2)\n\n    brain_key = \"test_points\"\n    points_field = brain_key\n\n    fob.compute_visualization(dataset, brain_key=brain_key, points=points)\n\n    dataset.clear_cache()\n    results = dataset.load_brain_results(brain_key)\n\n    assert results.config.points_field is None\n    assert not dataset.has_sample_field(points_field)\n    assert points_field not in dataset.list_indexes()\n\n    results.index_points()\n\n    dataset.clear_cache()\n    results = dataset.load_brain_results(brain_key)\n\n    assert results.config.points_field == points_field\n    assert dataset.has_sample_field(points_field)\n    assert points_field in dataset.list_indexes()\n\n    points = results.points\n    all_points = dataset.values(points_field)\n\n    assert np.allclose(points, all_points)\n\n    results.remove_index()\n\n    dataset.clear_cache()\n    results = dataset.load_brain_results(brain_key)\n\n    assert results.config.points_field is None\n    assert not dataset.has_sample_field(points_field)\n    assert points_field not in dataset.list_indexes()\n\n\ndef test_index_points_patches():\n    dataset = _load_patches_dataset()\n\n    num_points = dataset.count(\"ground_truth.detections\")\n    points = np.random.randn(num_points, 2)\n\n    brain_key = \"test_points\"\n    points_field = brain_key\n    points_path = f\"ground_truth.detections.{points_field}\"\n\n    fob.compute_visualization(\n        dataset,\n        brain_key=brain_key,\n        points=points,\n        patches_field=\"ground_truth\",\n    )\n\n    dataset.clear_cache()\n    results = dataset.load_brain_results(brain_key)\n\n    assert results.config.points_field is None\n    assert not dataset.has_sample_field(points_path)\n\n    results.index_points()\n\n    dataset.clear_cache()\n    results = dataset.load_brain_results(brain_key)\n\n    assert results.config.points_field == points_field\n    assert dataset.has_sample_field(points_path)\n\n    points = results.points\n    all_points = dataset.values(f\"ground_truth.detections[].{points_field}\")\n\n    assert np.allclose(points, all_points)\n\n    results.remove_index()\n\n    dataset.clear_cache()\n    results = dataset.load_brain_results(brain_key)\n\n    assert results.config.points_field is None\n    assert not dataset.has_sample_field(points_path)\n\n\ndef _load_images_dataset():\n    name = \"test-visualization-images\"\n\n    if fo.dataset_exists(name):\n        return fo.load_dataset(name)\n\n    return _make_images_dataset(name)\n\n\ndef _load_patches_dataset():\n    name = \"test-visualization-patches\"\n\n    if fo.dataset_exists(name):\n        return fo.load_dataset(name)\n\n    return _make_patches_dataset(name)\n\n\ndef _make_images_dataset(name):\n    dataset = foz.load_zoo_dataset(\n        \"quickstart\", max_samples=20, dataset_name=name\n    )\n    model = foz.load_zoo_model(\"inception-v3-imagenet-torch\")\n\n    # Embed images\n    dataset.compute_embeddings(\n        model, embeddings_field=\"embeddings\", batch_size=8\n    )\n\n    # Image visualization\n    fob.compute_visualization(\n        dataset,\n        embeddings=\"embeddings\",\n        num_dims=2,\n        verbose=True,\n        seed=51,\n        brain_key=\"img_viz\",\n    )\n\n    return dataset\n\n\ndef _make_patches_dataset(name):\n    dataset = foz.load_zoo_dataset(\n        \"quickstart\", max_samples=20, dataset_name=name\n    )\n    model = foz.load_zoo_model(\"inception-v3-imagenet-torch\")\n\n    # Embed ground truth patches\n    dataset.compute_patch_embeddings(\n        model,\n        \"ground_truth\",\n        embeddings_field=\"embeddings\",\n        batch_size=8,\n        force_square=True,\n    )\n\n    # Patch visualization\n    fob.compute_visualization(\n        dataset,\n        patches_field=\"ground_truth\",\n        embeddings=\"embeddings\",\n        num_dims=2,\n        verbose=True,\n        seed=51,\n        brain_key=\"gt_viz\",\n    )\n\n    return dataset\n\n\nif __name__ == \"__main__\":\n    fo.config.show_progress_bars = True\n    unittest.main(verbosity=2)\n"
  },
  {
    "path": "tests/models/test_simple_resnet.py",
    "content": "\"\"\"\nTests for :mod:`fiftyone.brain.internal.models.simple_resnet`.\n\n| Copyright 2017-2026, Voxel51, Inc.\n| `voxel51.com <https://voxel51.com/>`_\n|\n\"\"\"\nimport imageio\nfrom PIL import Image\nimport torch\n\nimport eta.core.image as etai\n\nimport fiftyone as fo\nimport fiftyone.core.utils as fou\nimport fiftyone.zoo as foz\n\nimport fiftyone.brain.internal.models as fbm\n\n\ndef _transpose(x, source, target):\n    return x.permute([source.index(d) for d in target])\n\n\ndef _check_prediction(actual, expected):\n    assert isinstance(actual, fo.Classification)\n    assert isinstance(expected, fo.Classification)\n    # @todo fix me on 3.9\n    # assert actual.label == expected.label\n\n\ndef test_simple_resnet():\n    dataset = foz.load_zoo_dataset(\n        \"cifar10\",\n        split=\"test\",\n        dataset_name=fo.get_default_dataset_name(),\n        shuffle=True,\n        max_samples=1,\n    )\n\n    sample = dataset.first()\n    filepath = sample.filepath\n    print(\"Working on image at %s\" % filepath)\n\n    img_pil = Image.open(filepath)\n    print(\"img_pil is type %s\" % type(img_pil))\n\n    img_numpy = imageio.imread(filepath)\n    print(\"img_numpy is type %s\" % type(img_numpy))\n    print(img_numpy.shape)\n\n    img_torch = torch.from_numpy(img_numpy)\n    img_torch = _transpose(img_torch, \"HWC\", \"CHW\")\n    print(\"img_torch is type %s\" % type(img_torch))\n    print(img_torch.shape)\n    assert tuple(reversed(img_torch.shape)) == img_numpy.shape\n\n    img_eta = etai.read(filepath)\n    print(\"img_eta is type %s\" % type(img_eta))\n    print(img_eta.shape)\n    assert tuple(img_eta.shape) == img_numpy.shape\n\n    model = fbm.load_model(\"simple-resnet-cifar10\")\n\n    with model:\n        print(\"PIL\")\n        p_pil = model.predict(img_pil)\n        print(p_pil)\n\n        print(\"IMAGEIO\")\n        p_numpy = model.predict(img_numpy)\n        print(p_numpy)\n        _check_prediction(p_numpy, p_pil)\n\n        print(\"ETA\")\n        p_eta = model.predict(img_eta)\n        print(p_eta)\n        _check_prediction(p_eta, p_pil)\n\n        print(\"PIL (manual preprocessing)\")\n        with fou.SetAttributes(model, preprocess=False):\n            img_tensor = model.transforms(img_pil)\n            p_pil2 = model.predict(img_tensor)\n            print(p_pil2)\n            _check_prediction(p_pil2, p_pil)\n\n        print(\"IMAGEIO (manual preprocessing)\")\n        with fou.SetAttributes(model, preprocess=False):\n            img_tensor = model.transforms(img_numpy)\n            p_numpy2 = model.predict(img_tensor)\n            print(p_numpy2)\n            _check_prediction(p_numpy2, p_numpy)\n\n\nif __name__ == \"__main__\":\n    test_simple_resnet()\n"
  },
  {
    "path": "tests/test_uniqueness.py",
    "content": "\"\"\"\nUniqueness tests.\n\n| Copyright 2017-2026, Voxel51, Inc.\n| `voxel51.com <https://voxel51.com/>`_\n|\n\"\"\"\nimport os\nimport unittest\n\nimport eta.core.storage as etas\nimport eta.core.utils as etau\n\nimport fiftyone as fo\nimport fiftyone.brain as fob\nimport fiftyone.zoo as foz\n\n\ndef test_uniqueness():\n    dataset = foz.load_zoo_dataset(\"cifar10\", split=\"test\")\n    assert \"uniqueness\" not in dataset.get_field_schema()\n\n    view = dataset.view().take(100)\n    fob.compute_uniqueness(view)\n\n    print(dataset)\n    assert \"uniqueness\" in dataset.get_field_schema()\n\n\ndef test_gray():\n    \"\"\"Test default support for handling grayscale images.\n\n    Requires Voxel51 Google Drive credentials to download the test data.\n    \"\"\"\n    with etau.TempDir() as tmpdir:\n        tmp_zip = os.path.join(tmpdir, \"data.zip\")\n        tmp_data = os.path.join(tmpdir, \"brain_grayscale_test_data\")\n        client = etas.GoogleDriveStorageClient()\n        client.download(\"1ECeNnLmKQCHxlVdRqGefV5eXOD_OkmWx\", tmp_zip)\n        etau.extract_zip(tmp_zip, delete_zip=True)\n\n        dataset = fo.Dataset.from_dir(tmp_data, fo.types.ImageDirectory)\n\n        fob.compute_uniqueness(dataset)\n        print(dataset)\n\n\nif __name__ == \"__main__\":\n    fo.config.show_progress_bars = True\n    unittest.main(verbosity=2)\n"
  }
]