Repository: voxel51/fiftyone-brain Branch: develop Commit: 05fccee0ae1c Files: 61 Total size: 549.0 KB Directory structure: gitextract_zu49vpz0/ ├── .github/ │ ├── CODEOWNERS │ ├── dependabot.yml │ ├── pull_request_template.md │ └── workflows/ │ └── build.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .prettierrc ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── RELEASING.md ├── STYLE_GUIDE.md ├── fiftyone/ │ ├── __init__.py │ └── brain/ │ ├── __init__.py │ ├── config.py │ ├── internal/ │ │ ├── __init__.py │ │ ├── core/ │ │ │ ├── __init__.py │ │ │ ├── duplicates.py │ │ │ ├── elasticsearch.py │ │ │ ├── hardness.py │ │ │ ├── lancedb.py │ │ │ ├── leaky_splits.py │ │ │ ├── milvus.py │ │ │ ├── mistakenness.py │ │ │ ├── mongodb.py │ │ │ ├── mosaic.py │ │ │ ├── pgvector.py │ │ │ ├── pinecone.py │ │ │ ├── qdrant.py │ │ │ ├── redis.py │ │ │ ├── representativeness.py │ │ │ ├── sklearn.py │ │ │ ├── uniqueness.py │ │ │ ├── utils.py │ │ │ └── visualization.py │ │ └── models/ │ │ ├── .gitignore │ │ ├── __init__.py │ │ ├── manifest.json │ │ ├── simple_resnet.py │ │ └── torch.py │ ├── similarity.py │ └── visualization.py ├── install.bat ├── install.sh ├── pylintrc ├── pyproject.toml ├── pytest.ini ├── requirements/ │ ├── build.txt │ ├── common.txt │ ├── dev.txt │ └── prod.txt ├── requirements.txt ├── setup.py └── tests/ ├── README.md ├── intensive/ │ ├── test_interface.py │ ├── test_similarity.py │ ├── test_uniqueness.py │ └── test_visualization.py ├── models/ │ └── test_simple_resnet.py └── test_uniqueness.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/CODEOWNERS ================================================ * @voxel51/developers # Aloha! .github/ @voxel51/aloha-shirts pyproject.toml @voxel51/aloha-shirts RELEASING.md @voxel51/aloha-shirts setup.py @voxel51/aloha-shirts ================================================ FILE: .github/dependabot.yml ================================================ --- version: 2 updates: - package-ecosystem: "github-actions" directory: "/" schedule: interval: "weekly" day: "wednesday" time: "14:00" timezone: "UTC" ================================================ FILE: .github/pull_request_template.md ================================================ # Rationale ## Changes ## Testing ================================================ FILE: .github/workflows/build.yml ================================================ name: Build on: pull_request: branches: - develop types: [opened, synchronize] push: branches: - develop tags: - v* jobs: build: runs-on: ubuntu-latest steps: - name: Clone fiftyone-brain uses: actions/checkout@v6 with: submodules: true - name: Set up Python uses: actions/setup-python@v6 with: python-version: 3.9 - name: Check Python version run: | python --version pip --version - name: Install dependencies run: | pip install --upgrade pip setuptools wheel pip install -r requirements/build.txt - name: Set environment env: RELEASE_TAG: ${{ github.ref }} run: | if [[ $RELEASE_TAG =~ ^refs\/tags\/v.* ]]; then echo "RELEASE_VERSION=$(echo '${{ github.ref }}' | sed 's/^refs\/tags\/v//')" >> $GITHUB_ENV fi - name: Build wheel run: | python setup.py sdist bdist_wheel - name: Upload wheel uses: actions/upload-artifact@v7 with: name: dist path: dist/ retention-days: 1 test: needs: [build] runs-on: ubuntu-latest env: FIFTYONE_DATASET_ZOO_DIR: ${{ github.workspace }}/.fiftyone FIFTYONE_DO_NOT_TRACK: true FIFTYONE_MODEL_ZOO_DIR: ${{ github.workspace }}/.fiftyone permissions: contents: read id-token: write strategy: fail-fast: false matrix: python: - "3.9" - "3.10" - "3.11" steps: - name: Clone fiftyone-brain uses: actions/checkout@v6 with: submodules: true - name: Clone fiftyone uses: actions/checkout@v6 with: fetch-depth: 1 path: fiftyone-src ref: develop repository: voxel51/fiftyone - name: Clone voxel51-eta uses: actions/checkout@v6 if: ${{ !startsWith(github.ref, 'refs/heads/rel') && !startsWith(github.ref, 'refs/tags/') }} with: fetch-depth: 1 path: eta ref: develop repository: voxel51/eta # ETA tests will create a storage client which, # in it's __init__, tries to log in to GCP. # See tests/tests_uniqueness.py - name: Authenticate to Google Cloud uses: google-github-actions/auth@v3 with: project_id: ${{ secrets.REPO_GCP_PROJECT }} service_account: ${{ secrets.REPO_GCP_SERVICE_ACCOUNT }} workload_identity_provider: ${{ secrets.REPO_GOOGLE_WORKLOAD_IDP }} - name: Set Up Cloud SDK uses: google-github-actions/setup-gcloud@v3 - name: Set up Python ${{ matrix.python }} uses: actions/setup-python@v6 with: python-version: ${{ matrix.python }} - name: Free Disk Space (Ubuntu) # standard runner's 14 GB available disk size isn't enough. Need at least 22 GB free. uses: jlumbroso/free-disk-space@v1.3.1 - name: Install dependencies run: | pip install --upgrade pip setuptools wheel - name: Download fiftyone-brain wheel uses: actions/download-artifact@v8 with: name: dist path: dist/ - name: Install fiftyone working-directory: fiftyone-src run: | python setup.py bdist_wheel pip install voxel51-eta[storage] fiftyone-db pip install ./dist/*.whl - name: Install ETA from source working-directory: eta # Don't install from source if this is a release. # Install from PyPI if: ${{ !startsWith(github.ref, 'refs/heads/rel') && !startsWith(github.ref, 'refs/tags/') }} run: | echo "Installing ETA from source because github.ref = ${{ github.ref }} (not a release)" python setup.py bdist_wheel pip install ./dist/*.whl --force-reinstall - name: Reinstall fiftyone-brain run: | pip install --force-reinstall --no-deps dist/*.whl - name: Install test dependencies run: | pip install imageio pytest torch torchvision - name: Cache Zoo id: fiftyone-cache uses: actions/cache@v5 with: path: | .fiftyone key: zoo-${{ hashFiles('tests/**') }} - name: Run tests run: | pytest --verbose tests/ --ignore tests/intensive/ publish: needs: [build, test] if: startsWith(github.ref, 'refs/tags/v') runs-on: ubuntu-latest environment: release # For trusted publishing. See below. permissions: contents: read id-token: write steps: - name: Download wheels uses: actions/download-artifact@v8 with: name: dist path: dist/ # Utilize # [trusted publishers](https://docs.pypi.org/trusted-publishers/) # This will use OIDC to publish the dists/ package to pypi. # See # [fiftyone-brain](https://pypi.org/manage/project/fiftyone-brain/settings/publishing/) - name: Publish uses: pypa/gh-action-pypi-publish@v1.14.0 ================================================ FILE: .gitignore ================================================ __pycache__ .DS_store .ipynb_checkpoints *~ *.egg-info *.py[cod] *.pth *.swp .idea .project .pydevproject build/ dist/ /fiftyone/brain/internal/models/cache/**/* !/fiftyone/brain/internal/models/cache/manifest.json *.pth ================================================ FILE: .pre-commit-config.yaml ================================================ repos: - repo: https://github.com/asottile/blacken-docs rev: v1.12.0 hooks: - id: blacken-docs additional_dependencies: [black==21.12b0] args: ["-l 79"] - repo: https://github.com/ambv/black rev: 22.3.0 hooks: - id: black language_version: python3 args: ["-l 79"] - repo: local hooks: - id: pylint name: pylint language: system files: \.py$ entry: pylint args: ["--errors-only"] - repo: local hooks: - id: ipynb-strip name: ipynb-strip language: system files: \.ipynb$ entry: jupyter nbconvert --clear-output --ClearOutputPreprocessor.enabled=True args: ["--log-level=ERROR"] - repo: https://github.com/pre-commit/mirrors-prettier rev: v2.6.2 hooks: - id: prettier language_version: system ================================================ FILE: .prettierrc ================================================ { "overrides": [ { "files": "*.md", "options": { "printWidth": 79, "proseWrap": "always", "tabWidth": 4 } }, { "files": "*.json", "options": { "tabWidth": 4 } } ] } ================================================ FILE: CONTRIBUTING.md ================================================ # Contributing to FiftyOne Brain All Brain contributions should follow the practices established in [FiftyOne](https://github.com/voxel51/fiftyone/blob/develop/CONTRIBUTING.md). ## Adding new public methods to the Brain package The `fiftyone.brain` package should expose all core user-functionality at the base level. For example, for hardness, the user should be able to execute calls in the following way: ```py # Users should be able to do this import fiftyone.brain as fob fob.compute_hardness(...) # And NOT have to do this import fiftyone.brain.hardness as fobh fobh.compute_hardness(...) ``` To achieve this, follow the existing pattern of declaring new public methods in [`fiftyone/brain/__init__.py`](https://github.com/voxel51/fiftyone-brain/blob/develop/fiftyone/brain/__init__.py). Be sure to include a detailed docstring for all methods in this file, as they are pulled in by FiftyOne documentation builds and are made available in the [public docs](https://docs.voxel51.com/api/fiftyone.brain.html). ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright 2017-2026, Voxel51, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: MANIFEST.in ================================================ global-include * prune fiftyone/brain/internal/models/cache/ include fiftyone/brain/internal/models/cache/manifest.json ================================================ FILE: README.md ================================================

**Open Source AI from [Voxel51](https://voxel51.com)** FiftyOne WebsiteFiftyOne DocsFiftyOne Brain DocsBlogCommunity [![PyPI python](https://img.shields.io/pypi/pyversions/fiftyone-brain)](https://pypi.org/project/fiftyone-brain) [![PyPI version](https://badge.fury.io/py/fiftyone-brain.svg)](https://pypi.org/project/fiftyone-brain) [![Downloads](https://static.pepy.tech/badge/fiftyone-brain)](https://pepy.tech/project/fiftyone-brain) [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](LICENSE) [![Discord](https://img.shields.io/badge/Discord-7289DA?logo=discord&logoColor=white)](https://discord.gg/fiftyone-community) [![Hugging Face](https://img.shields.io/badge/Hugging_Face-purple?style=flat&logo=huggingface)](https://huggingface.co/Voxel51) [![Voxel51 Blog](https://img.shields.io/badge/Voxel51_Blog-ff6d04?style=flat)](https://voxel51.com/blog) [![Newsletter](https://img.shields.io/badge/Newsletter-BE5B25?logo=mail.ru&logoColor=white)](https://share.hsforms.com/1zpJ60ggaQtOoVeBqIZdaaA2ykyk) [![LinkedIn](https://img.shields.io/badge/In-white?style=flat&label=Linked&labelColor=blue)](https://www.linkedin.com/company/voxel51) [![Twitter](https://img.shields.io/badge/Twitter-000000?logo=x&logoColor=white)](https://x.com/voxel51) [![Medium](https://img.shields.io/badge/Medium-12100E?logo=medium&logoColor=white)](https://medium.com/voxel51)

--- FiftyOne Brain contains the open source AI/ML capabilities for the [FiftyOne ecosystem](https://github.com/voxel51/fiftyone), enabling users to automatically analyze and manipulate their datasets and models. FiftyOne Brain includes features like visual similarity search, query by text, finding unique and representative samples, finding media quality problems and annotation mistakes, and more 🚀 ## Documentation Public documentation for the FiftyOne Brain is [available here](https://docs.voxel51.com/user_guide/brain.html). ## Installation The FiftyOne Brain is distributed via the `fiftyone-brain` package, and a suitable version is automatically included with every `fiftyone` install: ```shell pip install fiftyone pip show fiftyone-brain ``` ### Installing from source If you wish to do a source install of the latest FiftyOne Brain version, simply clone this repository: ```shell git clone https://github.com/voxel51/fiftyone-brain cd fiftyone-brain ``` and run the install script: ```shell # Mac or Linux bash install.sh # Windows .\install.bat ``` ### Developer installation If you are a developer contributing to this repository, you should perform a developer installation using the `-d` flag of the install script: ```shell # Mac or Linux bash install.sh -d # Windows .\install.bat -d ``` Check out the [contribution guide](CONTRIBUTING.md) to get started. ## Uninstallation ```shell pip uninstall fiftyone-brain ``` ## Repository layout - `fiftyone/brain/` definition of the `fiftyone.brain` namespace - `requirements/` Python requirements for the project - `tests/` tests for the various components of the Brain ## Citation If you use the FiftyOne Brain in your research, please cite the project: ```bibtex @article{moore2020fiftyone, title={FiftyOne}, author={Moore, B. E. and Corso, J. J.}, journal={GitHub. Note: https://github.com/voxel51/fiftyone-brain}, year={2020} } ``` ================================================ FILE: RELEASING.md ================================================ # Releasing the Brain package > [!NOTE] > These steps are to be performed by authorized Voxel51 engineers. The `fiftyone-brain` repository follows `Gitflow`. Releases will be initiated when a teammate submits a pull request from their respective `release/v*` branch to `main`. We can see an example PR for [version 0.21.4](https://github.com/voxel51/fiftyone-brain/pull/265). Reviewers should always check that the version in the `setup.py` matches the branch version. The release engineer will merge the pull request once it is approved. The PyPI uploads will be triggered when a release tag is pushed to the repository: 1. Navigate to the [releases page](https://github.com/voxel51/fiftyone-brain/pull/265). 1. Select `Draft a new release`. 1. Select `Create new tag` with the appropriate version and set the target to `main`. 1. The tag format is `v`. For example, `v0.21.4`. This should match the `setup.py` and release branch. 1. Select `Generate release notes`. 1. Select `Set as the latest release`. 1. Select `Publish release`. This will create a new tag in the repository and will trigger the [build/publish workflow](https://github.com/voxel51/fiftyone-brain/actions/workflows/build.yml). This workflow will build the `.whl` artifacts and publish them to [PyPI](https://pypi.org/project/fiftyone-brain/). Once the build are finished, submit a PR from `main` to `develop` to complete the `Gitflow` process. ================================================ FILE: STYLE_GUIDE.md ================================================ # FiftyOne Brain Style Guide The Brain follows the same style guidelines as [FiftyOne](https://github.com/voxel51/fiftyone/blob/develop/STYLE_GUIDE.md). ================================================ FILE: fiftyone/__init__.py ================================================ from pkgutil import extend_path # # This statement allows multiple `fiftyone.XXX` packages to be installed in the # same environment and used simultaneously. # # https://docs.python.org/3/library/pkgutil.html#pkgutil.extend_path # __path__ = extend_path(__path__, __name__) from fiftyone.__public__ import * ================================================ FILE: fiftyone/brain/__init__.py ================================================ """ The brains behind FiftyOne: a powerful package for dataset curation, analysis, and visualization. See https://github.com/voxel51/fiftyone for more information. | Copyright 2017-2026, Voxel51, Inc. | `voxel51.com `_ | """ import fiftyone.brain.config as _foc from .similarity import ( Similarity, SimilarityConfig, SimilarityIndex, ) from .visualization import ( Visualization, VisualizationConfig, VisualizationResults, ) brain_config = _foc.load_brain_config() def compute_hardness( samples, label_field, hardness_field="hardness", progress=None, ): """Adds a hardness field to each sample scoring the difficulty that the specified label field observed in classifying the sample. Hardness is a measure computed based on model prediction output (through logits) that summarizes a measure of the uncertainty the model had with the sample. This makes hardness quantitative and can be used to detect things like hard samples, annotation errors during noisy training, and more. All classifications must have their :attr:`logits ` attributes populated in order to use this method. .. note:: Runs of this method can be referenced later via brain key ``hardness_field``. Args: samples: a :class:`fiftyone.core.collections.SampleCollection` label_field: the :class:`fiftyone.core.labels.Classification` or :class:`fiftyone.core.labels.Classifications` field to use from each sample hardness_field ("hardness"): the field name to use to store the hardness value for each sample progress (None): whether to render a progress bar (True/False), use the default value ``fiftyone.config.show_progress_bars`` (None), or a progress callback function to invoke instead """ import fiftyone.brain.internal.core.hardness as fbh return fbh.compute_hardness(samples, label_field, hardness_field, progress) def compute_mistakenness( samples, pred_field, label_field, mistakenness_field="mistakenness", missing_field="possible_missing", spurious_field="possible_spurious", use_logits=False, copy_missing=False, progress=None, ): """Computes the mistakenness (likelihood of being incorrect) of the labels in ``label_field`` based on the predcted labels in ``pred_field``. Mistakenness is measured based on either the ``confidence`` or ``logits`` of the predictions in ``pred_field``. This measure can be used to detect things like annotation errors and unusually hard samples. For classifications, a ``mistakenness_field`` field is populated on each sample that quantifies the likelihood that the label in the ``label_field`` of that sample is incorrect. For objects (detections, polylines, keypoints, etc), the mistakenness of each object in ``label_field`` is computed, using :meth:`fiftyone.core.collections.SampleCollection.evaluate_detections` to locate corresponding objects in ``pred_field``. Three types of mistakes are identified: - **(Mistakes)** Objects in ``label_field`` with a match in ``pred_field`` are assigned a mistakenness value in their ``mistakenness_field`` that captures the likelihood that the class label of the object in ``label_field`` is a mistake. A ``mistakenness_field + "_loc"`` field is also populated that captures the likelihood that the object in ``label_field`` is a mistake due to its localization (bounding box). - **(Missing)** Objects in ``pred_field`` with no matches in ``label_field`` but which are likely to be correct will have their ``missing_field`` attribute set to True. In addition, if ``copy_missing`` is True, copies of these objects are *added* to the ground truth ``label_field``. - **(Spurious)** Objects in ``label_field`` with no matches in ``pred_field`` but which are likely to be incorrect will have their ``spurious_field`` attribute set to True. In addition, for objects, the following sample-level fields are populated: - **(Mistakes)** The ``mistakenness_field`` of each sample is populated with the maximum mistakenness of the objects in ``label_field`` - **(Missing)** The ``missing_field`` of each sample is populated with the number of missing objects that were deemed missing from ``label_field``. - **(Spurious)** The ``spurious_field`` of each sample is populated with the number of objects in ``label_field`` that were given deemed spurious. .. note:: Runs of this method can be referenced later via brain key ``mistakenness_field``. Args: samples: a :class:`fiftyone.core.collections.SampleCollection` pred_field: the name of the predicted label field to use from each sample. Can be of type :class:`fiftyone.core.labels.Classification`, :class:`fiftyone.core.labels.Classifications`, :class:`fiftyone.core.labels.Detections`, :class:`fiftyone.core.labels.Polylines`, :class:`fiftyone.core.labels.Keypoints`, or :class:`fiftyone.core.labels.TemporalDetections` label_field: the name of the "ground truth" label field that you want to test for mistakes with respect to the predictions in ``pred_field``. Must have the same type as ``pred_field`` mistakenness_field ("mistakenness"): the field name to use to store the mistakenness value for each sample missing_field ("possible_missing): the field in which to store per-sample counts of potential missing objects spurious_field ("possible_spurious): the field in which to store per-sample counts of potential spurious objects use_logits (False): whether to use logits (True) or confidence (False) to compute mistakenness. Logits typically yield better results, when they are available copy_missing (False): whether to copy predicted objects that were deemed to be missing into ``label_field`` progress (None): whether to render a progress bar (True/False), use the default value ``fiftyone.config.show_progress_bars`` (None), or a progress callback function to invoke instead """ import fiftyone.brain.internal.core.mistakenness as fbm return fbm.compute_mistakenness( samples, pred_field, label_field, mistakenness_field, missing_field, spurious_field, use_logits, copy_missing, progress, ) def compute_uniqueness( samples, uniqueness_field="uniqueness", roi_field=None, embeddings=None, similarity_index=None, model=None, model_kwargs=None, force_square=False, alpha=None, batch_size=None, num_workers=None, skip_failures=True, progress=None, ): """Adds a uniqueness field to each sample scoring how unique it is with respect to the rest of the samples. This function only uses the pixel data and can therefore process labeled or unlabeled samples. If no ``embeddings``, ``similarity_index``, or ``model`` is provided, a default model is used to generate embeddings. .. note:: Runs of this method can be referenced later via brain key ``uniqueness_field``. Args: samples: a :class:`fiftyone.core.collections.SampleCollection` uniqueness_field ("uniqueness"): the field name to use to store the uniqueness value for each sample roi_field (None): an optional :class:`fiftyone.core.labels.Detection`, :class:`fiftyone.core.labels.Detections`, :class:`fiftyone.core.labels.Polyline`, or :class:`fiftyone.core.labels.Polylines` field defining a region of interest within each image to use to compute uniqueness embeddings (None): if no ``model`` is provided, this argument specifies pre-computed embeddings to use, which can be any of the following: - a ``num_samples x num_dims`` array of embeddings - if ``roi_field`` is specified, a dict mapping sample IDs to ``num_patches x num_dims`` arrays of patch embeddings - the name of a dataset field containing the embeddings to use If a ``model`` is provided, this argument specifies the name of a field in which to store the computed embeddings. In either case, when working with patch embeddings, you can provide either the fully-qualified path to the patch embeddings or just the name of the label attribute in ``roi_field`` similarity_index (None): a :class:`fiftyone.brain.similarity.SimilarityIndex` or the brain key of a similarity index to use to load pre-computed embeddings model (None): a :class:`fiftyone.core.models.Model` or the name of a model from the `FiftyOne Model Zoo `_ to use to generate embeddings. The model must expose embeddings (``model.has_embeddings = True``) model_kwargs (None): a dictionary of optional keyword arguments to pass to the model's ``Config`` when a model name is provided force_square (False): whether to minimally manipulate the patch bounding boxes into squares prior to extraction. Only applicable when a ``model`` and ``roi_field`` are specified alpha (None): an optional expansion/contraction to apply to the patches before extracting them, in ``[-1, inf)``. If provided, the length and width of the box are expanded (or contracted, when ``alpha < 0``) by ``(100 * alpha)%``. For example, set ``alpha = 0.1`` to expand the boxes by 10%, and set ``alpha = -0.1`` to contract the boxes by 10%. Only applicable when a ``model`` and ``roi_field`` are specified batch_size (None): a batch size to use when computing embeddings. Only applicable when a ``model`` is provided num_workers (None): the number of workers to use when loading images. Only applicable when a Torch-based model is being used to compute embeddings skip_failures (True): whether to gracefully continue without raising an error if embeddings cannot be generated for a sample progress (None): whether to render a progress bar (True/False), use the default value ``fiftyone.config.show_progress_bars`` (None), or a progress callback function to invoke instead """ import fiftyone.brain.internal.core.uniqueness as fbu return fbu.compute_uniqueness( samples, uniqueness_field, roi_field, embeddings, similarity_index, model, model_kwargs, force_square, alpha, batch_size, num_workers, skip_failures, progress, ) def compute_representativeness( samples, representativeness_field="representativeness", method="cluster-center", roi_field=None, embeddings=None, similarity_index=None, model=None, model_kwargs=None, force_square=False, alpha=None, batch_size=None, num_workers=None, skip_failures=True, progress=None, ): """Adds a representativeness field to each sample scoring how representative of nearby samples it is. This function only uses the pixel data and can therefore process labeled or unlabeled samples. If no ``embeddings``, ``similarity_index``, or ``model`` is provided, a default model is used to generate embeddings. .. note:: Runs of this method can be referenced later via brain key ``representativeness_field``. Args: samples: a :class:`fiftyone.core.collections.SampleCollection` representativeness_field ("representativeness"): the field name to use to store the representativeness value for each sample method ("cluster-center"): the name of the method to use to compute the representativeness. The supported values are ``["cluster-center", 'cluster-center-downweight']``. ``"cluster-center"` will make a sample's representativeness proportional to it's proximity to cluster centers, while ``"cluster-center-downweight"`` will ensure more diversity in representative samples roi_field (None): an optional :class:`fiftyone.core.labels.Detection`, :class:`fiftyone.core.labels.Detections`, :class:`fiftyone.core.labels.Polyline`, or :class:`fiftyone.core.labels.Polylines` field defining a region of interest within each image to use to compute representativeness embeddings (None): if no ``model`` is provided, this argument specifies pre-computed embeddings to use, which can be any of the following: - a ``num_samples x num_dims`` array of embeddings - if ``roi_field`` is specified, a dict mapping sample IDs to ``num_patches x num_dims`` arrays of patch embeddings - the name of a dataset field containing the embeddings to use If a ``model`` is provided, this argument specifies the name of a field in which to store the computed embeddings. In either case, when working with patch embeddings, you can provide either the fully-qualified path to the patch embeddings or just the name of the label attribute in ``roi_field`` similarity_index (None): a :class:`fiftyone.brain.similarity.SimilarityIndex` or the brain key of a similarity index to use to load pre-computed embeddings model (None): a :class:`fiftyone.core.models.Model` or the name of a model from the `FiftyOne Model Zoo `_ to use to generate embeddings. The model must expose embeddings (``model.has_embeddings = True``) model_kwargs (None): a dictionary of optional keyword arguments to pass to the model's ``Config`` when a model name is provided force_square (False): whether to minimally manipulate the patch bounding boxes into squares prior to extraction. Only applicable when a ``model`` and ``roi_field`` are specified alpha (None): an optional expansion/contraction to apply to the patches before extracting them, in ``[-1, inf)``. If provided, the length and width of the box are expanded (or contracted, when ``alpha < 0``) by ``(100 * alpha)%``. For example, set ``alpha = 0.1`` to expand the boxes by 10%, and set ``alpha = -0.1`` to contract the boxes by 10%. Only applicable when a ``model`` and ``roi_field`` are specified batch_size (None): a batch size to use when computing embeddings. Only applicable when a ``model`` is provided num_workers (None): the number of workers to use when loading images. Only applicable when a Torch-based model is being used to compute embeddings skip_failures (True): whether to gracefully continue without raising an error if embeddings cannot be generated for a sample progress (None): whether to render a progress bar (True/False), use the default value ``fiftyone.config.show_progress_bars`` (None), or a progress callback function to invoke instead """ import fiftyone.brain.internal.core.representativeness as fbr return fbr.compute_representativeness( samples, representativeness_field, method, roi_field, embeddings, similarity_index, model, model_kwargs, force_square, alpha, batch_size, num_workers, skip_failures, progress, ) def compute_visualization( samples, patches_field=None, embeddings=None, points=None, create_index=False, points_field=None, brain_key=None, num_dims=2, method=None, similarity_index=None, model=None, model_kwargs=None, force_square=False, alpha=None, batch_size=None, num_workers=None, skip_failures=True, progress=None, **kwargs, ): """Computes a low-dimensional representation of the samples' media or their patches that can be interactively visualized. The representation can be visualized by calling the :meth:`visualize() ` method of the returned :class:`fiftyone.brain.visualization.VisualizationResults` object. If no ``embeddings``, ``similarity_index``, or ``model`` is provided, a default model is used to generate embeddings. You can use the ``method`` parameter to select the dimensionality reduction method to use, and you can optionally customize the method by passing additional parameters for the method's :class:`fiftyone.brain.visualization.VisualizationConfig` class as ``kwargs``. The builtin ``method`` values and their associated config classes are: - ``"umap"``: :class:`fiftyone.brain.visualization.UMAPVisualizationConfig` - ``"tsne"``: :class:`fiftyone.brain.visualization.TSNEVisualizationConfig` - ``"pca"``: :class:`fiftyone.brain.visualization.PCAVisualizationConfig` - ``"manual"``: :class:`fiftyone.brain.visualization.ManualVisualizationConfig` You can pass ``create_index=True`` to create a spatial index of the computed points on your dataset's samples. This is highly recommended for large datasets as it enables efficient querying when lassoing points in embeddings plots. By default, spatial indexes are created in a field with name ``points_field=brain_key``, but you can customize this by manually providing a ``points_field``. You can also provide a ``points_field`` with ``create_index=False`` to store the points on your dataset without explicitly creating a database index. This will allow lasso callbacks to leverage point data rather than relying on ID selection, but without the added benefit of a database index to further optimize performance. Args: samples: a :class:`fiftyone.core.collections.SampleCollection` patches_field (None): a sample field defining the image patches in each sample that have been/will be embedded. Must be of type :class:`fiftyone.core.labels.Detection`, :class:`fiftyone.core.labels.Detections`, :class:`fiftyone.core.labels.Polyline`, or :class:`fiftyone.core.labels.Polylines` embeddings (None): if no ``model`` is provided, this argument specifies pre-computed embeddings to use, which can be any of the following: - a dict mapping sample IDs to embedding vectors - a ``num_samples x num_embedding_dims`` array of embeddings corresponding to the samples in ``samples`` - if ``patches_field`` is specified, a dict mapping label IDs to to embedding vectors - if ``patches_field`` is specified, a dict mapping sample IDs to ``num_patches x num_embedding_dims`` arrays of patch embeddings - the name of a dataset field containing the embeddings to use If a ``model`` is provided, this argument specifies the name of a field in which to store the computed embeddings. In either case, when working with patch embeddings, you can provide either the fully-qualified path to the patch embeddings or just the name of the label attribute in ``patches_field`` points (None): a pre-computed low-dimensional representation to use. If provided, no embeddings will be used/computed. Can be any of the following: - a dict mapping sample IDs to points vectors - a ``num_samples x num_dims`` array of points corresponding to the samples in ``samples`` - if ``patches_field`` is specified, a dict mapping label IDs to points vectors - if ``patches_field`` is specified, a ``num_patches x num_dims`` array of points whose rows correspond to the flattened list of patches whose IDs are shown below:: # The list of patch IDs that the rows of `points` must match _, id_field = samples._get_label_field_path(patches_field, "id") patch_ids = samples.values(id_field, unwind=True) create_index (False): whether to create a spatial index for the computed points on your dataset points_field (None): an optional field name in which to store the spatial index. When ``create_index=True``, this defaults to ``points_field=brain_key``. When working with patches, you can provide either the fully-qualified path to the points field or just the name of the label attribute in ``patches_field`` brain_key (None): a brain key under which to store the results of this method num_dims (2): the dimension of the visualization space method (None): the dimensionality reduction method to use. The supported values are ``fiftyone.brain.brain_config.visualization_methods.keys()`` and the default is ``fiftyone.brain.brain_config.default_visualization_method`` similarity_index (None): a :class:`fiftyone.brain.similarity.SimilarityIndex` or the brain key of a similarity index to use to load pre-computed embeddings model (None): a :class:`fiftyone.core.models.Model` or the name of a model from the `FiftyOne Model Zoo `_ to use to generate embeddings. The model must expose embeddings (``model.has_embeddings = True``) model_kwargs (None): a dictionary of optional keyword arguments to pass to the model's ``Config`` when a model name is provided force_square (False): whether to minimally manipulate the patch bounding boxes into squares prior to extraction. Only applicable when a ``model`` and ``patches_field`` are specified alpha (None): an optional expansion/contraction to apply to the patches before extracting them, in ``[-1, inf)``. If provided, the length and width of the box are expanded (or contracted, when ``alpha < 0``) by ``(100 * alpha)%``. For example, set ``alpha = 0.1`` to expand the boxes by 10%, and set ``alpha = -0.1`` to contract the boxes by 10%. Only applicable when a ``model`` and ``patches_field`` are specified batch_size (None): an optional batch size to use when computing embeddings. Only applicable when a ``model`` is provided num_workers (None): the number of workers to use when loading images. Only applicable when a Torch-based model is being used to compute embeddings skip_failures (True): whether to gracefully continue without raising an error if embeddings cannot be generated for a sample progress (None): whether to render a progress bar (True/False), use the default value ``fiftyone.config.show_progress_bars`` (None), or a progress callback function to invoke instead **kwargs: optional keyword arguments for the constructor of the :class:`fiftyone.brain.visualization.VisualizationConfig` being used Returns: a :class:`fiftyone.brain.visualization.VisualizationResults` """ import fiftyone.brain.visualization as fbv return fbv.compute_visualization( samples, patches_field, embeddings, points, create_index, points_field, brain_key, num_dims, method, similarity_index, model, model_kwargs, force_square, alpha, batch_size, num_workers, skip_failures, progress, **kwargs, ) def compute_similarity( samples, patches_field=None, roi_field=None, embeddings=None, brain_key=None, model=None, model_kwargs=None, force_square=False, alpha=None, batch_size=None, num_workers=None, skip_failures=True, progress=None, backend=None, **kwargs, ): """Uses embeddings to index the samples or their patches so that you can query/sort by similarity. Calling this method only creates the index. You can then call the methods exposed on the retuned :class:`fiftyone.brain.similarity.SimilarityIndex` object to perform the following operations: - :meth:`sort_by_similarity() `: Sort the samples in the collection by similarity to a specific example or example(s) All indexes support querying by image similarity by passing sample IDs to :meth:`sort_by_similarity() `. In addition, if you pass the name of a model from the `FiftyOne Model Zoo `_ like ``model="clip-vit-base32-torch"`` that can embed prompts to this method, then you can query the index by text similarity as well. In addition, if the backend supports it, you can call the following duplicate detection methods: - :meth:`find_duplicates() `: Query the index to find all examples with near-duplicates in the collection - :meth:`find_unique() `: Query the index to select a subset of examples of a specified size that are maximally unique with respect to each other If no ``embeddings`` or ``model`` is provided, a default model is used to generate embeddings. Args: samples: a :class:`fiftyone.core.collections.SampleCollection` patches_field (None): a sample field defining the image patches in each sample that have been/will be embedded. Must be of type :class:`fiftyone.core.labels.Detection`, :class:`fiftyone.core.labels.Detections`, :class:`fiftyone.core.labels.Polyline`, or :class:`fiftyone.core.labels.Polylines` roi_field (None): an optional :class:`fiftyone.core.labels.Detection`, :class:`fiftyone.core.labels.Detections`, :class:`fiftyone.core.labels.Polyline`, or :class:`fiftyone.core.labels.Polylines` field defining a region of interest within each image to use to compute embeddings embeddings (None): embeddings to feed the index. This argument's behavior depends on whether a ``model`` is provided, as described below. If no ``model`` is provided, this argument specifies pre-computed embeddings to use: - a ``num_samples x num_dims`` array of embeddings - if ``patches_field``/``roi_field`` is specified, a dict mapping sample IDs to ``num_patches x num_dims`` arrays of patch embeddings - the name of a dataset field from which to load embeddings - ``None``: use the default model to compute embeddings - ``False``: **do not** compute embeddings right now If a ``model`` is provided, this argument specifies where to store the model's embeddings: - the name of a field in which to store the computed embeddings - ``False``: **do not** compute embeddings right now In either case, when working with patch embeddings, you can provide either the fully-qualified path to the patch embeddings or just the name of the label attribute in ``patches_field``/``roi_field`` brain_key (None): a brain key under which to store the results of this method model (None): a :class:`fiftyone.core.models.Model` or the name of a model from the `FiftyOne Model Zoo `_ to use, or that was already used, to generate embeddings. The model must expose embeddings (``model.has_embeddings = True``) model_kwargs (None): a dictionary of optional keyword arguments to pass to the model's ``Config`` when a model name is provided force_square (False): whether to minimally manipulate the patch bounding boxes into squares prior to extraction. Only applicable when a ``model`` and ``patches_field``/``roi_field`` are specified alpha (None): an optional expansion/contraction to apply to the patches before extracting them, in ``[-1, inf)``. If provided, the length and width of the box are expanded (or contracted, when ``alpha < 0``) by ``(100 * alpha)%``. For example, set ``alpha = 0.1`` to expand the boxes by 10%, and set ``alpha = -0.1`` to contract the boxes by 10%. Only applicable when a ``model`` and ``patches_field``/``roi_field`` are specified batch_size (None): an optional batch size to use when computing embeddings. Only applicable when a ``model`` is provided num_workers (None): the number of workers to use when loading images. Only applicable when a Torch-based model is being used to compute embeddings skip_failures (True): whether to gracefully continue without raising an error if embeddings cannot be generated for a sample progress (None): whether to render a progress bar (True/False), use the default value ``fiftyone.config.show_progress_bars`` (None), or a progress callback function to invoke instead backend (None): the similarity backend to use. The supported values are ``fiftyone.brain.brain_config.similarity_backends.keys()`` and the default is ``fiftyone.brain.brain_config.default_similarity_backend`` **kwargs: keyword arguments for the :class:`fiftyone.brian.SimilarityConfig` subclass of the backend being used Returns: a :class:`fiftyone.brain.similarity.SimilarityIndex` """ import fiftyone.brain.similarity as fbs return fbs.compute_similarity( samples, patches_field, roi_field, embeddings, brain_key, model, model_kwargs, force_square, alpha, batch_size, num_workers, skip_failures, progress, backend, **kwargs, ) def compute_near_duplicates( samples, threshold=0.2, roi_field=None, embeddings=None, similarity_index=None, model=None, model_kwargs=None, force_square=False, alpha=None, batch_size=None, num_workers=None, skip_failures=True, progress=None, ): """Detects potential duplicates in the given sample collection. Calling this method only initializes the index. You can then call the methods exposed on the returned object to perform the following operations: - :meth:`duplicate_ids `: A list of duplicate IDs - :meth:`neighbors_map `: A dictionary mapping IDs to lists of ``(dup_id, dist)`` tuples - :meth:`duplicates_view() `: Returns a view of all duplicates in the input collection Args: samples: a :class:`fiftyone.core.collections.SampleCollection` threshold (0.2): the similarity distance threshold to use when detecting duplicates. Values in ``[0.1, 0.25]`` work well for the default setup roi_field (None): an optional :class:`fiftyone.core.labels.Detection`, :class:`fiftyone.core.labels.Detections`, :class:`fiftyone.core.labels.Polyline`, or :class:`fiftyone.core.labels.Polylines` field defining a region of interest within each image to use to compute embeddings embeddings (None): if no ``model`` is provided, this argument specifies pre-computed embeddings to use, which can be any of the following: - a ``num_samples x num_dims`` array of embeddings - if ``roi_field`` is specified, a dict mapping sample IDs to ``num_patches x num_dims`` arrays of patch embeddings - the name of a dataset field containing the embeddings to use If a ``model`` is provided, this argument specifies the name of a field in which to store the computed embeddings. In either case, when working with patch embeddings, you can provide either the fully-qualified path to the patch embeddings or just the name of the label attribute in ``roi_field`` similarity_index (None): a :class:`fiftyone.brain.similarity.SimilarityIndex` or the brain key of a similarity index to use to load pre-computed embeddings model (None): a :class:`fiftyone.core.models.Model` or the name of a model from the `FiftyOne Model Zoo `_ to use to generate embeddings. The model must expose embeddings (``model.has_embeddings = True``) model_kwargs (None): a dictionary of optional keyword arguments to pass to the model's ``Config`` when a model name is provided force_square (False): whether to minimally manipulate the patch bounding boxes into squares prior to extraction. Only applicable when a ``model`` and ``roi_field`` are specified alpha (None): an optional expansion/contraction to apply to the patches before extracting them, in ``[-1, inf)``. If provided, the length and width of the box are expanded (or contracted, when ``alpha < 0``) by ``(100 * alpha)%``. For example, set ``alpha = 0.1`` to expand the boxes by 10%, and set ``alpha = -0.1`` to contract the boxes by 10%. Only applicable when a ``model`` and ``roi_field`` are specified batch_size (None): a batch size to use when computing embeddings. Only applicable when a ``model`` is provided num_workers (None): the number of workers to use when loading images. Only applicable when a Torch-based model is being used to compute embeddings skip_failures (True): whether to gracefully continue without raising an error if embeddings cannot be generated for a sample progress (None): whether to render a progress bar (True/False), use the default value ``fiftyone.config.show_progress_bars`` (None), or a progress callback function to invoke instead Returns: a :class:`fiftyone.brain.similarity.SimilarityIndex` """ import fiftyone.brain.internal.core.duplicates as fbd return fbd.compute_near_duplicates( samples, threshold=threshold, roi_field=roi_field, embeddings=embeddings, similarity_index=similarity_index, model=model, model_kwargs=model_kwargs, force_square=force_square, alpha=alpha, batch_size=batch_size, num_workers=num_workers, skip_failures=skip_failures, progress=progress, ) def compute_exact_duplicates( samples, num_workers=None, skip_failures=True, progress=None, ): """Detects duplicate media in a sample collection. This method detects exact duplicates with the same filehash. Use :meth:`compute_near_duplicates` to detect near-duplicates. If duplicates are found, the first instance in ``samples`` will be the key in the returned dictionary, while the subsequent duplicates will be the values in the corresponding list. Args: samples: a :class:`fiftyone.core.collections.SampleCollection` num_workers (None): an optional number of processes to use skip_failures (True): whether to gracefully ignore samples whose filehash cannot be computed progress (None): whether to render a progress bar (True/False), use the default value ``fiftyone.config.show_progress_bars`` (None), or a progress callback function to invoke instead Returns: a dictionary mapping IDs of samples with exact duplicates to lists of IDs of the duplicates for the corresponding sample """ import fiftyone.brain.internal.core.duplicates as fbd return fbd.compute_exact_duplicates( samples, num_workers, skip_failures, progress ) def compute_leaky_splits( samples, splits, threshold=0.2, roi_field=None, embeddings=None, similarity_index=None, model=None, model_kwargs=None, force_square=False, alpha=None, batch_size=None, num_workers=None, skip_failures=True, progress=None, ): """Computes potential leaks between splits of the given sample collection. Calling this method only initializes the index. You can then call the methods exposed on the returned object to perform the following operations: - :meth:`leaks_view() `: Returns a view of all leaks in the input collection - :meth:`no_leaks_view() `: Returns the subset of the input collection without any leaks - :meth:`leaks_for_sample() `: Returns a view with leaks corresponding to the given sample - :meth:`tag_leaks() `: Tags leaks in the dataset as leaks Args: samples: a :class:`fiftyone.core.collections.SampleCollection` splits: the dataset splits, specified in one of the following ways: - a list of tag strings - the name of a string/list field that encodes the split memberships - a dict mapping split names to :class:`fiftyone.core.view.DatasetView` instances threshold (0.2): the similarity distance threshold to use when detecting leaks. Values in ``[0.1, 0.25]`` work well for the default setup roi_field (None): an optional :class:`fiftyone.core.labels.Detection`, :class:`fiftyone.core.labels.Detections`, :class:`fiftyone.core.labels.Polyline`, or :class:`fiftyone.core.labels.Polylines` field defining a region of interest within each image to use to compute leaks embeddings (None): if no ``model`` is provided, this argument specifies pre-computed embeddings to use, which can be any of the following: - a ``num_samples x num_dims`` array of embeddings - if ``roi_field`` is specified, a dict mapping sample IDs to ``num_patches x num_dims`` arrays of patch embeddings - the name of a dataset field containing the embeddings to use If a ``model`` is provided, this argument specifies the name of a field in which to store the computed embeddings. In either case, when working with patch embeddings, you can provide either the fully-qualified path to the patch embeddings or just the name of the label attribute in ``roi_field`` similarity_index (None): a :class:`fiftyone.brain.similarity.SimilarityIndex` or the brain key of a similarity index to use to load pre-computed embeddings model (None): a :class:`fiftyone.core.models.Model` or the name of a model from the `FiftyOne Model Zoo `_ to use to generate embeddings. The model must expose embeddings (``model.has_embeddings = True``) model_kwargs (None): a dictionary of optional keyword arguments to pass to the model's ``Config`` when a model name is provided force_square (False): whether to minimally manipulate the patch bounding boxes into squares prior to extraction. Only applicable when a ``model`` and ``roi_field`` are specified alpha (None): an optional expansion/contraction to apply to the patches before extracting them, in ``[-1, inf)``. If provided, the length and width of the box are expanded (or contracted, when ``alpha < 0``) by ``(100 * alpha)%``. For example, set ``alpha = 0.1`` to expand the boxes by 10%, and set ``alpha = -0.1`` to contract the boxes by 10%. Only applicable when a ``model`` and ``roi_field`` are specified batch_size (None): a batch size to use when computing embeddings. Only applicable when a ``model`` is provided num_workers (None): the number of workers to use when loading images. Only applicable when a Torch-based model is being used to compute embeddings skip_failures (True): whether to gracefully continue without raising an error if embeddings cannot be generated for a sample progress (None): whether to render a progress bar (True/False), use the default value ``fiftyone.config.show_progress_bars`` (None), or a progress callback function to invoke instead Returns: a :class:`fiftyone.brain.internal.core.leaky_splits.LeakySplitsIndex` """ import fiftyone.brain.internal.core.leaky_splits as fbl return fbl.compute_leaky_splits( samples, splits, threshold=threshold, roi_field=roi_field, embeddings=embeddings, similarity_index=similarity_index, model=model, model_kwargs=model_kwargs, force_square=force_square, alpha=alpha, batch_size=batch_size, num_workers=num_workers, skip_failures=skip_failures, progress=progress, ) ================================================ FILE: fiftyone/brain/config.py ================================================ """ Brain config. | Copyright 2017-2026, Voxel51, Inc. | `voxel51.com `_ | """ import os from fiftyone.core.config import EnvConfig class BrainConfig(EnvConfig): """FiftyOne brain configuration settings.""" _BUILTIN_SIMILARITY_BACKENDS = { "sklearn": { "config_cls": "fiftyone.brain.internal.core.sklearn.SklearnSimilarityConfig", }, "pinecone": { "config_cls": "fiftyone.brain.internal.core.pinecone.PineconeSimilarityConfig", }, "qdrant": { "config_cls": "fiftyone.brain.internal.core.qdrant.QdrantSimilarityConfig", }, "milvus": { "config_cls": "fiftyone.brain.internal.core.milvus.MilvusSimilarityConfig", }, "lancedb": { "config_cls": "fiftyone.brain.internal.core.lancedb.LanceDBSimilarityConfig", }, "redis": { "config_cls": "fiftyone.brain.internal.core.redis.RedisSimilarityConfig", }, "mongodb": { "config_cls": "fiftyone.brain.internal.core.mongodb.MongoDBSimilarityConfig", }, "elasticsearch": { "config_cls": "fiftyone.brain.internal.core.elasticsearch.ElasticsearchSimilarityConfig", }, "pgvector": { "config_cls": "fiftyone.brain.internal.core.pgvector.PgVectorSimilarityConfig", }, "mosaic": { "config_cls": "fiftyone.brain.internal.core.mosaic.MosaicSimilarityConfig", }, } _BUILTIN_VISUALIZATION_METHODS = { "umap": { "config_cls": "fiftyone.brain.visualization.UMAPVisualizationConfig", }, "tsne": { "config_cls": "fiftyone.brain.visualization.TSNEVisualizationConfig", }, "pca": { "config_cls": "fiftyone.brain.visualization.PCAVisualizationConfig", }, "manual": { "config_cls": "fiftyone.brain.visualization.ManualVisualizationConfig", }, } def __init__(self, d=None): if d is None: d = {} self.default_similarity_backend = self.parse_string( d, "default_similarity_backend", env_var="FIFTYONE_BRAIN_DEFAULT_SIMILARITY_BACKEND", default="sklearn", ) self.similarity_backends = self._parse_similarity_backends(d) if self.default_similarity_backend not in self.similarity_backends: self.default_similarity_backend = next( iter(sorted(self.similarity_backends.keys())), None ) self.default_visualization_method = self.parse_string( d, "default_visualization_method", env_var="FIFTYONE_BRAIN_DEFAULT_VISUALIZATION_METHOD", default="umap", ) self.visualization_methods = self._parse_visualization_methods(d) if self.default_visualization_method not in self.visualization_methods: self.default_visualization_method = next( iter(sorted(self.visualization_methods.keys())), None ) def _parse_similarity_backends(self, d): d = d.get("similarity_backends", {}) env_vars = dict(os.environ) # # `FIFTYONE_BRAIN_SIMILARITY_BACKENDS` can be used to declare which # backends are exposed. This may exclude builtin backends and/or # declare new backends # if "FIFTYONE_BRAIN_SIMILARITY_BACKENDS" in env_vars: backends = env_vars["FIFTYONE_BRAIN_SIMILARITY_BACKENDS"].split( "," ) # Special syntax to append rather than override default backends if "*" in backends: backends = set(b for b in backends if b != "*") backends |= set(self._BUILTIN_SIMILARITY_BACKENDS.keys()) d = {backend: d.get(backend, {}) for backend in backends} else: backends = self._BUILTIN_SIMILARITY_BACKENDS.keys() for backend in backends: if backend not in d: d[backend] = {} # # Extract parameters from any environment variables of the form # `FIFTYONE_BRAIN_SIMILARITY__` # for backend, d_backend in d.items(): prefix = "FIFTYONE_BRAIN_SIMILARITY_%s_" % backend.upper() for env_name, env_value in env_vars.items(): if env_name.startswith(prefix): name = env_name[len(prefix) :].lower() value = _parse_env_value(env_value) d_backend[name] = value # # Set default parameters for builtin similarity backends # for backend, defaults in self._BUILTIN_SIMILARITY_BACKENDS.items(): if backend not in d: continue d_backend = d[backend] for name, value in defaults.items(): if name not in d_backend: d_backend[name] = value return d def _parse_visualization_methods(self, d): d = d.get("visualization_methods", {}) env_vars = dict(os.environ) # # `FIFTYONE_BRAIN_VISUALIZATION_METHODS` can be used to declare which # methods are exposed. This may exclude builtin methods and/or declare # new methods # if "FIFTYONE_BRAIN_VISUALIZATION_METHODS" in env_vars: methods = env_vars["FIFTYONE_BRAIN_VISUALIZATION_METHODS"].split( "," ) # Special syntax to append rather than override default methods if "*" in methods: methods = set(m for m in methods if m != "*") methods |= set(self._BUILTIN_VISUALIZATION_METHODS.keys()) d = {method: d.get(method, {}) for method in methods} else: methods = self._BUILTIN_VISUALIZATION_METHODS.keys() for method in methods: if method not in d: d[method] = {} # # Extract parameters from any environment variables of the form # `FIFTYONE_BRAIN_VISUALIZATION__` # for method, d_method in d.items(): prefix = "FIFTYONE_BRAIN_VISUALIZATION_%s_" % method.upper() for env_name, env_value in env_vars.items(): if env_name.startswith(prefix): name = env_name[len(prefix) :].lower() value = _parse_env_value(env_value) d_method[name] = value # # Set default parameters for builtin visualization methods # for method, defaults in self._BUILTIN_VISUALIZATION_METHODS.items(): if method not in d: continue d_method = d[method] for name, value in defaults.items(): if name not in d_method: d_method[name] = value return d def locate_brain_config(): """Returns the path to the :class:`BrainConfig` on disk. The default location is ``~/.fiftyone/brain_config.json``, but you can override this path by setting the ``FIFTYONE_BRAIN_CONFIG_PATH`` environment variable. Note that a config file may not actually exist on disk. Returns: the path to the :class:`BrainConfig` on disk """ if "FIFTYONE_BRAIN_CONFIG_PATH" not in os.environ: return os.path.join( os.path.expanduser("~"), ".fiftyone", "brain_config.json" ) return os.environ["FIFTYONE_BRAIN_CONFIG_PATH"] def load_brain_config(): """Loads the FiftyOne brain config. Returns: a :class:`BrainConfig` instance """ brain_config_path = locate_brain_config() if os.path.isfile(brain_config_path): return BrainConfig.from_json(brain_config_path) return BrainConfig() def _parse_env_value(value): try: return int(value) except: pass try: return float(value) except: pass if value in ("True", "true"): return True if value in ("False", "false"): return False if value in ("None", ""): return None if "," in value: return [_parse_env_value(v) for v in value.split(",")] return value ================================================ FILE: fiftyone/brain/internal/__init__.py ================================================ """ Internal FiftyOne Brain package. Contains all non-public code powering the ``fiftyone.brain`` public namespace. | Copyright 2017-2026, Voxel51, Inc. | `voxel51.com `_ | """ ================================================ FILE: fiftyone/brain/internal/core/__init__.py ================================================ """ | Copyright 2017-2026, Voxel51, Inc. | `voxel51.com `_ | """ ================================================ FILE: fiftyone/brain/internal/core/duplicates.py ================================================ """ Duplicates methods. | Copyright 2017-2026, Voxel51, Inc. | `voxel51.com `_ | """ from collections import defaultdict import itertools import logging import multiprocessing import eta.core.utils as etau import fiftyone.core.media as fom import fiftyone.core.utils as fou import fiftyone.core.validation as fov import fiftyone.brain as fb import fiftyone.brain.similarity as fbs import fiftyone.brain.internal.core.utils as fbu logger = logging.getLogger(__name__) _DEFAULT_MODEL = "resnet18-imagenet-torch" def compute_near_duplicates( samples, threshold=None, roi_field=None, embeddings=None, similarity_index=None, model=None, model_kwargs=None, force_square=False, alpha=None, batch_size=None, num_workers=None, skip_failures=True, progress=None, ): """See ``fiftyone/brain/__init__.py``.""" fov.validate_collection(samples) if etau.is_str(embeddings): embeddings_field, embeddings_exist = fbu.parse_data_field( samples, embeddings, data_type="embeddings", ) embeddings = None else: embeddings_field = None embeddings_exist = None if etau.is_str(similarity_index): similarity_index = samples.load_brain_results(similarity_index) if ( model is None and embeddings is None and similarity_index is None and not embeddings_exist ): model = _DEFAULT_MODEL if similarity_index is None: similarity_index = fb.compute_similarity( samples, backend="sklearn", roi_field=roi_field, embeddings=embeddings_field or embeddings, model=model, model_kwargs=model_kwargs, force_square=force_square, alpha=alpha, batch_size=batch_size, num_workers=num_workers, skip_failures=skip_failures, progress=progress, ) elif not isinstance(similarity_index, fbs.DuplicatesMixin): raise ValueError( "This method only supports similarity indexes that implement the " "%s mixin" % fbs.DuplicatesMixin ) similarity_index.find_duplicates(thresh=threshold) return similarity_index def compute_exact_duplicates(samples, num_workers, skip_failures, progress): """See ``fiftyone/brain/__init__.py``.""" fov.validate_collection(samples) if num_workers is None: if samples.media_type == fom.VIDEO: num_workers = multiprocessing.cpu_count() else: num_workers = 1 logger.info("Computing filehashes...") method = "md5" if samples.media_type == fom.VIDEO else None if num_workers <= 1: hashes = _compute_filehashes(samples, method, progress) else: hashes = _compute_filehashes_multi( samples, method, num_workers, progress ) num_missing = sum(h is None for h in hashes) if num_missing > 0: msg = "Failed to compute %d filehashes" % num_missing if skip_failures: logger.warning(msg) else: raise ValueError(msg) neighbors_map = defaultdict(list) observed_hashes = {} for _id, _hash in hashes.items(): if _hash is None: continue if _hash in observed_hashes: neighbors_map[observed_hashes[_hash]].append(_id) else: observed_hashes[_hash] = _id return dict(neighbors_map) def _compute_filehashes(samples, method, progress): ids, filepaths = samples.values(["id", "filepath"]) with fou.ProgressBar(total=len(ids), progress=progress) as pb: return { _id: _compute_filehash(filepath, method) for _id, filepath in pb(zip(ids, filepaths)) } def _compute_filehashes_multi(samples, method, num_workers, progress): ids, filepaths = samples.values(["id", "filepath"]) methods = itertools.repeat(method) inputs = list(zip(ids, filepaths, methods)) with fou.ProgressBar(total=len(inputs), progress=progress) as pb: with multiprocessing.Pool(processes=num_workers) as pool: return { k: v for k, v in pb( pool.imap_unordered(_do_compute_filehash, inputs) ) } def _compute_filehash(filepath, method): try: filehash = fou.compute_filehash(filepath, method=method) except: filehash = None return filehash def _do_compute_filehash(args): _id, filepath, method = args try: filehash = fou.compute_filehash(filepath, method=method) except: filehash = None return _id, filehash ================================================ FILE: fiftyone/brain/internal/core/elasticsearch.py ================================================ """ Elastisearch similarity backend. | Copyright 2017-2026, Voxel51, Inc. | `voxel51.com `_ | """ import logging import numpy as np import eta.core.utils as etau from fiftyone import ViewField as F import fiftyone.core.utils as fou import fiftyone.brain.internal.core.utils as fbu from fiftyone.brain.similarity import ( SimilarityConfig, Similarity, SimilarityIndex, ) es = fou.lazy_import("elasticsearch") logger = logging.getLogger(__name__) _SUPPORTED_METRICS = { "cosine": "cosine", "dotproduct": "dot_product", "euclidean": "l2_norm", "innerproduct": "max_inner_product", } class ElasticsearchSimilarityConfig(SimilarityConfig): """Configuration for a Elasticsearch similarity instance. Args: index_name (None): the name of the Elasticsearch index to use or create. If none is provided, a new index will be created metric ("cosine"): the embedding distance metric to use when creating a new index. Supported values are ``("cosine", "dotproduct", "euclidean", "innerproduct")`` hosts (None): the full Elasticsearch server address(es) to use. Can be a string or list of strings cloud_id (None): the Cloud ID of an Elastic Cloud to connect to username (None): a username to use password (None): a password to use api_key (None): an API key to use ca_certs (None): a path to a CA certificate bearer_auth (None): a bearer token to use ssl_assert_fingerprint (None): a SHA256 fingerprint to use verify_certs (None): whether to verify SSL certificates **kwargs: keyword arguments for :class:`fiftyone.brain.similarity.SimilarityConfig` """ def __init__( self, index_name=None, metric="cosine", hosts=None, cloud_id=None, username=None, password=None, api_key=None, ca_certs=None, bearer_auth=None, ssl_assert_fingerprint=None, verify_certs=None, **kwargs, ): if metric not in _SUPPORTED_METRICS: raise ValueError( "Unsupported metric '%s'. Supported values are %s" % (metric, tuple(_SUPPORTED_METRICS.keys())) ) super().__init__(**kwargs) self.index_name = index_name self.metric = metric self._hosts = hosts self._cloud_id = cloud_id self._username = username self._password = password self._api_key = api_key self._ca_certs = ca_certs self._bearer_auth = bearer_auth self._ssl_assert_fingerprint = ssl_assert_fingerprint self._verify_certs = verify_certs @property def method(self): return "elasticsearch" @property def hosts(self): return self._hosts @hosts.setter def hosts(self, value): self._hosts = value @property def cloud_id(self): return self._cloud_id @cloud_id.setter def cloud_id(self, value): self._cloud_id = value @property def username(self): return self._username @username.setter def username(self, value): self._username = value @property def password(self): return self._password @password.setter def password(self, value): self._password = value @property def api_key(self): return self._api_key @api_key.setter def api_key(self, value): self._api_key = value @property def ca_certs(self): return self._ca_certs @ca_certs.setter def ca_certs(self, value): self._ca_certs = value @property def bearer_auth(self): return self._bearer_auth @bearer_auth.setter def bearer_auth(self, value): self._bearer_auth = value @property def ssl_assert_fingerprint(self): return self._ssl_assert_fingerprint @ssl_assert_fingerprint.setter def ssl_assert_fingerprint(self, value): self._ssl_assert_fingerprint = value @property def verify_certs(self): return self._verify_certs @verify_certs.setter def verify_certs(self, value): self._verify_certs = value @property def max_k(self): return 10000 # Elasticsearch limit @property def supports_least_similarity(self): return False @property def supported_aggregations(self): return ("mean",) def load_credentials( self, hosts=None, cloud_id=None, username=None, password=None, api_key=None, ca_certs=None, bearer_auth=None, ssl_assert_fingerprint=None, verify_certs=None, ): self._load_parameters( hosts=hosts, cloud_id=cloud_id, username=username, password=password, api_key=api_key, ca_certs=ca_certs, bearer_auth=bearer_auth, ssl_assert_fingerprint=ssl_assert_fingerprint, verify_certs=verify_certs, ) class ElasticsearchSimilarity(Similarity): """Elasticsearch similarity factory. Args: config: a :class:`ElasticsearchSimilarityConfig` """ def ensure_requirements(self): fou.ensure_package("elasticsearch") def ensure_usage_requirements(self): fou.ensure_package("elasticsearch") def initialize(self, samples, brain_key): return ElasticsearchSimilarityIndex( samples, self.config, brain_key, backend=self ) class ElasticsearchSimilarityIndex(SimilarityIndex): """Class for interacting with Elasticsearch similarity indexes. Args: samples: the :class:`fiftyone.core.collections.SampleCollection` used config: the :class:`ElasticsearchSimilarityConfig` used brain_key: the brain key backend (None): a :class:`ElasticsearchSimilarity` instance """ def __init__(self, samples, config, brain_key, backend=None): super().__init__(samples, config, brain_key, backend=backend) self._client = None self._metric = None self._initialize() @property def total_index_size(self): try: return self._client.count(index=self.config.index_name)["count"] except: return 0 @property def client(self): """The ``elasticsearch.Elasticsearch`` instance for this index.""" return self._client def _initialize(self): kwargs = {} for key in ( "hosts", "cloud_id", "username", "password", "api_key", "ca_certs", "bearer_auth", "ssl_assert_fingerprint", "verify_certs", ): value = getattr(self.config, key, None) if value is not None: kwargs[key] = value username = kwargs.pop("username", None) password = kwargs.pop("password", None) if username is not None and password is not None: kwargs["basic_auth"] = (username, password) try: self._client = es.Elasticsearch(**kwargs) except Exception as e: raise ValueError( "Failed to connect to Elasticsearch backend. Refer to " "https://docs.voxel51.com/integrations/elasticsearch.html for more " "information" ) from e if self.config.index_name is None: root = "fiftyone-" + fou.to_slug(self.samples._root_dataset.name) index_name = fbu.get_unique_name(root, self._get_index_names()) self.config.index_name = index_name self.save_config() def _get_index_names(self): return self._client.indices.get_alias().keys() def _get_index_ids(self, batch_size=1000): sample_ids = [] label_ids = [] for batch in range(0, self.total_index_size, batch_size): response = self._client.search( index=self.config.index_name, body={ "fields": ["sample_id"], "from": batch, "query": { "bool": { "must": [ {"exists": {"field": "vector"}}, {"exists": {"field": "sample_id"}}, ] } }, }, source=False, size=batch_size, ) for doc in response["hits"]["hits"]: sample_id = doc["fields"]["sample_id"][0] sample_or_label_id = doc["_id"] sample_ids.append(sample_id) label_ids.append(sample_or_label_id) return sample_ids, label_ids def _get_dimension(self): if self.total_index_size == 0: return None if self.config.patches_field is not None: embeddings, _, _ = self.get_embeddings( label_ids=self._label_ids[:1] ) else: embeddings, _, _ = self.get_embeddings( sample_ids=self._sample_ids[:1] ) return embeddings.shape[1] def _get_metric(self): if self._metric is None: try: # We must ask ES rather than using `self.config.metric` because # we may be working with a preexisting index self._metric = self._client.indices.get_mapping( index=self.config.index_name )[self.config.index_name]["mappings"]["properties"]["vector"][ "similarity" ] except: logger.warning( "Failed to infer similarity metric from index '%s'", self.config.index_name, ) return self._metric def _index_exists(self): if self.config.index_name is None: return False return self.config.index_name in self._get_index_names() def _create_index(self, dimension): metric = _SUPPORTED_METRICS[self.config.metric] mappings = { "properties": { "vector": { "type": "dense_vector", "dims": dimension, "index": "true", "similarity": metric, } } } self._client.indices.create( index=self.config.index_name, mappings=mappings ) self._metric = metric def _get_existing_ids(self, ids): docs = [{"_index": self.config.index_name, "_id": i} for i in ids] resp = self._client.mget(docs=docs) return [d["_id"] for d in resp["docs"] if d["found"]] def add_to_index( self, embeddings, sample_ids, label_ids=None, overwrite=True, allow_existing=True, warn_existing=False, reload=True, batch_size=500, ): if not self._index_exists(): self._create_index(embeddings.shape[1]) if label_ids is not None: ids = label_ids else: ids = sample_ids if warn_existing or not allow_existing or not overwrite: existing_ids = self._get_existing_ids(ids) num_existing = len(existing_ids) if num_existing > 0: if not allow_existing: raise ValueError( "Found %d IDs (eg %s) that already exist in the index" % (num_existing, next(iter(existing_ids))) ) if warn_existing: if overwrite: logger.warning( "Overwriting %d IDs that already exist in the " "index", num_existing, ) else: logger.warning( "Skipping %d IDs that already exist in the index", num_existing, ) else: existing_ids = set() if existing_ids and not overwrite: del_inds = [i for i, _id in enumerate(ids) if _id in existing_ids] embeddings = np.delete(embeddings, del_inds, axis=0) sample_ids = np.delete(sample_ids, del_inds) if label_ids is not None: label_ids = np.delete(label_ids, del_inds) if self._get_metric() == _SUPPORTED_METRICS["dotproduct"]: embeddings /= np.linalg.norm(embeddings, axis=1)[:, np.newaxis] embeddings = [e.tolist() for e in embeddings] sample_ids = list(sample_ids) if label_ids is not None: ids = list(label_ids) else: ids = list(sample_ids) for _embeddings, _ids, _sample_ids in zip( fou.iter_batches(embeddings, batch_size), fou.iter_batches(ids, batch_size), fou.iter_batches(sample_ids, batch_size), ): operations = [] for _e, _id, _sid in zip(_embeddings, _ids, _sample_ids): operations.append( {"index": {"_index": self.config.index_name, "_id": _id}} ) operations.append({"sample_id": _sid, "vector": _e}) self._client.bulk( index=self.config.index_name, operations=operations, refresh=True, ) if reload: self.reload() def remove_from_index( self, sample_ids=None, label_ids=None, allow_missing=True, warn_missing=False, reload=True, ): if label_ids is not None: ids = label_ids else: ids = sample_ids if not allow_missing or warn_missing: existing_ids = self._get_existing_ids(ids) missing_ids = set(ids) - set(existing_ids) num_missing = len(missing_ids) if num_missing > 0: if not allow_missing: raise ValueError( "Found %d IDs (eg %s) that are not present in the " "index" % (num_missing, next(iter(missing_ids))) ) if warn_missing: logger.warning( "Ignoring %d IDs that are not present in the index", num_missing, ) ids = existing_ids operations = [ {"delete": {"_index": self.config.index_name, "_id": i}} for i in ids ] self._client.bulk(body=operations, refresh=True) if reload: self.reload() def get_embeddings( self, sample_ids=None, label_ids=None, allow_missing=True, warn_missing=False, ): if label_ids is not None: if self.config.patches_field is None: raise ValueError("This index does not support label IDs") if sample_ids is not None: logger.warning( "Ignoring sample IDs when label IDs are provided" ) if sample_ids is not None and self.config.patches_field is not None: ( embeddings, sample_ids, label_ids, missing_ids, ) = self._get_patch_embeddings_from_sample_ids(sample_ids) elif self.config.patches_field is not None: ( embeddings, sample_ids, label_ids, missing_ids, ) = self._get_patch_embeddings_from_label_ids(label_ids) else: ( embeddings, sample_ids, label_ids, missing_ids, ) = self._get_sample_embeddings(sample_ids) num_missing_ids = len(missing_ids) if num_missing_ids > 0: if not allow_missing: raise ValueError( "Found %d IDs (eg %s) that do not exist in the index" % (num_missing_ids, missing_ids[0]) ) if warn_missing: logger.warning( "Skipping %d IDs that do not exist in the index", num_missing_ids, ) embeddings = np.array(embeddings) sample_ids = np.array(sample_ids) if label_ids is not None: label_ids = np.array(label_ids) return embeddings, sample_ids, label_ids def _parse_embeddings_response(self, response, label_id=True): found_embeddings = [] found_sample_ids = [] found_label_ids = [] for r in response: if r.get("found", True): found_embeddings.append(r["_source"]["vector"]) if label_id: found_sample_ids.append(r["_source"]["sample_id"]) found_label_ids.append(r["_id"]) else: found_sample_ids.append(r["_id"]) return found_embeddings, found_sample_ids, found_label_ids def _get_sample_embeddings(self, sample_ids, batch_size=1000): found_embeddings = [] found_sample_ids = [] if sample_ids is None: sample_ids, label_ids = self._get_index_ids(batch_size=batch_size) for batch_ids in fou.iter_batches(sample_ids, batch_size): response = self._client.mget( index=self.config.index_name, ids=batch_ids, source=True ) ( _found_embeddings, _found_sample_ids, _, ) = self._parse_embeddings_response( response["docs"], label_id=False ) found_embeddings += _found_embeddings found_sample_ids += _found_sample_ids missing_ids = list(set(sample_ids) - set(found_sample_ids)) return found_embeddings, found_sample_ids, None, missing_ids def _get_patch_embeddings_from_label_ids(self, label_ids, batch_size=1000): found_embeddings = [] found_sample_ids = [] found_label_ids = [] if label_ids is None: sample_ids, label_ids = self._get_index_ids(batch_size=batch_size) for batch_ids in fou.iter_batches(label_ids, batch_size): response = self._client.mget( index=self.config.index_name, ids=batch_ids, source=True ) ( _found_embeddings, _found_sample_ids, _found_label_ids, ) = self._parse_embeddings_response(response["docs"]) found_embeddings += _found_embeddings found_sample_ids += _found_sample_ids found_label_ids += _found_label_ids missing_ids = list(set(label_ids) - set(found_label_ids)) return found_embeddings, found_sample_ids, found_label_ids, missing_ids def _get_patch_embeddings_from_sample_ids( self, sample_ids, batch_size=100 ): found_embeddings = [] found_sample_ids = [] found_label_ids = [] if sample_ids is None: sample_ids, label_ids = self._get_index_ids(batch_size=batch_size) for batch_ids in fou.iter_batches(sample_ids, batch_size): response = self._client.search( index=self.config.index_name, body={"query": {"terms": {"sample_id": sample_ids}}}, ) ( _found_embeddings, _found_sample_ids, _found_label_ids, ) = self._parse_embeddings_response(response["hits"]["hits"]) found_embeddings += _found_embeddings found_sample_ids += _found_sample_ids found_label_ids += _found_label_ids missing_ids = list(set(sample_ids) - set(found_sample_ids)) return found_embeddings, found_sample_ids, found_label_ids, missing_ids def cleanup(self): self._client.indices.delete( index=self.config.index_name, ignore_unavailable=True ) def _kneighbors( self, query=None, k=None, reverse=False, aggregation=None, return_dists=False, ): if query is None: raise ValueError( "Elasticsearch does not support full index neighbors" ) if reverse is True: raise ValueError( "Elasticsearch does not support least similarity queries" ) if aggregation not in (None, "mean"): raise ValueError( f"Elasticsearch does not support {aggregation} aggregation" ) query = self._parse_neighbors_query(query) if aggregation == "mean" and query.ndim == 2: query = query.mean(axis=0) single_query = query.ndim == 1 if single_query: query = [query] if self.has_view: if self.config.patches_field is not None: index_ids = self.current_label_ids else: index_ids = self.current_sample_ids _filter = {"terms": {"_id": list(index_ids)}} else: _filter = None sample_ids = [] label_ids = [] if self.config.patches_field is not None else None dists = [] for q in query: if self._get_metric() == _SUPPORTED_METRICS["dotproduct"]: q /= np.linalg.norm(q) knn = { "field": "vector", "query_vector": q.tolist(), "k": k, "num_candidates": 10 * k, } if _filter: knn["filter"] = _filter source = self.config.patches_field is not None response = self._client.search( index=self.config.index_name, knn=knn, size=k, source=source, ) if self.config.patches_field is not None: sample_ids.append( [ r["_source"]["sample_id"] for r in response["hits"]["hits"] ] ) label_ids.append([r["_id"] for r in response["hits"]["hits"]]) else: sample_ids.append([r["_id"] for r in response["hits"]["hits"]]) if return_dists: dists.append([r["_score"] for r in response["hits"]["hits"]]) if single_query: sample_ids = sample_ids[0] if label_ids is not None: label_ids = label_ids[0] if return_dists: dists = dists[0] if return_dists: return sample_ids, label_ids, dists return sample_ids, label_ids def _parse_neighbors_query(self, query): if etau.is_str(query): query_ids = [query] single_query = True else: query = np.asarray(query) # Query by vector(s) if np.issubdtype(query.dtype, np.number): return query query_ids = list(query) single_query = False # Query by ID(s) response = self._client.mget( index=self.config.index_name, ids=query_ids, source=True ) query = np.array( [r["_source"]["vector"] for r in response["docs"] if r["found"]] ) if query.size == 0: raise ValueError( "Query IDs %s were not found in the index" % query_ids ) if single_query: query = query[0, :] return query @classmethod def _from_dict(cls, d, samples, config, brain_key): return cls(samples, config, brain_key) ================================================ FILE: fiftyone/brain/internal/core/hardness.py ================================================ """ Hardness methods. | Copyright 2017-2026, Voxel51, Inc. | `voxel51.com `_ | """ import logging import numpy as np from scipy.special import softmax from scipy.stats import entropy import fiftyone.core.brain as fob import fiftyone.core.labels as fol import fiftyone.core.media as fom import fiftyone.core.utils as fou import fiftyone.core.validation as fov logger = logging.getLogger(__name__) _ALLOWED_TYPES = (fol.Classification, fol.Classifications) def compute_hardness(samples, label_field, hardness_field, progress): """See ``fiftyone/brain/__init__.py``.""" # # Algorithm # # Hardness is computed directly as the entropy of the logits # fov.validate_collection(samples) fov.validate_collection_label_fields(samples, label_field, _ALLOWED_TYPES) if samples.media_type == fom.VIDEO: hardness_field, _ = samples._handle_frame_field(hardness_field) config = HardnessConfig(label_field, hardness_field) brain_key = hardness_field brain_method = config.build() brain_method.ensure_requirements() brain_method.register_run(samples, brain_key, cleanup=False) brain_method.register_samples(samples) view = samples.select_fields(label_field) processing_frames = samples._is_frame_field(label_field) logger.info("Computing hardness...") for sample in view.iter_samples(progress=progress): if processing_frames: images = sample.frames.values() else: images = [sample] sample_hardness = [] for image in images: hardness = brain_method.process_image(image) if hardness is not None: sample_hardness.append(hardness) if processing_frames: image[hardness_field] = hardness if sample_hardness: sample[hardness_field] = np.max(sample_hardness) else: sample[hardness_field] = None sample.save() brain_method.save_run_results(samples, brain_key, None) logger.info("Hardness computation complete") # @todo move to `fiftyone/brain/hardness.py` class HardnessConfig(fob.BrainMethodConfig): def __init__(self, label_field, hardness_field, **kwargs): self.label_field = label_field self.hardness_field = hardness_field super().__init__(**kwargs) @property def type(self): return "mistakenness" @property def method(self): return "entropy" class Hardness(fob.BrainMethod): def __init__(self, config): super().__init__(config) self.label_field = None def ensure_requirements(self): pass def register_samples(self, samples): self.label_field, _ = samples._handle_frame_field( self.config.label_field ) def process_image(self, sample_or_frame): label = _get_data(sample_or_frame, self.label_field) if label is None: return None return entropy(softmax(np.asarray(label.logits))) def get_fields(self, samples, brain_key): label_field = self.config.label_field hardness_field = self.config.hardness_field fields = [label_field, hardness_field] if samples._is_frame_field(label_field): fields.append(samples._FRAMES_PREFIX + hardness_field) return fields def cleanup(self, samples, brain_key): label_field = self.config.label_field hardness_field = self.config.hardness_field samples._dataset.delete_sample_fields(hardness_field, error_level=1) if samples._is_frame_field(label_field): samples._dataset.delete_frame_fields(hardness_field, error_level=1) def _validate_run(self, samples, brain_key, existing_info): self._validate_fields_match(brain_key, "hardness_field", existing_info) def _get_data(sample, label_field): label = sample[label_field] if label is None: return None if label.logits is None: raise ValueError( "Sample '%s' field '%s' has no logits" % (sample.id, label_field) ) return label ================================================ FILE: fiftyone/brain/internal/core/lancedb.py ================================================ """ LanceDB similarity backend. | Copyright 2017-2026, Voxel51, Inc. | `voxel51.com `_ | """ import logging import numpy as np import eta.core.utils as etau import fiftyone.core.storage as fos import fiftyone.core.utils as fou import fiftyone.brain.internal.core.utils as fbu from fiftyone.brain.similarity import ( SimilarityConfig, Similarity, SimilarityIndex, ) lancedb = fou.lazy_import("lancedb") pa = fou.lazy_import("pyarrow") _SUPPORTED_METRICS = { "cosine": "cosine", "euclidean": "l2", } logger = logging.getLogger(__name__) class LanceDBSimilarityConfig(SimilarityConfig): """Configuration for a LanceDB similarity instance. Args: table_name (None): the name of the LanceDB table to use. If none is provided, a new table will be created metric ("cosine"): the embedding distance metric to use when creating a new index. Supported values are ``("cosine", "euclidean")`` uri ("/tmp/lancedb"): the database URI to use **kwargs: keyword arguments for :class:`SimilarityConfig` """ def __init__( self, table_name=None, metric="cosine", uri="/tmp/lancedb", **kwargs, ): if metric not in _SUPPORTED_METRICS: raise ValueError( "Unsupported metric '%s'. Supported values are %s" % (metric, tuple(_SUPPORTED_METRICS.keys())) ) super().__init__(**kwargs) self.table_name = table_name self.metric = metric # store privately so these aren't serialized self._uri = uri @property def method(self): return "lancedb" @property def uri(self): return self._uri @uri.setter def uri(self, value): self._uri = value @property def max_k(self): return None @property def supports_least_similarity(self): return False @property def supported_aggregations(self): return ("mean",) def load_credentials(self, uri=None): self._load_parameters(uri=uri) class LanceDBSimilarity(Similarity): """LanceDB similarity factory. Args: config: a :class:`LanceDBSimilarityConfig` """ def ensure_requirements(self): fou.ensure_package("lancedb") def ensure_usage_requirements(self): fou.ensure_package("lancedb") def initialize(self, samples, brain_key): return LanceDBSimilarityIndex( samples, self.config, brain_key, backend=self ) class LanceDBSimilarityIndex(SimilarityIndex): """Class for interacting with LanceDB similarity indexes. Args: samples: the :class:`fiftyone.core.collections.SampleCollection` used config: the :class:`LanceDBSimilarityConfig` used brain_key: the brain key backend (None): a :class:`LanceDBSimilarity` instance """ def __init__(self, samples, config, brain_key, backend=None): super().__init__(samples, config, brain_key, backend=backend) self._table = None self._db = None self._initialize() def _initialize(self): try: db = lancedb.connect(self.config.uri) except Exception as e: raise ValueError( "Failed to connect to LanceDB backend at URI '%s'. Refer to " "https://docs.voxel51.com/integrations/lancedb.html for more " "information" % self.config.uri ) from e table_names = db.table_names() if self.config.table_name is None: root = "fiftyone-" + fou.to_slug(self.samples._root_dataset.name) table_name = fbu.get_unique_name(root, table_names) self.config.table_name = table_name self.save_config() if self.config.table_name in table_names: table = db.open_table(self.config.table_name) else: table = None self._db = db self._table = table @property def table(self): """The ``lancedb.LanceTable`` instance for this index.""" return self._table @property def total_index_size(self): if self._table is None: return 0 return len(self._table) def add_to_index( self, embeddings, sample_ids, label_ids=None, overwrite=True, allow_existing=True, warn_existing=False, reload=True, ): if self._table is None: pa_table = pa.Table.from_arrays( [[], [], []], names=["id", "sample_id", "vector"] ) else: pa_table = self._table.to_arrow() if label_ids is not None: ids = label_ids else: ids = sample_ids if warn_existing or not allow_existing or not overwrite: existing_ids = set(pa_table["id"].to_pylist()) & set(ids) num_existing = len(existing_ids) if num_existing > 0: if not allow_existing: raise ValueError( "Found %d IDs (eg %s) that already exist in the index" % (num_existing, next(iter(existing_ids))) ) if warn_existing: if overwrite: logger.warning( "Overwriting %d IDs that already exist in the " "index", num_existing, ) else: logger.warning( "Skipping %d IDs that already exist in the index", num_existing, ) else: existing_ids = set() if existing_ids and not overwrite: del_inds = [i for i, _id in enumerate(ids) if _id in existing_ids] embeddings = np.delete(embeddings, del_inds, axis=0) sample_ids = np.delete(sample_ids, del_inds) if label_ids is not None: label_ids = np.delete(label_ids, del_inds) if label_ids is not None: ids = list(label_ids) else: ids = list(sample_ids) dim = embeddings.shape[1] if self._table: prev_embeddings = np.concatenate( pa_table["vector"].to_numpy() ).reshape(-1, dim) embeddings = np.concatenate([prev_embeddings, embeddings]) ids = pa_table["id"].to_pylist() + ids sample_ids = pa_table["sample_id"].to_pylist() + sample_ids embeddings = pa.array(embeddings.reshape(-1), type=pa.float32()) embeddings = pa.FixedSizeListArray.from_arrays(embeddings, dim) sample_ids = list(sample_ids) pa_table = pa.Table.from_arrays( [ids, sample_ids, embeddings], names=["id", "sample_id", "vector"] ) self._table = self._db.create_table( self.config.table_name, pa_table, mode="overwrite" ) if reload: self.reload() def remove_from_index( self, sample_ids=None, label_ids=None, allow_missing=True, warn_missing=False, reload=True, ): if label_ids is not None: ids = label_ids else: ids = sample_ids if not allow_missing or warn_missing: existing_ids = list(self._index.fetch(ids).vectors.keys()) missing_ids = set(ids) - set(existing_ids) num_missing = len(missing_ids) if num_missing > 0: if not allow_missing: raise ValueError( "Found %d IDs (eg %s) that are not present in the " "index" % (num_missing, next(iter(missing_ids))) ) if warn_missing: logger.warning( "Ignoring %d IDs that are not present in the index", num_missing, ) ids = existing_ids df = self._table.to_pandas() df = df[~df["id"].isin(ids)] self._table = self._db.create_table( self.config.table_name, df, mode="overwrite" ) if reload: self.reload() def get_embeddings( self, sample_ids=None, label_ids=None, allow_missing=True, warn_missing=False, ): if label_ids is not None: if self.config.patches_field is None: raise ValueError("This index does not support label IDs") if sample_ids is not None: logger.warning( "Ignoring sample IDs when label IDs are provided" ) df = self._table.to_pandas() found_embeddings = [] found_sample_ids = [] found_label_ids = [] missing_ids = [] if sample_ids is not None and self.config.patches_field is not None: df.set_index("sample_id", drop=False, inplace=True) if not etau.is_container(sample_ids): sample_ids = [sample_ids] for sample_id in sample_ids: if sample_id in df.index: found_embeddings.append(df.loc[sample_id]["vector"]) found_sample_ids.append(sample_id) found_label_ids.append(df.loc[sample_id]["id"]) else: missing_ids.append(sample_id) elif self.config.patches_field is not None: df.set_index("id", drop=False, inplace=True) if label_ids is None: label_ids = list(df.index) elif not etau.is_container(label_ids): label_ids = [label_ids] for label_id in label_ids: if label_id in df.index: found_embeddings.append(df.loc[label_id]["vector"]) found_sample_ids.append(df.loc[label_id]["sample_id"]) found_label_ids.append(label_id) else: missing_ids.append(label_id) else: df.set_index("id", drop=False, inplace=True) if sample_ids is None: sample_ids = list(df.index) elif not etau.is_container(sample_ids): sample_ids = [sample_ids] for sample_id in sample_ids: if sample_id in df.index: found_embeddings.append(df.loc[sample_id]["vector"]) found_sample_ids.append(sample_id) else: missing_ids.append(sample_id) num_missing_ids = len(missing_ids) if num_missing_ids > 0: if not allow_missing: raise ValueError( "Found %d IDs (eg %s) that do not exist in the index" % (num_missing_ids, missing_ids[0]) ) if warn_missing: logger.warning( "Skipping %d IDs that do not exist in the index", num_missing_ids, ) embeddings = np.array(found_embeddings) sample_ids = np.array(found_sample_ids) if label_ids is not None: label_ids = np.array(found_label_ids) return embeddings, sample_ids, label_ids def cleanup(self): if self._db is None: return for tbl in ( self.config.table_name, self.config.table_name + "_filter", ): if tbl in self._db.table_names(): self._db.drop_table(tbl) self._table = None def _kneighbors( self, query=None, k=None, reverse=False, aggregation=None, return_dists=False, ): if query is None: raise ValueError("LanceDB does not support full index neighbors") if reverse is True: raise ValueError( "LanceDB does not support least similarity queries" ) if aggregation not in (None, "mean"): raise ValueError( f"LanceDB does not support {aggregation} aggregation" ) if k is None: k = self.index_size query = self._parse_neighbors_query(query) if aggregation == "mean" and query.ndim == 2: query = query.mean(axis=0) single_query = query.ndim == 1 if single_query: query = [query] table = self._table if self.has_view: if self.config.patches_field is not None: index_ids = list(self.current_label_ids) else: index_ids = list(self.current_sample_ids) df = table.to_pandas() df = df[df["id"].isin(index_ids)] table = self._db.create_table( self.config.table_name + "_filter", df, mode="overwrite" ) metric = _SUPPORTED_METRICS[self.config.metric] sample_ids = [] label_ids = [] if self.config.patches_field is not None else None dists = [] for q in query: results = table.search(q).metric(metric).limit(k).to_df() if self.config.patches_field is not None: sample_ids.append(results.sample_id.tolist()) label_ids.append(results.id.tolist()) else: sample_ids.append(results.id.tolist()) if return_dists: dists.append(results._distance.tolist()) if single_query: sample_ids = sample_ids[0] if label_ids is not None: label_ids = label_ids[0] if return_dists: dists = dists[0] if return_dists: return sample_ids, label_ids, dists return sample_ids, label_ids def _parse_neighbors_query(self, query): if etau.is_str(query): query_ids = [query] single_query = True else: query = np.asarray(query) # Query by vector(s) if np.issubdtype(query.dtype, np.number): return query query_ids = list(query) single_query = False # Query by ID(s) df = self._table.to_pandas() df = df[df["id"].isin(query_ids)] query = np.array([v for v in df["vector"]]) if query.size == 0: raise ValueError( "Query IDs %s were not found in the index" % query_ids ) if single_query: query = query[0, :] return query @classmethod def _from_dict(cls, d, samples, config, brain_key): return cls(samples, config, brain_key) ================================================ FILE: fiftyone/brain/internal/core/leaky_splits.py ================================================ """ Finds leaks between splits. | Copyright 2017-2026, Voxel51, Inc. | `voxel51.com `_ | """ import logging import eta.core.utils as etau import fiftyone.core.brain as fob import fiftyone.core.fields as fof import fiftyone.core.validation as fov import fiftyone.zoo as foz from fiftyone import ViewField as F import fiftyone.brain as fb import fiftyone.brain.similarity as fbs import fiftyone.brain.internal.core.utils as fbu logger = logging.getLogger(__name__) _DEFAULT_MODEL = "resnet18-imagenet-torch" _DEFAULT_BATCH_SIZE = None def compute_leaky_splits( samples, splits, threshold=None, roi_field=None, embeddings=None, similarity_index=None, model=None, model_kwargs=None, force_square=False, alpha=None, batch_size=None, num_workers=None, skip_failures=True, progress=None, ): """See ``fiftyone/brain/__init__.py``.""" fov.validate_collection(samples) if etau.is_str(embeddings): embeddings_field, embeddings_exist = fbu.parse_data_field( samples, embeddings, data_type="embeddings", ) embeddings = None else: embeddings_field = None embeddings_exist = None if etau.is_str(similarity_index): similarity_index = samples.load_brain_results(similarity_index) if ( model is None and embeddings is None and similarity_index is None and not embeddings_exist ): model = foz.load_zoo_model(_DEFAULT_MODEL) if batch_size is None: batch_size = _DEFAULT_BATCH_SIZE config = LeakySplitsConfig( splits=splits, embeddings_field=embeddings_field, similarity_index=similarity_index, model=model, model_kwargs=model_kwargs, ) brain_method = config.build() brain_method.ensure_requirements() if similarity_index is None: similarity_index = fb.compute_similarity( samples, backend="sklearn", roi_field=roi_field, embeddings=embeddings_field or embeddings, model=model, model_kwargs=model_kwargs, force_square=force_square, alpha=alpha, batch_size=batch_size, num_workers=num_workers, skip_failures=skip_failures, progress=progress, ) elif not isinstance(similarity_index, fbs.DuplicatesMixin): raise ValueError( "This method only supports similarity indexes that implement the " "%s mixin" % fbs.DuplicatesMixin ) split_views = _to_split_views(samples, splits) index = brain_method.initialize(samples, similarity_index, split_views) if threshold is not None: index.find_leaks(threshold) return index class LeakySplitsConfig(fob.BrainMethodConfig): def __init__( self, splits=None, embeddings_field=None, similarity_index=None, model=None, model_kwargs=None, **kwargs, ): if isinstance(splits, dict): splits = None if similarity_index is not None and not etau.is_str(similarity_index): similarity_index = similarity_index.key if model is not None and not etau.is_str(model): model = etau.get_class_name(model) self.splits = splits self.embeddings_field = embeddings_field self.similarity_index = similarity_index self.model = model self.model_kwargs = model_kwargs super().__init__(**kwargs) @property def type(self): return "leakage" @property def method(self): return "similarity" class LeakySplits(fob.BrainMethod): def initialize(self, samples, similarity_index, split_views): return LeakySplitsIndex( samples, self.config, similarity_index, split_views ) def get_fields(self, samples, _): fields = [] if self.config.embeddings_field is not None: fields.append(self.config.embeddings_field) return fields class LeakySplitsIndex(fob.BrainResults): def __init__(self, samples, config, similarity_index, split_views): super().__init__(samples, config, None) self._similarity_index = similarity_index self._split_views = split_views self._id2split = None self._thresh = None self._leak_ids = None self._initialize() @property def split_views(self): """A dict mapping split names to views.""" return self._split_views @property def thresh(self): """The threshold used by the last call to :meth:`find_leaks`.""" return self._thresh @property def leak_ids(self): """The list of leaky sample IDs from the last call to :meth:`find_leaks`. """ return self._leak_ids def find_leaks(self, thresh): """Scans the index for leaks between splits. Args: thresh: the similarity distance threshold to use when detecting potential leaks """ if thresh == self._thresh: return # Find duplicates self._thresh = thresh if self._similarity_index.thresh != self._thresh: self._similarity_index.find_duplicates(self._thresh) # Filter duplicates to just those with neighbors in different splits leak_ids = [] neighbors_map = self._similarity_index.neighbors_map for sample_id, neighbors in neighbors_map.items(): _leak_ids = [] sample_split = self._id2split.get(sample_id, None) if sample_split is None: continue for n in neighbors: neighbor_id = n[0] neighbor_split = self._id2split.get(neighbor_id, None) if neighbor_split is None: continue if neighbor_split != sample_split: _leak_ids.append(neighbor_id) if _leak_ids: leak_ids.append(sample_id) leak_ids.extend(_leak_ids) self._leak_ids = leak_ids def leaks_view(self): """Returns a view containg all potential leaks generated by the last call to :meth:`find_leaks`. Returns: a :class:`fiftyone.core.view.DatasetView` """ if self._thresh is None: raise ValueError("You must first call `find_leaks()`") return self.samples.select(self._leak_ids, ordered=True) def leaks_for_sample(self, sample_or_id): """Returns a view that contains all leaks related to the given sample. The given sample is always first in the returned view, followed by any related leaks. Args: sample_or_id: a :class:`fiftyone.core.sample.Sample` or sample ID Returns: a :class:`fiftyone.core.view.DatasetView` """ if self._thresh is None: raise ValueError("You must first call `find_leaks()`") if etau.is_str(sample_or_id): sample_id = sample_or_id else: sample_id = sample_or_id.id sample_split = self._id2split[sample_id] neighbors_map = self._similarity_index.neighbors_map leak_ids = [] if sample_id in neighbors_map.keys(): neighbors = neighbors_map[sample_id] leak_ids = [ n[0] for n in neighbors if self._id2split[n[0]] != sample_split ] else: for unique_id, neighbors in neighbors_map.items(): if sample_id in [n[0] for n in neighbors]: leak_ids = [ n[0] for n in neighbors if self._id2split[n[0]] != sample_split ] leak_ids.append(unique_id) break return self.samples.select([sample_id] + leak_ids, ordered=True) def no_leaks_view(self, view=None): """Returns a view with leaks excluded. Args: view (None): an optional :class:`fiftyone.core.view.DatasetView` from which to exclude. By default, :meth:`samples` is used """ if self._thresh is None: raise ValueError("You must first call `find_leaks()`") if view is None: view = self.samples return view.exclude(self._leak_ids) def tag_leaks(self, tag="leak"): """Tags all potential leaks in :meth:`leaks_view` with the given tag. Args: tag ("leak"): the tag string to apply """ self.leaks_view().tag_samples(tag) def _initialize(self): id2split = {} split_ids = {} for split_name, split_view in self.split_views.items(): sample_ids = set(split_view.values("id")) split_ids[split_name] = sample_ids id2split.update({sid: split_name for sid in sample_ids}) # Check for overlapping splits split_names = list(split_ids.keys()) for idx, split1 in enumerate(split_names): for split2 in split_names[idx + 1 :]: overlap = split_ids[split1] & split_ids[split2] if overlap: logger.warning( "The '%s' and '%s' splits contain %d overlapping samples." "Use dataset.match_tags('%s').match_tags('%s') to " "identify them", split1, split2, len(overlap), split1, split2, ) # Check for samples not in index index_ids = self._similarity_index.sample_ids if index_ids is not None: index_ids = set(index_ids) all_split_ids = set(id2split.keys()) missing_ids = all_split_ids - index_ids if missing_ids: logger.warning( "The provided splits contain %d samples (eg '%s') that " "are not present in the index", len(missing_ids), next(iter(missing_ids)), ) self._id2split = id2split def _to_split_views(samples, splits): if etau.is_container(splits): return {tag: samples.match_tags(tag) for tag in splits} if isinstance(splits, str): field = samples.get_field(splits) if isinstance(field, fof.ListField): return { value: samples.exists(splits).match(F(splits).contains(value)) for value in samples.distinct(splits) } else: return { value: samples.match(F(splits) == value) for value in samples.distinct(splits) } ================================================ FILE: fiftyone/brain/internal/core/milvus.py ================================================ """ Milvus similarity backend. | Copyright 2017-2026, Voxel51, Inc. | `voxel51.com `_ | """ import logging import numpy as np from uuid import uuid4 import eta.core.utils as etau import fiftyone.core.utils as fou from fiftyone.brain.similarity import ( SimilarityConfig, Similarity, SimilarityIndex, ) import fiftyone.brain.internal.core.utils as fbu pymilvus = fou.lazy_import("pymilvus") logger = logging.getLogger(__name__) _SUPPORTED_METRICS = { "cosine": "COSINE", "dotproduct": "IP", "euclidean": "L2", } class MilvusSimilarityConfig(SimilarityConfig): """Configuration for the Milvus similarity backend. Args: collection_name (None): the name of a Milvus collection to use or create. If none is provided, a new collection will be created metric ("dotproduct"): the embedding distance metric to use when creating a new index. Supported values are ``("cosine", "dotproduct", "euclidean")`` consistency_level ("Session"): the consistency level to use. Supported values are ``("Session", "Strong", "Bounded", "Eventually")`` uri (None): a full Milvus server address to use, like ``"http://localhost:19530"``, ``"tcp:localhost:19530"``, or ``"https://ok.s3.south.com:19530"`` user (None): a username to use password (None): a password to use secure (None): whether to enable TLS (True) token (None): a header token for RPC calls db_name (None): a database name for the connection client_key_path (None): a client.key path for TLS two-way client_pem_path (None): a client.pem path for TLS two-way ca_pem_path (None): a ca.pem path for TLS two-way server_pem_path (None): a server.pem path for TLS one-way server_name (None): the server name, for TLS **kwargs: keyword arguments for :class:`fiftyone.brain.similarity.SimilarityConfig` """ def __init__( self, collection_name=None, metric="dotproduct", consistency_level="Session", uri=None, user=None, password=None, secure=None, token=None, db_name=None, client_key_path=None, client_pem_path=None, ca_pem_path=None, server_pem_path=None, server_name=None, **kwargs, ): if metric not in _SUPPORTED_METRICS: raise ValueError( "Unsupported metric '%s'. Supported values are %s" % (metric, tuple(_SUPPORTED_METRICS.keys())) ) super().__init__(**kwargs) self.collection_name = collection_name self.metric = metric self.consistency_level = consistency_level # store privately so these aren't serialized self._uri = uri self._user = user self._password = password self._secure = secure self._token = token self._db_name = db_name self._client_key_path = client_key_path self._client_pem_path = client_pem_path self._ca_pem_path = ca_pem_path self._server_pem_path = server_pem_path self._server_name = server_name @property def method(self): return "milvus" @property def uri(self): return self._uri @uri.setter def uri(self, value): self._uri = value @property def user(self): return self._user @user.setter def user(self, value): self._user = value @property def password(self): return self._password @password.setter def password(self, value): self._password = value @property def secure(self): return self._secure @secure.setter def secure(self, value): self._secure = value @property def token(self): return self._token @token.setter def token(self, value): self._token = value @property def db_name(self): return self._db_name @db_name.setter def db_name(self, value): self._db_name = value @property def client_key_path(self): return self._client_key_path @client_key_path.setter def client_key_path(self, value): self._client_key_path = value @property def client_pem_path(self): return self._client_pem_path @client_pem_path.setter def client_pem_path(self, value): self._client_pem_path = value @property def ca_pem_path(self): return self._ca_pem_path @ca_pem_path.setter def ca_pem_path(self, value): self._ca_pem_path = value @property def server_pem_path(self): return self._server_pem_path @server_pem_path.setter def server_pem_path(self, value): self._server_pem_path = value @property def server_name(self): return self._server_name @server_name.setter def server_name(self, value): self._server_name = value @property def max_k(self): return 16384 @property def supports_least_similarity(self): return False @property def supported_aggregations(self): return ("mean",) @property def index_params(self): return { "metric_type": _SUPPORTED_METRICS[self.metric], "index_type": "HNSW", "params": {"M": 8, "efConstruction": 64}, } @property def search_params(self): return { "HNSW": { "metric_type": _SUPPORTED_METRICS[self.metric], "params": {"ef": 10}, }, } def load_credentials( self, uri=None, user=None, password=None, secure=None, token=None, db_name=None, client_key_path=None, client_pem_path=None, ca_pem_path=None, server_pem_path=None, server_name=None, ): self._load_parameters( uri=uri, user=user, password=password, secure=secure, token=token, db_name=db_name, client_key_path=client_key_path, client_pem_path=client_pem_path, ca_pem_path=ca_pem_path, server_pem_path=server_pem_path, server_name=server_name, ) class MilvusSimilarity(Similarity): """Milvus similarity factory. Args: config: a :class:`MilvusSimilarityConfig` """ def ensure_requirements(self): fou.ensure_package("pymilvus") def ensure_usage_requirements(self): fou.ensure_package("pymilvus") def initialize(self, samples, brain_key): return MilvusSimilarityIndex( samples, self.config, brain_key, backend=self ) class MilvusSimilarityIndex(SimilarityIndex): """Class for interacting with Milvus similarity indexes. Args: samples: the :class:`fiftyone.core.collections.SampleCollection` used config: the :class:`MilvusSimilarityConfig` used brain_key: the brain key backend (None): a :class:`MilvusSimilarity` instance """ def __init__(self, samples, config, brain_key, backend=None): super().__init__(samples, config, brain_key, backend=backend) self._alias = None self._collection = None self._initialize() def _initialize(self): kwargs = {} for key in ( "uri", "user", "password", "secure", "token", "db_name", "client_key_path", "client_pem_path", "ca_pem_path", "server_pem_path", "server_name", ): value = getattr(self.config, key, None) if value is not None: kwargs[key] = value alias = uuid4().hex if kwargs else "default" try: pymilvus.connections.connect(alias=alias, **kwargs) except pymilvus.MilvusException as e: raise ValueError( "Failed to connect to Milvus backend at URI '%s'. Refer to " "https://docs.voxel51.com/integrations/milvus.html for more " "information" % self.config.uri ) from e collection_names = pymilvus.utility.list_collections(using=alias) if self.config.collection_name is None: # Milvus only supports numbers, letters and underscores root = "fiftyone-" + fou.to_slug(self.samples._root_dataset.name) root = root.replace("-", "_") collection_name = fbu.get_unique_name(root, collection_names) collection_name = collection_name.replace("-", "_") self.config.collection_name = collection_name self.save_config() if self.config.collection_name in collection_names: collection = pymilvus.Collection( self.config.collection_name, using=alias ) collection.load() else: collection = None self._alias = alias self._collection = collection def _create_collection(self, dimension): schema = pymilvus.CollectionSchema( [ pymilvus.FieldSchema( "pk", pymilvus.DataType.VARCHAR, is_primary=True, auto_id=False, max_length=64000, ), pymilvus.FieldSchema( "vector", pymilvus.DataType.FLOAT_VECTOR, dim=dimension ), pymilvus.FieldSchema( "sample_id", pymilvus.DataType.VARCHAR, max_length=64000 ), ] ) collection = pymilvus.Collection( self.config.collection_name, schema, consistency_level=self.config.consistency_level, using=self._alias, ) collection.create_index( "vector", index_params=self.config.index_params ) collection.load() self._collection = collection @property def collection(self): """The ``pymilvus.Collection`` instance for this index.""" return self._collection @property def total_index_size(self): if self._collection is None: return 0 return self._collection.num_entities def add_to_index( self, embeddings, sample_ids, label_ids=None, overwrite=True, allow_existing=True, warn_existing=False, reload=True, batch_size=100, ): if self._collection is None: self._create_collection(embeddings.shape[1]) if label_ids is not None: ids = label_ids else: ids = sample_ids if warn_existing or not allow_existing or not overwrite: existing_ids = self._get_existing_ids(ids) num_existing = len(existing_ids) if num_existing > 0: if not allow_existing: raise ValueError( "Found %d IDs (eg %s) that already exist in the index" % (num_existing, next(iter(existing_ids))) ) if warn_existing: if overwrite: logger.warning( "Overwriting %d IDs that already exist in the " "index", num_existing, ) else: logger.warning( "Skipping %d IDs that already exist in the index", num_existing, ) else: existing_ids = set() if existing_ids and not overwrite: del_inds = [i for i, _id in enumerate(ids) if _id in existing_ids] embeddings = np.delete(embeddings, del_inds, axis=0) sample_ids = np.delete(sample_ids, del_inds) if label_ids is not None: label_ids = np.delete(label_ids, del_inds) elif existing_ids and overwrite: self._delete_ids(existing_ids) embeddings = [e.tolist() for e in embeddings] sample_ids = list(sample_ids) ids = list(ids) for _embeddings, _ids, _sample_ids in zip( fou.iter_batches(embeddings, batch_size), fou.iter_batches(ids, batch_size), fou.iter_batches(sample_ids, batch_size), ): insert_data = [ list(_ids), list(_embeddings), list(_sample_ids), ] self._collection.insert(insert_data) self._collection.flush() if reload: self.reload() def _get_existing_ids(self, ids): ids = ['"' + str(entry) + '"' for entry in ids] expr = f"""pk in [{','.join(ids)}]""" return self._collection.query(expr) def _delete_ids(self, ids): ids = ['"' + str(entry) + '"' for entry in ids] expr = f"""pk in [{','.join(ids)}]""" self._collection.delete(expr) self._collection.flush() def _get_embeddings(self, ids): ids = ['"' + str(entry) + '"' for entry in ids] expr = f"""pk in [{','.join(ids)}]""" return self._collection.query( expr, output_fields=["pk", "sample_id", "vector"] ) def remove_from_index( self, sample_ids=None, label_ids=None, allow_missing=True, warn_missing=False, reload=True, ): if label_ids is not None: ids = label_ids else: ids = sample_ids if not allow_missing or warn_missing: existing_ids = self._get_existing_ids(ids) missing_ids = set(ids) - set(existing_ids) num_missing = len(missing_ids) if num_missing > 0: if not allow_missing: raise ValueError( "Found %d IDs (eg %s) that are not present in the " "index" % (num_missing, next(iter(missing_ids))) ) if warn_missing: logger.warning( "Ignoring %d IDs that are not present in the index", num_missing, ) ids = existing_ids self._delete_ids(ids=ids) if reload: self.reload() def get_embeddings( self, sample_ids=None, label_ids=None, allow_missing=True, warn_missing=False, ): if label_ids is not None: if self.config.patches_field is None: raise ValueError("This index does not support label IDs") if sample_ids is not None: logger.warning( "Ignoring sample IDs when label IDs are provided" ) if sample_ids is not None and self.config.patches_field is not None: ( embeddings, sample_ids, label_ids, missing_ids, ) = self._get_patch_embeddings_from_sample_ids(sample_ids) elif self.config.patches_field is not None: ( embeddings, sample_ids, label_ids, missing_ids, ) = self._get_patch_embeddings_from_label_ids(label_ids) else: ( embeddings, sample_ids, label_ids, missing_ids, ) = self._get_sample_embeddings(sample_ids) num_missing_ids = len(missing_ids) if num_missing_ids > 0: if not allow_missing: raise ValueError( "Found %d IDs (eg %s) that do not exist in the index" % (num_missing_ids, missing_ids[0]) ) if warn_missing: logger.warning( "Skipping %d IDs that do not exist in the index", num_missing_ids, ) embeddings = np.array(embeddings) sample_ids = np.array(sample_ids) if label_ids is not None: label_ids = np.array(label_ids) return embeddings, sample_ids, label_ids def cleanup(self): pymilvus.utility.drop_collection( self.config.collection_name, using=self._alias ) self._collection = None def _get_sample_embeddings(self, sample_ids, batch_size=1000): found_embeddings = [] found_sample_ids = [] if sample_ids is None: raise ValueError( "Milvus does not support retrieving all vectors in an index" ) for batch_ids in fou.iter_batches(sample_ids, batch_size): response = self._get_embeddings(list(batch_ids)) for r in response: found_embeddings.append(r["vector"]) found_sample_ids.append(r["sample_id"]) missing_ids = list(set(sample_ids) - set(found_sample_ids)) return found_embeddings, found_sample_ids, None, missing_ids def _get_patch_embeddings_from_label_ids(self, label_ids, batch_size=1000): found_embeddings = [] found_sample_ids = [] found_label_ids = [] if label_ids is None: raise ValueError( "Milvus does not support retrieving all vectors in an index" ) for batch_ids in fou.iter_batches(label_ids, batch_size): response = self._get_embeddings(list(batch_ids)) for r in response: found_embeddings.append(r["vector"]) found_sample_ids.append(r["sample_id"]) found_label_ids.append(r["pk"]) missing_ids = list(set(label_ids) - set(found_label_ids)) return found_embeddings, found_sample_ids, found_label_ids, missing_ids def _get_patch_embeddings_from_sample_ids( self, sample_ids, batch_size=100 ): found_embeddings = [] found_sample_ids = [] found_label_ids = [] query_vector = [0.0] * self._get_dimension() top_k = min(batch_size, self.config.max_k) for batch_ids in fou.iter_batches(sample_ids, batch_size): ids = ['"' + str(entry) + '"' for entry in batch_ids] expr = f"""pk in [{','.join(ids)}]""" response = self._collection.search( data=[query_vector], anns_field="vector", param=self.config.search_params, expr=expr, limit=top_k, ) ids = [x.id for x in response[0]] response = self._get_embeddings(ids) for r in response: found_embeddings.append(r["vector"]) found_sample_ids.append(r["sample_id"]) found_label_ids.append(r["pk"]) missing_ids = list(set(sample_ids) - set(found_sample_ids)) return found_embeddings, found_sample_ids, found_label_ids, missing_ids def _kneighbors( self, query=None, k=None, reverse=False, aggregation=None, return_dists=False, ): if query is None: raise ValueError("Milvus does not support full index neighbors") if reverse is True: raise ValueError( "Milvus does not support least similarity queries" ) if k is None or k > self.config.max_k: raise ValueError("Milvus requires k<=%s" % self.config.max_k) if aggregation not in (None, "mean"): raise ValueError("Unsupported aggregation '%s'" % aggregation) query = self._parse_neighbors_query(query) if aggregation == "mean" and query.ndim == 2: query = query.mean(axis=0) single_query = query.ndim == 1 if single_query: query = [query] if self.has_view: if self.config.patches_field is not None: index_ids = self.current_label_ids else: index_ids = self.current_sample_ids expr = ['"' + str(entry) + '"' for entry in index_ids] expr = f"""pk in [{','.join(expr)}]""" else: expr = None sample_ids = [] label_ids = [] if self.config.patches_field is not None else None dists = [] for q in query: if self.config.patches_field is not None: output_fields = ["sample_id"] else: output_fields = None response = self._collection.search( data=[q.tolist()], anns_field="vector", limit=k, expr=expr, param=self.config.search_params, output_fields=output_fields, ) if self.config.patches_field is not None: sample_ids.append( [r.entity.get("sample_id") for r in response[0]] ) label_ids.append([r.id for r in response[0]]) else: sample_ids.append([r.id for r in response[0]]) if return_dists: dists.append([r.score for r in response[0]]) if single_query: sample_ids = sample_ids[0] if label_ids is not None: label_ids = label_ids[0] if return_dists: dists = dists[0] if return_dists: return sample_ids, label_ids, dists return sample_ids, label_ids def _parse_neighbors_query(self, query): if etau.is_str(query): query_ids = [query] single_query = True else: query = np.asarray(query) # Query by vector(s) if np.issubdtype(query.dtype, np.number): return query query_ids = list(query) single_query = False # Query by ID(s) response = self._get_embeddings(query_ids) query = np.array([x["vector"] for x in response]) if query.size == 0: raise ValueError( "Query IDs %s were not found in the index" % query_ids ) if single_query: query = query[0, :] return query def _get_dimension(self): if self._collection is None: return None for field in self._collection.describe()["fields"]: if field["name"] == "vector": return field["params"]["dim"] @classmethod def _from_dict(cls, d, samples, config, brain_key): return cls(samples, config, brain_key) ================================================ FILE: fiftyone/brain/internal/core/mistakenness.py ================================================ """ Mistakenness methods. | Copyright 2017-2026, Voxel51, Inc. | `voxel51.com `_ | """ import logging from math import exp import numpy as np from scipy.special import softmax from scipy.stats import entropy from fiftyone import ViewField as F import fiftyone.core.brain as fob import fiftyone.core.labels as fol import fiftyone.core.media as fom import fiftyone.core.utils as fou import fiftyone.core.validation as fov logger = logging.getLogger(__name__) _ALLOWED_TYPES = ( fol.Classification, fol.Classifications, fol.Detections, fol.Polylines, fol.Keypoints, fol.TemporalDetections, ) _MISSED_CONFIDENCE_THRESHOLD = 0.95 _DETECTION_IOU = 0.5 def compute_mistakenness( samples, pred_field, label_field, mistakenness_field, missing_field, spurious_field, use_logits, copy_missing, progress, ): """See ``fiftyone/brain/__init__.py``.""" # # Algorithm # # The chance of a mistake is related to how confident the model prediction # was as well as whether or not the prediction is correct. A prediction # that is highly confident and incorrect is likely to be a mistake. A # prediction that is low confidence and incorrect is not likely to be a # mistake. # # Let us compute a confidence measure based on negative entropy of logits: # $c = -entropy(logits)$. This value is large when there is low uncertainty # and small when there is high uncertainty. Let us define modulator, $m$, # based on whether or not the answer is correct. $m = -1$ when the label is # correct and $1$ otherwise. Then, mistakenness is computed as # $(m * exp(c) + 1) / 2$ so that high confidence correct predictions result # in low mistakenness, high confidence incorrect predictions result in high # mistakenness, and low confidence predictions result in middling # mistakenness. # fov.validate_collection_label_fields( samples, (pred_field, label_field), _ALLOWED_TYPES, same_type=True ) if samples.media_type == fom.VIDEO: mistakenness_field, _ = samples._handle_frame_field(mistakenness_field) missing_field, _ = samples._handle_frame_field(missing_field) spurious_field, _ = samples._handle_frame_field(spurious_field) is_objects = samples._is_label_field( pred_field, (fol.Detections, fol.Polylines, fol.Keypoints, fol.TemporalDetections), ) if is_objects: eval_key = _make_eval_key(samples, mistakenness_field) config = DetectionMistakennessConfig( pred_field, label_field, mistakenness_field, missing_field, spurious_field, use_logits, copy_missing, eval_key, ) else: eval_key = None config = ClassificationMistakennessConfig( pred_field, label_field, mistakenness_field, use_logits ) brain_key = mistakenness_field brain_method = config.build() brain_method.ensure_requirements() brain_method.register_run(samples, brain_key, cleanup=False) brain_method.register_samples(samples) if is_objects: samples.evaluate_detections( pred_field, gt_field=label_field, eval_key=eval_key, classwise=False, iou=_DETECTION_IOU, progress=progress, ) view = samples.select_fields([label_field, pred_field]) processing_frames = samples._is_frame_field(label_field) logger.info("Computing mistakenness...") for sample in view.iter_samples(progress=progress): if processing_frames: images = sample.frames.values() else: images = [sample] sample_mistakenness = [] num_missing = 0 num_spurious = 0 for image in images: if is_objects: ( img_mistakenness, img_missing, img_spurious, ) = brain_method.process_image(image, eval_key) num_missing += img_missing num_spurious += img_spurious if processing_frames: image[missing_field] = img_missing image[spurious_field] = img_spurious else: img_mistakenness = brain_method.process_image(image) if img_mistakenness is not None: sample_mistakenness.append(img_mistakenness) if processing_frames: image[mistakenness_field] = img_mistakenness if sample_mistakenness: sample[mistakenness_field] = np.max(sample_mistakenness) else: sample[mistakenness_field] = None if is_objects: sample[missing_field] = num_missing sample[spurious_field] = num_spurious sample.save() if eval_key is not None: samples.delete_evaluation(eval_key) brain_method.save_run_results(samples, brain_key, None) logger.info("Mistakenness computation complete") # @todo move to `fiftyone/brain/mistakenness.py` # Don't do this hastily; `get_brain_info()` on existing datasets has this # class's full path in it and may need migration class MistakennessMethodConfig(fob.BrainMethodConfig): def __init__(self, pred_field, label_field, mistakenness_field, **kwargs): super().__init__(**kwargs) self.pred_field = pred_field self.label_field = label_field self.mistakenness_field = mistakenness_field @property def type(self): return "mistakenness" class MistakennessMethod(fob.BrainMethod): def __init__(self, config): super().__init__(config) self.pred_field = None self.label_field = None self.label_type = None def ensure_requirements(self): pass def register_samples(self, samples): self.pred_field, _ = samples._handle_frame_field( self.config.pred_field ) self.label_field, _ = samples._handle_frame_field( self.config.label_field ) self.label_type = samples._get_label_field_type(self.config.pred_field) def _validate_run(self, samples, brain_key, existing_info): self._validate_fields_match(brain_key, "pred_field", existing_info) self._validate_fields_match(brain_key, "label_field", existing_info) self._validate_fields_match( brain_key, "mistakenness_field", existing_info ) # @todo move to `fiftyone/brain/mistakenness.py` # Don't do this hastily; `get_brain_info()` on existing datasets has this # class's full path in it and may need migration class ClassificationMistakennessConfig(MistakennessMethodConfig): def __init__( self, pred_field, label_field, mistakenness_field, use_logits, **kwargs ): super().__init__(pred_field, label_field, mistakenness_field, **kwargs) self.use_logits = use_logits @property def method(self): return "classification" class ClassificationMistakenness(MistakennessMethod): def process_image(self, sample_or_frame): use_logits = self.config.use_logits pred_label, gt_label = _get_data( sample_or_frame, self.pred_field, self.label_field, use_logits ) if pred_label is None and gt_label is None: return None if pred_label is None or gt_label is None: m = 1.0 elif isinstance(pred_label, fol.Classifications): # For multilabel problems, all labels must match pred_labels = set(c.label for c in pred_label.classifications) gt_labels = set(c.label for c in gt_label.classifications) m = float(pred_labels == gt_labels) else: m = float(pred_label.label == gt_label.label) if pred_label is None: mistakenness = 1.0 elif use_logits: mistakenness = _compute_mistakenness_class(pred_label.logits, m) else: mistakenness = _compute_mistakenness_class_conf( pred_label.confidence, m ) return mistakenness def get_fields(self, samples, brain_key): pred_field = self.config.pred_field label_field = self.config.label_field mistakenness_field = self.config.mistakenness_field fields = [pred_field, label_field, mistakenness_field] if samples._is_frame_field(label_field): fields.append(samples._FRAMES_PREFIX + mistakenness_field) return fields def cleanup(self, samples, brain_key): label_field = self.config.label_field mistakenness_field = self.config.mistakenness_field samples._dataset.delete_sample_fields( mistakenness_field, error_level=1 ) if samples._is_frame_field(label_field): samples._dataset.delete_frame_fields( mistakenness_field, error_level=1 ) # @todo move to `fiftyone/brain/mistakenness.py` # Don't do this hastily; `get_brain_info()` on existing datasets has this # class's full path in it and may need migration class DetectionMistakennessConfig(MistakennessMethodConfig): def __init__( self, pred_field, label_field, mistakenness_field, missing_field, spurious_field, use_logits, copy_missing, eval_key, **kwargs ): super().__init__(pred_field, label_field, mistakenness_field, **kwargs) self.missing_field = missing_field self.spurious_field = spurious_field self.use_logits = use_logits self.copy_missing = copy_missing self.eval_key = eval_key @property def method(self): return "detection" class DetectionMistakenness(MistakennessMethod): def process_image(self, sample_or_frame, eval_key): missing_field = self.config.missing_field spurious_field = self.config.spurious_field mistakenness_field = self.config.mistakenness_field copy_missing = self.config.copy_missing use_logits = self.config.use_logits pred_label, gt_label = _get_data( sample_or_frame, self.pred_field, self.label_field, use_logits ) list_field = self.label_type._LABEL_LIST_FIELD if pred_label is None: pred_label = self.label_type() if gt_label is None: gt_label = self.label_type() num_spurious = 0 num_missing = 0 missing_objects = {} image_mistakenness = [] pred_map = {} for pred_obj in pred_label[list_field]: pred_map[pred_obj.id] = pred_obj gt_id = pred_obj[eval_key + "_id"] conf = pred_obj.confidence if gt_id == "" and conf > _MISSED_CONFIDENCE_THRESHOLD: # Unmached FP with high confidence are missing pred_obj[missing_field] = True num_missing += 1 missing_objects[pred_obj.id] = pred_obj for gt_obj in gt_label[list_field]: # Avoid adding the same unmatched FP predictions upon multiple runs # of this method if copy_missing and gt_obj.has_field(missing_field): if gt_obj.id in missing_objects: del missing_objects[gt_obj.id] continue pred_id = gt_obj[eval_key + "_id"] if pred_id == "": # FN may be spurious gt_obj[spurious_field] = True num_spurious += 1 else: # For matched FP, compute mistakenness iou = gt_obj[eval_key + "_iou"] pred_obj = pred_map[pred_id] m = float(gt_obj.label == pred_obj.label) if use_logits: mistakenness_class = _compute_mistakenness_class( pred_obj.logits, m ) mistakenness_loc = _compute_mistakenness_loc( pred_obj.logits, iou ) else: mistakenness_class = _compute_mistakenness_class_conf( pred_obj.confidence, m ) mistakenness_loc = _compute_mistakenness_loc_conf( pred_obj.confidence, iou ) gt_obj[mistakenness_field] = mistakenness_class gt_obj[mistakenness_field + "_loc"] = mistakenness_loc image_mistakenness.append(mistakenness_class) if copy_missing: gt_label[list_field].extend(missing_objects.values()) sample_or_frame[self.label_field] = gt_label if image_mistakenness: mistakenness = np.max(image_mistakenness) else: mistakenness = -1 return mistakenness, num_missing, num_spurious def get_fields(self, samples, brain_key): pred_field = self.config.pred_field label_field = self.config.label_field mistakenness_field = self.config.mistakenness_field missing_field = self.config.missing_field spurious_field = self.config.spurious_field label_type = samples._get_label_field_type(pred_field) list_field = label_type._LABEL_LIST_FIELD fields = [ mistakenness_field, missing_field, spurious_field, "%s.%s.%s" % (label_field, list_field, mistakenness_field), "%s.%s.%s_loc" % (label_field, list_field, mistakenness_field), "%s.%s.%s" % (pred_field, list_field, missing_field), "%s.%s.%s" % (label_field, list_field, spurious_field), ] if samples._is_frame_field(pred_field): fields.extend( [ samples._FRAMES_PREFIX + mistakenness_field, samples._FRAMES_PREFIX + missing_field, samples._FRAMES_PREFIX + spurious_field, ] ) return fields def cleanup(self, samples, brain_key): pred_field = self.config.pred_field label_field = self.config.label_field mistakenness_field = self.config.mistakenness_field missing_field = self.config.missing_field spurious_field = self.config.spurious_field eval_key = self.config.eval_key label_type = samples._get_label_field_type(pred_field) list_field = label_type._LABEL_LIST_FIELD pred_field, is_frame_field = samples._handle_frame_field(pred_field) label_field, _ = samples._handle_frame_field(label_field) fields = [ mistakenness_field, missing_field, spurious_field, "%s.%s.%s" % (label_field, list_field, mistakenness_field), "%s.%s.%s_loc" % (label_field, list_field, mistakenness_field), "%s.%s.%s" % (pred_field, list_field, missing_field), "%s.%s.%s" % (label_field, list_field, spurious_field), ] if self.config.copy_missing: # Remove objects that were added to `label_field` samples._dataset.filter_labels( self.config.label_field, F(missing_field).exists(False) ).save() if is_frame_field: samples._dataset.delete_sample_fields( [mistakenness_field, spurious_field, missing_field], error_level=1, ) samples._dataset.delete_frame_fields(fields, error_level=1) else: samples._dataset.delete_sample_fields(fields, error_level=1) if eval_key in samples.list_evaluations(): samples.delete_evaluation(eval_key) def _validate_run(self, samples, brain_key, existing_info): super()._validate_run(samples, brain_key, existing_info) self._validate_fields_match(brain_key, "missing_field", existing_info) self._validate_fields_match(brain_key, "spurious_field", existing_info) self._validate_fields_match(brain_key, "copy_missing", existing_info) def _make_eval_key(samples, brain_key): existing_eval_keys = samples.list_evaluations() eval_key = brain_key + "_eval" if eval_key not in existing_eval_keys: return eval_key idx = 2 while eval_key + str(idx) in existing_eval_keys: idx += 1 return eval_key + str(idx) def _get_data(sample, pred_field, label_field, use_logits): pred_label = sample[pred_field] label = sample[label_field] if pred_label is None: return pred_label, label if isinstance(pred_label, fol.Detections): for det in pred_label.detections: if det.confidence is None: raise ValueError( "Detection '%s' in sample '%s' field '%s' has no " "confidence" % (det.id, sample.id, pred_field) ) elif isinstance(pred_label, fol.Polylines): for poly in pred_label.polylines: if poly.confidence is None: raise ValueError( "Polyline '%s' in sample '%s' field '%s' has no " "confidence" % (poly.id, sample.id, pred_field) ) elif use_logits: if pred_label.logits is None: raise ValueError( "Sample '%s' field '%s' has no logits" % (sample.id, pred_field) ) else: if pred_label.confidence is None: raise ValueError( "Sample '%s' field '%s' has no confidence" % (sample.id, pred_field) ) return pred_label, label def _compute_mistakenness_class(logits, m): # constrain m to either 1 (incorrect) or -1 (correct) m = m * -2.0 + 1.0 c = -1.0 * entropy(softmax(np.asarray(logits))) mistakenness = (m * exp(c) + 1.0) / 2.0 return mistakenness def _compute_mistakenness_loc(logits, iou): # i = 0 for high iou, i = 1 for low iou i = (1.0 / (1.0 - _DETECTION_IOU)) * (1.0 - iou) # c = 0 for low confidence, c = 1 for high confidence c = exp(-1.0 * entropy(softmax(np.asarray(logits)))) # mistakenness = i when c = i, mistakenness = 0.5 if c = 0 # mistakenness is higher with lower IoU and closer to 0 or 1 with higher # confidence mistakenness = (c * ((2.0 * i) - 1.0) + 1.0) / 2.0 return mistakenness def _compute_mistakenness_class_conf(confidence, m): # constrain m to either 1 (incorrect) or -1 (correct) m = m * -2.0 + 1.0 mistakenness = (m * confidence + 1.0) / 2.0 return mistakenness def _compute_mistakenness_loc_conf(confidence, iou): # i = 0 for high iou, i = 1 for low iou i = (1.0 / (1.0 - _DETECTION_IOU)) * (1.0 - iou) # c = 0 for low confidence, c = 1 for high confidence c = confidence # mistakenness = i when c = i, mistakenness = 0.5 if c = 0 # mistakenness is higher with lower IoU and closer to 0 or 1 with higher # confidence mistakenness = (c * ((2.0 * i) - 1.0) + 1.0) / 2.0 return mistakenness ================================================ FILE: fiftyone/brain/internal/core/mongodb.py ================================================ """ MongoDB similarity backend. | Copyright 2017-2026, Voxel51, Inc. | `voxel51.com `_ | """ import logging from bson import ObjectId import numpy as np from pymongo.errors import OperationFailure import eta.core.utils as etau from fiftyone import ViewField as F import fiftyone.core.fields as fof import fiftyone.core.media as fom import fiftyone.core.utils as fou import fiftyone.brain.internal.core.utils as fbu from fiftyone.brain.similarity import ( SimilarityConfig, Similarity, SimilarityIndex, ) logger = logging.getLogger(__name__) _SUPPORTED_METRICS = { "cosine": "cosine", "dotproduct": "dotProduct", "euclidean": "euclidean", } class MongoDBSimilarityConfig(SimilarityConfig): """Configuration for a MongoDB similarity instance. Args: index_name (None): the name of the MongoDB vector index to use or create. If none is provided, a new index will be created metric ("cosine"): the embedding distance metric to use when creating a new index. Supported values are ``("cosine", "dotproduct", "euclidean")`` **kwargs: keyword arguments for :class:`fiftyone.brain.similarity.SimilarityConfig` """ def __init__(self, index_name=None, metric="cosine", **kwargs): if kwargs.get("embeddings_field") is None and index_name is None: raise ValueError( "You must provide either the name of a field to read/write " "embeddings for this index by passing the `embeddings` " "parameter, or you must provide the name of an existing " "vector search index via the `index_name` parameter" ) # @todo support this. Will likely require copying embeddings to a new # collection as vector search indexes do not yet support array fields if kwargs.get("patches_field") is not None: raise ValueError( "The MongoDB backend does not yet support patch embeddings" ) if metric not in _SUPPORTED_METRICS: raise ValueError( "Unsupported metric '%s'. Supported values are %s" % (metric, tuple(_SUPPORTED_METRICS.keys())) ) super().__init__(**kwargs) self.index_name = index_name self.metric = metric @property def method(self): return "mongodb" @property def max_k(self): return 10000 # MongoDB limit @property def supports_least_similarity(self): return False @property def supported_aggregations(self): return ("mean",) class MongoDBSimilarity(Similarity): """MongoDB similarity factory. Args: config: a :class:`MongoDBSimilarityConfig` """ def ensure_requirements(self): # # https://pymongo.readthedocs.io/en/stable/api/pymongo/collection.html#pymongo.collection.Collection.create_search_index # # Could also validate that user is connected to an Atlas cluster here # eg Atlas clusters generally have hostnames which end in "mongodb.net" # https://stackoverflow.com/q/73180110 # fou.ensure_package("pymongo>=4.7") def ensure_usage_requirements(self): fou.ensure_package("pymongo>=4.7") def initialize(self, samples, brain_key): return MongoDBSimilarityIndex( samples, self.config, brain_key, backend=self ) class MongoDBSimilarityIndex(SimilarityIndex): """Class for interacting with MongoDB similarity indexes. Args: samples: the :class:`fiftyone.core.collections.SampleCollection` used config: the :class:`MongoDBSimilarityConfig` used brain_key: the brain key backend (None): a :class:`MongoDBSimilarity` instance """ def __init__(self, samples, config, brain_key, backend=None): super().__init__(samples, config, brain_key, backend=backend) self._dataset = samples._dataset self._sample_ids = None self._label_ids = None self._index = None self._initialize() @property def is_external(self): return False @property def total_index_size(self): if self._sample_ids is not None: return len(self._sample_ids) if self._dataset.media_type == fom.GROUP: samples = self._dataset.select_group_slices(_allow_mixed=True) else: samples = self._dataset patches_field = self.config.patches_field embeddings_field = self.config.embeddings_field if patches_field is not None: _, embeddings_path = self._dataset._get_label_field_path( patches_field, embeddings_field ) samples = samples.filter_labels( patches_field, F(embeddings_field).exists() ) return samples.count(embeddings_path) if samples.has_field(embeddings_field): return samples.exists(embeddings_field).count() return 0 def _initialize(self): coll = self._dataset._sample_collection try: indexes = { i["name"]: i for i in coll.aggregate([{"$listSearchIndexes": {}}]) } except OperationFailure: # https://www.mongodb.com/docs/manual/release-notes/7.0/#atlas-search-index-management # https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-overview if self.config.index_name is None: raise ValueError( "You must be running MongoDB Atlas 7.0 or later in order " "to use vector search indexes" ) # Must assume index exists because we can't use pymongo to check... self._index = True return if self.config.index_name is None: root = self.config.embeddings_field index_name = fbu.get_unique_name(root, list(indexes.keys())) self.config.index_name = index_name self.save_config() elif self.config.embeddings_field is None: info = indexes.get(self.config.index_name, None) if info is None: raise ValueError( "Index '%s' does not exist" % self.config.index_name ) self.config.embeddings_field = next( iter(info["latestDefinition"]["mappings"]["fields"].keys()) ) self.save_config() if self.config.index_name in indexes: # Index already exists self._index = True elif self.total_index_size > 0: # Embeddings already exist but the index hasn't been declared yet dimension = self._get_dimension() self._create_index(dimension) else: # Index will be created when add_to_index() is called pass def _get_dimension(self): if self._dataset.media_type == fom.GROUP: samples = self._dataset.select_group_slices(_allow_mixed=True) else: samples = self._dataset patches_field = self.config.patches_field embeddings_field = self.config.embeddings_field if patches_field is not None: _, embeddings_path = self._dataset._get_label_field_path( patches_field, embeddings_field ) view = samples.filter_labels( patches_field, F(embeddings_field).exists() ).limit(1) embeddings = view.values(embeddings_path, unwind=True) else: view = samples.exists(embeddings_field).limit(1) embeddings = view.values(embeddings_field) embedding = next(iter(embeddings), None) if embedding is None: return None return len(embedding) # MongoDB requires list fields def _create_index(self, dimension): # https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage # https://www.mongodb.com/docs/languages/python/pymongo-driver/current/indexes/atlas-search-index/ from pymongo.operations import SearchIndexModel field = self._dataset.get_field(self.config.embeddings_field) if field is not None and not isinstance(field, fof.ListField): raise ValueError( "MongoDB vector search indexes require embeddings to be " "stored in list fields" ) metric = _SUPPORTED_METRICS[self.config.metric] fields = [ { "type": "vector", "numDimensions": dimension, "path": self.config.embeddings_field, "similarity": metric, }, { "type": "filter", "path": "_id", }, ] """ if self._dataset.media_type == fom.GROUP: fields.append( { "type": "filter", "path": self._dataset.group_field + ".name", } ) """ model = SearchIndexModel( name=self.config.index_name, type="vectorSearch", # requires pymongo>=4.7 definition={"fields": fields}, ) coll = self._dataset._sample_collection coll.create_search_index(model=model) self._index = True @property def ready(self): """Returns True/False whether the vector search index is ready to be queried. """ if self._index is None: return False try: coll = self._dataset._sample_collection indexes = { i["name"]: i for i in coll.aggregate([{"$listSearchIndexes": {}}]) } except OperationFailure: # requires MongoDB Atlas 7.0 or later return None info = indexes.get(self.config.index_name, {}) return info.get("status", None) == "READY" def add_to_index( self, embeddings, sample_ids, label_ids=None, overwrite=True, allow_existing=True, warn_existing=False, reload=True, ): if self._index is None: self._create_index(embeddings.shape[1]) sample_ids = np.asarray(sample_ids) label_ids = np.asarray(label_ids) if label_ids is not None else None if not overwrite or not allow_existing or warn_existing: if self._sample_ids is not None: _sample_ids, _label_ids = self._sample_ids, self._label_ids else: _sample_ids, _label_ids = self._parse_data( self._dataset, self.config ) index_sample_ids, index_label_ids, ii, _ = fbu.add_ids( sample_ids, label_ids, _sample_ids, _label_ids, patches_field=self.config.patches_field, overwrite=overwrite, allow_existing=allow_existing, warn_existing=warn_existing, ) self._sample_ids = index_sample_ids self._label_ids = index_label_ids if ii.size == 0: return embeddings = embeddings[ii, :] sample_ids = sample_ids[ii] label_ids = label_ids[ii] if label_ids is not None else None else: index_sample_ids = None index_label_ids = None fbu.add_embeddings( self._dataset, embeddings.tolist(), # MongoDB requires list fields sample_ids, label_ids, self.config.embeddings_field, patches_field=self.config.patches_field, ) if reload: super().reload() self._sample_ids = index_sample_ids self._label_ids = index_label_ids def remove_from_index( self, sample_ids=None, label_ids=None, allow_missing=True, warn_missing=False, reload=True, ): if not allow_missing or warn_missing: if self._sample_ids is not None: _sample_ids, _label_ids = self._sample_ids, self._label_ids else: _sample_ids, _label_ids = self._parse_data( self._dataset, self.config ) index_sample_ids, index_label_ids, rm_inds = fbu.remove_ids( sample_ids, label_ids, _sample_ids, _label_ids, patches_field=self.config.patches_field, allow_missing=allow_missing, warn_missing=warn_missing, ) self._sample_ids = index_sample_ids self._label_ids = index_label_ids if rm_inds.size == 0: return if self.config.patches_field is not None: sample_ids = None label_ids = _label_ids[rm_inds] else: sample_ids = _sample_ids[rm_inds] label_ids = None else: index_sample_ids = None index_label_ids = None fbu.remove_embeddings( self._dataset, self.config.embeddings_field, sample_ids=sample_ids, label_ids=label_ids, patches_field=self.config.patches_field, ) if reload: super().reload() self._sample_ids = index_sample_ids self._label_ids = index_label_ids def get_embeddings( self, sample_ids=None, label_ids=None, allow_missing=True, warn_missing=False, ): if self._dataset.media_type == fom.GROUP: samples = self._dataset.select_group_slices(_allow_mixed=True) else: samples = self._dataset if sample_ids is not None: samples = samples.select(sample_ids) elif label_ids is not None: if self.config.patches_field is None: raise ValueError("This index does not support label IDs") if sample_ids is not None: logger.warning( "Ignoring sample IDs when label IDs are provided" ) samples = samples.select_labels( ids=label_ids, fields=self.config.patches_field ) _embeddings, _sample_ids, _label_ids = fbu.get_embeddings( samples, patches_field=self.config.patches_field, embeddings_field=self.config.embeddings_field, ) if label_ids is not None: inds = _get_inds( label_ids, _label_ids, "label", allow_missing, warn_missing, ) embeddings = _embeddings[inds, :] sample_ids = _sample_ids[inds] label_ids = np.asarray(label_ids) elif sample_ids is not None: if etau.is_str(sample_ids): sample_ids = [sample_ids] if self.config.patches_field is not None: sample_ids = set(sample_ids) bools = [_id in sample_ids for _id in _sample_ids] inds = np.nonzero(bools)[0] else: inds = _get_inds( sample_ids, _sample_ids, "sample", allow_missing, warn_missing, ) embeddings = _embeddings[inds, :] sample_ids = _sample_ids[inds] if self.config.patches_field is not None: label_ids = _label_ids[inds] else: label_ids = None else: embeddings = _embeddings sample_ids = _sample_ids label_ids = _label_ids return embeddings, sample_ids, label_ids def reload(self): self._sample_ids = None self._label_ids = None super().reload() def cleanup(self): if self._index is None: return try: coll = self._dataset._sample_collection coll.drop_search_index(self.config.index_name) except OperationFailure: # requires MongoDB Atlas 7.0 or later pass self._index = None def _kneighbors( self, query=None, k=None, reverse=False, aggregation=None, return_dists=False, ): if query is None: raise ValueError("MongoDB does not support full index neighbors") if reverse is True: raise ValueError( "MongoDB does not support least similarity queries" ) if aggregation not in (None, "mean"): raise ValueError( f"MongoDB does not support {aggregation} aggregation" ) if k is None: k = min(self.index_size, self.config.max_k) # https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage num_candidates = min(10 * k, self.config.max_k) query = self._parse_neighbors_query(query) if aggregation == "mean" and query.ndim == 2: query = query.mean(axis=0) single_query = query.ndim == 1 if single_query: query = [query] if self.has_view: index_ids = self.current_sample_ids # if self.config.patches_field is not None: # index_ids = self.current_label_ids else: index_ids = None dataset = self._dataset sample_ids = [] label_ids = None # if self.config.patches_field is not None: # label_ids = [] dists = [] for q in query: search = { "index": self.config.index_name, "path": self.config.embeddings_field, "limit": k, "numCandidates": num_candidates, "queryVector": q.tolist(), } if index_ids is not None: search["filter"] = { "_id": {"$in": [ObjectId(_id) for _id in index_ids]} } """ elif dataset.media_type == fom.GROUP: # $vectorSearch must be the first stage in all pipelines, so we # have to incorporate slice selection as a $filter name_field = dataset.group_field + ".name" group_slice = self.view.group_slice or dataset.group_slice search["filter"] = {name_field: {"$eq": group_slice}} """ project = {"_id": 1} # if self.config.patches_field is not None: # project["_sample_id"] = 1 if return_dists: project["score"] = {"$meta": "vectorSearchScore"} # https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage pipeline = [{"$vectorSearch": search}, {"$project": project}] try: matches = list( dataset._aggregate( pipeline=pipeline, manual_group_select=True ) ) except OperationFailure as e: if index_ids is None: raise e logger.warning( "This legacy search index does not yet support views. " "Please follow the instructions at " "https://github.com/voxel51/fiftyone-brain/pull/248 " "to upgrade it.\n\nIn the meantime, the full index will " "instead be queried, which may result in fewer " "matches in your current view" ) search.pop("filter") matches = list( dataset._aggregate( pipeline=pipeline, manual_group_select=True ) ) sample_ids.append([str(m["_id"]) for m in matches]) # if self.config.patches_field is not None: # sample_ids.append([str(m["_sample_id"]) for m in matches]) # label_ids.append([str(m["_id"]) for m in matches]) if return_dists: dists.append([m["score"] for m in matches]) if single_query: sample_ids = sample_ids[0] if label_ids is not None: label_ids = label_ids[0] if return_dists: dists = dists[0] if return_dists: return sample_ids, label_ids, dists return sample_ids, label_ids def _parse_neighbors_query(self, query): if etau.is_str(query): query_ids = [query] single_query = True else: query = np.asarray(query) # Query by vector(s) if np.issubdtype(query.dtype, np.number): return query query_ids = list(query) single_query = False # Query by ID(s) embeddings = self._get_embeddings(query_ids) num_missing = len(query_ids) - len(embeddings) for e in embeddings: num_missing += int(e is None) if num_missing > 0: if single_query: raise ValueError("The query ID does not exist in this index") else: raise ValueError( f"{num_missing} query IDs do not exist in this index" ) query = np.array(embeddings) if single_query: query = query[0, :] return query def _get_embeddings(self, query_ids): if self._dataset.media_type == fom.GROUP: samples = self._dataset.select_group_slices(_allow_mixed=True) else: samples = self._dataset patches_field = self.config.patches_field embeddings_field = self.config.embeddings_field if patches_field is not None: _, embeddings_path = self._dataset._get_label_field_path( patches_field, embeddings_field ) view = samples.filter_labels( patches_field, F("_id").is_in(query_ids) ) embeddings = view.values(embeddings_path, unwind=True) else: view = samples.select(query_ids) embeddings = view.values(embeddings_field) return embeddings @staticmethod def _parse_data(samples, config): if samples.media_type == fom.GROUP: samples = samples.select_group_slices(_allow_mixed=True) if config.patches_field is not None: samples = samples.filter_labels( config.patches_field, F(config.embeddings_field).exists() ) else: samples = samples.exists(config.embeddings_field) return fbu.get_ids(samples, patches_field=config.patches_field) @classmethod def _from_dict(cls, d, samples, config, brain_key): return cls(samples, config, brain_key) def _get_inds(ids, index_ids, ftype, allow_missing, warn_missing): if etau.is_str(ids): ids = [ids] ids_map = {_id: i for i, _id in enumerate(index_ids)} inds = [] bad_ids = [] for _id in ids: idx = ids_map.get(_id, None) if idx is not None: inds.append(idx) else: bad_ids.append(_id) num_missing = len(bad_ids) if num_missing > 0: if not allow_missing: raise ValueError( "Found %d %s IDs (eg '%s') that are not present in the index" % (num_missing, ftype, bad_ids[0]) ) if warn_missing: logger.warning( "Ignoring %d %s IDs that are not present in the index", num_missing, ftype, ) return np.array(inds) ================================================ FILE: fiftyone/brain/internal/core/mosaic.py ================================================ """ Mosaic similarity backend. | Copyright 2017-2026, Voxel51, Inc. | `voxel51.com `_ | """ import logging import numpy as np import eta.core.utils as etau import fiftyone.core.utils as fou from fiftyone.brain.similarity import ( SimilarityConfig, Similarity, SimilarityIndex, ) import fiftyone.brain.internal.core.utils as fbu vector_search_client = fou.lazy_import("databricks.vector_search.client") logger = logging.getLogger(__name__) # Todo: add in required for arguments that are necessary to create the index table class MosaicSimilarityConfig(SimilarityConfig): """Configuration for the Mosaic similarity backend. Args: endpoint_name (None): the name of the vector search endpoint that was created in the Databricks workspace workspace_url (None): the URL of the Databricks workspace catalog_name (None): the name of the catalog in the Databricks workspace schema_name (None): the name of the schema in the Databricks workspace index_name (None): the name of the index to use, if one is not provided, a unique name will be generated service_principal_client_id (None): the client ID of the service principal created for authentication service_principal_client_secret (None): the client secret of the service principal created for authentication personal_access_token (None): the personal access token created for authentication **kwargs: keyword arguments for :class:`fiftyone.brain.similarity.SimilarityConfig` """ def __init__( self, endpoint_name=None, workspace_url=None, catalog_name=None, schema_name=None, index_name=None, service_principal_client_id=None, service_principal_client_secret=None, personal_access_token=None, **kwargs, ): super().__init__(**kwargs) self.index_name = index_name self.endpoint_name = endpoint_name self.catalog_name = catalog_name self.schema_name = schema_name # store privately so these aren't serialized self._workspace_url = workspace_url self._service_principal_client_id = service_principal_client_id self._service_principal_client_secret = service_principal_client_secret self._personal_access_token = personal_access_token @property def method(self): return "mosaic" @property def workspace_url(self): return self._workspace_url @workspace_url.setter def workspace_url(self, workspace_url): self._workspace_url = workspace_url @property def service_principal_client_id(self): return self._service_principal_client_id @service_principal_client_id.setter def service_principal_client_id(self, service_principal_client_id): self._service_principal_client_id = service_principal_client_id @property def service_principal_client_secret(self): return self._service_principal_client_secret @service_principal_client_secret.setter def service_principal_client_secret(self, service_principal_client_secret): self._service_principal_client_secret = service_principal_client_secret @property def personal_access_token(self): return self._personal_access_token @personal_access_token.setter def personal_access_token(self, personal_access_token): self._personal_access_token = personal_access_token @property def max_k(self): return None @property def supports_least_similarity(self): return False @property def supported_aggregations(self): return ("mean",) def load_credentials( self, workspace_url=None, service_principal_client_id=None, service_principal_client_secret=None, personal_access_token=None, ): self._load_parameters( workspace_url=workspace_url, service_principal_client_id=service_principal_client_id, service_principal_client_secret=service_principal_client_secret, personal_access_token=personal_access_token, ) class MosaicSimilarity(Similarity): """Mosaic similarity factory. Args: config: a :class:`MosaicSimilarityConfig` """ def ensure_requirements(self): fou.ensure_package("databricks-vectorsearch") def ensure_usage_requirements(self): fou.ensure_package("databricks-vectorsearch") def initialize(self, samples, brain_key): return MosaicSimilarityIndex( samples, self.config, brain_key, backend=self ) class MosaicSimilarityIndex(SimilarityIndex): """Class for interacting with Mosaic similarity indexes. Args: samples: the :class:`fiftyone.core.collections.SampleCollection` used config: the :class:`MosaicSimilarityConfig` used brain_key: the brain key backend (None): a :class:`MosaicSimilarity` instance """ def __init__(self, samples, config, brain_key, backend=None): super().__init__(samples, config, brain_key, backend=backend) self._client = None self._index = None self._initialize() def _initialize(self): self._client = vector_search_client.VectorSearchClient( workspace_url=self.config.workspace_url, service_principal_client_id=self.config.service_principal_client_id, service_principal_client_secret=self.config.service_principal_client_secret, personal_access_token=self.config.personal_access_token, ) try: index_names_result = self._client.list_indexes( self.config.endpoint_name ) except Exception as e: raise ValueError( f"Failed to list indexes from endpoint :{self.config.endpoint_name}" ) from e index_prefix = f"{self.config.catalog_name}.{self.config.schema_name}." if not index_names_result: index_names = [] else: index_names = [ ind["name"].replace(index_prefix, "") for ind in index_names_result["vector_indexes"] if ind["name"].startswith(index_prefix) ] if self.config.index_name is None: root = "fiftyone-" + fou.to_slug(self._samples._root_dataset.name) index_name = fbu.get_unique_name(root, index_names) self.config.index_name = index_name self.save_config() if self.config.index_name in index_names: index = self._client.get_index( endpoint_name=self.config.endpoint_name, index_name=f"{index_prefix}{self.config.index_name}", ) else: index = None self._index = index def _create_index(self, dimension): self._index = self._client.create_direct_access_index( endpoint_name=self.config.endpoint_name, index_name=f"{self.config.catalog_name}.{self.config.schema_name}.{self.config.index_name}", primary_key="foid", embedding_dimension=dimension, embedding_vector_column="embedding_vector", schema={ "foid": "string", "sample_id": "string", "embedding_vector": "array", }, ) @property def client(self): """The ``databricks.vector_search.client.VectorSearchClient`` instance for this index.""" return self._client @property def total_index_size(self): if self._index is None: return 0 return self._index.describe()["status"]["indexed_row_count"] def add_to_index( self, embeddings, sample_ids, label_ids=None, overwrite=True, allow_existing=True, warn_existing=False, reload=True, batch_size=200, ): if self._index is None: self._create_index(embeddings.shape[1]) if label_ids is not None: ids = label_ids else: ids = sample_ids if warn_existing or not allow_existing or not overwrite: index_ids = self._get_index_ids() existing_ids = set(ids) & set(index_ids) num_existing = len(existing_ids) if num_existing > 0: if not allow_existing: raise ValueError( "Found %d IDs (eg %s) that already exist in the index" % (num_existing, next(iter(existing_ids))) ) if warn_existing: if overwrite: logger.warning( "Overwriting %d IDs that already exist in the " "index", num_existing, ) else: logger.warning( "Skipping %d IDs that already exist in the index", num_existing, ) else: existing_ids = set() if existing_ids and not overwrite: del_inds = [i for i, _id in enumerate(ids) if _id in existing_ids] embeddings = np.delete(embeddings, del_inds, axis=0) sample_ids = np.delete(sample_ids, del_inds) if label_ids is not None: label_ids = np.delete(label_ids, del_inds) for _embeddings, _ids, _sample_ids in zip( fou.iter_batches(embeddings, batch_size), fou.iter_batches(ids, batch_size), fou.iter_batches(sample_ids, batch_size), ): result = [ {"foid": f, "sample_id": s, "embedding_vector": list(e)} for f, s, e in zip(_ids, _sample_ids, _embeddings) ] self._index.upsert(result) if reload: self.reload() def _get_index_ids(self, batch_size=200): ids = set() result = self._index.scan(num_results=batch_size) while len(result) > 0: ids.update( [ doc["fields"][0]["value"]["string_value"] for doc in result["data"] ] ) last_primary_key = result["last_primary_key"] result = self._index.scan( num_results=batch_size, last_primary_key=last_primary_key ) return list(ids) def _get_values(self, ids, batch_size=200): embeddings = [] result = self._index.scan(num_results=batch_size) while len(result) > 0: for doc in result["data"]: foid = doc["fields"][0]["value"]["string_value"] if foid in ids: embedding = [ d["number_value"] for d in doc["fields"][2]["value"]["list_value"][ "values" ] ] embeddings.append(embedding) last_primary_key = result["last_primary_key"] result = self._index.scan( num_results=batch_size, last_primary_key=last_primary_key ) return embeddings def remove_from_index( self, sample_ids=None, label_ids=None, allow_missing=True, warn_missing=False, reload=True, ): if label_ids is not None: ids = label_ids else: ids = sample_ids if not allow_missing or warn_missing: existing_ids = self._get_index_ids() missing_ids = set(ids) - set(existing_ids) num_missing = len(missing_ids) if num_missing > 0: if not allow_missing: raise ValueError( "Found %d IDs (eg %s) that are not present in the " "index" % (num_missing, next(iter(missing_ids))) ) if warn_missing: logger.warning( "Ignoring %d IDs that are not present in the index", num_missing, ) ids = existing_ids self._index.delete(ids) if reload: self.reload() def get_embeddings( self, sample_ids=None, label_ids=None, allow_missing=True, warn_missing=False, ): if label_ids is not None: if self.config.patches_field is None: raise ValueError("This index does not support label IDs") if sample_ids is not None: logger.warning( "Ignoring sample IDs when label IDs are provided" ) if sample_ids is not None and self.config.patches_field is not None: ( embeddings, sample_ids, label_ids, missing_ids, ) = self._get_patch_embeddings_from_sample_ids(sample_ids) elif self.config.patches_field is not None: ( embeddings, sample_ids, label_ids, missing_ids, ) = self._get_patch_embeddings_from_label_ids(label_ids) else: ( embeddings, sample_ids, label_ids, missing_ids, ) = self._get_sample_embeddings(sample_ids) num_missing_ids = len(missing_ids) if num_missing_ids > 0: if not allow_missing: raise ValueError( "Found %d IDs (eg %s) that do not exist in the index" % (num_missing_ids, missing_ids[0]) ) if warn_missing: logger.warning( "Skipping %d IDs that do not exist in the index", num_missing_ids, ) embeddings = np.array(embeddings) sample_ids = np.array(sample_ids) if label_ids is not None: label_ids = np.array(label_ids) return embeddings, sample_ids, label_ids # Note: might be an arg in delete_brain_run? def cleanup(self): if self._index is not None: self._client.delete_index( self.config.endpoint_name, f"{self.config.catalog_name}.{self.config.schema_name}.{self.config.index_name}", ) self._index = None def _get_sample_embeddings(self, sample_ids, batch_size=200): found_embeddings = [] found_sample_ids = [] if sample_ids is None: sample_ids = self._get_index_ids() result = self._index.scan(num_results=batch_size) while len(result) > 0: for doc in result["data"]: sample_id = doc["fields"][1]["value"]["string_value"] if sample_id in sample_ids: embedding = [ d["number_value"] for d in doc["fields"][2]["value"]["list_value"][ "values" ] ] found_embeddings.append(embedding) found_sample_ids.append(sample_id) last_primary_key = result["last_primary_key"] result = self._index.scan( num_results=batch_size, last_primary_key=last_primary_key ) missing_ids = list(set(sample_ids) - set(found_sample_ids)) return found_embeddings, found_sample_ids, None, missing_ids def _get_patch_embeddings_from_label_ids(self, label_ids, batch_size=200): found_embeddings = [] found_sample_ids = [] found_label_ids = [] if label_ids is None: label_ids = self._get_index_ids() result = self._index.scan(num_results=batch_size) while len(result) > 0: for doc in result["data"]: label_id = doc["fields"][0]["value"]["string_value"] if label_id in label_ids: embedding = [ d["number_value"] for d in doc["fields"][2]["value"]["list_value"][ "values" ] ] found_embeddings.append(embedding) found_label_ids.append(label_id) found_sample_ids.append( doc["fields"][1]["value"]["string_value"] ) last_primary_key = result["last_primary_key"] result = self._index.scan( num_results=batch_size, last_primary_key=last_primary_key ) missing_ids = list(set(label_ids) - set(found_label_ids)) return found_embeddings, found_sample_ids, found_label_ids, missing_ids def _get_patch_embeddings_from_sample_ids( self, sample_ids, batch_size=200 ): found_embeddings = [] found_sample_ids = [] found_label_ids = [] result = self._index.scan(num_results=batch_size) while len(result) > 0: for doc in result["data"]: sample_id = doc["fields"][1]["value"]["string_value"] if sample_id in sample_ids: embedding = [ d["number_value"] for d in doc["fields"][2]["value"]["list_value"][ "values" ] ] found_embeddings.append(embedding) found_sample_ids.append(sample_id) found_label_ids.append( doc["fields"][0]["value"]["string_value"] ) last_primary_key = result["last_primary_key"] result = self._index.scan( num_results=batch_size, last_primary_key=last_primary_key ) missing_ids = list(set(sample_ids) - set(found_sample_ids)) return found_embeddings, found_sample_ids, found_label_ids, missing_ids def _kneighbors( self, query=None, k=None, reverse=False, aggregation=None, return_dists=False, ): if query is None: raise ValueError("Mosaic does not support full index neighbors") if reverse is True: raise ValueError( "Mosaic does not support least similarity queries" ) if k is None: k = self.index_size if aggregation not in (None, "mean"): raise ValueError("Unsupported aggregation '%s'" % aggregation) query = self._parse_neighbors_query(query) if aggregation == "mean" and query.ndim == 2: query = query.mean(axis=0) single_query = query.ndim == 1 if single_query: query = [query] if self.has_view: if self.config.patches_field is not None: index_ids = self.current_label_ids else: index_ids = self.current_sample_ids # @todo apply filtering in similarity_search(), not post-hoc # As of this writing, filtering is supported in Mosaic but it is # not robust and cannot handle a large number of IDs logger.warning( "The Mosaic backend does not yet support view filters; the " "full index will instead be queried, which may result in " "fewer matches in your current view" ) _filter = {"foid": set(index_ids)} else: _filter = None sample_ids = [] label_ids = [] if self.config.patches_field is not None else None dists = [] for q in query: results = self._index.similarity_search( columns=["foid", "sample_id"], query_vector=[float(i) for i in list(q)], num_results=k, )["result"]["data_array"] if _filter is not None: results = [r for r in results if r[0] in _filter["foid"]] if self.config.patches_field is not None: sample_ids.append([r[1] for r in results]) label_ids.append([r[0] for r in results]) else: sample_ids.append([r[0] for r in results]) if return_dists: dists.append([r[2] for r in results]) if single_query: sample_ids = sample_ids[0] if label_ids is not None: label_ids = label_ids[0] if return_dists: dists = dists[0] if return_dists: return sample_ids, label_ids, dists return sample_ids, label_ids def _parse_neighbors_query(self, query): if etau.is_str(query): query_ids = [query] single_query = True else: query = np.asarray(query) # Query by vector(s) if np.issubdtype(query.dtype, np.number): return query query_ids = list(query) single_query = False # Query by ID(s) embeddings = self._get_values(query_ids) if len(embeddings) == 0: raise ValueError( "Query IDs %s do not exist in this index" % query_ids ) query = np.array(embeddings) if single_query: query = query[0, :] return query @classmethod def _from_dict(cls, d, samples, config, brain_key): return cls(samples, config, brain_key) ================================================ FILE: fiftyone/brain/internal/core/pgvector.py ================================================ """ PGVector similarity backend. | Copyright 2017-2026, Voxel51, Inc. | `voxel51.com `_ | """ import logging import numpy as np import eta.core.utils as etau import fiftyone.core.utils as fou from fiftyone.brain.similarity import ( SimilarityConfig, Similarity, SimilarityIndex, ) import fiftyone.brain.internal.core.utils as fbu psycopg2 = fou.lazy_import("psycopg2") psy_extras = fou.lazy_import("psycopg2.extras") logger = logging.getLogger(__name__) # Supported metrics for pgvector _SUPPORTED_METRICS = { "cosine": "vector_cosine_ops", "dotproduct": "vector_ip_ops", "euclidean": "vector_l2_ops", "l1": "vector_l1_ops", "jaccard": "vector_jaccard_ops", "hamming": "vector_hamming_ops", } class PgVectorSimilarityConfig(SimilarityConfig): """Configuration for the PGVector similarity backend. Args: index_name (None): the name of the PGVector index to use or create. If none is provided, a default index name will be used. table_name (None): the name of the table to use or create. If none is provided, a default table name will be used. metric ("cosine"): the similarity metric to use. Supported values are ``("cosine", "dotproduct", "euclidean", "l1", "jaccard", "hamming")`` connection_string (None): the connection string to the PostgreSQL database ssl_cert (None): the path to the SSL certificate file ssl_key (None): the path to the secret key used for the client certificate ssl_root_cert (None): the path to the file containing SSL certificate authority (CA) certificate(s). work_mem ("64MB"): the base maximum amount of memory to be used by a query operation (such as a sort or hash table) before writing to temporary disk files hnsw_m (16): the max number of connections per layer in the HNSW index hnsw_ef_construction (64): the size of the dynamic candidate list for constructing the graph for the HNSW index **kwargs: keyword arguments for :class:`fiftyone.brain.similarity.SimilarityConfig` """ def __init__( self, index_name=None, table_name=None, metric="cosine", connection_string=None, ssl_cert=None, ssl_key=None, ssl_root_cert=None, work_mem="64MB", hnsw_m=16, hnsw_ef_construction=64, **kwargs, ): if metric not in _SUPPORTED_METRICS: raise ValueError( f"Unsupported metric '{metric}'. " f"Supported values are {_SUPPORTED_METRICS}" ) super().__init__(**kwargs) self.metric = metric self.ssl_cert = ssl_cert self.ssl_key = ssl_key self.ssl_root_cert = ssl_root_cert self.work_mem = work_mem self.index_name = index_name self.table_name = table_name self.hnsw_m = hnsw_m self.hnsw_ef_construction = hnsw_ef_construction self._connection_string = connection_string @property def method(self): return "pgvector" @property def connection_string(self): return self._connection_string @connection_string.setter def connection_string(self, connection_string): self._connection_string = connection_string @property def max_k(self): return 10000 @property def supports_least_similarity(self): return False @property def supported_aggregations(self): return ("mean",) def load_credentials( self, connection_string=None, ): self._load_parameters(connection_string=connection_string) class PgVectorSimilarity(Similarity): """PGVector similarity factory. Args: config: a :class:`PgVectorSimilarityConfig` """ def ensure_requirements(self): fou.ensure_package("psycopg2|psycopg2-binary") def ensure_usage_requirements(self): fou.ensure_package("psycopg2|psycopg2-binary") def initialize(self, samples, brain_key): return PgVectorSimilarityIndex( samples, self.config, brain_key, backend=self ) class PgVectorSimilarityIndex(SimilarityIndex): """Class for interacting with PGVector similarity indexes. Args: samples: the :class:`fiftyone.core.collections.SampleCollection` used config: the :class:`PGVectorSimilarityConfig` used brain_key: the brain key backend (None): a :class:`PGVectorSimilarity` instance """ def __init__(self, samples, config, brain_key, backend=None): super().__init__(samples, config, brain_key, backend=backend) self._conn = None self._cur = None self._initialize() @property def total_index_size(self): if self._conn.closed: self._initialize() try: self._cur.execute( f"""SELECT COUNT(*) FROM "{self.config.table_name}";""" ) return self._cur.fetchone()[0] except Exception as e: logger.error(f"Error getting index size: {str(e)}") return 0 def _initialize(self): ssl_options = {} if self.config.ssl_cert: ssl_options["sslcert"] = self.config.ssl_cert if self.config.ssl_key: ssl_options["sslkey"] = self.config.ssl_key if self.config.ssl_root_cert: ssl_options["sslrootcert"] = self.config.ssl_root_cert logger.info(f"Connecting to PostgreSQL database") self._conn = psycopg2.connect( self.config.connection_string, **ssl_options ) self._cur = self._conn.cursor() try: self._cur.execute("CREATE EXTENSION IF NOT EXISTS vector") self._conn.commit() except Exception as e: logger.error(f"Error creating vector extension: {str(e)}") raise if self.config.table_name is None: table_names = self._get_table_names() root = "fiftyone-" + fou.to_slug(self.samples._root_dataset.name) table_name = fbu.get_unique_name(root, table_names) self.config.table_name = table_name self.save_config() existing_indexes = [] else: existing_indexes = self._get_index_names(self.config.table_name) if self.config.index_name is None: root = "fiftyone-index-" + fou.to_slug( self.samples._root_dataset.name ) index_name = fbu.get_unique_name(root, existing_indexes) self.config.index_name = index_name self.save_config() def _get_table_names(self): self._cur.execute( "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';" ) return [row[0] for row in self._cur.fetchall()] def _get_index_names(self, table_name): self._cur.execute( f"SELECT indexname FROM pg_indexes WHERE tablename = '{table_name}' AND schemaname = 'public';" ) return [row[0] for row in self._cur.fetchall()] def _create_table(self, dimension): try: self._cur.execute( f""" CREATE TABLE IF NOT EXISTS "{self.config.table_name}" ( id TEXT PRIMARY KEY, sample_id TEXT, embedding_vector VECTOR({dimension}) ); """ ) self._conn.commit() except Exception as e: logger.error( f"Error creating table: {self.config.table_name} with dimension {dimension}: {str(e)}" ) raise def create_hnsw_index(self): operator_class = _SUPPORTED_METRICS[self.config.metric] try: self._cur.execute( f"""DROP INDEX IF EXISTS "{self.config.index_name}";""" ) self._conn.commit() self._cur.execute( f""" CREATE INDEX "{self.config.index_name}" ON "{self.config.table_name}" USING hnsw (embedding_vector {operator_class}) WITH (m = %s, ef_construction = %s); """, (self.config.hnsw_m, self.config.hnsw_ef_construction), ) self._conn.commit() except Exception as e: logger.error( f"Error creating HNSW index on table {self.config.table_name}:{str(e)}" ) raise def _get_index_ids(self, batch_size=1000): named_cursor = self._conn.cursor( name="id_cursor" ) # Named cursor for server-side query named_cursor.execute(f"""SELECT id FROM "{self.config.table_name}";""") existing_ids = [] while True: rows = named_cursor.fetchmany(batch_size) if not rows: break ids = [row[0] for row in rows] existing_ids.extend(ids) named_cursor.close() return existing_ids def add_to_index( self, embeddings, sample_ids, label_ids=None, overwrite=True, allow_existing=True, warn_existing=False, reload=True, batch_size=5000, close_conn=True, ): if self._conn.closed: self._initialize() self._cur.execute(f"SET work_mem TO '{self.config.work_mem}'") if self.config.table_name not in self._get_table_names(): self._create_table(embeddings.shape[1]) if label_ids is not None: ids = label_ids else: ids = sample_ids if warn_existing or not allow_existing or not overwrite: index_ids = self._get_index_ids() existing_ids = set(ids) & set(index_ids) num_existing = len(existing_ids) if num_existing > 0: if not allow_existing: raise ValueError( "Found %d IDs (eg %s) that already exist in the index" % (num_existing, next(iter(existing_ids))) ) if warn_existing: if overwrite: logger.warning( "Overwriting %d IDs that already exist in the " "index", num_existing, ) else: logger.warning( "Skipping %d IDs that already exist in the index", num_existing, ) else: existing_ids = set() if existing_ids and not overwrite: query = f""" INSERT INTO "{self.config.table_name}" (id, sample_id, embedding_vector) VALUES %s ON CONFLICT (id) DO NOTHING; """ else: query = f""" INSERT INTO "{self.config.table_name}" (id, sample_id, embedding_vector) VALUES %s ON CONFLICT (id) DO UPDATE SET sample_id = EXCLUDED.sample_id, embedding_vector = EXCLUDED.embedding_vector; """ embeddings = [e.tolist() for e in embeddings] sample_ids = list(sample_ids) if label_ids is not None: ids = list(label_ids) else: ids = list(sample_ids) for _embeddings, _ids, _sample_ids in zip( fou.iter_batches(embeddings, batch_size), fou.iter_batches(ids, batch_size), fou.iter_batches(sample_ids, batch_size), ): data = list(zip(_ids, _sample_ids, _embeddings)) psy_extras.execute_values(self._cur, query, data) self._conn.commit() self.create_hnsw_index() if close_conn: self.close_connections() if reload: self.reload() def remove_from_index( self, sample_ids=None, label_ids=None, allow_missing=True, warn_missing=False, reload=True, ): if self._conn.closed: self._initialize() if label_ids is not None: ids = label_ids else: ids = sample_ids if warn_missing or not allow_missing: response = self.get_embeddings_by_id(ids) existing_ids = [id for id, emb in response] missing_ids = set(ids) - set(existing_ids) num_missing_ids = len(missing_ids) if num_missing_ids > 0: if not allow_missing: raise ValueError( "Found %d IDs (eg %s) that do not exist in the index" % (num_missing_ids, next(iter(missing_ids))) ) if warn_missing and not allow_missing: logger.warning( "Skipping %d IDs that do not exist in the index", num_missing_ids, ) try: # Use parameterized query to delete multiple IDs self._cur.execute( f"""DELETE FROM "{self.config.table_name}" WHERE id IN %s;""", (tuple(ids),), ) except Exception as e: self._conn.rollback() logger.error(f"Error removing embeddings for ids {ids}: {str(e)}") raise deleted_count = self._cur.rowcount self._conn.commit() logger.info(f"Deleted {deleted_count} embeddings from the index.") if reload: self.reload() def close_connections(self): if not self._cur.closed: self._cur.close() if not self._conn.closed: self._conn.close() def get_embeddings_by_id(self, sample_ids=None, label_ids=None): if self._conn.closed: self._initialize() if label_ids is not None: try: self._cur.execute( f"""SELECT id, sample_id, embedding_vector FROM "{self.config.table_name}" WHERE id = ANY(%s)""", (list(label_ids),), ) except Exception as e: logger.error( f"Error fetching embeddings for labels {label_ids}: {str(e)}" ) raise elif sample_ids is not None: try: self._cur.execute( f"""SELECT id, sample_id, embedding_vector FROM "{self.config.table_name}" WHERE sample_id = ANY(%s)""", (list(sample_ids),), ) except Exception as e: logger.error( f"Error fetching embeddings for samples {sample_ids}: {str(e)}" ) raise else: try: self._cur.execute( f"""SELECT id, sample_id, embedding_vector FROM "{self.config.table_name}";""" ) except Exception as e: logger.error( f"Error fetching embeddings for all samples: {str(e)}" ) raise results = self._cur.fetchall() fo_id = [] sample_id = [] embeddings = [] for result in results: # Convert string "[1.2,3.4,5.6]" to float array if isinstance(result[2], str): emb = np.array( [float(x) for x in result[2].strip("[]").split(",")], dtype=np.float32, ) embeddings.append(emb) else: # Already numeric emb = np.array(result[2], dtype=np.float32) embeddings.append(emb) fo_id.append(result[0]) sample_id.append(result[1]) return fo_id, sample_id, embeddings def get_embeddings( self, sample_ids=None, label_ids=None, allow_missing=True, warn_missing=False, ): if label_ids is not None: if self.config.patches_field is None: raise ValueError("This index does not support label IDs") if sample_ids is not None: logger.warning( "Ignoring sample IDs when label IDs are provided" ) if sample_ids is not None and self.config.patches_field is not None: ( label_ids, found_sample_ids, embeddings, ) = self.get_embeddings_by_id(sample_ids=sample_ids) missing_ids = list(set(sample_ids) - set(found_sample_ids)) sample_ids = found_sample_ids elif self.config.patches_field is not None: ( found_label_ids, sample_ids, embeddings, ) = self.get_embeddings_by_id(label_ids=label_ids) missing_ids = ( list(set(label_ids) - set(found_label_ids)) if label_ids is not None else [] ) label_ids = found_label_ids else: ( label_ids, found_sample_ids, embeddings, ) = self.get_embeddings_by_id(sample_ids=sample_ids) missing_ids = ( list(set(sample_ids) - set(found_sample_ids)) if sample_ids is not None else [] ) sample_ids = found_sample_ids num_missing_ids = len(missing_ids) if num_missing_ids > 0: if not allow_missing: raise ValueError( "Found %d IDs (eg %s) that do not exist in the index" % (num_missing_ids, missing_ids[0]) ) if warn_missing: logger.warning( "Skipping %d IDs that do not exist in the index", num_missing_ids, ) embeddings = np.array(embeddings) sample_ids = np.array(sample_ids) if label_ids is not None: label_ids = np.array(label_ids) return embeddings, sample_ids, label_ids def _kneighbors( self, query=None, k=None, reverse=False, aggregation=None, return_dists=False, close_conn=True, ): if self._conn.closed: self._initialize() if query is None: raise ValueError("Postgres does not support full index neighbors") if aggregation not in (None, "mean"): raise ValueError("Unsupported aggregation '%s'" % aggregation) if k is None: k = self.index_size query = self._parse_neighbors_query(query) if aggregation == "mean" and query.ndim == 2: query = query.mean(axis=0) single_query = query.ndim == 1 if single_query: query = [query] index_ids = None if self.has_view: if self.config.patches_field is not None: index_ids = list(self.current_label_ids) else: index_ids = list(self.current_sample_ids) _filter = True else: _filter = False sort_order = "DESC" if reverse else "ASC" sample_ids = [] label_ids = [] if self.config.patches_field is not None else None dists = [] for q in query: if _filter: self._cur.execute( f""" SELECT id, sample_id, embedding_vector <-> %s::vector AS distance FROM "{self.config.table_name}" WHERE id = ANY(%s) ORDER BY distance {sort_order} LIMIT %s; """, (q.tolist(), index_ids, k), ) else: self._cur.execute( f""" SELECT id, sample_id, embedding_vector <-> %s::vector AS distance FROM "{self.config.table_name}" ORDER BY distance {sort_order} LIMIT %s; """, (q.tolist(), k), ) results = self._cur.fetchall() if self.config.patches_field is not None: sample_ids.append([r[1] for r in results]) label_ids.append([r[0] for r in results]) else: sample_ids.append([r[0] for r in results]) if return_dists: dists.append([r[2] for r in results]) if close_conn: self.close_connections() if single_query: sample_ids = sample_ids[0] if label_ids is not None: label_ids = label_ids[0] if return_dists: dists = dists[0] if return_dists: return sample_ids, label_ids, dists return sample_ids, label_ids def _parse_neighbors_query(self, query): if etau.is_str(query): query_ids = [query] single_query = True else: query = np.asarray(query) # Query by vector(s) if np.issubdtype(query.dtype, np.number): return query query_ids = list(query) single_query = False _, _, embeddings = self.get_embeddings_by_id(label_ids=query_ids) if len(embeddings) == 0: raise ValueError( "Query IDs %s do not exist in this index" % query_ids ) query = np.array(embeddings) if single_query: query = query[0, :] return query def cleanup(self, drop_table=False): """ Clean up the database by dropping the HNSW index and optionally the embeddings table. """ logger.info( f"Cleaning up: Deleting HNSW index '{self.config.index_name}'" ) self._cur.execute( f"""DROP INDEX IF EXISTS "{self.config.index_name}";""" ) if self._conn.closed: self._initialize() if drop_table: self._cur.execute( f"""DROP TABLE IF EXISTS "{self.config.table_name}";""" ) logger.info( f"{self.config.table_name} table deleted successfully." ) self._conn.commit() # Close the database connection self.close_connections() logger.info("Database connection closed.") @classmethod def _from_dict(cls, d, samples, config, brain_key): return cls(samples, config, brain_key) ================================================ FILE: fiftyone/brain/internal/core/pinecone.py ================================================ """ Piencone similarity backend. | Copyright 2017-2026, Voxel51, Inc. | `voxel51.com `_ | """ import logging import numpy as np import eta.core.utils as etau import fiftyone.core.utils as fou from fiftyone.brain.similarity import ( SimilarityConfig, Similarity, SimilarityIndex, ) import fiftyone.brain.internal.core.utils as fbu pinecone = fou.lazy_import("pinecone") logger = logging.getLogger(__name__) _SUPPORTED_METRICS = ("cosine", "dotproduct", "euclidean") class PineconeSimilarityConfig(SimilarityConfig): """Configuration for the Pinecone similarity backend. Args: index_name (None): the name of a Pinecone index to use or create. If none is provided, a new index will be created index_type (None): the index type to use when creating a new index. The supported values are ``["serverless", "pod"]`` and the default is ``"serverless"`` namespace (None): a namespace under which to store vectors added to the index metric (None): the embedding distance metric to use when creating a new index. Supported values are ``("cosine", "dotproduct", "euclidean")`` replicas (None): an optional number of replicas when creating a new pod-based index shards (None): an optional number of shards when creating a new pod-based index pods (None): an optional number of pods when creating a new pod-based index pod_type (None): an optional pod type when creating a new pod-based index api_key (None): a Pinecone API key to use cloud (None): a cloud to use when creating serverless indexes region (None): a region to use when creating serverless indexes environment (None): an environment to use when creating pod-based indexes **kwargs: keyword arguments for :class:`fiftyone.brain.similarity.SimilarityConfig` """ def __init__( self, index_name=None, index_type=None, namespace=None, metric=None, replicas=None, shards=None, pods=None, pod_type=None, api_key=None, cloud=None, region=None, environment=None, **kwargs, ): if metric is not None and metric not in _SUPPORTED_METRICS: raise ValueError( "Unsupported metric '%s'. Supported values are %s" % (metric, _SUPPORTED_METRICS) ) super().__init__(**kwargs) self.index_name = index_name self.index_type = index_type self.namespace = namespace self.metric = metric self.replicas = replicas self.shards = shards self.pods = pods self.pod_type = pod_type # store privately so these aren't serialized self._api_key = api_key self._cloud = cloud self._region = region self._environment = environment @property def method(self): return "pinecone" @property def api_key(self): return self._api_key @api_key.setter def api_key(self, value): self._api_key = value @property def cloud(self): return self._cloud @cloud.setter def cloud(self, value): self._cloud = value @property def region(self): return self._region @region.setter def region(self, value): self._region = value @property def environment(self): return self._environment @environment.setter def environment(self, value): self._environment = value @property def max_k(self): return 10000 # Pinecone limit @property def supports_least_similarity(self): return False @property def supported_aggregations(self): return ("mean",) def load_credentials( self, api_key=None, cloud=None, region=None, environment=None ): self._load_parameters( api_key=api_key, cloud=cloud, region=region, environment=environment, ) class PineconeSimilarity(Similarity): """Pinecone similarity factory. Args: config: a :class:`PineconeSimilarityConfig` """ def ensure_requirements(self): fou.ensure_package("pinecone-client") def ensure_usage_requirements(self): fou.ensure_package("pinecone-client>=3.2") def initialize(self, samples, brain_key): return PineconeSimilarityIndex( samples, self.config, brain_key, backend=self ) class PineconeSimilarityIndex(SimilarityIndex): """Class for interacting with Pinecone similarity indexes. Args: samples: the :class:`fiftyone.core.collections.SampleCollection` used config: the :class:`PineconeSimilarityConfig` used brain_key: the brain key backend (None): a :class:`PineconeSimilarity` instance """ def __init__(self, samples, config, brain_key, backend=None): super().__init__(samples, config, brain_key, backend=backend) self._pinecone = None self._index = None self._initialize() def _initialize(self): self._pinecone = pinecone.Pinecone(api_key=self.config.api_key) try: index_names = [d["name"] for d in self._pinecone.list_indexes()] except Exception as e: raise ValueError( "Failed to connect to Pinecone backend. " "Refer to https://docs.voxel51.com/integrations/pinecone.html " "for more information" ) from e if self.config.index_name is None: # https://docs.pinecone.io/troubleshooting/restrictions-on-index-names root = "fiftyone-" + fou.to_slug(self.samples._root_dataset.name) index_name = fbu.get_unique_name(root, index_names, max_len=45) self.config.index_name = index_name self.save_config() if self.config.index_name in index_names: index = self._pinecone.Index(self.config.index_name) else: index = None self._index = index def _create_index(self, dimension): index_type = self.config.index_type or "serverless" if index_type == "serverless": spec = pinecone.ServerlessSpec( self.config.cloud, self.config.region, ) elif index_type == "pod": kwargs = dict( pod_type=self.config.pod_type, pods=self.config.pods, replicas=self.config.replicas, shards=self.config.shards, ) kwargs = {k: v for k, v in kwargs.items() if v is not None} spec = pinecone.PodSpec(self.config.environment, **kwargs) else: raise TypeError( f"Invalid index_type='{index_type}'. The supported values are " "['serverless', 'pod']" ) metric = self.config.metric or "cosine" self._pinecone.create_index( name=self.config.index_name, dimension=dimension, metric=metric, spec=spec, ) self._index = self._pinecone.Index(self.config.index_name) @property def index(self): """The ``pinecone.Index`` instance for this index.""" return self._index @property def total_index_size(self): if self._index is None: return 0 return self._index.describe_index_stats()["total_vector_count"] @property def ready(self): return self._pinecone.describe_index(self.config.index_name).status[ "ready" ] def add_to_index( self, embeddings, sample_ids, label_ids=None, overwrite=True, allow_existing=True, warn_existing=False, reload=True, batch_size=100, namespace=None, ): if namespace is None: namespace = self.config.namespace if self._index is None: self._create_index(embeddings.shape[1]) if label_ids is not None: ids = label_ids else: ids = sample_ids if warn_existing or not allow_existing or not overwrite: existing_ids = self._get_existing_ids(ids) num_existing = len(existing_ids) if num_existing > 0: if not allow_existing: raise ValueError( "Found %d IDs (eg %s) that already exist in the index" % (num_existing, next(iter(existing_ids))) ) if warn_existing: if overwrite: logger.warning( "Overwriting %d IDs that already exist in the " "index", num_existing, ) else: logger.warning( "Skipping %d IDs that already exist in the index", num_existing, ) else: existing_ids = set() if existing_ids and not overwrite: del_inds = [i for i, _id in enumerate(ids) if _id in existing_ids] embeddings = np.delete(embeddings, del_inds, axis=0) sample_ids = np.delete(sample_ids, del_inds) if label_ids is not None: label_ids = np.delete(label_ids, del_inds) embeddings = [e.tolist() for e in embeddings] sample_ids = list(sample_ids) if label_ids is not None: ids = list(label_ids) else: ids = list(sample_ids) for _embeddings, _ids, _sample_ids in zip( fou.iter_batches(embeddings, batch_size), fou.iter_batches(ids, batch_size), fou.iter_batches(sample_ids, batch_size), ): _id_dicts = [ {"id": _id, "sample_id": _sid} for _id, _sid in zip(_ids, _sample_ids) ] self._index.upsert( list(zip(_ids, _embeddings, _id_dicts)), namespace=namespace, ) if reload: self.reload() def remove_from_index( self, sample_ids=None, label_ids=None, allow_missing=True, warn_missing=False, reload=True, ): if label_ids is not None: ids = label_ids else: ids = sample_ids if not allow_missing or warn_missing: existing_ids = list(self._index.fetch(ids).vectors.keys()) missing_ids = set(ids) - set(existing_ids) num_missing = len(missing_ids) if num_missing > 0: if not allow_missing: raise ValueError( "Found %d IDs (eg %s) that are not present in the " "index" % (num_missing, next(iter(missing_ids))) ) if warn_missing: logger.warning( "Ignoring %d IDs that are not present in the index", num_missing, ) ids = existing_ids self._index.delete(ids=ids) if reload: self.reload() def get_embeddings( self, sample_ids=None, label_ids=None, allow_missing=True, warn_missing=False, ): if label_ids is not None: if self.config.patches_field is None: raise ValueError("This index does not support label IDs") if sample_ids is not None: logger.warning( "Ignoring sample IDs when label IDs are provided" ) if sample_ids is not None and self.config.patches_field is not None: ( embeddings, sample_ids, label_ids, missing_ids, ) = self._get_patch_embeddings_from_sample_ids(sample_ids) elif self.config.patches_field is not None: ( embeddings, sample_ids, label_ids, missing_ids, ) = self._get_patch_embeddings_from_label_ids(label_ids) else: ( embeddings, sample_ids, label_ids, missing_ids, ) = self._get_sample_embeddings(sample_ids) num_missing_ids = len(missing_ids) if num_missing_ids > 0: if not allow_missing: raise ValueError( "Found %d IDs (eg %s) that do not exist in the index" % (num_missing_ids, missing_ids[0]) ) if warn_missing: logger.warning( "Skipping %d IDs that do not exist in the index", num_missing_ids, ) embeddings = np.array(embeddings) sample_ids = np.array(sample_ids) if label_ids is not None: label_ids = np.array(label_ids) return embeddings, sample_ids, label_ids def cleanup(self): self._pinecone.delete_index(self.config.index_name) self._index = None def _get_existing_ids(self, ids, batch_size=1000): existing_ids = set() for batch_ids in fou.iter_batches(ids, batch_size): response = self._index.fetch(ids=list(batch_ids))["vectors"] existing_ids.update(response.keys()) return existing_ids def _get_sample_embeddings(self, sample_ids, batch_size=1000): found_embeddings = [] found_sample_ids = [] if sample_ids is None: raise ValueError( "Pinecone does not support retrieving all vectors in an index" ) for batch_ids in fou.iter_batches(sample_ids, batch_size): response = self._index.fetch(ids=list(batch_ids))["vectors"] for r in response.values(): found_embeddings.append(r["values"]) found_sample_ids.append(r["id"]) missing_ids = list(set(sample_ids) - set(found_sample_ids)) return found_embeddings, found_sample_ids, None, missing_ids def _get_patch_embeddings_from_label_ids(self, label_ids, batch_size=1000): found_embeddings = [] found_sample_ids = [] found_label_ids = [] if label_ids is None: raise ValueError( "Pinecone does not support retrieving all vectors in an index" ) for batch_ids in fou.iter_batches(label_ids, batch_size): response = self._index.fetch(ids=list(batch_ids))["vectors"] for r in response.values(): found_embeddings.append(r["values"]) found_sample_ids.append(r["metadata"]["sample_id"]) found_label_ids.append(r["id"]) missing_ids = list(set(label_ids) - set(found_label_ids)) return found_embeddings, found_sample_ids, found_label_ids, missing_ids def _get_patch_embeddings_from_sample_ids( self, sample_ids, batch_size=100 ): found_embeddings = [] found_sample_ids = [] found_label_ids = [] query_vector = [0.0] * self._get_dimension() top_k = min(batch_size, self.config.max_k) for batch_ids in fou.iter_batches(sample_ids, batch_size): response = self._index.query( vector=query_vector, filter={"sample_id": {"$in": list(batch_ids)}}, top_k=top_k, include_values=True, include_metadata=True, ) for r in response["matches"]: found_embeddings.append(r["values"]) found_sample_ids.append(r["metadata"]["sample_id"]) found_label_ids.append(r["id"]) missing_ids = list(set(sample_ids) - set(found_sample_ids)) return found_embeddings, found_sample_ids, found_label_ids, missing_ids def _kneighbors( self, query=None, k=None, reverse=False, aggregation=None, return_dists=False, ): if query is None: raise ValueError("Pinecone does not support full index neighbors") if reverse is True: raise ValueError( "Pinecone does not support least similarity queries" ) if k is None or k > self.config.max_k: raise ValueError("Pinecone requires k<=%s" % self.config.max_k) if aggregation not in (None, "mean"): raise ValueError("Unsupported aggregation '%s'" % aggregation) query = self._parse_neighbors_query(query) if aggregation == "mean" and query.ndim == 2: query = query.mean(axis=0) single_query = query.ndim == 1 if single_query: query = [query] if self.has_view: if self.config.patches_field is not None: index_ids = self.current_label_ids else: index_ids = self.current_sample_ids _filter = {"id": {"$in": list(index_ids)}} else: _filter = None sample_ids = [] label_ids = [] if self.config.patches_field is not None else None dists = [] for q in query: include_metadata = self.config.patches_field is not None response = self._index.query( vector=q.tolist(), top_k=k, filter=_filter, include_metadata=include_metadata, ) if self.config.patches_field is not None: sample_ids.append( [r["metadata"]["sample_id"] for r in response["matches"]] ) label_ids.append([r["id"] for r in response["matches"]]) else: sample_ids.append([r["id"] for r in response["matches"]]) if return_dists: dists.append([r["score"] for r in response["matches"]]) if single_query: sample_ids = sample_ids[0] if label_ids is not None: label_ids = label_ids[0] if return_dists: dists = dists[0] if return_dists: return sample_ids, label_ids, dists return sample_ids, label_ids def _parse_neighbors_query(self, query): if etau.is_str(query): query_ids = [query] single_query = True else: query = np.asarray(query) # Query by vector(s) if np.issubdtype(query.dtype, np.number): return query query_ids = list(query) single_query = False # Query by ID(s) response = self._index.fetch(query_ids)["vectors"] query = np.array([response[_id]["values"] for _id in query_ids]) if query.size == 0: raise ValueError( "Query IDs %s were not found in the index" % query_ids ) if single_query: query = query[0, :] return query def _get_dimension(self): if self._index is None: return None return self._index.describe_index_stats().dimension @classmethod def _from_dict(cls, d, samples, config, brain_key): return cls(samples, config, brain_key) ================================================ FILE: fiftyone/brain/internal/core/qdrant.py ================================================ """ Qdrant similarity backend. | Copyright 2017-2026, Voxel51, Inc. | `voxel51.com `_ | """ import logging import numpy as np import eta.core.utils as etau import fiftyone.core.utils as fou from fiftyone.brain.similarity import ( SimilarityConfig, Similarity, SimilarityIndex, ) import fiftyone.brain.internal.core.utils as fbu qdrant = fou.lazy_import("qdrant_client") qmodels = fou.lazy_import("qdrant_client.http.models") logger = logging.getLogger(__name__) _SUPPORTED_METRICS = { "cosine": qmodels.Distance.COSINE, "dotproduct": qmodels.Distance.DOT, "euclidean": qmodels.Distance.EUCLID, } class QdrantSimilarityConfig(SimilarityConfig): """Configuration for the Qdrant similarity backend. Args: collection_name (None): the name of a Qdrant collection to use or create. If none is provided, a new collection will be created metric (None): the embedding distance metric to use when creating a new index. Supported values are ``("cosine", "dotproduct", "euclidean")`` replication_factor (None): an optional replication factor to use when creating a new index shard_number (None): an optional number of shards to use when creating a new index write_consistency_factor (None): an optional write consistsency factor to use when creating a new index hnsw_config (None): an optional dict of HNSW config parameters to use when creating a new index optimizers_config (None): an optional dict of optimizer parameters to use when creating a new index wal_config (None): an optional dict of WAL config parameters to use when creating a new index url (None): a Qdrant server URL to use api_key (None): a Qdrant API key to use grpc_port (None): Port of Qdrant gRPC interface prefer_grpc (None): If `true`, use gRPC interface when possible **kwargs: keyword arguments for :class:`fiftyone.brain.similarity.SimilarityConfig` """ def __init__( self, collection_name=None, metric=None, replication_factor=None, shard_number=None, write_consistency_factor=None, hnsw_config=None, optimizers_config=None, wal_config=None, url=None, api_key=None, grpc_port=None, prefer_grpc=None, **kwargs, ): if metric is not None and metric not in _SUPPORTED_METRICS: raise ValueError( "Unsupported metric '%s'. Supported values are %s" % (metric, tuple(_SUPPORTED_METRICS.keys())) ) super().__init__(**kwargs) self.collection_name = collection_name self.metric = metric self.replication_factor = replication_factor self.shard_number = shard_number self.write_consistency_factor = write_consistency_factor self.hnsw_config = hnsw_config self.optimizers_config = optimizers_config self.wal_config = wal_config # store privately so these aren't serialized self._url = url self._api_key = api_key self._grpc_port = grpc_port self._prefer_grpc = prefer_grpc @property def method(self): return "qdrant" @property def url(self): return self._url @url.setter def url(self, value): self._url = value @property def api_key(self): return self._api_key @api_key.setter def api_key(self, value): self._api_key = value @property def grpc_port(self): return self._grpc_port @grpc_port.setter def grpc_port(self, value): self._grpc_port = value @property def prefer_grpc(self): return self._prefer_grpc @prefer_grpc.setter def prefer_grpc(self, value): self._prefer_grpc = value @property def max_k(self): return None @property def supports_least_similarity(self): return False @property def supported_aggregations(self): return ("mean",) def load_credentials( self, url=None, api_key=None, grpc_port=None, prefer_grpc=None ): self._load_parameters( url=url, api_key=api_key, grpc_port=grpc_port, prefer_grpc=prefer_grpc, ) class QdrantSimilarity(Similarity): """Qdrant similarity factory. Args: config: a :class:`QdrantSimilarityConfig` """ def ensure_requirements(self): fou.ensure_package("qdrant-client") def ensure_usage_requirements(self): fou.ensure_package("qdrant-client") def initialize(self, samples, brain_key): return QdrantSimilarityIndex( samples, self.config, brain_key, backend=self ) class QdrantSimilarityIndex(SimilarityIndex): """Class for interacting with Qdrant similarity indexes. Args: samples: the :class:`fiftyone.core.collections.SampleCollection` used config: the :class:`QdrantSimilarityConfig` used brain_key: the brain key backend (None): a :class:`QdrantSimilarity` instance """ def __init__(self, samples, config, brain_key, backend=None): super().__init__(samples, config, brain_key, backend=backend) self._client = None self._initialize() def _initialize(self): # QdrantClient does not appear to like passing None as defaults grpc_port = ( self.config.grpc_port if self.config.grpc_port is not None else 6334 ) prefer_grpc = ( self.config.prefer_grpc if self.config.prefer_grpc is not None else False ) self._client = qdrant.QdrantClient( url=self.config.url, api_key=self.config.api_key, grpc_port=grpc_port, prefer_grpc=prefer_grpc, ) try: collection_names = self._get_collection_names() except Exception as e: raise ValueError( "Failed to connect to Qdrant backend at URL '%s'. Refer to " "https://docs.voxel51.com/integrations/qdrant.html for more " "information" % self.config.url ) from e if self.config.collection_name is None: root = "fiftyone-" + fou.to_slug(self.samples._root_dataset.name) collection_name = fbu.get_unique_name(root, collection_names) self.config.collection_name = collection_name self.save_config() def _get_collection_names(self): return [c.name for c in self._client.get_collections().collections] def _create_collection(self, dimension): if self.config.metric: metric = self.config.metric else: metric = "cosine" vectors_config = qmodels.VectorParams( size=dimension, distance=_SUPPORTED_METRICS[metric], ) if self.config.hnsw_config: hnsw_config = qmodels.HnswConfig(**self.config.hnsw_config) else: hnsw_config = None if self.config.optimizers_config: optimizers_config = qmodels.OptimizersConfig( **self.config.optimizers_config ) else: optimizers_config = None if self.config.wal_config: wal_config = qmodels.WalConfig(**self.config.wal_config) else: wal_config = None self._client.recreate_collection( collection_name=self.config.collection_name, vectors_config=vectors_config, shard_number=self.config.shard_number, replication_factor=self.config.replication_factor, hnsw_config=hnsw_config, optimizers_config=optimizers_config, wal_config=wal_config, ) def _get_index_ids(self, batch_size=1000): ids = [] offset = 0 while offset is not None: response = self._client.scroll( collection_name=self.config.collection_name, offset=offset, limit=batch_size, with_payload=True, with_vectors=False, ) ids.extend([self._to_fiftyone_id(r.id) for r in response[0]]) offset = response[-1] return ids @property def total_index_size(self): try: return self._client.count(self.config.collection_name).count except: return 0 @property def client(self): """The ``qdrant.QdrantClient`` instance for this index.""" return self._client def add_to_index( self, embeddings, sample_ids, label_ids=None, overwrite=True, allow_existing=True, warn_existing=False, reload=True, batch_size=1000, ): if self.config.collection_name not in self._get_collection_names(): self._create_collection(embeddings.shape[1]) if label_ids is not None: ids = label_ids else: ids = sample_ids if warn_existing or not allow_existing or not overwrite: index_ids = self._get_index_ids() existing_ids = set(ids) & set(index_ids) num_existing = len(existing_ids) if num_existing > 0: if not allow_existing: raise ValueError( "Found %d IDs (eg %s) that already exist in the index" % (num_existing, next(iter(existing_ids))) ) if warn_existing: if overwrite: logger.warning( "Overwriting %d IDs that already exist in the " "index", num_existing, ) else: logger.warning( "Skipping %d IDs that already exist in the index", num_existing, ) else: existing_ids = set() if existing_ids and not overwrite: del_inds = [i for i, _id in enumerate(ids) if _id in existing_ids] embeddings = np.delete(embeddings, del_inds, axis=0) sample_ids = np.delete(sample_ids, del_inds) if label_ids is not None: label_ids = np.delete(label_ids, del_inds) embeddings = [e.tolist() for e in embeddings] sample_ids = list(sample_ids) if label_ids is not None: ids = list(label_ids) else: ids = list(sample_ids) for _embeddings, _ids, _sample_ids in zip( fou.iter_batches(embeddings, batch_size), fou.iter_batches(ids, batch_size), fou.iter_batches(sample_ids, batch_size), ): self._client.upsert( collection_name=self.config.collection_name, points=qmodels.Batch( ids=self._to_qdrant_ids(_ids), payloads=[{"sample_id": _id} for _id in _sample_ids], vectors=_embeddings, ), ) if reload: self.reload() def remove_from_index( self, sample_ids=None, label_ids=None, allow_missing=True, warn_missing=False, reload=True, ): if label_ids is not None: ids = label_ids else: ids = sample_ids qids = self._to_qdrant_ids(ids) if warn_missing or not allow_missing: response = self._retrieve_points(qids, with_vectors=False) existing_ids = self._to_fiftyone_ids([r.id for r in response]) missing_ids = set(ids) - set(existing_ids) num_missing_ids = len(missing_ids) if num_missing_ids > 0: if not allow_missing: raise ValueError( "Found %d IDs (eg %s) that do not exist in the index" % (num_missing_ids, next(iter(missing_ids))) ) if warn_missing and not allow_missing: logger.warning( "Skipping %d IDs that do not exist in the index", num_missing_ids, ) qids = self._to_qdrant_ids(existing_ids) self._client.delete( collection_name=self.config.collection_name, points_selector=qmodels.PointIdsList(points=qids), ) if reload: self.reload() def get_embeddings( self, sample_ids=None, label_ids=None, allow_missing=True, warn_missing=False, ): if label_ids is not None: if self.config.patches_field is None: raise ValueError("This index does not support label IDs") if sample_ids is not None: logger.warning( "Ignoring sample IDs when label IDs are provided" ) if sample_ids is not None and self.config.patches_field is not None: ( embeddings, sample_ids, label_ids, missing_ids, ) = self._get_patch_embeddings_from_sample_ids(sample_ids) elif self.config.patches_field is not None: ( embeddings, sample_ids, label_ids, missing_ids, ) = self._get_patch_embeddings_from_label_ids(label_ids) else: ( embeddings, sample_ids, label_ids, missing_ids, ) = self._get_sample_embeddings(sample_ids) num_missing_ids = len(missing_ids) if num_missing_ids > 0: if not allow_missing: raise ValueError( "Found %d IDs (eg %s) that do not exist in the index" % (num_missing_ids, missing_ids[0]) ) if warn_missing: logger.warning( "Skipping %d IDs that do not exist in the index", num_missing_ids, ) embeddings = np.array(embeddings) sample_ids = np.array(sample_ids) if label_ids is not None: label_ids = np.array(label_ids) return embeddings, sample_ids, label_ids def cleanup(self): self._client.delete_collection(self.config.collection_name) def _retrieve_points(self, qids, with_vectors=True, with_payload=True): # @todo add batching? return self._client.retrieve( collection_name=self.config.collection_name, ids=qids, with_vectors=with_vectors, with_payload=with_payload, ) def _get_sample_embeddings(self, sample_ids): if sample_ids is None: sample_ids = self._get_index_ids() response = self._retrieve_points( self._to_qdrant_ids(sample_ids), with_vectors=True, ) found_embeddings = [r.vector for r in response] found_sample_ids = self._to_fiftyone_ids([r.id for r in response]) missing_ids = list(set(sample_ids) - set(found_sample_ids)) return found_embeddings, found_sample_ids, None, missing_ids def _get_patch_embeddings_from_label_ids(self, label_ids): if label_ids is None: label_ids = self._get_index_ids() response = self._retrieve_points( self._to_qdrant_ids(label_ids), with_vectors=True, with_payload=True, ) found_embeddings = [r.vector for r in response] found_sample_ids = [r.payload["sample_id"] for r in response] found_label_ids = self._to_fiftyone_ids([r.id for r in response]) missing_ids = list(set(label_ids) - set(found_label_ids)) return found_embeddings, found_sample_ids, found_label_ids, missing_ids def _get_patch_embeddings_from_sample_ids(self, sample_ids): _filter = qmodels.Filter( should=[ qmodels.FieldCondition( key="sample_id", match=qmodels.MatchValue(value=sid) ) for sid in sample_ids ] ) response = self._client.scroll( collection_name=self.config.collection_name, scroll_filter=_filter, with_vectors=True, with_payload=True, )[0] found_embeddings = [r.vector for r in response] found_sample_ids = [r.payload["sample_id"] for r in response] found_label_ids = [r.id for r in response] missing_ids = list(set(sample_ids) - set(found_sample_ids)) return found_embeddings, found_sample_ids, found_label_ids, missing_ids def _kneighbors( self, query=None, k=None, reverse=False, aggregation=None, return_dists=False, ): if query is None: raise ValueError("Qdrant does not support full index neighbors") if reverse is True: raise ValueError( "Qdrant does not support least similarity queries" ) if aggregation not in (None, "mean"): raise ValueError("Unsupported aggregation '%s'" % aggregation) if k is None: k = self.index_size query = self._parse_neighbors_query(query) if aggregation == "mean" and query.ndim == 2: query = query.mean(axis=0) single_query = query.ndim == 1 if single_query: query = [query] if self.has_view: if self.config.patches_field is not None: index_ids = self.current_label_ids else: index_ids = self.current_sample_ids _filter = qmodels.Filter( must=[ qmodels.HasIdCondition( has_id=self._to_qdrant_ids(index_ids) ) ] ) else: _filter = None sample_ids = [] label_ids = [] if self.config.patches_field is not None else None dists = [] for q in query: with_payload = self.config.patches_field is not None results = self._client.search( collection_name=self.config.collection_name, query_vector=q, query_filter=_filter, with_payload=with_payload, limit=k, ) if self.config.patches_field is not None: sample_ids.append( self._to_fiftyone_ids( [r.payload["sample_id"] for r in results] ) ) label_ids.append( self._to_fiftyone_ids([r.id for r in results]) ) else: sample_ids.append( self._to_fiftyone_ids([r.id for r in results]) ) if return_dists: dists.append([r.score for r in results]) if single_query: sample_ids = sample_ids[0] if label_ids is not None: label_ids = label_ids[0] if return_dists: dists = dists[0] if return_dists: return sample_ids, label_ids, dists return sample_ids, label_ids def _parse_neighbors_query(self, query): if etau.is_str(query): query_ids = [query] single_query = True else: query = np.asarray(query) # Query by vector(s) if np.issubdtype(query.dtype, np.number): return query query_ids = list(query) single_query = False # Query by ID(s) qids = self._to_qdrant_ids(query_ids) response = self._retrieve_points(qids, with_vectors=True) query = np.array([r.vector for r in response]) if query.size == 0: raise ValueError( "Query IDs %s were not found in the index" % query_ids ) if single_query: query = query[0, :] return query def _to_qdrant_id(self, _id): return _id + "00000000" def _to_qdrant_ids(self, ids): return [self._to_qdrant_id(_id) for _id in ids] def _to_fiftyone_id(self, qid): return qid.replace("-", "")[:-8] def _to_fiftyone_ids(self, qids): return [self._to_fiftyone_id(qid) for qid in qids] @classmethod def _from_dict(cls, d, samples, config, brain_key): return cls(samples, config, brain_key) ================================================ FILE: fiftyone/brain/internal/core/redis.py ================================================ """ Redis similarity backend. | Copyright 2017-2026, Voxel51, Inc. | `voxel51.com `_ | """ import logging import numpy as np import eta.core.utils as etau import fiftyone.core.utils as fou from fiftyone.brain.similarity import ( SimilarityConfig, Similarity, SimilarityIndex, ) import fiftyone.brain.internal.core.utils as fbu redis = fou.lazy_import("redis") logger = logging.getLogger(__name__) _SUPPORTED_METRICS = { "cosine": "COSINE", "dotproduct": "IP", "euclidean": "L2", } class RedisSimilarityConfig(SimilarityConfig): """Configuration for the Redis similarity backend. Args: index_name (None): the name of a Redis index to use or create. If none is provided, a new index will be created metric ("cosine"): the embedding distance metric to use when creating a new index. Supported values are ``("cosine", "dotproduct", "euclidean")`` algorithm ("FLAT"): the search algorithm to use. The supported values are ``("FLAT", "HNSW")`` host ("localhost"): the host to use port (6379): the port to use db (0): the database to use username (None): a username to use password (None): a password to use **kwargs: keyword arguments for :class:`fiftyone.brain.similarity.SimilarityConfig` """ def __init__( self, index_name=None, metric="cosine", algorithm="FLAT", host="localhost", port=6379, db=0, username=None, password=None, **kwargs, ): if metric not in _SUPPORTED_METRICS: raise ValueError( "Unsupported metric '%s'. Supported values are %s" % (metric, tuple(_SUPPORTED_METRICS.keys())) ) super().__init__(**kwargs) self.index_name = index_name self.metric = metric self.algorithm = algorithm # store privately so these aren't serialized self._host = host self._port = port self._db = db self._username = username self._password = password @property def method(self): return "redis" @property def host(self): return self._host @host.setter def host(self, value): self._host = value @property def port(self): return self._port @port.setter def port(self, value): self._port = value @property def db(self): return self._db @db.setter def db(self, value): self._db = value @property def username(self): return self._username @username.setter def username(self, value): self._username = value @property def password(self): return self._password @password.setter def password(self, value): self._password = value @property def max_k(self): return None @property def supports_least_similarity(self): return False @property def supported_aggregations(self): return ("mean",) def load_credentials( self, host=None, port=None, db=None, username=None, password=None ): self._load_parameters( host=host, port=port, db=db, username=username, password=password ) class RedisSimilarity(Similarity): """Redis similarity factory. Args: config: a :class:`RedisSimilarityConfig` """ def ensure_requirements(self): fou.ensure_package("redis") def ensure_usage_requirements(self): fou.ensure_package("redis") def initialize(self, samples, brain_key): return RedisSimilarityIndex( samples, self.config, brain_key, backend=self ) class RedisSimilarityIndex(SimilarityIndex): """Class for interacting with Redis similarity indexes. Args: samples: the :class:`fiftyone.core.collections.SampleCollection` used config: the :class:`RedisSimilarityConfig` used brain_key: the brain key backend (None): a :class:`RedisSimilarity` instance """ def __init__(self, samples, config, brain_key, backend=None): super().__init__(samples, config, brain_key, backend=backend) self._client = None self._index = None self._initialize() def _initialize(self): client = redis.Redis( host=self.config.host, port=self.config.port, db=self.config.db, username=self.config.username, password=self.config.password, decode_responses=True, ) if self.config.index_name is None: def index_exists(index_name): try: client.ft(index_name).info() return True except: return False root = "fiftyone-" + fou.to_slug(self._samples._root_dataset.name) index_name = fbu.get_unique_name(root, index_exists) self.config.index_name = index_name self.save_config() try: index = client.ft(self.config.index_name) index.info() except: index = None self._client = client self._index = index def _create_index(self, dimension): from redis.commands.search.field import TagField, VectorField from redis.commands.search.indexDefinition import ( IndexDefinition, IndexType, ) schema = ( TagField("$.foid", as_name="foid"), TagField("$.sample_id", as_name="sample_id"), VectorField( "$.vector", self.config.algorithm, { "TYPE": "FLOAT32", "DIM": dimension, "DISTANCE_METRIC": _SUPPORTED_METRICS[self.config.metric], }, as_name="vector", ), ) definition = IndexDefinition( prefix=[self.config.index_name + ":"], index_type=IndexType.JSON, ) index = self._client.ft(self.config.index_name) index.create_index(fields=schema, definition=definition) self._index = index @property def client(self): """The ``redis.client.Redis`` instance for this index.""" return self._client @property def total_index_size(self): try: return int(self._index.info()["num_docs"]) except: return 0 def add_to_index( self, embeddings, sample_ids, label_ids=None, overwrite=True, allow_existing=True, warn_existing=False, reload=True, ): if self._index is None: self._create_index(embeddings.shape[1]) if label_ids is not None: ids = label_ids else: ids = sample_ids if warn_existing or not allow_existing or not overwrite: existing_ids = self._get_existing_ids(ids) num_existing = len(existing_ids) if num_existing > 0: if not allow_existing: raise ValueError( "Found %d IDs (eg %s) that already exist in the index" % (num_existing, next(iter(existing_ids))) ) if warn_existing: if overwrite: logger.warning( "Overwriting %d IDs that already exist in the " "index", num_existing, ) else: logger.warning( "Skipping %d IDs that already exist in the index", num_existing, ) else: existing_ids = set() if existing_ids and not overwrite: del_inds = [i for i, _id in enumerate(ids) if _id in existing_ids] embeddings = np.delete(embeddings, del_inds, axis=0) sample_ids = np.delete(sample_ids, del_inds) if label_ids is not None: label_ids = np.delete(label_ids, del_inds) elif existing_ids and overwrite: self._delete_ids(existing_ids) pipeline = self._client.pipeline() for e, id, sample_id in zip(embeddings, ids, sample_ids): key = f"{self.config.index_name}:{id}" d = { "foid": id, "sample_id": sample_id, "vector": e.astype(np.float32).tolist(), } pipeline.json().set(key, "$", d) pipeline.execute() if reload: self.reload() def _get_existing_ids(self, ids): return [d["foid"] for d in self._get_values(ids)] def _delete_ids(self, ids): keys = [f"{self.config.index_name}:{id}" for id in ids] self._client.delete(*keys) def _get_values(self, ids): pipeline = self._client.pipeline() for id in ids: pipeline.json().get(f"{self.config.index_name}:{id}") return [d for d in pipeline.execute() if d is not None] def remove_from_index( self, sample_ids=None, label_ids=None, allow_missing=True, warn_missing=False, reload=True, ): if label_ids is not None: ids = label_ids else: ids = sample_ids if not allow_missing or warn_missing: existing_ids = self._get_existing_ids(ids) missing_ids = set(ids) - set(existing_ids) num_missing = len(missing_ids) if num_missing > 0: if not allow_missing: raise ValueError( "Found %d IDs (eg %s) that are not present in the " "index" % (num_missing, next(iter(missing_ids))) ) if warn_missing: logger.warning( "Ignoring %d IDs that are not present in the index", num_missing, ) ids = existing_ids self._delete_ids(ids=ids) if reload: self.reload() def get_embeddings( self, sample_ids=None, label_ids=None, allow_missing=True, warn_missing=False, ): if label_ids is not None: if self.config.patches_field is None: raise ValueError("This index does not support label IDs") if sample_ids is not None: logger.warning( "Ignoring sample IDs when label IDs are provided" ) if sample_ids is not None and self.config.patches_field is not None: ( embeddings, sample_ids, label_ids, missing_ids, ) = self._get_patch_embeddings_from_sample_ids(sample_ids) elif self.config.patches_field is not None: ( embeddings, sample_ids, label_ids, missing_ids, ) = self._get_patch_embeddings_from_label_ids(label_ids) else: ( embeddings, sample_ids, label_ids, missing_ids, ) = self._get_sample_embeddings(sample_ids) num_missing_ids = len(missing_ids) if num_missing_ids > 0: if not allow_missing: raise ValueError( "Found %d IDs (eg %s) that do not exist in the index" % (num_missing_ids, missing_ids[0]) ) if warn_missing: logger.warning( "Skipping %d IDs that do not exist in the index", num_missing_ids, ) embeddings = np.array(embeddings) sample_ids = np.array(sample_ids) if label_ids is not None: label_ids = np.array(label_ids) return embeddings, sample_ids, label_ids def cleanup(self): if self._index is None: return self._index.dropindex(delete_documents=True) self._index = None def _get_sample_embeddings(self, sample_ids, batch_size=1000): found_embeddings = [] found_sample_ids = [] if sample_ids is None: get_id = lambda key: key.rsplit(":", 1)[1] keys = self._client.keys(f"{self.config.index_name}:*") sample_ids = map(get_id, keys) for batch_ids in fou.iter_batches(sample_ids, batch_size): for d in self._get_values(batch_ids): found_embeddings.append(d["vector"]) found_sample_ids.append(d["sample_id"]) missing_ids = list(set(sample_ids) - set(found_sample_ids)) return found_embeddings, found_sample_ids, None, missing_ids def _get_patch_embeddings_from_label_ids(self, label_ids, batch_size=1000): found_embeddings = [] found_sample_ids = [] found_label_ids = [] if label_ids is None: get_id = lambda key: key.rsplit(":", 1)[1] keys = self._client.keys(f"{self.config.index_name}:*") label_ids = map(get_id, keys) for batch_ids in fou.iter_batches(label_ids, batch_size): for d in self._get_values(batch_ids): found_embeddings.append(d["vector"]) found_sample_ids.append(d["sample_id"]) found_label_ids.append(d["foid"]) missing_ids = list(set(label_ids) - set(found_label_ids)) return found_embeddings, found_sample_ids, found_label_ids, missing_ids def _get_patch_embeddings_from_sample_ids( self, sample_ids, batch_size=100 ): from redis.commands.search.query import Query found_embeddings = [] found_sample_ids = [] found_label_ids = [] for batch_ids in fou.iter_batches(sample_ids, batch_size): filter = "@sample_id:{ " + " | ".join(batch_ids) + " }" query = Query(filter).dialect(2) for doc in self._index.search(query).docs: found_embeddings.append(doc.embeddings) found_sample_ids.append(doc.sample_id) found_label_ids.append(doc.foid) missing_ids = list(set(sample_ids) - set(found_sample_ids)) return found_embeddings, found_sample_ids, found_label_ids, missing_ids def _kneighbors( self, query=None, k=None, reverse=False, aggregation=None, return_dists=False, ): from redis.commands.search.query import Query if query is None: raise ValueError("Redis does not support full index neighbors") if reverse is True: raise ValueError("Redis does not support least similarity queries") if k is None: k = self.index_size if aggregation not in (None, "mean"): raise ValueError("Unsupported aggregation '%s'" % aggregation) query = self._parse_neighbors_query(query) if aggregation == "mean" and query.ndim == 2: query = query.mean(axis=0) single_query = query.ndim == 1 if single_query: query = [query] if self.has_view: if self.config.patches_field is not None: index_ids = list(self.current_label_ids) else: index_ids = list(self.current_sample_ids) filter = "@foid:{ " + " | ".join(index_ids) + " }" else: filter = "*" sample_ids = [] label_ids = [] if self.config.patches_field is not None else None dists = [] for q in query: _query = ( Query(f"({filter})=>[KNN {k} @vector $query AS score]") .sort_by("score") .return_fields("score", "foid", "sample_id") .dialect(2) .paging(0, k) ) _q = q.astype(np.float32).tobytes() docs = self._index.search(_query, {"query": _q}).docs if self.config.patches_field is not None: sample_ids.append([doc.sample_id for doc in docs]) label_ids.append([doc.foid for doc in docs]) else: sample_ids.append([doc.foid for doc in docs]) if return_dists: dists.append([doc.score for doc in docs]) if single_query: sample_ids = sample_ids[0] if label_ids is not None: label_ids = label_ids[0] if return_dists: dists = dists[0] if return_dists: return sample_ids, label_ids, dists return sample_ids, label_ids def _parse_neighbors_query(self, query): if etau.is_str(query): query_ids = [query] single_query = True else: query = np.asarray(query) # Query by vector(s) if np.issubdtype(query.dtype, np.number): return query query_ids = list(query) single_query = False # Query by ID(s) dicts = self._get_values(query_ids) if not dicts: raise ValueError( "Query IDs %s do not exist in this index" % query_ids ) query = np.array([d["vector"] for d in dicts]) if single_query: query = query[0, :] return query @classmethod def _from_dict(cls, d, samples, config, brain_key): return cls(samples, config, brain_key) ================================================ FILE: fiftyone/brain/internal/core/representativeness.py ================================================ """ Representativeness methods. | Copyright 2017-2026, Voxel51, Inc. | `voxel51.com `_ | """ import logging import copy import numpy as np import sklearn.cluster as skc from scipy.spatial import cKDTree import eta.core.utils as etau import fiftyone.core.brain as fob import fiftyone.core.fields as fof import fiftyone.core.labels as fol import fiftyone.core.validation as fov import fiftyone.brain.internal.core.utils as fbu import fiftyone.brain.internal.models as fbm logger = logging.getLogger(__name__) _ALLOWED_ROI_FIELD_TYPES = ( fol.Detection, fol.Detections, fol.Polyline, fol.Polylines, ) _DEFAULT_MODEL = "simple-resnet-cifar10" _DEFAULT_BATCH_SIZE = 16 def compute_representativeness( samples, representativeness_field, method, roi_field, embeddings, similarity_index, model, model_kwargs, force_square, alpha, batch_size, num_workers, skip_failures, progress, ): """See ``fiftyone/brain/__init__.py``.""" # # Algorithm # # Compute cluster centers with MeanShift. The representativeness will # then be a scaled distance to the nearest cluster center. This puts # cluster centers which should represent the data the highest with a high # ranking and points on the outliers with low ranking. # fov.validate_collection(samples) if roi_field is not None: fov.validate_collection_label_fields( samples, roi_field, _ALLOWED_ROI_FIELD_TYPES ) if etau.is_str(embeddings): embeddings_field, embeddings_exist = fbu.parse_data_field( samples, embeddings, patches_field=roi_field, data_type="embeddings", ) embeddings = None else: embeddings_field = None embeddings_exist = None if etau.is_str(similarity_index): similarity_index = samples.load_brain_results(similarity_index) if ( model is None and embeddings is None and similarity_index is None and not embeddings_exist ): model = fbm.load_model(_DEFAULT_MODEL) if batch_size is None: batch_size = _DEFAULT_BATCH_SIZE config = RepresentativenessConfig( representativeness_field, method=method, roi_field=roi_field, embeddings_field=embeddings_field, similarity_index=similarity_index, model=model, model_kwargs=model_kwargs, ) brain_key = representativeness_field brain_method = config.build() brain_method.ensure_requirements() brain_method.register_run(samples, brain_key, cleanup=False) if roi_field is not None: # @todo experiment with mean(), max(), abs().max(), etc agg_fcn = lambda e: np.mean(e, axis=0) else: agg_fcn = None embeddings, sample_ids, _ = fbu.get_embeddings( samples, model=model, model_kwargs=model_kwargs, patches_field=roi_field, embeddings_field=embeddings_field, embeddings=embeddings, similarity_index=similarity_index, force_square=force_square, alpha=alpha, handle_missing="image", agg_fcn=agg_fcn, batch_size=batch_size, num_workers=num_workers, skip_failures=skip_failures, progress=progress, ) logger.info("Computing representativeness...") representativeness = _compute_representativeness(embeddings, method=method) # Ensure field exists, even if `representativeness` is empty samples._dataset.add_sample_field(representativeness_field, fof.FloatField) representativeness = { _id: u for _id, u in zip(sample_ids, representativeness) } if representativeness: samples.set_values( representativeness_field, representativeness, key_field="id" ) brain_method.save_run_results(samples, brain_key, None) logger.info("Representativeness computation complete") def _compute_representativeness(embeddings, method="cluster-center"): # # @todo experiment on which method for assessing representativeness # num_embeddings = len(embeddings) logger.info( "Computing clusters for %d embeddings; this may take awhile...", num_embeddings, ) initial_ranking, _ = _cluster_ranker(embeddings) if method == "cluster-center": final_ranking = initial_ranking elif method == "cluster-center-downweight": logger.info("Applying iterative downweighting...") final_ranking = _adjust_rankings( embeddings, initial_ranking, ball_radius=0.5 ) else: raise ValueError( ( "Method '%s' not supported. Please use one of " "['cluster-center', 'cluster-center-downweight']" ) % method ) return final_ranking def _cluster_ranker( embeddings, cluster_algorithm="kmeans", N=20, norm_method="local" ): # Cluster if cluster_algorithm == "meanshift": bandwidth = skc.estimate_bandwidth( embeddings, quantile=0.8, n_samples=500 ) clusterer = skc.MeanShift(bandwidth=bandwidth, bin_seeding=True).fit( embeddings ) elif cluster_algorithm == "kmeans": clusterer = skc.KMeans(n_clusters=N, random_state=1234).fit(embeddings) else: raise ValueError( ( "Clustering algorithm '%s' not supported. Please use one of " "['meanshift', 'kmeans']" ) % cluster_algorithm ) cluster_centers = clusterer.cluster_centers_ cluster_ids = clusterer.labels_ # Get distance from each point to it's closest cluster center sample_dists = np.linalg.norm( embeddings - cluster_centers[cluster_ids], axis=1 ) centerness_ranking = 1 / (1 + sample_dists) # Normalize per cluster vs globally norm_method = "local" if norm_method == "global": centerness_ranking = centerness_ranking / centerness_ranking.max() elif norm_method == "local": unique_ids = np.unique(cluster_ids) for unique_id in unique_ids: cluster_indices = np.where(cluster_ids == unique_id)[0] cluster_dists = sample_dists[cluster_indices] cluster_dists /= cluster_dists.max() sample_dists[cluster_indices] = cluster_dists centerness_ranking = sample_dists return centerness_ranking, clusterer # Step 3: Adjust rankings to avoid redundancy def _adjust_rankings(embeddings, initial_ranking, ball_radius=0.5): tree = cKDTree(embeddings) new_ranking = copy.deepcopy(initial_ranking) ordered_ranking = np.argsort(new_ranking)[::-1] visited_indices = set() for ranked_index in ordered_ranking: visited_indices.add(ranked_index) query_embedding = embeddings[ranked_index, :] nearby_indices = tree.query_ball_point( query_embedding, ball_radius, return_sorted=True ) filtered_indices = [ idx for idx in nearby_indices if idx not in visited_indices ] visited_indices |= set(filtered_indices) new_ranking[filtered_indices] = new_ranking[filtered_indices] * 0.7 new_ranking = new_ranking / new_ranking.max() return new_ranking # @todo move to `fiftyone/brain/representativeness.py` # Don't do this hastily; `get_brain_info()` on existing datasets has this # class's full path in it and may need migration class RepresentativenessConfig(fob.BrainMethodConfig): def __init__( self, representativeness_field, method=None, roi_field=None, embeddings_field=None, similarity_index=None, model=None, model_kwargs=None, **kwargs, ): if similarity_index is not None and not etau.is_str(similarity_index): similarity_index = similarity_index.key if model is not None and not etau.is_str(model): model = etau.get_class_name(model) self.representativeness_field = representativeness_field self._method = method self.roi_field = roi_field self.embeddings_field = embeddings_field self.similarity_index = similarity_index self.model = model self.model_kwargs = model_kwargs super().__init__(**kwargs) @property def type(self): return "representativeness" @property def method(self): return self._method @classmethod def _virtual_attributes(cls): # By default 'method' is virtual but we omit so it *IS* serialized return ["cls", "type"] class Representativeness(fob.BrainMethod): def ensure_requirements(self): pass def get_fields(self, samples, brain_key): fields = [self.config.representativeness_field] if self.config.roi_field is not None: fields.append(self.config.roi_field) if self.config.embeddings_field is not None: fields.append(self.config.embeddings_field) return fields def cleanup(self, samples, brain_key): representativeness_field = self.config.representativeness_field samples._dataset.delete_sample_fields( representativeness_field, error_level=1 ) def _validate_run(self, samples, brain_key, existing_info): self._validate_fields_match( brain_key, "representativeness_field", existing_info ) ================================================ FILE: fiftyone/brain/internal/core/sklearn.py ================================================ """ Sklearn similarity backend. | Copyright 2017-2026, Voxel51, Inc. | `voxel51.com `_ | """ import logging import numpy as np import sklearn.metrics as skm import sklearn.neighbors as skn import sklearn.preprocessing as skp import eta.core.utils as etau import fiftyone.core.media as fom from fiftyone.brain.similarity import ( DuplicatesMixin, SimilarityConfig, Similarity, SimilarityIndex, ) import fiftyone.brain.internal.core.utils as fbu logger = logging.getLogger(__name__) _AGGREGATIONS = { "mean": np.mean, "post-mean": np.nanmean, "post-min": np.nanmin, "post-max": np.nanmax, } _MAX_PRECOMPUTE_DISTS = 15000 # ~1.7GB to store distance matrix in-memory _COSINE_HACK_ATTR = "_cosine_hack" class SklearnSimilarityConfig(SimilarityConfig): """Configuration for the sklearn similarity backend. Args: metric ("cosine"): the embedding distance metric to use. See ``sklearn.metrics.pairwise_distance`` for supported values **kwargs: keyword arguments for :class:`fiftyone.brain.similarity.SimilarityConfig` """ def __init__(self, metric="cosine", **kwargs): super().__init__(**kwargs) self.metric = metric @property def method(self): return "sklearn" @property def max_k(self): return None @property def supports_least_similarity(self): return True @property def supported_aggregations(self): return tuple(_AGGREGATIONS.keys()) class SklearnSimilarity(Similarity): """Sklearn similarity factory. Args: config: an :class:`SklearnSimilarityConfig` """ def initialize(self, samples, brain_key): return SklearnSimilarityIndex( samples, self.config, brain_key, backend=self ) class SklearnSimilarityIndex(SimilarityIndex, DuplicatesMixin): """Class for interacting with sklearn similarity indexes. Args: samples: the :class:`fiftyone.core.collections.SampleCollection` used config: the :class:`SklearnSimilarityConfig` used brain_key: the brain key embeddings (None): a ``num_embeddings x num_dims`` array of embeddings sample_ids (None): a ``num_embeddings`` array of sample IDs label_ids (None): a ``num_embeddings`` array of label IDs, if applicable backend (None): a :class:`SklearnSimilarity` instance """ def __init__( self, samples, config, brain_key, embeddings=None, sample_ids=None, label_ids=None, backend=None, ): embeddings, sample_ids, label_ids = self._parse_data( samples, config, embeddings=embeddings, sample_ids=sample_ids, label_ids=label_ids, ) self._dataset = samples._dataset self._embeddings = embeddings self._sample_ids = sample_ids self._label_ids = label_ids self._ids_to_inds = None self._curr_ids_to_inds = None self._neighbors_helper = None SimilarityIndex.__init__( self, samples, config, brain_key, backend=backend ) DuplicatesMixin.__init__(self) @property def is_external(self): return self.config.embeddings_field is None @property def embeddings(self): return self._embeddings @property def sample_ids(self): return self._sample_ids @property def label_ids(self): return self._label_ids @property def total_index_size(self): return len(self._sample_ids) def add_to_index( self, embeddings, sample_ids, label_ids=None, overwrite=True, allow_existing=True, warn_existing=False, reload=True, ): sample_ids = np.asarray(sample_ids) label_ids = np.asarray(label_ids) if label_ids is not None else None _sample_ids, _label_ids, ii, jj = fbu.add_ids( sample_ids, label_ids, self._sample_ids, self._label_ids, patches_field=self.config.patches_field, overwrite=overwrite, allow_existing=allow_existing, warn_existing=warn_existing, ) if ii.size == 0: return _embeddings = embeddings[ii, :] if self.config.embeddings_field is not None: fbu.add_embeddings( self._dataset, _embeddings, sample_ids[ii], label_ids[ii] if label_ids is not None else None, self.config.embeddings_field, patches_field=self.config.patches_field, ) _e = self._embeddings n = _e.shape[0] if n == 0: _e = np.empty((0, embeddings.shape[1]), dtype=embeddings.dtype) d = _e.shape[1] m = max(jj) - n + 1 if m > 0: if _e.size > 0: _e = np.concatenate((_e, np.empty((m, d), dtype=_e.dtype))) else: _e = np.empty_like(_embeddings) _e[jj, :] = _embeddings self._embeddings = _e self._sample_ids = _sample_ids self._label_ids = _label_ids self._ids_to_inds = None self._curr_ids_to_inds = None self._neighbors_helper = None if reload: super().reload() def remove_from_index( self, sample_ids=None, label_ids=None, allow_missing=True, warn_missing=False, reload=True, ): _sample_ids, _label_ids, rm_inds = fbu.remove_ids( sample_ids, label_ids, self._sample_ids, self._label_ids, patches_field=self.config.patches_field, allow_missing=allow_missing, warn_missing=warn_missing, ) if rm_inds.size == 0: return if self.config.embeddings_field is not None: if self.config.patches_field is not None: rm_sample_ids = None rm_label_ids = self._label_ids[rm_inds] else: rm_sample_ids = self._sample_ids[rm_inds] rm_label_ids = None fbu.remove_embeddings( self._dataset, self.config.embeddings_field, sample_ids=rm_sample_ids, label_ids=rm_label_ids, patches_field=self.config.patches_field, ) _embeddings = np.delete(self._embeddings, rm_inds, axis=0) self._embeddings = _embeddings self._sample_ids = _sample_ids self._label_ids = _label_ids self._ids_to_inds = None self._curr_ids_to_inds = None self._neighbors_helper = None if reload: super().reload() def use_view(self, *args, **kwargs): self._curr_ids_to_inds = None return super().use_view(*args, **kwargs) def get_embeddings( self, sample_ids=None, label_ids=None, allow_missing=True, warn_missing=False, ): if label_ids is not None: if self.config.patches_field is None: raise ValueError("This index does not support label IDs") if sample_ids is not None: logger.warning( "Ignoring sample IDs when label IDs are provided" ) inds = _get_inds( label_ids, self.label_ids, "label", allow_missing, warn_missing, ) embeddings = self._embeddings[inds, :] sample_ids = self.sample_ids[inds] label_ids = np.asarray(label_ids) elif sample_ids is not None: if etau.is_str(sample_ids): sample_ids = [sample_ids] if self.config.patches_field is not None: sample_ids = set(sample_ids) bools = [_id in sample_ids for _id in self.sample_ids] inds = np.nonzero(bools)[0] else: inds = _get_inds( sample_ids, self.sample_ids, "sample", allow_missing, warn_missing, ) embeddings = self._embeddings[inds, :] sample_ids = self.sample_ids[inds] if self.config.patches_field is not None: label_ids = self.label_ids[inds] else: label_ids = None else: embeddings = self._embeddings.copy() sample_ids = self.sample_ids.copy() if self.config.patches_field is not None: label_ids = self.label_ids.copy() else: label_ids = None return embeddings, sample_ids, label_ids def reload(self): if self.config.embeddings_field is not None: embeddings, sample_ids, label_ids = self._parse_data( self._dataset, self.config ) self._embeddings = embeddings self._sample_ids = sample_ids self._label_ids = label_ids self._ids_to_inds = None self._curr_ids_to_inds = None self._neighbors_helper = None super().reload() def cleanup(self): pass def attributes(self): attrs = super().attributes() if self.config.embeddings_field is None: attrs.extend(["embeddings", "sample_ids", "label_ids"]) return attrs def _kneighbors( self, query=None, k=None, reverse=False, aggregation=None, return_dists=False, ): if aggregation is not None: return self._kneighbors_aggregate( query, k, reverse, aggregation, return_dists ) ( query, query_inds, full_index, single_query, ) = self._parse_neighbors_query(query) can_use_dists = full_index or query_inds is not None neighbors, dists = self._get_neighbors(can_use_dists=can_use_dists) if dists is not None: # Use pre-computed distances if query_inds is not None: _dists = dists[query_inds, :] else: _dists = dists # note: this must gracefully ignore nans inds = _nanargmin(_dists, k=k) if return_dists: dists = [d[i] for i, d in zip(inds, _dists)] else: dists = None else: if return_dists: dists, inds = neighbors.kneighbors( X=query, n_neighbors=k, return_distance=True ) inds = list(inds) dists = list(dists) else: inds = neighbors.kneighbors( X=query, n_neighbors=k, return_distance=False ) inds = list(inds) dists = None return self._format_output( inds, dists, full_index, single_query, return_dists ) def _radius_neighbors(self, query=None, thresh=None, return_dists=False): ( query, query_inds, full_index, single_query, ) = self._parse_neighbors_query(query) can_use_dists = full_index or query_inds is not None neighbors, dists = self._get_neighbors(can_use_dists=can_use_dists) # When not using brute force, we approximate cosine distance by # computing Euclidean distance on unit-norm embeddings. # ED = sqrt(2 * CD), so we need to scale the threshold appropriately if getattr(neighbors, _COSINE_HACK_ATTR, False): thresh = np.sqrt(2.0 * thresh) if dists is not None: # Use pre-computed distances if query_inds is not None: _dists = dists[query_inds, :] else: _dists = dists # note: this must gracefully ignore nans inds = [np.nonzero(d <= thresh)[0] for d in _dists] if return_dists: dists = [d[i] for i, d in zip(inds, _dists)] else: dists = None else: if return_dists: dists, inds = neighbors.radius_neighbors( X=query, radius=thresh, return_distance=True ) else: dists = None inds = neighbors.radius_neighbors( X=query, radius=thresh, return_distance=False ) return self._format_output( inds, dists, full_index, single_query, return_dists ) def _kneighbors_aggregate( self, query, k, reverse, aggregation, return_dists ): if query is None: raise ValueError("Full index queries do not support aggregation") if aggregation not in _AGGREGATIONS: raise ValueError( "Unsupported aggregation method '%s'. Supported values are %s" % (aggregation, tuple(_AGGREGATIONS.keys())) ) query, query_inds, _, _ = self._parse_neighbors_query(query) # Pre-aggregation if aggregation == "mean": if query.shape[0] > 1: query = query.mean(axis=0, keepdims=True) query_inds = None aggregation = None can_use_dists = query_inds is not None _, dists = self._get_neighbors( can_use_neighbors=False, can_use_dists=can_use_dists ) if dists is not None: # Use pre-computed distances dists = dists[query_inds, :] else: keep_inds = self._current_inds index_embeddings = self._embeddings if keep_inds is not None: index_embeddings = index_embeddings[keep_inds] dists = skm.pairwise_distances( query, index_embeddings, metric=self.config.metric ) # Post-aggregation if aggregation is not None: # note: this must gracefully ignore nans agg_fcn = _AGGREGATIONS[aggregation] dists = agg_fcn(dists, axis=0) else: dists = dists[0, :] if can_use_dists: dists[np.isnan(dists)] = 0.0 inds = np.argsort(dists) if reverse: inds = np.flip(inds) if k is not None: inds = inds[:k] sample_ids = list(self.current_sample_ids[inds]) if self.config.patches_field is not None: label_ids = list(self.current_label_ids[inds]) else: label_ids = None if return_dists: dists = list(dists[inds]) return sample_ids, label_ids, dists return sample_ids, label_ids def _parse_neighbors_query(self, query): # Full index if query is None: return None, None, True, False if etau.is_str(query): query_ids = [query] single_query = True else: query = np.asarray(query) # Query vector(s) if np.issubdtype(query.dtype, np.number): single_query = query.ndim == 1 if single_query: query = query[np.newaxis, :] return query, None, False, single_query query_ids = list(query) single_query = False # Retrieve indices into active `dists` matrix, if possible ids_to_inds = self._get_ids_to_inds(full=False) query_inds = [] for _id in query_ids: _ind = ids_to_inds.get(_id, None) if _ind is not None: query_inds.append(_ind) else: # At least one query ID is not in the active index query_inds = None break # Retrieve embeddings ids_to_inds = self._get_ids_to_inds(full=True) inds = [] bad_ids = [] for _id in query_ids: _ind = ids_to_inds.get(_id, None) if _ind is not None: inds.append(_ind) else: bad_ids.append(_id) inds = np.array(inds) if bad_ids: raise ValueError( "Query IDs %s do not exist in this index" % bad_ids ) query = self._embeddings[inds, :] if query_inds is not None: query_inds = np.array(query_inds) return query, query_inds, False, single_query def _get_ids_to_inds(self, full=False): if full: if self._ids_to_inds is None: if self.config.patches_field is not None: ids = self.label_ids else: ids = self.sample_ids self._ids_to_inds = {_id: i for i, _id in enumerate(ids)} return self._ids_to_inds if self._curr_ids_to_inds is None: if self.config.patches_field is not None: ids = self.current_label_ids else: ids = self.current_sample_ids self._curr_ids_to_inds = {_id: i for i, _id in enumerate(ids)} return self._curr_ids_to_inds def _get_neighbors(self, can_use_neighbors=True, can_use_dists=True): if self._neighbors_helper is None: self._neighbors_helper = NeighborsHelper( self._embeddings, self.config.metric ) return self._neighbors_helper.get_neighbors( keep_inds=self._current_inds, can_use_neighbors=can_use_neighbors, can_use_dists=can_use_dists, ) def _format_output( self, inds, dists, full_index, single_query, return_dists ): if full_index: if return_dists: return inds, dists return inds curr_sample_ids = self.current_sample_ids sample_ids = [[curr_sample_ids[i] for i in _inds] for _inds in inds] if single_query: sample_ids = sample_ids[0] if self.config.patches_field is not None: curr_label_ids = self.current_label_ids label_ids = [[curr_label_ids[i] for i in _inds] for _inds in inds] if single_query: label_ids = label_ids[0] else: label_ids = None if return_dists: dists = [list(d) for d in dists] if single_query: dists = dists[0] return sample_ids, label_ids, dists return sample_ids, label_ids @staticmethod def _parse_data( samples, config, embeddings=None, sample_ids=None, label_ids=None, ): if embeddings is None: samples = samples._dataset if samples.media_type == fom.GROUP: samples = samples.select_group_slices(_allow_mixed=True) embeddings, sample_ids, label_ids = fbu.get_embeddings( samples, patches_field=config.patches_field, embeddings_field=config.embeddings_field, ) elif sample_ids is None: sample_ids, label_ids = fbu.get_ids( samples, patches_field=config.patches_field, data=embeddings, data_type="embeddings", ) return embeddings, sample_ids, label_ids @classmethod def _from_dict(cls, d, samples, config, brain_key): embeddings = d.get("embeddings", None) if embeddings is not None: embeddings = np.array(embeddings) sample_ids = d.get("sample_ids", None) if sample_ids is not None: sample_ids = np.array(sample_ids) label_ids = d.get("label_ids", None) if label_ids is not None: label_ids = np.array(label_ids) return cls( samples, config, brain_key, embeddings=embeddings, sample_ids=sample_ids, label_ids=label_ids, ) class NeighborsHelper(object): _UNAVAILABLE = "UNAVAILABLE" def __init__(self, embeddings, metric): self.embeddings = embeddings self.metric = metric self._initialized = False self._full_dists = None self._curr_keep_inds = None self._curr_neighbors = None self._curr_dists = None def get_neighbors( self, keep_inds=None, can_use_neighbors=True, can_use_dists=True, ): iokay = self._same_keep_inds(keep_inds) nokay = not can_use_neighbors or self._curr_neighbors is not None dokay = not can_use_dists or self._curr_dists is not None if iokay and nokay and dokay: neighbors = self._curr_neighbors dists = self._curr_dists else: neighbors, dists = self._build( keep_inds=keep_inds, can_use_neighbors=can_use_neighbors, can_use_dists=can_use_dists, ) if not iokay: self._curr_keep_inds = keep_inds if self._curr_neighbors is None or not iokay: self._curr_neighbors = neighbors if self._curr_dists is None or not iokay: self._curr_dists = dists if not can_use_neighbors or neighbors is self._UNAVAILABLE: neighbors = None if not can_use_dists or dists is self._UNAVAILABLE: dists = None return neighbors, dists def _same_keep_inds(self, keep_inds): # This handles either argument being None return np.array_equal(keep_inds, self._curr_keep_inds) def _build( self, keep_inds=None, can_use_neighbors=True, can_use_dists=True ): if can_use_dists: if ( self._full_dists is None and len(self.embeddings) <= _MAX_PRECOMPUTE_DISTS ): self._full_dists = self._build_dists(self.embeddings) if self._full_dists is not None: if keep_inds is not None: dists = self._full_dists[keep_inds, :][:, keep_inds] else: dists = self._full_dists elif ( keep_inds is not None and len(keep_inds) <= _MAX_PRECOMPUTE_DISTS ): dists = self._build_dists(self.embeddings[keep_inds]) else: dists = self._UNAVAILABLE else: dists = None if can_use_neighbors: if not isinstance(dists, np.ndarray): embeddings = self.embeddings if keep_inds is not None: embeddings = embeddings[keep_inds] neighbors = self._build_neighbors(embeddings) else: neighbors = self._UNAVAILABLE else: neighbors = None return neighbors, dists def _build_dists(self, embeddings): logger.debug("Generating index for %d embeddings...", len(embeddings)) # Center embeddings embeddings = np.asarray(embeddings) embeddings -= embeddings.mean(axis=0, keepdims=True) dists = skm.pairwise_distances(embeddings, metric=self.metric) np.fill_diagonal(dists, np.nan) logger.debug("Index complete") return dists def _build_neighbors(self, embeddings): logger.debug( "Generating neighbors graph for %d embeddings...", len(embeddings), ) # Center embeddings embeddings = np.asarray(embeddings) embeddings -= embeddings.mean(axis=0, keepdims=True) metric = self.metric if metric == "cosine": # Nearest neighbors does not directly support cosine distance, so # we approximate via euclidean distance on unit-norm embeddings cosine_hack = True embeddings = skp.normalize(embeddings, axis=1) metric = "euclidean" else: cosine_hack = False neighbors = skn.NearestNeighbors(metric=metric) neighbors.fit(embeddings) setattr(neighbors, _COSINE_HACK_ATTR, cosine_hack) logger.debug("Index complete") return neighbors def _get_inds(ids, index_ids, ftype, allow_missing, warn_missing): if etau.is_str(ids): ids = [ids] ids_map = {_id: i for i, _id in enumerate(index_ids)} inds = [] bad_ids = [] for _id in ids: idx = ids_map.get(_id, None) if idx is not None: inds.append(idx) else: bad_ids.append(_id) num_missing = len(bad_ids) if num_missing > 0: if not allow_missing: raise ValueError( "Found %d %s IDs (eg '%s') that are not present in the index" % (num_missing, ftype, bad_ids[0]) ) if warn_missing: logger.warning( "Ignoring %d %s IDs that are not present in the index", num_missing, ftype, ) return np.array(inds) def _nanargmin(array, k=1): if k == 1: inds = np.nanargmin(array, axis=1) inds = [np.array([i]) for i in inds] else: inds = np.argsort(array, axis=1) inds = list(inds[:, :k]) return inds ================================================ FILE: fiftyone/brain/internal/core/uniqueness.py ================================================ """ Uniqueness methods. | Copyright 2017-2026, Voxel51, Inc. | `voxel51.com `_ | """ import logging import numpy as np import eta.core.utils as etau import fiftyone.brain as fb import fiftyone.core.brain as fob import fiftyone.core.fields as fof import fiftyone.core.labels as fol import fiftyone.core.utils as fou import fiftyone.core.validation as fov import fiftyone.brain.internal.core.utils as fbu import fiftyone.brain.internal.models as fbm logger = logging.getLogger(__name__) _ALLOWED_ROI_FIELD_TYPES = ( fol.Detection, fol.Detections, fol.Polyline, fol.Polylines, ) _DEFAULT_MODEL = "simple-resnet-cifar10" _DEFAULT_BATCH_SIZE = 16 def compute_uniqueness( samples, uniqueness_field, roi_field, embeddings, similarity_index, model, model_kwargs, force_square, alpha, batch_size, num_workers, skip_failures, progress, ): """See ``fiftyone/brain/__init__.py``.""" # # Algorithm # # Uniqueness is computed based on a classification model. Each sample is # embedded into a vector space based on the model. Then, we compute the # knn's (k is a parameter of the uniqueness function). The uniqueness is # then proportional to these distances. The intuition is that a sample is # unique when it is far from other samples in the set. This is different # than, say, "representativeness" which would stress samples that are core # to dense clusters of related samples. # fov.validate_collection(samples) if roi_field is not None: fov.validate_collection_label_fields( samples, roi_field, _ALLOWED_ROI_FIELD_TYPES ) if etau.is_str(embeddings): embeddings_field, embeddings_exist = fbu.parse_data_field( samples, embeddings, patches_field=roi_field, data_type="embeddings", ) embeddings = None else: embeddings_field = None embeddings_exist = None if etau.is_str(similarity_index): similarity_index = samples.load_brain_results(similarity_index) if ( model is None and embeddings is None and similarity_index is None and not embeddings_exist ): model = fbm.load_model(_DEFAULT_MODEL) if batch_size is None: batch_size = _DEFAULT_BATCH_SIZE config = UniquenessConfig( uniqueness_field, roi_field=roi_field, embeddings_field=embeddings_field, similarity_index=similarity_index, model=model, model_kwargs=model_kwargs, ) brain_key = uniqueness_field brain_method = config.build() brain_method.ensure_requirements() brain_method.register_run(samples, brain_key, cleanup=False) if roi_field is not None: # @todo experiment with mean(), max(), abs().max(), etc agg_fcn = lambda e: np.mean(e, axis=0) else: agg_fcn = None embeddings, sample_ids, _ = fbu.get_embeddings( samples, model=model, model_kwargs=model_kwargs, patches_field=roi_field, embeddings_field=embeddings_field, embeddings=embeddings, similarity_index=similarity_index, force_square=force_square, alpha=alpha, handle_missing="image", agg_fcn=agg_fcn, batch_size=batch_size, num_workers=num_workers, skip_failures=skip_failures, progress=progress, ) if similarity_index is None: similarity_index = fb.compute_similarity( samples, backend="sklearn", embeddings=False ) similarity_index.add_to_index(embeddings, sample_ids) logger.info("Computing uniqueness...") uniqueness = _compute_uniqueness( embeddings, similarity_index, progress=progress ) # Ensure field exists, even if `uniqueness` is empty samples._dataset.add_sample_field(uniqueness_field, fof.FloatField) uniqueness = {_id: u for _id, u in zip(sample_ids, uniqueness)} if uniqueness: samples.set_values(uniqueness_field, uniqueness, key_field="id") brain_method.save_run_results(samples, brain_key, None) logger.info("Uniqueness computation complete") def _compute_uniqueness( embeddings, similarity_index, batch_size=10, progress=None ): K = 3 num_embeddings = len(embeddings) if num_embeddings <= K: return [1] * num_embeddings if similarity_index.config.method == "sklearn": _, dists = similarity_index._kneighbors(k=K + 1, return_dists=True) else: dists = [] with fou.ProgressBar(total=num_embeddings, progress=progress) as pb: for _embeddings in fou.iter_slices(embeddings, batch_size): _, _, _dists = similarity_index._kneighbors( query=_embeddings, k=K + 1, return_dists=True ) dists.extend(_dists) pb.update(len(_dists)) dists = np.array(dists) # @todo experiment on which method for assessing uniqueness is best # # To get something going, for now, just take a weighted mean # weights = [0.6, 0.3, 0.1] sample_dists = np.mean(dists[:, 1:] * weights, axis=1) # Normalize to keep the user on common footing across datasets sample_dists /= sample_dists.max() return sample_dists # @todo move to `fiftyone/brain/uniqueness.py` # Don't do this hastily; `get_brain_info()` on existing datasets has this # class's full path in it and may need migration class UniquenessConfig(fob.BrainMethodConfig): def __init__( self, uniqueness_field, roi_field=None, embeddings_field=None, similarity_index=None, model=None, model_kwargs=None, **kwargs, ): if similarity_index is not None and not etau.is_str(similarity_index): similarity_index = similarity_index.key if model is not None and not etau.is_str(model): model = etau.get_class_name(model) self.uniqueness_field = uniqueness_field self.roi_field = roi_field self.embeddings_field = embeddings_field self.similarity_index = similarity_index self.model = model self.model_kwargs = model_kwargs super().__init__(**kwargs) @property def type(self): return "uniqueness" @property def method(self): return "neighbors" class Uniqueness(fob.BrainMethod): def ensure_requirements(self): pass def get_fields(self, samples, brain_key): fields = [self.config.uniqueness_field] if self.config.roi_field is not None: fields.append(self.config.roi_field) if self.config.embeddings_field is not None: fields.append(self.config.embeddings_field) return fields def cleanup(self, samples, brain_key): uniqueness_field = self.config.uniqueness_field samples._dataset.delete_sample_fields(uniqueness_field, error_level=1) def _validate_run(self, samples, brain_key, existing_info): self._validate_fields_match( brain_key, "uniqueness_field", existing_info ) ================================================ FILE: fiftyone/brain/internal/core/utils.py ================================================ """ Utilities. | Copyright 2017-2026, Voxel51, Inc. | `voxel51.com `_ | """ import itertools import logging import random import string import numpy as np import eta.core.utils as etau import fiftyone.brain as fob import fiftyone.core.fields as fof import fiftyone.core.labels as fol import fiftyone.core.models as fom import fiftyone.core.media as fomm import fiftyone.core.patches as fop import fiftyone.zoo as foz from fiftyone import ViewField as F logger = logging.getLogger(__name__) def parse_data( samples, patches_field=None, data=None, data_type="embeddings", allow_missing=True, warn_missing=True, ): if isinstance(data, fob.SimilarityIndex): return get_embeddings_from_index( samples, data, patches_field=None, allow_missing=True, warn_missing=True, ) _validate_args(samples, patches_field=patches_field) if patches_field is None: if isinstance(data, dict): sample_ids, data = zip(*data.items()) return np.array(data), np.array(sample_ids), None sample_ids, _ = get_ids(samples, data=data, data_type=data_type) return data, sample_ids, None if isinstance(data, dict): value = next(iter(data.values()), None) if isinstance(value, np.ndarray) and value.ndim == 1: label_ids, data = zip(*data.items()) return _parse_label_data( samples, patches_field, label_ids, data, data_type, allow_missing, warn_missing, ) sample_ids, label_ids = get_ids( samples, patches_field=patches_field, data=data, data_type=data_type, ) return data, sample_ids, label_ids def _parse_label_data( samples, patches_field, label_ids, data, data_type, allow_missing, warn_missing, ): if samples._is_patches: sample_id_path = "sample_id" else: sample_id_path = "id" label_type, label_id_path = samples._get_label_field_path( patches_field, "id" ) is_list_field = issubclass(label_type, fol._LABEL_LIST_FIELDS) ref_sample_ids, ref_label_ids = samples._dataset.values( [sample_id_path, label_id_path] ) if is_list_field: ids_map = {} for _sample_id, _lids in zip(ref_sample_ids, ref_label_ids): if _lids: for _label_id in _lids: ids_map[_label_id] = _sample_id else: ids_map = dict(zip(ref_label_ids, ref_sample_ids)) _data = [] _sample_ids = [] _label_ids = [] _missing_ids = [] for _lid, _d in zip(label_ids, data): _sid = ids_map.get(_lid, None) if _sid is not None: _data.append(_d) _sample_ids.append(_sid) _label_ids.append(_lid) else: _missing_ids.append(_lid) num_missing = len(_missing_ids) if num_missing > 0: if not allow_missing: raise ValueError( "Unable to retrieve sample IDs for %d label IDs (eg %s)" % (num_missing, _missing_ids[0]) ) if warn_missing: logger.warning( "Ignoring %s for %d label IDs (eg %s) for which sample IDs " "could not be retrieved", data_type, num_missing, _missing_ids[0], ) return np.array(_data), np.array(_sample_ids), np.array(_label_ids) def get_embeddings_from_index( samples, similarity_index, patches_field=None, allow_missing=True, warn_missing=True, ): if patches_field is None: if samples._is_patches: sample_id_path = "sample_id" else: sample_id_path = "id" sample_ids = samples.values(sample_id_path) label_ids = None else: if samples._is_patches: label_id_path = "id" else: _, label_id_path = samples._get_label_field_path( patches_field, "id" ) sample_ids = None label_ids = samples.values(label_id_path, unwind=True) logger.info("Retrieving embeddings from similarity index...") return similarity_index.get_embeddings( sample_ids=sample_ids, label_ids=label_ids, allow_missing=allow_missing, warn_missing=warn_missing, ) def get_ids( samples, patches_field=None, data=None, data_type="embeddings", handle_missing="skip", ref_sample_ids=None, ): _validate_args(samples, patches_field=patches_field) if patches_field is None: if ref_sample_ids is not None: sample_ids = ref_sample_ids else: sample_ids = samples.values("id") if data is not None and len(sample_ids) != len(data): raise ValueError( "The number of %s (%d) in these results no longer matches the " "number of samples (%d) in the collection. You must " "regenerate the results" % (data_type, len(data), len(sample_ids)) ) return np.array(sample_ids), None sample_ids, label_ids = _get_patch_ids( samples, patches_field, handle_missing=handle_missing, ref_sample_ids=ref_sample_ids, ) if data is not None and len(sample_ids) != len(data): raise ValueError( "The number of %s (%d) in these results no longer matches the " "number of labels (%d) in the '%s' field of the collection. You " "must regenerate the results" % (data_type, len(data), len(sample_ids), patches_field) ) return np.array(sample_ids), np.array(label_ids) def filter_ids( samples, index_sample_ids, index_label_ids, patches_field=None, allow_missing=True, warn_missing=False, ): _validate_args(samples, patches_field=patches_field) if patches_field is None: if samples._is_patches: sample_ids = np.array(samples.values("sample_id")) else: sample_ids = np.array(samples.values("id")) if index_sample_ids is None: return sample_ids, None, None, None keep_inds, good_inds, bad_ids = _parse_ids( sample_ids, index_sample_ids, "samples", allow_missing, warn_missing, ) if bad_ids is not None: sample_ids = sample_ids[good_inds] return sample_ids, None, keep_inds, good_inds sample_ids, label_ids = _get_patch_ids(samples, patches_field) if index_label_ids is None: return sample_ids, label_ids, None, None keep_inds, good_inds, bad_ids = _parse_ids( label_ids, index_label_ids, "labels", allow_missing, warn_missing, ) if bad_ids is not None: sample_ids = sample_ids[good_inds] label_ids = label_ids[good_inds] return sample_ids, label_ids, keep_inds, good_inds def _get_patch_ids( samples, patches_field, handle_missing="skip", ref_sample_ids=None ): if samples._is_patches: sample_id_path = "sample_id" else: sample_id_path = "id" label_type, label_id_path = samples._get_label_field_path( patches_field, "id" ) is_list_field = issubclass(label_type, fol._LABEL_LIST_FIELDS) sample_ids, label_ids = samples.values([sample_id_path, label_id_path]) if ref_sample_ids is not None: sample_ids, label_ids = _apply_ref_sample_ids( sample_ids, label_ids, ref_sample_ids ) if is_list_field: sample_ids, label_ids = _flatten_list_ids( sample_ids, label_ids, handle_missing ) return np.array(sample_ids), np.array(label_ids) def _apply_ref_sample_ids(sample_ids, label_ids, ref_sample_ids): ref_label_ids = [None] * len(ref_sample_ids) inds_map = {_id: i for i, _id in enumerate(ref_sample_ids)} for _id, _lid in zip(sample_ids, label_ids): idx = inds_map.get(_id, None) if idx is not None: ref_label_ids[idx] = _lid return ref_sample_ids, ref_label_ids def _flatten_list_ids(sample_ids, label_ids, handle_missing): _sample_ids = [] _label_ids = [] _add_missing = handle_missing == "image" for _id, _lids in zip(sample_ids, label_ids): if _lids: for _lid in _lids: _sample_ids.append(_id) _label_ids.append(_lid) elif _add_missing: _sample_ids.append(_id) _label_ids.append(None) return _sample_ids, _label_ids def _parse_ids(ids, index_ids, ftype, allow_missing, warn_missing): if np.array_equal(ids, index_ids): return None, None, None inds_map = {_id: idx for idx, _id in enumerate(index_ids)} keep_inds = [] bad_inds = [] bad_ids = [] for _idx, _id in enumerate(ids): ind = inds_map.get(_id, None) if ind is not None: keep_inds.append(ind) else: bad_inds.append(_idx) bad_ids.append(_id) num_missing_index = len(index_ids) - len(keep_inds) if num_missing_index > 0: if not allow_missing: raise ValueError( "The index contains %d %s that are not present in the " "provided collection" % (num_missing_index, ftype) ) if warn_missing: logger.warning( "Ignoring %d %s from the index that are not present in the " "provided collection", num_missing_index, ftype, ) num_missing_collection = len(bad_ids) if num_missing_collection > 0: if not allow_missing: raise ValueError( "The provided collection contains %d %s not present in the " "index" % (num_missing_collection, ftype) ) if warn_missing: logger.warning( "Ignoring %d %s from the provided collection that are not " "present in the index", num_missing_collection, ftype, ) bad_inds = np.array(bad_inds, dtype=np.int64) good_inds = np.full(ids.shape, True) good_inds[bad_inds] = False else: good_inds = None bad_ids = None keep_inds = np.array(keep_inds, dtype=np.int64) return keep_inds, good_inds, bad_ids def skip_ids(samples, ids, patches_field=None, warn_existing=False): sample_ids, label_ids = get_ids(samples, patches_field=patches_field) if patches_field is not None: exclude_ids = list(set(label_ids) & set(ids)) num_existing = len(exclude_ids) if num_existing > 0: if warn_existing: logger.warning("Skipping %d existing label IDs", num_existing) samples = samples.exclude_labels( ids=exclude_ids, fields=patches_field ) else: exclude_ids = list(set(sample_ids) & set(ids)) num_existing = len(exclude_ids) if num_existing > 0: if warn_existing: logger.warning("Skipping %d existing sample IDs", num_existing) samples = samples.exclude(exclude_ids) return samples def add_ids( sample_ids, label_ids, index_sample_ids, index_label_ids, patches_field=None, overwrite=True, allow_existing=True, warn_existing=False, ): if patches_field is not None: ids = label_ids index_ids = index_label_ids else: ids = sample_ids index_ids = index_sample_ids ii = [] jj = [] ids_map = {_id: _i for _i, _id in enumerate(index_ids)} new_idx = len(index_ids) for _i, _id in enumerate(ids): _idx = ids_map.get(_id, None) if _idx is None: _idx = new_idx new_idx += 1 ii.append(_i) jj.append(_idx) ii = np.array(ii) jj = np.array(jj) n = len(index_sample_ids) if not overwrite: existing_inds = np.nonzero(jj < n)[0] num_existing = existing_inds.size if num_existing > 0: if not allow_existing: raise ValueError( "Found %d IDs (eg '%s') that are already present in the " "index" % (num_existing, ids[ii[0]]) ) elif warn_existing: logger.warning( "Ignoring %d IDs (eg '%s') that are already present in " "the index", num_existing, ids[ii[0]], ) ii = np.delete(ii, existing_inds) jj = np.delete(jj, existing_inds) if ii.size > 0: sample_ids = np.array(sample_ids) if patches_field is not None: label_ids = np.array(label_ids) m = max(jj) - n + 1 if n == 0: index_sample_ids = np.array([], dtype=sample_ids.dtype) if patches_field is not None: index_label_ids = np.array([], dtype=label_ids.dtype) if m > 0: index_sample_ids = np.concatenate( (index_sample_ids, np.empty(m, dtype=index_sample_ids.dtype)) ) if patches_field is not None: index_label_ids = np.concatenate( (index_label_ids, np.empty(m, dtype=index_label_ids.dtype)) ) index_sample_ids[jj] = sample_ids[ii] if patches_field is not None: index_label_ids[jj] = label_ids[ii] return index_sample_ids, index_label_ids, ii, jj def add_embeddings( samples, embeddings, sample_ids, label_ids, embeddings_field, patches_field=None, ): dataset = samples._dataset if dataset.media_type == fomm.GROUP: view = dataset.select_group_slices(_allow_mixed=True) else: view = dataset if patches_field is not None: _, embeddings_path = dataset._get_label_field_path( patches_field, embeddings_field ) values = dict(zip(label_ids, embeddings)) view.set_label_values(embeddings_path, values, dynamic=True) else: values = dict(zip(sample_ids, embeddings)) view.set_values(embeddings_field, values, key_field="id") def remove_ids( sample_ids, label_ids, index_sample_ids, index_label_ids, patches_field=None, allow_missing=True, warn_missing=False, ): rm_inds = [] if sample_ids is not None: rm_inds.extend( _find_ids( sample_ids, index_sample_ids, allow_missing, warn_missing, "sample", ) ) if label_ids is not None: rm_inds.extend( _find_ids( label_ids, index_label_ids, allow_missing, warn_missing, "label", ) ) rm_inds = np.array(rm_inds) if rm_inds.size > 0: index_sample_ids = np.delete(index_sample_ids, rm_inds) if patches_field is not None: index_label_ids = np.delete(index_label_ids, rm_inds) return index_sample_ids, index_label_ids, rm_inds def _find_ids(ids, index_ids, allow_missing, warn_missing, ftype): found_inds = [] missing_ids = [] ids_map = {_id: _i for _i, _id in enumerate(index_ids)} for _id in ids: ind = ids_map.get(_id, None) if ind is not None: found_inds.append(ind) elif not allow_missing: missing_ids.append(_id) num_missing = len(missing_ids) if num_missing > 0: if not allow_missing: raise ValueError( "Found %d %d IDs (eg '%s') that are not present in the index" % (num_missing, ftype, missing_ids[0]) ) if warn_missing: logger.warning( "Ignoring %d %d IDs (eg '%s') that are not present in the " "index", num_missing, ftype, missing_ids[0], ) return found_inds def remove_embeddings( samples, embeddings_field, sample_ids=None, label_ids=None, patches_field=None, ): dataset = samples._dataset if dataset.media_type == fomm.GROUP: view = dataset.select_group_slices(_allow_mixed=True) else: view = dataset if patches_field is not None: _, embeddings_path = dataset._get_label_field_path( patches_field, embeddings_field ) if sample_ids is not None and label_ids is None: _, id_path = dataset._get_label_field_path(patches_field, "id") label_ids = view.select(sample_ids).values(id_path, unwind=True) if label_ids is not None: values = dict(zip(label_ids, itertools.repeat(None))) view.set_label_values(embeddings_path, values) elif sample_ids is not None: values = dict(zip(sample_ids, itertools.repeat(None))) view.set_values(embeddings_field, values, key_field="id") def filter_values(values, keep_inds, patches_field=None): if patches_field: _values = list(itertools.chain.from_iterable(values)) else: _values = values _values = np.asarray(_values) if _values.size == keep_inds.size: _values = _values[keep_inds] else: num_expected = np.count_nonzero(keep_inds) if _values.size != num_expected: raise ValueError( "Expected %d raw values or %d pre-filtered values; found %d " "values" % (keep_inds.size, num_expected, values.size) ) # @todo we might need to re-ravel patch values here in the future # We currently do not do this because all downstream users of this data # will gracefully handle either flat or nested list data return _values def get_values(samples, path_or_expr, ids, patches_field=None): _validate_args( samples, patches_field=patches_field, path_or_expr=path_or_expr ) return samples._get_values_by_id( path_or_expr, ids, link_field=patches_field ) def parse_data_field( samples, data_field, patches_field=None, data_type="embeddings", ): if not etau.is_str(data_field): raise ValueError( "Invalid %s field '%s'; expected a string field name" % (data_type, data_field) ) if patches_field is None: _data_field, is_frame_field = samples._handle_frame_field(data_field) if "." in _data_field: root, _ = _data_field.rsplit(".", 1) if not samples.has_field(root): raise ValueError( "Invalid %s field '%s'; root field '%s' does not exist" % (data_type, data_field, root) ) data_exists = samples.has_field(data_field) return data_field, data_exists if data_field.startswith(patches_field + "."): _, root = samples._get_label_field_path(patches_field) if not data_field.startswith(root + "."): raise ValueError( "Invalid %s field '%s' for patches field '%s'" % (data_type, data_field, patches_field) ) data_field = data_field[len(root) + 1 :] if "." in data_field: _, root = samples._get_label_field_path(patches_field) root += data_field.rsplit(".", 1)[0] if not samples.has_field(root): raise ValueError( "Invalid %s field '%s'; root field '%s' does not exist" % (data_type, data_field, root) ) _, data_path = samples._get_label_field_path(patches_field, data_field) data_exists = samples.has_field(data_path) return data_field, data_exists def get_embeddings( samples, model=None, model_kwargs=None, patches_field=None, embeddings_field=None, embeddings=None, similarity_index=None, force_square=False, alpha=None, handle_missing="skip", agg_fcn=None, batch_size=None, num_workers=None, skip_failures=True, progress=None, ): _validate_args(samples, patches_field=patches_field) if ( model is None and embeddings_field is None and embeddings is None and similarity_index is None ): return _empty_embeddings(patches_field) if similarity_index is not None: return get_embeddings_from_index( samples, similarity_index, patches_field=patches_field, allow_missing=True, warn_missing=True, ) if ( embeddings is None and model is not None and not _has_embeddings_field(samples, embeddings_field, patches_field) ): if etau.is_str(model): model_kwargs = model_kwargs or {} model = foz.load_zoo_model(model, **model_kwargs) if patches_field is not None: logger.info("Computing patch embeddings...") embeddings = samples.compute_patch_embeddings( model, patches_field, embeddings_field=embeddings_field, force_square=force_square, alpha=alpha, handle_missing=handle_missing, batch_size=batch_size, num_workers=num_workers, skip_failures=skip_failures, progress=progress, ) else: logger.info("Computing embeddings...") embeddings = samples.compute_embeddings( model, embeddings_field=embeddings_field, batch_size=batch_size, num_workers=num_workers, skip_failures=skip_failures, progress=progress, ) if embeddings is None and embeddings_field is not None: embeddings, samples = _load_embeddings( samples, embeddings_field, patches_field=patches_field ) ref_sample_ids = None else: if isinstance(embeddings, dict): embeddings = [ embeddings.get(_id, None) for _id in samples.values("id") ] embeddings, ref_sample_ids = _handle_missing_embeddings( embeddings, samples ) if not isinstance(embeddings, np.ndarray) and not embeddings: return _empty_embeddings(patches_field) if patches_field is not None: if agg_fcn is not None: embeddings = np.stack([agg_fcn(e) for e in embeddings]) else: embeddings = np.concatenate(embeddings, axis=0) elif not isinstance(embeddings, np.ndarray): embeddings = np.stack(embeddings) if agg_fcn is not None: patches_field = None sample_ids, label_ids = get_ids( samples, patches_field=patches_field, data=embeddings, data_type="embeddings", handle_missing=handle_missing, ref_sample_ids=ref_sample_ids, ) return embeddings, sample_ids, label_ids def get_unique_name(name, ref_names_or_fcn, max_len=None): unique_name = _get_unique_name(name, ref_names_or_fcn) if max_len is not None: while name and len(unique_name) > max_len: name = name[:-1] unique_name = _get_unique_name(name, ref_names_or_fcn) return unique_name def _get_unique_name(name, ref_names_or_fcn): if etau.is_container(ref_names_or_fcn): return _get_unique_name_from_list(name, ref_names_or_fcn) return _get_unique_name_from_function(name, ref_names_or_fcn) def _get_unique_name_from_list(name, ref_names): ref_names = set(ref_names) if name not in ref_names: return name name += "-" + _get_random_characters(6) while name in ref_names: name += _get_random_characters(1) return name def _get_unique_name_from_function(name, exists_fcn): if not exists_fcn(name): return name name += "-" + _get_random_characters(6) while exists_fcn(name): name += _get_random_characters(1) return name def _get_random_characters(n): return "".join( random.choice(string.ascii_lowercase + string.digits) for _ in range(n) ) def _empty_embeddings(patches_field): embeddings = np.empty((0, 0), dtype=float) sample_ids = np.array([], dtype="`_ | """ # For backwards-compatibility with older versions of plugins like # https://github.com/voxel51/fiftyone-plugins/blob/5c800f1ded53c285f8e17f37e1ad9b2472fa93e7/plugins/brain/__init__.py#L25 from fiftyone.brain.visualization import ( Visualization, UMAPVisualization, TSNEVisualization, PCAVisualization, ManualVisualization, ) ================================================ FILE: fiftyone/brain/internal/models/.gitignore ================================================ cache/ ================================================ FILE: fiftyone/brain/internal/models/__init__.py ================================================ """ Brain models. | Copyright 2017-2026, Voxel51, Inc. | `voxel51.com `_ | """ from copy import deepcopy import logging import os from eta.core.config import ConfigError import eta.core.learning as etal import eta.core.models as etam import fiftyone.core.models as fom logger = logging.getLogger(__name__) _THIS_DIR = os.path.dirname(os.path.abspath(__file__)) _MODELS_MANIFEST_PATH = os.path.join(_THIS_DIR, "manifest.json") _MODELS_DIR = os.path.join(_THIS_DIR, "cache") def list_models(): """Returns the list of available models. Returns: a list of model names """ manifest = _load_models_manifest() return sorted([model.name for model in manifest]) def list_downloaded_models(): """Returns information about the models that have been downloaded. Returns: a dict mapping model names to (model path, ``eta.core.models.Model``) tuples """ manifest = _load_models_manifest() models = {} for model in manifest: if model.is_in_dir(_MODELS_DIR): model_path = model.get_path_in_dir(_MODELS_DIR) models[model.name] = (model_path, model) return models def is_model_downloaded(name): """Determines whether the model of the given name is downloaded. Args: name: the name of the model, which can have ``@`` appended to refer to a specific version of the model. If no version is specified, the latest version of the model is used Returns: True/False """ model = _get_model(name) return model.is_in_dir(_MODELS_DIR) def download_model(name, overwrite=False): """Downloads the model of the given name. If the model is already downloaded, it is not re-downloaded unless ``overwrite == True`` is specified. Args: name: the name of the model, which can have ``@`` appended to refer to a specific version of the model. If no version is specified, the latest version of the model is used. Call :func:`list_models` to see the available models overwrite (False): whether to overwrite any existing files Returns: tuple of - model: the ``eta.core.models.Model`` instance for the model - model_path: the path to the downloaded model on disk """ model, model_path = _get_model_in_dir(name) if not overwrite and is_model_downloaded(name): logger.info("Model '%s' is already downloaded", name) else: model.manager.download_model(model_path, force=overwrite) return model, model_path def install_model_requirements(name, error_level=0): """Installs any package requirements for the model with the given name. Args: name: the name of the model, which can have ``@`` appended to refer to a specific version of the model. If no version is specified, the latest version of the model is used. Call :func:`list_models` to see the available models error_level: the error level to use, defined as: 0: raise error if a requirement install fails 1: log warning if a requirement install fails 2: ignore install fails """ model = _get_model(name) model.install_requirements(error_level=error_level) def ensure_model_requirements(name, error_level=0): """Ensures that the package requirements for the model with the given name are satisfied. Args: name: the name of the model, which can have ``@`` appended to refer to a specific version of the model. If no version is specified, the latest version of the model is used. Call :func:`list_models` to see the available models error_level: the error level to use, defined as: 0: raise error if a requirement is not satisfied 1: log warning if a requirement is not satisifed 2: ignore unsatisifed requirements """ model = _get_model(name) model.ensure_requirements(error_level=error_level) def load_model( name, download_if_necessary=True, install_requirements=False, error_level=0, **kwargs ): """Loads the model of the given name. By default, the model will be downloaded if necessary. Args: name: the name of the model, which can have ``@`` appended to refer to a specific version of the model. If no version is specified, the latest version of the model is used. Call :func:`list_models` to see the available models download_if_necessary (True): whether to download the model if it is not found in the specified directory install_requirements: whether to install any requirements before loading the model. By default, this is False error_level: the error level to use, defined as: 0: raise error if a requirement is not satisfied 1: log warning if a requirement is not satisifed 2: ignore unsatisifed requirements **kwargs: keyword arguments to inject into the model's ``Config`` instance Returns: a :class:`fiftyone.core.models.Model` """ model = _get_model(name) if not model.is_in_dir(_MODELS_DIR): if not download_if_necessary: raise ValueError("Model '%s' is not downloaded" % name) download_model(name) if install_requirements: model.install_requirements(error_level=error_level) else: model.ensure_requirements(error_level=error_level) config_dict = deepcopy(model.default_deployment_config_dict) model_path = model.get_path_in_dir(_MODELS_DIR) return fom.load_model(config_dict, model_path=model_path, **kwargs) def find_model(name): """Returns the path to the model on disk. The model must be downloaded. Use :func:`download_model` to download models. Args: name: the name of the model, which can have ``@`` appended to refer to a specific version of the model. If no version is specified, the latest version of the model is used Returns: the path to the model on disk Raises: ValueError: if the model does not exist or has not been downloaded """ model, model_path = _get_model_in_dir(name) if not model.is_model_downloaded(model_path): raise ValueError("Model '%s' is not downloaded" % name) return model_path def get_model(name): """Returns the ``eta.core.models.Model`` instance for the model with the given name. Args: name: the name of the model Returnsn ``eta.core.models.Model``:class:`ZooModel` """ return _get_model(name) def delete_model(name): """Deletes the model from local disk, if necessary. Args: name: the name of the model, which can have ``@`` appended to refer to a specific version of the model. If no version is specified, the latest version of the model is used """ model, model_path = _get_model_in_dir(name) model.flush_model(model_path) class HasBrainModel(etal.HasPublishedModel): """Mixin class for Config classes of :class:`fiftyone.core.models.Model` instances whose models are stored privately by the FiftyOne Brain. """ def download_model_if_necessary(self): # pylint: disable=attribute-defined-outside-init if not self.model_name and not self.model_path: raise ConfigError( "Either `model_name` or `model_path` must be provided" ) if self.model_path is None: self.model_path = download_model(self.model_name) @classmethod def _get_model(cls, model_name): return get_model(model_name) def _load_models_manifest(): return etam.ModelsManifest.from_json(_MODELS_MANIFEST_PATH) def _get_model_in_dir(name): model = _get_model(name) model_path = model.get_path_in_dir(_MODELS_DIR) return model, model_path def _get_model(name): if etam.Model.has_version_str(name): return _get_exact_model(name) return _get_latest_model(name) def _get_exact_model(name): manifest = _load_models_manifest() try: return manifest.get_model_with_name(name) except etam.ModelError: raise ValueError("No model with name '%s' was found" % name) def _get_latest_model(base_name): manifest = _load_models_manifest() try: return manifest.get_latest_model_with_base_name(base_name) except etam.ModelError: raise ValueError("No models found with base name '%s'" % base_name) ================================================ FILE: fiftyone/brain/internal/models/manifest.json ================================================ { "models": [ { "base_name": "simple-resnet-cifar10", "base_filename": "simple-resnet-cifar10.pth", "version": "1.0", "description": "Simple ResNet trained on CIFAR-10", "manager": { "type": "fiftyone.core.models.ModelManager", "config": { "google_drive_id": "1SIO9XreK0w1ja4EuhBWcR10CnWxCOsom" } }, "default_deployment_config_dict": { "type": "fiftyone.brain.internal.models.torch.TorchImageModel", "config": { "entrypoint_fcn": "fiftyone.brain.internal.models.simple_resnet.simple_resnet", "output_processor_cls": "fiftyone.utils.torch.ClassifierOutputProcessor", "labels_string": "airplane,automobile,bird,cat,deer,dog,frog,horse,ship,truck", "image_size": [32, 32], "image_mean": [0.4914, 0.4822, 0.4465], "image_std": [0.2023, 0.1994, 0.201], "embeddings_layer": "flatten", "use_half_precision": false, "cudnn_benchmark": true } }, "date_created": "2020-05-07 08:25:51" } ] } ================================================ FILE: fiftyone/brain/internal/models/simple_resnet.py ================================================ """ Implementation of a simple ResNet that is suitable only for smallish data. The original implementation of this is from David Page's work on fast model training with resnets at https://github.com/davidcpage/cifar10-fast. | Copyright 2017-2026, Voxel51, Inc. | `voxel51.com `_ | """ from collections import namedtuple import os import numpy as np import torch from torch import nn def simple_resnet( channels=None, weight=0.125, pool=nn.MaxPool2d(2), extra_layers=(), res_layers=("layer1", "layer3"), ): channels = channels or { "prep": 64, "layer1": 128, "layer2": 256, "layer3": 512, } net = { "input": (None, []), "prep": conv_bn(3, channels["prep"]), "layer1": dict( conv_bn(channels["prep"], channels["layer1"]), pool=pool ), "layer2": dict( conv_bn(channels["layer1"], channels["layer2"]), pool=pool ), "layer3": dict( conv_bn(channels["layer2"], channels["layer3"]), pool=pool ), "pool": nn.MaxPool2d(4), "flatten": Flatten(), "linear": nn.Linear(channels["layer3"], 10, bias=False), "logits": Mul(weight), } for layer in res_layers: net[layer]["residual"] = residual(channels[layer]) for layer in extra_layers: net[layer]["extra"] = conv_bn(channels[layer], channels[layer]) return Network(net, input_layer="input", output_layer="logits") class Network(nn.Module): def __init__(self, net, input_layer=None, output_layer=None): super().__init__() self.input_layer = input_layer self.output_layer = output_layer self.graph = build_graph(net) for path, (val, _) in self.graph.items(): setattr(self, path.replace("/", "_"), val) def nodes(self): return (node for node, _ in self.graph.values()) def forward(self, inputs): if self.input_layer: outputs = {self.input_layer: inputs} else: outputs = dict(inputs) for k, (node, ins) in self.graph.items(): # only compute nodes that are not supplied as inputs. if k not in outputs: outputs[k] = node(*[outputs[x] for x in ins]) if self.output_layer: return outputs[self.output_layer] return outputs def half(self): for node in self.nodes(): if isinstance(node, nn.Module) and not isinstance( node, nn.BatchNorm2d ): node.half() return self def has_inputs(node): return type(node) is tuple def build_graph(net): flattened = pipeline(net) resolve_input = lambda rel_path, path, idx: ( os.path.normpath(os.path.sep.join((path, "..", rel_path))) if isinstance(rel_path, str) else flattened[idx + rel_path][0] ) return { path: ( node[0], [resolve_input(rel_path, path, idx) for rel_path in node[1]], ) for idx, (path, node) in enumerate(flattened) } def pipeline(net): return [ (os.path.sep.join(path), (node if has_inputs(node) else (node, [-1]))) for (path, node) in path_iter(net) ] class Crop(namedtuple("Crop", ("h", "w"))): def __call__(self, x, x0, y0): return x[..., y0 : y0 + self.h, x0 : x0 + self.w] def options(self, shape): *_, H, W = shape return [ {"x0": x0, "y0": y0} for x0 in range(W + 1 - self.w) for y0 in range(H + 1 - self.h) ] def output_shape(self, shape): *_, H, W = shape return (*_, self.h, self.w) class FlipLR(namedtuple("FlipLR", ())): def __call__(self, x, choice): if isinstance(x, np.ndarray): return x[..., ::-1].copy() return torch.flip(x, [-1]) if choice else x def options(self, shape): return [{"choice": b} for b in [True, False]] class Cutout(namedtuple("Cutout", ("h", "w"))): def __call__(self, x, x0, y0): x[..., y0 : y0 + self.h, x0 : x0 + self.w] = 0.0 return x def options(self, shape): *_, H, W = shape return [ {"x0": x0, "y0": y0} for x0 in range(W + 1 - self.w) for y0 in range(H + 1 - self.h) ] class PiecewiseLinear(namedtuple("PiecewiseLinear", ("knots", "vals"))): def __call__(self, t): return np.interp([t], self.knots, self.vals)[0] class Const(namedtuple("Const", ["val"])): def __call__(self, x): return self.val class Identity(namedtuple("Identity", [])): def __call__(self, x): return x class Add(namedtuple("Add", [])): def __call__(self, x, y): return x + y class AddWeighted(namedtuple("AddWeighted", ["wx", "wy"])): def __call__(self, x, y): return self.wx * x + self.wy * y class Mul(nn.Module): def __init__(self, weight): super().__init__() self.weight = weight def __call__(self, x): return x * self.weight class Flatten(nn.Module): def forward(self, x): return x.view(x.size(0), x.size(1)) class Concat(nn.Module): def forward(self, *xs): return torch.cat(xs, 1) class BatchNorm(nn.BatchNorm2d): def __init__( self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0, bias_init=0.0, ): super().__init__(num_features, eps=eps, momentum=momentum) if weight_init is not None: self.weight.data.fill_(weight_init) if bias_init is not None: self.bias.data.fill_(bias_init) self.weight.requires_grad = not weight_freeze self.bias.requires_grad = not bias_freeze def conv_bn(c_in, c_out): return { "conv": nn.Conv2d( c_in, c_out, kernel_size=3, stride=1, padding=1, bias=False ), "bn": BatchNorm(c_out), "relu": nn.ReLU(True), } def residual(c): return { "in": Identity(), "res1": conv_bn(c, c), "res2": conv_bn(c, c), "add": (Add(), ["in", "res2/relu"]), } def path_iter(nested_dict, pfx=()): for name, val in nested_dict.items(): if isinstance(val, dict): yield from path_iter(val, (*pfx, name)) else: yield ((*pfx, name), val) MODEL = "model" VALID_MODEL = "valid_model" OUTPUT = "output" ================================================ FILE: fiftyone/brain/internal/models/torch.py ================================================ """ PyTorch utilities. | Copyright 2017-2026, Voxel51, Inc. | `voxel51.com `_ | """ import fiftyone.utils.torch as fout from fiftyone.brain.internal.models import HasBrainModel import torch class TorchImageModelConfig(fout.TorchImageModelConfig, HasBrainModel): """Configuration for running a :class:`TorchImageModel`. See :class:`fiftyone.utils.torch.TorchImageModelConfig` for additional parameters. Args: model_name (None): the name of the Brain model state dict to load model_path (None): the path to a state dict on disk to load """ def __init__(self, d): d = self.init(d) super().__init__(d) class TorchImageModel(fout.TorchImageModel): """Wrapper for evaluating a Torch model on images whose state dict is stored privately by the Brain. Args: config: an :class:`TorchImageModelConfig` """ def _download_model(self, config): config.download_model_if_necessary() def _load_state_dict(self, model, config): state_dict = torch.load(config.model_path, map_location=self.device) model.load_state_dict(state_dict) ================================================ FILE: fiftyone/brain/similarity.py ================================================ """ Similarity interface. | Copyright 2017-2026, Voxel51, Inc. | `voxel51.com `_ | """ from collections import defaultdict from copy import deepcopy import inspect import logging from bson import ObjectId import numpy as np import eta.core.utils as etau import fiftyone.brain as fb import fiftyone.core.brain as fob import fiftyone.core.context as foc import fiftyone.core.dataset as fod import fiftyone.core.fields as fof import fiftyone.core.labels as fol import fiftyone.core.media as fom import fiftyone.core.patches as fop import fiftyone.core.stages as fos import fiftyone.core.utils as fou import fiftyone.core.view as fov import fiftyone.core.validation as fova import fiftyone.zoo as foz from fiftyone import ViewField as F fbu = fou.lazy_import("fiftyone.brain.internal.core.utils") logger = logging.getLogger(__name__) _ALLOWED_ROI_FIELD_TYPES = ( fol.Detection, fol.Detections, fol.Polyline, fol.Polylines, ) _DEFAULT_MODEL = "mobilenet-v2-imagenet-torch" _DEFAULT_BATCH_SIZE = None def compute_similarity( samples, patches_field, roi_field, embeddings, brain_key, model, model_kwargs, force_square, alpha, batch_size, num_workers, skip_failures, progress, backend, **kwargs, ): """See ``fiftyone/brain/__init__.py``.""" fova.validate_collection(samples) if roi_field is not None: fova.validate_collection_label_fields( samples, roi_field, _ALLOWED_ROI_FIELD_TYPES ) # Allow for `embeddings_field=XXX` and `embeddings=False` together embeddings_field = kwargs.pop("embeddings_field", None) if embeddings_field is not None or etau.is_str(embeddings): if embeddings_field is None: embeddings_field = embeddings embeddings = None embeddings_field, embeddings_exist = fbu.parse_data_field( samples, embeddings_field, patches_field=patches_field or roi_field, data_type="embeddings", ) else: embeddings_field = None embeddings_exist = None if model is None and embeddings is None and not embeddings_exist: model = _DEFAULT_MODEL if batch_size is None: batch_size = _DEFAULT_BATCH_SIZE if etau.is_str(model): _model_kwargs = model_kwargs or {} _model = foz.load_zoo_model(model, **_model_kwargs) else: _model = model try: supports_prompts = _model.can_embed_prompts except: supports_prompts = False if brain_key is not None and supports_prompts and not etau.is_str(model): logger.warning( "This index will not support prompt queries in the App or in " "future Python sessions. You can support this by providing the " "string name of a zoo model rather than a Model instance to " "compute_similarity(model=)." ) config = _parse_config( backend, embeddings_field=embeddings_field, patches_field=patches_field, roi_field=roi_field, model=model, model_kwargs=model_kwargs, supports_prompts=supports_prompts, **kwargs, ) brain_method = config.build() brain_method.ensure_requirements() # Similarity indexes can be modified after creation, so we always register # the index on the full dataset so that queries will always be performed # against the full index by default dataset = samples._root_dataset if samples._is_frames: dataset = samples._base_view if brain_key is not None: # Don't allow overwriting an existing run with same key, since we # need the existing run in order to perform workflows like # automatically cleaning up the backend's index brain_method.register_run(dataset, brain_key, overwrite=False) results = brain_method.initialize(dataset, brain_key) results._model = _model results._supports_prompts = supports_prompts get_embeddings = embeddings is not False if not results.is_external and results.total_index_size > 0: # No need to load embeddings because the index already has them get_embeddings = False if get_embeddings: # Don't immediatly store embeddings in DB; let `add_to_index()` do it if not embeddings_exist: embeddings_field = None if roi_field is not None: handle_missing = "image" agg_fcn = lambda e: np.mean(e, axis=0) else: handle_missing = "skip" agg_fcn = None embeddings, sample_ids, label_ids = fbu.get_embeddings( samples, model=_model, patches_field=patches_field or roi_field, embeddings=embeddings, embeddings_field=embeddings_field, force_square=force_square, alpha=alpha, handle_missing=handle_missing, agg_fcn=agg_fcn, batch_size=batch_size, num_workers=num_workers, skip_failures=skip_failures, progress=progress, ) else: embeddings = None sample_ids = None label_ids = None if embeddings is not None: results.add_to_index(embeddings, sample_ids, label_ids=label_ids) brain_method.save_run_results(dataset, brain_key, results) return results def _parse_config(name, **kwargs): if name is None: name = fb.brain_config.default_similarity_backend if inspect.isclass(name): return name(**kwargs) backends = fb.brain_config.similarity_backends if name not in backends: raise ValueError( "Unsupported backend '%s'. The available backends are %s" % (name, sorted(backends.keys())) ) params = deepcopy(backends[name]) config_cls = kwargs.pop("config_cls", None) if config_cls is None: config_cls = params.pop("config_cls", None) if config_cls is None: raise ValueError("Similarity backend '%s' has no `config_cls`" % name) if etau.is_str(config_cls): config_cls = etau.get_class(config_cls) params.update(**kwargs) return config_cls(**params) class SimilarityConfig(fob.BrainMethodConfig): """Similarity configuration. Args: embeddings_field (None): the sample field containing the embeddings, if one was provided model (None): the :class:`fiftyone.core.models.Model` or name of the zoo model that was used to compute embeddings, if known model_kwargs (None): a dictionary of optional keyword arguments to pass to the model's ``Config`` when a model name is provided patches_field (None): the sample field defining the patches being analyzed, if any roi_field (None): the sample field defining a region of interest within each image to use to compute embeddings, if any supports_prompts (False): whether this run supports prompt queries """ def __init__( self, embeddings_field=None, model=None, model_kwargs=None, patches_field=None, roi_field=None, supports_prompts=None, **kwargs, ): if model is not None and not etau.is_str(model): model = None # We can't declare permanent support for prompts because we don't # know how to load the model in future sessions supports_prompts = None self.embeddings_field = embeddings_field self.model = model self.model_kwargs = model_kwargs self.patches_field = patches_field self.roi_field = roi_field self.supports_prompts = supports_prompts super().__init__(**kwargs) @property def type(self): return "similarity" @property def method(self): """The name of the similarity backend.""" raise NotImplementedError("subclass must implement method") @property def max_k(self): """A maximum k value for nearest neighbor queries, or None if there is no limit. """ raise NotImplementedError("subclass must implement max_k") @property def supports_least_similarity(self): """Whether this backend supports least similarity queries.""" raise NotImplementedError( "subclass must implement supports_least_similarity" ) @property def supported_aggregations(self): """A tuple of supported values for the ``aggregation`` parameter of the backend's :meth:`sort_by_similarity() ` and :meth:`_kneighbors() ` methods. """ raise NotImplementedError( "subclass must implement supported_aggregations" ) def load_credentials(self, **kwargs): self._load_parameters(**kwargs) def _load_parameters(self, **kwargs): name = self.method parameters = fb.brain_config.similarity_backends.get(name, {}) for name, value in kwargs.items(): if value is None: value = parameters.get(name, None) if value is not None: setattr(self, name, value) class Similarity(fob.BrainMethod): """Base class for similarity factories. Args: config: a :class:`SimilarityConfig` """ def initialize(self, samples, brain_key): """Initializes a similarity index. Args: samples: a :class:`fiftyone.core.collections.SampleColllection` brain_key: the brain key Returns: a :class:`SimilarityIndex` """ raise NotImplementedError("subclass must implement initialize()") def get_fields(self, samples, brain_key): fields = [] if self.config.patches_field is not None: fields.append(self.config.patches_field) if self.config.embeddings_field is not None: fields.append(self.config.embeddings_field) return fields class SimilarityIndex(fob.BrainResults): """Base class for similarity indexes. Args: samples: the :class:`fiftyone.core.collections.SampleCollection` used config: the :class:`SimilarityConfig` used brain_key: the brain key backend (None): a :class:`Similarity` backend """ def __init__(self, samples, config, brain_key, backend=None): super().__init__(samples, config, brain_key, backend=backend) self._model = None self._supports_prompts = None self._last_view = None self._last_views = [] self._curr_view = None self._curr_view_allow_missing = None self._curr_view_warn_missing = None self._curr_sample_ids = None self._curr_label_ids = None self._curr_keep_inds = None self._curr_missing_size = None self.use_view(samples) def __enter__(self): self._last_views.append(self._last_view) return self def __exit__(self, *args): try: last_view = self._last_views.pop() except: last_view = self._samples self.use_view(last_view) @property def config(self): """The :class:`SimilarityConfig` for these results.""" return self._config @property def supports_prompts(self): """Whether this similarity index supports prompt queries.""" if self._supports_prompts is not None: return self._supports_prompts return self.config.supports_prompts or False @property def is_external(self): """Whether this similarity index manages its own embeddings (True) or loads them directly from the ``embeddings_field`` of the dataset (False). """ return True # assume external unless explicitly overridden @property def sample_ids(self): """The sample IDs of the full index, or ``None`` if not supported.""" return None @property def label_ids(self): """The label IDs of the full index, or ``None`` if not applicable or not supported. """ return None @property def total_index_size(self): """The total number of data points in the index. If :meth:`use_view` has been called to restrict the index, this value may be larger than the current :meth:`index_size`. """ raise NotImplementedError("subclass must implement total_index_size") @property def has_view(self): """Whether the index is currently restricted to a view. Use :meth:`use_view` to restrict the index to a view, and use :meth:`clear_view` to reset to the full index. """ # Full dataset if isinstance(self._curr_view, fod.Dataset): return False # Full group slices view if ( isinstance(self._curr_view, fov.DatasetView) and self._curr_view._root_dataset.media_type == fom.GROUP and len(self._curr_view._stages) == 1 and isinstance(self._curr_view._stages[0], fos.SelectGroupSlices) and self._curr_view._pipeline() == [] ): return False # Full patches view if ( self.config.patches_field is not None and isinstance(self._curr_view, fop.PatchesView) and len(self._curr_view._all_stages) == 1 ): return False return self._curr_view.view() != self._samples.view() @property def view(self): """The :class:`fiftyone.core.collections.SampleCollection` against which results are currently being generated. If :meth:`use_view` has been called, this view may be different than the collection on which the full index was generated. """ return self._curr_view @property def current_sample_ids(self): """The sample IDs of the currently active data points in the index. If :meth:`use_view` has been called, this may be a subset of the full index. If the index does not support full sample ID lists (ie if :meth:`sample_ids` is ``None``), then this will be all sample IDs in the current :meth:`view` regardless of whether all samples are indexed. """ self._apply_view_if_necessary() return self._curr_sample_ids @property def current_label_ids(self): """The label IDs of the currently active data points in the index, or ``None`` if not applicable. If :meth:`use_view` has been called, this may be a subset of the full index. If the index does not support full label ID lists (ie if :meth:`label_ids` is ``None``), then this will be all label IDs in the current :meth:`view` regardless of whether all labels are indexed. """ self._apply_view_if_necessary() return self._curr_label_ids @property def _current_inds(self): """The indices of :meth:`current_sample_ids` in :meth:`sample_ids`, or ``None`` if not supported or if the full index is currently being used. """ self._apply_view_if_necessary() return self._curr_keep_inds @property def index_size(self): """The number of active data points in the index. If :meth:`use_view` has been called to restrict the index, this property will reflect the size of the active index. """ self._apply_view_if_necessary() return len(self._curr_sample_ids) @property def missing_size(self): """The total number of data points in :meth:`view` that are missing from this index, or ``None`` if unknown. This property is only applicable when :meth:`use_view` has been called, and it will be ``None`` if no data points are missing or when the backend does not support it. """ self._apply_view_if_necessary() return self._curr_missing_size def add_to_index( self, embeddings, sample_ids, label_ids=None, overwrite=True, allow_existing=True, warn_existing=False, reload=True, ): """Adds the given embeddings to the index. Args: embeddings: a ``num_embeddings x num_dims`` array of embeddings sample_ids: a ``num_embeddings`` array of sample IDs label_ids (None): a ``num_embeddings`` array of label IDs, if applicable overwrite (True): whether to replace (True) or ignore (False) existing embeddings with the same sample/label IDs allow_existing (True): whether to ignore (True) or raise an error (False) when ``overwrite`` is False and a provided ID already exists in the warn_existing (False): whether to log a warning if an embedding is not added to the index because its ID already exists reload (True): whether to call :meth:`reload` to refresh the current view after the update """ raise NotImplementedError("subclass must implement add_to_index()") def remove_from_index( self, sample_ids=None, label_ids=None, allow_missing=True, warn_missing=False, reload=True, ): """Removes the specified embeddings from the index. Args: sample_ids (None): an array of sample IDs label_ids (None): an array of label IDs, if applicable allow_missing (True): whether to allow the index to not contain IDs that you provide (True) or whether to raise an error in this case (False) warn_missing (False): whether to log a warning if the index does not contain IDs that you provide reload (True): whether to call :meth:`reload` to refresh the current view after the update """ raise NotImplementedError( "subclass must implement remove_from_index()" ) def get_embeddings( self, sample_ids=None, label_ids=None, allow_missing=True, warn_missing=False, ): """Retrieves the embeddings for the given IDs from the index. If no IDs are provided, the entire index is returned. Args: sample_ids (None): a sample ID or list of sample IDs for which to retrieve embeddings label_ids (None): a label ID or list of label IDs for which to retrieve embeddings allow_missing (True): whether to allow the index to not contain IDs that you provide (True) or whether to raise an error in this case (False) warn_missing (False): whether to log a warning if the index does not contain IDs that you provide Returns: a tuple of: - a ``num_embeddings x num_dims`` array of embeddings - a ``num_embeddings`` array of sample IDs - a ``num_embeddings`` array of label IDs, if applicable, or else ``None`` """ raise NotImplementedError("subclass must implement get_embeddings()") def use_view( self, samples, allow_missing=True, warn_missing=False, ): """Restricts the index to the provided view. Subsequent calls to methods on this instance will only contain results from the specified view rather than the full index. Use :meth:`clear_view` to reset to the full index. Or, equivalently, use the context manager interface as demonstrated below to automatically reset the view when the context exits. Example usage:: import fiftyone as fo import fiftyone.brain as fob import fiftyone.zoo as foz dataset = foz.load_zoo_dataset("quickstart") results = fob.compute_similarity(dataset) print(results.index_size) # 200 view = dataset.take(50) with results.use_view(view): print(results.index_size) # 50 results.find_unique(10) print(results.unique_ids) plot = results.visualize_unique() plot.show() Args: samples: a :class:`fiftyone.core.collections.SampleCollection` allow_missing (True): whether to allow the provided collection to contain data points that this index does not contain (True) or whether to raise an error in this case (False) warn_missing (False): whether to log a warning if the provided collection contains data points that this index does not contain Returns: self """ self._last_view = self._curr_view self._curr_view = samples self._curr_view_allow_missing = allow_missing self._curr_view_warn_missing = warn_missing self._curr_sample_ids = None self._curr_label_ids = None self._curr_keep_inds = None self._curr_missing_size = None return self def _apply_view(self): sample_ids = self.sample_ids label_ids = self.label_ids if sample_ids is not None and not self.has_view: keep_inds = None good_inds = None else: sample_ids, label_ids, keep_inds, good_inds = fbu.filter_ids( self._curr_view, sample_ids, label_ids, patches_field=self.config.patches_field, allow_missing=self._curr_view_allow_missing, warn_missing=self._curr_view_warn_missing, ) if good_inds is not None: missing_size = good_inds.size - np.count_nonzero(good_inds) else: missing_size = None self._curr_sample_ids = sample_ids self._curr_label_ids = label_ids self._curr_keep_inds = keep_inds self._curr_missing_size = missing_size def _apply_view_if_necessary(self): if self._curr_sample_ids is None: self._apply_view() def clear_view(self): """Clears the view set by :meth:`use_view`, if any. Subsequent operations will be performed on the full index. """ self.use_view(self._samples) def reload(self): """Reloads the index for the current view. Subclasses may override this method, but by default this method simply passes the current :meth:`view` back into :meth:`use_view`, which updates the index's current ID set based on any changes to the view since the index was last loaded. """ self.use_view(self._curr_view) def cleanup(self): """Deletes the similarity index from the backend.""" raise NotImplementedError("subclass must implement cleanup()") def values(self, path_or_expr): """Extracts a flat list of values from the given field or expression corresponding to the current :meth:`view`. This method always returns values in the same order as :meth:`current_sample_ids` and :meth:`current_label_ids`. Args: path_or_expr: the values to extract, which can be: - the name of a sample field or ``embedded.field.name`` from which to extract numeric or string values - a :class:`fiftyone.core.expressions.ViewExpression` defining numeric or string values to compute via :meth:`fiftyone.core.collections.SampleCollection.values` Returns: a list of values """ samples = self.view patches_field = self.config.patches_field if patches_field is not None: ids = self.current_label_ids else: ids = self.current_sample_ids return fbu.get_values( samples, path_or_expr, ids, patches_field=patches_field ) def sort_by_similarity( self, query, k=None, reverse=False, aggregation="mean", dist_field=None, _mongo=False, ): """Returns a view that sorts the samples/labels in :meth:`view` by similarity to the specified query. When querying by IDs, the query can be any ID(s) in the full index of this instance, even if the current :meth:`view` contains a subset of the full index. Args: query: the query, which can be any of the following: - an ID or iterable of IDs - a ``num_dims`` vector or ``num_queries x num_dims`` array of vectors - a prompt or iterable of prompts (if supported by the index) k (None): the number of matches to return. Some backends may support ``None``, in which case all samples will be sorted reverse (False): whether to sort by least similarity (True) or greatest similarity (False). Some backends may not support least similarity aggregation ("mean"): the aggregation method to use when multiple queries are provided. The default is ``"mean"``, which means that the query vectors are averaged prior to searching. Some backends may support additional options dist_field (None): the name of a float field in which to store the distance of each example to the specified query. The field is created if necessary Returns: a :class:`fiftyone.core.view.DatasetView` """ samples = self.view patches_field = self.config.patches_field selecting_samples = patches_field is None or isinstance( samples, fop.PatchesView ) kwargs = dict( query=self._parse_query(query), k=k, reverse=reverse, aggregation=aggregation, return_dists=dist_field is not None, ) if dist_field is not None: sample_ids, label_ids, dists = self._kneighbors(**kwargs) else: sample_ids, label_ids = self._kneighbors(**kwargs) if selecting_samples: if patches_field is not None: ids = label_ids else: ids = sample_ids else: ids = label_ids # Store query distances if dist_field is not None: if selecting_samples: values = dict(zip(ids, dists)) samples.set_values(dist_field, values, key_field="id") else: label_type, path = samples._get_label_field_path( patches_field, dist_field ) if issubclass(label_type, fol._LABEL_LIST_FIELDS): samples._set_list_values_by_id( path, sample_ids, label_ids, dists, path.rsplit(".", 1)[0], ) else: values = dict(zip(sample_ids, dists)) samples.set_values(path, values, key_field="id") # Construct sorted view stages = [] if selecting_samples: stage = fos.Select(ids, ordered=True) stages.append(stage) else: # Sorting by object similarity but this is not a patches view, so # arrange the samples in order of their first occuring label result_sample_ids = _unique_no_sort(sample_ids) stage = fos.Select(result_sample_ids, ordered=True) stages.append(stage) if k is not None: _ids = [ObjectId(_id) for _id in ids] stage = fos.FilterLabels(patches_field, F("_id").is_in(_ids)) stages.append(stage) if _mongo: pipeline = [] for stage in stages: stage.validate(samples) pipeline.extend(stage.to_mongo(samples)) return pipeline view = samples for stage in stages: view = view.add_stage(stage) return view def _parse_query(self, query): if query is None: raise ValueError("At least one query must be provided") if isinstance(query, np.ndarray): # Query by vector(s) if query.size == 0: raise ValueError("At least one query vector must be provided") return query if etau.is_str(query): query = [query] else: query = list(query) if not query: raise ValueError("At least one query must be provided") if etau.is_numeric(query[0]): return np.asarray(query) try: ObjectId(query[0]) is_prompts = False except: is_prompts = True if is_prompts: if not self.supports_prompts: raise ValueError( "Invalid query '%s'; this model does not support prompts" % query[0] ) model = self.get_model() with model: return model.embed_prompts(query) return query def _kneighbors( self, query=None, k=None, reverse=False, aggregation=None, return_dists=False, ): """Returns the k-nearest neighbors for the given query. This method should only return results from the current :meth:`view`. Args: query (None): the query, which can be any of the following: - an ID or list of IDs for which to return neighbors - an embedding or ``num_queries x num_dim`` array of embeddings for which to return neighbors - Some backends may also support ``None``, in which case the neighbors for all points in the current :meth:`view` are returned k (None): the number of neighbors to return. Some backends may enforce upper bounds on this parameter reverse (False): whether to sort by least similarity (True) or greatest similarity (False). Some backends may not support least similarity aggregation (None): an optional aggregation method to use when multiple queries are provided. All backends must support ``"mean"``, which averages query vectors prior to searching. Backends may support additional options as well return_dists (False): whether to return query-neighbor distances Returns: the query result, in one of the following formats: - a ``(sample_ids, label_ids, dists)`` tuple, when ``return_dists`` is True - a ``(sample_ids, label_ids)`` tuple, when ``return_dists`` is False In the above, ``sample_ids`` and ``label_ids`` (if applicable) contain the IDs of the nearest neighbors, in one of the following formats: - a list of nearest neighbor IDs, when a single query ID or vector is provided, **or** when an ``aggregation`` is provided - a list of lists of nearest neighbor IDs, when multiple query IDs/vectors and no ``aggregation`` is provided and ``dists`` contains the corresponding query-neighbor distances for each result. If the backend supports full index queries (``query=None``), then ``inds`` are returned rather than ``(sample_ids, label_ids)``, in the following format: - a list of arrays of the **integer indexes** (not IDs) of nearest neighbor points for every vector in the index, when no query is provided """ raise NotImplementedError("subclass must implement _kneighbors()") def get_model(self): """Returns the stored model for this index. Returns: a :class:`fiftyone.core.models.Model` """ if self._model is None: model = self.config.model if model is None: raise ValueError("These results don't have a stored model") if etau.is_str(model): model_kwargs = self.config.model_kwargs or {} model = foz.load_zoo_model(model, **model_kwargs) self._model = model return self._model def compute_embeddings( self, samples, model=None, batch_size=None, num_workers=None, skip_failures=True, skip_existing=False, warn_existing=False, force_square=False, alpha=None, progress=None, ): """Computes embeddings for the given samples using this backend's model. Args: samples: a :class:`fiftyone.core.collections.SampleCollection` model (None): a :class:`fiftyone.core.models.Model` to apply. If not provided, these results must have been created with a stored model, which will be used by default batch_size (None): an optional batch size to use when computing embeddings. Only applicable when a ``model`` is provided num_workers (None): the number of workers to use when loading images. Only applicable when a Torch-based model is being used to compute embeddings skip_failures (True): whether to gracefully continue without raising an error if embeddings cannot be generated for a sample skip_existing (False): whether to skip generating embeddings for sample/label IDs that are already in the index warn_existing (False): whether to log a warning if any IDs already exist in the index force_square (False): whether to minimally manipulate the patch bounding boxes into squares prior to extraction. Only applicable when a ``model`` and ``patches_field`` are specified alpha (None): an optional expansion/contraction to apply to the patches before extracting them, in ``[-1, inf)``. If provided, the length and width of the box are expanded (or contracted, when ``alpha < 0``) by ``(100 * alpha)%``. For example, set ``alpha = 0.1`` to expand the boxes by 10%, and set ``alpha = -0.1`` to contract the boxes by 10%. Only applicable when a ``model`` and ``patches_field`` are specified progress (None): whether to render a progress bar (True/False), use the default value ``fiftyone.config.show_progress_bars`` (None), or a progress callback function to invoke instead Returns: a tuple of: - a ``num_embeddings x num_dims`` array of embeddings - a ``num_embeddings`` array of sample IDs - a ``num_embeddings`` array of label IDs, if applicable, or else ``None`` """ if model is None: model = self.get_model() if skip_existing: if self.config.patches_field is not None: index_ids = self.label_ids else: index_ids = self.sample_ids if index_ids is not None: samples = fbu.skip_ids( samples, index_ids, patches_field=self.config.patches_field, warn_existing=warn_existing, ) else: logger.warning( "This index does not support skipping existing IDs" ) if self.config.roi_field is not None: patches_field = self.config.roi_field handle_missing = "image" agg_fcn = lambda e: np.mean(e, axis=0) else: patches_field = self.config.patches_field handle_missing = "skip" agg_fcn = None return fbu.get_embeddings( samples, model=model, patches_field=patches_field, force_square=force_square, alpha=alpha, handle_missing=handle_missing, agg_fcn=agg_fcn, batch_size=batch_size, num_workers=num_workers, skip_failures=skip_failures, progress=progress, ) @classmethod def _from_dict(cls, d, samples, config, brain_key): """Builds a :class:`SimilarityIndex` from a JSON representation of it. Args: d: a JSON dict samples: the :class:`fiftyone.core.collections.SampleCollection` for the run config: the :class:`SimilarityConfig` for the run brain_key: the brain key Returns: a :class:`SimilarityIndex` """ raise NotImplementedError("subclass must implement _from_dict()") class DuplicatesMixin(object): """Mixin for :class:`SimilarityIndex` instances that support duplicate detection operations. Similarity backends can expose this mixin simply by implementing :meth:`_radius_neighbors`. """ def __init__(self): self._thresh = None self._unique_ids = None self._duplicate_ids = None self._neighbors_map = None @property def thresh(self): """The threshold used by the last call to :meth:`find_duplicates` or :meth:`find_unique`. """ return self._thresh @property def unique_ids(self): """A list of unique IDs from the last call to :meth:`find_duplicates` or :meth:`find_unique`. """ return self._unique_ids @property def duplicate_ids(self): """A list of duplicate IDs from the last call to :meth:`find_duplicates` or :meth:`find_unique`. """ return self._duplicate_ids @property def neighbors_map(self): """A dictionary mapping IDs to lists of ``(dup_id, dist)`` tuples from the last call to :meth:`find_duplicates`. """ return self._neighbors_map def _radius_neighbors(self, query=None, thresh=None, return_dists=False): """Returns the neighbors within the given distance threshold for the given query. This method should only return results from the current :meth:`view`. Args: query (None): the query, which can be any of the following: - an ID or list of IDs for which to return neighbors - an embedding or ``num_queries x num_dim`` array of embeddings for which to return neighbors - ``None``, in which case the neighbors for all points in the current :meth:`view` are returned thresh (None): the distance threshold to use return_dists (False): whether to return query-neighbor distances Returns: the query result, in one of the following formats: - a ``(sample_ids, label_ids, dists)`` tuple, when ``return_dists`` is True - a ``(sample_ids, label_ids)`` tuple, when ``return_dists`` is False In the above, ``sample_ids`` and ``label_ids`` (if applicable) contain the IDs of the nearest neighbors, in one of the following formats: - a list of nearest neighbor IDs, when a single query ID or vector is provided, **or** when an ``aggregation`` is provided - a list of lists of nearest neighbor IDs, when multiple query IDs/vectors and no ``aggregation`` is provided and ``dists`` contains the corresponding query-neighbor distances for each result. If the backend supports full index queries (``query=None``), then ``inds`` are returned rather than ``(sample_ids, label_ids)``, in the following format: - a list of arrays of the **integer indexes** (not IDs) of nearest neighbor points for every vector in the index, when no query is provided """ raise NotImplementedError( "subclass must implement _radius_neighbors()" ) def find_duplicates(self, thresh=None, fraction=None): """Queries the index to find near-duplicate examples based on the provided parameters. Calling this method populates the :meth:`unique_ids`, :meth:`duplicate_ids`, :attr:`neighbors_map`, and :attr:`thresh` properties of this object with the results of the query. Use :meth:`duplicates_view` and :meth:`visualize_duplicates` to analyze the results generated by this method. Args: thresh (None): a distance threshold to use to determine duplicates. If specified, the non-duplicate set will be the (approximately) largest set such that all pairwise distances between non-duplicate examples are greater than this threshold fraction (None): a desired fraction of images/patches to tag as duplicates, in ``[0, 1]``. In this case ``thresh`` is automatically tuned to achieve the desired fraction of duplicates """ if self.config.patches_field is not None: logger.info("Computing duplicate patches...") ids = self.current_label_ids else: logger.info("Computing duplicate samples...") ids = self.current_sample_ids # Detect duplicates if fraction is not None: num_keep = int(round(min(max(0, 1.0 - fraction), 1) * len(ids))) unique_ids, thresh = self._remove_duplicates_count( num_keep, ids, init_thresh=thresh ) else: unique_ids = self._remove_duplicates_thresh(thresh, ids) _unique_ids = set(unique_ids) duplicate_ids = [_id for _id in ids if _id not in _unique_ids] # Locate nearest non-duplicate for each duplicate if unique_ids and duplicate_ids: if self.config.patches_field is not None: unique_view = self._samples.select_labels( ids=unique_ids, fields=self.config.patches_field ) else: unique_view = self._samples.select(unique_ids) with self.use_view(unique_view): _sample_ids, _label_ids, dists = self._kneighbors( query=duplicate_ids, k=1, return_dists=True ) if self.config.patches_field is not None: nearest_ids = _label_ids else: nearest_ids = _sample_ids neighbors_map = defaultdict(list) for dup_id, _ids, _dists in zip(duplicate_ids, nearest_ids, dists): neighbors_map[_ids[0]].append((dup_id, _dists[0])) neighbors_map = { k: sorted(v, key=lambda t: t[1]) for k, v in neighbors_map.items() } else: neighbors_map = {} logger.info("Duplicates computation complete") self._thresh = thresh self._unique_ids = unique_ids self._duplicate_ids = duplicate_ids self._neighbors_map = neighbors_map def find_unique(self, count): """Queries the index to select a subset of examples of the specified size that are maximally unique with respect to each other. Calling this method populates the :meth:`unique_ids`, :meth:`duplicate_ids`, and :attr:`thresh` properties of this object with the results of the query. Use :meth:`unique_view` and :meth:`visualize_unique` to analyze the results generated by this method. Args: count: the desired number of unique examples """ if self.config.patches_field is not None: logger.info("Computing unique patches...") ids = self.current_label_ids else: logger.info("Computing unique samples...") ids = self.current_sample_ids unique_ids, thresh = self._remove_duplicates_count(count, ids) _unique_ids = set(unique_ids) duplicate_ids = [_id for _id in ids if _id not in _unique_ids] logger.info("Uniqueness computation complete") self._thresh = thresh self._unique_ids = unique_ids self._duplicate_ids = duplicate_ids self._neighbors_map = None def _remove_duplicates_count(self, num_keep, ids, init_thresh=None): if init_thresh is not None: thresh = init_thresh else: thresh = 1 if num_keep <= 0: logger.info( "threshold: -, kept: %d, target: %d", num_keep, num_keep ) return set(), None if num_keep >= len(ids): logger.info( "threshold: -, kept: %d, target: %d", num_keep, num_keep ) return set(ids), None thresh_lims = [0, None] num_target = num_keep num_keep = -1 while True: keep_ids = self._remove_duplicates_thresh(thresh, ids) num_keep_last = num_keep num_keep = len(keep_ids) logger.info( "threshold: %f, kept: %d, target: %d", thresh, num_keep, num_target, ) if num_keep == num_target or ( num_keep == num_keep_last and thresh_lims[1] is not None and thresh_lims[1] - thresh_lims[0] < 1e-6 ): break if num_keep < num_target: # Need to decrease threshold thresh_lims[1] = thresh thresh = 0.5 * (thresh_lims[0] + thresh) else: # Need to increase threshold thresh_lims[0] = thresh if thresh_lims[1] is not None: thresh = 0.5 * (thresh + thresh_lims[1]) else: thresh *= 2 return keep_ids, thresh def _remove_duplicates_thresh(self, thresh, ids): nearest_inds = self._radius_neighbors(thresh=thresh) n = len(ids) keep = set(range(n)) for ind in range(n): if ind in keep: keep -= {i for i in nearest_inds[ind] if i > ind} return [ids[i] for i in keep] def plot_distances(self, bins=100, log=False, backend="plotly", **kwargs): """Plots a histogram of the distance between each example and its nearest neighbor. If `:meth:`find_duplicates` or :meth:`find_unique` has been executed, the threshold used is also indicated on the plot. Args: bins (100): the number of bins to use log (False): whether to use a log scale y-axis backend ("plotly"): the plotting backend to use. Supported values are ``("plotly", "matplotlib")`` **kwargs: keyword arguments for the backend plotting method Returns: one of the following: - a :class:`fiftyone.core.plots.plotly.PlotlyNotebookPlot`, if you are working in a notebook context and the plotly backend is used - a plotly or matplotlib figure, otherwise """ metric = self.config.metric thresh = self.thresh _, dists = self._kneighbors(k=1, return_dists=True) dists = np.array([d[0] for d in dists]) if backend == "matplotlib": return _plot_distances_mpl( dists, metric, thresh, bins, log, **kwargs ) return _plot_distances_plotly( dists, metric, thresh, bins, log, **kwargs ) def duplicates_view( self, type_field=None, id_field=None, dist_field=None, sort_by="distance", reverse=False, ): """Returns a view that contains only the duplicate examples and their corresponding nearest non-duplicate examples generated by the last call to :meth:`find_duplicates`. If you are analyzing patches, the returned view will be a :class:`fiftyone.core.patches.PatchesView`. The examples are organized so that each non-duplicate is immediately followed by all duplicate(s) that are nearest to it. Args: type_field (None): the name of a string field in which to store ``"nearest"`` and ``"duplicate"`` labels. The field is created if necessary id_field (None): the name of a string field in which to store the ID of the nearest non-duplicate for each example in the view. The field is created if necessary dist_field (None): the name of a float field in which to store the distance of each example to its nearest non-duplicate example. The field is created if necessary sort_by ("distance"): specifies how to sort the groups of duplicate examples. The supported values are: - ``"distance"``: sort the groups by the distance between the non-duplicate and its (nearest, if multiple) duplicate - ``"count"``: sort the groups by the number of duplicate examples reverse (False): whether to sort in descending order Returns: a :class:`fiftyone.core.view.DatasetView` """ if self.neighbors_map is None: raise ValueError( "You must first call `find_duplicates()` to generate results" ) samples = self.view patches_field = self.config.patches_field neighbors_map = self.neighbors_map if patches_field is not None and not isinstance( samples, fop.PatchesView ): samples = samples.to_patches(patches_field) if sort_by == "distance": key = lambda kv: min(e[1] for e in kv[1]) elif sort_by == "count": key = lambda kv: len(kv[1]) else: raise ValueError( "Invalid sort_by='%s'; supported values are %s" % (sort_by, ("distance", "count")) ) existing_ids = set(samples.values("id")) neighbors = [ (k, v) for k, v in neighbors_map.items() if k in existing_ids ] ids = [] types = {} nearest_ids = {} dists = {} for _id, duplicates in sorted(neighbors, key=key, reverse=reverse): ids.append(_id) types[_id] = "nearest" nearest_ids[_id] = _id dists[_id] = 0.0 for dup_id, dist in duplicates: ids.append(dup_id) types[dup_id] = "duplicate" nearest_ids[dup_id] = _id dists[dup_id] = dist if type_field is not None: samples.set_values(type_field, types, key_field="id") if id_field is not None: samples.set_values(id_field, nearest_ids, key_field="id") if dist_field is not None: samples.set_values(dist_field, dists, key_field="id") return samples.select(ids, ordered=True) def unique_view(self): """Returns a view that contains only the unique examples generated by the last call to :meth:`find_duplicates` or :meth:`find_unique`. If you are analyzing patches, the returned view will be a :class:`fiftyone.core.patches.PatchesView`. Returns: a :class:`fiftyone.core.view.DatasetView` """ if self.unique_ids is None: raise ValueError( "You must first call `find_unique()` or `find_duplicates()` " "to generate results" ) samples = self.view patches_field = self.config.patches_field unique_ids = self.unique_ids if patches_field is not None and not isinstance( samples, fop.PatchesView ): samples = samples.to_patches(patches_field) return samples.select(unique_ids) def visualize_duplicates(self, visualization, backend="plotly", **kwargs): """Generates an interactive scatterplot of the results generated by the last call to :meth:`find_duplicates`. The ``visualization`` argument can be any visualization computed on the same dataset (or subset of it) as long as it contains every sample/object in the view whose results you are visualizing. The points are colored based on the following partition: - "duplicate": duplicate example - "nearest": nearest neighbor of a duplicate example - "unique": the remaining unique examples Edges are also drawn between each duplicate and its nearest non-duplicate neighbor. You can attach plots generated by this method to an App session via its :attr:`fiftyone.core.session.Session.plots` attribute, which will automatically sync the session's view with the currently selected points in the plot. Args: visualization: a :class:`fiftyone.brain.visualization.VisualizationResults` instance to use to visualize the results backend ("plotly"): the plotting backend to use. Supported values are ``("plotly", "matplotlib")`` **kwargs: keyword arguments for the backend plotting method: - "plotly" backend: :meth:`fiftyone.core.plots.plotly.scatterplot` - "matplotlib" backend: :meth:`fiftyone.core.plots.matplotlib.scatterplot` Returns: a :class:`fiftyone.core.plots.base.InteractivePlot` """ if self.neighbors_map is None: raise ValueError( "You must first call `find_duplicates()` to generate results" ) samples = self.view duplicate_ids = self.duplicate_ids neighbors_map = self.neighbors_map patches_field = self.config.patches_field dup_ids = set(duplicate_ids) nearest_ids = set(neighbors_map.keys()) with visualization.use_view(samples, allow_missing=True): if patches_field is not None: ids = visualization.current_label_ids else: ids = visualization.current_sample_ids labels = [] for _id in ids: if _id in dup_ids: label = "duplicate" elif _id in nearest_ids: label = "nearest" else: label = "unique" labels.append(label) if backend == "plotly": kwargs["edges"] = _build_edges(ids, neighbors_map) kwargs["edges_title"] = "neighbors" kwargs["labels_title"] = "type" return visualization.visualize( labels=labels, classes=["unique", "nearest", "duplicate"], backend=backend, **kwargs, ) def visualize_unique(self, visualization, backend="plotly", **kwargs): """Generates an interactive scatterplot of the results generated by the last call to :meth:`find_unique`. The ``visualization`` argument can be any visualization computed on the same dataset (or subset of it) as long as it contains every sample/object in the view whose results you are visualizing. The points are colored based on the following partition: - "unique": the unique examples - "other": the other examples You can attach plots generated by this method to an App session via its :attr:`fiftyone.core.session.Session.plots` attribute, which will automatically sync the session's view with the currently selected points in the plot. Args: visualization: a :class:`fiftyone.brain.visualization.VisualizationResults` instance to use to visualize the results backend ("plotly"): the plotting backend to use. Supported values are ``("plotly", "matplotlib")`` **kwargs: keyword arguments for the backend plotting method: - "plotly" backend: :meth:`fiftyone.core.plots.plotly.scatterplot` - "matplotlib" backend: :meth:`fiftyone.core.plots.matplotlib.scatterplot` Returns: a :class:`fiftyone.core.plots.base.InteractivePlot` """ if self.unique_ids is None: raise ValueError( "You must first call `find_unique()` to generate results" ) samples = self.view unique_ids = self.unique_ids patches_field = self.config.patches_field unique_ids = set(unique_ids) with visualization.use_view(samples, allow_missing=True): if patches_field is not None: ids = visualization.current_label_ids else: ids = visualization.current_sample_ids labels = [] for _id in ids: if _id in unique_ids: label = "unique" else: label = "other" labels.append(label) return visualization.visualize( labels=labels, classes=["other", "unique"], backend=backend, **kwargs, ) def _unique_no_sort(values): seen = set() return [v for v in values if v not in seen and not seen.add(v)] def _build_edges(ids, neighbors_map): inds_map = {_id: idx for idx, _id in enumerate(ids)} edges = [] for nearest_id, duplicates in neighbors_map.items(): nearest_ind = inds_map[nearest_id] for dup_id, _ in duplicates: dup_ind = inds_map[dup_id] edges.append((dup_ind, nearest_ind)) return np.array(edges) def _plot_distances_plotly(dists, metric, thresh, bins, log, **kwargs): import plotly.graph_objects as go import fiftyone.core.plots.plotly as fopl counts, edges = np.histogram(dists, bins=bins) left_edges = edges[:-1] widths = edges[1:] - edges[:-1] customdata = np.stack((edges[:-1], edges[1:]), axis=1) hover_lines = [ "count: %{y}", "distance: [%{customdata[0]:.2f}, %{customdata[1]:.2f}]", ] hovertemplate = "
".join(hover_lines) + "" bar = go.Bar( x=left_edges, y=counts, width=widths, customdata=customdata, offset=0, marker_color="#FF6D04", hovertemplate=hovertemplate, showlegend=False, ) traces = [bar] if thresh is not None: line = go.Scatter( x=[thresh, thresh], y=[0, max(counts)], mode="lines", line=dict(color="#17191C", width=3), hovertemplate="thresh: %{x}", showlegend=False, ) traces.append(line) figure = go.Figure(traces) figure.update_layout( xaxis_title="nearest neighbor distance (%s)" % metric, yaxis_title="count", hovermode="x", yaxis_rangemode="tozero", ) if log: figure.update_layout(yaxis_type="log") figure.update_layout(**fopl._DEFAULT_LAYOUT) figure.update_layout(**kwargs) if foc.is_jupyter_context(): figure = fopl.PlotlyNotebookPlot(figure) return figure def _plot_distances_mpl( dists, metric, thresh, bins, log, ax=None, figsize=None, **kwargs ): import matplotlib.pyplot as plt if ax is None: fig, ax = plt.subplots() else: fig = ax.figure counts, edges = np.histogram(dists, bins=bins) left_edges = edges[:-1] widths = edges[1:] - edges[:-1] ax.bar( left_edges, counts, width=widths, align="edge", color="#FF6D04", **kwargs, ) if thresh is not None: ax.vlines(thresh, 0, max(counts), color="#17191C", linewidth=3) if log: ax.set_yscale("log") ax.set_xlabel("nearest neighbor distance (%s)" % metric) ax.set_ylabel("count") if figsize is not None: fig.set_size_inches(*figsize) plt.tight_layout() return fig ================================================ FILE: fiftyone/brain/visualization.py ================================================ """ Visualization interface. | Copyright 2017-2026, Voxel51, Inc. | `voxel51.com `_ | """ from copy import deepcopy import inspect import logging from packaging import version import numpy as np import sklearn import sklearn.decomposition as skd import sklearn.manifold as skm import eta.core.utils as etau import fiftyone.brain as fb import fiftyone.core.brain as fob import fiftyone.core.expressions as foe import fiftyone.core.fields as fof import fiftyone.core.plots as fop import fiftyone.core.utils as fou import fiftyone.core.validation as fov fbu = fou.lazy_import("fiftyone.brain.internal.core.utils") umap = fou.lazy_import("umap") logger = logging.getLogger(__name__) _DEFAULT_MODEL = "mobilenet-v2-imagenet-torch" _DEFAULT_BATCH_SIZE = None def compute_visualization( samples, patches_field, embeddings, points, create_index, points_field, brain_key, num_dims, method, similarity_index, model, model_kwargs, force_square, alpha, batch_size, num_workers, skip_failures, progress, **kwargs, ): """See ``fiftyone/brain/__init__.py``.""" fov.validate_collection(samples) if method == "manual" and points is None: raise ValueError( "You must provide your own `points` when `method='manual'`" ) if points is not None: method = "manual" model = None embeddings = None embeddings_field = None num_dims = _get_dimension(points) if create_index and points_field is None: points_field = brain_key if points_field is not None and num_dims != 2: raise ValueError("`points_field` is only supported when `num_dims=2`") if etau.is_str(embeddings): embeddings_field, embeddings_exist = fbu.parse_data_field( samples, embeddings, patches_field=patches_field, data_type="embeddings", ) embeddings = None else: embeddings_field = None embeddings_exist = None if points_field is not None: points_field, _ = fbu.parse_data_field( samples, points_field, patches_field=patches_field, data_type="points", ) if etau.is_str(similarity_index): similarity_index = samples.load_brain_results(similarity_index) if ( model is None and points is None and embeddings is None and similarity_index is None and not embeddings_exist ): model = _DEFAULT_MODEL if batch_size is None: batch_size = _DEFAULT_BATCH_SIZE config = _parse_config( method, embeddings_field=embeddings_field, points_field=points_field, similarity_index=similarity_index, model=model, model_kwargs=model_kwargs, patches_field=patches_field, num_dims=num_dims, **kwargs, ) brain_method = config.build() brain_method.ensure_requirements() if brain_key is not None: brain_method.register_run(samples, brain_key) if points is None: embeddings, sample_ids, label_ids = fbu.get_embeddings( samples, model=model, model_kwargs=model_kwargs, patches_field=patches_field, embeddings_field=embeddings_field, embeddings=embeddings, similarity_index=similarity_index, force_square=force_square, alpha=alpha, batch_size=batch_size, num_workers=num_workers, skip_failures=skip_failures, progress=progress, ) logger.info("Generating visualization...") points = brain_method.fit(embeddings) else: points, sample_ids, label_ids = fbu.parse_data( samples, patches_field=patches_field, data=points, data_type="points", ) if points_field is not None: _generate_spatial_index( samples, points, points_field, sample_ids, label_ids=label_ids, patches_field=patches_field, create_index=create_index, progress=progress, ) results = VisualizationResults( samples, config, brain_key, points, sample_ids=sample_ids, label_ids=label_ids, ) brain_method.save_run_results(samples, brain_key, results) return results def values(results, path_or_expr): samples = results.view patches_field = results.config.patches_field if patches_field is not None: ids = results.current_label_ids else: ids = results.current_sample_ids return fbu.get_values( samples, path_or_expr, ids, patches_field=patches_field ) def visualize( results, labels=None, sizes=None, classes=None, backend="plotly", **kwargs, ): points = results.current_points samples = results.view patches_field = results.config.patches_field good_inds = results._curr_good_inds if patches_field is not None: ids = results.current_label_ids else: ids = results.current_sample_ids if good_inds is not None: if etau.is_container(labels) and not _is_expr(labels): labels = fbu.filter_values( labels, good_inds, patches_field=patches_field ) if etau.is_container(sizes) and not _is_expr(sizes): sizes = fbu.filter_values( sizes, good_inds, patches_field=patches_field ) if labels is not None and _is_expr(labels): labels = fbu.get_values( samples, labels, ids, patches_field=patches_field ) if sizes is not None and _is_expr(sizes): sizes = fbu.get_values( samples, sizes, ids, patches_field=patches_field ) return fop.scatterplot( points, samples=samples, ids=ids, link_field=patches_field, labels=labels, sizes=sizes, classes=classes, backend=backend, **kwargs, ) def _is_expr(arg): return isinstance(arg, (foe.ViewExpression, dict)) def _parse_config(name, **kwargs): if name is None: name = fb.brain_config.default_visualization_method if inspect.isclass(name): return name(**kwargs) methods = fb.brain_config.visualization_methods if name not in methods: raise ValueError( "Unsupported method '%s'. The available methods are %s" % (name, sorted(methods.keys())) ) params = deepcopy(methods[name]) config_cls = kwargs.pop("config_cls", None) if config_cls is None: config_cls = params.pop("config_cls", None) if config_cls is None: raise ValueError( "Visualization method '%s' has no `config_cls`" % name ) if etau.is_str(config_cls): config_cls = etau.get_class(config_cls) params.update(**kwargs) return config_cls(**params) def _get_dimension(points): if isinstance(points, dict): points = next(iter(points.values()), None) if isinstance(points, list): points = next(iter(points), None) if points is None: return 2 return points.shape[-1] def _generate_spatial_index( samples, points, points_field, sample_ids, label_ids=None, patches_field=None, create_index=True, progress=False, ): # Indexes are not currently usable on patch visualizations if create_index and patches_field is not None: create_index = False dataset = samples._root_dataset if patches_field is not None: _, points_field = dataset._get_label_field_path( patches_field, points_field ) logger.info("Generating spatial index in field '%s'...", points_field) dataset.add_sample_field( points_field, fof.ListField, subfield=fof.FloatField ) points = points.astype(float) if create_index: min_val, max_val = points.min(), points.max() dataset.create_index([(points_field, "2d")], min=min_val, max=max_val) points = points.tolist() if patches_field is not None: values = dict(zip(label_ids, points)) dataset.set_label_values(points_field, values, progress=progress) else: values = dict(zip(sample_ids, points)) dataset.set_values( points_field, values, key_field="id", progress=progress ) class VisualizationResults(fob.BrainResults): """Class storing the results of :meth:`fiftyone.brain.compute_visualization`. Args: samples: the :class:`fiftyone.core.collections.SampleCollection` used config: the :class:`VisualizationConfig` used brain_key: the brain key points: a ``num_points x num_dims`` array of visualization points sample_ids (None): a ``num_points`` array of sample IDs label_ids (None): a ``num_points`` array of label IDs, if applicable backend (None): a :class:`Visualization` backend """ def __init__( self, samples, config, brain_key, points, sample_ids=None, label_ids=None, backend=None, ): super().__init__(samples, config, brain_key, backend=backend) if sample_ids is None: sample_ids, label_ids = fbu.get_ids( samples, patches_field=config.patches_field, data=points, data_type="points", ) self.points = points self.sample_ids = sample_ids self.label_ids = label_ids self._last_view = None self._curr_view = None self._curr_points = None self._curr_sample_ids = None self._curr_label_ids = None self._curr_keep_inds = None self._curr_good_inds = None self.use_view(samples) def __enter__(self): self._last_view = self.view return self def __exit__(self, *args): self.use_view(self._last_view) self._last_view = None @property def config(self): """The :class:`VisualizationConfig` for the results.""" return self._config @property def index_size(self): """The number of active points in the index. If :meth:`use_view` has been called to restrict the index, this property will reflect the size of the active index. """ return len(self._curr_sample_ids) @property def total_index_size(self): """The total number of data points in the index. If :meth:`use_view` has been called to restrict the index, this value may be larger than the current :meth:`index_size`. """ return len(self.points) @property def missing_size(self): """The total number of data points in :meth:`view` that are missing from this index. This property is only applicable when :meth:`use_view` has been called, and it will be ``None`` if no data points are missing. """ good = self._curr_good_inds if good is None: return None return good.size - np.count_nonzero(good) @property def current_points(self): """The currently active points in the index. If :meth:`use_view` has been called, this may be a subset of the full index. """ return self._curr_points @property def current_sample_ids(self): """The sample IDs of the currently active points in the index. If :meth:`use_view` has been called, this may be a subset of the full index. """ return self._curr_sample_ids @property def current_label_ids(self): """The label IDs of the currently active points in the index, or ``None`` if not applicable. If :meth:`use_view` has been called, this may be a subset of the full index. """ return self._curr_label_ids @property def view(self): """The :class:`fiftyone.core.collections.SampleCollection` against which results are currently being generated. If :meth:`use_view` has been called, this view may be different than the collection on which the full index was generated. """ return self._curr_view @property def has_spatial_index(self): """Whether these results have a spatial index. Use :meth:`index_points` to add a spatial index to an existing set of visualization results. """ return self.config.points_field is not None def use_view( self, sample_collection, allow_missing=True, warn_missing=False ): """Restricts the index to the provided view. Subsequent calls to methods on this instance will only contain results from the specified view rather than the full index. Use :meth:`clear_view` to reset to the full index. Or, equivalently, use the context manager interface as demonstrated below to automatically reset the view when the context exits. Example usage:: import fiftyone as fo import fiftyone.brain as fob import fiftyone.zoo as foz dataset = foz.load_zoo_dataset("quickstart") results = fob.compute_visualization(dataset) print(results.index_size) # 200 view = dataset.take(50) with results.use_view(view): print(results.index_size) # 50 plot = results.visualize() plot.show() Args: sample_collection: a :class:`fiftyone.core.collections.SampleCollection` allow_missing (True): whether to allow the provided collection to contain data points that this index does not contain (True) or whether to raise an error in this case (False) warn_missing (False): whether to log a warning if the provided collection contains data points that this index does not contain Returns: self """ sample_ids, label_ids, keep_inds, good_inds = fbu.filter_ids( sample_collection, self.sample_ids, self.label_ids, patches_field=self._config.patches_field, allow_missing=allow_missing, warn_missing=warn_missing, ) if keep_inds is not None: points = self.points[keep_inds, :] else: points = self.points self._curr_view = sample_collection self._curr_points = points self._curr_sample_ids = sample_ids self._curr_label_ids = label_ids self._curr_keep_inds = keep_inds self._curr_good_inds = good_inds return self def clear_view(self): """Clears the view set by :meth:`use_view`, if any. Subsequent operations will be performed on the full index. """ self.use_view(self._samples) def values(self, path_or_expr): """Extracts a flat list of values from the given field or expression corresponding to the current :meth:`view`. This method always returns values in the same order as :meth:`current_points`, :meth:`current_sample_ids`, and :meth:`current_label_ids`. Args: path_or_expr: the values to extract, which can be: - the name of a sample field or ``embedded.field.name`` from which to extract numeric or string values - a :class:`fiftyone.core.expressions.ViewExpression` defining numeric or string values to compute via :meth:`fiftyone.core.collections.SampleCollection.values` Returns: a list of values """ return values(self, path_or_expr) def visualize( self, labels=None, sizes=None, classes=None, backend="plotly", **kwargs, ): """Generates an interactive scatterplot of the visualization results for the current :meth:`view`. This method supports 2D or 3D visualizations, but interactive point selection is only available in 2D. You can use the ``labels`` parameters to define a coloring for the points, and you can use the ``sizes`` parameter to scale the sizes of the points. You can attach plots generated by this method to an App session via its :attr:`fiftyone.core.session.Session.plots` attribute, which will automatically sync the session's view with the currently selected points in the plot. Args: labels (None): data to use to color the points. Can be any of the following: - the name of a sample field or ``embedded.field.name`` from which to extract numeric or string values - a :class:`fiftyone.core.expressions.ViewExpression` defining numeric or string values to compute via :meth:`fiftyone.core.collections.SampleCollection.values` - a list or array-like of numeric or string values - a list of lists of numeric or string values, if the data in this visualization corresponds to a label list field like :class:`fiftyone.core.labels.Detections` sizes (None): data to use to scale the sizes of the points. Can be any of the following: - the name of a sample field or ``embedded.field.name`` from which to extract numeric values - a :class:`fiftyone.core.expressions.ViewExpression` defining numeric values to compute via :meth:`fiftyone.core.collections.SampleCollection.values` - a list or array-like of numeric values - a list of lists of numeric values, if the data in this visualization corresponds to a label list field like :class:`fiftyone.core.labels.Detections` classes (None): an optional list of classes whose points to plot. Only applicable when ``labels`` contains strings backend ("plotly"): the plotting backend to use. Supported values are ``("plotly", "matplotlib")`` **kwargs: keyword arguments for the backend plotting method: - "plotly" backend: :meth:`fiftyone.core.plots.plotly.scatterplot` - "matplotlib" backend: :meth:`fiftyone.core.plots.matplotlib.scatterplot` Returns: an :class:`fiftyone.core.plots.base.InteractivePlot` """ return visualize( self, labels=labels, sizes=sizes, classes=classes, backend=backend, **kwargs, ) def index_points( self, points_field=None, create_index=True, progress=None, ): """Adds a spatial index for these visualization results to its dataset's samples. This method is useful if you want to add a spatial index to existing visualization results that don't yet have one. Spatial indexes are highly recommended for large datasets as they enable efficient querying when lassoing points in embeddings plots. Args: points_field (None): an optional field name in which to store the spatial index. The default is the result's ``brain_key`` create_index (True): whether to create a database index for the points progress (None): whether to render a progress bar (True/False), use the default value ``fiftyone.config.show_progress_bars`` (None), or a progress callback function to invoke instead """ if points_field is None: if self.key is None: raise ValueError( "You must provide a `points_field` when indexing points " "that are not associated with a brain key" ) points_field = self.key _generate_spatial_index( self.samples, self.points, points_field, self.sample_ids, label_ids=self.label_ids, patches_field=self.config.patches_field, create_index=create_index, progress=progress, ) if self.key is not None: self.config.points_field = points_field self.save_config() def remove_index(self): """Removes the spatial index from these visualization results, if one exists. """ points_field = self.config.points_field if points_field is None: return dataset = self.samples._root_dataset if self.config.patches_field is not None: _, points_field = dataset._get_label_field_path( self.config.patches_field, points_field ) dataset.delete_sample_field(points_field, error_level=1) if self.key is not None: self.config.points_field = None self.save_config() @classmethod def _from_dict(cls, d, samples, config, brain_key): points = np.array(d["points"]) sample_ids = d.get("sample_ids", None) if sample_ids is not None: sample_ids = np.array(sample_ids) label_ids = d.get("label_ids", None) if label_ids is not None: label_ids = np.array(label_ids) return cls( samples, config, brain_key, points, sample_ids=sample_ids, label_ids=label_ids, ) class VisualizationConfig(fob.BrainMethodConfig): """Base class for configuring visualization methods. Args: embeddings_field (None): the sample field containing the embeddings, if one was provided points_field (None): the name of a field in which to store the visualization points, if requested similarity_index (None): the similarity index containing the embeddings, if one was provided model (None): the :class:`fiftyone.core.models.Model` or name of the zoo model that was used to compute embeddings, if known model_kwargs (None): a dictionary of optional keyword arguments to pass to the model's ``Config`` when a model name is provided patches_field (None): the sample field defining the patches being analyzed, if any num_dims (2): the dimension of the visualization space """ def __init__( self, embeddings_field=None, points_field=None, similarity_index=None, model=None, model_kwargs=None, patches_field=None, num_dims=2, **kwargs, ): if similarity_index is not None and not etau.is_str(similarity_index): similarity_index = similarity_index.key if model is not None and not etau.is_str(model): model = None self.embeddings_field = embeddings_field self.points_field = points_field self.similarity_index = similarity_index self.model = model self.model_kwargs = model_kwargs self.patches_field = patches_field self.num_dims = num_dims super().__init__(**kwargs) @property def type(self): return "visualization" class Visualization(fob.BrainMethod): def fit(self, embeddings): raise NotImplementedError("subclass must implement fit()") def get_fields(self, samples, brain_key): fields = [] if self.config.patches_field is not None: fields.append(self.config.patches_field) elif self.config.points_field is not None: fields.append(self.config.points_field) return fields def rename(self, samples, key, new_key): patches_field = self.config.patches_field points_field = self.config.points_field dataset = samples._root_dataset if points_field is not None and points_field == key: old_path = key new_path = new_key if patches_field is not None: _, old_path = dataset._get_label_field_path( patches_field, old_path ) _, new_path = dataset._get_label_field_path( patches_field, new_path ) self.config.points_field = new_key self.update_run_config(samples, key, self.config) dataset.rename_sample_field(old_path, new_path) def cleanup(self, samples, key): patches_field = self.config.patches_field points_field = self.config.points_field dataset = samples._root_dataset if points_field is not None: if patches_field is not None: _, points_field = dataset._get_label_field_path( patches_field, points_field ) dataset.delete_sample_field(points_field, error_level=1) class UMAPVisualizationConfig(VisualizationConfig): """Configuration for Uniform Manifold Approximation and Projection (UMAP) embedding visualization. See https://github.com/lmcinnes/umap for more information about the supported parameters. Args: embeddings_field (None): the sample field containing the embeddings, if one was provided points_field (None): the name of a field in which to store the visualization points, if requested similarity_index (None): the similarity index containing the embeddings, if one was provided model (None): the :class:`fiftyone.core.models.Model` or name of the zoo model that was used to compute embeddings, if known model_kwargs (None): a dictionary of optional keyword arguments to pass to the model's ``Config`` when a model name is provided patches_field (None): the sample field defining the patches being analyzed, if any num_dims (2): the dimension of the visualization space num_neighbors (15): the number of neighboring points used in local approximations of manifold structure. Larger values will result in more global structure being preserved at the loss of detailed local structure. Typical values are in ``[5, 50]`` metric ("euclidean"): the metric to use when calculating distance between embeddings. See the UMAP documentation for supported values min_dist (0.1): the effective minimum distance between embedded points. This controls how tightly the embedding is allowed compress points together. Larger values ensure embedded points are more evenly distributed, while smaller values allow the algorithm to optimise more accurately with regard to local structure. Typical values are in ``[0.001, 0.5]`` seed (None): a random seed verbose (True): whether to log progress """ def __init__( self, embeddings_field=None, points_field=None, similarity_index=None, model=None, model_kwargs=None, patches_field=None, num_dims=2, num_neighbors=15, metric="euclidean", min_dist=0.1, seed=None, verbose=True, **kwargs, ): super().__init__( embeddings_field=embeddings_field, points_field=points_field, similarity_index=similarity_index, model=model, model_kwargs=model_kwargs, patches_field=patches_field, num_dims=num_dims, **kwargs, ) self.num_neighbors = num_neighbors self.metric = metric self.min_dist = min_dist self.seed = seed self.verbose = verbose @property def method(self): return "umap" class UMAPVisualization(Visualization): def ensure_requirements(self): fou.ensure_package( "umap-learn>=0.5", error_msg=( "You must install the `umap-learn>=0.5` package in order to " "use UMAP-based visualization. This is recommended, as UMAP " "is awesome! If you do not wish to install UMAP, try " "`method='tsne'` instead" ), ) def fit(self, embeddings): _umap = umap.UMAP( n_components=self.config.num_dims, n_neighbors=self.config.num_neighbors, metric=self.config.metric, min_dist=self.config.min_dist, random_state=self.config.seed, verbose=self.config.verbose, ) return _umap.fit_transform(embeddings) class TSNEVisualizationConfig(VisualizationConfig): """Configuration for t-distributed Stochastic Neighbor Embedding (t-SNE) visualization. See https://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html for more information about the supported parameters. Args: embeddings_field (None): the sample field containing the embeddings, if one was provided points_field (None): the name of a field in which to store the visualization points, if requested similarity_index (None): the similarity index containing the embeddings, if one was provided model (None): the :class:`fiftyone.core.models.Model` or name of the zoo model that was used to compute embeddings, if known model_kwargs (None): a dictionary of optional keyword arguments to pass to the model's ``Config`` when a model name is provided patches_field (None): the sample field defining the patches being analyzed, if any num_dims (2): the dimension of the visualization space pca_dims (50): the number of PCA dimensions to compute prior to running t-SNE. It is highly recommended to reduce the number of dimensions to a reasonable number (e.g. 50) before running t-SNE, as this will suppress some noise and speed up the computation of pairwise distances between samples svd_solver ("randomized"): the SVD solver to use when performing PCA. Consult the sklearn docmentation for details metric ("euclidean"): the metric to use when calculating distance between embeddings. Must be a supported value for the ``metric`` argument of ``scipy.spatial.distance.pdist`` perplexity (30.0): the perplexity to use. Perplexity is related to the number of nearest neighbors that is used in other manifold learning algorithms. Larger datasets usually require a larger perplexity. Typical values are in ``[5, 50]`` learning_rate (200.0): the learning rate to use. Typical values are in ``[10, 1000]``. If the learning rate is too high, the data may look like a ball with any point approximately equidistant from its nearest neighbours. If the learning rate is too low, most points may look compressed in a dense cloud with few outliers. If the cost function gets stuck in a bad local minimum increasing the learning rate may help max_iters (1000): the maximum number of iterations to run. Should be at least 250 seed (None): a random seed verbose (True): whether to log progress """ def __init__( self, embeddings_field=None, points_field=None, similarity_index=None, model=None, model_kwargs=None, patches_field=None, num_dims=2, pca_dims=50, svd_solver="randomized", metric="euclidean", perplexity=30.0, learning_rate=200.0, max_iters=1000, seed=None, verbose=True, **kwargs, ): super().__init__( embeddings_field=embeddings_field, points_field=points_field, similarity_index=similarity_index, model=model, model_kwargs=model_kwargs, patches_field=patches_field, num_dims=num_dims, **kwargs, ) self.pca_dims = pca_dims self.svd_solver = svd_solver self.metric = metric self.perplexity = perplexity self.learning_rate = learning_rate self.max_iters = max_iters self.seed = seed self.verbose = verbose @property def method(self): return "tsne" class TSNEVisualization(Visualization): def fit(self, embeddings): if self.config.pca_dims is not None: _pca = skd.PCA( n_components=self.config.pca_dims, svd_solver=self.config.svd_solver, random_state=self.config.seed, ) embeddings = _pca.fit_transform(embeddings) embeddings = embeddings.astype(np.float32, copy=False) verbose = 2 if self.config.verbose else 0 sklearn_version = version.parse(sklearn.__version__) iter_param = ( "max_iter" if sklearn_version >= version.parse("1.5.0") else "n_iter" ) _tsne = skm.TSNE( n_components=self.config.num_dims, perplexity=self.config.perplexity, learning_rate=self.config.learning_rate, metric=self.config.metric, init="pca", random_state=self.config.seed, verbose=verbose, **{iter_param: self.config.max_iters}, ) return _tsne.fit_transform(embeddings) class PCAVisualizationConfig(VisualizationConfig): """Configuration for principal component analysis (PCA) embedding visualization. See https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.PCA.html for more information about the supported parameters. Args: embeddings_field (None): the sample field containing the embeddings, if one was provided points_field (None): the name of a field in which to store the visualization points, if requested similarity_index (None): the similarity index containing the embeddings, if one was provided model (None): the :class:`fiftyone.core.models.Model` or name of the zoo model that was used to compute embeddings, if known model_kwargs (None): a dictionary of optional keyword arguments to pass to the model's ``Config`` when a model name is provided patches_field (None): the sample field defining the patches being analyzed, if any num_dims (2): the dimension of the visualization space svd_solver ("randomized"): the SVD solver to use. Consult the sklearn docmentation for details seed (None): a random seed """ def __init__( self, embeddings_field=None, points_field=None, similarity_index=None, model=None, model_kwargs=None, patches_field=None, num_dims=2, svd_solver="randomized", seed=None, **kwargs, ): super().__init__( embeddings_field=embeddings_field, points_field=points_field, similarity_index=similarity_index, model=model, model_kwargs=model_kwargs, patches_field=patches_field, num_dims=num_dims, **kwargs, ) self.svd_solver = svd_solver self.seed = seed @property def method(self): return "pca" class PCAVisualization(Visualization): def fit(self, embeddings): _pca = skd.PCA( n_components=self.config.num_dims, svd_solver=self.config.svd_solver, random_state=self.config.seed, ) return _pca.fit_transform(embeddings) class ManualVisualizationConfig(VisualizationConfig): """Configuration for manually-provided low-dimensional visualizations. Args: patches_field (None): the sample field defining the patches being analyzed, if any num_dims (2): the dimension of the visualization space """ def __init__(self, patches_field=None, num_dims=2, **kwargs): super().__init__( patches_field=patches_field, num_dims=num_dims, **kwargs ) @property def method(self): return "manual" class ManualVisualization(Visualization): def fit(self, embeddings): raise NotImplementedError( "The low-dimensional representation must be manually provided " "when using this method" ) ================================================ FILE: install.bat ================================================ @echo off :: Installs the `fiftyone-brain` package and its dependencies. :: :: Usage: :: .\install.bat :: :: Copyright 2017-2026, Voxel51, Inc. :: voxel51.com :: :: Commands: :: -h Display help message :: -d Install developer dependencies. set SHOW_HELP=false set DEV_INSTALL=false :parse IF "%~1"=="" GOTO endparse IF "%~1"=="-h" GOTO helpmessage IF "%~1"=="-d" set DEV_INSTALL=true SHIFT GOTO parse :endparse echo ***** INSTALLING FIFTYONE-BRAIN ***** IF %DEV_INSTALL%==true ( echo Performing dev install pip install -r requirements/dev.txt pre-commit install pip install -e . ) else ( pip install -r requirements.txt pip install . ) echo ***** INSTALLATION COMPLETE ***** exit /b :helpmessage echo Additional Arguments: echo -h Display help message echo -d Install developer dependencies. exit /b ================================================ FILE: install.sh ================================================ #!/bin/sh # Installs the `fiftyone-brain` package and its dependencies. # # Usage: # sh install.sh # # Copyright 2017-2026, Voxel51, Inc. # voxel51.com # # Show usage information set -e usage() { echo "Usage: sh $0 [-h] [-d] Getting help: -h Display this help message. Custom installations: -d Install developer dependencies. " } # Parse flags SHOW_HELP=false DEV_INSTALL=false while getopts "hd" FLAG; do case "${FLAG}" in h) SHOW_HELP=true ;; d) DEV_INSTALL=true ;; *) usage ;; esac done [ ${SHOW_HELP} = true ] && usage && exit 0 OS=$(uname -s) echo "***** INSTALLING FIFTYONE-BRAIN *****" if [ ${DEV_INSTALL} = true ]; then echo "Performing dev install" pip install -r requirements/dev.txt pre-commit install pip install -e . else pip install -r requirements.txt pip install . fi echo "***** INSTALLATION COMPLETE *****" ================================================ FILE: pylintrc ================================================ [MASTER] # Specify a configuration file. #rcfile= # Python code to execute, usually for sys.path manipulation such as # pygtk.require(). #init-hook= # Add files or directories to the blacklist. They should be base names, not # paths. ignore=CVS # Add files or directories matching the regex patterns to the blacklist. The # regex matches against base names, not paths. ignore-patterns= # Pickle collected data for later comparisons. persistent=yes # List of plugins (as comma separated values of python modules names) to load, # usually to register additional checkers. load-plugins= # Use multiple processes to speed up Pylint. jobs=1 # Allow loading of arbitrary C extensions. Extensions are imported into the # active Python interpreter and may run arbitrary code. unsafe-load-any-extension=no # A comma-separated list of package or module names from where C extensions may # be loaded. Extensions are loading into the active Python interpreter and may # run arbitrary code extension-pkg-whitelist= # Allow optimization of some AST trees. This will activate a peephole AST # optimizer, which will apply various small optimizations. For instance, it can # be used to obtain the result of joining multiple strings with the addition # operator. Joining a lot of strings can lead to a maximum recursion error in # Pylint and this flag can prevent that. It has one side effect, the resulting # AST will be different than the one from reality. This option is deprecated # and it will be removed in Pylint 2.0. optimize-ast=no [MESSAGES CONTROL] # Only show warnings with the listed confidence levels. Leave empty to show # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED confidence= # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option # multiple time (only on the command line, not in the configuration file where # it should appear only once). See also the "--disable" option for examples. #enable= # Disable the message, report, category or checker with the given id(s). You # can either give multiple identifiers separated by comma (,) or put this # option multiple times (only on the command line, not in the configuration # file where it should appear only once).You can also use "--disable=all" to # disable everything first and then reenable specific checks. For example, if # you want to run only the similarities checker, you can use "--disable=all # --enable=similarities". If you want to run only the classes checker, but have # no Warning level messages displayed, use"--disable=all --enable=classes # --disable=W" #disable=import-star-module-level,old-octal-literal,oct-method,print-statement,unpacking-in-except,parameter-unpacking,backtick,old-raise-syntax,old-ne-operator,long-suffix,dict-view-method,dict-iter-method,metaclass-assignment,next-method-called,raising-string,indexing-exception,raw_input-builtin,long-builtin,file-builtin,execfile-builtin,coerce-builtin,cmp-builtin,buffer-builtin,basestring-builtin,apply-builtin,filter-builtin-not-iterating,using-cmp-argument,useless-suppression,range-builtin-not-iterating,suppressed-message,no-absolute-import,old-division,cmp-method,reload-builtin,zip-builtin-not-iterating,intern-builtin,unichr-builtin,reduce-builtin,standarderror-builtin,unicode-builtin,xrange-builtin,coerce-method,delslice-method,getslice-method,setslice-method,input-builtin,round-builtin,hex-method,nonzero-method,map-builtin-not-iterating disable=too-few-public-methods,too-many-instance-attributes,too-many-arguments,too-many-locals,too-many-lines,non-iterator-returned,too-many-statements,useless-object-inheritance,abstract-method,too-many-ancestors,too-many-branches,unnecessary-pass,too-many-public-methods,bad-continuation [REPORTS] # Set the output format. Available formats are text, parseable, colorized, msvs # (visual studio) and html. You can also give a reporter class, eg # mypackage.mymodule.MyReporterClass. output-format=colorized # Put messages in a separate file for each module / package specified on the # command line instead of printing them on stdout. Reports (if any) will be # written in a file name "pylint_global.[txt|html]". This option is deprecated # and it will be removed in Pylint 2.0. files-output=no # Tells whether to display a full report or only the messages reports=no score=no # Python expression which should return a note less than 10 (10 is the highest # note). You have access to the variables errors warning, statement which # respectively contain the number of errors / warnings messages and the total # number of statements analyzed. This is used by the global evaluation report # (RP0004). evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) # Template used to display messages. This is a python new-style format string # used to format the message information. See doc for all details #msg-template= [BASIC] # Good variable names which should always be accepted, separated by a comma good-names=i,j,k # Bad variable names which should always be refused, separated by a comma bad-names= # Colon-delimited sets of names that determine each other's naming style when # the name regexes allow several styles. name-group= # Include a hint for the correct naming format with invalid-name include-naming-hint=no # List of decorators that produce properties, such as abc.abstractproperty. Add # to this list to register other decorators that produce valid properties. property-classes=abc.abstractproperty # Regular expression matching correct function names function-rgx=[a-z_]([a-z0-9_]{0,30})$ # Naming hint for function names function-name-hint=[a-z_]([a-z0-9_]{0,30})$ # Regular expression matching correct variable names variable-rgx=[a-z_]([a-z0-9_]{0,30})$ # Naming hint for variable names variable-name-hint=[a-z_]([a-z0-9_]{0,30})$ # Regular expression matching correct constant names const-rgx=(([A-Z_][A-Z0-9_]*)|(__.*__))$ # Naming hint for constant names const-name-hint=(([A-Z_][A-Z0-9_]*)|(__.*__))$ # Regular expression matching correct attribute names attr-rgx=[a-z_]([a-z0-9_]{0,30})$ # Naming hint for attribute names attr-name-hint=[a-z_]([a-z0-9_]{0,30})$ # Regular expression matching correct argument names argument-rgx=[a-z_]([a-z0-9_]{0,30})$ # Naming hint for argument names argument-name-hint=[a-z_]([a-z0-9_]{0,30})$ # Regular expression matching correct class attribute names class-attribute-rgx=([A-Za-z_]([A-Za-z0-9_]{0,30})|(__.*__))$ # Naming hint for class attribute names class-attribute-name-hint=([A-Za-z_]([A-Za-z0-9_]{0,30})|(__.*__))$ # Regular expression matching correct inline iteration names inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$ # Naming hint for inline iteration names inlinevar-name-hint=[A-Za-z_][A-Za-z0-9_]*$ # Regular expression matching correct class names class-rgx=[A-Z_][a-zA-Z0-9]+$ # Naming hint for class names class-name-hint=[A-Z_][a-zA-Z0-9]+$ # Regular expression matching correct module names module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ # Naming hint for module names module-name-hint=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ # Regular expression matching correct method names method-rgx=[a-z_]([a-z0-9_]{0,30})$ # Naming hint for method names method-name-hint=[a-z_]([a-z0-9_]{0,30})$ # Regular expression which should only match function or class names that do # not require a docstring. no-docstring-rgx=^_ # Minimum line length for functions/classes that require docstrings, shorter # ones are exempt. docstring-min-length=-1 [ELIF] # Maximum number of nested blocks for function / method body max-nested-blocks=5 [FORMAT] # Maximum number of characters on a single line. max-line-length=79 # Regexp for a line that is allowed to be longer than the limit. ignore-long-lines=^\s*(# )??$ # Allow the body of an if to be on the same line as the test if there is no # else. single-line-if-stmt=no # List of optional constructs for which whitespace checking is disabled. `dict- # separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. # `trailing-comma` allows a space between comma and closing bracket: (a, ). # `empty-line` allows space-only lines. no-space-check=trailing-comma,dict-separator # Maximum number of lines in a module max-module-lines=1000 # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 # tab). indent-string=' ' # Number of spaces of indent required inside a hanging or continued line. indent-after-paren=4 # Expected format of line ending, e.g. empty (any line ending), LF or CRLF. expected-line-ending-format= [LOGGING] # Logging modules to check that the string format arguments are in logging # function parameter format logging-modules=logging [MISCELLANEOUS] # List of note tags to take in consideration, separated by a comma. notes=FIXME,XXX,TODO [SIMILARITIES] # Minimum lines number of a similarity. min-similarity-lines=4 # Ignore comments when computing similarities. ignore-comments=yes # Ignore docstrings when computing similarities. ignore-docstrings=yes # Ignore imports when computing similarities. ignore-imports=no [SPELLING] # Spelling dictionary name. Available dictionaries: none. To make it working # install python-enchant package. spelling-dict= # List of comma separated words that should not be checked. spelling-ignore-words= # A path to a file that contains private dictionary; one word per line. spelling-private-dict-file= # Tells whether to store unknown words to indicated private dictionary in # --spelling-private-dict-file option instead of raising a message. spelling-store-unknown-words=no [TYPECHECK] # Tells whether missing members accessed in mixin class should be ignored. A # mixin class is detected if its name ends with "mixin" (case insensitive). ignore-mixin-members=yes # List of module names for which member attributes should not be checked # (useful for modules/projects where namespaces are manipulated during runtime # and thus existing member attributes cannot be deduced by static analysis. It # supports qualified module names, as well as Unix pattern matching. ignored-modules= # List of class names for which member attributes should not be checked (useful # for classes with dynamically set attributes). This supports the use of # qualified names. ignored-classes=optparse.Values,thread._local,_thread._local # List of members which are set dynamically and missed by pylint inference # system, and so shouldn't trigger E1101 when accessed. Python regular # expressions are accepted. generated-members=torch.*,fiftyone.*,fo.* # List of decorators that produce context managers, such as # contextlib.contextmanager. Add to this list to register other decorators that # produce valid context managers. contextmanager-decorators=contextlib.contextmanager [CLASSES] # List of method names used to declare (i.e. assign) instance attributes. defining-attr-methods=__init__,__new__,setUp # List of valid names for the first argument in a class method. valid-classmethod-first-arg=cls # List of valid names for the first argument in a metaclass class method. valid-metaclass-classmethod-first-arg=mcs # List of member names, which should be excluded from the protected access # warning. exclude-protected=_asdict,_fields,_replace,_source,_make [DESIGN] # Maximum number of arguments for function / method max-args=5 # Argument names that match this expression will be ignored. Default to name # with leading underscore ignored-argument-names=_.* # Maximum number of locals for function / method body max-locals=15 # Maximum number of return / yield for function / method body max-returns=6 # Maximum number of branch for function / method body max-branches=12 # Maximum number of statements in function / method body max-statements=50 # Maximum number of parents for a class (see R0901). max-parents=7 # Maximum number of attributes for a class (see R0902). max-attributes=7 # Minimum number of public methods for a class (see R0903). min-public-methods=2 # Maximum number of public methods for a class (see R0904). max-public-methods=20 # Maximum number of boolean expressions in a if statement max-bool-expr=5 [IMPORTS] # Deprecated modules which should not be used, separated by a comma deprecated-modules=regsub,TERMIOS,Bastion,rexec # Create a graph of every (i.e. internal and external) dependencies in the # given file (report RP0402 must not be disabled) import-graph= # Create a graph of external dependencies in the given file (report RP0402 must # not be disabled) ext-import-graph= # Create a graph of internal dependencies in the given file (report RP0402 must # not be disabled) int-import-graph= # Force import order to recognize a module as part of the standard # compatibility libraries. known-standard-library= # Force import order to recognize a module as part of a third party library. known-third-party=enchant # Analyse import fallback blocks. This can be used to support both Python 2 and # 3 compatible code, which means that the block might have code that exists # only in one or another interpreter, leading to false positives when analysed. analyse-fallback-blocks=no [EXCEPTIONS] # Exceptions that will emit a warning when being caught. Defaults to # "Exception" overgeneral-exceptions=Exception ================================================ FILE: pyproject.toml ================================================ [tool.black] line-length = 79 include = '\.pyi?$' exclude = ''' /( | \.git )/ ''' ================================================ FILE: pytest.ini ================================================ [pytest] python_files = *test*.py filterwarnings = ignore:dns.hash module will be removed in future versions:DeprecationWarning ignore:the imp module is deprecated in favour of importlib:DeprecationWarning ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning ignore:Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated:DeprecationWarning ignore:numpy.* size changed, may indicate binary incompatibility:RuntimeWarning ================================================ FILE: requirements/build.txt ================================================ -r common.txt pytest==5.4.3 twine>=3 ================================================ FILE: requirements/common.txt ================================================ numpy scipy scikit-learn ================================================ FILE: requirements/dev.txt ================================================ -r common.txt flickrapi==2.4.0 imageio==2.8.0 ipython>=7.16.1 pandas pre-commit==2.0.1 pylint==2.3.1 pytest==7.3.1 twine>=3 voxel51-eta[storage] ================================================ FILE: requirements/prod.txt ================================================ -r common.txt ================================================ FILE: requirements.txt ================================================ -r requirements/prod.txt ================================================ FILE: setup.py ================================================ #!/usr/bin/env python """ Installs `fiftyone-brain`. | Copyright 2017-2026, Voxel51, Inc. | `voxel51.com `_ | """ import os from setuptools import setup VERSION = "0.21.4" def get_version(): if "RELEASE_VERSION" in os.environ: version = os.environ["RELEASE_VERSION"] if not version.startswith(VERSION): raise ValueError( "Release version doest not match version: %s and %s" % (version, VERSION) ) return version return VERSION with open("README.md", "r") as fh: long_description = fh.read() setup( name="fiftyone-brain", version=get_version(), description="FiftyOne Brain", author="Voxel51, Inc.", author_email="info@voxel51.com", url="https://github.com/voxel51/fiftyone-brain", license="Apache", long_description=long_description, long_description_content_type="text/markdown", packages=["fiftyone.brain"], include_package_data=True, install_requires=["numpy", "scipy>=1.2.0", "scikit-learn"], classifiers=[ "Development Status :: 4 - Beta", "Intended Audience :: Developers", "Intended Audience :: Science/Research", "License :: OSI Approved :: Apache Software License", "Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Scientific/Engineering :: Image Processing", "Topic :: Scientific/Engineering :: Image Recognition", "Topic :: Scientific/Engineering :: Information Analysis", "Topic :: Scientific/Engineering :: Visualization", "Operating System :: MacOS :: MacOS X", "Operating System :: POSIX :: Linux", "Operating System :: Microsoft :: Windows", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", ], scripts=[], python_requires=">=3.9", ) ================================================ FILE: tests/README.md ================================================ # FiftyOne-Brain Tests The brain currently uses both [unittest](https://docs.python.org/3/library/unittest.html) and [pytest](https://docs.pytest.org/en/stable) to implement its tests. ## Contents | File | Description | | -------------------- | -------------------------------------------------------- | | `test_uniqueness.py` | Tests of the uniqueness capability | | `models/*.py` | Tests of the various models used by the brain | | `intensive/*.py` | Intensive tests that are not included in automated tests | ## Running tests To run all tests in this directory, execute: ```shell pytest . -s ``` To run a specific set of tests, execute: ```shell pytest .py -s ``` To run a specific test case, execute: ```shell pytest .py -s -k ``` ## Copyright Copyright 2017-2026, Voxel51, Inc.
voxel51.com ================================================ FILE: tests/intensive/test_interface.py ================================================ """ Brain interface tests. | Copyright 2017-2026, Voxel51, Inc. | `voxel51.com `_ | """ import unittest import fiftyone as fo import fiftyone.brain as fob import fiftyone.zoo as foz def test_uniqueness(): dataset = foz.load_zoo_dataset("quickstart").clone() fob.compute_uniqueness(dataset) print(dataset.list_brain_runs()) print(dataset.get_brain_info("uniqueness")) print(dataset.bounds("uniqueness")) dataset.delete_brain_runs() print(dataset) def test_detection_mistakenness(): dataset = foz.load_zoo_dataset("quickstart").clone() fob.compute_mistakenness( dataset, "predictions", label_field="ground_truth", copy_missing=True ) print(dataset.list_brain_runs()) print(dataset.get_brain_info("mistakenness")) # should be non-trivial print(dataset.bounds("mistakenness")) print(dataset.bounds("possible_missing")) print(dataset.bounds("possible_spurious")) print(dataset.bounds("ground_truth.detections.mistakenness")) print(dataset.bounds("ground_truth.detections.mistakenness_loc")) print(dataset.count_values("ground_truth.detections.possible_spurious")) print(dataset.count_values("predictions.detections.possible_missing")) print(dataset.count_values("ground_truth.detections.possible_missing")) dataset.delete_brain_runs() print(dataset) # should be None print(dataset.bounds("ground_truth.detections.mistakenness")) print(dataset.bounds("ground_truth.detections.mistakenness_loc")) print(dataset.count_values("ground_truth.detections.possible_spurious")) print(dataset.count_values("predictions.detections.possible_missing")) print(dataset.count_values("ground_truth.detections.possible_missing")) def test_classification_mistakenness_confidence(): dataset = foz.load_zoo_dataset("quickstart").clone() test_view = dataset.take(10) # labels proxy model = foz.load_zoo_model("alexnet-imagenet-torch") test_view.apply_model(model, "alexnet") # predictions proxy model = foz.load_zoo_model("resnet50-imagenet-torch") test_view.apply_model(model, "resnet50") fob.compute_mistakenness(test_view, "resnet50", label_field="alexnet") print(dataset.list_brain_runs()) print(dataset.load_brain_view("mistakenness")) print(dataset.bounds("mistakenness")) dataset.delete_brain_runs() print(dataset) def test_classification_mistakenness_logits(): dataset = foz.load_zoo_dataset("quickstart").clone() test_view = dataset.take(10) # labels proxy model = foz.load_zoo_model("alexnet-imagenet-torch") test_view.apply_model(model, "alexnet") # predictions proxy model = foz.load_zoo_model("resnet50-imagenet-torch") test_view.apply_model(model, "resnet50", store_logits=True) fob.compute_mistakenness( test_view, "resnet50", label_field="alexnet", use_logits=True ) print(dataset.list_brain_runs()) print(dataset.load_brain_view("mistakenness")) print(dataset.bounds("mistakenness")) dataset.delete_brain_runs() print(dataset) def test_hardness(): dataset = foz.load_zoo_dataset("quickstart").clone() test_view = dataset.take(10) model = foz.load_zoo_model("alexnet-imagenet-torch") test_view.apply_model(model, "alexnet", store_logits=True) fob.compute_hardness(test_view, "alexnet") print(dataset.list_brain_runs()) print(dataset.get_brain_info("hardness")) print(dataset.load_brain_view("hardness")) print(dataset.bounds("hardness")) dataset.delete_brain_runs() print(dataset) if __name__ == "__main__": fo.config.show_progress_bars = True unittest.main(verbosity=2) ================================================ FILE: tests/intensive/test_similarity.py ================================================ """ Similarity tests. Usage:: # Optional: specific backends to test export SIMILARITY_BACKENDS=qdrant,pinecone,milvus,redis,elasticsearch,mosaic,pgvector,lancedb pytest tests/intensive/test_similarity.py -s -k test_XXX Qdrant setup:: docker pull qdrant/qdrant docker run -p 6333:6333 qdrant/qdrant pip install qdrant-client Pinecone setup:: # Sign up at https://www.pinecone.io # Download API key and environment pip install pinecone-client Milvus setup:: # Instructions from: https://milvus.io/docs/install_standalone-docker.md wget https://github.com/milvus-io/milvus/releases/download/v2.2.11/milvus-standalone-docker-compose.yml -O docker-compose.yml docker compose up -d pip install pymilvus LanceDB setup:: pip install lancedb Redis setup:: brew tap redis-stack/redis-stack brew install redis-stack redis-stack-server pip install redis Elasticsearch setup:: # Instructions from: https://www.elastic.co/guide/en/elasticsearch/reference/current/getting-started.html#run-elasticsearch docker run -p 127.0.0.1:9200:9200 -d \ --name elasticsearch \ -e ELASTIC_PASSWORD=elastic \ -e "discovery.type=single-node" \ -e "xpack.security.http.ssl.enabled=false" \ -e "xpack.license.self_generated.type=trial" \ docker.elastic.co/elasticsearch/elasticsearch:8.15.0 pip install elasticsearch Mosaic setup:: # In your databricks workspace, generate a personal access token for authentication. # You will also need to create a catalog and schema in your workspace. # You will have to create an endpoint under `compute` -> `vector search` pip install databricks-vectorsearch PGVector setup:: # Run a postgres instance locally with pgvector extension docker pull pgvector/pgvector:pg17 docker run --name postgres -p 5432:5432 -e POSTGRES_PASSWORD=mysecretpassword -d pgvector/pgvector:pg17 # Enter the container and create the vector extension docker exec -it postgres ./bin/psql -U postgres CREATE EXTENSION IF NOT EXISTS vector; # run in container pip install psycopg2 Brain config setup at `~/.fiftyone/brain_config.json`:: { "similarity_backends": { "pinecone": { "api_key": "XXXXXXXX", "cloud": "aws", "region": "us-east-1", "environment": "us-east-1-aws" }, "qdrant": { "url": "http://localhost:6333" }, "milvus": { "uri": "http://localhost:19530" }, "lancedb": { "uri": "/tmp/lancedb" }, "redis": { "host": "localhost", "port": 6379 } "elasticsearch": { "hosts": "http://localhost:9200", "username": "elastic", "password": "elastic" }, "mosaic": { "workspace_url": "https://.cloud.databricks.com/", "personal_access_token": "", "catalog_name": "", "schema_name": "", "endpoint_name": "" }, "pgvector": { "connection_string": "postgresql://postgres:mysecretpassword@localhost:5432/postgres" } } } | Copyright 2017-2026, Voxel51, Inc. | `voxel51.com `_ | """ import random import os import time import unittest import numpy as np import fiftyone as fo import fiftyone.brain as fob # pylint: disable=import-error,no-name-in-module import fiftyone.zoo as foz from fiftyone import ViewField as F CUSTOM_BACKENDS = [ "qdrant", "pinecone", "milvus", "redis", "elasticsearch", "mosaic", "pgvector", "lancedb", ] def get_custom_backends(): if "SIMILARITY_BACKENDS" in os.environ: return os.environ["SIMILARITY_BACKENDS"].split(",") return CUSTOM_BACKENDS def test_brain_config(): similarity_backends = fob.brain_config.similarity_backends assert "sklearn" in similarity_backends for backend in get_custom_backends(): if backend == "qdrant": assert "qdrant" in similarity_backends # this isn't mandatory # assert "url" in similarity_backends["qdrant"] if backend == "pinecone": assert "pinecone" in similarity_backends # this isn't mandatory # assert "api_key" in similarity_backends["pinecone"] # assert "cloud" in similarity_backends["pinecone"] # assert "region" in similarity_backends["pinecone"] # assert "environment" in similarity_backends["pinecone"] if backend == "milvus": assert "milvus" in similarity_backends # this isn't mandatory # assert "uri" in similarity_backends["milvus"] if backend == "lancedb": assert "lancedb" in similarity_backends # this isn't mandatory # assert "uri" in similarity_backends["lancedb"] if backend == "redis": assert "redis" in similarity_backends # this isn't mandatory # assert "host" in similarity_backends["redis"] # assert "port" in similarity_backends["redis"] if backend == "elasticsearch": assert "elasticsearch" in similarity_backends def test_image_similarity_backends(): dataset = foz.load_zoo_dataset( "quickstart", dataset_name="quickstart-test-similarity-image", drop_existing_dataset=True, ) # sklearn backend ########################################################################### index1 = fob.compute_similarity( dataset, model="clip-vit-base32-torch", metric="euclidean", embeddings=False, backend="sklearn", brain_key="clip_sklearn", ) embeddings, sample_ids, _ = index1.compute_embeddings(dataset) index1.add_to_index(embeddings, sample_ids) index1.save() index1.reload() assert index1.total_index_size == 200 assert index1.index_size == 200 assert index1.missing_size is None prompt = "kites high in the air" view1 = dataset.sort_by_similarity(prompt, k=10, brain_key="clip_sklearn") assert len(view1) == 10 del index1 dataset.clear_cache() print(dataset.get_brain_info("clip_sklearn")) index1 = dataset.load_brain_results("clip_sklearn") assert index1.total_index_size == 200 embeddings1, sample_ids1, _ = index1.get_embeddings() assert embeddings1.shape == (200, 512) assert sample_ids1.shape == (200,) ids = random.sample(list(index1.sample_ids), 100) embeddings1, sample_ids1, _ = index1.get_embeddings(sample_ids=ids) assert embeddings1.shape == (100, 512) assert sample_ids1.shape == (100,) index1.remove_from_index(sample_ids=ids) assert index1.total_index_size == 100 index1.cleanup() dataset.delete_brain_run("clip_sklearn") # custom backends ########################################################################### for backend in get_custom_backends(): brain_key = "clip_" + backend index2 = fob.compute_similarity( dataset, model="clip-vit-base32-torch", metric="euclidean", embeddings=False, backend=backend, brain_key=brain_key, ) index2.add_to_index(embeddings, sample_ids) assert _verify_total_index_size(index=index2, expected_size=200) assert index2.total_index_size == 200 assert index2.index_size == 200 assert index2.missing_size is None view2 = dataset.sort_by_similarity(prompt, k=10, brain_key=brain_key) assert len(view2) == 10 del index2 dataset.clear_cache() print(dataset.get_brain_info(brain_key)) index2 = dataset.load_brain_results(brain_key) assert index2.total_index_size == 200 # Pinecone and Milvus require IDs, so this method is not supported if backend not in ("pinecone", "milvus"): embeddings2, sample_ids2, _ = index2.get_embeddings() assert embeddings2.shape == (200, 512) assert sample_ids2.shape == (200,) embeddings2, sample_ids2, _ = index2.get_embeddings(sample_ids=ids) assert embeddings2.shape == (100, 512) assert sample_ids2.shape == (100,) assert set(sample_ids1) == set(sample_ids2) embeddings2_dict = dict(zip(sample_ids2, embeddings2)) _embeddings2 = np.array([embeddings2_dict[i] for i in sample_ids1]) assert np.allclose(embeddings1, _embeddings2) index2.remove_from_index(sample_ids=ids) # Collection size is known to be wrong in Milvus after deletions # As of July 5, 2023 this has not been fixed # https://github.com/milvus-io/milvus/issues/17193 if backend != "milvus": assert index2.total_index_size == 100 index2.cleanup() dataset.delete_brain_run(brain_key) dataset.delete() def test_patch_similarity_backends(): dataset = foz.load_zoo_dataset( "quickstart", dataset_name="quickstart-test-similarity-patch", drop_existing_dataset=True, ) # sklearn backend ########################################################################### index1 = fob.compute_similarity( dataset, patches_field="ground_truth", model="clip-vit-base32-torch", metric="euclidean", embeddings=False, backend="sklearn", brain_key="gt_clip_sklearn", ) embeddings, sample_ids, label_ids = index1.compute_embeddings(dataset) index1.add_to_index(embeddings, sample_ids, label_ids=label_ids) index1.save() index1.reload() assert index1.total_index_size == 1232 assert index1.index_size == 1232 assert index1.missing_size is None view = dataset.to_patches("ground_truth") prompt = "cute puppies" view1 = view.sort_by_similarity(prompt, k=10, brain_key="gt_clip_sklearn") assert len(view1) == 10 del index1 dataset.clear_cache() print(dataset.get_brain_info("gt_clip_sklearn")) index1 = dataset.load_brain_results("gt_clip_sklearn") assert index1.total_index_size == 1232 embeddings1, sample_ids1, label_ids1 = index1.get_embeddings() assert embeddings1.shape == (1232, 512) assert sample_ids1.shape == (1232,) assert label_ids1.shape == (1232,) ids = random.sample(list(index1.label_ids), 100) embeddings1, sample_ids1, label_ids1 = index1.get_embeddings(label_ids=ids) assert embeddings1.shape == (100, 512) assert sample_ids1.shape == (100,) assert label_ids1.shape == (100,) index1.remove_from_index(label_ids=ids) assert index1.total_index_size == 1132 index1.cleanup() dataset.delete_brain_run("gt_clip_sklearn") # custom backends ########################################################################### for backend in get_custom_backends(): brain_key = "gt_clip_" + backend index2 = fob.compute_similarity( dataset, patches_field="ground_truth", model="clip-vit-base32-torch", metric="euclidean", embeddings=False, backend=backend, brain_key=brain_key, ) index2.add_to_index(embeddings, sample_ids, label_ids=label_ids) assert _verify_total_index_size(index=index2, expected_size=1232) assert index2.total_index_size == 1232 assert index2.index_size == 1232 assert index2.missing_size is None view2 = view.sort_by_similarity(prompt, k=10, brain_key=brain_key) assert len(view2) == 10 del index2 dataset.clear_cache() print(dataset.get_brain_info(brain_key)) index2 = dataset.load_brain_results(brain_key) assert index2.total_index_size == 1232 # Pinecone and Milvus require IDs, so this method is not supported if backend not in ("pinecone", "milvus"): embeddings2, sample_ids2, label_ids2 = index2.get_embeddings() assert embeddings2.shape == (1232, 512) assert sample_ids2.shape == (1232,) assert label_ids2.shape == (1232,) embeddings2, sample_ids2, label_ids2 = index2.get_embeddings( label_ids=ids ) assert embeddings2.shape == (100, 512) assert sample_ids2.shape == (100,) assert label_ids2.shape == (100,) assert set(label_ids1) == set(label_ids2) embeddings2_dict = dict(zip(label_ids2, embeddings2)) _embeddings2 = np.array([embeddings2_dict[i] for i in label_ids1]) assert np.allclose(embeddings1, _embeddings2) index2.remove_from_index(label_ids=ids) # Collection size is known to be wrong in Milvus after deletions # As of July 5, 2023 this has not been fixed # https://github.com/milvus-io/milvus/issues/17193 if backend != "milvus": assert index2.total_index_size == 1132 index2.cleanup() dataset.delete_brain_run(brain_key) dataset.delete() def test_qdrant_backend_config(): """ - *_similarity_backends tests run with custom backends as "externally" configured - To test varying connection details (eg with qdrant), re-configure externally and re-run tests - This test white-box tests that gRPC-related config settings are applied to QdrantClient """ backend = "qdrant" if backend not in get_custom_backends(): return dataset = foz.load_zoo_dataset("quickstart", max_samples=5) brain_key = "clip_" + backend index = fob.compute_similarity( dataset, model="clip-vit-base32-torch", metric="euclidean", embeddings=False, backend=backend, brain_key=brain_key, ) qclient = index.client qremote = qclient._client qdrant_config = fob.brain_config.similarity_backends["qdrant"] if "prefer_grpc" in qdrant_config: prefer_grpc = qdrant_config["prefer_grpc"] assert qremote._prefer_grpc == prefer_grpc print(f"Applied qdrant config prefer_grpc={prefer_grpc}") else: print("Qdrant config prefer_grpc unset") if "grpc_port" in qdrant_config: grpc_port = qdrant_config["grpc_port"] assert qremote._grpc_port == grpc_port print(f"Applied qdrant config grpc_port={grpc_port}") else: print("Qdrant config grpc_port unset") dataset.delete() def test_images(): dataset = _load_images_dataset() index = dataset.load_brain_results("img_sim") assert index.total_index_size == len(dataset) assert set(dataset.values("id")) == set(index.sample_ids) def test_images_subset(): dataset = _load_images_dataset() index = dataset.load_brain_results("img_sim") view = dataset.take(10) index.use_view(view) assert index.index_size == len(view) assert set(view.values("id")) == set(index.current_sample_ids) def test_images_missing(): dataset = _load_images_dataset().limit(4).clone() dataset.add_samples( [ fo.Sample(filepath="non-existent1.png"), fo.Sample(filepath="non-existent2.png"), fo.Sample(filepath="non-existent3.png"), fo.Sample(filepath="non-existent4.png"), ] ) sample_ids = dataset[:4].values("id") index = fob.compute_similarity(dataset, batch_size=1) assert index.total_index_size == 4 assert set(sample_ids) == set(index.sample_ids) model = foz.load_zoo_model("inception-v3-imagenet-torch") index = fob.compute_similarity( dataset, model=model, embeddings="embeddings_missing", batch_size=1, ) assert len(dataset.exists("embeddings_missing")) == 4 assert index.index_size == 4 assert set(sample_ids) == set(index.sample_ids) def test_images_embeddings(): dataset = foz.load_zoo_dataset( "quickstart", max_samples=10, drop_existing_dataset=True ) model = foz.load_zoo_model("clip-vit-base32-torch") n = len(dataset) # Embeddings are computed on-the-fly and stored on dataset index1 = fob.compute_similarity( dataset, embeddings="embeddings", model="clip-vit-base32-torch", brain_key="img_sim1", backend="sklearn", ) assert index1.total_index_size == n assert index1.config.supports_prompts is True assert "embeddings" not in index1.serialize() # Embeddings already exist on dataset dataset.compute_embeddings(model, embeddings_field="embeddings2") index2 = fob.compute_similarity( dataset, embeddings="embeddings2", model="clip-vit-base32-torch", brain_key="img_sim2", backend="sklearn", ) assert index2.total_index_size == n assert index2.config.supports_prompts is True assert "embeddings" not in index2.serialize() # Embeddings stored in index itself index3 = fob.compute_similarity( dataset, model="clip-vit-base32-torch", brain_key="img_sim3", backend="sklearn", ) assert index3.total_index_size == n assert index3.config.supports_prompts is True assert "embeddings" in index3.serialize() # Embeddings stored on dataset (but field doesn't initially exist) index4 = fob.compute_similarity( dataset, embeddings="embeddings4", brain_key="img_sim4", backend="sklearn", ) embeddings = np.random.randn(n, 512) sample_ids = dataset.values("id") index4.add_to_index(embeddings, sample_ids) assert index4.total_index_size == n assert index4.config.supports_prompts is not True assert "embeddings" not in index4.serialize() dataset.delete() def test_patches(): dataset = _load_patches_dataset() index = dataset.load_brain_results("gt_sim") label_ids = dataset.values("ground_truth.detections.id", unwind=True) assert index.total_index_size == len(label_ids) assert set(label_ids) == set(index.label_ids) def test_patches_subset(): dataset = _load_patches_dataset() index = dataset.load_brain_results("gt_sim") label_ids = dataset.values("ground_truth.detections.id", unwind=True) assert index.total_index_size == len(label_ids) assert set(label_ids) == set(index.label_ids) view = dataset.filter_labels("ground_truth", F("label") == "person") index.use_view(view) label_ids = view.values("ground_truth.detections.id", unwind=True) assert index.index_size == len(label_ids) assert set(label_ids) == set(index.current_label_ids) def test_patches_missing(): dataset = _load_patches_dataset().limit(4).clone() dataset.add_samples( [ fo.Sample(filepath="non-existent1.png"), fo.Sample(filepath="non-existent2.png"), fo.Sample(filepath="non-existent3.png"), fo.Sample(filepath="non-existent4.png"), ] ) for sample in dataset[4:]: sample["ground_truth"] = fo.Detections( detections=[fo.Detection(bounding_box=[0.1, 0.1, 0.8, 0.8])] ) sample.save() index = fob.compute_similarity( dataset, patches_field="ground_truth", batch_size=1 ) num_patches = dataset[:4].count("ground_truth.detections") label_ids = dataset[:4].values("ground_truth.detections.id", unwind=True) assert index.total_index_size == num_patches assert set(label_ids) == set(index.label_ids) model = foz.load_zoo_model("inception-v3-imagenet-torch") index = fob.compute_similarity( dataset, model=model, patches_field="ground_truth", embeddings="embeddings_missing", batch_size=1, ) view = dataset.filter_labels( "ground_truth", F("embeddings_missing") != None ) assert view.count("ground_truth.detections") == num_patches assert index.total_index_size == num_patches assert set(label_ids) == set(index.label_ids) def test_patches_embeddings(): dataset = foz.load_zoo_dataset( "quickstart", max_samples=10, drop_existing_dataset=True ) model = foz.load_zoo_model("clip-vit-base32-torch") n = dataset.count("ground_truth.detections") # Embeddings are computed on-the-fly and stored on dataset index1 = fob.compute_similarity( dataset, patches_field="ground_truth", embeddings="embeddings", model="clip-vit-base32-torch", brain_key="gt_sim1", backend="sklearn", ) assert index1.total_index_size == n assert index1.config.supports_prompts is True assert "embeddings" not in index1.serialize() # Embeddings already exist on dataset dataset.compute_patch_embeddings( model, "ground_truth", embeddings_field="embeddings2" ) index2 = fob.compute_similarity( dataset, patches_field="ground_truth", embeddings="embeddings2", model="clip-vit-base32-torch", brain_key="gt_sim2", backend="sklearn", ) assert index2.total_index_size == n assert index2.config.supports_prompts is True assert "embeddings" not in index2.serialize() # Embeddings stored in index itself index3 = fob.compute_similarity( dataset, patches_field="ground_truth", model="clip-vit-base32-torch", brain_key="gt_sim3", backend="sklearn", ) assert index3.total_index_size == n assert index3.config.supports_prompts is True assert "embeddings" in index3.serialize() # Embeddings stored on dataset (but field doesn't initially exist) index4 = fob.compute_similarity( dataset, patches_field="ground_truth", embeddings="embeddings4", brain_key="gt_sim4", backend="sklearn", ) embeddings = np.random.randn(n, 512) view = dataset.to_patches("ground_truth") sample_ids, label_ids = view.values(["sample_id", "id"]) index4.add_to_index(embeddings, sample_ids, label_ids=label_ids) assert index4.total_index_size == n assert index4.config.supports_prompts is not True assert "embeddings" not in index4.serialize() dataset.delete() def _load_images_dataset(): name = "test-similarity-images" if fo.dataset_exists(name): return fo.load_dataset(name) return _make_images_dataset(name) def _load_patches_dataset(): name = "test-similarity-patches" if fo.dataset_exists(name): return fo.load_dataset(name) return _make_patches_dataset(name) def _make_images_dataset(name): dataset = foz.load_zoo_dataset( "quickstart", max_samples=20, dataset_name=name ) model = foz.load_zoo_model("inception-v3-imagenet-torch") # Embed images dataset.compute_embeddings( model, embeddings_field="embeddings", batch_size=8 ) # Image similarity fob.compute_similarity( dataset, embeddings="embeddings", brain_key="img_sim" ) return dataset def _make_patches_dataset(name): dataset = foz.load_zoo_dataset( "quickstart", max_samples=20, dataset_name=name ) model = foz.load_zoo_model("inception-v3-imagenet-torch") # Embed ground truth patches dataset.compute_patch_embeddings( model, "ground_truth", embeddings_field="embeddings", batch_size=8, force_square=True, ) # Patch similarity fob.compute_similarity( dataset, patches_field="ground_truth", embeddings="embeddings", brain_key="gt_sim", ) return dataset def _verify_total_index_size(index, expected_size, timeout=10, interval=1): elapsed_time = 0 while index.total_index_size != expected_size and elapsed_time < timeout: time.sleep(interval) elapsed_time += interval return index.total_index_size == expected_size if __name__ == "__main__": fo.config.show_progress_bars = True unittest.main(verbosity=2) ================================================ FILE: tests/intensive/test_uniqueness.py ================================================ """ Uniqueness tests. | Copyright 2017-2026, Voxel51, Inc. | `voxel51.com `_ | """ import unittest import fiftyone as fo import fiftyone.brain as fob import fiftyone.zoo as foz def test_uniqueness(): _run_uniqueness() def test_uniqueness_torch(): model = foz.load_zoo_model("inception-v3-imagenet-torch") _run_uniqueness(model=model, batch_size=16) def test_uniqueness_tf(): model = foz.load_zoo_model("resnet-v2-50-imagenet-tf1") _run_uniqueness(model=model, batch_size=16) def test_uniqueness_missing(): dataset = fo.Dataset() dataset.add_samples( [ fo.Sample(filepath="non-existent1.png"), fo.Sample(filepath="non-existent2.png"), fo.Sample(filepath="non-existent3.png"), fo.Sample(filepath="non-existent4.png"), ] ) fob.compute_uniqueness(dataset, batch_size=1) view = dataset.exists("uniqueness") assert dataset.has_field("uniqueness") assert len(view) == 0 def test_roi_uniqueness(): _run_uniqueness(roi_field="ground_truth") def test_roi_uniqueness_torch(): model = foz.load_zoo_model("inception-v3-imagenet-torch") _run_uniqueness(roi_field="ground_truth", model=model, batch_size=16) def test_roi_uniqueness_tf(): model = foz.load_zoo_model("resnet-v2-50-imagenet-tf1") _run_uniqueness(roi_field="ground_truth", model=model, batch_size=16) def test_roi_uniqueness_missing(): dataset = fo.Dataset() dataset.add_samples( [ fo.Sample(filepath="non-existent1.png"), fo.Sample(filepath="non-existent2.png"), fo.Sample(filepath="non-existent3.png"), fo.Sample(filepath="non-existent4.png"), ] ) for sample in dataset: sample["ground_truth"] = fo.Detections( detections=[fo.Detection(bounding_box=[0.1, 0.1, 0.8, 0.8])] ) sample.save() fob.compute_uniqueness(dataset, roi_field="ground_truth", batch_size=1) view = dataset.exists("uniqueness") assert dataset.has_field("uniqueness") assert len(view) == 0 def test_uniqueness_similarity_index(): dataset = foz.load_zoo_dataset( "quickstart", dataset_name=fo.get_default_dataset_name() ) dataset.delete_sample_field("uniqueness") # Full similarity index similarity_index = fob.compute_similarity( dataset, brain_key="sklearn_index", backend="sklearn" ) fob.compute_uniqueness(dataset, similarity_index=similarity_index) assert dataset.has_field("uniqueness") dataset.clear_cache() dataset.delete_sample_field("uniqueness") fob.compute_uniqueness(dataset, similarity_index="sklearn_index") assert dataset.has_field("uniqueness") # Partial similarity index view = dataset.take(100, seed=51) similarity_index2 = fob.compute_similarity( view, brain_key="sklearn_index2", backend="sklearn" ) fob.compute_uniqueness( dataset, uniqueness_field="uniqueness2", similarity_index="sklearn_index2", ) assert len(dataset.exists("uniqueness2")) == len(view) def _run_uniqueness(roi_field=None, model=None, batch_size=None): dataset = foz.load_zoo_dataset( "quickstart", dataset_name=fo.get_default_dataset_name() ) dataset.delete_sample_field("uniqueness") view = dataset.take(50) num_samples = len(view) fob.compute_uniqueness( view, roi_field=roi_field, model=model, batch_size=batch_size ) num_uniqueness = dataset.count("uniqueness") assert num_uniqueness == num_samples bounds = dataset.bounds("uniqueness") assert bounds[0] >= 0 assert bounds[1] <= 1 if __name__ == "__main__": fo.config.show_progress_bars = True unittest.main(verbosity=2) ================================================ FILE: tests/intensive/test_visualization.py ================================================ """ Visualization tests. All of these tests are designed to be run manually via:: pytest tests/intensive/test_visualization.py -s -k test_ | Copyright 2017-2026, Voxel51, Inc. | `voxel51.com `_ | """ import unittest import cv2 import numpy as np import fiftyone as fo import fiftyone.brain as fob import fiftyone.zoo as foz from fiftyone import ViewField as F def test_mnist(): dataset = foz.load_zoo_dataset("mnist", split="test") # pylint: disable=no-member embeddings = np.array( [ cv2.imread(f, cv2.IMREAD_UNCHANGED).ravel() for f in dataset.values("filepath") ] ) results = fob.compute_visualization( dataset, embeddings=embeddings, num_dims=2, verbose=True, seed=51, ) plot = results.visualize(labels="ground_truth.label") plot.show() input("Press enter to continue...") def test_images(): dataset = _load_images_dataset() results = dataset.load_brain_results("img_viz") assert results.total_index_size == len(dataset) assert set(dataset.values("id")) == set(results.sample_ids) plot = results.visualize(labels="uniqueness") plot.show() input("Press enter to continue...") def test_images_subset(): dataset = _load_images_dataset() results = dataset.load_brain_results("img_viz") view = dataset.take(10) results.use_view(view) assert results.index_size == len(view) assert set(view.values("id")) == set(results.current_sample_ids) plot = results.visualize(labels="uniqueness") plot.show() input("Press enter to continue...") def test_images_missing(): dataset = _load_images_dataset().limit(4).clone() dataset.add_samples( [ fo.Sample(filepath="non-existent1.png"), fo.Sample(filepath="non-existent2.png"), fo.Sample(filepath="non-existent3.png"), fo.Sample(filepath="non-existent4.png"), ] ) sample_ids = dataset[:4].values("id") results = fob.compute_visualization(dataset, batch_size=1) assert results.total_index_size == 4 assert set(sample_ids) == set(results.sample_ids) model = foz.load_zoo_model("inception-v3-imagenet-torch") results = fob.compute_visualization( dataset, model=model, embeddings="embeddings_missing", batch_size=1, ) assert len(dataset.exists("embeddings_missing")) == 4 assert results.total_index_size == 4 assert set(sample_ids) == set(results.sample_ids) def test_patches(): dataset = _load_patches_dataset() results = dataset.load_brain_results("gt_viz") label_ids = dataset.values("ground_truth.detections.id", unwind=True) assert results.total_index_size == len(label_ids) assert set(label_ids) == set(results.label_ids) plot = results.visualize(labels="ground_truth.detections.label") plot.show() input("Press enter to continue...") def test_patches_subset(): dataset = _load_patches_dataset() results = dataset.load_brain_results("gt_viz") plot = results.visualize( labels="ground_truth.detections.label", classes=["person"], ) plot.show() input("Press enter to continue...") view = dataset.filter_labels("ground_truth", F("label") == "person") results.use_view(view) label_ids = view.values("ground_truth.detections.id", unwind=True) assert results.index_size == len(label_ids) assert set(label_ids) == set(results.current_label_ids) plot = results.visualize(labels="ground_truth.detections.label") plot.show() input("Press enter to continue...") def test_patches_missing(): dataset = _load_patches_dataset().limit(4).clone() dataset.add_samples( [ fo.Sample(filepath="non-existent1.png"), fo.Sample(filepath="non-existent2.png"), fo.Sample(filepath="non-existent3.png"), fo.Sample(filepath="non-existent4.png"), ] ) for sample in dataset[4:]: sample["ground_truth"] = fo.Detections( detections=[fo.Detection(bounding_box=[0.1, 0.1, 0.8, 0.8])] ) sample.save() results = fob.compute_visualization( dataset, patches_field="ground_truth", batch_size=1 ) num_patches = dataset[:4].count("ground_truth.detections") label_ids = dataset[:4].values("ground_truth.detections.id", unwind=True) assert results.total_index_size == num_patches assert set(label_ids) == set(results.label_ids) model = foz.load_zoo_model("inception-v3-imagenet-torch") results = fob.compute_visualization( dataset, model=model, patches_field="ground_truth", embeddings="embeddings_missing", batch_size=1, ) view = dataset.filter_labels( "ground_truth", F("embeddings_missing") != None ) assert view.count("ground_truth.detections") == num_patches assert results.total_index_size == num_patches assert set(label_ids) == set(results.label_ids) def test_points(): dataset = foz.load_zoo_dataset("quickstart") n = len(dataset) p = dataset.count("ground_truth.detections") d = 512 points1 = np.random.rand(n, d) results1 = fob.compute_visualization( dataset, points=points1, brain_key="test1", ) assert results1.points.shape == (n, d) points2 = {_id: np.random.rand(d) for _id in dataset.values("id")} results2 = fob.compute_visualization( dataset, points=points2, brain_key="test2", ) assert results2.points.shape == (n, d) points3 = np.random.rand(p, d) results3 = fob.compute_visualization( dataset, patches_field="ground_truth", points=points3, brain_key="test3", ) assert results3.points.shape == (p, d) points4 = { _id: np.random.rand(d) for _id in dataset.values("ground_truth.detections.id", unwind=True) } results4 = fob.compute_visualization( dataset, patches_field="ground_truth", points=points4, brain_key="test4", ) assert results4.points.shape == (p, d) dataset.delete() def test_similarity_index(): dataset = foz.load_zoo_dataset( "quickstart", dataset_name=fo.get_default_dataset_name() ) # Full similarity index similarity_index = fob.compute_similarity( dataset, brain_key="sklearn_index", backend="sklearn" ) results = fob.compute_visualization( dataset, brain_key="img_viz", similarity_index=similarity_index, ) assert len(results.points) == len(dataset) # Partial similarity index view = dataset.take(100, seed=51) similarity_index2 = fob.compute_similarity( view, brain_key="sklearn_index2", backend="sklearn" ) results2 = fob.compute_visualization( dataset, brain_key="img_viz2", similarity_index="sklearn_index2", ) assert len(results2.points) == len(view) def test_points_field(): dataset = _load_images_dataset() num_points = len(dataset) points = np.random.randn(num_points, 2) brain_key = "test_points" points_field = brain_key fob.compute_visualization( dataset, brain_key=brain_key, points=points, create_index=True, ) dataset.clear_cache() results = dataset.load_brain_results(brain_key) assert results.config.points_field == points_field assert dataset.has_sample_field(points_field) assert points_field in dataset.list_indexes() sample_points = dataset.first()[points_field] assert isinstance(sample_points, list) assert len(sample_points) == 2 assert isinstance(sample_points[0], float) points = results.points assert len(points) == num_points assert len(points[0]) == 2 all_points = dataset.values(points_field) assert np.allclose(points, all_points) dataset.delete_brain_run(brain_key) assert not dataset.has_sample_field(points_field) assert points_field not in dataset.list_indexes() def test_points_field_patches(): dataset = _load_patches_dataset() num_points = dataset.count("ground_truth.detections") points = np.random.randn(num_points, 2) brain_key = "test_points" points_field = brain_key points_path = f"ground_truth.detections.{points_field}" fob.compute_visualization( dataset, brain_key=brain_key, points=points, patches_field="ground_truth", create_index=True, ) dataset.clear_cache() results = dataset.load_brain_results(brain_key) assert results.config.points_field == points_field assert dataset.has_sample_field(points_path) # Patch visualizations can't currently make use of database indexes assert points_path not in dataset.list_indexes() label_points = dataset.first().ground_truth.detections[0][points_field] assert isinstance(label_points, list) assert len(label_points) == 2 assert isinstance(label_points[0], float) points = results.points assert len(points) == num_points assert len(points[0]) == 2 all_points = dataset.values(f"ground_truth.detections[].{points_field}") assert np.allclose(points, all_points) dataset.delete_brain_run(brain_key) assert not dataset.has_sample_field(points_path) def test_index_points(): dataset = _load_images_dataset() num_points = len(dataset) points = np.random.randn(num_points, 2) brain_key = "test_points" points_field = brain_key fob.compute_visualization(dataset, brain_key=brain_key, points=points) dataset.clear_cache() results = dataset.load_brain_results(brain_key) assert results.config.points_field is None assert not dataset.has_sample_field(points_field) assert points_field not in dataset.list_indexes() results.index_points() dataset.clear_cache() results = dataset.load_brain_results(brain_key) assert results.config.points_field == points_field assert dataset.has_sample_field(points_field) assert points_field in dataset.list_indexes() points = results.points all_points = dataset.values(points_field) assert np.allclose(points, all_points) results.remove_index() dataset.clear_cache() results = dataset.load_brain_results(brain_key) assert results.config.points_field is None assert not dataset.has_sample_field(points_field) assert points_field not in dataset.list_indexes() def test_index_points_patches(): dataset = _load_patches_dataset() num_points = dataset.count("ground_truth.detections") points = np.random.randn(num_points, 2) brain_key = "test_points" points_field = brain_key points_path = f"ground_truth.detections.{points_field}" fob.compute_visualization( dataset, brain_key=brain_key, points=points, patches_field="ground_truth", ) dataset.clear_cache() results = dataset.load_brain_results(brain_key) assert results.config.points_field is None assert not dataset.has_sample_field(points_path) results.index_points() dataset.clear_cache() results = dataset.load_brain_results(brain_key) assert results.config.points_field == points_field assert dataset.has_sample_field(points_path) points = results.points all_points = dataset.values(f"ground_truth.detections[].{points_field}") assert np.allclose(points, all_points) results.remove_index() dataset.clear_cache() results = dataset.load_brain_results(brain_key) assert results.config.points_field is None assert not dataset.has_sample_field(points_path) def _load_images_dataset(): name = "test-visualization-images" if fo.dataset_exists(name): return fo.load_dataset(name) return _make_images_dataset(name) def _load_patches_dataset(): name = "test-visualization-patches" if fo.dataset_exists(name): return fo.load_dataset(name) return _make_patches_dataset(name) def _make_images_dataset(name): dataset = foz.load_zoo_dataset( "quickstart", max_samples=20, dataset_name=name ) model = foz.load_zoo_model("inception-v3-imagenet-torch") # Embed images dataset.compute_embeddings( model, embeddings_field="embeddings", batch_size=8 ) # Image visualization fob.compute_visualization( dataset, embeddings="embeddings", num_dims=2, verbose=True, seed=51, brain_key="img_viz", ) return dataset def _make_patches_dataset(name): dataset = foz.load_zoo_dataset( "quickstart", max_samples=20, dataset_name=name ) model = foz.load_zoo_model("inception-v3-imagenet-torch") # Embed ground truth patches dataset.compute_patch_embeddings( model, "ground_truth", embeddings_field="embeddings", batch_size=8, force_square=True, ) # Patch visualization fob.compute_visualization( dataset, patches_field="ground_truth", embeddings="embeddings", num_dims=2, verbose=True, seed=51, brain_key="gt_viz", ) return dataset if __name__ == "__main__": fo.config.show_progress_bars = True unittest.main(verbosity=2) ================================================ FILE: tests/models/test_simple_resnet.py ================================================ """ Tests for :mod:`fiftyone.brain.internal.models.simple_resnet`. | Copyright 2017-2026, Voxel51, Inc. | `voxel51.com `_ | """ import imageio from PIL import Image import torch import eta.core.image as etai import fiftyone as fo import fiftyone.core.utils as fou import fiftyone.zoo as foz import fiftyone.brain.internal.models as fbm def _transpose(x, source, target): return x.permute([source.index(d) for d in target]) def _check_prediction(actual, expected): assert isinstance(actual, fo.Classification) assert isinstance(expected, fo.Classification) # @todo fix me on 3.9 # assert actual.label == expected.label def test_simple_resnet(): dataset = foz.load_zoo_dataset( "cifar10", split="test", dataset_name=fo.get_default_dataset_name(), shuffle=True, max_samples=1, ) sample = dataset.first() filepath = sample.filepath print("Working on image at %s" % filepath) img_pil = Image.open(filepath) print("img_pil is type %s" % type(img_pil)) img_numpy = imageio.imread(filepath) print("img_numpy is type %s" % type(img_numpy)) print(img_numpy.shape) img_torch = torch.from_numpy(img_numpy) img_torch = _transpose(img_torch, "HWC", "CHW") print("img_torch is type %s" % type(img_torch)) print(img_torch.shape) assert tuple(reversed(img_torch.shape)) == img_numpy.shape img_eta = etai.read(filepath) print("img_eta is type %s" % type(img_eta)) print(img_eta.shape) assert tuple(img_eta.shape) == img_numpy.shape model = fbm.load_model("simple-resnet-cifar10") with model: print("PIL") p_pil = model.predict(img_pil) print(p_pil) print("IMAGEIO") p_numpy = model.predict(img_numpy) print(p_numpy) _check_prediction(p_numpy, p_pil) print("ETA") p_eta = model.predict(img_eta) print(p_eta) _check_prediction(p_eta, p_pil) print("PIL (manual preprocessing)") with fou.SetAttributes(model, preprocess=False): img_tensor = model.transforms(img_pil) p_pil2 = model.predict(img_tensor) print(p_pil2) _check_prediction(p_pil2, p_pil) print("IMAGEIO (manual preprocessing)") with fou.SetAttributes(model, preprocess=False): img_tensor = model.transforms(img_numpy) p_numpy2 = model.predict(img_tensor) print(p_numpy2) _check_prediction(p_numpy2, p_numpy) if __name__ == "__main__": test_simple_resnet() ================================================ FILE: tests/test_uniqueness.py ================================================ """ Uniqueness tests. | Copyright 2017-2026, Voxel51, Inc. | `voxel51.com `_ | """ import os import unittest import eta.core.storage as etas import eta.core.utils as etau import fiftyone as fo import fiftyone.brain as fob import fiftyone.zoo as foz def test_uniqueness(): dataset = foz.load_zoo_dataset("cifar10", split="test") assert "uniqueness" not in dataset.get_field_schema() view = dataset.view().take(100) fob.compute_uniqueness(view) print(dataset) assert "uniqueness" in dataset.get_field_schema() def test_gray(): """Test default support for handling grayscale images. Requires Voxel51 Google Drive credentials to download the test data. """ with etau.TempDir() as tmpdir: tmp_zip = os.path.join(tmpdir, "data.zip") tmp_data = os.path.join(tmpdir, "brain_grayscale_test_data") client = etas.GoogleDriveStorageClient() client.download("1ECeNnLmKQCHxlVdRqGefV5eXOD_OkmWx", tmp_zip) etau.extract_zip(tmp_zip, delete_zip=True) dataset = fo.Dataset.from_dir(tmp_data, fo.types.ImageDirectory) fob.compute_uniqueness(dataset) print(dataset) if __name__ == "__main__": fo.config.show_progress_bars = True unittest.main(verbosity=2)