Repository: voxel51/fiftyone-brain
Branch: develop
Commit: 05fccee0ae1c
Files: 61
Total size: 549.0 KB
Directory structure:
gitextract_zu49vpz0/
├── .github/
│ ├── CODEOWNERS
│ ├── dependabot.yml
│ ├── pull_request_template.md
│ └── workflows/
│ └── build.yml
├── .gitignore
├── .pre-commit-config.yaml
├── .prettierrc
├── CONTRIBUTING.md
├── LICENSE
├── MANIFEST.in
├── README.md
├── RELEASING.md
├── STYLE_GUIDE.md
├── fiftyone/
│ ├── __init__.py
│ └── brain/
│ ├── __init__.py
│ ├── config.py
│ ├── internal/
│ │ ├── __init__.py
│ │ ├── core/
│ │ │ ├── __init__.py
│ │ │ ├── duplicates.py
│ │ │ ├── elasticsearch.py
│ │ │ ├── hardness.py
│ │ │ ├── lancedb.py
│ │ │ ├── leaky_splits.py
│ │ │ ├── milvus.py
│ │ │ ├── mistakenness.py
│ │ │ ├── mongodb.py
│ │ │ ├── mosaic.py
│ │ │ ├── pgvector.py
│ │ │ ├── pinecone.py
│ │ │ ├── qdrant.py
│ │ │ ├── redis.py
│ │ │ ├── representativeness.py
│ │ │ ├── sklearn.py
│ │ │ ├── uniqueness.py
│ │ │ ├── utils.py
│ │ │ └── visualization.py
│ │ └── models/
│ │ ├── .gitignore
│ │ ├── __init__.py
│ │ ├── manifest.json
│ │ ├── simple_resnet.py
│ │ └── torch.py
│ ├── similarity.py
│ └── visualization.py
├── install.bat
├── install.sh
├── pylintrc
├── pyproject.toml
├── pytest.ini
├── requirements/
│ ├── build.txt
│ ├── common.txt
│ ├── dev.txt
│ └── prod.txt
├── requirements.txt
├── setup.py
└── tests/
├── README.md
├── intensive/
│ ├── test_interface.py
│ ├── test_similarity.py
│ ├── test_uniqueness.py
│ └── test_visualization.py
├── models/
│ └── test_simple_resnet.py
└── test_uniqueness.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/CODEOWNERS
================================================
* @voxel51/developers
# Aloha!
.github/ @voxel51/aloha-shirts
pyproject.toml @voxel51/aloha-shirts
RELEASING.md @voxel51/aloha-shirts
setup.py @voxel51/aloha-shirts
================================================
FILE: .github/dependabot.yml
================================================
---
version: 2
updates:
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "weekly"
day: "wednesday"
time: "14:00"
timezone: "UTC"
================================================
FILE: .github/pull_request_template.md
================================================
# Rationale
## Changes
## Testing
================================================
FILE: .github/workflows/build.yml
================================================
name: Build
on:
pull_request:
branches:
- develop
types: [opened, synchronize]
push:
branches:
- develop
tags:
- v*
jobs:
build:
runs-on: ubuntu-latest
steps:
- name: Clone fiftyone-brain
uses: actions/checkout@v6
with:
submodules: true
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: 3.9
- name: Check Python version
run: |
python --version
pip --version
- name: Install dependencies
run: |
pip install --upgrade pip setuptools wheel
pip install -r requirements/build.txt
- name: Set environment
env:
RELEASE_TAG: ${{ github.ref }}
run: |
if [[ $RELEASE_TAG =~ ^refs\/tags\/v.* ]]; then
echo "RELEASE_VERSION=$(echo '${{ github.ref }}' | sed 's/^refs\/tags\/v//')" >> $GITHUB_ENV
fi
- name: Build wheel
run: |
python setup.py sdist bdist_wheel
- name: Upload wheel
uses: actions/upload-artifact@v7
with:
name: dist
path: dist/
retention-days: 1
test:
needs: [build]
runs-on: ubuntu-latest
env:
FIFTYONE_DATASET_ZOO_DIR: ${{ github.workspace }}/.fiftyone
FIFTYONE_DO_NOT_TRACK: true
FIFTYONE_MODEL_ZOO_DIR: ${{ github.workspace }}/.fiftyone
permissions:
contents: read
id-token: write
strategy:
fail-fast: false
matrix:
python:
- "3.9"
- "3.10"
- "3.11"
steps:
- name: Clone fiftyone-brain
uses: actions/checkout@v6
with:
submodules: true
- name: Clone fiftyone
uses: actions/checkout@v6
with:
fetch-depth: 1
path: fiftyone-src
ref: develop
repository: voxel51/fiftyone
- name: Clone voxel51-eta
uses: actions/checkout@v6
if: ${{ !startsWith(github.ref, 'refs/heads/rel') && !startsWith(github.ref, 'refs/tags/') }}
with:
fetch-depth: 1
path: eta
ref: develop
repository: voxel51/eta
# ETA tests will create a storage client which,
# in it's __init__, tries to log in to GCP.
# See tests/tests_uniqueness.py
- name: Authenticate to Google Cloud
uses: google-github-actions/auth@v3
with:
project_id: ${{ secrets.REPO_GCP_PROJECT }}
service_account: ${{ secrets.REPO_GCP_SERVICE_ACCOUNT }}
workload_identity_provider: ${{ secrets.REPO_GOOGLE_WORKLOAD_IDP }}
- name: Set Up Cloud SDK
uses: google-github-actions/setup-gcloud@v3
- name: Set up Python ${{ matrix.python }}
uses: actions/setup-python@v6
with:
python-version: ${{ matrix.python }}
- name: Free Disk Space (Ubuntu) # standard runner's 14 GB available disk size isn't enough. Need at least 22 GB free.
uses: jlumbroso/free-disk-space@v1.3.1
- name: Install dependencies
run: |
pip install --upgrade pip setuptools wheel
- name: Download fiftyone-brain wheel
uses: actions/download-artifact@v8
with:
name: dist
path: dist/
- name: Install fiftyone
working-directory: fiftyone-src
run: |
python setup.py bdist_wheel
pip install voxel51-eta[storage] fiftyone-db
pip install ./dist/*.whl
- name: Install ETA from source
working-directory: eta
# Don't install from source if this is a release.
# Install from PyPI
if: ${{ !startsWith(github.ref, 'refs/heads/rel') && !startsWith(github.ref, 'refs/tags/') }}
run: |
echo "Installing ETA from source because github.ref = ${{ github.ref }} (not a release)"
python setup.py bdist_wheel
pip install ./dist/*.whl --force-reinstall
- name: Reinstall fiftyone-brain
run: |
pip install --force-reinstall --no-deps dist/*.whl
- name: Install test dependencies
run: |
pip install imageio pytest torch torchvision
- name: Cache Zoo
id: fiftyone-cache
uses: actions/cache@v5
with:
path: |
.fiftyone
key: zoo-${{ hashFiles('tests/**') }}
- name: Run tests
run: |
pytest --verbose tests/ --ignore tests/intensive/
publish:
needs: [build, test]
if: startsWith(github.ref, 'refs/tags/v')
runs-on: ubuntu-latest
environment: release # For trusted publishing. See below.
permissions:
contents: read
id-token: write
steps:
- name: Download wheels
uses: actions/download-artifact@v8
with:
name: dist
path: dist/
# Utilize
# [trusted publishers](https://docs.pypi.org/trusted-publishers/)
# This will use OIDC to publish the dists/ package to pypi.
# See
# [fiftyone-brain](https://pypi.org/manage/project/fiftyone-brain/settings/publishing/)
- name: Publish
uses: pypa/gh-action-pypi-publish@v1.14.0
================================================
FILE: .gitignore
================================================
__pycache__
.DS_store
.ipynb_checkpoints
*~
*.egg-info
*.py[cod]
*.pth
*.swp
.idea
.project
.pydevproject
build/
dist/
/fiftyone/brain/internal/models/cache/**/*
!/fiftyone/brain/internal/models/cache/manifest.json
*.pth
================================================
FILE: .pre-commit-config.yaml
================================================
repos:
- repo: https://github.com/asottile/blacken-docs
rev: v1.12.0
hooks:
- id: blacken-docs
additional_dependencies: [black==21.12b0]
args: ["-l 79"]
- repo: https://github.com/ambv/black
rev: 22.3.0
hooks:
- id: black
language_version: python3
args: ["-l 79"]
- repo: local
hooks:
- id: pylint
name: pylint
language: system
files: \.py$
entry: pylint
args: ["--errors-only"]
- repo: local
hooks:
- id: ipynb-strip
name: ipynb-strip
language: system
files: \.ipynb$
entry: jupyter nbconvert --clear-output --ClearOutputPreprocessor.enabled=True
args: ["--log-level=ERROR"]
- repo: https://github.com/pre-commit/mirrors-prettier
rev: v2.6.2
hooks:
- id: prettier
language_version: system
================================================
FILE: .prettierrc
================================================
{
"overrides": [
{
"files": "*.md",
"options": {
"printWidth": 79,
"proseWrap": "always",
"tabWidth": 4
}
},
{
"files": "*.json",
"options": {
"tabWidth": 4
}
}
]
}
================================================
FILE: CONTRIBUTING.md
================================================
# Contributing to FiftyOne Brain
All Brain contributions should follow the practices established in
[FiftyOne](https://github.com/voxel51/fiftyone/blob/develop/CONTRIBUTING.md).
## Adding new public methods to the Brain package
The `fiftyone.brain` package should expose all core user-functionality at the
base level. For example, for hardness, the user should be able to execute calls
in the following way:
```py
# Users should be able to do this
import fiftyone.brain as fob
fob.compute_hardness(...)
# And NOT have to do this
import fiftyone.brain.hardness as fobh
fobh.compute_hardness(...)
```
To achieve this, follow the existing pattern of declaring new public methods in
[`fiftyone/brain/__init__.py`](https://github.com/voxel51/fiftyone-brain/blob/develop/fiftyone/brain/__init__.py).
Be sure to include a detailed docstring for all methods in this file, as they
are pulled in by FiftyOne documentation builds and are made available in the
[public docs](https://docs.voxel51.com/api/fiftyone.brain.html).
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2017-2026, Voxel51, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: MANIFEST.in
================================================
global-include *
prune fiftyone/brain/internal/models/cache/
include fiftyone/brain/internal/models/cache/manifest.json
================================================
FILE: README.md
================================================
**Open Source AI from [Voxel51](https://voxel51.com)**
FiftyOne Website •
FiftyOne Docs •
FiftyOne Brain Docs •
Blog •
Community
[](https://pypi.org/project/fiftyone-brain)
[](https://pypi.org/project/fiftyone-brain)
[](https://pepy.tech/project/fiftyone-brain)
[](LICENSE)
[](https://discord.gg/fiftyone-community)
[](https://huggingface.co/Voxel51)
[](https://voxel51.com/blog)
[](https://share.hsforms.com/1zpJ60ggaQtOoVeBqIZdaaA2ykyk)
[](https://www.linkedin.com/company/voxel51)
[](https://x.com/voxel51)
[](https://medium.com/voxel51)
---
FiftyOne Brain contains the open source AI/ML capabilities for the
[FiftyOne ecosystem](https://github.com/voxel51/fiftyone), enabling users to
automatically analyze and manipulate their datasets and models. FiftyOne Brain
includes features like visual similarity search, query by text, finding unique
and representative samples, finding media quality problems and annotation
mistakes, and more 🚀
## Documentation
Public documentation for the FiftyOne Brain is
[available here](https://docs.voxel51.com/user_guide/brain.html).
## Installation
The FiftyOne Brain is distributed via the `fiftyone-brain` package, and a
suitable version is automatically included with every `fiftyone` install:
```shell
pip install fiftyone
pip show fiftyone-brain
```
### Installing from source
If you wish to do a source install of the latest FiftyOne Brain version, simply
clone this repository:
```shell
git clone https://github.com/voxel51/fiftyone-brain
cd fiftyone-brain
```
and run the install script:
```shell
# Mac or Linux
bash install.sh
# Windows
.\install.bat
```
### Developer installation
If you are a developer contributing to this repository, you should perform a
developer installation using the `-d` flag of the install script:
```shell
# Mac or Linux
bash install.sh -d
# Windows
.\install.bat -d
```
Check out the [contribution guide](CONTRIBUTING.md) to get started.
## Uninstallation
```shell
pip uninstall fiftyone-brain
```
## Repository layout
- `fiftyone/brain/` definition of the `fiftyone.brain` namespace
- `requirements/` Python requirements for the project
- `tests/` tests for the various components of the Brain
## Citation
If you use the FiftyOne Brain in your research, please cite the project:
```bibtex
@article{moore2020fiftyone,
title={FiftyOne},
author={Moore, B. E. and Corso, J. J.},
journal={GitHub. Note: https://github.com/voxel51/fiftyone-brain},
year={2020}
}
```
================================================
FILE: RELEASING.md
================================================
# Releasing the Brain package
> [!NOTE]
> These steps are to be performed by authorized Voxel51 engineers.
The `fiftyone-brain` repository follows `Gitflow`.
Releases will be initiated when a teammate submits a
pull request from their respective `release/v*` branch to `main`.
We can see an example PR for
[version 0.21.4](https://github.com/voxel51/fiftyone-brain/pull/265).
Reviewers should always check that the version in the `setup.py`
matches the branch version.
The release engineer will merge the pull request once it is approved.
The PyPI uploads will be triggered when a release tag is pushed to the
repository:
1. Navigate to the
[releases page](https://github.com/voxel51/fiftyone-brain/pull/265).
1. Select `Draft a new release`.
1. Select `Create new tag` with the appropriate version and set the target to
`main`.
1. The tag format is `v`.
For example, `v0.21.4`.
This should match the `setup.py` and release branch.
1. Select `Generate release notes`.
1. Select `Set as the latest release`.
1. Select `Publish release`.
This will create a new tag in the repository and will trigger the
[build/publish workflow](https://github.com/voxel51/fiftyone-brain/actions/workflows/build.yml).
This workflow will build the `.whl` artifacts and publish them to
[PyPI](https://pypi.org/project/fiftyone-brain/).
Once the build are finished, submit a PR from `main` to `develop` to complete
the `Gitflow` process.
================================================
FILE: STYLE_GUIDE.md
================================================
# FiftyOne Brain Style Guide
The Brain follows the same style guidelines as
[FiftyOne](https://github.com/voxel51/fiftyone/blob/develop/STYLE_GUIDE.md).
================================================
FILE: fiftyone/__init__.py
================================================
from pkgutil import extend_path
#
# This statement allows multiple `fiftyone.XXX` packages to be installed in the
# same environment and used simultaneously.
#
# https://docs.python.org/3/library/pkgutil.html#pkgutil.extend_path
#
__path__ = extend_path(__path__, __name__)
from fiftyone.__public__ import *
================================================
FILE: fiftyone/brain/__init__.py
================================================
"""
The brains behind FiftyOne: a powerful package for dataset curation, analysis,
and visualization.
See https://github.com/voxel51/fiftyone for more information.
| Copyright 2017-2026, Voxel51, Inc.
| `voxel51.com `_
|
"""
import fiftyone.brain.config as _foc
from .similarity import (
Similarity,
SimilarityConfig,
SimilarityIndex,
)
from .visualization import (
Visualization,
VisualizationConfig,
VisualizationResults,
)
brain_config = _foc.load_brain_config()
def compute_hardness(
samples,
label_field,
hardness_field="hardness",
progress=None,
):
"""Adds a hardness field to each sample scoring the difficulty that the
specified label field observed in classifying the sample.
Hardness is a measure computed based on model prediction output (through
logits) that summarizes a measure of the uncertainty the model had with the
sample. This makes hardness quantitative and can be used to detect things
like hard samples, annotation errors during noisy training, and more.
All classifications must have their
:attr:`logits ` attributes
populated in order to use this method.
.. note::
Runs of this method can be referenced later via brain key
``hardness_field``.
Args:
samples: a :class:`fiftyone.core.collections.SampleCollection`
label_field: the :class:`fiftyone.core.labels.Classification` or
:class:`fiftyone.core.labels.Classifications` field to use from
each sample
hardness_field ("hardness"): the field name to use to store the
hardness value for each sample
progress (None): whether to render a progress bar (True/False), use the
default value ``fiftyone.config.show_progress_bars`` (None), or a
progress callback function to invoke instead
"""
import fiftyone.brain.internal.core.hardness as fbh
return fbh.compute_hardness(samples, label_field, hardness_field, progress)
def compute_mistakenness(
samples,
pred_field,
label_field,
mistakenness_field="mistakenness",
missing_field="possible_missing",
spurious_field="possible_spurious",
use_logits=False,
copy_missing=False,
progress=None,
):
"""Computes the mistakenness (likelihood of being incorrect) of the labels
in ``label_field`` based on the predcted labels in ``pred_field``.
Mistakenness is measured based on either the ``confidence`` or ``logits``
of the predictions in ``pred_field``. This measure can be used to detect
things like annotation errors and unusually hard samples.
For classifications, a ``mistakenness_field`` field is populated on each
sample that quantifies the likelihood that the label in the ``label_field``
of that sample is incorrect.
For objects (detections, polylines, keypoints, etc), the mistakenness of
each object in ``label_field`` is computed, using
:meth:`fiftyone.core.collections.SampleCollection.evaluate_detections` to
locate corresponding objects in ``pred_field``. Three types of mistakes
are identified:
- **(Mistakes)** Objects in ``label_field`` with a match in
``pred_field`` are assigned a mistakenness value in their
``mistakenness_field`` that captures the likelihood that the class
label of the object in ``label_field`` is a mistake. A
``mistakenness_field + "_loc"`` field is also populated that captures
the likelihood that the object in ``label_field`` is a mistake due
to its localization (bounding box).
- **(Missing)** Objects in ``pred_field`` with no matches in
``label_field`` but which are likely to be correct will have their
``missing_field`` attribute set to True. In addition, if
``copy_missing`` is True, copies of these objects are *added* to the
ground truth ``label_field``.
- **(Spurious)** Objects in ``label_field`` with no matches in
``pred_field`` but which are likely to be incorrect will have their
``spurious_field`` attribute set to True.
In addition, for objects, the following sample-level fields are populated:
- **(Mistakes)** The ``mistakenness_field`` of each sample is populated
with the maximum mistakenness of the objects in ``label_field``
- **(Missing)** The ``missing_field`` of each sample is populated with
the number of missing objects that were deemed missing from
``label_field``.
- **(Spurious)** The ``spurious_field`` of each sample is populated with
the number of objects in ``label_field`` that were given deemed
spurious.
.. note::
Runs of this method can be referenced later via brain key
``mistakenness_field``.
Args:
samples: a :class:`fiftyone.core.collections.SampleCollection`
pred_field: the name of the predicted label field to use from each
sample. Can be of type
:class:`fiftyone.core.labels.Classification`,
:class:`fiftyone.core.labels.Classifications`,
:class:`fiftyone.core.labels.Detections`,
:class:`fiftyone.core.labels.Polylines`,
:class:`fiftyone.core.labels.Keypoints`, or
:class:`fiftyone.core.labels.TemporalDetections`
label_field: the name of the "ground truth" label field that you want
to test for mistakes with respect to the predictions in
``pred_field``. Must have the same type as ``pred_field``
mistakenness_field ("mistakenness"): the field name to use to store the
mistakenness value for each sample
missing_field ("possible_missing): the field in which to store
per-sample counts of potential missing objects
spurious_field ("possible_spurious): the field in which to store
per-sample counts of potential spurious objects
use_logits (False): whether to use logits (True) or confidence (False)
to compute mistakenness. Logits typically yield better results,
when they are available
copy_missing (False): whether to copy predicted objects that were
deemed to be missing into ``label_field``
progress (None): whether to render a progress bar (True/False), use the
default value ``fiftyone.config.show_progress_bars`` (None), or a
progress callback function to invoke instead
"""
import fiftyone.brain.internal.core.mistakenness as fbm
return fbm.compute_mistakenness(
samples,
pred_field,
label_field,
mistakenness_field,
missing_field,
spurious_field,
use_logits,
copy_missing,
progress,
)
def compute_uniqueness(
samples,
uniqueness_field="uniqueness",
roi_field=None,
embeddings=None,
similarity_index=None,
model=None,
model_kwargs=None,
force_square=False,
alpha=None,
batch_size=None,
num_workers=None,
skip_failures=True,
progress=None,
):
"""Adds a uniqueness field to each sample scoring how unique it is with
respect to the rest of the samples.
This function only uses the pixel data and can therefore process labeled or
unlabeled samples.
If no ``embeddings``, ``similarity_index``, or ``model`` is provided, a
default model is used to generate embeddings.
.. note::
Runs of this method can be referenced later via brain key
``uniqueness_field``.
Args:
samples: a :class:`fiftyone.core.collections.SampleCollection`
uniqueness_field ("uniqueness"): the field name to use to store the
uniqueness value for each sample
roi_field (None): an optional :class:`fiftyone.core.labels.Detection`,
:class:`fiftyone.core.labels.Detections`,
:class:`fiftyone.core.labels.Polyline`, or
:class:`fiftyone.core.labels.Polylines` field defining a region of
interest within each image to use to compute uniqueness
embeddings (None): if no ``model`` is provided, this argument specifies
pre-computed embeddings to use, which can be any of the following:
- a ``num_samples x num_dims`` array of embeddings
- if ``roi_field`` is specified, a dict mapping sample IDs to
``num_patches x num_dims`` arrays of patch embeddings
- the name of a dataset field containing the embeddings to use
If a ``model`` is provided, this argument specifies the name of a
field in which to store the computed embeddings. In either case,
when working with patch embeddings, you can provide either the
fully-qualified path to the patch embeddings or just the name of
the label attribute in ``roi_field``
similarity_index (None): a
:class:`fiftyone.brain.similarity.SimilarityIndex` or the brain key
of a similarity index to use to load pre-computed embeddings
model (None): a :class:`fiftyone.core.models.Model` or the name of a
model from the
`FiftyOne Model Zoo `_
to use to generate embeddings. The model must expose embeddings
(``model.has_embeddings = True``)
model_kwargs (None): a dictionary of optional keyword arguments to pass
to the model's ``Config`` when a model name is provided
force_square (False): whether to minimally manipulate the patch
bounding boxes into squares prior to extraction. Only applicable
when a ``model`` and ``roi_field`` are specified
alpha (None): an optional expansion/contraction to apply to the patches
before extracting them, in ``[-1, inf)``. If provided, the length
and width of the box are expanded (or contracted, when
``alpha < 0``) by ``(100 * alpha)%``. For example, set
``alpha = 0.1`` to expand the boxes by 10%, and set
``alpha = -0.1`` to contract the boxes by 10%. Only applicable when
a ``model`` and ``roi_field`` are specified
batch_size (None): a batch size to use when computing embeddings. Only
applicable when a ``model`` is provided
num_workers (None): the number of workers to use when loading images.
Only applicable when a Torch-based model is being used to compute
embeddings
skip_failures (True): whether to gracefully continue without raising an
error if embeddings cannot be generated for a sample
progress (None): whether to render a progress bar (True/False), use the
default value ``fiftyone.config.show_progress_bars`` (None), or a
progress callback function to invoke instead
"""
import fiftyone.brain.internal.core.uniqueness as fbu
return fbu.compute_uniqueness(
samples,
uniqueness_field,
roi_field,
embeddings,
similarity_index,
model,
model_kwargs,
force_square,
alpha,
batch_size,
num_workers,
skip_failures,
progress,
)
def compute_representativeness(
samples,
representativeness_field="representativeness",
method="cluster-center",
roi_field=None,
embeddings=None,
similarity_index=None,
model=None,
model_kwargs=None,
force_square=False,
alpha=None,
batch_size=None,
num_workers=None,
skip_failures=True,
progress=None,
):
"""Adds a representativeness field to each sample scoring how representative
of nearby samples it is.
This function only uses the pixel data and can therefore process labeled or
unlabeled samples.
If no ``embeddings``, ``similarity_index``, or ``model`` is provided, a
default model is used to generate embeddings.
.. note::
Runs of this method can be referenced later via brain key
``representativeness_field``.
Args:
samples: a :class:`fiftyone.core.collections.SampleCollection`
representativeness_field ("representativeness"): the field name to use
to store the representativeness value for each sample
method ("cluster-center"): the name of the method to use to compute the
representativeness. The supported values are
``["cluster-center", 'cluster-center-downweight']``.
``"cluster-center"` will make a sample's representativeness
proportional to it's proximity to cluster centers, while
``"cluster-center-downweight"`` will ensure more diversity in
representative samples
roi_field (None): an optional :class:`fiftyone.core.labels.Detection`,
:class:`fiftyone.core.labels.Detections`,
:class:`fiftyone.core.labels.Polyline`, or
:class:`fiftyone.core.labels.Polylines` field defining a region of
interest within each image to use to compute representativeness
embeddings (None): if no ``model`` is provided, this argument specifies
pre-computed embeddings to use, which can be any of the following:
- a ``num_samples x num_dims`` array of embeddings
- if ``roi_field`` is specified, a dict mapping sample IDs to
``num_patches x num_dims`` arrays of patch embeddings
- the name of a dataset field containing the embeddings to use
If a ``model`` is provided, this argument specifies the name of a
field in which to store the computed embeddings. In either case,
when working with patch embeddings, you can provide either the
fully-qualified path to the patch embeddings or just the name of
the label attribute in ``roi_field``
similarity_index (None): a
:class:`fiftyone.brain.similarity.SimilarityIndex` or the brain key
of a similarity index to use to load pre-computed embeddings
model (None): a :class:`fiftyone.core.models.Model` or the name of a
model from the
`FiftyOne Model Zoo `_
to use to generate embeddings. The model must expose embeddings
(``model.has_embeddings = True``)
model_kwargs (None): a dictionary of optional keyword arguments to pass
to the model's ``Config`` when a model name is provided
force_square (False): whether to minimally manipulate the patch
bounding boxes into squares prior to extraction. Only applicable
when a ``model`` and ``roi_field`` are specified
alpha (None): an optional expansion/contraction to apply to the patches
before extracting them, in ``[-1, inf)``. If provided, the length
and width of the box are expanded (or contracted, when
``alpha < 0``) by ``(100 * alpha)%``. For example, set
``alpha = 0.1`` to expand the boxes by 10%, and set
``alpha = -0.1`` to contract the boxes by 10%. Only applicable when
a ``model`` and ``roi_field`` are specified
batch_size (None): a batch size to use when computing embeddings. Only
applicable when a ``model`` is provided
num_workers (None): the number of workers to use when loading images.
Only applicable when a Torch-based model is being used to compute
embeddings
skip_failures (True): whether to gracefully continue without raising an
error if embeddings cannot be generated for a sample
progress (None): whether to render a progress bar (True/False), use the
default value ``fiftyone.config.show_progress_bars`` (None), or a
progress callback function to invoke instead
"""
import fiftyone.brain.internal.core.representativeness as fbr
return fbr.compute_representativeness(
samples,
representativeness_field,
method,
roi_field,
embeddings,
similarity_index,
model,
model_kwargs,
force_square,
alpha,
batch_size,
num_workers,
skip_failures,
progress,
)
def compute_visualization(
samples,
patches_field=None,
embeddings=None,
points=None,
create_index=False,
points_field=None,
brain_key=None,
num_dims=2,
method=None,
similarity_index=None,
model=None,
model_kwargs=None,
force_square=False,
alpha=None,
batch_size=None,
num_workers=None,
skip_failures=True,
progress=None,
**kwargs,
):
"""Computes a low-dimensional representation of the samples' media or their
patches that can be interactively visualized.
The representation can be visualized by calling the
:meth:`visualize() `
method of the returned
:class:`fiftyone.brain.visualization.VisualizationResults` object.
If no ``embeddings``, ``similarity_index``, or ``model`` is provided, a
default model is used to generate embeddings.
You can use the ``method`` parameter to select the dimensionality reduction
method to use, and you can optionally customize the method by passing
additional parameters for the method's
:class:`fiftyone.brain.visualization.VisualizationConfig` class as
``kwargs``.
The builtin ``method`` values and their associated config classes are:
- ``"umap"``: :class:`fiftyone.brain.visualization.UMAPVisualizationConfig`
- ``"tsne"``: :class:`fiftyone.brain.visualization.TSNEVisualizationConfig`
- ``"pca"``: :class:`fiftyone.brain.visualization.PCAVisualizationConfig`
- ``"manual"``: :class:`fiftyone.brain.visualization.ManualVisualizationConfig`
You can pass ``create_index=True`` to create a spatial index of the
computed points on your dataset's samples. This is highly recommended for
large datasets as it enables efficient querying when lassoing points in
embeddings plots. By default, spatial indexes are created in a field with
name ``points_field=brain_key``, but you can customize this by manually
providing a ``points_field``.
You can also provide a ``points_field`` with ``create_index=False`` to
store the points on your dataset without explicitly creating a database
index. This will allow lasso callbacks to leverage point data rather than
relying on ID selection, but without the added benefit of a database index
to further optimize performance.
Args:
samples: a :class:`fiftyone.core.collections.SampleCollection`
patches_field (None): a sample field defining the image patches in each
sample that have been/will be embedded. Must be of type
:class:`fiftyone.core.labels.Detection`,
:class:`fiftyone.core.labels.Detections`,
:class:`fiftyone.core.labels.Polyline`, or
:class:`fiftyone.core.labels.Polylines`
embeddings (None): if no ``model`` is provided, this argument specifies
pre-computed embeddings to use, which can be any of the following:
- a dict mapping sample IDs to embedding vectors
- a ``num_samples x num_embedding_dims`` array of embeddings
corresponding to the samples in ``samples``
- if ``patches_field`` is specified, a dict mapping label IDs to
to embedding vectors
- if ``patches_field`` is specified, a dict mapping sample IDs
to ``num_patches x num_embedding_dims`` arrays of patch
embeddings
- the name of a dataset field containing the embeddings to use
If a ``model`` is provided, this argument specifies the name of a
field in which to store the computed embeddings. In either case,
when working with patch embeddings, you can provide either the
fully-qualified path to the patch embeddings or just the name of
the label attribute in ``patches_field``
points (None): a pre-computed low-dimensional representation to use. If
provided, no embeddings will be used/computed. Can be any of the
following:
- a dict mapping sample IDs to points vectors
- a ``num_samples x num_dims`` array of points corresponding to
the samples in ``samples``
- if ``patches_field`` is specified, a dict mapping label IDs to
points vectors
- if ``patches_field`` is specified, a ``num_patches x num_dims``
array of points whose rows correspond to the flattened list of
patches whose IDs are shown below::
# The list of patch IDs that the rows of `points` must match
_, id_field = samples._get_label_field_path(patches_field, "id")
patch_ids = samples.values(id_field, unwind=True)
create_index (False): whether to create a spatial index for the
computed points on your dataset
points_field (None): an optional field name in which to store the
spatial index. When ``create_index=True``, this defaults to
``points_field=brain_key``. When working with patches, you can
provide either the fully-qualified path to the points field or just
the name of the label attribute in ``patches_field``
brain_key (None): a brain key under which to store the results of this
method
num_dims (2): the dimension of the visualization space
method (None): the dimensionality reduction method to use. The
supported values are
``fiftyone.brain.brain_config.visualization_methods.keys()`` and
the default is
``fiftyone.brain.brain_config.default_visualization_method``
similarity_index (None): a
:class:`fiftyone.brain.similarity.SimilarityIndex` or the brain key
of a similarity index to use to load pre-computed embeddings
model (None): a :class:`fiftyone.core.models.Model` or the name of a
model from the
`FiftyOne Model Zoo `_
to use to generate embeddings. The model must expose embeddings
(``model.has_embeddings = True``)
model_kwargs (None): a dictionary of optional keyword arguments to pass
to the model's ``Config`` when a model name is provided
force_square (False): whether to minimally manipulate the patch
bounding boxes into squares prior to extraction. Only applicable
when a ``model`` and ``patches_field`` are specified
alpha (None): an optional expansion/contraction to apply to the patches
before extracting them, in ``[-1, inf)``. If provided, the length
and width of the box are expanded (or contracted, when
``alpha < 0``) by ``(100 * alpha)%``. For example, set
``alpha = 0.1`` to expand the boxes by 10%, and set
``alpha = -0.1`` to contract the boxes by 10%. Only applicable when
a ``model`` and ``patches_field`` are specified
batch_size (None): an optional batch size to use when computing
embeddings. Only applicable when a ``model`` is provided
num_workers (None): the number of workers to use when loading images.
Only applicable when a Torch-based model is being used to compute
embeddings
skip_failures (True): whether to gracefully continue without raising an
error if embeddings cannot be generated for a sample
progress (None): whether to render a progress bar (True/False), use the
default value ``fiftyone.config.show_progress_bars`` (None), or a
progress callback function to invoke instead
**kwargs: optional keyword arguments for the constructor of the
:class:`fiftyone.brain.visualization.VisualizationConfig`
being used
Returns:
a :class:`fiftyone.brain.visualization.VisualizationResults`
"""
import fiftyone.brain.visualization as fbv
return fbv.compute_visualization(
samples,
patches_field,
embeddings,
points,
create_index,
points_field,
brain_key,
num_dims,
method,
similarity_index,
model,
model_kwargs,
force_square,
alpha,
batch_size,
num_workers,
skip_failures,
progress,
**kwargs,
)
def compute_similarity(
samples,
patches_field=None,
roi_field=None,
embeddings=None,
brain_key=None,
model=None,
model_kwargs=None,
force_square=False,
alpha=None,
batch_size=None,
num_workers=None,
skip_failures=True,
progress=None,
backend=None,
**kwargs,
):
"""Uses embeddings to index the samples or their patches so that you can
query/sort by similarity.
Calling this method only creates the index. You can then call the methods
exposed on the retuned :class:`fiftyone.brain.similarity.SimilarityIndex`
object to perform the following operations:
- :meth:`sort_by_similarity() `:
Sort the samples in the collection by similarity to a specific example
or example(s)
All indexes support querying by image similarity by passing sample IDs to
:meth:`sort_by_similarity() `.
In addition, if you pass the name of a model from the
`FiftyOne Model Zoo `_
like ``model="clip-vit-base32-torch"`` that can embed prompts to this
method, then you can query the index by text similarity as well.
In addition, if the backend supports it, you can call the following
duplicate detection methods:
- :meth:`find_duplicates() `:
Query the index to find all examples with near-duplicates in the
collection
- :meth:`find_unique() `:
Query the index to select a subset of examples of a specified size that
are maximally unique with respect to each other
If no ``embeddings`` or ``model`` is provided, a default model is used to
generate embeddings.
Args:
samples: a :class:`fiftyone.core.collections.SampleCollection`
patches_field (None): a sample field defining the image patches in each
sample that have been/will be embedded. Must be of type
:class:`fiftyone.core.labels.Detection`,
:class:`fiftyone.core.labels.Detections`,
:class:`fiftyone.core.labels.Polyline`, or
:class:`fiftyone.core.labels.Polylines`
roi_field (None): an optional :class:`fiftyone.core.labels.Detection`,
:class:`fiftyone.core.labels.Detections`,
:class:`fiftyone.core.labels.Polyline`, or
:class:`fiftyone.core.labels.Polylines` field defining a region of
interest within each image to use to compute embeddings
embeddings (None): embeddings to feed the index. This argument's
behavior depends on whether a ``model`` is provided, as described
below.
If no ``model`` is provided, this argument specifies pre-computed
embeddings to use:
- a ``num_samples x num_dims`` array of embeddings
- if ``patches_field``/``roi_field`` is specified, a dict
mapping sample IDs to ``num_patches x num_dims`` arrays of
patch embeddings
- the name of a dataset field from which to load embeddings
- ``None``: use the default model to compute embeddings
- ``False``: **do not** compute embeddings right now
If a ``model`` is provided, this argument specifies where to store
the model's embeddings:
- the name of a field in which to store the computed embeddings
- ``False``: **do not** compute embeddings right now
In either case, when working with patch embeddings, you can provide
either the fully-qualified path to the patch embeddings or just the
name of the label attribute in ``patches_field``/``roi_field``
brain_key (None): a brain key under which to store the results of this
method
model (None): a :class:`fiftyone.core.models.Model` or the name of a
model from the
`FiftyOne Model Zoo `_
to use, or that was already used, to generate embeddings. The model
must expose embeddings (``model.has_embeddings = True``)
model_kwargs (None): a dictionary of optional keyword arguments to pass
to the model's ``Config`` when a model name is provided
force_square (False): whether to minimally manipulate the patch
bounding boxes into squares prior to extraction. Only applicable
when a ``model`` and ``patches_field``/``roi_field`` are specified
alpha (None): an optional expansion/contraction to apply to the patches
before extracting them, in ``[-1, inf)``. If provided, the length
and width of the box are expanded (or contracted, when
``alpha < 0``) by ``(100 * alpha)%``. For example, set
``alpha = 0.1`` to expand the boxes by 10%, and set
``alpha = -0.1`` to contract the boxes by 10%. Only applicable when
a ``model`` and ``patches_field``/``roi_field`` are specified
batch_size (None): an optional batch size to use when computing
embeddings. Only applicable when a ``model`` is provided
num_workers (None): the number of workers to use when loading images.
Only applicable when a Torch-based model is being used to compute
embeddings
skip_failures (True): whether to gracefully continue without raising an
error if embeddings cannot be generated for a sample
progress (None): whether to render a progress bar (True/False), use the
default value ``fiftyone.config.show_progress_bars`` (None), or a
progress callback function to invoke instead
backend (None): the similarity backend to use. The supported values are
``fiftyone.brain.brain_config.similarity_backends.keys()`` and the
default is
``fiftyone.brain.brain_config.default_similarity_backend``
**kwargs: keyword arguments for the
:class:`fiftyone.brian.SimilarityConfig` subclass of the backend
being used
Returns:
a :class:`fiftyone.brain.similarity.SimilarityIndex`
"""
import fiftyone.brain.similarity as fbs
return fbs.compute_similarity(
samples,
patches_field,
roi_field,
embeddings,
brain_key,
model,
model_kwargs,
force_square,
alpha,
batch_size,
num_workers,
skip_failures,
progress,
backend,
**kwargs,
)
def compute_near_duplicates(
samples,
threshold=0.2,
roi_field=None,
embeddings=None,
similarity_index=None,
model=None,
model_kwargs=None,
force_square=False,
alpha=None,
batch_size=None,
num_workers=None,
skip_failures=True,
progress=None,
):
"""Detects potential duplicates in the given sample collection.
Calling this method only initializes the index. You can then call the
methods exposed on the returned object to perform the following operations:
- :meth:`duplicate_ids `:
A list of duplicate IDs
- :meth:`neighbors_map `:
A dictionary mapping IDs to lists of ``(dup_id, dist)`` tuples
- :meth:`duplicates_view() `:
Returns a view of all duplicates in the input collection
Args:
samples: a :class:`fiftyone.core.collections.SampleCollection`
threshold (0.2): the similarity distance threshold to use when
detecting duplicates. Values in ``[0.1, 0.25]`` work well for the
default setup
roi_field (None): an optional :class:`fiftyone.core.labels.Detection`,
:class:`fiftyone.core.labels.Detections`,
:class:`fiftyone.core.labels.Polyline`, or
:class:`fiftyone.core.labels.Polylines` field defining a region of
interest within each image to use to compute embeddings
embeddings (None): if no ``model`` is provided, this argument specifies
pre-computed embeddings to use, which can be any of the following:
- a ``num_samples x num_dims`` array of embeddings
- if ``roi_field`` is specified, a dict mapping sample IDs to
``num_patches x num_dims`` arrays of patch embeddings
- the name of a dataset field containing the embeddings to use
If a ``model`` is provided, this argument specifies the name of a
field in which to store the computed embeddings. In either case,
when working with patch embeddings, you can provide either the
fully-qualified path to the patch embeddings or just the name of
the label attribute in ``roi_field``
similarity_index (None): a
:class:`fiftyone.brain.similarity.SimilarityIndex` or the brain key
of a similarity index to use to load pre-computed embeddings
model (None): a :class:`fiftyone.core.models.Model` or the name of a
model from the
`FiftyOne Model Zoo `_
to use to generate embeddings. The model must expose embeddings
(``model.has_embeddings = True``)
model_kwargs (None): a dictionary of optional keyword arguments to pass
to the model's ``Config`` when a model name is provided
force_square (False): whether to minimally manipulate the patch
bounding boxes into squares prior to extraction. Only applicable
when a ``model`` and ``roi_field`` are specified
alpha (None): an optional expansion/contraction to apply to the patches
before extracting them, in ``[-1, inf)``. If provided, the length
and width of the box are expanded (or contracted, when
``alpha < 0``) by ``(100 * alpha)%``. For example, set
``alpha = 0.1`` to expand the boxes by 10%, and set
``alpha = -0.1`` to contract the boxes by 10%. Only applicable when
a ``model`` and ``roi_field`` are specified
batch_size (None): a batch size to use when computing embeddings. Only
applicable when a ``model`` is provided
num_workers (None): the number of workers to use when loading images.
Only applicable when a Torch-based model is being used to compute
embeddings
skip_failures (True): whether to gracefully continue without raising an
error if embeddings cannot be generated for a sample
progress (None): whether to render a progress bar (True/False), use the
default value ``fiftyone.config.show_progress_bars`` (None), or a
progress callback function to invoke instead
Returns:
a :class:`fiftyone.brain.similarity.SimilarityIndex`
"""
import fiftyone.brain.internal.core.duplicates as fbd
return fbd.compute_near_duplicates(
samples,
threshold=threshold,
roi_field=roi_field,
embeddings=embeddings,
similarity_index=similarity_index,
model=model,
model_kwargs=model_kwargs,
force_square=force_square,
alpha=alpha,
batch_size=batch_size,
num_workers=num_workers,
skip_failures=skip_failures,
progress=progress,
)
def compute_exact_duplicates(
samples,
num_workers=None,
skip_failures=True,
progress=None,
):
"""Detects duplicate media in a sample collection.
This method detects exact duplicates with the same filehash. Use
:meth:`compute_near_duplicates` to detect near-duplicates.
If duplicates are found, the first instance in ``samples`` will be the key
in the returned dictionary, while the subsequent duplicates will be the
values in the corresponding list.
Args:
samples: a :class:`fiftyone.core.collections.SampleCollection`
num_workers (None): an optional number of processes to use
skip_failures (True): whether to gracefully ignore samples whose
filehash cannot be computed
progress (None): whether to render a progress bar (True/False), use the
default value ``fiftyone.config.show_progress_bars`` (None), or a
progress callback function to invoke instead
Returns:
a dictionary mapping IDs of samples with exact duplicates to lists of
IDs of the duplicates for the corresponding sample
"""
import fiftyone.brain.internal.core.duplicates as fbd
return fbd.compute_exact_duplicates(
samples, num_workers, skip_failures, progress
)
def compute_leaky_splits(
samples,
splits,
threshold=0.2,
roi_field=None,
embeddings=None,
similarity_index=None,
model=None,
model_kwargs=None,
force_square=False,
alpha=None,
batch_size=None,
num_workers=None,
skip_failures=True,
progress=None,
):
"""Computes potential leaks between splits of the given sample collection.
Calling this method only initializes the index. You can then call the
methods exposed on the returned object to perform the following operations:
- :meth:`leaks_view() `:
Returns a view of all leaks in the input collection
- :meth:`no_leaks_view() `:
Returns the subset of the input collection without any leaks
- :meth:`leaks_for_sample() `:
Returns a view with leaks corresponding to the given sample
- :meth:`tag_leaks() `:
Tags leaks in the dataset as leaks
Args:
samples: a :class:`fiftyone.core.collections.SampleCollection`
splits: the dataset splits, specified in one of the following ways:
- a list of tag strings
- the name of a string/list field that encodes the split
memberships
- a dict mapping split names to
:class:`fiftyone.core.view.DatasetView` instances
threshold (0.2): the similarity distance threshold to use when
detecting leaks. Values in ``[0.1, 0.25]`` work well for the
default setup
roi_field (None): an optional :class:`fiftyone.core.labels.Detection`,
:class:`fiftyone.core.labels.Detections`,
:class:`fiftyone.core.labels.Polyline`, or
:class:`fiftyone.core.labels.Polylines` field defining a region of
interest within each image to use to compute leaks
embeddings (None): if no ``model`` is provided, this argument specifies
pre-computed embeddings to use, which can be any of the following:
- a ``num_samples x num_dims`` array of embeddings
- if ``roi_field`` is specified, a dict mapping sample IDs to
``num_patches x num_dims`` arrays of patch embeddings
- the name of a dataset field containing the embeddings to use
If a ``model`` is provided, this argument specifies the name of a
field in which to store the computed embeddings. In either case,
when working with patch embeddings, you can provide either the
fully-qualified path to the patch embeddings or just the name of
the label attribute in ``roi_field``
similarity_index (None): a
:class:`fiftyone.brain.similarity.SimilarityIndex` or the brain key
of a similarity index to use to load pre-computed embeddings
model (None): a :class:`fiftyone.core.models.Model` or the name of a
model from the
`FiftyOne Model Zoo `_
to use to generate embeddings. The model must expose embeddings
(``model.has_embeddings = True``)
model_kwargs (None): a dictionary of optional keyword arguments to pass
to the model's ``Config`` when a model name is provided
force_square (False): whether to minimally manipulate the patch
bounding boxes into squares prior to extraction. Only applicable
when a ``model`` and ``roi_field`` are specified
alpha (None): an optional expansion/contraction to apply to the patches
before extracting them, in ``[-1, inf)``. If provided, the length
and width of the box are expanded (or contracted, when
``alpha < 0``) by ``(100 * alpha)%``. For example, set
``alpha = 0.1`` to expand the boxes by 10%, and set
``alpha = -0.1`` to contract the boxes by 10%. Only applicable when
a ``model`` and ``roi_field`` are specified
batch_size (None): a batch size to use when computing embeddings. Only
applicable when a ``model`` is provided
num_workers (None): the number of workers to use when loading images.
Only applicable when a Torch-based model is being used to compute
embeddings
skip_failures (True): whether to gracefully continue without raising an
error if embeddings cannot be generated for a sample
progress (None): whether to render a progress bar (True/False), use the
default value ``fiftyone.config.show_progress_bars`` (None), or a
progress callback function to invoke instead
Returns:
a :class:`fiftyone.brain.internal.core.leaky_splits.LeakySplitsIndex`
"""
import fiftyone.brain.internal.core.leaky_splits as fbl
return fbl.compute_leaky_splits(
samples,
splits,
threshold=threshold,
roi_field=roi_field,
embeddings=embeddings,
similarity_index=similarity_index,
model=model,
model_kwargs=model_kwargs,
force_square=force_square,
alpha=alpha,
batch_size=batch_size,
num_workers=num_workers,
skip_failures=skip_failures,
progress=progress,
)
================================================
FILE: fiftyone/brain/config.py
================================================
"""
Brain config.
| Copyright 2017-2026, Voxel51, Inc.
| `voxel51.com `_
|
"""
import os
from fiftyone.core.config import EnvConfig
class BrainConfig(EnvConfig):
"""FiftyOne brain configuration settings."""
_BUILTIN_SIMILARITY_BACKENDS = {
"sklearn": {
"config_cls": "fiftyone.brain.internal.core.sklearn.SklearnSimilarityConfig",
},
"pinecone": {
"config_cls": "fiftyone.brain.internal.core.pinecone.PineconeSimilarityConfig",
},
"qdrant": {
"config_cls": "fiftyone.brain.internal.core.qdrant.QdrantSimilarityConfig",
},
"milvus": {
"config_cls": "fiftyone.brain.internal.core.milvus.MilvusSimilarityConfig",
},
"lancedb": {
"config_cls": "fiftyone.brain.internal.core.lancedb.LanceDBSimilarityConfig",
},
"redis": {
"config_cls": "fiftyone.brain.internal.core.redis.RedisSimilarityConfig",
},
"mongodb": {
"config_cls": "fiftyone.brain.internal.core.mongodb.MongoDBSimilarityConfig",
},
"elasticsearch": {
"config_cls": "fiftyone.brain.internal.core.elasticsearch.ElasticsearchSimilarityConfig",
},
"pgvector": {
"config_cls": "fiftyone.brain.internal.core.pgvector.PgVectorSimilarityConfig",
},
"mosaic": {
"config_cls": "fiftyone.brain.internal.core.mosaic.MosaicSimilarityConfig",
},
}
_BUILTIN_VISUALIZATION_METHODS = {
"umap": {
"config_cls": "fiftyone.brain.visualization.UMAPVisualizationConfig",
},
"tsne": {
"config_cls": "fiftyone.brain.visualization.TSNEVisualizationConfig",
},
"pca": {
"config_cls": "fiftyone.brain.visualization.PCAVisualizationConfig",
},
"manual": {
"config_cls": "fiftyone.brain.visualization.ManualVisualizationConfig",
},
}
def __init__(self, d=None):
if d is None:
d = {}
self.default_similarity_backend = self.parse_string(
d,
"default_similarity_backend",
env_var="FIFTYONE_BRAIN_DEFAULT_SIMILARITY_BACKEND",
default="sklearn",
)
self.similarity_backends = self._parse_similarity_backends(d)
if self.default_similarity_backend not in self.similarity_backends:
self.default_similarity_backend = next(
iter(sorted(self.similarity_backends.keys())), None
)
self.default_visualization_method = self.parse_string(
d,
"default_visualization_method",
env_var="FIFTYONE_BRAIN_DEFAULT_VISUALIZATION_METHOD",
default="umap",
)
self.visualization_methods = self._parse_visualization_methods(d)
if self.default_visualization_method not in self.visualization_methods:
self.default_visualization_method = next(
iter(sorted(self.visualization_methods.keys())), None
)
def _parse_similarity_backends(self, d):
d = d.get("similarity_backends", {})
env_vars = dict(os.environ)
#
# `FIFTYONE_BRAIN_SIMILARITY_BACKENDS` can be used to declare which
# backends are exposed. This may exclude builtin backends and/or
# declare new backends
#
if "FIFTYONE_BRAIN_SIMILARITY_BACKENDS" in env_vars:
backends = env_vars["FIFTYONE_BRAIN_SIMILARITY_BACKENDS"].split(
","
)
# Special syntax to append rather than override default backends
if "*" in backends:
backends = set(b for b in backends if b != "*")
backends |= set(self._BUILTIN_SIMILARITY_BACKENDS.keys())
d = {backend: d.get(backend, {}) for backend in backends}
else:
backends = self._BUILTIN_SIMILARITY_BACKENDS.keys()
for backend in backends:
if backend not in d:
d[backend] = {}
#
# Extract parameters from any environment variables of the form
# `FIFTYONE_BRAIN_SIMILARITY__`
#
for backend, d_backend in d.items():
prefix = "FIFTYONE_BRAIN_SIMILARITY_%s_" % backend.upper()
for env_name, env_value in env_vars.items():
if env_name.startswith(prefix):
name = env_name[len(prefix) :].lower()
value = _parse_env_value(env_value)
d_backend[name] = value
#
# Set default parameters for builtin similarity backends
#
for backend, defaults in self._BUILTIN_SIMILARITY_BACKENDS.items():
if backend not in d:
continue
d_backend = d[backend]
for name, value in defaults.items():
if name not in d_backend:
d_backend[name] = value
return d
def _parse_visualization_methods(self, d):
d = d.get("visualization_methods", {})
env_vars = dict(os.environ)
#
# `FIFTYONE_BRAIN_VISUALIZATION_METHODS` can be used to declare which
# methods are exposed. This may exclude builtin methods and/or declare
# new methods
#
if "FIFTYONE_BRAIN_VISUALIZATION_METHODS" in env_vars:
methods = env_vars["FIFTYONE_BRAIN_VISUALIZATION_METHODS"].split(
","
)
# Special syntax to append rather than override default methods
if "*" in methods:
methods = set(m for m in methods if m != "*")
methods |= set(self._BUILTIN_VISUALIZATION_METHODS.keys())
d = {method: d.get(method, {}) for method in methods}
else:
methods = self._BUILTIN_VISUALIZATION_METHODS.keys()
for method in methods:
if method not in d:
d[method] = {}
#
# Extract parameters from any environment variables of the form
# `FIFTYONE_BRAIN_VISUALIZATION__`
#
for method, d_method in d.items():
prefix = "FIFTYONE_BRAIN_VISUALIZATION_%s_" % method.upper()
for env_name, env_value in env_vars.items():
if env_name.startswith(prefix):
name = env_name[len(prefix) :].lower()
value = _parse_env_value(env_value)
d_method[name] = value
#
# Set default parameters for builtin visualization methods
#
for method, defaults in self._BUILTIN_VISUALIZATION_METHODS.items():
if method not in d:
continue
d_method = d[method]
for name, value in defaults.items():
if name not in d_method:
d_method[name] = value
return d
def locate_brain_config():
"""Returns the path to the :class:`BrainConfig` on disk.
The default location is ``~/.fiftyone/brain_config.json``, but you can
override this path by setting the ``FIFTYONE_BRAIN_CONFIG_PATH``
environment variable.
Note that a config file may not actually exist on disk.
Returns:
the path to the :class:`BrainConfig` on disk
"""
if "FIFTYONE_BRAIN_CONFIG_PATH" not in os.environ:
return os.path.join(
os.path.expanduser("~"), ".fiftyone", "brain_config.json"
)
return os.environ["FIFTYONE_BRAIN_CONFIG_PATH"]
def load_brain_config():
"""Loads the FiftyOne brain config.
Returns:
a :class:`BrainConfig` instance
"""
brain_config_path = locate_brain_config()
if os.path.isfile(brain_config_path):
return BrainConfig.from_json(brain_config_path)
return BrainConfig()
def _parse_env_value(value):
try:
return int(value)
except:
pass
try:
return float(value)
except:
pass
if value in ("True", "true"):
return True
if value in ("False", "false"):
return False
if value in ("None", ""):
return None
if "," in value:
return [_parse_env_value(v) for v in value.split(",")]
return value
================================================
FILE: fiftyone/brain/internal/__init__.py
================================================
"""
Internal FiftyOne Brain package.
Contains all non-public code powering the ``fiftyone.brain`` public namespace.
| Copyright 2017-2026, Voxel51, Inc.
| `voxel51.com `_
|
"""
================================================
FILE: fiftyone/brain/internal/core/__init__.py
================================================
"""
| Copyright 2017-2026, Voxel51, Inc.
| `voxel51.com `_
|
"""
================================================
FILE: fiftyone/brain/internal/core/duplicates.py
================================================
"""
Duplicates methods.
| Copyright 2017-2026, Voxel51, Inc.
| `voxel51.com `_
|
"""
from collections import defaultdict
import itertools
import logging
import multiprocessing
import eta.core.utils as etau
import fiftyone.core.media as fom
import fiftyone.core.utils as fou
import fiftyone.core.validation as fov
import fiftyone.brain as fb
import fiftyone.brain.similarity as fbs
import fiftyone.brain.internal.core.utils as fbu
logger = logging.getLogger(__name__)
_DEFAULT_MODEL = "resnet18-imagenet-torch"
def compute_near_duplicates(
samples,
threshold=None,
roi_field=None,
embeddings=None,
similarity_index=None,
model=None,
model_kwargs=None,
force_square=False,
alpha=None,
batch_size=None,
num_workers=None,
skip_failures=True,
progress=None,
):
"""See ``fiftyone/brain/__init__.py``."""
fov.validate_collection(samples)
if etau.is_str(embeddings):
embeddings_field, embeddings_exist = fbu.parse_data_field(
samples,
embeddings,
data_type="embeddings",
)
embeddings = None
else:
embeddings_field = None
embeddings_exist = None
if etau.is_str(similarity_index):
similarity_index = samples.load_brain_results(similarity_index)
if (
model is None
and embeddings is None
and similarity_index is None
and not embeddings_exist
):
model = _DEFAULT_MODEL
if similarity_index is None:
similarity_index = fb.compute_similarity(
samples,
backend="sklearn",
roi_field=roi_field,
embeddings=embeddings_field or embeddings,
model=model,
model_kwargs=model_kwargs,
force_square=force_square,
alpha=alpha,
batch_size=batch_size,
num_workers=num_workers,
skip_failures=skip_failures,
progress=progress,
)
elif not isinstance(similarity_index, fbs.DuplicatesMixin):
raise ValueError(
"This method only supports similarity indexes that implement the "
"%s mixin" % fbs.DuplicatesMixin
)
similarity_index.find_duplicates(thresh=threshold)
return similarity_index
def compute_exact_duplicates(samples, num_workers, skip_failures, progress):
"""See ``fiftyone/brain/__init__.py``."""
fov.validate_collection(samples)
if num_workers is None:
if samples.media_type == fom.VIDEO:
num_workers = multiprocessing.cpu_count()
else:
num_workers = 1
logger.info("Computing filehashes...")
method = "md5" if samples.media_type == fom.VIDEO else None
if num_workers <= 1:
hashes = _compute_filehashes(samples, method, progress)
else:
hashes = _compute_filehashes_multi(
samples, method, num_workers, progress
)
num_missing = sum(h is None for h in hashes)
if num_missing > 0:
msg = "Failed to compute %d filehashes" % num_missing
if skip_failures:
logger.warning(msg)
else:
raise ValueError(msg)
neighbors_map = defaultdict(list)
observed_hashes = {}
for _id, _hash in hashes.items():
if _hash is None:
continue
if _hash in observed_hashes:
neighbors_map[observed_hashes[_hash]].append(_id)
else:
observed_hashes[_hash] = _id
return dict(neighbors_map)
def _compute_filehashes(samples, method, progress):
ids, filepaths = samples.values(["id", "filepath"])
with fou.ProgressBar(total=len(ids), progress=progress) as pb:
return {
_id: _compute_filehash(filepath, method)
for _id, filepath in pb(zip(ids, filepaths))
}
def _compute_filehashes_multi(samples, method, num_workers, progress):
ids, filepaths = samples.values(["id", "filepath"])
methods = itertools.repeat(method)
inputs = list(zip(ids, filepaths, methods))
with fou.ProgressBar(total=len(inputs), progress=progress) as pb:
with multiprocessing.Pool(processes=num_workers) as pool:
return {
k: v
for k, v in pb(
pool.imap_unordered(_do_compute_filehash, inputs)
)
}
def _compute_filehash(filepath, method):
try:
filehash = fou.compute_filehash(filepath, method=method)
except:
filehash = None
return filehash
def _do_compute_filehash(args):
_id, filepath, method = args
try:
filehash = fou.compute_filehash(filepath, method=method)
except:
filehash = None
return _id, filehash
================================================
FILE: fiftyone/brain/internal/core/elasticsearch.py
================================================
"""
Elastisearch similarity backend.
| Copyright 2017-2026, Voxel51, Inc.
| `voxel51.com `_
|
"""
import logging
import numpy as np
import eta.core.utils as etau
from fiftyone import ViewField as F
import fiftyone.core.utils as fou
import fiftyone.brain.internal.core.utils as fbu
from fiftyone.brain.similarity import (
SimilarityConfig,
Similarity,
SimilarityIndex,
)
es = fou.lazy_import("elasticsearch")
logger = logging.getLogger(__name__)
_SUPPORTED_METRICS = {
"cosine": "cosine",
"dotproduct": "dot_product",
"euclidean": "l2_norm",
"innerproduct": "max_inner_product",
}
class ElasticsearchSimilarityConfig(SimilarityConfig):
"""Configuration for a Elasticsearch similarity instance.
Args:
index_name (None): the name of the Elasticsearch index to use or
create. If none is provided, a new index will be created
metric ("cosine"): the embedding distance metric to use when creating a
new index. Supported values are
``("cosine", "dotproduct", "euclidean", "innerproduct")``
hosts (None): the full Elasticsearch server address(es) to use. Can be
a string or list of strings
cloud_id (None): the Cloud ID of an Elastic Cloud to connect to
username (None): a username to use
password (None): a password to use
api_key (None): an API key to use
ca_certs (None): a path to a CA certificate
bearer_auth (None): a bearer token to use
ssl_assert_fingerprint (None): a SHA256 fingerprint to use
verify_certs (None): whether to verify SSL certificates
**kwargs: keyword arguments for
:class:`fiftyone.brain.similarity.SimilarityConfig`
"""
def __init__(
self,
index_name=None,
metric="cosine",
hosts=None,
cloud_id=None,
username=None,
password=None,
api_key=None,
ca_certs=None,
bearer_auth=None,
ssl_assert_fingerprint=None,
verify_certs=None,
**kwargs,
):
if metric not in _SUPPORTED_METRICS:
raise ValueError(
"Unsupported metric '%s'. Supported values are %s"
% (metric, tuple(_SUPPORTED_METRICS.keys()))
)
super().__init__(**kwargs)
self.index_name = index_name
self.metric = metric
self._hosts = hosts
self._cloud_id = cloud_id
self._username = username
self._password = password
self._api_key = api_key
self._ca_certs = ca_certs
self._bearer_auth = bearer_auth
self._ssl_assert_fingerprint = ssl_assert_fingerprint
self._verify_certs = verify_certs
@property
def method(self):
return "elasticsearch"
@property
def hosts(self):
return self._hosts
@hosts.setter
def hosts(self, value):
self._hosts = value
@property
def cloud_id(self):
return self._cloud_id
@cloud_id.setter
def cloud_id(self, value):
self._cloud_id = value
@property
def username(self):
return self._username
@username.setter
def username(self, value):
self._username = value
@property
def password(self):
return self._password
@password.setter
def password(self, value):
self._password = value
@property
def api_key(self):
return self._api_key
@api_key.setter
def api_key(self, value):
self._api_key = value
@property
def ca_certs(self):
return self._ca_certs
@ca_certs.setter
def ca_certs(self, value):
self._ca_certs = value
@property
def bearer_auth(self):
return self._bearer_auth
@bearer_auth.setter
def bearer_auth(self, value):
self._bearer_auth = value
@property
def ssl_assert_fingerprint(self):
return self._ssl_assert_fingerprint
@ssl_assert_fingerprint.setter
def ssl_assert_fingerprint(self, value):
self._ssl_assert_fingerprint = value
@property
def verify_certs(self):
return self._verify_certs
@verify_certs.setter
def verify_certs(self, value):
self._verify_certs = value
@property
def max_k(self):
return 10000 # Elasticsearch limit
@property
def supports_least_similarity(self):
return False
@property
def supported_aggregations(self):
return ("mean",)
def load_credentials(
self,
hosts=None,
cloud_id=None,
username=None,
password=None,
api_key=None,
ca_certs=None,
bearer_auth=None,
ssl_assert_fingerprint=None,
verify_certs=None,
):
self._load_parameters(
hosts=hosts,
cloud_id=cloud_id,
username=username,
password=password,
api_key=api_key,
ca_certs=ca_certs,
bearer_auth=bearer_auth,
ssl_assert_fingerprint=ssl_assert_fingerprint,
verify_certs=verify_certs,
)
class ElasticsearchSimilarity(Similarity):
"""Elasticsearch similarity factory.
Args:
config: a :class:`ElasticsearchSimilarityConfig`
"""
def ensure_requirements(self):
fou.ensure_package("elasticsearch")
def ensure_usage_requirements(self):
fou.ensure_package("elasticsearch")
def initialize(self, samples, brain_key):
return ElasticsearchSimilarityIndex(
samples, self.config, brain_key, backend=self
)
class ElasticsearchSimilarityIndex(SimilarityIndex):
"""Class for interacting with Elasticsearch similarity indexes.
Args:
samples: the :class:`fiftyone.core.collections.SampleCollection` used
config: the :class:`ElasticsearchSimilarityConfig` used
brain_key: the brain key
backend (None): a :class:`ElasticsearchSimilarity` instance
"""
def __init__(self, samples, config, brain_key, backend=None):
super().__init__(samples, config, brain_key, backend=backend)
self._client = None
self._metric = None
self._initialize()
@property
def total_index_size(self):
try:
return self._client.count(index=self.config.index_name)["count"]
except:
return 0
@property
def client(self):
"""The ``elasticsearch.Elasticsearch`` instance for this index."""
return self._client
def _initialize(self):
kwargs = {}
for key in (
"hosts",
"cloud_id",
"username",
"password",
"api_key",
"ca_certs",
"bearer_auth",
"ssl_assert_fingerprint",
"verify_certs",
):
value = getattr(self.config, key, None)
if value is not None:
kwargs[key] = value
username = kwargs.pop("username", None)
password = kwargs.pop("password", None)
if username is not None and password is not None:
kwargs["basic_auth"] = (username, password)
try:
self._client = es.Elasticsearch(**kwargs)
except Exception as e:
raise ValueError(
"Failed to connect to Elasticsearch backend. Refer to "
"https://docs.voxel51.com/integrations/elasticsearch.html for more "
"information"
) from e
if self.config.index_name is None:
root = "fiftyone-" + fou.to_slug(self.samples._root_dataset.name)
index_name = fbu.get_unique_name(root, self._get_index_names())
self.config.index_name = index_name
self.save_config()
def _get_index_names(self):
return self._client.indices.get_alias().keys()
def _get_index_ids(self, batch_size=1000):
sample_ids = []
label_ids = []
for batch in range(0, self.total_index_size, batch_size):
response = self._client.search(
index=self.config.index_name,
body={
"fields": ["sample_id"],
"from": batch,
"query": {
"bool": {
"must": [
{"exists": {"field": "vector"}},
{"exists": {"field": "sample_id"}},
]
}
},
},
source=False,
size=batch_size,
)
for doc in response["hits"]["hits"]:
sample_id = doc["fields"]["sample_id"][0]
sample_or_label_id = doc["_id"]
sample_ids.append(sample_id)
label_ids.append(sample_or_label_id)
return sample_ids, label_ids
def _get_dimension(self):
if self.total_index_size == 0:
return None
if self.config.patches_field is not None:
embeddings, _, _ = self.get_embeddings(
label_ids=self._label_ids[:1]
)
else:
embeddings, _, _ = self.get_embeddings(
sample_ids=self._sample_ids[:1]
)
return embeddings.shape[1]
def _get_metric(self):
if self._metric is None:
try:
# We must ask ES rather than using `self.config.metric` because
# we may be working with a preexisting index
self._metric = self._client.indices.get_mapping(
index=self.config.index_name
)[self.config.index_name]["mappings"]["properties"]["vector"][
"similarity"
]
except:
logger.warning(
"Failed to infer similarity metric from index '%s'",
self.config.index_name,
)
return self._metric
def _index_exists(self):
if self.config.index_name is None:
return False
return self.config.index_name in self._get_index_names()
def _create_index(self, dimension):
metric = _SUPPORTED_METRICS[self.config.metric]
mappings = {
"properties": {
"vector": {
"type": "dense_vector",
"dims": dimension,
"index": "true",
"similarity": metric,
}
}
}
self._client.indices.create(
index=self.config.index_name, mappings=mappings
)
self._metric = metric
def _get_existing_ids(self, ids):
docs = [{"_index": self.config.index_name, "_id": i} for i in ids]
resp = self._client.mget(docs=docs)
return [d["_id"] for d in resp["docs"] if d["found"]]
def add_to_index(
self,
embeddings,
sample_ids,
label_ids=None,
overwrite=True,
allow_existing=True,
warn_existing=False,
reload=True,
batch_size=500,
):
if not self._index_exists():
self._create_index(embeddings.shape[1])
if label_ids is not None:
ids = label_ids
else:
ids = sample_ids
if warn_existing or not allow_existing or not overwrite:
existing_ids = self._get_existing_ids(ids)
num_existing = len(existing_ids)
if num_existing > 0:
if not allow_existing:
raise ValueError(
"Found %d IDs (eg %s) that already exist in the index"
% (num_existing, next(iter(existing_ids)))
)
if warn_existing:
if overwrite:
logger.warning(
"Overwriting %d IDs that already exist in the "
"index",
num_existing,
)
else:
logger.warning(
"Skipping %d IDs that already exist in the index",
num_existing,
)
else:
existing_ids = set()
if existing_ids and not overwrite:
del_inds = [i for i, _id in enumerate(ids) if _id in existing_ids]
embeddings = np.delete(embeddings, del_inds, axis=0)
sample_ids = np.delete(sample_ids, del_inds)
if label_ids is not None:
label_ids = np.delete(label_ids, del_inds)
if self._get_metric() == _SUPPORTED_METRICS["dotproduct"]:
embeddings /= np.linalg.norm(embeddings, axis=1)[:, np.newaxis]
embeddings = [e.tolist() for e in embeddings]
sample_ids = list(sample_ids)
if label_ids is not None:
ids = list(label_ids)
else:
ids = list(sample_ids)
for _embeddings, _ids, _sample_ids in zip(
fou.iter_batches(embeddings, batch_size),
fou.iter_batches(ids, batch_size),
fou.iter_batches(sample_ids, batch_size),
):
operations = []
for _e, _id, _sid in zip(_embeddings, _ids, _sample_ids):
operations.append(
{"index": {"_index": self.config.index_name, "_id": _id}}
)
operations.append({"sample_id": _sid, "vector": _e})
self._client.bulk(
index=self.config.index_name,
operations=operations,
refresh=True,
)
if reload:
self.reload()
def remove_from_index(
self,
sample_ids=None,
label_ids=None,
allow_missing=True,
warn_missing=False,
reload=True,
):
if label_ids is not None:
ids = label_ids
else:
ids = sample_ids
if not allow_missing or warn_missing:
existing_ids = self._get_existing_ids(ids)
missing_ids = set(ids) - set(existing_ids)
num_missing = len(missing_ids)
if num_missing > 0:
if not allow_missing:
raise ValueError(
"Found %d IDs (eg %s) that are not present in the "
"index" % (num_missing, next(iter(missing_ids)))
)
if warn_missing:
logger.warning(
"Ignoring %d IDs that are not present in the index",
num_missing,
)
ids = existing_ids
operations = [
{"delete": {"_index": self.config.index_name, "_id": i}}
for i in ids
]
self._client.bulk(body=operations, refresh=True)
if reload:
self.reload()
def get_embeddings(
self,
sample_ids=None,
label_ids=None,
allow_missing=True,
warn_missing=False,
):
if label_ids is not None:
if self.config.patches_field is None:
raise ValueError("This index does not support label IDs")
if sample_ids is not None:
logger.warning(
"Ignoring sample IDs when label IDs are provided"
)
if sample_ids is not None and self.config.patches_field is not None:
(
embeddings,
sample_ids,
label_ids,
missing_ids,
) = self._get_patch_embeddings_from_sample_ids(sample_ids)
elif self.config.patches_field is not None:
(
embeddings,
sample_ids,
label_ids,
missing_ids,
) = self._get_patch_embeddings_from_label_ids(label_ids)
else:
(
embeddings,
sample_ids,
label_ids,
missing_ids,
) = self._get_sample_embeddings(sample_ids)
num_missing_ids = len(missing_ids)
if num_missing_ids > 0:
if not allow_missing:
raise ValueError(
"Found %d IDs (eg %s) that do not exist in the index"
% (num_missing_ids, missing_ids[0])
)
if warn_missing:
logger.warning(
"Skipping %d IDs that do not exist in the index",
num_missing_ids,
)
embeddings = np.array(embeddings)
sample_ids = np.array(sample_ids)
if label_ids is not None:
label_ids = np.array(label_ids)
return embeddings, sample_ids, label_ids
def _parse_embeddings_response(self, response, label_id=True):
found_embeddings = []
found_sample_ids = []
found_label_ids = []
for r in response:
if r.get("found", True):
found_embeddings.append(r["_source"]["vector"])
if label_id:
found_sample_ids.append(r["_source"]["sample_id"])
found_label_ids.append(r["_id"])
else:
found_sample_ids.append(r["_id"])
return found_embeddings, found_sample_ids, found_label_ids
def _get_sample_embeddings(self, sample_ids, batch_size=1000):
found_embeddings = []
found_sample_ids = []
if sample_ids is None:
sample_ids, label_ids = self._get_index_ids(batch_size=batch_size)
for batch_ids in fou.iter_batches(sample_ids, batch_size):
response = self._client.mget(
index=self.config.index_name, ids=batch_ids, source=True
)
(
_found_embeddings,
_found_sample_ids,
_,
) = self._parse_embeddings_response(
response["docs"], label_id=False
)
found_embeddings += _found_embeddings
found_sample_ids += _found_sample_ids
missing_ids = list(set(sample_ids) - set(found_sample_ids))
return found_embeddings, found_sample_ids, None, missing_ids
def _get_patch_embeddings_from_label_ids(self, label_ids, batch_size=1000):
found_embeddings = []
found_sample_ids = []
found_label_ids = []
if label_ids is None:
sample_ids, label_ids = self._get_index_ids(batch_size=batch_size)
for batch_ids in fou.iter_batches(label_ids, batch_size):
response = self._client.mget(
index=self.config.index_name, ids=batch_ids, source=True
)
(
_found_embeddings,
_found_sample_ids,
_found_label_ids,
) = self._parse_embeddings_response(response["docs"])
found_embeddings += _found_embeddings
found_sample_ids += _found_sample_ids
found_label_ids += _found_label_ids
missing_ids = list(set(label_ids) - set(found_label_ids))
return found_embeddings, found_sample_ids, found_label_ids, missing_ids
def _get_patch_embeddings_from_sample_ids(
self, sample_ids, batch_size=100
):
found_embeddings = []
found_sample_ids = []
found_label_ids = []
if sample_ids is None:
sample_ids, label_ids = self._get_index_ids(batch_size=batch_size)
for batch_ids in fou.iter_batches(sample_ids, batch_size):
response = self._client.search(
index=self.config.index_name,
body={"query": {"terms": {"sample_id": sample_ids}}},
)
(
_found_embeddings,
_found_sample_ids,
_found_label_ids,
) = self._parse_embeddings_response(response["hits"]["hits"])
found_embeddings += _found_embeddings
found_sample_ids += _found_sample_ids
found_label_ids += _found_label_ids
missing_ids = list(set(sample_ids) - set(found_sample_ids))
return found_embeddings, found_sample_ids, found_label_ids, missing_ids
def cleanup(self):
self._client.indices.delete(
index=self.config.index_name, ignore_unavailable=True
)
def _kneighbors(
self,
query=None,
k=None,
reverse=False,
aggregation=None,
return_dists=False,
):
if query is None:
raise ValueError(
"Elasticsearch does not support full index neighbors"
)
if reverse is True:
raise ValueError(
"Elasticsearch does not support least similarity queries"
)
if aggregation not in (None, "mean"):
raise ValueError(
f"Elasticsearch does not support {aggregation} aggregation"
)
query = self._parse_neighbors_query(query)
if aggregation == "mean" and query.ndim == 2:
query = query.mean(axis=0)
single_query = query.ndim == 1
if single_query:
query = [query]
if self.has_view:
if self.config.patches_field is not None:
index_ids = self.current_label_ids
else:
index_ids = self.current_sample_ids
_filter = {"terms": {"_id": list(index_ids)}}
else:
_filter = None
sample_ids = []
label_ids = [] if self.config.patches_field is not None else None
dists = []
for q in query:
if self._get_metric() == _SUPPORTED_METRICS["dotproduct"]:
q /= np.linalg.norm(q)
knn = {
"field": "vector",
"query_vector": q.tolist(),
"k": k,
"num_candidates": 10 * k,
}
if _filter:
knn["filter"] = _filter
source = self.config.patches_field is not None
response = self._client.search(
index=self.config.index_name,
knn=knn,
size=k,
source=source,
)
if self.config.patches_field is not None:
sample_ids.append(
[
r["_source"]["sample_id"]
for r in response["hits"]["hits"]
]
)
label_ids.append([r["_id"] for r in response["hits"]["hits"]])
else:
sample_ids.append([r["_id"] for r in response["hits"]["hits"]])
if return_dists:
dists.append([r["_score"] for r in response["hits"]["hits"]])
if single_query:
sample_ids = sample_ids[0]
if label_ids is not None:
label_ids = label_ids[0]
if return_dists:
dists = dists[0]
if return_dists:
return sample_ids, label_ids, dists
return sample_ids, label_ids
def _parse_neighbors_query(self, query):
if etau.is_str(query):
query_ids = [query]
single_query = True
else:
query = np.asarray(query)
# Query by vector(s)
if np.issubdtype(query.dtype, np.number):
return query
query_ids = list(query)
single_query = False
# Query by ID(s)
response = self._client.mget(
index=self.config.index_name, ids=query_ids, source=True
)
query = np.array(
[r["_source"]["vector"] for r in response["docs"] if r["found"]]
)
if query.size == 0:
raise ValueError(
"Query IDs %s were not found in the index" % query_ids
)
if single_query:
query = query[0, :]
return query
@classmethod
def _from_dict(cls, d, samples, config, brain_key):
return cls(samples, config, brain_key)
================================================
FILE: fiftyone/brain/internal/core/hardness.py
================================================
"""
Hardness methods.
| Copyright 2017-2026, Voxel51, Inc.
| `voxel51.com `_
|
"""
import logging
import numpy as np
from scipy.special import softmax
from scipy.stats import entropy
import fiftyone.core.brain as fob
import fiftyone.core.labels as fol
import fiftyone.core.media as fom
import fiftyone.core.utils as fou
import fiftyone.core.validation as fov
logger = logging.getLogger(__name__)
_ALLOWED_TYPES = (fol.Classification, fol.Classifications)
def compute_hardness(samples, label_field, hardness_field, progress):
"""See ``fiftyone/brain/__init__.py``."""
#
# Algorithm
#
# Hardness is computed directly as the entropy of the logits
#
fov.validate_collection(samples)
fov.validate_collection_label_fields(samples, label_field, _ALLOWED_TYPES)
if samples.media_type == fom.VIDEO:
hardness_field, _ = samples._handle_frame_field(hardness_field)
config = HardnessConfig(label_field, hardness_field)
brain_key = hardness_field
brain_method = config.build()
brain_method.ensure_requirements()
brain_method.register_run(samples, brain_key, cleanup=False)
brain_method.register_samples(samples)
view = samples.select_fields(label_field)
processing_frames = samples._is_frame_field(label_field)
logger.info("Computing hardness...")
for sample in view.iter_samples(progress=progress):
if processing_frames:
images = sample.frames.values()
else:
images = [sample]
sample_hardness = []
for image in images:
hardness = brain_method.process_image(image)
if hardness is not None:
sample_hardness.append(hardness)
if processing_frames:
image[hardness_field] = hardness
if sample_hardness:
sample[hardness_field] = np.max(sample_hardness)
else:
sample[hardness_field] = None
sample.save()
brain_method.save_run_results(samples, brain_key, None)
logger.info("Hardness computation complete")
# @todo move to `fiftyone/brain/hardness.py`
class HardnessConfig(fob.BrainMethodConfig):
def __init__(self, label_field, hardness_field, **kwargs):
self.label_field = label_field
self.hardness_field = hardness_field
super().__init__(**kwargs)
@property
def type(self):
return "mistakenness"
@property
def method(self):
return "entropy"
class Hardness(fob.BrainMethod):
def __init__(self, config):
super().__init__(config)
self.label_field = None
def ensure_requirements(self):
pass
def register_samples(self, samples):
self.label_field, _ = samples._handle_frame_field(
self.config.label_field
)
def process_image(self, sample_or_frame):
label = _get_data(sample_or_frame, self.label_field)
if label is None:
return None
return entropy(softmax(np.asarray(label.logits)))
def get_fields(self, samples, brain_key):
label_field = self.config.label_field
hardness_field = self.config.hardness_field
fields = [label_field, hardness_field]
if samples._is_frame_field(label_field):
fields.append(samples._FRAMES_PREFIX + hardness_field)
return fields
def cleanup(self, samples, brain_key):
label_field = self.config.label_field
hardness_field = self.config.hardness_field
samples._dataset.delete_sample_fields(hardness_field, error_level=1)
if samples._is_frame_field(label_field):
samples._dataset.delete_frame_fields(hardness_field, error_level=1)
def _validate_run(self, samples, brain_key, existing_info):
self._validate_fields_match(brain_key, "hardness_field", existing_info)
def _get_data(sample, label_field):
label = sample[label_field]
if label is None:
return None
if label.logits is None:
raise ValueError(
"Sample '%s' field '%s' has no logits" % (sample.id, label_field)
)
return label
================================================
FILE: fiftyone/brain/internal/core/lancedb.py
================================================
"""
LanceDB similarity backend.
| Copyright 2017-2026, Voxel51, Inc.
| `voxel51.com `_
|
"""
import logging
import numpy as np
import eta.core.utils as etau
import fiftyone.core.storage as fos
import fiftyone.core.utils as fou
import fiftyone.brain.internal.core.utils as fbu
from fiftyone.brain.similarity import (
SimilarityConfig,
Similarity,
SimilarityIndex,
)
lancedb = fou.lazy_import("lancedb")
pa = fou.lazy_import("pyarrow")
_SUPPORTED_METRICS = {
"cosine": "cosine",
"euclidean": "l2",
}
logger = logging.getLogger(__name__)
class LanceDBSimilarityConfig(SimilarityConfig):
"""Configuration for a LanceDB similarity instance.
Args:
table_name (None): the name of the LanceDB table to use. If none is
provided, a new table will be created
metric ("cosine"): the embedding distance metric to use when creating a
new index. Supported values are ``("cosine", "euclidean")``
uri ("/tmp/lancedb"): the database URI to use
**kwargs: keyword arguments for :class:`SimilarityConfig`
"""
def __init__(
self,
table_name=None,
metric="cosine",
uri="/tmp/lancedb",
**kwargs,
):
if metric not in _SUPPORTED_METRICS:
raise ValueError(
"Unsupported metric '%s'. Supported values are %s"
% (metric, tuple(_SUPPORTED_METRICS.keys()))
)
super().__init__(**kwargs)
self.table_name = table_name
self.metric = metric
# store privately so these aren't serialized
self._uri = uri
@property
def method(self):
return "lancedb"
@property
def uri(self):
return self._uri
@uri.setter
def uri(self, value):
self._uri = value
@property
def max_k(self):
return None
@property
def supports_least_similarity(self):
return False
@property
def supported_aggregations(self):
return ("mean",)
def load_credentials(self, uri=None):
self._load_parameters(uri=uri)
class LanceDBSimilarity(Similarity):
"""LanceDB similarity factory.
Args:
config: a :class:`LanceDBSimilarityConfig`
"""
def ensure_requirements(self):
fou.ensure_package("lancedb")
def ensure_usage_requirements(self):
fou.ensure_package("lancedb")
def initialize(self, samples, brain_key):
return LanceDBSimilarityIndex(
samples, self.config, brain_key, backend=self
)
class LanceDBSimilarityIndex(SimilarityIndex):
"""Class for interacting with LanceDB similarity indexes.
Args:
samples: the :class:`fiftyone.core.collections.SampleCollection` used
config: the :class:`LanceDBSimilarityConfig` used
brain_key: the brain key
backend (None): a :class:`LanceDBSimilarity` instance
"""
def __init__(self, samples, config, brain_key, backend=None):
super().__init__(samples, config, brain_key, backend=backend)
self._table = None
self._db = None
self._initialize()
def _initialize(self):
try:
db = lancedb.connect(self.config.uri)
except Exception as e:
raise ValueError(
"Failed to connect to LanceDB backend at URI '%s'. Refer to "
"https://docs.voxel51.com/integrations/lancedb.html for more "
"information" % self.config.uri
) from e
table_names = db.table_names()
if self.config.table_name is None:
root = "fiftyone-" + fou.to_slug(self.samples._root_dataset.name)
table_name = fbu.get_unique_name(root, table_names)
self.config.table_name = table_name
self.save_config()
if self.config.table_name in table_names:
table = db.open_table(self.config.table_name)
else:
table = None
self._db = db
self._table = table
@property
def table(self):
"""The ``lancedb.LanceTable`` instance for this index."""
return self._table
@property
def total_index_size(self):
if self._table is None:
return 0
return len(self._table)
def add_to_index(
self,
embeddings,
sample_ids,
label_ids=None,
overwrite=True,
allow_existing=True,
warn_existing=False,
reload=True,
):
if self._table is None:
pa_table = pa.Table.from_arrays(
[[], [], []], names=["id", "sample_id", "vector"]
)
else:
pa_table = self._table.to_arrow()
if label_ids is not None:
ids = label_ids
else:
ids = sample_ids
if warn_existing or not allow_existing or not overwrite:
existing_ids = set(pa_table["id"].to_pylist()) & set(ids)
num_existing = len(existing_ids)
if num_existing > 0:
if not allow_existing:
raise ValueError(
"Found %d IDs (eg %s) that already exist in the index"
% (num_existing, next(iter(existing_ids)))
)
if warn_existing:
if overwrite:
logger.warning(
"Overwriting %d IDs that already exist in the "
"index",
num_existing,
)
else:
logger.warning(
"Skipping %d IDs that already exist in the index",
num_existing,
)
else:
existing_ids = set()
if existing_ids and not overwrite:
del_inds = [i for i, _id in enumerate(ids) if _id in existing_ids]
embeddings = np.delete(embeddings, del_inds, axis=0)
sample_ids = np.delete(sample_ids, del_inds)
if label_ids is not None:
label_ids = np.delete(label_ids, del_inds)
if label_ids is not None:
ids = list(label_ids)
else:
ids = list(sample_ids)
dim = embeddings.shape[1]
if self._table:
prev_embeddings = np.concatenate(
pa_table["vector"].to_numpy()
).reshape(-1, dim)
embeddings = np.concatenate([prev_embeddings, embeddings])
ids = pa_table["id"].to_pylist() + ids
sample_ids = pa_table["sample_id"].to_pylist() + sample_ids
embeddings = pa.array(embeddings.reshape(-1), type=pa.float32())
embeddings = pa.FixedSizeListArray.from_arrays(embeddings, dim)
sample_ids = list(sample_ids)
pa_table = pa.Table.from_arrays(
[ids, sample_ids, embeddings], names=["id", "sample_id", "vector"]
)
self._table = self._db.create_table(
self.config.table_name, pa_table, mode="overwrite"
)
if reload:
self.reload()
def remove_from_index(
self,
sample_ids=None,
label_ids=None,
allow_missing=True,
warn_missing=False,
reload=True,
):
if label_ids is not None:
ids = label_ids
else:
ids = sample_ids
if not allow_missing or warn_missing:
existing_ids = list(self._index.fetch(ids).vectors.keys())
missing_ids = set(ids) - set(existing_ids)
num_missing = len(missing_ids)
if num_missing > 0:
if not allow_missing:
raise ValueError(
"Found %d IDs (eg %s) that are not present in the "
"index" % (num_missing, next(iter(missing_ids)))
)
if warn_missing:
logger.warning(
"Ignoring %d IDs that are not present in the index",
num_missing,
)
ids = existing_ids
df = self._table.to_pandas()
df = df[~df["id"].isin(ids)]
self._table = self._db.create_table(
self.config.table_name, df, mode="overwrite"
)
if reload:
self.reload()
def get_embeddings(
self,
sample_ids=None,
label_ids=None,
allow_missing=True,
warn_missing=False,
):
if label_ids is not None:
if self.config.patches_field is None:
raise ValueError("This index does not support label IDs")
if sample_ids is not None:
logger.warning(
"Ignoring sample IDs when label IDs are provided"
)
df = self._table.to_pandas()
found_embeddings = []
found_sample_ids = []
found_label_ids = []
missing_ids = []
if sample_ids is not None and self.config.patches_field is not None:
df.set_index("sample_id", drop=False, inplace=True)
if not etau.is_container(sample_ids):
sample_ids = [sample_ids]
for sample_id in sample_ids:
if sample_id in df.index:
found_embeddings.append(df.loc[sample_id]["vector"])
found_sample_ids.append(sample_id)
found_label_ids.append(df.loc[sample_id]["id"])
else:
missing_ids.append(sample_id)
elif self.config.patches_field is not None:
df.set_index("id", drop=False, inplace=True)
if label_ids is None:
label_ids = list(df.index)
elif not etau.is_container(label_ids):
label_ids = [label_ids]
for label_id in label_ids:
if label_id in df.index:
found_embeddings.append(df.loc[label_id]["vector"])
found_sample_ids.append(df.loc[label_id]["sample_id"])
found_label_ids.append(label_id)
else:
missing_ids.append(label_id)
else:
df.set_index("id", drop=False, inplace=True)
if sample_ids is None:
sample_ids = list(df.index)
elif not etau.is_container(sample_ids):
sample_ids = [sample_ids]
for sample_id in sample_ids:
if sample_id in df.index:
found_embeddings.append(df.loc[sample_id]["vector"])
found_sample_ids.append(sample_id)
else:
missing_ids.append(sample_id)
num_missing_ids = len(missing_ids)
if num_missing_ids > 0:
if not allow_missing:
raise ValueError(
"Found %d IDs (eg %s) that do not exist in the index"
% (num_missing_ids, missing_ids[0])
)
if warn_missing:
logger.warning(
"Skipping %d IDs that do not exist in the index",
num_missing_ids,
)
embeddings = np.array(found_embeddings)
sample_ids = np.array(found_sample_ids)
if label_ids is not None:
label_ids = np.array(found_label_ids)
return embeddings, sample_ids, label_ids
def cleanup(self):
if self._db is None:
return
for tbl in (
self.config.table_name,
self.config.table_name + "_filter",
):
if tbl in self._db.table_names():
self._db.drop_table(tbl)
self._table = None
def _kneighbors(
self,
query=None,
k=None,
reverse=False,
aggregation=None,
return_dists=False,
):
if query is None:
raise ValueError("LanceDB does not support full index neighbors")
if reverse is True:
raise ValueError(
"LanceDB does not support least similarity queries"
)
if aggregation not in (None, "mean"):
raise ValueError(
f"LanceDB does not support {aggregation} aggregation"
)
if k is None:
k = self.index_size
query = self._parse_neighbors_query(query)
if aggregation == "mean" and query.ndim == 2:
query = query.mean(axis=0)
single_query = query.ndim == 1
if single_query:
query = [query]
table = self._table
if self.has_view:
if self.config.patches_field is not None:
index_ids = list(self.current_label_ids)
else:
index_ids = list(self.current_sample_ids)
df = table.to_pandas()
df = df[df["id"].isin(index_ids)]
table = self._db.create_table(
self.config.table_name + "_filter", df, mode="overwrite"
)
metric = _SUPPORTED_METRICS[self.config.metric]
sample_ids = []
label_ids = [] if self.config.patches_field is not None else None
dists = []
for q in query:
results = table.search(q).metric(metric).limit(k).to_df()
if self.config.patches_field is not None:
sample_ids.append(results.sample_id.tolist())
label_ids.append(results.id.tolist())
else:
sample_ids.append(results.id.tolist())
if return_dists:
dists.append(results._distance.tolist())
if single_query:
sample_ids = sample_ids[0]
if label_ids is not None:
label_ids = label_ids[0]
if return_dists:
dists = dists[0]
if return_dists:
return sample_ids, label_ids, dists
return sample_ids, label_ids
def _parse_neighbors_query(self, query):
if etau.is_str(query):
query_ids = [query]
single_query = True
else:
query = np.asarray(query)
# Query by vector(s)
if np.issubdtype(query.dtype, np.number):
return query
query_ids = list(query)
single_query = False
# Query by ID(s)
df = self._table.to_pandas()
df = df[df["id"].isin(query_ids)]
query = np.array([v for v in df["vector"]])
if query.size == 0:
raise ValueError(
"Query IDs %s were not found in the index" % query_ids
)
if single_query:
query = query[0, :]
return query
@classmethod
def _from_dict(cls, d, samples, config, brain_key):
return cls(samples, config, brain_key)
================================================
FILE: fiftyone/brain/internal/core/leaky_splits.py
================================================
"""
Finds leaks between splits.
| Copyright 2017-2026, Voxel51, Inc.
| `voxel51.com `_
|
"""
import logging
import eta.core.utils as etau
import fiftyone.core.brain as fob
import fiftyone.core.fields as fof
import fiftyone.core.validation as fov
import fiftyone.zoo as foz
from fiftyone import ViewField as F
import fiftyone.brain as fb
import fiftyone.brain.similarity as fbs
import fiftyone.brain.internal.core.utils as fbu
logger = logging.getLogger(__name__)
_DEFAULT_MODEL = "resnet18-imagenet-torch"
_DEFAULT_BATCH_SIZE = None
def compute_leaky_splits(
samples,
splits,
threshold=None,
roi_field=None,
embeddings=None,
similarity_index=None,
model=None,
model_kwargs=None,
force_square=False,
alpha=None,
batch_size=None,
num_workers=None,
skip_failures=True,
progress=None,
):
"""See ``fiftyone/brain/__init__.py``."""
fov.validate_collection(samples)
if etau.is_str(embeddings):
embeddings_field, embeddings_exist = fbu.parse_data_field(
samples,
embeddings,
data_type="embeddings",
)
embeddings = None
else:
embeddings_field = None
embeddings_exist = None
if etau.is_str(similarity_index):
similarity_index = samples.load_brain_results(similarity_index)
if (
model is None
and embeddings is None
and similarity_index is None
and not embeddings_exist
):
model = foz.load_zoo_model(_DEFAULT_MODEL)
if batch_size is None:
batch_size = _DEFAULT_BATCH_SIZE
config = LeakySplitsConfig(
splits=splits,
embeddings_field=embeddings_field,
similarity_index=similarity_index,
model=model,
model_kwargs=model_kwargs,
)
brain_method = config.build()
brain_method.ensure_requirements()
if similarity_index is None:
similarity_index = fb.compute_similarity(
samples,
backend="sklearn",
roi_field=roi_field,
embeddings=embeddings_field or embeddings,
model=model,
model_kwargs=model_kwargs,
force_square=force_square,
alpha=alpha,
batch_size=batch_size,
num_workers=num_workers,
skip_failures=skip_failures,
progress=progress,
)
elif not isinstance(similarity_index, fbs.DuplicatesMixin):
raise ValueError(
"This method only supports similarity indexes that implement the "
"%s mixin" % fbs.DuplicatesMixin
)
split_views = _to_split_views(samples, splits)
index = brain_method.initialize(samples, similarity_index, split_views)
if threshold is not None:
index.find_leaks(threshold)
return index
class LeakySplitsConfig(fob.BrainMethodConfig):
def __init__(
self,
splits=None,
embeddings_field=None,
similarity_index=None,
model=None,
model_kwargs=None,
**kwargs,
):
if isinstance(splits, dict):
splits = None
if similarity_index is not None and not etau.is_str(similarity_index):
similarity_index = similarity_index.key
if model is not None and not etau.is_str(model):
model = etau.get_class_name(model)
self.splits = splits
self.embeddings_field = embeddings_field
self.similarity_index = similarity_index
self.model = model
self.model_kwargs = model_kwargs
super().__init__(**kwargs)
@property
def type(self):
return "leakage"
@property
def method(self):
return "similarity"
class LeakySplits(fob.BrainMethod):
def initialize(self, samples, similarity_index, split_views):
return LeakySplitsIndex(
samples, self.config, similarity_index, split_views
)
def get_fields(self, samples, _):
fields = []
if self.config.embeddings_field is not None:
fields.append(self.config.embeddings_field)
return fields
class LeakySplitsIndex(fob.BrainResults):
def __init__(self, samples, config, similarity_index, split_views):
super().__init__(samples, config, None)
self._similarity_index = similarity_index
self._split_views = split_views
self._id2split = None
self._thresh = None
self._leak_ids = None
self._initialize()
@property
def split_views(self):
"""A dict mapping split names to views."""
return self._split_views
@property
def thresh(self):
"""The threshold used by the last call to :meth:`find_leaks`."""
return self._thresh
@property
def leak_ids(self):
"""The list of leaky sample IDs from the last call to
:meth:`find_leaks`.
"""
return self._leak_ids
def find_leaks(self, thresh):
"""Scans the index for leaks between splits.
Args:
thresh: the similarity distance threshold to use when detecting
potential leaks
"""
if thresh == self._thresh:
return
# Find duplicates
self._thresh = thresh
if self._similarity_index.thresh != self._thresh:
self._similarity_index.find_duplicates(self._thresh)
# Filter duplicates to just those with neighbors in different splits
leak_ids = []
neighbors_map = self._similarity_index.neighbors_map
for sample_id, neighbors in neighbors_map.items():
_leak_ids = []
sample_split = self._id2split.get(sample_id, None)
if sample_split is None:
continue
for n in neighbors:
neighbor_id = n[0]
neighbor_split = self._id2split.get(neighbor_id, None)
if neighbor_split is None:
continue
if neighbor_split != sample_split:
_leak_ids.append(neighbor_id)
if _leak_ids:
leak_ids.append(sample_id)
leak_ids.extend(_leak_ids)
self._leak_ids = leak_ids
def leaks_view(self):
"""Returns a view containg all potential leaks generated by the last
call to :meth:`find_leaks`.
Returns:
a :class:`fiftyone.core.view.DatasetView`
"""
if self._thresh is None:
raise ValueError("You must first call `find_leaks()`")
return self.samples.select(self._leak_ids, ordered=True)
def leaks_for_sample(self, sample_or_id):
"""Returns a view that contains all leaks related to the given sample.
The given sample is always first in the returned view, followed by any
related leaks.
Args:
sample_or_id: a :class:`fiftyone.core.sample.Sample` or sample ID
Returns:
a :class:`fiftyone.core.view.DatasetView`
"""
if self._thresh is None:
raise ValueError("You must first call `find_leaks()`")
if etau.is_str(sample_or_id):
sample_id = sample_or_id
else:
sample_id = sample_or_id.id
sample_split = self._id2split[sample_id]
neighbors_map = self._similarity_index.neighbors_map
leak_ids = []
if sample_id in neighbors_map.keys():
neighbors = neighbors_map[sample_id]
leak_ids = [
n[0] for n in neighbors if self._id2split[n[0]] != sample_split
]
else:
for unique_id, neighbors in neighbors_map.items():
if sample_id in [n[0] for n in neighbors]:
leak_ids = [
n[0]
for n in neighbors
if self._id2split[n[0]] != sample_split
]
leak_ids.append(unique_id)
break
return self.samples.select([sample_id] + leak_ids, ordered=True)
def no_leaks_view(self, view=None):
"""Returns a view with leaks excluded.
Args:
view (None): an optional :class:`fiftyone.core.view.DatasetView`
from which to exclude. By default, :meth:`samples` is used
"""
if self._thresh is None:
raise ValueError("You must first call `find_leaks()`")
if view is None:
view = self.samples
return view.exclude(self._leak_ids)
def tag_leaks(self, tag="leak"):
"""Tags all potential leaks in :meth:`leaks_view` with the given tag.
Args:
tag ("leak"): the tag string to apply
"""
self.leaks_view().tag_samples(tag)
def _initialize(self):
id2split = {}
split_ids = {}
for split_name, split_view in self.split_views.items():
sample_ids = set(split_view.values("id"))
split_ids[split_name] = sample_ids
id2split.update({sid: split_name for sid in sample_ids})
# Check for overlapping splits
split_names = list(split_ids.keys())
for idx, split1 in enumerate(split_names):
for split2 in split_names[idx + 1 :]:
overlap = split_ids[split1] & split_ids[split2]
if overlap:
logger.warning(
"The '%s' and '%s' splits contain %d overlapping samples."
"Use dataset.match_tags('%s').match_tags('%s') to "
"identify them",
split1,
split2,
len(overlap),
split1,
split2,
)
# Check for samples not in index
index_ids = self._similarity_index.sample_ids
if index_ids is not None:
index_ids = set(index_ids)
all_split_ids = set(id2split.keys())
missing_ids = all_split_ids - index_ids
if missing_ids:
logger.warning(
"The provided splits contain %d samples (eg '%s') that "
"are not present in the index",
len(missing_ids),
next(iter(missing_ids)),
)
self._id2split = id2split
def _to_split_views(samples, splits):
if etau.is_container(splits):
return {tag: samples.match_tags(tag) for tag in splits}
if isinstance(splits, str):
field = samples.get_field(splits)
if isinstance(field, fof.ListField):
return {
value: samples.exists(splits).match(F(splits).contains(value))
for value in samples.distinct(splits)
}
else:
return {
value: samples.match(F(splits) == value)
for value in samples.distinct(splits)
}
================================================
FILE: fiftyone/brain/internal/core/milvus.py
================================================
"""
Milvus similarity backend.
| Copyright 2017-2026, Voxel51, Inc.
| `voxel51.com `_
|
"""
import logging
import numpy as np
from uuid import uuid4
import eta.core.utils as etau
import fiftyone.core.utils as fou
from fiftyone.brain.similarity import (
SimilarityConfig,
Similarity,
SimilarityIndex,
)
import fiftyone.brain.internal.core.utils as fbu
pymilvus = fou.lazy_import("pymilvus")
logger = logging.getLogger(__name__)
_SUPPORTED_METRICS = {
"cosine": "COSINE",
"dotproduct": "IP",
"euclidean": "L2",
}
class MilvusSimilarityConfig(SimilarityConfig):
"""Configuration for the Milvus similarity backend.
Args:
collection_name (None): the name of a Milvus collection to use or
create. If none is provided, a new collection will be created
metric ("dotproduct"): the embedding distance metric to use when
creating a new index. Supported values are
``("cosine", "dotproduct", "euclidean")``
consistency_level ("Session"): the consistency level to use. Supported
values are ``("Session", "Strong", "Bounded", "Eventually")``
uri (None): a full Milvus server address to use, like
``"http://localhost:19530"``,
``"tcp:localhost:19530"``, or
``"https://ok.s3.south.com:19530"``
user (None): a username to use
password (None): a password to use
secure (None): whether to enable TLS (True)
token (None): a header token for RPC calls
db_name (None): a database name for the connection
client_key_path (None): a client.key path for TLS two-way
client_pem_path (None): a client.pem path for TLS two-way
ca_pem_path (None): a ca.pem path for TLS two-way
server_pem_path (None): a server.pem path for TLS one-way
server_name (None): the server name, for TLS
**kwargs: keyword arguments for
:class:`fiftyone.brain.similarity.SimilarityConfig`
"""
def __init__(
self,
collection_name=None,
metric="dotproduct",
consistency_level="Session",
uri=None,
user=None,
password=None,
secure=None,
token=None,
db_name=None,
client_key_path=None,
client_pem_path=None,
ca_pem_path=None,
server_pem_path=None,
server_name=None,
**kwargs,
):
if metric not in _SUPPORTED_METRICS:
raise ValueError(
"Unsupported metric '%s'. Supported values are %s"
% (metric, tuple(_SUPPORTED_METRICS.keys()))
)
super().__init__(**kwargs)
self.collection_name = collection_name
self.metric = metric
self.consistency_level = consistency_level
# store privately so these aren't serialized
self._uri = uri
self._user = user
self._password = password
self._secure = secure
self._token = token
self._db_name = db_name
self._client_key_path = client_key_path
self._client_pem_path = client_pem_path
self._ca_pem_path = ca_pem_path
self._server_pem_path = server_pem_path
self._server_name = server_name
@property
def method(self):
return "milvus"
@property
def uri(self):
return self._uri
@uri.setter
def uri(self, value):
self._uri = value
@property
def user(self):
return self._user
@user.setter
def user(self, value):
self._user = value
@property
def password(self):
return self._password
@password.setter
def password(self, value):
self._password = value
@property
def secure(self):
return self._secure
@secure.setter
def secure(self, value):
self._secure = value
@property
def token(self):
return self._token
@token.setter
def token(self, value):
self._token = value
@property
def db_name(self):
return self._db_name
@db_name.setter
def db_name(self, value):
self._db_name = value
@property
def client_key_path(self):
return self._client_key_path
@client_key_path.setter
def client_key_path(self, value):
self._client_key_path = value
@property
def client_pem_path(self):
return self._client_pem_path
@client_pem_path.setter
def client_pem_path(self, value):
self._client_pem_path = value
@property
def ca_pem_path(self):
return self._ca_pem_path
@ca_pem_path.setter
def ca_pem_path(self, value):
self._ca_pem_path = value
@property
def server_pem_path(self):
return self._server_pem_path
@server_pem_path.setter
def server_pem_path(self, value):
self._server_pem_path = value
@property
def server_name(self):
return self._server_name
@server_name.setter
def server_name(self, value):
self._server_name = value
@property
def max_k(self):
return 16384
@property
def supports_least_similarity(self):
return False
@property
def supported_aggregations(self):
return ("mean",)
@property
def index_params(self):
return {
"metric_type": _SUPPORTED_METRICS[self.metric],
"index_type": "HNSW",
"params": {"M": 8, "efConstruction": 64},
}
@property
def search_params(self):
return {
"HNSW": {
"metric_type": _SUPPORTED_METRICS[self.metric],
"params": {"ef": 10},
},
}
def load_credentials(
self,
uri=None,
user=None,
password=None,
secure=None,
token=None,
db_name=None,
client_key_path=None,
client_pem_path=None,
ca_pem_path=None,
server_pem_path=None,
server_name=None,
):
self._load_parameters(
uri=uri,
user=user,
password=password,
secure=secure,
token=token,
db_name=db_name,
client_key_path=client_key_path,
client_pem_path=client_pem_path,
ca_pem_path=ca_pem_path,
server_pem_path=server_pem_path,
server_name=server_name,
)
class MilvusSimilarity(Similarity):
"""Milvus similarity factory.
Args:
config: a :class:`MilvusSimilarityConfig`
"""
def ensure_requirements(self):
fou.ensure_package("pymilvus")
def ensure_usage_requirements(self):
fou.ensure_package("pymilvus")
def initialize(self, samples, brain_key):
return MilvusSimilarityIndex(
samples, self.config, brain_key, backend=self
)
class MilvusSimilarityIndex(SimilarityIndex):
"""Class for interacting with Milvus similarity indexes.
Args:
samples: the :class:`fiftyone.core.collections.SampleCollection` used
config: the :class:`MilvusSimilarityConfig` used
brain_key: the brain key
backend (None): a :class:`MilvusSimilarity` instance
"""
def __init__(self, samples, config, brain_key, backend=None):
super().__init__(samples, config, brain_key, backend=backend)
self._alias = None
self._collection = None
self._initialize()
def _initialize(self):
kwargs = {}
for key in (
"uri",
"user",
"password",
"secure",
"token",
"db_name",
"client_key_path",
"client_pem_path",
"ca_pem_path",
"server_pem_path",
"server_name",
):
value = getattr(self.config, key, None)
if value is not None:
kwargs[key] = value
alias = uuid4().hex if kwargs else "default"
try:
pymilvus.connections.connect(alias=alias, **kwargs)
except pymilvus.MilvusException as e:
raise ValueError(
"Failed to connect to Milvus backend at URI '%s'. Refer to "
"https://docs.voxel51.com/integrations/milvus.html for more "
"information" % self.config.uri
) from e
collection_names = pymilvus.utility.list_collections(using=alias)
if self.config.collection_name is None:
# Milvus only supports numbers, letters and underscores
root = "fiftyone-" + fou.to_slug(self.samples._root_dataset.name)
root = root.replace("-", "_")
collection_name = fbu.get_unique_name(root, collection_names)
collection_name = collection_name.replace("-", "_")
self.config.collection_name = collection_name
self.save_config()
if self.config.collection_name in collection_names:
collection = pymilvus.Collection(
self.config.collection_name, using=alias
)
collection.load()
else:
collection = None
self._alias = alias
self._collection = collection
def _create_collection(self, dimension):
schema = pymilvus.CollectionSchema(
[
pymilvus.FieldSchema(
"pk",
pymilvus.DataType.VARCHAR,
is_primary=True,
auto_id=False,
max_length=64000,
),
pymilvus.FieldSchema(
"vector", pymilvus.DataType.FLOAT_VECTOR, dim=dimension
),
pymilvus.FieldSchema(
"sample_id", pymilvus.DataType.VARCHAR, max_length=64000
),
]
)
collection = pymilvus.Collection(
self.config.collection_name,
schema,
consistency_level=self.config.consistency_level,
using=self._alias,
)
collection.create_index(
"vector", index_params=self.config.index_params
)
collection.load()
self._collection = collection
@property
def collection(self):
"""The ``pymilvus.Collection`` instance for this index."""
return self._collection
@property
def total_index_size(self):
if self._collection is None:
return 0
return self._collection.num_entities
def add_to_index(
self,
embeddings,
sample_ids,
label_ids=None,
overwrite=True,
allow_existing=True,
warn_existing=False,
reload=True,
batch_size=100,
):
if self._collection is None:
self._create_collection(embeddings.shape[1])
if label_ids is not None:
ids = label_ids
else:
ids = sample_ids
if warn_existing or not allow_existing or not overwrite:
existing_ids = self._get_existing_ids(ids)
num_existing = len(existing_ids)
if num_existing > 0:
if not allow_existing:
raise ValueError(
"Found %d IDs (eg %s) that already exist in the index"
% (num_existing, next(iter(existing_ids)))
)
if warn_existing:
if overwrite:
logger.warning(
"Overwriting %d IDs that already exist in the "
"index",
num_existing,
)
else:
logger.warning(
"Skipping %d IDs that already exist in the index",
num_existing,
)
else:
existing_ids = set()
if existing_ids and not overwrite:
del_inds = [i for i, _id in enumerate(ids) if _id in existing_ids]
embeddings = np.delete(embeddings, del_inds, axis=0)
sample_ids = np.delete(sample_ids, del_inds)
if label_ids is not None:
label_ids = np.delete(label_ids, del_inds)
elif existing_ids and overwrite:
self._delete_ids(existing_ids)
embeddings = [e.tolist() for e in embeddings]
sample_ids = list(sample_ids)
ids = list(ids)
for _embeddings, _ids, _sample_ids in zip(
fou.iter_batches(embeddings, batch_size),
fou.iter_batches(ids, batch_size),
fou.iter_batches(sample_ids, batch_size),
):
insert_data = [
list(_ids),
list(_embeddings),
list(_sample_ids),
]
self._collection.insert(insert_data)
self._collection.flush()
if reload:
self.reload()
def _get_existing_ids(self, ids):
ids = ['"' + str(entry) + '"' for entry in ids]
expr = f"""pk in [{','.join(ids)}]"""
return self._collection.query(expr)
def _delete_ids(self, ids):
ids = ['"' + str(entry) + '"' for entry in ids]
expr = f"""pk in [{','.join(ids)}]"""
self._collection.delete(expr)
self._collection.flush()
def _get_embeddings(self, ids):
ids = ['"' + str(entry) + '"' for entry in ids]
expr = f"""pk in [{','.join(ids)}]"""
return self._collection.query(
expr, output_fields=["pk", "sample_id", "vector"]
)
def remove_from_index(
self,
sample_ids=None,
label_ids=None,
allow_missing=True,
warn_missing=False,
reload=True,
):
if label_ids is not None:
ids = label_ids
else:
ids = sample_ids
if not allow_missing or warn_missing:
existing_ids = self._get_existing_ids(ids)
missing_ids = set(ids) - set(existing_ids)
num_missing = len(missing_ids)
if num_missing > 0:
if not allow_missing:
raise ValueError(
"Found %d IDs (eg %s) that are not present in the "
"index" % (num_missing, next(iter(missing_ids)))
)
if warn_missing:
logger.warning(
"Ignoring %d IDs that are not present in the index",
num_missing,
)
ids = existing_ids
self._delete_ids(ids=ids)
if reload:
self.reload()
def get_embeddings(
self,
sample_ids=None,
label_ids=None,
allow_missing=True,
warn_missing=False,
):
if label_ids is not None:
if self.config.patches_field is None:
raise ValueError("This index does not support label IDs")
if sample_ids is not None:
logger.warning(
"Ignoring sample IDs when label IDs are provided"
)
if sample_ids is not None and self.config.patches_field is not None:
(
embeddings,
sample_ids,
label_ids,
missing_ids,
) = self._get_patch_embeddings_from_sample_ids(sample_ids)
elif self.config.patches_field is not None:
(
embeddings,
sample_ids,
label_ids,
missing_ids,
) = self._get_patch_embeddings_from_label_ids(label_ids)
else:
(
embeddings,
sample_ids,
label_ids,
missing_ids,
) = self._get_sample_embeddings(sample_ids)
num_missing_ids = len(missing_ids)
if num_missing_ids > 0:
if not allow_missing:
raise ValueError(
"Found %d IDs (eg %s) that do not exist in the index"
% (num_missing_ids, missing_ids[0])
)
if warn_missing:
logger.warning(
"Skipping %d IDs that do not exist in the index",
num_missing_ids,
)
embeddings = np.array(embeddings)
sample_ids = np.array(sample_ids)
if label_ids is not None:
label_ids = np.array(label_ids)
return embeddings, sample_ids, label_ids
def cleanup(self):
pymilvus.utility.drop_collection(
self.config.collection_name, using=self._alias
)
self._collection = None
def _get_sample_embeddings(self, sample_ids, batch_size=1000):
found_embeddings = []
found_sample_ids = []
if sample_ids is None:
raise ValueError(
"Milvus does not support retrieving all vectors in an index"
)
for batch_ids in fou.iter_batches(sample_ids, batch_size):
response = self._get_embeddings(list(batch_ids))
for r in response:
found_embeddings.append(r["vector"])
found_sample_ids.append(r["sample_id"])
missing_ids = list(set(sample_ids) - set(found_sample_ids))
return found_embeddings, found_sample_ids, None, missing_ids
def _get_patch_embeddings_from_label_ids(self, label_ids, batch_size=1000):
found_embeddings = []
found_sample_ids = []
found_label_ids = []
if label_ids is None:
raise ValueError(
"Milvus does not support retrieving all vectors in an index"
)
for batch_ids in fou.iter_batches(label_ids, batch_size):
response = self._get_embeddings(list(batch_ids))
for r in response:
found_embeddings.append(r["vector"])
found_sample_ids.append(r["sample_id"])
found_label_ids.append(r["pk"])
missing_ids = list(set(label_ids) - set(found_label_ids))
return found_embeddings, found_sample_ids, found_label_ids, missing_ids
def _get_patch_embeddings_from_sample_ids(
self, sample_ids, batch_size=100
):
found_embeddings = []
found_sample_ids = []
found_label_ids = []
query_vector = [0.0] * self._get_dimension()
top_k = min(batch_size, self.config.max_k)
for batch_ids in fou.iter_batches(sample_ids, batch_size):
ids = ['"' + str(entry) + '"' for entry in batch_ids]
expr = f"""pk in [{','.join(ids)}]"""
response = self._collection.search(
data=[query_vector],
anns_field="vector",
param=self.config.search_params,
expr=expr,
limit=top_k,
)
ids = [x.id for x in response[0]]
response = self._get_embeddings(ids)
for r in response:
found_embeddings.append(r["vector"])
found_sample_ids.append(r["sample_id"])
found_label_ids.append(r["pk"])
missing_ids = list(set(sample_ids) - set(found_sample_ids))
return found_embeddings, found_sample_ids, found_label_ids, missing_ids
def _kneighbors(
self,
query=None,
k=None,
reverse=False,
aggregation=None,
return_dists=False,
):
if query is None:
raise ValueError("Milvus does not support full index neighbors")
if reverse is True:
raise ValueError(
"Milvus does not support least similarity queries"
)
if k is None or k > self.config.max_k:
raise ValueError("Milvus requires k<=%s" % self.config.max_k)
if aggregation not in (None, "mean"):
raise ValueError("Unsupported aggregation '%s'" % aggregation)
query = self._parse_neighbors_query(query)
if aggregation == "mean" and query.ndim == 2:
query = query.mean(axis=0)
single_query = query.ndim == 1
if single_query:
query = [query]
if self.has_view:
if self.config.patches_field is not None:
index_ids = self.current_label_ids
else:
index_ids = self.current_sample_ids
expr = ['"' + str(entry) + '"' for entry in index_ids]
expr = f"""pk in [{','.join(expr)}]"""
else:
expr = None
sample_ids = []
label_ids = [] if self.config.patches_field is not None else None
dists = []
for q in query:
if self.config.patches_field is not None:
output_fields = ["sample_id"]
else:
output_fields = None
response = self._collection.search(
data=[q.tolist()],
anns_field="vector",
limit=k,
expr=expr,
param=self.config.search_params,
output_fields=output_fields,
)
if self.config.patches_field is not None:
sample_ids.append(
[r.entity.get("sample_id") for r in response[0]]
)
label_ids.append([r.id for r in response[0]])
else:
sample_ids.append([r.id for r in response[0]])
if return_dists:
dists.append([r.score for r in response[0]])
if single_query:
sample_ids = sample_ids[0]
if label_ids is not None:
label_ids = label_ids[0]
if return_dists:
dists = dists[0]
if return_dists:
return sample_ids, label_ids, dists
return sample_ids, label_ids
def _parse_neighbors_query(self, query):
if etau.is_str(query):
query_ids = [query]
single_query = True
else:
query = np.asarray(query)
# Query by vector(s)
if np.issubdtype(query.dtype, np.number):
return query
query_ids = list(query)
single_query = False
# Query by ID(s)
response = self._get_embeddings(query_ids)
query = np.array([x["vector"] for x in response])
if query.size == 0:
raise ValueError(
"Query IDs %s were not found in the index" % query_ids
)
if single_query:
query = query[0, :]
return query
def _get_dimension(self):
if self._collection is None:
return None
for field in self._collection.describe()["fields"]:
if field["name"] == "vector":
return field["params"]["dim"]
@classmethod
def _from_dict(cls, d, samples, config, brain_key):
return cls(samples, config, brain_key)
================================================
FILE: fiftyone/brain/internal/core/mistakenness.py
================================================
"""
Mistakenness methods.
| Copyright 2017-2026, Voxel51, Inc.
| `voxel51.com `_
|
"""
import logging
from math import exp
import numpy as np
from scipy.special import softmax
from scipy.stats import entropy
from fiftyone import ViewField as F
import fiftyone.core.brain as fob
import fiftyone.core.labels as fol
import fiftyone.core.media as fom
import fiftyone.core.utils as fou
import fiftyone.core.validation as fov
logger = logging.getLogger(__name__)
_ALLOWED_TYPES = (
fol.Classification,
fol.Classifications,
fol.Detections,
fol.Polylines,
fol.Keypoints,
fol.TemporalDetections,
)
_MISSED_CONFIDENCE_THRESHOLD = 0.95
_DETECTION_IOU = 0.5
def compute_mistakenness(
samples,
pred_field,
label_field,
mistakenness_field,
missing_field,
spurious_field,
use_logits,
copy_missing,
progress,
):
"""See ``fiftyone/brain/__init__.py``."""
#
# Algorithm
#
# The chance of a mistake is related to how confident the model prediction
# was as well as whether or not the prediction is correct. A prediction
# that is highly confident and incorrect is likely to be a mistake. A
# prediction that is low confidence and incorrect is not likely to be a
# mistake.
#
# Let us compute a confidence measure based on negative entropy of logits:
# $c = -entropy(logits)$. This value is large when there is low uncertainty
# and small when there is high uncertainty. Let us define modulator, $m$,
# based on whether or not the answer is correct. $m = -1$ when the label is
# correct and $1$ otherwise. Then, mistakenness is computed as
# $(m * exp(c) + 1) / 2$ so that high confidence correct predictions result
# in low mistakenness, high confidence incorrect predictions result in high
# mistakenness, and low confidence predictions result in middling
# mistakenness.
#
fov.validate_collection_label_fields(
samples, (pred_field, label_field), _ALLOWED_TYPES, same_type=True
)
if samples.media_type == fom.VIDEO:
mistakenness_field, _ = samples._handle_frame_field(mistakenness_field)
missing_field, _ = samples._handle_frame_field(missing_field)
spurious_field, _ = samples._handle_frame_field(spurious_field)
is_objects = samples._is_label_field(
pred_field,
(fol.Detections, fol.Polylines, fol.Keypoints, fol.TemporalDetections),
)
if is_objects:
eval_key = _make_eval_key(samples, mistakenness_field)
config = DetectionMistakennessConfig(
pred_field,
label_field,
mistakenness_field,
missing_field,
spurious_field,
use_logits,
copy_missing,
eval_key,
)
else:
eval_key = None
config = ClassificationMistakennessConfig(
pred_field, label_field, mistakenness_field, use_logits
)
brain_key = mistakenness_field
brain_method = config.build()
brain_method.ensure_requirements()
brain_method.register_run(samples, brain_key, cleanup=False)
brain_method.register_samples(samples)
if is_objects:
samples.evaluate_detections(
pred_field,
gt_field=label_field,
eval_key=eval_key,
classwise=False,
iou=_DETECTION_IOU,
progress=progress,
)
view = samples.select_fields([label_field, pred_field])
processing_frames = samples._is_frame_field(label_field)
logger.info("Computing mistakenness...")
for sample in view.iter_samples(progress=progress):
if processing_frames:
images = sample.frames.values()
else:
images = [sample]
sample_mistakenness = []
num_missing = 0
num_spurious = 0
for image in images:
if is_objects:
(
img_mistakenness,
img_missing,
img_spurious,
) = brain_method.process_image(image, eval_key)
num_missing += img_missing
num_spurious += img_spurious
if processing_frames:
image[missing_field] = img_missing
image[spurious_field] = img_spurious
else:
img_mistakenness = brain_method.process_image(image)
if img_mistakenness is not None:
sample_mistakenness.append(img_mistakenness)
if processing_frames:
image[mistakenness_field] = img_mistakenness
if sample_mistakenness:
sample[mistakenness_field] = np.max(sample_mistakenness)
else:
sample[mistakenness_field] = None
if is_objects:
sample[missing_field] = num_missing
sample[spurious_field] = num_spurious
sample.save()
if eval_key is not None:
samples.delete_evaluation(eval_key)
brain_method.save_run_results(samples, brain_key, None)
logger.info("Mistakenness computation complete")
# @todo move to `fiftyone/brain/mistakenness.py`
# Don't do this hastily; `get_brain_info()` on existing datasets has this
# class's full path in it and may need migration
class MistakennessMethodConfig(fob.BrainMethodConfig):
def __init__(self, pred_field, label_field, mistakenness_field, **kwargs):
super().__init__(**kwargs)
self.pred_field = pred_field
self.label_field = label_field
self.mistakenness_field = mistakenness_field
@property
def type(self):
return "mistakenness"
class MistakennessMethod(fob.BrainMethod):
def __init__(self, config):
super().__init__(config)
self.pred_field = None
self.label_field = None
self.label_type = None
def ensure_requirements(self):
pass
def register_samples(self, samples):
self.pred_field, _ = samples._handle_frame_field(
self.config.pred_field
)
self.label_field, _ = samples._handle_frame_field(
self.config.label_field
)
self.label_type = samples._get_label_field_type(self.config.pred_field)
def _validate_run(self, samples, brain_key, existing_info):
self._validate_fields_match(brain_key, "pred_field", existing_info)
self._validate_fields_match(brain_key, "label_field", existing_info)
self._validate_fields_match(
brain_key, "mistakenness_field", existing_info
)
# @todo move to `fiftyone/brain/mistakenness.py`
# Don't do this hastily; `get_brain_info()` on existing datasets has this
# class's full path in it and may need migration
class ClassificationMistakennessConfig(MistakennessMethodConfig):
def __init__(
self, pred_field, label_field, mistakenness_field, use_logits, **kwargs
):
super().__init__(pred_field, label_field, mistakenness_field, **kwargs)
self.use_logits = use_logits
@property
def method(self):
return "classification"
class ClassificationMistakenness(MistakennessMethod):
def process_image(self, sample_or_frame):
use_logits = self.config.use_logits
pred_label, gt_label = _get_data(
sample_or_frame, self.pred_field, self.label_field, use_logits
)
if pred_label is None and gt_label is None:
return None
if pred_label is None or gt_label is None:
m = 1.0
elif isinstance(pred_label, fol.Classifications):
# For multilabel problems, all labels must match
pred_labels = set(c.label for c in pred_label.classifications)
gt_labels = set(c.label for c in gt_label.classifications)
m = float(pred_labels == gt_labels)
else:
m = float(pred_label.label == gt_label.label)
if pred_label is None:
mistakenness = 1.0
elif use_logits:
mistakenness = _compute_mistakenness_class(pred_label.logits, m)
else:
mistakenness = _compute_mistakenness_class_conf(
pred_label.confidence, m
)
return mistakenness
def get_fields(self, samples, brain_key):
pred_field = self.config.pred_field
label_field = self.config.label_field
mistakenness_field = self.config.mistakenness_field
fields = [pred_field, label_field, mistakenness_field]
if samples._is_frame_field(label_field):
fields.append(samples._FRAMES_PREFIX + mistakenness_field)
return fields
def cleanup(self, samples, brain_key):
label_field = self.config.label_field
mistakenness_field = self.config.mistakenness_field
samples._dataset.delete_sample_fields(
mistakenness_field, error_level=1
)
if samples._is_frame_field(label_field):
samples._dataset.delete_frame_fields(
mistakenness_field, error_level=1
)
# @todo move to `fiftyone/brain/mistakenness.py`
# Don't do this hastily; `get_brain_info()` on existing datasets has this
# class's full path in it and may need migration
class DetectionMistakennessConfig(MistakennessMethodConfig):
def __init__(
self,
pred_field,
label_field,
mistakenness_field,
missing_field,
spurious_field,
use_logits,
copy_missing,
eval_key,
**kwargs
):
super().__init__(pred_field, label_field, mistakenness_field, **kwargs)
self.missing_field = missing_field
self.spurious_field = spurious_field
self.use_logits = use_logits
self.copy_missing = copy_missing
self.eval_key = eval_key
@property
def method(self):
return "detection"
class DetectionMistakenness(MistakennessMethod):
def process_image(self, sample_or_frame, eval_key):
missing_field = self.config.missing_field
spurious_field = self.config.spurious_field
mistakenness_field = self.config.mistakenness_field
copy_missing = self.config.copy_missing
use_logits = self.config.use_logits
pred_label, gt_label = _get_data(
sample_or_frame, self.pred_field, self.label_field, use_logits
)
list_field = self.label_type._LABEL_LIST_FIELD
if pred_label is None:
pred_label = self.label_type()
if gt_label is None:
gt_label = self.label_type()
num_spurious = 0
num_missing = 0
missing_objects = {}
image_mistakenness = []
pred_map = {}
for pred_obj in pred_label[list_field]:
pred_map[pred_obj.id] = pred_obj
gt_id = pred_obj[eval_key + "_id"]
conf = pred_obj.confidence
if gt_id == "" and conf > _MISSED_CONFIDENCE_THRESHOLD:
# Unmached FP with high confidence are missing
pred_obj[missing_field] = True
num_missing += 1
missing_objects[pred_obj.id] = pred_obj
for gt_obj in gt_label[list_field]:
# Avoid adding the same unmatched FP predictions upon multiple runs
# of this method
if copy_missing and gt_obj.has_field(missing_field):
if gt_obj.id in missing_objects:
del missing_objects[gt_obj.id]
continue
pred_id = gt_obj[eval_key + "_id"]
if pred_id == "":
# FN may be spurious
gt_obj[spurious_field] = True
num_spurious += 1
else:
# For matched FP, compute mistakenness
iou = gt_obj[eval_key + "_iou"]
pred_obj = pred_map[pred_id]
m = float(gt_obj.label == pred_obj.label)
if use_logits:
mistakenness_class = _compute_mistakenness_class(
pred_obj.logits, m
)
mistakenness_loc = _compute_mistakenness_loc(
pred_obj.logits, iou
)
else:
mistakenness_class = _compute_mistakenness_class_conf(
pred_obj.confidence, m
)
mistakenness_loc = _compute_mistakenness_loc_conf(
pred_obj.confidence, iou
)
gt_obj[mistakenness_field] = mistakenness_class
gt_obj[mistakenness_field + "_loc"] = mistakenness_loc
image_mistakenness.append(mistakenness_class)
if copy_missing:
gt_label[list_field].extend(missing_objects.values())
sample_or_frame[self.label_field] = gt_label
if image_mistakenness:
mistakenness = np.max(image_mistakenness)
else:
mistakenness = -1
return mistakenness, num_missing, num_spurious
def get_fields(self, samples, brain_key):
pred_field = self.config.pred_field
label_field = self.config.label_field
mistakenness_field = self.config.mistakenness_field
missing_field = self.config.missing_field
spurious_field = self.config.spurious_field
label_type = samples._get_label_field_type(pred_field)
list_field = label_type._LABEL_LIST_FIELD
fields = [
mistakenness_field,
missing_field,
spurious_field,
"%s.%s.%s" % (label_field, list_field, mistakenness_field),
"%s.%s.%s_loc" % (label_field, list_field, mistakenness_field),
"%s.%s.%s" % (pred_field, list_field, missing_field),
"%s.%s.%s" % (label_field, list_field, spurious_field),
]
if samples._is_frame_field(pred_field):
fields.extend(
[
samples._FRAMES_PREFIX + mistakenness_field,
samples._FRAMES_PREFIX + missing_field,
samples._FRAMES_PREFIX + spurious_field,
]
)
return fields
def cleanup(self, samples, brain_key):
pred_field = self.config.pred_field
label_field = self.config.label_field
mistakenness_field = self.config.mistakenness_field
missing_field = self.config.missing_field
spurious_field = self.config.spurious_field
eval_key = self.config.eval_key
label_type = samples._get_label_field_type(pred_field)
list_field = label_type._LABEL_LIST_FIELD
pred_field, is_frame_field = samples._handle_frame_field(pred_field)
label_field, _ = samples._handle_frame_field(label_field)
fields = [
mistakenness_field,
missing_field,
spurious_field,
"%s.%s.%s" % (label_field, list_field, mistakenness_field),
"%s.%s.%s_loc" % (label_field, list_field, mistakenness_field),
"%s.%s.%s" % (pred_field, list_field, missing_field),
"%s.%s.%s" % (label_field, list_field, spurious_field),
]
if self.config.copy_missing:
# Remove objects that were added to `label_field`
samples._dataset.filter_labels(
self.config.label_field, F(missing_field).exists(False)
).save()
if is_frame_field:
samples._dataset.delete_sample_fields(
[mistakenness_field, spurious_field, missing_field],
error_level=1,
)
samples._dataset.delete_frame_fields(fields, error_level=1)
else:
samples._dataset.delete_sample_fields(fields, error_level=1)
if eval_key in samples.list_evaluations():
samples.delete_evaluation(eval_key)
def _validate_run(self, samples, brain_key, existing_info):
super()._validate_run(samples, brain_key, existing_info)
self._validate_fields_match(brain_key, "missing_field", existing_info)
self._validate_fields_match(brain_key, "spurious_field", existing_info)
self._validate_fields_match(brain_key, "copy_missing", existing_info)
def _make_eval_key(samples, brain_key):
existing_eval_keys = samples.list_evaluations()
eval_key = brain_key + "_eval"
if eval_key not in existing_eval_keys:
return eval_key
idx = 2
while eval_key + str(idx) in existing_eval_keys:
idx += 1
return eval_key + str(idx)
def _get_data(sample, pred_field, label_field, use_logits):
pred_label = sample[pred_field]
label = sample[label_field]
if pred_label is None:
return pred_label, label
if isinstance(pred_label, fol.Detections):
for det in pred_label.detections:
if det.confidence is None:
raise ValueError(
"Detection '%s' in sample '%s' field '%s' has no "
"confidence" % (det.id, sample.id, pred_field)
)
elif isinstance(pred_label, fol.Polylines):
for poly in pred_label.polylines:
if poly.confidence is None:
raise ValueError(
"Polyline '%s' in sample '%s' field '%s' has no "
"confidence" % (poly.id, sample.id, pred_field)
)
elif use_logits:
if pred_label.logits is None:
raise ValueError(
"Sample '%s' field '%s' has no logits"
% (sample.id, pred_field)
)
else:
if pred_label.confidence is None:
raise ValueError(
"Sample '%s' field '%s' has no confidence"
% (sample.id, pred_field)
)
return pred_label, label
def _compute_mistakenness_class(logits, m):
# constrain m to either 1 (incorrect) or -1 (correct)
m = m * -2.0 + 1.0
c = -1.0 * entropy(softmax(np.asarray(logits)))
mistakenness = (m * exp(c) + 1.0) / 2.0
return mistakenness
def _compute_mistakenness_loc(logits, iou):
# i = 0 for high iou, i = 1 for low iou
i = (1.0 / (1.0 - _DETECTION_IOU)) * (1.0 - iou)
# c = 0 for low confidence, c = 1 for high confidence
c = exp(-1.0 * entropy(softmax(np.asarray(logits))))
# mistakenness = i when c = i, mistakenness = 0.5 if c = 0
# mistakenness is higher with lower IoU and closer to 0 or 1 with higher
# confidence
mistakenness = (c * ((2.0 * i) - 1.0) + 1.0) / 2.0
return mistakenness
def _compute_mistakenness_class_conf(confidence, m):
# constrain m to either 1 (incorrect) or -1 (correct)
m = m * -2.0 + 1.0
mistakenness = (m * confidence + 1.0) / 2.0
return mistakenness
def _compute_mistakenness_loc_conf(confidence, iou):
# i = 0 for high iou, i = 1 for low iou
i = (1.0 / (1.0 - _DETECTION_IOU)) * (1.0 - iou)
# c = 0 for low confidence, c = 1 for high confidence
c = confidence
# mistakenness = i when c = i, mistakenness = 0.5 if c = 0
# mistakenness is higher with lower IoU and closer to 0 or 1 with higher
# confidence
mistakenness = (c * ((2.0 * i) - 1.0) + 1.0) / 2.0
return mistakenness
================================================
FILE: fiftyone/brain/internal/core/mongodb.py
================================================
"""
MongoDB similarity backend.
| Copyright 2017-2026, Voxel51, Inc.
| `voxel51.com `_
|
"""
import logging
from bson import ObjectId
import numpy as np
from pymongo.errors import OperationFailure
import eta.core.utils as etau
from fiftyone import ViewField as F
import fiftyone.core.fields as fof
import fiftyone.core.media as fom
import fiftyone.core.utils as fou
import fiftyone.brain.internal.core.utils as fbu
from fiftyone.brain.similarity import (
SimilarityConfig,
Similarity,
SimilarityIndex,
)
logger = logging.getLogger(__name__)
_SUPPORTED_METRICS = {
"cosine": "cosine",
"dotproduct": "dotProduct",
"euclidean": "euclidean",
}
class MongoDBSimilarityConfig(SimilarityConfig):
"""Configuration for a MongoDB similarity instance.
Args:
index_name (None): the name of the MongoDB vector index to use or
create. If none is provided, a new index will be created
metric ("cosine"): the embedding distance metric to use when creating a
new index. Supported values are
``("cosine", "dotproduct", "euclidean")``
**kwargs: keyword arguments for
:class:`fiftyone.brain.similarity.SimilarityConfig`
"""
def __init__(self, index_name=None, metric="cosine", **kwargs):
if kwargs.get("embeddings_field") is None and index_name is None:
raise ValueError(
"You must provide either the name of a field to read/write "
"embeddings for this index by passing the `embeddings` "
"parameter, or you must provide the name of an existing "
"vector search index via the `index_name` parameter"
)
# @todo support this. Will likely require copying embeddings to a new
# collection as vector search indexes do not yet support array fields
if kwargs.get("patches_field") is not None:
raise ValueError(
"The MongoDB backend does not yet support patch embeddings"
)
if metric not in _SUPPORTED_METRICS:
raise ValueError(
"Unsupported metric '%s'. Supported values are %s"
% (metric, tuple(_SUPPORTED_METRICS.keys()))
)
super().__init__(**kwargs)
self.index_name = index_name
self.metric = metric
@property
def method(self):
return "mongodb"
@property
def max_k(self):
return 10000 # MongoDB limit
@property
def supports_least_similarity(self):
return False
@property
def supported_aggregations(self):
return ("mean",)
class MongoDBSimilarity(Similarity):
"""MongoDB similarity factory.
Args:
config: a :class:`MongoDBSimilarityConfig`
"""
def ensure_requirements(self):
#
# https://pymongo.readthedocs.io/en/stable/api/pymongo/collection.html#pymongo.collection.Collection.create_search_index
#
# Could also validate that user is connected to an Atlas cluster here
# eg Atlas clusters generally have hostnames which end in "mongodb.net"
# https://stackoverflow.com/q/73180110
#
fou.ensure_package("pymongo>=4.7")
def ensure_usage_requirements(self):
fou.ensure_package("pymongo>=4.7")
def initialize(self, samples, brain_key):
return MongoDBSimilarityIndex(
samples, self.config, brain_key, backend=self
)
class MongoDBSimilarityIndex(SimilarityIndex):
"""Class for interacting with MongoDB similarity indexes.
Args:
samples: the :class:`fiftyone.core.collections.SampleCollection` used
config: the :class:`MongoDBSimilarityConfig` used
brain_key: the brain key
backend (None): a :class:`MongoDBSimilarity` instance
"""
def __init__(self, samples, config, brain_key, backend=None):
super().__init__(samples, config, brain_key, backend=backend)
self._dataset = samples._dataset
self._sample_ids = None
self._label_ids = None
self._index = None
self._initialize()
@property
def is_external(self):
return False
@property
def total_index_size(self):
if self._sample_ids is not None:
return len(self._sample_ids)
if self._dataset.media_type == fom.GROUP:
samples = self._dataset.select_group_slices(_allow_mixed=True)
else:
samples = self._dataset
patches_field = self.config.patches_field
embeddings_field = self.config.embeddings_field
if patches_field is not None:
_, embeddings_path = self._dataset._get_label_field_path(
patches_field, embeddings_field
)
samples = samples.filter_labels(
patches_field, F(embeddings_field).exists()
)
return samples.count(embeddings_path)
if samples.has_field(embeddings_field):
return samples.exists(embeddings_field).count()
return 0
def _initialize(self):
coll = self._dataset._sample_collection
try:
indexes = {
i["name"]: i
for i in coll.aggregate([{"$listSearchIndexes": {}}])
}
except OperationFailure:
# https://www.mongodb.com/docs/manual/release-notes/7.0/#atlas-search-index-management
# https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-overview
if self.config.index_name is None:
raise ValueError(
"You must be running MongoDB Atlas 7.0 or later in order "
"to use vector search indexes"
)
# Must assume index exists because we can't use pymongo to check...
self._index = True
return
if self.config.index_name is None:
root = self.config.embeddings_field
index_name = fbu.get_unique_name(root, list(indexes.keys()))
self.config.index_name = index_name
self.save_config()
elif self.config.embeddings_field is None:
info = indexes.get(self.config.index_name, None)
if info is None:
raise ValueError(
"Index '%s' does not exist" % self.config.index_name
)
self.config.embeddings_field = next(
iter(info["latestDefinition"]["mappings"]["fields"].keys())
)
self.save_config()
if self.config.index_name in indexes:
# Index already exists
self._index = True
elif self.total_index_size > 0:
# Embeddings already exist but the index hasn't been declared yet
dimension = self._get_dimension()
self._create_index(dimension)
else:
# Index will be created when add_to_index() is called
pass
def _get_dimension(self):
if self._dataset.media_type == fom.GROUP:
samples = self._dataset.select_group_slices(_allow_mixed=True)
else:
samples = self._dataset
patches_field = self.config.patches_field
embeddings_field = self.config.embeddings_field
if patches_field is not None:
_, embeddings_path = self._dataset._get_label_field_path(
patches_field, embeddings_field
)
view = samples.filter_labels(
patches_field, F(embeddings_field).exists()
).limit(1)
embeddings = view.values(embeddings_path, unwind=True)
else:
view = samples.exists(embeddings_field).limit(1)
embeddings = view.values(embeddings_field)
embedding = next(iter(embeddings), None)
if embedding is None:
return None
return len(embedding) # MongoDB requires list fields
def _create_index(self, dimension):
# https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage
# https://www.mongodb.com/docs/languages/python/pymongo-driver/current/indexes/atlas-search-index/
from pymongo.operations import SearchIndexModel
field = self._dataset.get_field(self.config.embeddings_field)
if field is not None and not isinstance(field, fof.ListField):
raise ValueError(
"MongoDB vector search indexes require embeddings to be "
"stored in list fields"
)
metric = _SUPPORTED_METRICS[self.config.metric]
fields = [
{
"type": "vector",
"numDimensions": dimension,
"path": self.config.embeddings_field,
"similarity": metric,
},
{
"type": "filter",
"path": "_id",
},
]
"""
if self._dataset.media_type == fom.GROUP:
fields.append(
{
"type": "filter",
"path": self._dataset.group_field + ".name",
}
)
"""
model = SearchIndexModel(
name=self.config.index_name,
type="vectorSearch", # requires pymongo>=4.7
definition={"fields": fields},
)
coll = self._dataset._sample_collection
coll.create_search_index(model=model)
self._index = True
@property
def ready(self):
"""Returns True/False whether the vector search index is ready to be
queried.
"""
if self._index is None:
return False
try:
coll = self._dataset._sample_collection
indexes = {
i["name"]: i
for i in coll.aggregate([{"$listSearchIndexes": {}}])
}
except OperationFailure:
# requires MongoDB Atlas 7.0 or later
return None
info = indexes.get(self.config.index_name, {})
return info.get("status", None) == "READY"
def add_to_index(
self,
embeddings,
sample_ids,
label_ids=None,
overwrite=True,
allow_existing=True,
warn_existing=False,
reload=True,
):
if self._index is None:
self._create_index(embeddings.shape[1])
sample_ids = np.asarray(sample_ids)
label_ids = np.asarray(label_ids) if label_ids is not None else None
if not overwrite or not allow_existing or warn_existing:
if self._sample_ids is not None:
_sample_ids, _label_ids = self._sample_ids, self._label_ids
else:
_sample_ids, _label_ids = self._parse_data(
self._dataset, self.config
)
index_sample_ids, index_label_ids, ii, _ = fbu.add_ids(
sample_ids,
label_ids,
_sample_ids,
_label_ids,
patches_field=self.config.patches_field,
overwrite=overwrite,
allow_existing=allow_existing,
warn_existing=warn_existing,
)
self._sample_ids = index_sample_ids
self._label_ids = index_label_ids
if ii.size == 0:
return
embeddings = embeddings[ii, :]
sample_ids = sample_ids[ii]
label_ids = label_ids[ii] if label_ids is not None else None
else:
index_sample_ids = None
index_label_ids = None
fbu.add_embeddings(
self._dataset,
embeddings.tolist(), # MongoDB requires list fields
sample_ids,
label_ids,
self.config.embeddings_field,
patches_field=self.config.patches_field,
)
if reload:
super().reload()
self._sample_ids = index_sample_ids
self._label_ids = index_label_ids
def remove_from_index(
self,
sample_ids=None,
label_ids=None,
allow_missing=True,
warn_missing=False,
reload=True,
):
if not allow_missing or warn_missing:
if self._sample_ids is not None:
_sample_ids, _label_ids = self._sample_ids, self._label_ids
else:
_sample_ids, _label_ids = self._parse_data(
self._dataset, self.config
)
index_sample_ids, index_label_ids, rm_inds = fbu.remove_ids(
sample_ids,
label_ids,
_sample_ids,
_label_ids,
patches_field=self.config.patches_field,
allow_missing=allow_missing,
warn_missing=warn_missing,
)
self._sample_ids = index_sample_ids
self._label_ids = index_label_ids
if rm_inds.size == 0:
return
if self.config.patches_field is not None:
sample_ids = None
label_ids = _label_ids[rm_inds]
else:
sample_ids = _sample_ids[rm_inds]
label_ids = None
else:
index_sample_ids = None
index_label_ids = None
fbu.remove_embeddings(
self._dataset,
self.config.embeddings_field,
sample_ids=sample_ids,
label_ids=label_ids,
patches_field=self.config.patches_field,
)
if reload:
super().reload()
self._sample_ids = index_sample_ids
self._label_ids = index_label_ids
def get_embeddings(
self,
sample_ids=None,
label_ids=None,
allow_missing=True,
warn_missing=False,
):
if self._dataset.media_type == fom.GROUP:
samples = self._dataset.select_group_slices(_allow_mixed=True)
else:
samples = self._dataset
if sample_ids is not None:
samples = samples.select(sample_ids)
elif label_ids is not None:
if self.config.patches_field is None:
raise ValueError("This index does not support label IDs")
if sample_ids is not None:
logger.warning(
"Ignoring sample IDs when label IDs are provided"
)
samples = samples.select_labels(
ids=label_ids, fields=self.config.patches_field
)
_embeddings, _sample_ids, _label_ids = fbu.get_embeddings(
samples,
patches_field=self.config.patches_field,
embeddings_field=self.config.embeddings_field,
)
if label_ids is not None:
inds = _get_inds(
label_ids,
_label_ids,
"label",
allow_missing,
warn_missing,
)
embeddings = _embeddings[inds, :]
sample_ids = _sample_ids[inds]
label_ids = np.asarray(label_ids)
elif sample_ids is not None:
if etau.is_str(sample_ids):
sample_ids = [sample_ids]
if self.config.patches_field is not None:
sample_ids = set(sample_ids)
bools = [_id in sample_ids for _id in _sample_ids]
inds = np.nonzero(bools)[0]
else:
inds = _get_inds(
sample_ids,
_sample_ids,
"sample",
allow_missing,
warn_missing,
)
embeddings = _embeddings[inds, :]
sample_ids = _sample_ids[inds]
if self.config.patches_field is not None:
label_ids = _label_ids[inds]
else:
label_ids = None
else:
embeddings = _embeddings
sample_ids = _sample_ids
label_ids = _label_ids
return embeddings, sample_ids, label_ids
def reload(self):
self._sample_ids = None
self._label_ids = None
super().reload()
def cleanup(self):
if self._index is None:
return
try:
coll = self._dataset._sample_collection
coll.drop_search_index(self.config.index_name)
except OperationFailure:
# requires MongoDB Atlas 7.0 or later
pass
self._index = None
def _kneighbors(
self,
query=None,
k=None,
reverse=False,
aggregation=None,
return_dists=False,
):
if query is None:
raise ValueError("MongoDB does not support full index neighbors")
if reverse is True:
raise ValueError(
"MongoDB does not support least similarity queries"
)
if aggregation not in (None, "mean"):
raise ValueError(
f"MongoDB does not support {aggregation} aggregation"
)
if k is None:
k = min(self.index_size, self.config.max_k)
# https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage
num_candidates = min(10 * k, self.config.max_k)
query = self._parse_neighbors_query(query)
if aggregation == "mean" and query.ndim == 2:
query = query.mean(axis=0)
single_query = query.ndim == 1
if single_query:
query = [query]
if self.has_view:
index_ids = self.current_sample_ids
# if self.config.patches_field is not None:
# index_ids = self.current_label_ids
else:
index_ids = None
dataset = self._dataset
sample_ids = []
label_ids = None
# if self.config.patches_field is not None:
# label_ids = []
dists = []
for q in query:
search = {
"index": self.config.index_name,
"path": self.config.embeddings_field,
"limit": k,
"numCandidates": num_candidates,
"queryVector": q.tolist(),
}
if index_ids is not None:
search["filter"] = {
"_id": {"$in": [ObjectId(_id) for _id in index_ids]}
}
"""
elif dataset.media_type == fom.GROUP:
# $vectorSearch must be the first stage in all pipelines, so we
# have to incorporate slice selection as a $filter
name_field = dataset.group_field + ".name"
group_slice = self.view.group_slice or dataset.group_slice
search["filter"] = {name_field: {"$eq": group_slice}}
"""
project = {"_id": 1}
# if self.config.patches_field is not None:
# project["_sample_id"] = 1
if return_dists:
project["score"] = {"$meta": "vectorSearchScore"}
# https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage
pipeline = [{"$vectorSearch": search}, {"$project": project}]
try:
matches = list(
dataset._aggregate(
pipeline=pipeline, manual_group_select=True
)
)
except OperationFailure as e:
if index_ids is None:
raise e
logger.warning(
"This legacy search index does not yet support views. "
"Please follow the instructions at "
"https://github.com/voxel51/fiftyone-brain/pull/248 "
"to upgrade it.\n\nIn the meantime, the full index will "
"instead be queried, which may result in fewer "
"matches in your current view"
)
search.pop("filter")
matches = list(
dataset._aggregate(
pipeline=pipeline, manual_group_select=True
)
)
sample_ids.append([str(m["_id"]) for m in matches])
# if self.config.patches_field is not None:
# sample_ids.append([str(m["_sample_id"]) for m in matches])
# label_ids.append([str(m["_id"]) for m in matches])
if return_dists:
dists.append([m["score"] for m in matches])
if single_query:
sample_ids = sample_ids[0]
if label_ids is not None:
label_ids = label_ids[0]
if return_dists:
dists = dists[0]
if return_dists:
return sample_ids, label_ids, dists
return sample_ids, label_ids
def _parse_neighbors_query(self, query):
if etau.is_str(query):
query_ids = [query]
single_query = True
else:
query = np.asarray(query)
# Query by vector(s)
if np.issubdtype(query.dtype, np.number):
return query
query_ids = list(query)
single_query = False
# Query by ID(s)
embeddings = self._get_embeddings(query_ids)
num_missing = len(query_ids) - len(embeddings)
for e in embeddings:
num_missing += int(e is None)
if num_missing > 0:
if single_query:
raise ValueError("The query ID does not exist in this index")
else:
raise ValueError(
f"{num_missing} query IDs do not exist in this index"
)
query = np.array(embeddings)
if single_query:
query = query[0, :]
return query
def _get_embeddings(self, query_ids):
if self._dataset.media_type == fom.GROUP:
samples = self._dataset.select_group_slices(_allow_mixed=True)
else:
samples = self._dataset
patches_field = self.config.patches_field
embeddings_field = self.config.embeddings_field
if patches_field is not None:
_, embeddings_path = self._dataset._get_label_field_path(
patches_field, embeddings_field
)
view = samples.filter_labels(
patches_field, F("_id").is_in(query_ids)
)
embeddings = view.values(embeddings_path, unwind=True)
else:
view = samples.select(query_ids)
embeddings = view.values(embeddings_field)
return embeddings
@staticmethod
def _parse_data(samples, config):
if samples.media_type == fom.GROUP:
samples = samples.select_group_slices(_allow_mixed=True)
if config.patches_field is not None:
samples = samples.filter_labels(
config.patches_field, F(config.embeddings_field).exists()
)
else:
samples = samples.exists(config.embeddings_field)
return fbu.get_ids(samples, patches_field=config.patches_field)
@classmethod
def _from_dict(cls, d, samples, config, brain_key):
return cls(samples, config, brain_key)
def _get_inds(ids, index_ids, ftype, allow_missing, warn_missing):
if etau.is_str(ids):
ids = [ids]
ids_map = {_id: i for i, _id in enumerate(index_ids)}
inds = []
bad_ids = []
for _id in ids:
idx = ids_map.get(_id, None)
if idx is not None:
inds.append(idx)
else:
bad_ids.append(_id)
num_missing = len(bad_ids)
if num_missing > 0:
if not allow_missing:
raise ValueError(
"Found %d %s IDs (eg '%s') that are not present in the index"
% (num_missing, ftype, bad_ids[0])
)
if warn_missing:
logger.warning(
"Ignoring %d %s IDs that are not present in the index",
num_missing,
ftype,
)
return np.array(inds)
================================================
FILE: fiftyone/brain/internal/core/mosaic.py
================================================
"""
Mosaic similarity backend.
| Copyright 2017-2026, Voxel51, Inc.
| `voxel51.com `_
|
"""
import logging
import numpy as np
import eta.core.utils as etau
import fiftyone.core.utils as fou
from fiftyone.brain.similarity import (
SimilarityConfig,
Similarity,
SimilarityIndex,
)
import fiftyone.brain.internal.core.utils as fbu
vector_search_client = fou.lazy_import("databricks.vector_search.client")
logger = logging.getLogger(__name__)
# Todo: add in required for arguments that are necessary to create the index table
class MosaicSimilarityConfig(SimilarityConfig):
"""Configuration for the Mosaic similarity backend.
Args:
endpoint_name (None): the name of the vector search endpoint that was created in the Databricks workspace
workspace_url (None): the URL of the Databricks workspace
catalog_name (None): the name of the catalog in the Databricks workspace
schema_name (None): the name of the schema in the Databricks workspace
index_name (None): the name of the index to use, if one is not provided, a unique name will be generated
service_principal_client_id (None): the client ID of the service principal created for authentication
service_principal_client_secret (None): the client secret of the service principal created for authentication
personal_access_token (None): the personal access token created for authentication
**kwargs: keyword arguments for
:class:`fiftyone.brain.similarity.SimilarityConfig`
"""
def __init__(
self,
endpoint_name=None,
workspace_url=None,
catalog_name=None,
schema_name=None,
index_name=None,
service_principal_client_id=None,
service_principal_client_secret=None,
personal_access_token=None,
**kwargs,
):
super().__init__(**kwargs)
self.index_name = index_name
self.endpoint_name = endpoint_name
self.catalog_name = catalog_name
self.schema_name = schema_name
# store privately so these aren't serialized
self._workspace_url = workspace_url
self._service_principal_client_id = service_principal_client_id
self._service_principal_client_secret = service_principal_client_secret
self._personal_access_token = personal_access_token
@property
def method(self):
return "mosaic"
@property
def workspace_url(self):
return self._workspace_url
@workspace_url.setter
def workspace_url(self, workspace_url):
self._workspace_url = workspace_url
@property
def service_principal_client_id(self):
return self._service_principal_client_id
@service_principal_client_id.setter
def service_principal_client_id(self, service_principal_client_id):
self._service_principal_client_id = service_principal_client_id
@property
def service_principal_client_secret(self):
return self._service_principal_client_secret
@service_principal_client_secret.setter
def service_principal_client_secret(self, service_principal_client_secret):
self._service_principal_client_secret = service_principal_client_secret
@property
def personal_access_token(self):
return self._personal_access_token
@personal_access_token.setter
def personal_access_token(self, personal_access_token):
self._personal_access_token = personal_access_token
@property
def max_k(self):
return None
@property
def supports_least_similarity(self):
return False
@property
def supported_aggregations(self):
return ("mean",)
def load_credentials(
self,
workspace_url=None,
service_principal_client_id=None,
service_principal_client_secret=None,
personal_access_token=None,
):
self._load_parameters(
workspace_url=workspace_url,
service_principal_client_id=service_principal_client_id,
service_principal_client_secret=service_principal_client_secret,
personal_access_token=personal_access_token,
)
class MosaicSimilarity(Similarity):
"""Mosaic similarity factory.
Args:
config: a :class:`MosaicSimilarityConfig`
"""
def ensure_requirements(self):
fou.ensure_package("databricks-vectorsearch")
def ensure_usage_requirements(self):
fou.ensure_package("databricks-vectorsearch")
def initialize(self, samples, brain_key):
return MosaicSimilarityIndex(
samples, self.config, brain_key, backend=self
)
class MosaicSimilarityIndex(SimilarityIndex):
"""Class for interacting with Mosaic similarity indexes.
Args:
samples: the :class:`fiftyone.core.collections.SampleCollection` used
config: the :class:`MosaicSimilarityConfig` used
brain_key: the brain key
backend (None): a :class:`MosaicSimilarity` instance
"""
def __init__(self, samples, config, brain_key, backend=None):
super().__init__(samples, config, brain_key, backend=backend)
self._client = None
self._index = None
self._initialize()
def _initialize(self):
self._client = vector_search_client.VectorSearchClient(
workspace_url=self.config.workspace_url,
service_principal_client_id=self.config.service_principal_client_id,
service_principal_client_secret=self.config.service_principal_client_secret,
personal_access_token=self.config.personal_access_token,
)
try:
index_names_result = self._client.list_indexes(
self.config.endpoint_name
)
except Exception as e:
raise ValueError(
f"Failed to list indexes from endpoint :{self.config.endpoint_name}"
) from e
index_prefix = f"{self.config.catalog_name}.{self.config.schema_name}."
if not index_names_result:
index_names = []
else:
index_names = [
ind["name"].replace(index_prefix, "")
for ind in index_names_result["vector_indexes"]
if ind["name"].startswith(index_prefix)
]
if self.config.index_name is None:
root = "fiftyone-" + fou.to_slug(self._samples._root_dataset.name)
index_name = fbu.get_unique_name(root, index_names)
self.config.index_name = index_name
self.save_config()
if self.config.index_name in index_names:
index = self._client.get_index(
endpoint_name=self.config.endpoint_name,
index_name=f"{index_prefix}{self.config.index_name}",
)
else:
index = None
self._index = index
def _create_index(self, dimension):
self._index = self._client.create_direct_access_index(
endpoint_name=self.config.endpoint_name,
index_name=f"{self.config.catalog_name}.{self.config.schema_name}.{self.config.index_name}",
primary_key="foid",
embedding_dimension=dimension,
embedding_vector_column="embedding_vector",
schema={
"foid": "string",
"sample_id": "string",
"embedding_vector": "array",
},
)
@property
def client(self):
"""The ``databricks.vector_search.client.VectorSearchClient`` instance for this index."""
return self._client
@property
def total_index_size(self):
if self._index is None:
return 0
return self._index.describe()["status"]["indexed_row_count"]
def add_to_index(
self,
embeddings,
sample_ids,
label_ids=None,
overwrite=True,
allow_existing=True,
warn_existing=False,
reload=True,
batch_size=200,
):
if self._index is None:
self._create_index(embeddings.shape[1])
if label_ids is not None:
ids = label_ids
else:
ids = sample_ids
if warn_existing or not allow_existing or not overwrite:
index_ids = self._get_index_ids()
existing_ids = set(ids) & set(index_ids)
num_existing = len(existing_ids)
if num_existing > 0:
if not allow_existing:
raise ValueError(
"Found %d IDs (eg %s) that already exist in the index"
% (num_existing, next(iter(existing_ids)))
)
if warn_existing:
if overwrite:
logger.warning(
"Overwriting %d IDs that already exist in the "
"index",
num_existing,
)
else:
logger.warning(
"Skipping %d IDs that already exist in the index",
num_existing,
)
else:
existing_ids = set()
if existing_ids and not overwrite:
del_inds = [i for i, _id in enumerate(ids) if _id in existing_ids]
embeddings = np.delete(embeddings, del_inds, axis=0)
sample_ids = np.delete(sample_ids, del_inds)
if label_ids is not None:
label_ids = np.delete(label_ids, del_inds)
for _embeddings, _ids, _sample_ids in zip(
fou.iter_batches(embeddings, batch_size),
fou.iter_batches(ids, batch_size),
fou.iter_batches(sample_ids, batch_size),
):
result = [
{"foid": f, "sample_id": s, "embedding_vector": list(e)}
for f, s, e in zip(_ids, _sample_ids, _embeddings)
]
self._index.upsert(result)
if reload:
self.reload()
def _get_index_ids(self, batch_size=200):
ids = set()
result = self._index.scan(num_results=batch_size)
while len(result) > 0:
ids.update(
[
doc["fields"][0]["value"]["string_value"]
for doc in result["data"]
]
)
last_primary_key = result["last_primary_key"]
result = self._index.scan(
num_results=batch_size, last_primary_key=last_primary_key
)
return list(ids)
def _get_values(self, ids, batch_size=200):
embeddings = []
result = self._index.scan(num_results=batch_size)
while len(result) > 0:
for doc in result["data"]:
foid = doc["fields"][0]["value"]["string_value"]
if foid in ids:
embedding = [
d["number_value"]
for d in doc["fields"][2]["value"]["list_value"][
"values"
]
]
embeddings.append(embedding)
last_primary_key = result["last_primary_key"]
result = self._index.scan(
num_results=batch_size, last_primary_key=last_primary_key
)
return embeddings
def remove_from_index(
self,
sample_ids=None,
label_ids=None,
allow_missing=True,
warn_missing=False,
reload=True,
):
if label_ids is not None:
ids = label_ids
else:
ids = sample_ids
if not allow_missing or warn_missing:
existing_ids = self._get_index_ids()
missing_ids = set(ids) - set(existing_ids)
num_missing = len(missing_ids)
if num_missing > 0:
if not allow_missing:
raise ValueError(
"Found %d IDs (eg %s) that are not present in the "
"index" % (num_missing, next(iter(missing_ids)))
)
if warn_missing:
logger.warning(
"Ignoring %d IDs that are not present in the index",
num_missing,
)
ids = existing_ids
self._index.delete(ids)
if reload:
self.reload()
def get_embeddings(
self,
sample_ids=None,
label_ids=None,
allow_missing=True,
warn_missing=False,
):
if label_ids is not None:
if self.config.patches_field is None:
raise ValueError("This index does not support label IDs")
if sample_ids is not None:
logger.warning(
"Ignoring sample IDs when label IDs are provided"
)
if sample_ids is not None and self.config.patches_field is not None:
(
embeddings,
sample_ids,
label_ids,
missing_ids,
) = self._get_patch_embeddings_from_sample_ids(sample_ids)
elif self.config.patches_field is not None:
(
embeddings,
sample_ids,
label_ids,
missing_ids,
) = self._get_patch_embeddings_from_label_ids(label_ids)
else:
(
embeddings,
sample_ids,
label_ids,
missing_ids,
) = self._get_sample_embeddings(sample_ids)
num_missing_ids = len(missing_ids)
if num_missing_ids > 0:
if not allow_missing:
raise ValueError(
"Found %d IDs (eg %s) that do not exist in the index"
% (num_missing_ids, missing_ids[0])
)
if warn_missing:
logger.warning(
"Skipping %d IDs that do not exist in the index",
num_missing_ids,
)
embeddings = np.array(embeddings)
sample_ids = np.array(sample_ids)
if label_ids is not None:
label_ids = np.array(label_ids)
return embeddings, sample_ids, label_ids
# Note: might be an arg in delete_brain_run?
def cleanup(self):
if self._index is not None:
self._client.delete_index(
self.config.endpoint_name,
f"{self.config.catalog_name}.{self.config.schema_name}.{self.config.index_name}",
)
self._index = None
def _get_sample_embeddings(self, sample_ids, batch_size=200):
found_embeddings = []
found_sample_ids = []
if sample_ids is None:
sample_ids = self._get_index_ids()
result = self._index.scan(num_results=batch_size)
while len(result) > 0:
for doc in result["data"]:
sample_id = doc["fields"][1]["value"]["string_value"]
if sample_id in sample_ids:
embedding = [
d["number_value"]
for d in doc["fields"][2]["value"]["list_value"][
"values"
]
]
found_embeddings.append(embedding)
found_sample_ids.append(sample_id)
last_primary_key = result["last_primary_key"]
result = self._index.scan(
num_results=batch_size, last_primary_key=last_primary_key
)
missing_ids = list(set(sample_ids) - set(found_sample_ids))
return found_embeddings, found_sample_ids, None, missing_ids
def _get_patch_embeddings_from_label_ids(self, label_ids, batch_size=200):
found_embeddings = []
found_sample_ids = []
found_label_ids = []
if label_ids is None:
label_ids = self._get_index_ids()
result = self._index.scan(num_results=batch_size)
while len(result) > 0:
for doc in result["data"]:
label_id = doc["fields"][0]["value"]["string_value"]
if label_id in label_ids:
embedding = [
d["number_value"]
for d in doc["fields"][2]["value"]["list_value"][
"values"
]
]
found_embeddings.append(embedding)
found_label_ids.append(label_id)
found_sample_ids.append(
doc["fields"][1]["value"]["string_value"]
)
last_primary_key = result["last_primary_key"]
result = self._index.scan(
num_results=batch_size, last_primary_key=last_primary_key
)
missing_ids = list(set(label_ids) - set(found_label_ids))
return found_embeddings, found_sample_ids, found_label_ids, missing_ids
def _get_patch_embeddings_from_sample_ids(
self, sample_ids, batch_size=200
):
found_embeddings = []
found_sample_ids = []
found_label_ids = []
result = self._index.scan(num_results=batch_size)
while len(result) > 0:
for doc in result["data"]:
sample_id = doc["fields"][1]["value"]["string_value"]
if sample_id in sample_ids:
embedding = [
d["number_value"]
for d in doc["fields"][2]["value"]["list_value"][
"values"
]
]
found_embeddings.append(embedding)
found_sample_ids.append(sample_id)
found_label_ids.append(
doc["fields"][0]["value"]["string_value"]
)
last_primary_key = result["last_primary_key"]
result = self._index.scan(
num_results=batch_size, last_primary_key=last_primary_key
)
missing_ids = list(set(sample_ids) - set(found_sample_ids))
return found_embeddings, found_sample_ids, found_label_ids, missing_ids
def _kneighbors(
self,
query=None,
k=None,
reverse=False,
aggregation=None,
return_dists=False,
):
if query is None:
raise ValueError("Mosaic does not support full index neighbors")
if reverse is True:
raise ValueError(
"Mosaic does not support least similarity queries"
)
if k is None:
k = self.index_size
if aggregation not in (None, "mean"):
raise ValueError("Unsupported aggregation '%s'" % aggregation)
query = self._parse_neighbors_query(query)
if aggregation == "mean" and query.ndim == 2:
query = query.mean(axis=0)
single_query = query.ndim == 1
if single_query:
query = [query]
if self.has_view:
if self.config.patches_field is not None:
index_ids = self.current_label_ids
else:
index_ids = self.current_sample_ids
# @todo apply filtering in similarity_search(), not post-hoc
# As of this writing, filtering is supported in Mosaic but it is
# not robust and cannot handle a large number of IDs
logger.warning(
"The Mosaic backend does not yet support view filters; the "
"full index will instead be queried, which may result in "
"fewer matches in your current view"
)
_filter = {"foid": set(index_ids)}
else:
_filter = None
sample_ids = []
label_ids = [] if self.config.patches_field is not None else None
dists = []
for q in query:
results = self._index.similarity_search(
columns=["foid", "sample_id"],
query_vector=[float(i) for i in list(q)],
num_results=k,
)["result"]["data_array"]
if _filter is not None:
results = [r for r in results if r[0] in _filter["foid"]]
if self.config.patches_field is not None:
sample_ids.append([r[1] for r in results])
label_ids.append([r[0] for r in results])
else:
sample_ids.append([r[0] for r in results])
if return_dists:
dists.append([r[2] for r in results])
if single_query:
sample_ids = sample_ids[0]
if label_ids is not None:
label_ids = label_ids[0]
if return_dists:
dists = dists[0]
if return_dists:
return sample_ids, label_ids, dists
return sample_ids, label_ids
def _parse_neighbors_query(self, query):
if etau.is_str(query):
query_ids = [query]
single_query = True
else:
query = np.asarray(query)
# Query by vector(s)
if np.issubdtype(query.dtype, np.number):
return query
query_ids = list(query)
single_query = False
# Query by ID(s)
embeddings = self._get_values(query_ids)
if len(embeddings) == 0:
raise ValueError(
"Query IDs %s do not exist in this index" % query_ids
)
query = np.array(embeddings)
if single_query:
query = query[0, :]
return query
@classmethod
def _from_dict(cls, d, samples, config, brain_key):
return cls(samples, config, brain_key)
================================================
FILE: fiftyone/brain/internal/core/pgvector.py
================================================
"""
PGVector similarity backend.
| Copyright 2017-2026, Voxel51, Inc.
| `voxel51.com `_
|
"""
import logging
import numpy as np
import eta.core.utils as etau
import fiftyone.core.utils as fou
from fiftyone.brain.similarity import (
SimilarityConfig,
Similarity,
SimilarityIndex,
)
import fiftyone.brain.internal.core.utils as fbu
psycopg2 = fou.lazy_import("psycopg2")
psy_extras = fou.lazy_import("psycopg2.extras")
logger = logging.getLogger(__name__)
# Supported metrics for pgvector
_SUPPORTED_METRICS = {
"cosine": "vector_cosine_ops",
"dotproduct": "vector_ip_ops",
"euclidean": "vector_l2_ops",
"l1": "vector_l1_ops",
"jaccard": "vector_jaccard_ops",
"hamming": "vector_hamming_ops",
}
class PgVectorSimilarityConfig(SimilarityConfig):
"""Configuration for the PGVector similarity backend.
Args:
index_name (None): the name of the PGVector index to use or create.
If none is provided, a default index name will be used.
table_name (None): the name of the table to use or create. If none is
provided, a default table name will be used.
metric ("cosine"): the similarity metric to use. Supported values are
``("cosine", "dotproduct", "euclidean", "l1", "jaccard", "hamming")``
connection_string (None): the connection string to the PostgreSQL database
ssl_cert (None): the path to the SSL certificate file
ssl_key (None): the path to the secret key used for the client certificate
ssl_root_cert (None): the path to the file containing SSL certificate
authority (CA) certificate(s).
work_mem ("64MB"): the base maximum amount of memory to be used by a query operation
(such as a sort or hash table) before writing to temporary disk files
hnsw_m (16): the max number of connections per layer in the HNSW index
hnsw_ef_construction (64): the size of the dynamic candidate list for constructing the graph for the HNSW index
**kwargs: keyword arguments for
:class:`fiftyone.brain.similarity.SimilarityConfig`
"""
def __init__(
self,
index_name=None,
table_name=None,
metric="cosine",
connection_string=None,
ssl_cert=None,
ssl_key=None,
ssl_root_cert=None,
work_mem="64MB",
hnsw_m=16,
hnsw_ef_construction=64,
**kwargs,
):
if metric not in _SUPPORTED_METRICS:
raise ValueError(
f"Unsupported metric '{metric}'. "
f"Supported values are {_SUPPORTED_METRICS}"
)
super().__init__(**kwargs)
self.metric = metric
self.ssl_cert = ssl_cert
self.ssl_key = ssl_key
self.ssl_root_cert = ssl_root_cert
self.work_mem = work_mem
self.index_name = index_name
self.table_name = table_name
self.hnsw_m = hnsw_m
self.hnsw_ef_construction = hnsw_ef_construction
self._connection_string = connection_string
@property
def method(self):
return "pgvector"
@property
def connection_string(self):
return self._connection_string
@connection_string.setter
def connection_string(self, connection_string):
self._connection_string = connection_string
@property
def max_k(self):
return 10000
@property
def supports_least_similarity(self):
return False
@property
def supported_aggregations(self):
return ("mean",)
def load_credentials(
self,
connection_string=None,
):
self._load_parameters(connection_string=connection_string)
class PgVectorSimilarity(Similarity):
"""PGVector similarity factory.
Args:
config: a :class:`PgVectorSimilarityConfig`
"""
def ensure_requirements(self):
fou.ensure_package("psycopg2|psycopg2-binary")
def ensure_usage_requirements(self):
fou.ensure_package("psycopg2|psycopg2-binary")
def initialize(self, samples, brain_key):
return PgVectorSimilarityIndex(
samples, self.config, brain_key, backend=self
)
class PgVectorSimilarityIndex(SimilarityIndex):
"""Class for interacting with PGVector similarity indexes.
Args:
samples: the :class:`fiftyone.core.collections.SampleCollection` used
config: the :class:`PGVectorSimilarityConfig` used
brain_key: the brain key
backend (None): a :class:`PGVectorSimilarity` instance
"""
def __init__(self, samples, config, brain_key, backend=None):
super().__init__(samples, config, brain_key, backend=backend)
self._conn = None
self._cur = None
self._initialize()
@property
def total_index_size(self):
if self._conn.closed:
self._initialize()
try:
self._cur.execute(
f"""SELECT COUNT(*) FROM "{self.config.table_name}";"""
)
return self._cur.fetchone()[0]
except Exception as e:
logger.error(f"Error getting index size: {str(e)}")
return 0
def _initialize(self):
ssl_options = {}
if self.config.ssl_cert:
ssl_options["sslcert"] = self.config.ssl_cert
if self.config.ssl_key:
ssl_options["sslkey"] = self.config.ssl_key
if self.config.ssl_root_cert:
ssl_options["sslrootcert"] = self.config.ssl_root_cert
logger.info(f"Connecting to PostgreSQL database")
self._conn = psycopg2.connect(
self.config.connection_string, **ssl_options
)
self._cur = self._conn.cursor()
try:
self._cur.execute("CREATE EXTENSION IF NOT EXISTS vector")
self._conn.commit()
except Exception as e:
logger.error(f"Error creating vector extension: {str(e)}")
raise
if self.config.table_name is None:
table_names = self._get_table_names()
root = "fiftyone-" + fou.to_slug(self.samples._root_dataset.name)
table_name = fbu.get_unique_name(root, table_names)
self.config.table_name = table_name
self.save_config()
existing_indexes = []
else:
existing_indexes = self._get_index_names(self.config.table_name)
if self.config.index_name is None:
root = "fiftyone-index-" + fou.to_slug(
self.samples._root_dataset.name
)
index_name = fbu.get_unique_name(root, existing_indexes)
self.config.index_name = index_name
self.save_config()
def _get_table_names(self):
self._cur.execute(
"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';"
)
return [row[0] for row in self._cur.fetchall()]
def _get_index_names(self, table_name):
self._cur.execute(
f"SELECT indexname FROM pg_indexes WHERE tablename = '{table_name}' AND schemaname = 'public';"
)
return [row[0] for row in self._cur.fetchall()]
def _create_table(self, dimension):
try:
self._cur.execute(
f"""
CREATE TABLE IF NOT EXISTS "{self.config.table_name}" (
id TEXT PRIMARY KEY,
sample_id TEXT,
embedding_vector VECTOR({dimension})
);
"""
)
self._conn.commit()
except Exception as e:
logger.error(
f"Error creating table: {self.config.table_name} with dimension {dimension}: {str(e)}"
)
raise
def create_hnsw_index(self):
operator_class = _SUPPORTED_METRICS[self.config.metric]
try:
self._cur.execute(
f"""DROP INDEX IF EXISTS "{self.config.index_name}";"""
)
self._conn.commit()
self._cur.execute(
f"""
CREATE INDEX "{self.config.index_name}"
ON "{self.config.table_name}" USING hnsw (embedding_vector {operator_class})
WITH (m = %s, ef_construction = %s);
""",
(self.config.hnsw_m, self.config.hnsw_ef_construction),
)
self._conn.commit()
except Exception as e:
logger.error(
f"Error creating HNSW index on table {self.config.table_name}:{str(e)}"
)
raise
def _get_index_ids(self, batch_size=1000):
named_cursor = self._conn.cursor(
name="id_cursor"
) # Named cursor for server-side query
named_cursor.execute(f"""SELECT id FROM "{self.config.table_name}";""")
existing_ids = []
while True:
rows = named_cursor.fetchmany(batch_size)
if not rows:
break
ids = [row[0] for row in rows]
existing_ids.extend(ids)
named_cursor.close()
return existing_ids
def add_to_index(
self,
embeddings,
sample_ids,
label_ids=None,
overwrite=True,
allow_existing=True,
warn_existing=False,
reload=True,
batch_size=5000,
close_conn=True,
):
if self._conn.closed:
self._initialize()
self._cur.execute(f"SET work_mem TO '{self.config.work_mem}'")
if self.config.table_name not in self._get_table_names():
self._create_table(embeddings.shape[1])
if label_ids is not None:
ids = label_ids
else:
ids = sample_ids
if warn_existing or not allow_existing or not overwrite:
index_ids = self._get_index_ids()
existing_ids = set(ids) & set(index_ids)
num_existing = len(existing_ids)
if num_existing > 0:
if not allow_existing:
raise ValueError(
"Found %d IDs (eg %s) that already exist in the index"
% (num_existing, next(iter(existing_ids)))
)
if warn_existing:
if overwrite:
logger.warning(
"Overwriting %d IDs that already exist in the "
"index",
num_existing,
)
else:
logger.warning(
"Skipping %d IDs that already exist in the index",
num_existing,
)
else:
existing_ids = set()
if existing_ids and not overwrite:
query = f"""
INSERT INTO "{self.config.table_name}" (id, sample_id, embedding_vector)
VALUES %s
ON CONFLICT (id) DO NOTHING;
"""
else:
query = f"""
INSERT INTO "{self.config.table_name}" (id, sample_id, embedding_vector)
VALUES %s
ON CONFLICT (id) DO UPDATE
SET sample_id = EXCLUDED.sample_id,
embedding_vector = EXCLUDED.embedding_vector;
"""
embeddings = [e.tolist() for e in embeddings]
sample_ids = list(sample_ids)
if label_ids is not None:
ids = list(label_ids)
else:
ids = list(sample_ids)
for _embeddings, _ids, _sample_ids in zip(
fou.iter_batches(embeddings, batch_size),
fou.iter_batches(ids, batch_size),
fou.iter_batches(sample_ids, batch_size),
):
data = list(zip(_ids, _sample_ids, _embeddings))
psy_extras.execute_values(self._cur, query, data)
self._conn.commit()
self.create_hnsw_index()
if close_conn:
self.close_connections()
if reload:
self.reload()
def remove_from_index(
self,
sample_ids=None,
label_ids=None,
allow_missing=True,
warn_missing=False,
reload=True,
):
if self._conn.closed:
self._initialize()
if label_ids is not None:
ids = label_ids
else:
ids = sample_ids
if warn_missing or not allow_missing:
response = self.get_embeddings_by_id(ids)
existing_ids = [id for id, emb in response]
missing_ids = set(ids) - set(existing_ids)
num_missing_ids = len(missing_ids)
if num_missing_ids > 0:
if not allow_missing:
raise ValueError(
"Found %d IDs (eg %s) that do not exist in the index"
% (num_missing_ids, next(iter(missing_ids)))
)
if warn_missing and not allow_missing:
logger.warning(
"Skipping %d IDs that do not exist in the index",
num_missing_ids,
)
try:
# Use parameterized query to delete multiple IDs
self._cur.execute(
f"""DELETE FROM "{self.config.table_name}" WHERE id IN %s;""",
(tuple(ids),),
)
except Exception as e:
self._conn.rollback()
logger.error(f"Error removing embeddings for ids {ids}: {str(e)}")
raise
deleted_count = self._cur.rowcount
self._conn.commit()
logger.info(f"Deleted {deleted_count} embeddings from the index.")
if reload:
self.reload()
def close_connections(self):
if not self._cur.closed:
self._cur.close()
if not self._conn.closed:
self._conn.close()
def get_embeddings_by_id(self, sample_ids=None, label_ids=None):
if self._conn.closed:
self._initialize()
if label_ids is not None:
try:
self._cur.execute(
f"""SELECT id, sample_id, embedding_vector FROM "{self.config.table_name}" WHERE id = ANY(%s)""",
(list(label_ids),),
)
except Exception as e:
logger.error(
f"Error fetching embeddings for labels {label_ids}: {str(e)}"
)
raise
elif sample_ids is not None:
try:
self._cur.execute(
f"""SELECT id, sample_id, embedding_vector FROM "{self.config.table_name}" WHERE sample_id = ANY(%s)""",
(list(sample_ids),),
)
except Exception as e:
logger.error(
f"Error fetching embeddings for samples {sample_ids}: {str(e)}"
)
raise
else:
try:
self._cur.execute(
f"""SELECT id, sample_id, embedding_vector FROM "{self.config.table_name}";"""
)
except Exception as e:
logger.error(
f"Error fetching embeddings for all samples: {str(e)}"
)
raise
results = self._cur.fetchall()
fo_id = []
sample_id = []
embeddings = []
for result in results:
# Convert string "[1.2,3.4,5.6]" to float array
if isinstance(result[2], str):
emb = np.array(
[float(x) for x in result[2].strip("[]").split(",")],
dtype=np.float32,
)
embeddings.append(emb)
else:
# Already numeric
emb = np.array(result[2], dtype=np.float32)
embeddings.append(emb)
fo_id.append(result[0])
sample_id.append(result[1])
return fo_id, sample_id, embeddings
def get_embeddings(
self,
sample_ids=None,
label_ids=None,
allow_missing=True,
warn_missing=False,
):
if label_ids is not None:
if self.config.patches_field is None:
raise ValueError("This index does not support label IDs")
if sample_ids is not None:
logger.warning(
"Ignoring sample IDs when label IDs are provided"
)
if sample_ids is not None and self.config.patches_field is not None:
(
label_ids,
found_sample_ids,
embeddings,
) = self.get_embeddings_by_id(sample_ids=sample_ids)
missing_ids = list(set(sample_ids) - set(found_sample_ids))
sample_ids = found_sample_ids
elif self.config.patches_field is not None:
(
found_label_ids,
sample_ids,
embeddings,
) = self.get_embeddings_by_id(label_ids=label_ids)
missing_ids = (
list(set(label_ids) - set(found_label_ids))
if label_ids is not None
else []
)
label_ids = found_label_ids
else:
(
label_ids,
found_sample_ids,
embeddings,
) = self.get_embeddings_by_id(sample_ids=sample_ids)
missing_ids = (
list(set(sample_ids) - set(found_sample_ids))
if sample_ids is not None
else []
)
sample_ids = found_sample_ids
num_missing_ids = len(missing_ids)
if num_missing_ids > 0:
if not allow_missing:
raise ValueError(
"Found %d IDs (eg %s) that do not exist in the index"
% (num_missing_ids, missing_ids[0])
)
if warn_missing:
logger.warning(
"Skipping %d IDs that do not exist in the index",
num_missing_ids,
)
embeddings = np.array(embeddings)
sample_ids = np.array(sample_ids)
if label_ids is not None:
label_ids = np.array(label_ids)
return embeddings, sample_ids, label_ids
def _kneighbors(
self,
query=None,
k=None,
reverse=False,
aggregation=None,
return_dists=False,
close_conn=True,
):
if self._conn.closed:
self._initialize()
if query is None:
raise ValueError("Postgres does not support full index neighbors")
if aggregation not in (None, "mean"):
raise ValueError("Unsupported aggregation '%s'" % aggregation)
if k is None:
k = self.index_size
query = self._parse_neighbors_query(query)
if aggregation == "mean" and query.ndim == 2:
query = query.mean(axis=0)
single_query = query.ndim == 1
if single_query:
query = [query]
index_ids = None
if self.has_view:
if self.config.patches_field is not None:
index_ids = list(self.current_label_ids)
else:
index_ids = list(self.current_sample_ids)
_filter = True
else:
_filter = False
sort_order = "DESC" if reverse else "ASC"
sample_ids = []
label_ids = [] if self.config.patches_field is not None else None
dists = []
for q in query:
if _filter:
self._cur.execute(
f"""
SELECT id, sample_id, embedding_vector <-> %s::vector AS distance
FROM "{self.config.table_name}"
WHERE id = ANY(%s)
ORDER BY distance {sort_order}
LIMIT %s;
""",
(q.tolist(), index_ids, k),
)
else:
self._cur.execute(
f"""
SELECT id, sample_id, embedding_vector <-> %s::vector AS distance
FROM "{self.config.table_name}"
ORDER BY distance {sort_order}
LIMIT %s;
""",
(q.tolist(), k),
)
results = self._cur.fetchall()
if self.config.patches_field is not None:
sample_ids.append([r[1] for r in results])
label_ids.append([r[0] for r in results])
else:
sample_ids.append([r[0] for r in results])
if return_dists:
dists.append([r[2] for r in results])
if close_conn:
self.close_connections()
if single_query:
sample_ids = sample_ids[0]
if label_ids is not None:
label_ids = label_ids[0]
if return_dists:
dists = dists[0]
if return_dists:
return sample_ids, label_ids, dists
return sample_ids, label_ids
def _parse_neighbors_query(self, query):
if etau.is_str(query):
query_ids = [query]
single_query = True
else:
query = np.asarray(query)
# Query by vector(s)
if np.issubdtype(query.dtype, np.number):
return query
query_ids = list(query)
single_query = False
_, _, embeddings = self.get_embeddings_by_id(label_ids=query_ids)
if len(embeddings) == 0:
raise ValueError(
"Query IDs %s do not exist in this index" % query_ids
)
query = np.array(embeddings)
if single_query:
query = query[0, :]
return query
def cleanup(self, drop_table=False):
"""
Clean up the database by dropping the HNSW index and optionally the embeddings table.
"""
logger.info(
f"Cleaning up: Deleting HNSW index '{self.config.index_name}'"
)
self._cur.execute(
f"""DROP INDEX IF EXISTS "{self.config.index_name}";"""
)
if self._conn.closed:
self._initialize()
if drop_table:
self._cur.execute(
f"""DROP TABLE IF EXISTS "{self.config.table_name}";"""
)
logger.info(
f"{self.config.table_name} table deleted successfully."
)
self._conn.commit()
# Close the database connection
self.close_connections()
logger.info("Database connection closed.")
@classmethod
def _from_dict(cls, d, samples, config, brain_key):
return cls(samples, config, brain_key)
================================================
FILE: fiftyone/brain/internal/core/pinecone.py
================================================
"""
Piencone similarity backend.
| Copyright 2017-2026, Voxel51, Inc.
| `voxel51.com `_
|
"""
import logging
import numpy as np
import eta.core.utils as etau
import fiftyone.core.utils as fou
from fiftyone.brain.similarity import (
SimilarityConfig,
Similarity,
SimilarityIndex,
)
import fiftyone.brain.internal.core.utils as fbu
pinecone = fou.lazy_import("pinecone")
logger = logging.getLogger(__name__)
_SUPPORTED_METRICS = ("cosine", "dotproduct", "euclidean")
class PineconeSimilarityConfig(SimilarityConfig):
"""Configuration for the Pinecone similarity backend.
Args:
index_name (None): the name of a Pinecone index to use or create. If
none is provided, a new index will be created
index_type (None): the index type to use when creating a new index.
The supported values are ``["serverless", "pod"]`` and the default
is ``"serverless"``
namespace (None): a namespace under which to store vectors added to the
index
metric (None): the embedding distance metric to use when creating a
new index. Supported values are
``("cosine", "dotproduct", "euclidean")``
replicas (None): an optional number of replicas when creating a new
pod-based index
shards (None): an optional number of shards when creating a new
pod-based index
pods (None): an optional number of pods when creating a new pod-based
index
pod_type (None): an optional pod type when creating a new pod-based
index
api_key (None): a Pinecone API key to use
cloud (None): a cloud to use when creating serverless indexes
region (None): a region to use when creating serverless indexes
environment (None): an environment to use when creating pod-based
indexes
**kwargs: keyword arguments for
:class:`fiftyone.brain.similarity.SimilarityConfig`
"""
def __init__(
self,
index_name=None,
index_type=None,
namespace=None,
metric=None,
replicas=None,
shards=None,
pods=None,
pod_type=None,
api_key=None,
cloud=None,
region=None,
environment=None,
**kwargs,
):
if metric is not None and metric not in _SUPPORTED_METRICS:
raise ValueError(
"Unsupported metric '%s'. Supported values are %s"
% (metric, _SUPPORTED_METRICS)
)
super().__init__(**kwargs)
self.index_name = index_name
self.index_type = index_type
self.namespace = namespace
self.metric = metric
self.replicas = replicas
self.shards = shards
self.pods = pods
self.pod_type = pod_type
# store privately so these aren't serialized
self._api_key = api_key
self._cloud = cloud
self._region = region
self._environment = environment
@property
def method(self):
return "pinecone"
@property
def api_key(self):
return self._api_key
@api_key.setter
def api_key(self, value):
self._api_key = value
@property
def cloud(self):
return self._cloud
@cloud.setter
def cloud(self, value):
self._cloud = value
@property
def region(self):
return self._region
@region.setter
def region(self, value):
self._region = value
@property
def environment(self):
return self._environment
@environment.setter
def environment(self, value):
self._environment = value
@property
def max_k(self):
return 10000 # Pinecone limit
@property
def supports_least_similarity(self):
return False
@property
def supported_aggregations(self):
return ("mean",)
def load_credentials(
self, api_key=None, cloud=None, region=None, environment=None
):
self._load_parameters(
api_key=api_key,
cloud=cloud,
region=region,
environment=environment,
)
class PineconeSimilarity(Similarity):
"""Pinecone similarity factory.
Args:
config: a :class:`PineconeSimilarityConfig`
"""
def ensure_requirements(self):
fou.ensure_package("pinecone-client")
def ensure_usage_requirements(self):
fou.ensure_package("pinecone-client>=3.2")
def initialize(self, samples, brain_key):
return PineconeSimilarityIndex(
samples, self.config, brain_key, backend=self
)
class PineconeSimilarityIndex(SimilarityIndex):
"""Class for interacting with Pinecone similarity indexes.
Args:
samples: the :class:`fiftyone.core.collections.SampleCollection` used
config: the :class:`PineconeSimilarityConfig` used
brain_key: the brain key
backend (None): a :class:`PineconeSimilarity` instance
"""
def __init__(self, samples, config, brain_key, backend=None):
super().__init__(samples, config, brain_key, backend=backend)
self._pinecone = None
self._index = None
self._initialize()
def _initialize(self):
self._pinecone = pinecone.Pinecone(api_key=self.config.api_key)
try:
index_names = [d["name"] for d in self._pinecone.list_indexes()]
except Exception as e:
raise ValueError(
"Failed to connect to Pinecone backend. "
"Refer to https://docs.voxel51.com/integrations/pinecone.html "
"for more information"
) from e
if self.config.index_name is None:
# https://docs.pinecone.io/troubleshooting/restrictions-on-index-names
root = "fiftyone-" + fou.to_slug(self.samples._root_dataset.name)
index_name = fbu.get_unique_name(root, index_names, max_len=45)
self.config.index_name = index_name
self.save_config()
if self.config.index_name in index_names:
index = self._pinecone.Index(self.config.index_name)
else:
index = None
self._index = index
def _create_index(self, dimension):
index_type = self.config.index_type or "serverless"
if index_type == "serverless":
spec = pinecone.ServerlessSpec(
self.config.cloud,
self.config.region,
)
elif index_type == "pod":
kwargs = dict(
pod_type=self.config.pod_type,
pods=self.config.pods,
replicas=self.config.replicas,
shards=self.config.shards,
)
kwargs = {k: v for k, v in kwargs.items() if v is not None}
spec = pinecone.PodSpec(self.config.environment, **kwargs)
else:
raise TypeError(
f"Invalid index_type='{index_type}'. The supported values are "
"['serverless', 'pod']"
)
metric = self.config.metric or "cosine"
self._pinecone.create_index(
name=self.config.index_name,
dimension=dimension,
metric=metric,
spec=spec,
)
self._index = self._pinecone.Index(self.config.index_name)
@property
def index(self):
"""The ``pinecone.Index`` instance for this index."""
return self._index
@property
def total_index_size(self):
if self._index is None:
return 0
return self._index.describe_index_stats()["total_vector_count"]
@property
def ready(self):
return self._pinecone.describe_index(self.config.index_name).status[
"ready"
]
def add_to_index(
self,
embeddings,
sample_ids,
label_ids=None,
overwrite=True,
allow_existing=True,
warn_existing=False,
reload=True,
batch_size=100,
namespace=None,
):
if namespace is None:
namespace = self.config.namespace
if self._index is None:
self._create_index(embeddings.shape[1])
if label_ids is not None:
ids = label_ids
else:
ids = sample_ids
if warn_existing or not allow_existing or not overwrite:
existing_ids = self._get_existing_ids(ids)
num_existing = len(existing_ids)
if num_existing > 0:
if not allow_existing:
raise ValueError(
"Found %d IDs (eg %s) that already exist in the index"
% (num_existing, next(iter(existing_ids)))
)
if warn_existing:
if overwrite:
logger.warning(
"Overwriting %d IDs that already exist in the "
"index",
num_existing,
)
else:
logger.warning(
"Skipping %d IDs that already exist in the index",
num_existing,
)
else:
existing_ids = set()
if existing_ids and not overwrite:
del_inds = [i for i, _id in enumerate(ids) if _id in existing_ids]
embeddings = np.delete(embeddings, del_inds, axis=0)
sample_ids = np.delete(sample_ids, del_inds)
if label_ids is not None:
label_ids = np.delete(label_ids, del_inds)
embeddings = [e.tolist() for e in embeddings]
sample_ids = list(sample_ids)
if label_ids is not None:
ids = list(label_ids)
else:
ids = list(sample_ids)
for _embeddings, _ids, _sample_ids in zip(
fou.iter_batches(embeddings, batch_size),
fou.iter_batches(ids, batch_size),
fou.iter_batches(sample_ids, batch_size),
):
_id_dicts = [
{"id": _id, "sample_id": _sid}
for _id, _sid in zip(_ids, _sample_ids)
]
self._index.upsert(
list(zip(_ids, _embeddings, _id_dicts)),
namespace=namespace,
)
if reload:
self.reload()
def remove_from_index(
self,
sample_ids=None,
label_ids=None,
allow_missing=True,
warn_missing=False,
reload=True,
):
if label_ids is not None:
ids = label_ids
else:
ids = sample_ids
if not allow_missing or warn_missing:
existing_ids = list(self._index.fetch(ids).vectors.keys())
missing_ids = set(ids) - set(existing_ids)
num_missing = len(missing_ids)
if num_missing > 0:
if not allow_missing:
raise ValueError(
"Found %d IDs (eg %s) that are not present in the "
"index" % (num_missing, next(iter(missing_ids)))
)
if warn_missing:
logger.warning(
"Ignoring %d IDs that are not present in the index",
num_missing,
)
ids = existing_ids
self._index.delete(ids=ids)
if reload:
self.reload()
def get_embeddings(
self,
sample_ids=None,
label_ids=None,
allow_missing=True,
warn_missing=False,
):
if label_ids is not None:
if self.config.patches_field is None:
raise ValueError("This index does not support label IDs")
if sample_ids is not None:
logger.warning(
"Ignoring sample IDs when label IDs are provided"
)
if sample_ids is not None and self.config.patches_field is not None:
(
embeddings,
sample_ids,
label_ids,
missing_ids,
) = self._get_patch_embeddings_from_sample_ids(sample_ids)
elif self.config.patches_field is not None:
(
embeddings,
sample_ids,
label_ids,
missing_ids,
) = self._get_patch_embeddings_from_label_ids(label_ids)
else:
(
embeddings,
sample_ids,
label_ids,
missing_ids,
) = self._get_sample_embeddings(sample_ids)
num_missing_ids = len(missing_ids)
if num_missing_ids > 0:
if not allow_missing:
raise ValueError(
"Found %d IDs (eg %s) that do not exist in the index"
% (num_missing_ids, missing_ids[0])
)
if warn_missing:
logger.warning(
"Skipping %d IDs that do not exist in the index",
num_missing_ids,
)
embeddings = np.array(embeddings)
sample_ids = np.array(sample_ids)
if label_ids is not None:
label_ids = np.array(label_ids)
return embeddings, sample_ids, label_ids
def cleanup(self):
self._pinecone.delete_index(self.config.index_name)
self._index = None
def _get_existing_ids(self, ids, batch_size=1000):
existing_ids = set()
for batch_ids in fou.iter_batches(ids, batch_size):
response = self._index.fetch(ids=list(batch_ids))["vectors"]
existing_ids.update(response.keys())
return existing_ids
def _get_sample_embeddings(self, sample_ids, batch_size=1000):
found_embeddings = []
found_sample_ids = []
if sample_ids is None:
raise ValueError(
"Pinecone does not support retrieving all vectors in an index"
)
for batch_ids in fou.iter_batches(sample_ids, batch_size):
response = self._index.fetch(ids=list(batch_ids))["vectors"]
for r in response.values():
found_embeddings.append(r["values"])
found_sample_ids.append(r["id"])
missing_ids = list(set(sample_ids) - set(found_sample_ids))
return found_embeddings, found_sample_ids, None, missing_ids
def _get_patch_embeddings_from_label_ids(self, label_ids, batch_size=1000):
found_embeddings = []
found_sample_ids = []
found_label_ids = []
if label_ids is None:
raise ValueError(
"Pinecone does not support retrieving all vectors in an index"
)
for batch_ids in fou.iter_batches(label_ids, batch_size):
response = self._index.fetch(ids=list(batch_ids))["vectors"]
for r in response.values():
found_embeddings.append(r["values"])
found_sample_ids.append(r["metadata"]["sample_id"])
found_label_ids.append(r["id"])
missing_ids = list(set(label_ids) - set(found_label_ids))
return found_embeddings, found_sample_ids, found_label_ids, missing_ids
def _get_patch_embeddings_from_sample_ids(
self, sample_ids, batch_size=100
):
found_embeddings = []
found_sample_ids = []
found_label_ids = []
query_vector = [0.0] * self._get_dimension()
top_k = min(batch_size, self.config.max_k)
for batch_ids in fou.iter_batches(sample_ids, batch_size):
response = self._index.query(
vector=query_vector,
filter={"sample_id": {"$in": list(batch_ids)}},
top_k=top_k,
include_values=True,
include_metadata=True,
)
for r in response["matches"]:
found_embeddings.append(r["values"])
found_sample_ids.append(r["metadata"]["sample_id"])
found_label_ids.append(r["id"])
missing_ids = list(set(sample_ids) - set(found_sample_ids))
return found_embeddings, found_sample_ids, found_label_ids, missing_ids
def _kneighbors(
self,
query=None,
k=None,
reverse=False,
aggregation=None,
return_dists=False,
):
if query is None:
raise ValueError("Pinecone does not support full index neighbors")
if reverse is True:
raise ValueError(
"Pinecone does not support least similarity queries"
)
if k is None or k > self.config.max_k:
raise ValueError("Pinecone requires k<=%s" % self.config.max_k)
if aggregation not in (None, "mean"):
raise ValueError("Unsupported aggregation '%s'" % aggregation)
query = self._parse_neighbors_query(query)
if aggregation == "mean" and query.ndim == 2:
query = query.mean(axis=0)
single_query = query.ndim == 1
if single_query:
query = [query]
if self.has_view:
if self.config.patches_field is not None:
index_ids = self.current_label_ids
else:
index_ids = self.current_sample_ids
_filter = {"id": {"$in": list(index_ids)}}
else:
_filter = None
sample_ids = []
label_ids = [] if self.config.patches_field is not None else None
dists = []
for q in query:
include_metadata = self.config.patches_field is not None
response = self._index.query(
vector=q.tolist(),
top_k=k,
filter=_filter,
include_metadata=include_metadata,
)
if self.config.patches_field is not None:
sample_ids.append(
[r["metadata"]["sample_id"] for r in response["matches"]]
)
label_ids.append([r["id"] for r in response["matches"]])
else:
sample_ids.append([r["id"] for r in response["matches"]])
if return_dists:
dists.append([r["score"] for r in response["matches"]])
if single_query:
sample_ids = sample_ids[0]
if label_ids is not None:
label_ids = label_ids[0]
if return_dists:
dists = dists[0]
if return_dists:
return sample_ids, label_ids, dists
return sample_ids, label_ids
def _parse_neighbors_query(self, query):
if etau.is_str(query):
query_ids = [query]
single_query = True
else:
query = np.asarray(query)
# Query by vector(s)
if np.issubdtype(query.dtype, np.number):
return query
query_ids = list(query)
single_query = False
# Query by ID(s)
response = self._index.fetch(query_ids)["vectors"]
query = np.array([response[_id]["values"] for _id in query_ids])
if query.size == 0:
raise ValueError(
"Query IDs %s were not found in the index" % query_ids
)
if single_query:
query = query[0, :]
return query
def _get_dimension(self):
if self._index is None:
return None
return self._index.describe_index_stats().dimension
@classmethod
def _from_dict(cls, d, samples, config, brain_key):
return cls(samples, config, brain_key)
================================================
FILE: fiftyone/brain/internal/core/qdrant.py
================================================
"""
Qdrant similarity backend.
| Copyright 2017-2026, Voxel51, Inc.
| `voxel51.com `_
|
"""
import logging
import numpy as np
import eta.core.utils as etau
import fiftyone.core.utils as fou
from fiftyone.brain.similarity import (
SimilarityConfig,
Similarity,
SimilarityIndex,
)
import fiftyone.brain.internal.core.utils as fbu
qdrant = fou.lazy_import("qdrant_client")
qmodels = fou.lazy_import("qdrant_client.http.models")
logger = logging.getLogger(__name__)
_SUPPORTED_METRICS = {
"cosine": qmodels.Distance.COSINE,
"dotproduct": qmodels.Distance.DOT,
"euclidean": qmodels.Distance.EUCLID,
}
class QdrantSimilarityConfig(SimilarityConfig):
"""Configuration for the Qdrant similarity backend.
Args:
collection_name (None): the name of a Qdrant collection to use or
create. If none is provided, a new collection will be created
metric (None): the embedding distance metric to use when creating a
new index. Supported values are
``("cosine", "dotproduct", "euclidean")``
replication_factor (None): an optional replication factor to use when
creating a new index
shard_number (None): an optional number of shards to use when creating
a new index
write_consistency_factor (None): an optional write consistsency factor
to use when creating a new index
hnsw_config (None): an optional dict of HNSW config parameters to use
when creating a new index
optimizers_config (None): an optional dict of optimizer parameters to
use when creating a new index
wal_config (None): an optional dict of WAL config parameters to use
when creating a new index
url (None): a Qdrant server URL to use
api_key (None): a Qdrant API key to use
grpc_port (None): Port of Qdrant gRPC interface
prefer_grpc (None): If `true`, use gRPC interface when possible
**kwargs: keyword arguments for
:class:`fiftyone.brain.similarity.SimilarityConfig`
"""
def __init__(
self,
collection_name=None,
metric=None,
replication_factor=None,
shard_number=None,
write_consistency_factor=None,
hnsw_config=None,
optimizers_config=None,
wal_config=None,
url=None,
api_key=None,
grpc_port=None,
prefer_grpc=None,
**kwargs,
):
if metric is not None and metric not in _SUPPORTED_METRICS:
raise ValueError(
"Unsupported metric '%s'. Supported values are %s"
% (metric, tuple(_SUPPORTED_METRICS.keys()))
)
super().__init__(**kwargs)
self.collection_name = collection_name
self.metric = metric
self.replication_factor = replication_factor
self.shard_number = shard_number
self.write_consistency_factor = write_consistency_factor
self.hnsw_config = hnsw_config
self.optimizers_config = optimizers_config
self.wal_config = wal_config
# store privately so these aren't serialized
self._url = url
self._api_key = api_key
self._grpc_port = grpc_port
self._prefer_grpc = prefer_grpc
@property
def method(self):
return "qdrant"
@property
def url(self):
return self._url
@url.setter
def url(self, value):
self._url = value
@property
def api_key(self):
return self._api_key
@api_key.setter
def api_key(self, value):
self._api_key = value
@property
def grpc_port(self):
return self._grpc_port
@grpc_port.setter
def grpc_port(self, value):
self._grpc_port = value
@property
def prefer_grpc(self):
return self._prefer_grpc
@prefer_grpc.setter
def prefer_grpc(self, value):
self._prefer_grpc = value
@property
def max_k(self):
return None
@property
def supports_least_similarity(self):
return False
@property
def supported_aggregations(self):
return ("mean",)
def load_credentials(
self, url=None, api_key=None, grpc_port=None, prefer_grpc=None
):
self._load_parameters(
url=url,
api_key=api_key,
grpc_port=grpc_port,
prefer_grpc=prefer_grpc,
)
class QdrantSimilarity(Similarity):
"""Qdrant similarity factory.
Args:
config: a :class:`QdrantSimilarityConfig`
"""
def ensure_requirements(self):
fou.ensure_package("qdrant-client")
def ensure_usage_requirements(self):
fou.ensure_package("qdrant-client")
def initialize(self, samples, brain_key):
return QdrantSimilarityIndex(
samples, self.config, brain_key, backend=self
)
class QdrantSimilarityIndex(SimilarityIndex):
"""Class for interacting with Qdrant similarity indexes.
Args:
samples: the :class:`fiftyone.core.collections.SampleCollection` used
config: the :class:`QdrantSimilarityConfig` used
brain_key: the brain key
backend (None): a :class:`QdrantSimilarity` instance
"""
def __init__(self, samples, config, brain_key, backend=None):
super().__init__(samples, config, brain_key, backend=backend)
self._client = None
self._initialize()
def _initialize(self):
# QdrantClient does not appear to like passing None as defaults
grpc_port = (
self.config.grpc_port
if self.config.grpc_port is not None
else 6334
)
prefer_grpc = (
self.config.prefer_grpc
if self.config.prefer_grpc is not None
else False
)
self._client = qdrant.QdrantClient(
url=self.config.url,
api_key=self.config.api_key,
grpc_port=grpc_port,
prefer_grpc=prefer_grpc,
)
try:
collection_names = self._get_collection_names()
except Exception as e:
raise ValueError(
"Failed to connect to Qdrant backend at URL '%s'. Refer to "
"https://docs.voxel51.com/integrations/qdrant.html for more "
"information" % self.config.url
) from e
if self.config.collection_name is None:
root = "fiftyone-" + fou.to_slug(self.samples._root_dataset.name)
collection_name = fbu.get_unique_name(root, collection_names)
self.config.collection_name = collection_name
self.save_config()
def _get_collection_names(self):
return [c.name for c in self._client.get_collections().collections]
def _create_collection(self, dimension):
if self.config.metric:
metric = self.config.metric
else:
metric = "cosine"
vectors_config = qmodels.VectorParams(
size=dimension,
distance=_SUPPORTED_METRICS[metric],
)
if self.config.hnsw_config:
hnsw_config = qmodels.HnswConfig(**self.config.hnsw_config)
else:
hnsw_config = None
if self.config.optimizers_config:
optimizers_config = qmodels.OptimizersConfig(
**self.config.optimizers_config
)
else:
optimizers_config = None
if self.config.wal_config:
wal_config = qmodels.WalConfig(**self.config.wal_config)
else:
wal_config = None
self._client.recreate_collection(
collection_name=self.config.collection_name,
vectors_config=vectors_config,
shard_number=self.config.shard_number,
replication_factor=self.config.replication_factor,
hnsw_config=hnsw_config,
optimizers_config=optimizers_config,
wal_config=wal_config,
)
def _get_index_ids(self, batch_size=1000):
ids = []
offset = 0
while offset is not None:
response = self._client.scroll(
collection_name=self.config.collection_name,
offset=offset,
limit=batch_size,
with_payload=True,
with_vectors=False,
)
ids.extend([self._to_fiftyone_id(r.id) for r in response[0]])
offset = response[-1]
return ids
@property
def total_index_size(self):
try:
return self._client.count(self.config.collection_name).count
except:
return 0
@property
def client(self):
"""The ``qdrant.QdrantClient`` instance for this index."""
return self._client
def add_to_index(
self,
embeddings,
sample_ids,
label_ids=None,
overwrite=True,
allow_existing=True,
warn_existing=False,
reload=True,
batch_size=1000,
):
if self.config.collection_name not in self._get_collection_names():
self._create_collection(embeddings.shape[1])
if label_ids is not None:
ids = label_ids
else:
ids = sample_ids
if warn_existing or not allow_existing or not overwrite:
index_ids = self._get_index_ids()
existing_ids = set(ids) & set(index_ids)
num_existing = len(existing_ids)
if num_existing > 0:
if not allow_existing:
raise ValueError(
"Found %d IDs (eg %s) that already exist in the index"
% (num_existing, next(iter(existing_ids)))
)
if warn_existing:
if overwrite:
logger.warning(
"Overwriting %d IDs that already exist in the "
"index",
num_existing,
)
else:
logger.warning(
"Skipping %d IDs that already exist in the index",
num_existing,
)
else:
existing_ids = set()
if existing_ids and not overwrite:
del_inds = [i for i, _id in enumerate(ids) if _id in existing_ids]
embeddings = np.delete(embeddings, del_inds, axis=0)
sample_ids = np.delete(sample_ids, del_inds)
if label_ids is not None:
label_ids = np.delete(label_ids, del_inds)
embeddings = [e.tolist() for e in embeddings]
sample_ids = list(sample_ids)
if label_ids is not None:
ids = list(label_ids)
else:
ids = list(sample_ids)
for _embeddings, _ids, _sample_ids in zip(
fou.iter_batches(embeddings, batch_size),
fou.iter_batches(ids, batch_size),
fou.iter_batches(sample_ids, batch_size),
):
self._client.upsert(
collection_name=self.config.collection_name,
points=qmodels.Batch(
ids=self._to_qdrant_ids(_ids),
payloads=[{"sample_id": _id} for _id in _sample_ids],
vectors=_embeddings,
),
)
if reload:
self.reload()
def remove_from_index(
self,
sample_ids=None,
label_ids=None,
allow_missing=True,
warn_missing=False,
reload=True,
):
if label_ids is not None:
ids = label_ids
else:
ids = sample_ids
qids = self._to_qdrant_ids(ids)
if warn_missing or not allow_missing:
response = self._retrieve_points(qids, with_vectors=False)
existing_ids = self._to_fiftyone_ids([r.id for r in response])
missing_ids = set(ids) - set(existing_ids)
num_missing_ids = len(missing_ids)
if num_missing_ids > 0:
if not allow_missing:
raise ValueError(
"Found %d IDs (eg %s) that do not exist in the index"
% (num_missing_ids, next(iter(missing_ids)))
)
if warn_missing and not allow_missing:
logger.warning(
"Skipping %d IDs that do not exist in the index",
num_missing_ids,
)
qids = self._to_qdrant_ids(existing_ids)
self._client.delete(
collection_name=self.config.collection_name,
points_selector=qmodels.PointIdsList(points=qids),
)
if reload:
self.reload()
def get_embeddings(
self,
sample_ids=None,
label_ids=None,
allow_missing=True,
warn_missing=False,
):
if label_ids is not None:
if self.config.patches_field is None:
raise ValueError("This index does not support label IDs")
if sample_ids is not None:
logger.warning(
"Ignoring sample IDs when label IDs are provided"
)
if sample_ids is not None and self.config.patches_field is not None:
(
embeddings,
sample_ids,
label_ids,
missing_ids,
) = self._get_patch_embeddings_from_sample_ids(sample_ids)
elif self.config.patches_field is not None:
(
embeddings,
sample_ids,
label_ids,
missing_ids,
) = self._get_patch_embeddings_from_label_ids(label_ids)
else:
(
embeddings,
sample_ids,
label_ids,
missing_ids,
) = self._get_sample_embeddings(sample_ids)
num_missing_ids = len(missing_ids)
if num_missing_ids > 0:
if not allow_missing:
raise ValueError(
"Found %d IDs (eg %s) that do not exist in the index"
% (num_missing_ids, missing_ids[0])
)
if warn_missing:
logger.warning(
"Skipping %d IDs that do not exist in the index",
num_missing_ids,
)
embeddings = np.array(embeddings)
sample_ids = np.array(sample_ids)
if label_ids is not None:
label_ids = np.array(label_ids)
return embeddings, sample_ids, label_ids
def cleanup(self):
self._client.delete_collection(self.config.collection_name)
def _retrieve_points(self, qids, with_vectors=True, with_payload=True):
# @todo add batching?
return self._client.retrieve(
collection_name=self.config.collection_name,
ids=qids,
with_vectors=with_vectors,
with_payload=with_payload,
)
def _get_sample_embeddings(self, sample_ids):
if sample_ids is None:
sample_ids = self._get_index_ids()
response = self._retrieve_points(
self._to_qdrant_ids(sample_ids),
with_vectors=True,
)
found_embeddings = [r.vector for r in response]
found_sample_ids = self._to_fiftyone_ids([r.id for r in response])
missing_ids = list(set(sample_ids) - set(found_sample_ids))
return found_embeddings, found_sample_ids, None, missing_ids
def _get_patch_embeddings_from_label_ids(self, label_ids):
if label_ids is None:
label_ids = self._get_index_ids()
response = self._retrieve_points(
self._to_qdrant_ids(label_ids),
with_vectors=True,
with_payload=True,
)
found_embeddings = [r.vector for r in response]
found_sample_ids = [r.payload["sample_id"] for r in response]
found_label_ids = self._to_fiftyone_ids([r.id for r in response])
missing_ids = list(set(label_ids) - set(found_label_ids))
return found_embeddings, found_sample_ids, found_label_ids, missing_ids
def _get_patch_embeddings_from_sample_ids(self, sample_ids):
_filter = qmodels.Filter(
should=[
qmodels.FieldCondition(
key="sample_id", match=qmodels.MatchValue(value=sid)
)
for sid in sample_ids
]
)
response = self._client.scroll(
collection_name=self.config.collection_name,
scroll_filter=_filter,
with_vectors=True,
with_payload=True,
)[0]
found_embeddings = [r.vector for r in response]
found_sample_ids = [r.payload["sample_id"] for r in response]
found_label_ids = [r.id for r in response]
missing_ids = list(set(sample_ids) - set(found_sample_ids))
return found_embeddings, found_sample_ids, found_label_ids, missing_ids
def _kneighbors(
self,
query=None,
k=None,
reverse=False,
aggregation=None,
return_dists=False,
):
if query is None:
raise ValueError("Qdrant does not support full index neighbors")
if reverse is True:
raise ValueError(
"Qdrant does not support least similarity queries"
)
if aggregation not in (None, "mean"):
raise ValueError("Unsupported aggregation '%s'" % aggregation)
if k is None:
k = self.index_size
query = self._parse_neighbors_query(query)
if aggregation == "mean" and query.ndim == 2:
query = query.mean(axis=0)
single_query = query.ndim == 1
if single_query:
query = [query]
if self.has_view:
if self.config.patches_field is not None:
index_ids = self.current_label_ids
else:
index_ids = self.current_sample_ids
_filter = qmodels.Filter(
must=[
qmodels.HasIdCondition(
has_id=self._to_qdrant_ids(index_ids)
)
]
)
else:
_filter = None
sample_ids = []
label_ids = [] if self.config.patches_field is not None else None
dists = []
for q in query:
with_payload = self.config.patches_field is not None
results = self._client.search(
collection_name=self.config.collection_name,
query_vector=q,
query_filter=_filter,
with_payload=with_payload,
limit=k,
)
if self.config.patches_field is not None:
sample_ids.append(
self._to_fiftyone_ids(
[r.payload["sample_id"] for r in results]
)
)
label_ids.append(
self._to_fiftyone_ids([r.id for r in results])
)
else:
sample_ids.append(
self._to_fiftyone_ids([r.id for r in results])
)
if return_dists:
dists.append([r.score for r in results])
if single_query:
sample_ids = sample_ids[0]
if label_ids is not None:
label_ids = label_ids[0]
if return_dists:
dists = dists[0]
if return_dists:
return sample_ids, label_ids, dists
return sample_ids, label_ids
def _parse_neighbors_query(self, query):
if etau.is_str(query):
query_ids = [query]
single_query = True
else:
query = np.asarray(query)
# Query by vector(s)
if np.issubdtype(query.dtype, np.number):
return query
query_ids = list(query)
single_query = False
# Query by ID(s)
qids = self._to_qdrant_ids(query_ids)
response = self._retrieve_points(qids, with_vectors=True)
query = np.array([r.vector for r in response])
if query.size == 0:
raise ValueError(
"Query IDs %s were not found in the index" % query_ids
)
if single_query:
query = query[0, :]
return query
def _to_qdrant_id(self, _id):
return _id + "00000000"
def _to_qdrant_ids(self, ids):
return [self._to_qdrant_id(_id) for _id in ids]
def _to_fiftyone_id(self, qid):
return qid.replace("-", "")[:-8]
def _to_fiftyone_ids(self, qids):
return [self._to_fiftyone_id(qid) for qid in qids]
@classmethod
def _from_dict(cls, d, samples, config, brain_key):
return cls(samples, config, brain_key)
================================================
FILE: fiftyone/brain/internal/core/redis.py
================================================
"""
Redis similarity backend.
| Copyright 2017-2026, Voxel51, Inc.
| `voxel51.com `_
|
"""
import logging
import numpy as np
import eta.core.utils as etau
import fiftyone.core.utils as fou
from fiftyone.brain.similarity import (
SimilarityConfig,
Similarity,
SimilarityIndex,
)
import fiftyone.brain.internal.core.utils as fbu
redis = fou.lazy_import("redis")
logger = logging.getLogger(__name__)
_SUPPORTED_METRICS = {
"cosine": "COSINE",
"dotproduct": "IP",
"euclidean": "L2",
}
class RedisSimilarityConfig(SimilarityConfig):
"""Configuration for the Redis similarity backend.
Args:
index_name (None): the name of a Redis index to use or create. If none
is provided, a new index will be created
metric ("cosine"): the embedding distance metric to use when creating a
new index. Supported values are
``("cosine", "dotproduct", "euclidean")``
algorithm ("FLAT"): the search algorithm to use. The supported values
are ``("FLAT", "HNSW")``
host ("localhost"): the host to use
port (6379): the port to use
db (0): the database to use
username (None): a username to use
password (None): a password to use
**kwargs: keyword arguments for
:class:`fiftyone.brain.similarity.SimilarityConfig`
"""
def __init__(
self,
index_name=None,
metric="cosine",
algorithm="FLAT",
host="localhost",
port=6379,
db=0,
username=None,
password=None,
**kwargs,
):
if metric not in _SUPPORTED_METRICS:
raise ValueError(
"Unsupported metric '%s'. Supported values are %s"
% (metric, tuple(_SUPPORTED_METRICS.keys()))
)
super().__init__(**kwargs)
self.index_name = index_name
self.metric = metric
self.algorithm = algorithm
# store privately so these aren't serialized
self._host = host
self._port = port
self._db = db
self._username = username
self._password = password
@property
def method(self):
return "redis"
@property
def host(self):
return self._host
@host.setter
def host(self, value):
self._host = value
@property
def port(self):
return self._port
@port.setter
def port(self, value):
self._port = value
@property
def db(self):
return self._db
@db.setter
def db(self, value):
self._db = value
@property
def username(self):
return self._username
@username.setter
def username(self, value):
self._username = value
@property
def password(self):
return self._password
@password.setter
def password(self, value):
self._password = value
@property
def max_k(self):
return None
@property
def supports_least_similarity(self):
return False
@property
def supported_aggregations(self):
return ("mean",)
def load_credentials(
self, host=None, port=None, db=None, username=None, password=None
):
self._load_parameters(
host=host, port=port, db=db, username=username, password=password
)
class RedisSimilarity(Similarity):
"""Redis similarity factory.
Args:
config: a :class:`RedisSimilarityConfig`
"""
def ensure_requirements(self):
fou.ensure_package("redis")
def ensure_usage_requirements(self):
fou.ensure_package("redis")
def initialize(self, samples, brain_key):
return RedisSimilarityIndex(
samples, self.config, brain_key, backend=self
)
class RedisSimilarityIndex(SimilarityIndex):
"""Class for interacting with Redis similarity indexes.
Args:
samples: the :class:`fiftyone.core.collections.SampleCollection` used
config: the :class:`RedisSimilarityConfig` used
brain_key: the brain key
backend (None): a :class:`RedisSimilarity` instance
"""
def __init__(self, samples, config, brain_key, backend=None):
super().__init__(samples, config, brain_key, backend=backend)
self._client = None
self._index = None
self._initialize()
def _initialize(self):
client = redis.Redis(
host=self.config.host,
port=self.config.port,
db=self.config.db,
username=self.config.username,
password=self.config.password,
decode_responses=True,
)
if self.config.index_name is None:
def index_exists(index_name):
try:
client.ft(index_name).info()
return True
except:
return False
root = "fiftyone-" + fou.to_slug(self._samples._root_dataset.name)
index_name = fbu.get_unique_name(root, index_exists)
self.config.index_name = index_name
self.save_config()
try:
index = client.ft(self.config.index_name)
index.info()
except:
index = None
self._client = client
self._index = index
def _create_index(self, dimension):
from redis.commands.search.field import TagField, VectorField
from redis.commands.search.indexDefinition import (
IndexDefinition,
IndexType,
)
schema = (
TagField("$.foid", as_name="foid"),
TagField("$.sample_id", as_name="sample_id"),
VectorField(
"$.vector",
self.config.algorithm,
{
"TYPE": "FLOAT32",
"DIM": dimension,
"DISTANCE_METRIC": _SUPPORTED_METRICS[self.config.metric],
},
as_name="vector",
),
)
definition = IndexDefinition(
prefix=[self.config.index_name + ":"],
index_type=IndexType.JSON,
)
index = self._client.ft(self.config.index_name)
index.create_index(fields=schema, definition=definition)
self._index = index
@property
def client(self):
"""The ``redis.client.Redis`` instance for this index."""
return self._client
@property
def total_index_size(self):
try:
return int(self._index.info()["num_docs"])
except:
return 0
def add_to_index(
self,
embeddings,
sample_ids,
label_ids=None,
overwrite=True,
allow_existing=True,
warn_existing=False,
reload=True,
):
if self._index is None:
self._create_index(embeddings.shape[1])
if label_ids is not None:
ids = label_ids
else:
ids = sample_ids
if warn_existing or not allow_existing or not overwrite:
existing_ids = self._get_existing_ids(ids)
num_existing = len(existing_ids)
if num_existing > 0:
if not allow_existing:
raise ValueError(
"Found %d IDs (eg %s) that already exist in the index"
% (num_existing, next(iter(existing_ids)))
)
if warn_existing:
if overwrite:
logger.warning(
"Overwriting %d IDs that already exist in the "
"index",
num_existing,
)
else:
logger.warning(
"Skipping %d IDs that already exist in the index",
num_existing,
)
else:
existing_ids = set()
if existing_ids and not overwrite:
del_inds = [i for i, _id in enumerate(ids) if _id in existing_ids]
embeddings = np.delete(embeddings, del_inds, axis=0)
sample_ids = np.delete(sample_ids, del_inds)
if label_ids is not None:
label_ids = np.delete(label_ids, del_inds)
elif existing_ids and overwrite:
self._delete_ids(existing_ids)
pipeline = self._client.pipeline()
for e, id, sample_id in zip(embeddings, ids, sample_ids):
key = f"{self.config.index_name}:{id}"
d = {
"foid": id,
"sample_id": sample_id,
"vector": e.astype(np.float32).tolist(),
}
pipeline.json().set(key, "$", d)
pipeline.execute()
if reload:
self.reload()
def _get_existing_ids(self, ids):
return [d["foid"] for d in self._get_values(ids)]
def _delete_ids(self, ids):
keys = [f"{self.config.index_name}:{id}" for id in ids]
self._client.delete(*keys)
def _get_values(self, ids):
pipeline = self._client.pipeline()
for id in ids:
pipeline.json().get(f"{self.config.index_name}:{id}")
return [d for d in pipeline.execute() if d is not None]
def remove_from_index(
self,
sample_ids=None,
label_ids=None,
allow_missing=True,
warn_missing=False,
reload=True,
):
if label_ids is not None:
ids = label_ids
else:
ids = sample_ids
if not allow_missing or warn_missing:
existing_ids = self._get_existing_ids(ids)
missing_ids = set(ids) - set(existing_ids)
num_missing = len(missing_ids)
if num_missing > 0:
if not allow_missing:
raise ValueError(
"Found %d IDs (eg %s) that are not present in the "
"index" % (num_missing, next(iter(missing_ids)))
)
if warn_missing:
logger.warning(
"Ignoring %d IDs that are not present in the index",
num_missing,
)
ids = existing_ids
self._delete_ids(ids=ids)
if reload:
self.reload()
def get_embeddings(
self,
sample_ids=None,
label_ids=None,
allow_missing=True,
warn_missing=False,
):
if label_ids is not None:
if self.config.patches_field is None:
raise ValueError("This index does not support label IDs")
if sample_ids is not None:
logger.warning(
"Ignoring sample IDs when label IDs are provided"
)
if sample_ids is not None and self.config.patches_field is not None:
(
embeddings,
sample_ids,
label_ids,
missing_ids,
) = self._get_patch_embeddings_from_sample_ids(sample_ids)
elif self.config.patches_field is not None:
(
embeddings,
sample_ids,
label_ids,
missing_ids,
) = self._get_patch_embeddings_from_label_ids(label_ids)
else:
(
embeddings,
sample_ids,
label_ids,
missing_ids,
) = self._get_sample_embeddings(sample_ids)
num_missing_ids = len(missing_ids)
if num_missing_ids > 0:
if not allow_missing:
raise ValueError(
"Found %d IDs (eg %s) that do not exist in the index"
% (num_missing_ids, missing_ids[0])
)
if warn_missing:
logger.warning(
"Skipping %d IDs that do not exist in the index",
num_missing_ids,
)
embeddings = np.array(embeddings)
sample_ids = np.array(sample_ids)
if label_ids is not None:
label_ids = np.array(label_ids)
return embeddings, sample_ids, label_ids
def cleanup(self):
if self._index is None:
return
self._index.dropindex(delete_documents=True)
self._index = None
def _get_sample_embeddings(self, sample_ids, batch_size=1000):
found_embeddings = []
found_sample_ids = []
if sample_ids is None:
get_id = lambda key: key.rsplit(":", 1)[1]
keys = self._client.keys(f"{self.config.index_name}:*")
sample_ids = map(get_id, keys)
for batch_ids in fou.iter_batches(sample_ids, batch_size):
for d in self._get_values(batch_ids):
found_embeddings.append(d["vector"])
found_sample_ids.append(d["sample_id"])
missing_ids = list(set(sample_ids) - set(found_sample_ids))
return found_embeddings, found_sample_ids, None, missing_ids
def _get_patch_embeddings_from_label_ids(self, label_ids, batch_size=1000):
found_embeddings = []
found_sample_ids = []
found_label_ids = []
if label_ids is None:
get_id = lambda key: key.rsplit(":", 1)[1]
keys = self._client.keys(f"{self.config.index_name}:*")
label_ids = map(get_id, keys)
for batch_ids in fou.iter_batches(label_ids, batch_size):
for d in self._get_values(batch_ids):
found_embeddings.append(d["vector"])
found_sample_ids.append(d["sample_id"])
found_label_ids.append(d["foid"])
missing_ids = list(set(label_ids) - set(found_label_ids))
return found_embeddings, found_sample_ids, found_label_ids, missing_ids
def _get_patch_embeddings_from_sample_ids(
self, sample_ids, batch_size=100
):
from redis.commands.search.query import Query
found_embeddings = []
found_sample_ids = []
found_label_ids = []
for batch_ids in fou.iter_batches(sample_ids, batch_size):
filter = "@sample_id:{ " + " | ".join(batch_ids) + " }"
query = Query(filter).dialect(2)
for doc in self._index.search(query).docs:
found_embeddings.append(doc.embeddings)
found_sample_ids.append(doc.sample_id)
found_label_ids.append(doc.foid)
missing_ids = list(set(sample_ids) - set(found_sample_ids))
return found_embeddings, found_sample_ids, found_label_ids, missing_ids
def _kneighbors(
self,
query=None,
k=None,
reverse=False,
aggregation=None,
return_dists=False,
):
from redis.commands.search.query import Query
if query is None:
raise ValueError("Redis does not support full index neighbors")
if reverse is True:
raise ValueError("Redis does not support least similarity queries")
if k is None:
k = self.index_size
if aggregation not in (None, "mean"):
raise ValueError("Unsupported aggregation '%s'" % aggregation)
query = self._parse_neighbors_query(query)
if aggregation == "mean" and query.ndim == 2:
query = query.mean(axis=0)
single_query = query.ndim == 1
if single_query:
query = [query]
if self.has_view:
if self.config.patches_field is not None:
index_ids = list(self.current_label_ids)
else:
index_ids = list(self.current_sample_ids)
filter = "@foid:{ " + " | ".join(index_ids) + " }"
else:
filter = "*"
sample_ids = []
label_ids = [] if self.config.patches_field is not None else None
dists = []
for q in query:
_query = (
Query(f"({filter})=>[KNN {k} @vector $query AS score]")
.sort_by("score")
.return_fields("score", "foid", "sample_id")
.dialect(2)
.paging(0, k)
)
_q = q.astype(np.float32).tobytes()
docs = self._index.search(_query, {"query": _q}).docs
if self.config.patches_field is not None:
sample_ids.append([doc.sample_id for doc in docs])
label_ids.append([doc.foid for doc in docs])
else:
sample_ids.append([doc.foid for doc in docs])
if return_dists:
dists.append([doc.score for doc in docs])
if single_query:
sample_ids = sample_ids[0]
if label_ids is not None:
label_ids = label_ids[0]
if return_dists:
dists = dists[0]
if return_dists:
return sample_ids, label_ids, dists
return sample_ids, label_ids
def _parse_neighbors_query(self, query):
if etau.is_str(query):
query_ids = [query]
single_query = True
else:
query = np.asarray(query)
# Query by vector(s)
if np.issubdtype(query.dtype, np.number):
return query
query_ids = list(query)
single_query = False
# Query by ID(s)
dicts = self._get_values(query_ids)
if not dicts:
raise ValueError(
"Query IDs %s do not exist in this index" % query_ids
)
query = np.array([d["vector"] for d in dicts])
if single_query:
query = query[0, :]
return query
@classmethod
def _from_dict(cls, d, samples, config, brain_key):
return cls(samples, config, brain_key)
================================================
FILE: fiftyone/brain/internal/core/representativeness.py
================================================
"""
Representativeness methods.
| Copyright 2017-2026, Voxel51, Inc.
| `voxel51.com `_
|
"""
import logging
import copy
import numpy as np
import sklearn.cluster as skc
from scipy.spatial import cKDTree
import eta.core.utils as etau
import fiftyone.core.brain as fob
import fiftyone.core.fields as fof
import fiftyone.core.labels as fol
import fiftyone.core.validation as fov
import fiftyone.brain.internal.core.utils as fbu
import fiftyone.brain.internal.models as fbm
logger = logging.getLogger(__name__)
_ALLOWED_ROI_FIELD_TYPES = (
fol.Detection,
fol.Detections,
fol.Polyline,
fol.Polylines,
)
_DEFAULT_MODEL = "simple-resnet-cifar10"
_DEFAULT_BATCH_SIZE = 16
def compute_representativeness(
samples,
representativeness_field,
method,
roi_field,
embeddings,
similarity_index,
model,
model_kwargs,
force_square,
alpha,
batch_size,
num_workers,
skip_failures,
progress,
):
"""See ``fiftyone/brain/__init__.py``."""
#
# Algorithm
#
# Compute cluster centers with MeanShift. The representativeness will
# then be a scaled distance to the nearest cluster center. This puts
# cluster centers which should represent the data the highest with a high
# ranking and points on the outliers with low ranking.
#
fov.validate_collection(samples)
if roi_field is not None:
fov.validate_collection_label_fields(
samples, roi_field, _ALLOWED_ROI_FIELD_TYPES
)
if etau.is_str(embeddings):
embeddings_field, embeddings_exist = fbu.parse_data_field(
samples,
embeddings,
patches_field=roi_field,
data_type="embeddings",
)
embeddings = None
else:
embeddings_field = None
embeddings_exist = None
if etau.is_str(similarity_index):
similarity_index = samples.load_brain_results(similarity_index)
if (
model is None
and embeddings is None
and similarity_index is None
and not embeddings_exist
):
model = fbm.load_model(_DEFAULT_MODEL)
if batch_size is None:
batch_size = _DEFAULT_BATCH_SIZE
config = RepresentativenessConfig(
representativeness_field,
method=method,
roi_field=roi_field,
embeddings_field=embeddings_field,
similarity_index=similarity_index,
model=model,
model_kwargs=model_kwargs,
)
brain_key = representativeness_field
brain_method = config.build()
brain_method.ensure_requirements()
brain_method.register_run(samples, brain_key, cleanup=False)
if roi_field is not None:
# @todo experiment with mean(), max(), abs().max(), etc
agg_fcn = lambda e: np.mean(e, axis=0)
else:
agg_fcn = None
embeddings, sample_ids, _ = fbu.get_embeddings(
samples,
model=model,
model_kwargs=model_kwargs,
patches_field=roi_field,
embeddings_field=embeddings_field,
embeddings=embeddings,
similarity_index=similarity_index,
force_square=force_square,
alpha=alpha,
handle_missing="image",
agg_fcn=agg_fcn,
batch_size=batch_size,
num_workers=num_workers,
skip_failures=skip_failures,
progress=progress,
)
logger.info("Computing representativeness...")
representativeness = _compute_representativeness(embeddings, method=method)
# Ensure field exists, even if `representativeness` is empty
samples._dataset.add_sample_field(representativeness_field, fof.FloatField)
representativeness = {
_id: u for _id, u in zip(sample_ids, representativeness)
}
if representativeness:
samples.set_values(
representativeness_field, representativeness, key_field="id"
)
brain_method.save_run_results(samples, brain_key, None)
logger.info("Representativeness computation complete")
def _compute_representativeness(embeddings, method="cluster-center"):
#
# @todo experiment on which method for assessing representativeness
#
num_embeddings = len(embeddings)
logger.info(
"Computing clusters for %d embeddings; this may take awhile...",
num_embeddings,
)
initial_ranking, _ = _cluster_ranker(embeddings)
if method == "cluster-center":
final_ranking = initial_ranking
elif method == "cluster-center-downweight":
logger.info("Applying iterative downweighting...")
final_ranking = _adjust_rankings(
embeddings, initial_ranking, ball_radius=0.5
)
else:
raise ValueError(
(
"Method '%s' not supported. Please use one of "
"['cluster-center', 'cluster-center-downweight']"
)
% method
)
return final_ranking
def _cluster_ranker(
embeddings, cluster_algorithm="kmeans", N=20, norm_method="local"
):
# Cluster
if cluster_algorithm == "meanshift":
bandwidth = skc.estimate_bandwidth(
embeddings, quantile=0.8, n_samples=500
)
clusterer = skc.MeanShift(bandwidth=bandwidth, bin_seeding=True).fit(
embeddings
)
elif cluster_algorithm == "kmeans":
clusterer = skc.KMeans(n_clusters=N, random_state=1234).fit(embeddings)
else:
raise ValueError(
(
"Clustering algorithm '%s' not supported. Please use one of "
"['meanshift', 'kmeans']"
)
% cluster_algorithm
)
cluster_centers = clusterer.cluster_centers_
cluster_ids = clusterer.labels_
# Get distance from each point to it's closest cluster center
sample_dists = np.linalg.norm(
embeddings - cluster_centers[cluster_ids], axis=1
)
centerness_ranking = 1 / (1 + sample_dists)
# Normalize per cluster vs globally
norm_method = "local"
if norm_method == "global":
centerness_ranking = centerness_ranking / centerness_ranking.max()
elif norm_method == "local":
unique_ids = np.unique(cluster_ids)
for unique_id in unique_ids:
cluster_indices = np.where(cluster_ids == unique_id)[0]
cluster_dists = sample_dists[cluster_indices]
cluster_dists /= cluster_dists.max()
sample_dists[cluster_indices] = cluster_dists
centerness_ranking = sample_dists
return centerness_ranking, clusterer
# Step 3: Adjust rankings to avoid redundancy
def _adjust_rankings(embeddings, initial_ranking, ball_radius=0.5):
tree = cKDTree(embeddings)
new_ranking = copy.deepcopy(initial_ranking)
ordered_ranking = np.argsort(new_ranking)[::-1]
visited_indices = set()
for ranked_index in ordered_ranking:
visited_indices.add(ranked_index)
query_embedding = embeddings[ranked_index, :]
nearby_indices = tree.query_ball_point(
query_embedding, ball_radius, return_sorted=True
)
filtered_indices = [
idx for idx in nearby_indices if idx not in visited_indices
]
visited_indices |= set(filtered_indices)
new_ranking[filtered_indices] = new_ranking[filtered_indices] * 0.7
new_ranking = new_ranking / new_ranking.max()
return new_ranking
# @todo move to `fiftyone/brain/representativeness.py`
# Don't do this hastily; `get_brain_info()` on existing datasets has this
# class's full path in it and may need migration
class RepresentativenessConfig(fob.BrainMethodConfig):
def __init__(
self,
representativeness_field,
method=None,
roi_field=None,
embeddings_field=None,
similarity_index=None,
model=None,
model_kwargs=None,
**kwargs,
):
if similarity_index is not None and not etau.is_str(similarity_index):
similarity_index = similarity_index.key
if model is not None and not etau.is_str(model):
model = etau.get_class_name(model)
self.representativeness_field = representativeness_field
self._method = method
self.roi_field = roi_field
self.embeddings_field = embeddings_field
self.similarity_index = similarity_index
self.model = model
self.model_kwargs = model_kwargs
super().__init__(**kwargs)
@property
def type(self):
return "representativeness"
@property
def method(self):
return self._method
@classmethod
def _virtual_attributes(cls):
# By default 'method' is virtual but we omit so it *IS* serialized
return ["cls", "type"]
class Representativeness(fob.BrainMethod):
def ensure_requirements(self):
pass
def get_fields(self, samples, brain_key):
fields = [self.config.representativeness_field]
if self.config.roi_field is not None:
fields.append(self.config.roi_field)
if self.config.embeddings_field is not None:
fields.append(self.config.embeddings_field)
return fields
def cleanup(self, samples, brain_key):
representativeness_field = self.config.representativeness_field
samples._dataset.delete_sample_fields(
representativeness_field, error_level=1
)
def _validate_run(self, samples, brain_key, existing_info):
self._validate_fields_match(
brain_key, "representativeness_field", existing_info
)
================================================
FILE: fiftyone/brain/internal/core/sklearn.py
================================================
"""
Sklearn similarity backend.
| Copyright 2017-2026, Voxel51, Inc.
| `voxel51.com `_
|
"""
import logging
import numpy as np
import sklearn.metrics as skm
import sklearn.neighbors as skn
import sklearn.preprocessing as skp
import eta.core.utils as etau
import fiftyone.core.media as fom
from fiftyone.brain.similarity import (
DuplicatesMixin,
SimilarityConfig,
Similarity,
SimilarityIndex,
)
import fiftyone.brain.internal.core.utils as fbu
logger = logging.getLogger(__name__)
_AGGREGATIONS = {
"mean": np.mean,
"post-mean": np.nanmean,
"post-min": np.nanmin,
"post-max": np.nanmax,
}
_MAX_PRECOMPUTE_DISTS = 15000 # ~1.7GB to store distance matrix in-memory
_COSINE_HACK_ATTR = "_cosine_hack"
class SklearnSimilarityConfig(SimilarityConfig):
"""Configuration for the sklearn similarity backend.
Args:
metric ("cosine"): the embedding distance metric to use. See
``sklearn.metrics.pairwise_distance`` for supported values
**kwargs: keyword arguments for
:class:`fiftyone.brain.similarity.SimilarityConfig`
"""
def __init__(self, metric="cosine", **kwargs):
super().__init__(**kwargs)
self.metric = metric
@property
def method(self):
return "sklearn"
@property
def max_k(self):
return None
@property
def supports_least_similarity(self):
return True
@property
def supported_aggregations(self):
return tuple(_AGGREGATIONS.keys())
class SklearnSimilarity(Similarity):
"""Sklearn similarity factory.
Args:
config: an :class:`SklearnSimilarityConfig`
"""
def initialize(self, samples, brain_key):
return SklearnSimilarityIndex(
samples, self.config, brain_key, backend=self
)
class SklearnSimilarityIndex(SimilarityIndex, DuplicatesMixin):
"""Class for interacting with sklearn similarity indexes.
Args:
samples: the :class:`fiftyone.core.collections.SampleCollection` used
config: the :class:`SklearnSimilarityConfig` used
brain_key: the brain key
embeddings (None): a ``num_embeddings x num_dims`` array of embeddings
sample_ids (None): a ``num_embeddings`` array of sample IDs
label_ids (None): a ``num_embeddings`` array of label IDs, if
applicable
backend (None): a :class:`SklearnSimilarity` instance
"""
def __init__(
self,
samples,
config,
brain_key,
embeddings=None,
sample_ids=None,
label_ids=None,
backend=None,
):
embeddings, sample_ids, label_ids = self._parse_data(
samples,
config,
embeddings=embeddings,
sample_ids=sample_ids,
label_ids=label_ids,
)
self._dataset = samples._dataset
self._embeddings = embeddings
self._sample_ids = sample_ids
self._label_ids = label_ids
self._ids_to_inds = None
self._curr_ids_to_inds = None
self._neighbors_helper = None
SimilarityIndex.__init__(
self, samples, config, brain_key, backend=backend
)
DuplicatesMixin.__init__(self)
@property
def is_external(self):
return self.config.embeddings_field is None
@property
def embeddings(self):
return self._embeddings
@property
def sample_ids(self):
return self._sample_ids
@property
def label_ids(self):
return self._label_ids
@property
def total_index_size(self):
return len(self._sample_ids)
def add_to_index(
self,
embeddings,
sample_ids,
label_ids=None,
overwrite=True,
allow_existing=True,
warn_existing=False,
reload=True,
):
sample_ids = np.asarray(sample_ids)
label_ids = np.asarray(label_ids) if label_ids is not None else None
_sample_ids, _label_ids, ii, jj = fbu.add_ids(
sample_ids,
label_ids,
self._sample_ids,
self._label_ids,
patches_field=self.config.patches_field,
overwrite=overwrite,
allow_existing=allow_existing,
warn_existing=warn_existing,
)
if ii.size == 0:
return
_embeddings = embeddings[ii, :]
if self.config.embeddings_field is not None:
fbu.add_embeddings(
self._dataset,
_embeddings,
sample_ids[ii],
label_ids[ii] if label_ids is not None else None,
self.config.embeddings_field,
patches_field=self.config.patches_field,
)
_e = self._embeddings
n = _e.shape[0]
if n == 0:
_e = np.empty((0, embeddings.shape[1]), dtype=embeddings.dtype)
d = _e.shape[1]
m = max(jj) - n + 1
if m > 0:
if _e.size > 0:
_e = np.concatenate((_e, np.empty((m, d), dtype=_e.dtype)))
else:
_e = np.empty_like(_embeddings)
_e[jj, :] = _embeddings
self._embeddings = _e
self._sample_ids = _sample_ids
self._label_ids = _label_ids
self._ids_to_inds = None
self._curr_ids_to_inds = None
self._neighbors_helper = None
if reload:
super().reload()
def remove_from_index(
self,
sample_ids=None,
label_ids=None,
allow_missing=True,
warn_missing=False,
reload=True,
):
_sample_ids, _label_ids, rm_inds = fbu.remove_ids(
sample_ids,
label_ids,
self._sample_ids,
self._label_ids,
patches_field=self.config.patches_field,
allow_missing=allow_missing,
warn_missing=warn_missing,
)
if rm_inds.size == 0:
return
if self.config.embeddings_field is not None:
if self.config.patches_field is not None:
rm_sample_ids = None
rm_label_ids = self._label_ids[rm_inds]
else:
rm_sample_ids = self._sample_ids[rm_inds]
rm_label_ids = None
fbu.remove_embeddings(
self._dataset,
self.config.embeddings_field,
sample_ids=rm_sample_ids,
label_ids=rm_label_ids,
patches_field=self.config.patches_field,
)
_embeddings = np.delete(self._embeddings, rm_inds, axis=0)
self._embeddings = _embeddings
self._sample_ids = _sample_ids
self._label_ids = _label_ids
self._ids_to_inds = None
self._curr_ids_to_inds = None
self._neighbors_helper = None
if reload:
super().reload()
def use_view(self, *args, **kwargs):
self._curr_ids_to_inds = None
return super().use_view(*args, **kwargs)
def get_embeddings(
self,
sample_ids=None,
label_ids=None,
allow_missing=True,
warn_missing=False,
):
if label_ids is not None:
if self.config.patches_field is None:
raise ValueError("This index does not support label IDs")
if sample_ids is not None:
logger.warning(
"Ignoring sample IDs when label IDs are provided"
)
inds = _get_inds(
label_ids,
self.label_ids,
"label",
allow_missing,
warn_missing,
)
embeddings = self._embeddings[inds, :]
sample_ids = self.sample_ids[inds]
label_ids = np.asarray(label_ids)
elif sample_ids is not None:
if etau.is_str(sample_ids):
sample_ids = [sample_ids]
if self.config.patches_field is not None:
sample_ids = set(sample_ids)
bools = [_id in sample_ids for _id in self.sample_ids]
inds = np.nonzero(bools)[0]
else:
inds = _get_inds(
sample_ids,
self.sample_ids,
"sample",
allow_missing,
warn_missing,
)
embeddings = self._embeddings[inds, :]
sample_ids = self.sample_ids[inds]
if self.config.patches_field is not None:
label_ids = self.label_ids[inds]
else:
label_ids = None
else:
embeddings = self._embeddings.copy()
sample_ids = self.sample_ids.copy()
if self.config.patches_field is not None:
label_ids = self.label_ids.copy()
else:
label_ids = None
return embeddings, sample_ids, label_ids
def reload(self):
if self.config.embeddings_field is not None:
embeddings, sample_ids, label_ids = self._parse_data(
self._dataset, self.config
)
self._embeddings = embeddings
self._sample_ids = sample_ids
self._label_ids = label_ids
self._ids_to_inds = None
self._curr_ids_to_inds = None
self._neighbors_helper = None
super().reload()
def cleanup(self):
pass
def attributes(self):
attrs = super().attributes()
if self.config.embeddings_field is None:
attrs.extend(["embeddings", "sample_ids", "label_ids"])
return attrs
def _kneighbors(
self,
query=None,
k=None,
reverse=False,
aggregation=None,
return_dists=False,
):
if aggregation is not None:
return self._kneighbors_aggregate(
query, k, reverse, aggregation, return_dists
)
(
query,
query_inds,
full_index,
single_query,
) = self._parse_neighbors_query(query)
can_use_dists = full_index or query_inds is not None
neighbors, dists = self._get_neighbors(can_use_dists=can_use_dists)
if dists is not None:
# Use pre-computed distances
if query_inds is not None:
_dists = dists[query_inds, :]
else:
_dists = dists
# note: this must gracefully ignore nans
inds = _nanargmin(_dists, k=k)
if return_dists:
dists = [d[i] for i, d in zip(inds, _dists)]
else:
dists = None
else:
if return_dists:
dists, inds = neighbors.kneighbors(
X=query, n_neighbors=k, return_distance=True
)
inds = list(inds)
dists = list(dists)
else:
inds = neighbors.kneighbors(
X=query, n_neighbors=k, return_distance=False
)
inds = list(inds)
dists = None
return self._format_output(
inds, dists, full_index, single_query, return_dists
)
def _radius_neighbors(self, query=None, thresh=None, return_dists=False):
(
query,
query_inds,
full_index,
single_query,
) = self._parse_neighbors_query(query)
can_use_dists = full_index or query_inds is not None
neighbors, dists = self._get_neighbors(can_use_dists=can_use_dists)
# When not using brute force, we approximate cosine distance by
# computing Euclidean distance on unit-norm embeddings.
# ED = sqrt(2 * CD), so we need to scale the threshold appropriately
if getattr(neighbors, _COSINE_HACK_ATTR, False):
thresh = np.sqrt(2.0 * thresh)
if dists is not None:
# Use pre-computed distances
if query_inds is not None:
_dists = dists[query_inds, :]
else:
_dists = dists
# note: this must gracefully ignore nans
inds = [np.nonzero(d <= thresh)[0] for d in _dists]
if return_dists:
dists = [d[i] for i, d in zip(inds, _dists)]
else:
dists = None
else:
if return_dists:
dists, inds = neighbors.radius_neighbors(
X=query, radius=thresh, return_distance=True
)
else:
dists = None
inds = neighbors.radius_neighbors(
X=query, radius=thresh, return_distance=False
)
return self._format_output(
inds, dists, full_index, single_query, return_dists
)
def _kneighbors_aggregate(
self, query, k, reverse, aggregation, return_dists
):
if query is None:
raise ValueError("Full index queries do not support aggregation")
if aggregation not in _AGGREGATIONS:
raise ValueError(
"Unsupported aggregation method '%s'. Supported values are %s"
% (aggregation, tuple(_AGGREGATIONS.keys()))
)
query, query_inds, _, _ = self._parse_neighbors_query(query)
# Pre-aggregation
if aggregation == "mean":
if query.shape[0] > 1:
query = query.mean(axis=0, keepdims=True)
query_inds = None
aggregation = None
can_use_dists = query_inds is not None
_, dists = self._get_neighbors(
can_use_neighbors=False, can_use_dists=can_use_dists
)
if dists is not None:
# Use pre-computed distances
dists = dists[query_inds, :]
else:
keep_inds = self._current_inds
index_embeddings = self._embeddings
if keep_inds is not None:
index_embeddings = index_embeddings[keep_inds]
dists = skm.pairwise_distances(
query, index_embeddings, metric=self.config.metric
)
# Post-aggregation
if aggregation is not None:
# note: this must gracefully ignore nans
agg_fcn = _AGGREGATIONS[aggregation]
dists = agg_fcn(dists, axis=0)
else:
dists = dists[0, :]
if can_use_dists:
dists[np.isnan(dists)] = 0.0
inds = np.argsort(dists)
if reverse:
inds = np.flip(inds)
if k is not None:
inds = inds[:k]
sample_ids = list(self.current_sample_ids[inds])
if self.config.patches_field is not None:
label_ids = list(self.current_label_ids[inds])
else:
label_ids = None
if return_dists:
dists = list(dists[inds])
return sample_ids, label_ids, dists
return sample_ids, label_ids
def _parse_neighbors_query(self, query):
# Full index
if query is None:
return None, None, True, False
if etau.is_str(query):
query_ids = [query]
single_query = True
else:
query = np.asarray(query)
# Query vector(s)
if np.issubdtype(query.dtype, np.number):
single_query = query.ndim == 1
if single_query:
query = query[np.newaxis, :]
return query, None, False, single_query
query_ids = list(query)
single_query = False
# Retrieve indices into active `dists` matrix, if possible
ids_to_inds = self._get_ids_to_inds(full=False)
query_inds = []
for _id in query_ids:
_ind = ids_to_inds.get(_id, None)
if _ind is not None:
query_inds.append(_ind)
else:
# At least one query ID is not in the active index
query_inds = None
break
# Retrieve embeddings
ids_to_inds = self._get_ids_to_inds(full=True)
inds = []
bad_ids = []
for _id in query_ids:
_ind = ids_to_inds.get(_id, None)
if _ind is not None:
inds.append(_ind)
else:
bad_ids.append(_id)
inds = np.array(inds)
if bad_ids:
raise ValueError(
"Query IDs %s do not exist in this index" % bad_ids
)
query = self._embeddings[inds, :]
if query_inds is not None:
query_inds = np.array(query_inds)
return query, query_inds, False, single_query
def _get_ids_to_inds(self, full=False):
if full:
if self._ids_to_inds is None:
if self.config.patches_field is not None:
ids = self.label_ids
else:
ids = self.sample_ids
self._ids_to_inds = {_id: i for i, _id in enumerate(ids)}
return self._ids_to_inds
if self._curr_ids_to_inds is None:
if self.config.patches_field is not None:
ids = self.current_label_ids
else:
ids = self.current_sample_ids
self._curr_ids_to_inds = {_id: i for i, _id in enumerate(ids)}
return self._curr_ids_to_inds
def _get_neighbors(self, can_use_neighbors=True, can_use_dists=True):
if self._neighbors_helper is None:
self._neighbors_helper = NeighborsHelper(
self._embeddings, self.config.metric
)
return self._neighbors_helper.get_neighbors(
keep_inds=self._current_inds,
can_use_neighbors=can_use_neighbors,
can_use_dists=can_use_dists,
)
def _format_output(
self, inds, dists, full_index, single_query, return_dists
):
if full_index:
if return_dists:
return inds, dists
return inds
curr_sample_ids = self.current_sample_ids
sample_ids = [[curr_sample_ids[i] for i in _inds] for _inds in inds]
if single_query:
sample_ids = sample_ids[0]
if self.config.patches_field is not None:
curr_label_ids = self.current_label_ids
label_ids = [[curr_label_ids[i] for i in _inds] for _inds in inds]
if single_query:
label_ids = label_ids[0]
else:
label_ids = None
if return_dists:
dists = [list(d) for d in dists]
if single_query:
dists = dists[0]
return sample_ids, label_ids, dists
return sample_ids, label_ids
@staticmethod
def _parse_data(
samples,
config,
embeddings=None,
sample_ids=None,
label_ids=None,
):
if embeddings is None:
samples = samples._dataset
if samples.media_type == fom.GROUP:
samples = samples.select_group_slices(_allow_mixed=True)
embeddings, sample_ids, label_ids = fbu.get_embeddings(
samples,
patches_field=config.patches_field,
embeddings_field=config.embeddings_field,
)
elif sample_ids is None:
sample_ids, label_ids = fbu.get_ids(
samples,
patches_field=config.patches_field,
data=embeddings,
data_type="embeddings",
)
return embeddings, sample_ids, label_ids
@classmethod
def _from_dict(cls, d, samples, config, brain_key):
embeddings = d.get("embeddings", None)
if embeddings is not None:
embeddings = np.array(embeddings)
sample_ids = d.get("sample_ids", None)
if sample_ids is not None:
sample_ids = np.array(sample_ids)
label_ids = d.get("label_ids", None)
if label_ids is not None:
label_ids = np.array(label_ids)
return cls(
samples,
config,
brain_key,
embeddings=embeddings,
sample_ids=sample_ids,
label_ids=label_ids,
)
class NeighborsHelper(object):
_UNAVAILABLE = "UNAVAILABLE"
def __init__(self, embeddings, metric):
self.embeddings = embeddings
self.metric = metric
self._initialized = False
self._full_dists = None
self._curr_keep_inds = None
self._curr_neighbors = None
self._curr_dists = None
def get_neighbors(
self,
keep_inds=None,
can_use_neighbors=True,
can_use_dists=True,
):
iokay = self._same_keep_inds(keep_inds)
nokay = not can_use_neighbors or self._curr_neighbors is not None
dokay = not can_use_dists or self._curr_dists is not None
if iokay and nokay and dokay:
neighbors = self._curr_neighbors
dists = self._curr_dists
else:
neighbors, dists = self._build(
keep_inds=keep_inds,
can_use_neighbors=can_use_neighbors,
can_use_dists=can_use_dists,
)
if not iokay:
self._curr_keep_inds = keep_inds
if self._curr_neighbors is None or not iokay:
self._curr_neighbors = neighbors
if self._curr_dists is None or not iokay:
self._curr_dists = dists
if not can_use_neighbors or neighbors is self._UNAVAILABLE:
neighbors = None
if not can_use_dists or dists is self._UNAVAILABLE:
dists = None
return neighbors, dists
def _same_keep_inds(self, keep_inds):
# This handles either argument being None
return np.array_equal(keep_inds, self._curr_keep_inds)
def _build(
self, keep_inds=None, can_use_neighbors=True, can_use_dists=True
):
if can_use_dists:
if (
self._full_dists is None
and len(self.embeddings) <= _MAX_PRECOMPUTE_DISTS
):
self._full_dists = self._build_dists(self.embeddings)
if self._full_dists is not None:
if keep_inds is not None:
dists = self._full_dists[keep_inds, :][:, keep_inds]
else:
dists = self._full_dists
elif (
keep_inds is not None
and len(keep_inds) <= _MAX_PRECOMPUTE_DISTS
):
dists = self._build_dists(self.embeddings[keep_inds])
else:
dists = self._UNAVAILABLE
else:
dists = None
if can_use_neighbors:
if not isinstance(dists, np.ndarray):
embeddings = self.embeddings
if keep_inds is not None:
embeddings = embeddings[keep_inds]
neighbors = self._build_neighbors(embeddings)
else:
neighbors = self._UNAVAILABLE
else:
neighbors = None
return neighbors, dists
def _build_dists(self, embeddings):
logger.debug("Generating index for %d embeddings...", len(embeddings))
# Center embeddings
embeddings = np.asarray(embeddings)
embeddings -= embeddings.mean(axis=0, keepdims=True)
dists = skm.pairwise_distances(embeddings, metric=self.metric)
np.fill_diagonal(dists, np.nan)
logger.debug("Index complete")
return dists
def _build_neighbors(self, embeddings):
logger.debug(
"Generating neighbors graph for %d embeddings...",
len(embeddings),
)
# Center embeddings
embeddings = np.asarray(embeddings)
embeddings -= embeddings.mean(axis=0, keepdims=True)
metric = self.metric
if metric == "cosine":
# Nearest neighbors does not directly support cosine distance, so
# we approximate via euclidean distance on unit-norm embeddings
cosine_hack = True
embeddings = skp.normalize(embeddings, axis=1)
metric = "euclidean"
else:
cosine_hack = False
neighbors = skn.NearestNeighbors(metric=metric)
neighbors.fit(embeddings)
setattr(neighbors, _COSINE_HACK_ATTR, cosine_hack)
logger.debug("Index complete")
return neighbors
def _get_inds(ids, index_ids, ftype, allow_missing, warn_missing):
if etau.is_str(ids):
ids = [ids]
ids_map = {_id: i for i, _id in enumerate(index_ids)}
inds = []
bad_ids = []
for _id in ids:
idx = ids_map.get(_id, None)
if idx is not None:
inds.append(idx)
else:
bad_ids.append(_id)
num_missing = len(bad_ids)
if num_missing > 0:
if not allow_missing:
raise ValueError(
"Found %d %s IDs (eg '%s') that are not present in the index"
% (num_missing, ftype, bad_ids[0])
)
if warn_missing:
logger.warning(
"Ignoring %d %s IDs that are not present in the index",
num_missing,
ftype,
)
return np.array(inds)
def _nanargmin(array, k=1):
if k == 1:
inds = np.nanargmin(array, axis=1)
inds = [np.array([i]) for i in inds]
else:
inds = np.argsort(array, axis=1)
inds = list(inds[:, :k])
return inds
================================================
FILE: fiftyone/brain/internal/core/uniqueness.py
================================================
"""
Uniqueness methods.
| Copyright 2017-2026, Voxel51, Inc.
| `voxel51.com `_
|
"""
import logging
import numpy as np
import eta.core.utils as etau
import fiftyone.brain as fb
import fiftyone.core.brain as fob
import fiftyone.core.fields as fof
import fiftyone.core.labels as fol
import fiftyone.core.utils as fou
import fiftyone.core.validation as fov
import fiftyone.brain.internal.core.utils as fbu
import fiftyone.brain.internal.models as fbm
logger = logging.getLogger(__name__)
_ALLOWED_ROI_FIELD_TYPES = (
fol.Detection,
fol.Detections,
fol.Polyline,
fol.Polylines,
)
_DEFAULT_MODEL = "simple-resnet-cifar10"
_DEFAULT_BATCH_SIZE = 16
def compute_uniqueness(
samples,
uniqueness_field,
roi_field,
embeddings,
similarity_index,
model,
model_kwargs,
force_square,
alpha,
batch_size,
num_workers,
skip_failures,
progress,
):
"""See ``fiftyone/brain/__init__.py``."""
#
# Algorithm
#
# Uniqueness is computed based on a classification model. Each sample is
# embedded into a vector space based on the model. Then, we compute the
# knn's (k is a parameter of the uniqueness function). The uniqueness is
# then proportional to these distances. The intuition is that a sample is
# unique when it is far from other samples in the set. This is different
# than, say, "representativeness" which would stress samples that are core
# to dense clusters of related samples.
#
fov.validate_collection(samples)
if roi_field is not None:
fov.validate_collection_label_fields(
samples, roi_field, _ALLOWED_ROI_FIELD_TYPES
)
if etau.is_str(embeddings):
embeddings_field, embeddings_exist = fbu.parse_data_field(
samples,
embeddings,
patches_field=roi_field,
data_type="embeddings",
)
embeddings = None
else:
embeddings_field = None
embeddings_exist = None
if etau.is_str(similarity_index):
similarity_index = samples.load_brain_results(similarity_index)
if (
model is None
and embeddings is None
and similarity_index is None
and not embeddings_exist
):
model = fbm.load_model(_DEFAULT_MODEL)
if batch_size is None:
batch_size = _DEFAULT_BATCH_SIZE
config = UniquenessConfig(
uniqueness_field,
roi_field=roi_field,
embeddings_field=embeddings_field,
similarity_index=similarity_index,
model=model,
model_kwargs=model_kwargs,
)
brain_key = uniqueness_field
brain_method = config.build()
brain_method.ensure_requirements()
brain_method.register_run(samples, brain_key, cleanup=False)
if roi_field is not None:
# @todo experiment with mean(), max(), abs().max(), etc
agg_fcn = lambda e: np.mean(e, axis=0)
else:
agg_fcn = None
embeddings, sample_ids, _ = fbu.get_embeddings(
samples,
model=model,
model_kwargs=model_kwargs,
patches_field=roi_field,
embeddings_field=embeddings_field,
embeddings=embeddings,
similarity_index=similarity_index,
force_square=force_square,
alpha=alpha,
handle_missing="image",
agg_fcn=agg_fcn,
batch_size=batch_size,
num_workers=num_workers,
skip_failures=skip_failures,
progress=progress,
)
if similarity_index is None:
similarity_index = fb.compute_similarity(
samples, backend="sklearn", embeddings=False
)
similarity_index.add_to_index(embeddings, sample_ids)
logger.info("Computing uniqueness...")
uniqueness = _compute_uniqueness(
embeddings, similarity_index, progress=progress
)
# Ensure field exists, even if `uniqueness` is empty
samples._dataset.add_sample_field(uniqueness_field, fof.FloatField)
uniqueness = {_id: u for _id, u in zip(sample_ids, uniqueness)}
if uniqueness:
samples.set_values(uniqueness_field, uniqueness, key_field="id")
brain_method.save_run_results(samples, brain_key, None)
logger.info("Uniqueness computation complete")
def _compute_uniqueness(
embeddings, similarity_index, batch_size=10, progress=None
):
K = 3
num_embeddings = len(embeddings)
if num_embeddings <= K:
return [1] * num_embeddings
if similarity_index.config.method == "sklearn":
_, dists = similarity_index._kneighbors(k=K + 1, return_dists=True)
else:
dists = []
with fou.ProgressBar(total=num_embeddings, progress=progress) as pb:
for _embeddings in fou.iter_slices(embeddings, batch_size):
_, _, _dists = similarity_index._kneighbors(
query=_embeddings, k=K + 1, return_dists=True
)
dists.extend(_dists)
pb.update(len(_dists))
dists = np.array(dists)
# @todo experiment on which method for assessing uniqueness is best
#
# To get something going, for now, just take a weighted mean
#
weights = [0.6, 0.3, 0.1]
sample_dists = np.mean(dists[:, 1:] * weights, axis=1)
# Normalize to keep the user on common footing across datasets
sample_dists /= sample_dists.max()
return sample_dists
# @todo move to `fiftyone/brain/uniqueness.py`
# Don't do this hastily; `get_brain_info()` on existing datasets has this
# class's full path in it and may need migration
class UniquenessConfig(fob.BrainMethodConfig):
def __init__(
self,
uniqueness_field,
roi_field=None,
embeddings_field=None,
similarity_index=None,
model=None,
model_kwargs=None,
**kwargs,
):
if similarity_index is not None and not etau.is_str(similarity_index):
similarity_index = similarity_index.key
if model is not None and not etau.is_str(model):
model = etau.get_class_name(model)
self.uniqueness_field = uniqueness_field
self.roi_field = roi_field
self.embeddings_field = embeddings_field
self.similarity_index = similarity_index
self.model = model
self.model_kwargs = model_kwargs
super().__init__(**kwargs)
@property
def type(self):
return "uniqueness"
@property
def method(self):
return "neighbors"
class Uniqueness(fob.BrainMethod):
def ensure_requirements(self):
pass
def get_fields(self, samples, brain_key):
fields = [self.config.uniqueness_field]
if self.config.roi_field is not None:
fields.append(self.config.roi_field)
if self.config.embeddings_field is not None:
fields.append(self.config.embeddings_field)
return fields
def cleanup(self, samples, brain_key):
uniqueness_field = self.config.uniqueness_field
samples._dataset.delete_sample_fields(uniqueness_field, error_level=1)
def _validate_run(self, samples, brain_key, existing_info):
self._validate_fields_match(
brain_key, "uniqueness_field", existing_info
)
================================================
FILE: fiftyone/brain/internal/core/utils.py
================================================
"""
Utilities.
| Copyright 2017-2026, Voxel51, Inc.
| `voxel51.com `_
|
"""
import itertools
import logging
import random
import string
import numpy as np
import eta.core.utils as etau
import fiftyone.brain as fob
import fiftyone.core.fields as fof
import fiftyone.core.labels as fol
import fiftyone.core.models as fom
import fiftyone.core.media as fomm
import fiftyone.core.patches as fop
import fiftyone.zoo as foz
from fiftyone import ViewField as F
logger = logging.getLogger(__name__)
def parse_data(
samples,
patches_field=None,
data=None,
data_type="embeddings",
allow_missing=True,
warn_missing=True,
):
if isinstance(data, fob.SimilarityIndex):
return get_embeddings_from_index(
samples,
data,
patches_field=None,
allow_missing=True,
warn_missing=True,
)
_validate_args(samples, patches_field=patches_field)
if patches_field is None:
if isinstance(data, dict):
sample_ids, data = zip(*data.items())
return np.array(data), np.array(sample_ids), None
sample_ids, _ = get_ids(samples, data=data, data_type=data_type)
return data, sample_ids, None
if isinstance(data, dict):
value = next(iter(data.values()), None)
if isinstance(value, np.ndarray) and value.ndim == 1:
label_ids, data = zip(*data.items())
return _parse_label_data(
samples,
patches_field,
label_ids,
data,
data_type,
allow_missing,
warn_missing,
)
sample_ids, label_ids = get_ids(
samples,
patches_field=patches_field,
data=data,
data_type=data_type,
)
return data, sample_ids, label_ids
def _parse_label_data(
samples,
patches_field,
label_ids,
data,
data_type,
allow_missing,
warn_missing,
):
if samples._is_patches:
sample_id_path = "sample_id"
else:
sample_id_path = "id"
label_type, label_id_path = samples._get_label_field_path(
patches_field, "id"
)
is_list_field = issubclass(label_type, fol._LABEL_LIST_FIELDS)
ref_sample_ids, ref_label_ids = samples._dataset.values(
[sample_id_path, label_id_path]
)
if is_list_field:
ids_map = {}
for _sample_id, _lids in zip(ref_sample_ids, ref_label_ids):
if _lids:
for _label_id in _lids:
ids_map[_label_id] = _sample_id
else:
ids_map = dict(zip(ref_label_ids, ref_sample_ids))
_data = []
_sample_ids = []
_label_ids = []
_missing_ids = []
for _lid, _d in zip(label_ids, data):
_sid = ids_map.get(_lid, None)
if _sid is not None:
_data.append(_d)
_sample_ids.append(_sid)
_label_ids.append(_lid)
else:
_missing_ids.append(_lid)
num_missing = len(_missing_ids)
if num_missing > 0:
if not allow_missing:
raise ValueError(
"Unable to retrieve sample IDs for %d label IDs (eg %s)"
% (num_missing, _missing_ids[0])
)
if warn_missing:
logger.warning(
"Ignoring %s for %d label IDs (eg %s) for which sample IDs "
"could not be retrieved",
data_type,
num_missing,
_missing_ids[0],
)
return np.array(_data), np.array(_sample_ids), np.array(_label_ids)
def get_embeddings_from_index(
samples,
similarity_index,
patches_field=None,
allow_missing=True,
warn_missing=True,
):
if patches_field is None:
if samples._is_patches:
sample_id_path = "sample_id"
else:
sample_id_path = "id"
sample_ids = samples.values(sample_id_path)
label_ids = None
else:
if samples._is_patches:
label_id_path = "id"
else:
_, label_id_path = samples._get_label_field_path(
patches_field, "id"
)
sample_ids = None
label_ids = samples.values(label_id_path, unwind=True)
logger.info("Retrieving embeddings from similarity index...")
return similarity_index.get_embeddings(
sample_ids=sample_ids,
label_ids=label_ids,
allow_missing=allow_missing,
warn_missing=warn_missing,
)
def get_ids(
samples,
patches_field=None,
data=None,
data_type="embeddings",
handle_missing="skip",
ref_sample_ids=None,
):
_validate_args(samples, patches_field=patches_field)
if patches_field is None:
if ref_sample_ids is not None:
sample_ids = ref_sample_ids
else:
sample_ids = samples.values("id")
if data is not None and len(sample_ids) != len(data):
raise ValueError(
"The number of %s (%d) in these results no longer matches the "
"number of samples (%d) in the collection. You must "
"regenerate the results"
% (data_type, len(data), len(sample_ids))
)
return np.array(sample_ids), None
sample_ids, label_ids = _get_patch_ids(
samples,
patches_field,
handle_missing=handle_missing,
ref_sample_ids=ref_sample_ids,
)
if data is not None and len(sample_ids) != len(data):
raise ValueError(
"The number of %s (%d) in these results no longer matches the "
"number of labels (%d) in the '%s' field of the collection. You "
"must regenerate the results"
% (data_type, len(data), len(sample_ids), patches_field)
)
return np.array(sample_ids), np.array(label_ids)
def filter_ids(
samples,
index_sample_ids,
index_label_ids,
patches_field=None,
allow_missing=True,
warn_missing=False,
):
_validate_args(samples, patches_field=patches_field)
if patches_field is None:
if samples._is_patches:
sample_ids = np.array(samples.values("sample_id"))
else:
sample_ids = np.array(samples.values("id"))
if index_sample_ids is None:
return sample_ids, None, None, None
keep_inds, good_inds, bad_ids = _parse_ids(
sample_ids,
index_sample_ids,
"samples",
allow_missing,
warn_missing,
)
if bad_ids is not None:
sample_ids = sample_ids[good_inds]
return sample_ids, None, keep_inds, good_inds
sample_ids, label_ids = _get_patch_ids(samples, patches_field)
if index_label_ids is None:
return sample_ids, label_ids, None, None
keep_inds, good_inds, bad_ids = _parse_ids(
label_ids,
index_label_ids,
"labels",
allow_missing,
warn_missing,
)
if bad_ids is not None:
sample_ids = sample_ids[good_inds]
label_ids = label_ids[good_inds]
return sample_ids, label_ids, keep_inds, good_inds
def _get_patch_ids(
samples, patches_field, handle_missing="skip", ref_sample_ids=None
):
if samples._is_patches:
sample_id_path = "sample_id"
else:
sample_id_path = "id"
label_type, label_id_path = samples._get_label_field_path(
patches_field, "id"
)
is_list_field = issubclass(label_type, fol._LABEL_LIST_FIELDS)
sample_ids, label_ids = samples.values([sample_id_path, label_id_path])
if ref_sample_ids is not None:
sample_ids, label_ids = _apply_ref_sample_ids(
sample_ids, label_ids, ref_sample_ids
)
if is_list_field:
sample_ids, label_ids = _flatten_list_ids(
sample_ids, label_ids, handle_missing
)
return np.array(sample_ids), np.array(label_ids)
def _apply_ref_sample_ids(sample_ids, label_ids, ref_sample_ids):
ref_label_ids = [None] * len(ref_sample_ids)
inds_map = {_id: i for i, _id in enumerate(ref_sample_ids)}
for _id, _lid in zip(sample_ids, label_ids):
idx = inds_map.get(_id, None)
if idx is not None:
ref_label_ids[idx] = _lid
return ref_sample_ids, ref_label_ids
def _flatten_list_ids(sample_ids, label_ids, handle_missing):
_sample_ids = []
_label_ids = []
_add_missing = handle_missing == "image"
for _id, _lids in zip(sample_ids, label_ids):
if _lids:
for _lid in _lids:
_sample_ids.append(_id)
_label_ids.append(_lid)
elif _add_missing:
_sample_ids.append(_id)
_label_ids.append(None)
return _sample_ids, _label_ids
def _parse_ids(ids, index_ids, ftype, allow_missing, warn_missing):
if np.array_equal(ids, index_ids):
return None, None, None
inds_map = {_id: idx for idx, _id in enumerate(index_ids)}
keep_inds = []
bad_inds = []
bad_ids = []
for _idx, _id in enumerate(ids):
ind = inds_map.get(_id, None)
if ind is not None:
keep_inds.append(ind)
else:
bad_inds.append(_idx)
bad_ids.append(_id)
num_missing_index = len(index_ids) - len(keep_inds)
if num_missing_index > 0:
if not allow_missing:
raise ValueError(
"The index contains %d %s that are not present in the "
"provided collection" % (num_missing_index, ftype)
)
if warn_missing:
logger.warning(
"Ignoring %d %s from the index that are not present in the "
"provided collection",
num_missing_index,
ftype,
)
num_missing_collection = len(bad_ids)
if num_missing_collection > 0:
if not allow_missing:
raise ValueError(
"The provided collection contains %d %s not present in the "
"index" % (num_missing_collection, ftype)
)
if warn_missing:
logger.warning(
"Ignoring %d %s from the provided collection that are not "
"present in the index",
num_missing_collection,
ftype,
)
bad_inds = np.array(bad_inds, dtype=np.int64)
good_inds = np.full(ids.shape, True)
good_inds[bad_inds] = False
else:
good_inds = None
bad_ids = None
keep_inds = np.array(keep_inds, dtype=np.int64)
return keep_inds, good_inds, bad_ids
def skip_ids(samples, ids, patches_field=None, warn_existing=False):
sample_ids, label_ids = get_ids(samples, patches_field=patches_field)
if patches_field is not None:
exclude_ids = list(set(label_ids) & set(ids))
num_existing = len(exclude_ids)
if num_existing > 0:
if warn_existing:
logger.warning("Skipping %d existing label IDs", num_existing)
samples = samples.exclude_labels(
ids=exclude_ids, fields=patches_field
)
else:
exclude_ids = list(set(sample_ids) & set(ids))
num_existing = len(exclude_ids)
if num_existing > 0:
if warn_existing:
logger.warning("Skipping %d existing sample IDs", num_existing)
samples = samples.exclude(exclude_ids)
return samples
def add_ids(
sample_ids,
label_ids,
index_sample_ids,
index_label_ids,
patches_field=None,
overwrite=True,
allow_existing=True,
warn_existing=False,
):
if patches_field is not None:
ids = label_ids
index_ids = index_label_ids
else:
ids = sample_ids
index_ids = index_sample_ids
ii = []
jj = []
ids_map = {_id: _i for _i, _id in enumerate(index_ids)}
new_idx = len(index_ids)
for _i, _id in enumerate(ids):
_idx = ids_map.get(_id, None)
if _idx is None:
_idx = new_idx
new_idx += 1
ii.append(_i)
jj.append(_idx)
ii = np.array(ii)
jj = np.array(jj)
n = len(index_sample_ids)
if not overwrite:
existing_inds = np.nonzero(jj < n)[0]
num_existing = existing_inds.size
if num_existing > 0:
if not allow_existing:
raise ValueError(
"Found %d IDs (eg '%s') that are already present in the "
"index" % (num_existing, ids[ii[0]])
)
elif warn_existing:
logger.warning(
"Ignoring %d IDs (eg '%s') that are already present in "
"the index",
num_existing,
ids[ii[0]],
)
ii = np.delete(ii, existing_inds)
jj = np.delete(jj, existing_inds)
if ii.size > 0:
sample_ids = np.array(sample_ids)
if patches_field is not None:
label_ids = np.array(label_ids)
m = max(jj) - n + 1
if n == 0:
index_sample_ids = np.array([], dtype=sample_ids.dtype)
if patches_field is not None:
index_label_ids = np.array([], dtype=label_ids.dtype)
if m > 0:
index_sample_ids = np.concatenate(
(index_sample_ids, np.empty(m, dtype=index_sample_ids.dtype))
)
if patches_field is not None:
index_label_ids = np.concatenate(
(index_label_ids, np.empty(m, dtype=index_label_ids.dtype))
)
index_sample_ids[jj] = sample_ids[ii]
if patches_field is not None:
index_label_ids[jj] = label_ids[ii]
return index_sample_ids, index_label_ids, ii, jj
def add_embeddings(
samples,
embeddings,
sample_ids,
label_ids,
embeddings_field,
patches_field=None,
):
dataset = samples._dataset
if dataset.media_type == fomm.GROUP:
view = dataset.select_group_slices(_allow_mixed=True)
else:
view = dataset
if patches_field is not None:
_, embeddings_path = dataset._get_label_field_path(
patches_field, embeddings_field
)
values = dict(zip(label_ids, embeddings))
view.set_label_values(embeddings_path, values, dynamic=True)
else:
values = dict(zip(sample_ids, embeddings))
view.set_values(embeddings_field, values, key_field="id")
def remove_ids(
sample_ids,
label_ids,
index_sample_ids,
index_label_ids,
patches_field=None,
allow_missing=True,
warn_missing=False,
):
rm_inds = []
if sample_ids is not None:
rm_inds.extend(
_find_ids(
sample_ids,
index_sample_ids,
allow_missing,
warn_missing,
"sample",
)
)
if label_ids is not None:
rm_inds.extend(
_find_ids(
label_ids,
index_label_ids,
allow_missing,
warn_missing,
"label",
)
)
rm_inds = np.array(rm_inds)
if rm_inds.size > 0:
index_sample_ids = np.delete(index_sample_ids, rm_inds)
if patches_field is not None:
index_label_ids = np.delete(index_label_ids, rm_inds)
return index_sample_ids, index_label_ids, rm_inds
def _find_ids(ids, index_ids, allow_missing, warn_missing, ftype):
found_inds = []
missing_ids = []
ids_map = {_id: _i for _i, _id in enumerate(index_ids)}
for _id in ids:
ind = ids_map.get(_id, None)
if ind is not None:
found_inds.append(ind)
elif not allow_missing:
missing_ids.append(_id)
num_missing = len(missing_ids)
if num_missing > 0:
if not allow_missing:
raise ValueError(
"Found %d %d IDs (eg '%s') that are not present in the index"
% (num_missing, ftype, missing_ids[0])
)
if warn_missing:
logger.warning(
"Ignoring %d %d IDs (eg '%s') that are not present in the "
"index",
num_missing,
ftype,
missing_ids[0],
)
return found_inds
def remove_embeddings(
samples,
embeddings_field,
sample_ids=None,
label_ids=None,
patches_field=None,
):
dataset = samples._dataset
if dataset.media_type == fomm.GROUP:
view = dataset.select_group_slices(_allow_mixed=True)
else:
view = dataset
if patches_field is not None:
_, embeddings_path = dataset._get_label_field_path(
patches_field, embeddings_field
)
if sample_ids is not None and label_ids is None:
_, id_path = dataset._get_label_field_path(patches_field, "id")
label_ids = view.select(sample_ids).values(id_path, unwind=True)
if label_ids is not None:
values = dict(zip(label_ids, itertools.repeat(None)))
view.set_label_values(embeddings_path, values)
elif sample_ids is not None:
values = dict(zip(sample_ids, itertools.repeat(None)))
view.set_values(embeddings_field, values, key_field="id")
def filter_values(values, keep_inds, patches_field=None):
if patches_field:
_values = list(itertools.chain.from_iterable(values))
else:
_values = values
_values = np.asarray(_values)
if _values.size == keep_inds.size:
_values = _values[keep_inds]
else:
num_expected = np.count_nonzero(keep_inds)
if _values.size != num_expected:
raise ValueError(
"Expected %d raw values or %d pre-filtered values; found %d "
"values" % (keep_inds.size, num_expected, values.size)
)
# @todo we might need to re-ravel patch values here in the future
# We currently do not do this because all downstream users of this data
# will gracefully handle either flat or nested list data
return _values
def get_values(samples, path_or_expr, ids, patches_field=None):
_validate_args(
samples, patches_field=patches_field, path_or_expr=path_or_expr
)
return samples._get_values_by_id(
path_or_expr, ids, link_field=patches_field
)
def parse_data_field(
samples,
data_field,
patches_field=None,
data_type="embeddings",
):
if not etau.is_str(data_field):
raise ValueError(
"Invalid %s field '%s'; expected a string field name"
% (data_type, data_field)
)
if patches_field is None:
_data_field, is_frame_field = samples._handle_frame_field(data_field)
if "." in _data_field:
root, _ = _data_field.rsplit(".", 1)
if not samples.has_field(root):
raise ValueError(
"Invalid %s field '%s'; root field '%s' does not exist"
% (data_type, data_field, root)
)
data_exists = samples.has_field(data_field)
return data_field, data_exists
if data_field.startswith(patches_field + "."):
_, root = samples._get_label_field_path(patches_field)
if not data_field.startswith(root + "."):
raise ValueError(
"Invalid %s field '%s' for patches field '%s'"
% (data_type, data_field, patches_field)
)
data_field = data_field[len(root) + 1 :]
if "." in data_field:
_, root = samples._get_label_field_path(patches_field)
root += data_field.rsplit(".", 1)[0]
if not samples.has_field(root):
raise ValueError(
"Invalid %s field '%s'; root field '%s' does not exist"
% (data_type, data_field, root)
)
_, data_path = samples._get_label_field_path(patches_field, data_field)
data_exists = samples.has_field(data_path)
return data_field, data_exists
def get_embeddings(
samples,
model=None,
model_kwargs=None,
patches_field=None,
embeddings_field=None,
embeddings=None,
similarity_index=None,
force_square=False,
alpha=None,
handle_missing="skip",
agg_fcn=None,
batch_size=None,
num_workers=None,
skip_failures=True,
progress=None,
):
_validate_args(samples, patches_field=patches_field)
if (
model is None
and embeddings_field is None
and embeddings is None
and similarity_index is None
):
return _empty_embeddings(patches_field)
if similarity_index is not None:
return get_embeddings_from_index(
samples,
similarity_index,
patches_field=patches_field,
allow_missing=True,
warn_missing=True,
)
if (
embeddings is None
and model is not None
and not _has_embeddings_field(samples, embeddings_field, patches_field)
):
if etau.is_str(model):
model_kwargs = model_kwargs or {}
model = foz.load_zoo_model(model, **model_kwargs)
if patches_field is not None:
logger.info("Computing patch embeddings...")
embeddings = samples.compute_patch_embeddings(
model,
patches_field,
embeddings_field=embeddings_field,
force_square=force_square,
alpha=alpha,
handle_missing=handle_missing,
batch_size=batch_size,
num_workers=num_workers,
skip_failures=skip_failures,
progress=progress,
)
else:
logger.info("Computing embeddings...")
embeddings = samples.compute_embeddings(
model,
embeddings_field=embeddings_field,
batch_size=batch_size,
num_workers=num_workers,
skip_failures=skip_failures,
progress=progress,
)
if embeddings is None and embeddings_field is not None:
embeddings, samples = _load_embeddings(
samples, embeddings_field, patches_field=patches_field
)
ref_sample_ids = None
else:
if isinstance(embeddings, dict):
embeddings = [
embeddings.get(_id, None) for _id in samples.values("id")
]
embeddings, ref_sample_ids = _handle_missing_embeddings(
embeddings, samples
)
if not isinstance(embeddings, np.ndarray) and not embeddings:
return _empty_embeddings(patches_field)
if patches_field is not None:
if agg_fcn is not None:
embeddings = np.stack([agg_fcn(e) for e in embeddings])
else:
embeddings = np.concatenate(embeddings, axis=0)
elif not isinstance(embeddings, np.ndarray):
embeddings = np.stack(embeddings)
if agg_fcn is not None:
patches_field = None
sample_ids, label_ids = get_ids(
samples,
patches_field=patches_field,
data=embeddings,
data_type="embeddings",
handle_missing=handle_missing,
ref_sample_ids=ref_sample_ids,
)
return embeddings, sample_ids, label_ids
def get_unique_name(name, ref_names_or_fcn, max_len=None):
unique_name = _get_unique_name(name, ref_names_or_fcn)
if max_len is not None:
while name and len(unique_name) > max_len:
name = name[:-1]
unique_name = _get_unique_name(name, ref_names_or_fcn)
return unique_name
def _get_unique_name(name, ref_names_or_fcn):
if etau.is_container(ref_names_or_fcn):
return _get_unique_name_from_list(name, ref_names_or_fcn)
return _get_unique_name_from_function(name, ref_names_or_fcn)
def _get_unique_name_from_list(name, ref_names):
ref_names = set(ref_names)
if name not in ref_names:
return name
name += "-" + _get_random_characters(6)
while name in ref_names:
name += _get_random_characters(1)
return name
def _get_unique_name_from_function(name, exists_fcn):
if not exists_fcn(name):
return name
name += "-" + _get_random_characters(6)
while exists_fcn(name):
name += _get_random_characters(1)
return name
def _get_random_characters(n):
return "".join(
random.choice(string.ascii_lowercase + string.digits) for _ in range(n)
)
def _empty_embeddings(patches_field):
embeddings = np.empty((0, 0), dtype=float)
sample_ids = np.array([], dtype="`_
|
"""
# For backwards-compatibility with older versions of plugins like
# https://github.com/voxel51/fiftyone-plugins/blob/5c800f1ded53c285f8e17f37e1ad9b2472fa93e7/plugins/brain/__init__.py#L25
from fiftyone.brain.visualization import (
Visualization,
UMAPVisualization,
TSNEVisualization,
PCAVisualization,
ManualVisualization,
)
================================================
FILE: fiftyone/brain/internal/models/.gitignore
================================================
cache/
================================================
FILE: fiftyone/brain/internal/models/__init__.py
================================================
"""
Brain models.
| Copyright 2017-2026, Voxel51, Inc.
| `voxel51.com `_
|
"""
from copy import deepcopy
import logging
import os
from eta.core.config import ConfigError
import eta.core.learning as etal
import eta.core.models as etam
import fiftyone.core.models as fom
logger = logging.getLogger(__name__)
_THIS_DIR = os.path.dirname(os.path.abspath(__file__))
_MODELS_MANIFEST_PATH = os.path.join(_THIS_DIR, "manifest.json")
_MODELS_DIR = os.path.join(_THIS_DIR, "cache")
def list_models():
"""Returns the list of available models.
Returns:
a list of model names
"""
manifest = _load_models_manifest()
return sorted([model.name for model in manifest])
def list_downloaded_models():
"""Returns information about the models that have been downloaded.
Returns:
a dict mapping model names to (model path, ``eta.core.models.Model``)
tuples
"""
manifest = _load_models_manifest()
models = {}
for model in manifest:
if model.is_in_dir(_MODELS_DIR):
model_path = model.get_path_in_dir(_MODELS_DIR)
models[model.name] = (model_path, model)
return models
def is_model_downloaded(name):
"""Determines whether the model of the given name is downloaded.
Args:
name: the name of the model, which can have ``@`` appended to
refer to a specific version of the model. If no version is
specified, the latest version of the model is used
Returns:
True/False
"""
model = _get_model(name)
return model.is_in_dir(_MODELS_DIR)
def download_model(name, overwrite=False):
"""Downloads the model of the given name.
If the model is already downloaded, it is not re-downloaded unless
``overwrite == True`` is specified.
Args:
name: the name of the model, which can have ``@`` appended to
refer to a specific version of the model. If no version is
specified, the latest version of the model is used. Call
:func:`list_models` to see the available models
overwrite (False): whether to overwrite any existing files
Returns:
tuple of
- model: the ``eta.core.models.Model`` instance for the model
- model_path: the path to the downloaded model on disk
"""
model, model_path = _get_model_in_dir(name)
if not overwrite and is_model_downloaded(name):
logger.info("Model '%s' is already downloaded", name)
else:
model.manager.download_model(model_path, force=overwrite)
return model, model_path
def install_model_requirements(name, error_level=0):
"""Installs any package requirements for the model with the given name.
Args:
name: the name of the model, which can have ``@`` appended to
refer to a specific version of the model. If no version is
specified, the latest version of the model is used. Call
:func:`list_models` to see the available models
error_level: the error level to use, defined as:
0: raise error if a requirement install fails
1: log warning if a requirement install fails
2: ignore install fails
"""
model = _get_model(name)
model.install_requirements(error_level=error_level)
def ensure_model_requirements(name, error_level=0):
"""Ensures that the package requirements for the model with the given name
are satisfied.
Args:
name: the name of the model, which can have ``@`` appended to
refer to a specific version of the model. If no version is
specified, the latest version of the model is used. Call
:func:`list_models` to see the available models
error_level: the error level to use, defined as:
0: raise error if a requirement is not satisfied
1: log warning if a requirement is not satisifed
2: ignore unsatisifed requirements
"""
model = _get_model(name)
model.ensure_requirements(error_level=error_level)
def load_model(
name,
download_if_necessary=True,
install_requirements=False,
error_level=0,
**kwargs
):
"""Loads the model of the given name.
By default, the model will be downloaded if necessary.
Args:
name: the name of the model, which can have ``@`` appended to
refer to a specific version of the model. If no version is
specified, the latest version of the model is used. Call
:func:`list_models` to see the available models
download_if_necessary (True): whether to download the model if it is
not found in the specified directory
install_requirements: whether to install any requirements before
loading the model. By default, this is False
error_level: the error level to use, defined as:
0: raise error if a requirement is not satisfied
1: log warning if a requirement is not satisifed
2: ignore unsatisifed requirements
**kwargs: keyword arguments to inject into the model's ``Config``
instance
Returns:
a :class:`fiftyone.core.models.Model`
"""
model = _get_model(name)
if not model.is_in_dir(_MODELS_DIR):
if not download_if_necessary:
raise ValueError("Model '%s' is not downloaded" % name)
download_model(name)
if install_requirements:
model.install_requirements(error_level=error_level)
else:
model.ensure_requirements(error_level=error_level)
config_dict = deepcopy(model.default_deployment_config_dict)
model_path = model.get_path_in_dir(_MODELS_DIR)
return fom.load_model(config_dict, model_path=model_path, **kwargs)
def find_model(name):
"""Returns the path to the model on disk.
The model must be downloaded. Use :func:`download_model` to download
models.
Args:
name: the name of the model, which can have ``@`` appended to
refer to a specific version of the model. If no version is
specified, the latest version of the model is used
Returns:
the path to the model on disk
Raises:
ValueError: if the model does not exist or has not been downloaded
"""
model, model_path = _get_model_in_dir(name)
if not model.is_model_downloaded(model_path):
raise ValueError("Model '%s' is not downloaded" % name)
return model_path
def get_model(name):
"""Returns the ``eta.core.models.Model`` instance for the model with the
given name.
Args:
name: the name of the model
Returnsn ``eta.core.models.Model``:class:`ZooModel`
"""
return _get_model(name)
def delete_model(name):
"""Deletes the model from local disk, if necessary.
Args:
name: the name of the model, which can have ``@`` appended to
refer to a specific version of the model. If no version is
specified, the latest version of the model is used
"""
model, model_path = _get_model_in_dir(name)
model.flush_model(model_path)
class HasBrainModel(etal.HasPublishedModel):
"""Mixin class for Config classes of :class:`fiftyone.core.models.Model`
instances whose models are stored privately by the FiftyOne Brain.
"""
def download_model_if_necessary(self):
# pylint: disable=attribute-defined-outside-init
if not self.model_name and not self.model_path:
raise ConfigError(
"Either `model_name` or `model_path` must be provided"
)
if self.model_path is None:
self.model_path = download_model(self.model_name)
@classmethod
def _get_model(cls, model_name):
return get_model(model_name)
def _load_models_manifest():
return etam.ModelsManifest.from_json(_MODELS_MANIFEST_PATH)
def _get_model_in_dir(name):
model = _get_model(name)
model_path = model.get_path_in_dir(_MODELS_DIR)
return model, model_path
def _get_model(name):
if etam.Model.has_version_str(name):
return _get_exact_model(name)
return _get_latest_model(name)
def _get_exact_model(name):
manifest = _load_models_manifest()
try:
return manifest.get_model_with_name(name)
except etam.ModelError:
raise ValueError("No model with name '%s' was found" % name)
def _get_latest_model(base_name):
manifest = _load_models_manifest()
try:
return manifest.get_latest_model_with_base_name(base_name)
except etam.ModelError:
raise ValueError("No models found with base name '%s'" % base_name)
================================================
FILE: fiftyone/brain/internal/models/manifest.json
================================================
{
"models": [
{
"base_name": "simple-resnet-cifar10",
"base_filename": "simple-resnet-cifar10.pth",
"version": "1.0",
"description": "Simple ResNet trained on CIFAR-10",
"manager": {
"type": "fiftyone.core.models.ModelManager",
"config": {
"google_drive_id": "1SIO9XreK0w1ja4EuhBWcR10CnWxCOsom"
}
},
"default_deployment_config_dict": {
"type": "fiftyone.brain.internal.models.torch.TorchImageModel",
"config": {
"entrypoint_fcn": "fiftyone.brain.internal.models.simple_resnet.simple_resnet",
"output_processor_cls": "fiftyone.utils.torch.ClassifierOutputProcessor",
"labels_string": "airplane,automobile,bird,cat,deer,dog,frog,horse,ship,truck",
"image_size": [32, 32],
"image_mean": [0.4914, 0.4822, 0.4465],
"image_std": [0.2023, 0.1994, 0.201],
"embeddings_layer": "flatten",
"use_half_precision": false,
"cudnn_benchmark": true
}
},
"date_created": "2020-05-07 08:25:51"
}
]
}
================================================
FILE: fiftyone/brain/internal/models/simple_resnet.py
================================================
"""
Implementation of a simple ResNet that is suitable only for smallish data.
The original implementation of this is from David Page's work on fast model
training with resnets at https://github.com/davidcpage/cifar10-fast.
| Copyright 2017-2026, Voxel51, Inc.
| `voxel51.com `_
|
"""
from collections import namedtuple
import os
import numpy as np
import torch
from torch import nn
def simple_resnet(
channels=None,
weight=0.125,
pool=nn.MaxPool2d(2),
extra_layers=(),
res_layers=("layer1", "layer3"),
):
channels = channels or {
"prep": 64,
"layer1": 128,
"layer2": 256,
"layer3": 512,
}
net = {
"input": (None, []),
"prep": conv_bn(3, channels["prep"]),
"layer1": dict(
conv_bn(channels["prep"], channels["layer1"]), pool=pool
),
"layer2": dict(
conv_bn(channels["layer1"], channels["layer2"]), pool=pool
),
"layer3": dict(
conv_bn(channels["layer2"], channels["layer3"]), pool=pool
),
"pool": nn.MaxPool2d(4),
"flatten": Flatten(),
"linear": nn.Linear(channels["layer3"], 10, bias=False),
"logits": Mul(weight),
}
for layer in res_layers:
net[layer]["residual"] = residual(channels[layer])
for layer in extra_layers:
net[layer]["extra"] = conv_bn(channels[layer], channels[layer])
return Network(net, input_layer="input", output_layer="logits")
class Network(nn.Module):
def __init__(self, net, input_layer=None, output_layer=None):
super().__init__()
self.input_layer = input_layer
self.output_layer = output_layer
self.graph = build_graph(net)
for path, (val, _) in self.graph.items():
setattr(self, path.replace("/", "_"), val)
def nodes(self):
return (node for node, _ in self.graph.values())
def forward(self, inputs):
if self.input_layer:
outputs = {self.input_layer: inputs}
else:
outputs = dict(inputs)
for k, (node, ins) in self.graph.items():
# only compute nodes that are not supplied as inputs.
if k not in outputs:
outputs[k] = node(*[outputs[x] for x in ins])
if self.output_layer:
return outputs[self.output_layer]
return outputs
def half(self):
for node in self.nodes():
if isinstance(node, nn.Module) and not isinstance(
node, nn.BatchNorm2d
):
node.half()
return self
def has_inputs(node):
return type(node) is tuple
def build_graph(net):
flattened = pipeline(net)
resolve_input = lambda rel_path, path, idx: (
os.path.normpath(os.path.sep.join((path, "..", rel_path)))
if isinstance(rel_path, str)
else flattened[idx + rel_path][0]
)
return {
path: (
node[0],
[resolve_input(rel_path, path, idx) for rel_path in node[1]],
)
for idx, (path, node) in enumerate(flattened)
}
def pipeline(net):
return [
(os.path.sep.join(path), (node if has_inputs(node) else (node, [-1])))
for (path, node) in path_iter(net)
]
class Crop(namedtuple("Crop", ("h", "w"))):
def __call__(self, x, x0, y0):
return x[..., y0 : y0 + self.h, x0 : x0 + self.w]
def options(self, shape):
*_, H, W = shape
return [
{"x0": x0, "y0": y0}
for x0 in range(W + 1 - self.w)
for y0 in range(H + 1 - self.h)
]
def output_shape(self, shape):
*_, H, W = shape
return (*_, self.h, self.w)
class FlipLR(namedtuple("FlipLR", ())):
def __call__(self, x, choice):
if isinstance(x, np.ndarray):
return x[..., ::-1].copy()
return torch.flip(x, [-1]) if choice else x
def options(self, shape):
return [{"choice": b} for b in [True, False]]
class Cutout(namedtuple("Cutout", ("h", "w"))):
def __call__(self, x, x0, y0):
x[..., y0 : y0 + self.h, x0 : x0 + self.w] = 0.0
return x
def options(self, shape):
*_, H, W = shape
return [
{"x0": x0, "y0": y0}
for x0 in range(W + 1 - self.w)
for y0 in range(H + 1 - self.h)
]
class PiecewiseLinear(namedtuple("PiecewiseLinear", ("knots", "vals"))):
def __call__(self, t):
return np.interp([t], self.knots, self.vals)[0]
class Const(namedtuple("Const", ["val"])):
def __call__(self, x):
return self.val
class Identity(namedtuple("Identity", [])):
def __call__(self, x):
return x
class Add(namedtuple("Add", [])):
def __call__(self, x, y):
return x + y
class AddWeighted(namedtuple("AddWeighted", ["wx", "wy"])):
def __call__(self, x, y):
return self.wx * x + self.wy * y
class Mul(nn.Module):
def __init__(self, weight):
super().__init__()
self.weight = weight
def __call__(self, x):
return x * self.weight
class Flatten(nn.Module):
def forward(self, x):
return x.view(x.size(0), x.size(1))
class Concat(nn.Module):
def forward(self, *xs):
return torch.cat(xs, 1)
class BatchNorm(nn.BatchNorm2d):
def __init__(
self,
num_features,
eps=1e-05,
momentum=0.1,
weight_freeze=False,
bias_freeze=False,
weight_init=1.0,
bias_init=0.0,
):
super().__init__(num_features, eps=eps, momentum=momentum)
if weight_init is not None:
self.weight.data.fill_(weight_init)
if bias_init is not None:
self.bias.data.fill_(bias_init)
self.weight.requires_grad = not weight_freeze
self.bias.requires_grad = not bias_freeze
def conv_bn(c_in, c_out):
return {
"conv": nn.Conv2d(
c_in, c_out, kernel_size=3, stride=1, padding=1, bias=False
),
"bn": BatchNorm(c_out),
"relu": nn.ReLU(True),
}
def residual(c):
return {
"in": Identity(),
"res1": conv_bn(c, c),
"res2": conv_bn(c, c),
"add": (Add(), ["in", "res2/relu"]),
}
def path_iter(nested_dict, pfx=()):
for name, val in nested_dict.items():
if isinstance(val, dict):
yield from path_iter(val, (*pfx, name))
else:
yield ((*pfx, name), val)
MODEL = "model"
VALID_MODEL = "valid_model"
OUTPUT = "output"
================================================
FILE: fiftyone/brain/internal/models/torch.py
================================================
"""
PyTorch utilities.
| Copyright 2017-2026, Voxel51, Inc.
| `voxel51.com `_
|
"""
import fiftyone.utils.torch as fout
from fiftyone.brain.internal.models import HasBrainModel
import torch
class TorchImageModelConfig(fout.TorchImageModelConfig, HasBrainModel):
"""Configuration for running a :class:`TorchImageModel`.
See :class:`fiftyone.utils.torch.TorchImageModelConfig` for additional
parameters.
Args:
model_name (None): the name of the Brain model state dict to load
model_path (None): the path to a state dict on disk to load
"""
def __init__(self, d):
d = self.init(d)
super().__init__(d)
class TorchImageModel(fout.TorchImageModel):
"""Wrapper for evaluating a Torch model on images whose state dict is
stored privately by the Brain.
Args:
config: an :class:`TorchImageModelConfig`
"""
def _download_model(self, config):
config.download_model_if_necessary()
def _load_state_dict(self, model, config):
state_dict = torch.load(config.model_path, map_location=self.device)
model.load_state_dict(state_dict)
================================================
FILE: fiftyone/brain/similarity.py
================================================
"""
Similarity interface.
| Copyright 2017-2026, Voxel51, Inc.
| `voxel51.com `_
|
"""
from collections import defaultdict
from copy import deepcopy
import inspect
import logging
from bson import ObjectId
import numpy as np
import eta.core.utils as etau
import fiftyone.brain as fb
import fiftyone.core.brain as fob
import fiftyone.core.context as foc
import fiftyone.core.dataset as fod
import fiftyone.core.fields as fof
import fiftyone.core.labels as fol
import fiftyone.core.media as fom
import fiftyone.core.patches as fop
import fiftyone.core.stages as fos
import fiftyone.core.utils as fou
import fiftyone.core.view as fov
import fiftyone.core.validation as fova
import fiftyone.zoo as foz
from fiftyone import ViewField as F
fbu = fou.lazy_import("fiftyone.brain.internal.core.utils")
logger = logging.getLogger(__name__)
_ALLOWED_ROI_FIELD_TYPES = (
fol.Detection,
fol.Detections,
fol.Polyline,
fol.Polylines,
)
_DEFAULT_MODEL = "mobilenet-v2-imagenet-torch"
_DEFAULT_BATCH_SIZE = None
def compute_similarity(
samples,
patches_field,
roi_field,
embeddings,
brain_key,
model,
model_kwargs,
force_square,
alpha,
batch_size,
num_workers,
skip_failures,
progress,
backend,
**kwargs,
):
"""See ``fiftyone/brain/__init__.py``."""
fova.validate_collection(samples)
if roi_field is not None:
fova.validate_collection_label_fields(
samples, roi_field, _ALLOWED_ROI_FIELD_TYPES
)
# Allow for `embeddings_field=XXX` and `embeddings=False` together
embeddings_field = kwargs.pop("embeddings_field", None)
if embeddings_field is not None or etau.is_str(embeddings):
if embeddings_field is None:
embeddings_field = embeddings
embeddings = None
embeddings_field, embeddings_exist = fbu.parse_data_field(
samples,
embeddings_field,
patches_field=patches_field or roi_field,
data_type="embeddings",
)
else:
embeddings_field = None
embeddings_exist = None
if model is None and embeddings is None and not embeddings_exist:
model = _DEFAULT_MODEL
if batch_size is None:
batch_size = _DEFAULT_BATCH_SIZE
if etau.is_str(model):
_model_kwargs = model_kwargs or {}
_model = foz.load_zoo_model(model, **_model_kwargs)
else:
_model = model
try:
supports_prompts = _model.can_embed_prompts
except:
supports_prompts = False
if brain_key is not None and supports_prompts and not etau.is_str(model):
logger.warning(
"This index will not support prompt queries in the App or in "
"future Python sessions. You can support this by providing the "
"string name of a zoo model rather than a Model instance to "
"compute_similarity(model=)."
)
config = _parse_config(
backend,
embeddings_field=embeddings_field,
patches_field=patches_field,
roi_field=roi_field,
model=model,
model_kwargs=model_kwargs,
supports_prompts=supports_prompts,
**kwargs,
)
brain_method = config.build()
brain_method.ensure_requirements()
# Similarity indexes can be modified after creation, so we always register
# the index on the full dataset so that queries will always be performed
# against the full index by default
dataset = samples._root_dataset
if samples._is_frames:
dataset = samples._base_view
if brain_key is not None:
# Don't allow overwriting an existing run with same key, since we
# need the existing run in order to perform workflows like
# automatically cleaning up the backend's index
brain_method.register_run(dataset, brain_key, overwrite=False)
results = brain_method.initialize(dataset, brain_key)
results._model = _model
results._supports_prompts = supports_prompts
get_embeddings = embeddings is not False
if not results.is_external and results.total_index_size > 0:
# No need to load embeddings because the index already has them
get_embeddings = False
if get_embeddings:
# Don't immediatly store embeddings in DB; let `add_to_index()` do it
if not embeddings_exist:
embeddings_field = None
if roi_field is not None:
handle_missing = "image"
agg_fcn = lambda e: np.mean(e, axis=0)
else:
handle_missing = "skip"
agg_fcn = None
embeddings, sample_ids, label_ids = fbu.get_embeddings(
samples,
model=_model,
patches_field=patches_field or roi_field,
embeddings=embeddings,
embeddings_field=embeddings_field,
force_square=force_square,
alpha=alpha,
handle_missing=handle_missing,
agg_fcn=agg_fcn,
batch_size=batch_size,
num_workers=num_workers,
skip_failures=skip_failures,
progress=progress,
)
else:
embeddings = None
sample_ids = None
label_ids = None
if embeddings is not None:
results.add_to_index(embeddings, sample_ids, label_ids=label_ids)
brain_method.save_run_results(dataset, brain_key, results)
return results
def _parse_config(name, **kwargs):
if name is None:
name = fb.brain_config.default_similarity_backend
if inspect.isclass(name):
return name(**kwargs)
backends = fb.brain_config.similarity_backends
if name not in backends:
raise ValueError(
"Unsupported backend '%s'. The available backends are %s"
% (name, sorted(backends.keys()))
)
params = deepcopy(backends[name])
config_cls = kwargs.pop("config_cls", None)
if config_cls is None:
config_cls = params.pop("config_cls", None)
if config_cls is None:
raise ValueError("Similarity backend '%s' has no `config_cls`" % name)
if etau.is_str(config_cls):
config_cls = etau.get_class(config_cls)
params.update(**kwargs)
return config_cls(**params)
class SimilarityConfig(fob.BrainMethodConfig):
"""Similarity configuration.
Args:
embeddings_field (None): the sample field containing the embeddings,
if one was provided
model (None): the :class:`fiftyone.core.models.Model` or name of the
zoo model that was used to compute embeddings, if known
model_kwargs (None): a dictionary of optional keyword arguments to pass
to the model's ``Config`` when a model name is provided
patches_field (None): the sample field defining the patches being
analyzed, if any
roi_field (None): the sample field defining a region of interest within
each image to use to compute embeddings, if any
supports_prompts (False): whether this run supports prompt queries
"""
def __init__(
self,
embeddings_field=None,
model=None,
model_kwargs=None,
patches_field=None,
roi_field=None,
supports_prompts=None,
**kwargs,
):
if model is not None and not etau.is_str(model):
model = None
# We can't declare permanent support for prompts because we don't
# know how to load the model in future sessions
supports_prompts = None
self.embeddings_field = embeddings_field
self.model = model
self.model_kwargs = model_kwargs
self.patches_field = patches_field
self.roi_field = roi_field
self.supports_prompts = supports_prompts
super().__init__(**kwargs)
@property
def type(self):
return "similarity"
@property
def method(self):
"""The name of the similarity backend."""
raise NotImplementedError("subclass must implement method")
@property
def max_k(self):
"""A maximum k value for nearest neighbor queries, or None if there is
no limit.
"""
raise NotImplementedError("subclass must implement max_k")
@property
def supports_least_similarity(self):
"""Whether this backend supports least similarity queries."""
raise NotImplementedError(
"subclass must implement supports_least_similarity"
)
@property
def supported_aggregations(self):
"""A tuple of supported values for the ``aggregation`` parameter of the
backend's
:meth:`sort_by_similarity() ` and
:meth:`_kneighbors() ` methods.
"""
raise NotImplementedError(
"subclass must implement supported_aggregations"
)
def load_credentials(self, **kwargs):
self._load_parameters(**kwargs)
def _load_parameters(self, **kwargs):
name = self.method
parameters = fb.brain_config.similarity_backends.get(name, {})
for name, value in kwargs.items():
if value is None:
value = parameters.get(name, None)
if value is not None:
setattr(self, name, value)
class Similarity(fob.BrainMethod):
"""Base class for similarity factories.
Args:
config: a :class:`SimilarityConfig`
"""
def initialize(self, samples, brain_key):
"""Initializes a similarity index.
Args:
samples: a :class:`fiftyone.core.collections.SampleColllection`
brain_key: the brain key
Returns:
a :class:`SimilarityIndex`
"""
raise NotImplementedError("subclass must implement initialize()")
def get_fields(self, samples, brain_key):
fields = []
if self.config.patches_field is not None:
fields.append(self.config.patches_field)
if self.config.embeddings_field is not None:
fields.append(self.config.embeddings_field)
return fields
class SimilarityIndex(fob.BrainResults):
"""Base class for similarity indexes.
Args:
samples: the :class:`fiftyone.core.collections.SampleCollection` used
config: the :class:`SimilarityConfig` used
brain_key: the brain key
backend (None): a :class:`Similarity` backend
"""
def __init__(self, samples, config, brain_key, backend=None):
super().__init__(samples, config, brain_key, backend=backend)
self._model = None
self._supports_prompts = None
self._last_view = None
self._last_views = []
self._curr_view = None
self._curr_view_allow_missing = None
self._curr_view_warn_missing = None
self._curr_sample_ids = None
self._curr_label_ids = None
self._curr_keep_inds = None
self._curr_missing_size = None
self.use_view(samples)
def __enter__(self):
self._last_views.append(self._last_view)
return self
def __exit__(self, *args):
try:
last_view = self._last_views.pop()
except:
last_view = self._samples
self.use_view(last_view)
@property
def config(self):
"""The :class:`SimilarityConfig` for these results."""
return self._config
@property
def supports_prompts(self):
"""Whether this similarity index supports prompt queries."""
if self._supports_prompts is not None:
return self._supports_prompts
return self.config.supports_prompts or False
@property
def is_external(self):
"""Whether this similarity index manages its own embeddings (True) or
loads them directly from the ``embeddings_field`` of the dataset
(False).
"""
return True # assume external unless explicitly overridden
@property
def sample_ids(self):
"""The sample IDs of the full index, or ``None`` if not supported."""
return None
@property
def label_ids(self):
"""The label IDs of the full index, or ``None`` if not applicable or
not supported.
"""
return None
@property
def total_index_size(self):
"""The total number of data points in the index.
If :meth:`use_view` has been called to restrict the index, this value
may be larger than the current :meth:`index_size`.
"""
raise NotImplementedError("subclass must implement total_index_size")
@property
def has_view(self):
"""Whether the index is currently restricted to a view.
Use :meth:`use_view` to restrict the index to a view, and use
:meth:`clear_view` to reset to the full index.
"""
# Full dataset
if isinstance(self._curr_view, fod.Dataset):
return False
# Full group slices view
if (
isinstance(self._curr_view, fov.DatasetView)
and self._curr_view._root_dataset.media_type == fom.GROUP
and len(self._curr_view._stages) == 1
and isinstance(self._curr_view._stages[0], fos.SelectGroupSlices)
and self._curr_view._pipeline() == []
):
return False
# Full patches view
if (
self.config.patches_field is not None
and isinstance(self._curr_view, fop.PatchesView)
and len(self._curr_view._all_stages) == 1
):
return False
return self._curr_view.view() != self._samples.view()
@property
def view(self):
"""The :class:`fiftyone.core.collections.SampleCollection` against
which results are currently being generated.
If :meth:`use_view` has been called, this view may be different than
the collection on which the full index was generated.
"""
return self._curr_view
@property
def current_sample_ids(self):
"""The sample IDs of the currently active data points in the index.
If :meth:`use_view` has been called, this may be a subset of the full
index.
If the index does not support full sample ID lists (ie if
:meth:`sample_ids` is ``None``), then this will be all sample IDs in
the current :meth:`view` regardless of whether all samples are indexed.
"""
self._apply_view_if_necessary()
return self._curr_sample_ids
@property
def current_label_ids(self):
"""The label IDs of the currently active data points in the index, or
``None`` if not applicable.
If :meth:`use_view` has been called, this may be a subset of the full
index.
If the index does not support full label ID lists (ie if
:meth:`label_ids` is ``None``), then this will be all label IDs in
the current :meth:`view` regardless of whether all labels are indexed.
"""
self._apply_view_if_necessary()
return self._curr_label_ids
@property
def _current_inds(self):
"""The indices of :meth:`current_sample_ids` in :meth:`sample_ids`, or
``None`` if not supported or if the full index is currently being used.
"""
self._apply_view_if_necessary()
return self._curr_keep_inds
@property
def index_size(self):
"""The number of active data points in the index.
If :meth:`use_view` has been called to restrict the index, this
property will reflect the size of the active index.
"""
self._apply_view_if_necessary()
return len(self._curr_sample_ids)
@property
def missing_size(self):
"""The total number of data points in :meth:`view` that are missing
from this index, or ``None`` if unknown.
This property is only applicable when :meth:`use_view` has been called,
and it will be ``None`` if no data points are missing or when the
backend does not support it.
"""
self._apply_view_if_necessary()
return self._curr_missing_size
def add_to_index(
self,
embeddings,
sample_ids,
label_ids=None,
overwrite=True,
allow_existing=True,
warn_existing=False,
reload=True,
):
"""Adds the given embeddings to the index.
Args:
embeddings: a ``num_embeddings x num_dims`` array of embeddings
sample_ids: a ``num_embeddings`` array of sample IDs
label_ids (None): a ``num_embeddings`` array of label IDs, if
applicable
overwrite (True): whether to replace (True) or ignore (False)
existing embeddings with the same sample/label IDs
allow_existing (True): whether to ignore (True) or raise an error
(False) when ``overwrite`` is False and a provided ID already
exists in the
warn_existing (False): whether to log a warning if an embedding is
not added to the index because its ID already exists
reload (True): whether to call :meth:`reload` to refresh the
current view after the update
"""
raise NotImplementedError("subclass must implement add_to_index()")
def remove_from_index(
self,
sample_ids=None,
label_ids=None,
allow_missing=True,
warn_missing=False,
reload=True,
):
"""Removes the specified embeddings from the index.
Args:
sample_ids (None): an array of sample IDs
label_ids (None): an array of label IDs, if applicable
allow_missing (True): whether to allow the index to not contain IDs
that you provide (True) or whether to raise an error in this
case (False)
warn_missing (False): whether to log a warning if the index does
not contain IDs that you provide
reload (True): whether to call :meth:`reload` to refresh the
current view after the update
"""
raise NotImplementedError(
"subclass must implement remove_from_index()"
)
def get_embeddings(
self,
sample_ids=None,
label_ids=None,
allow_missing=True,
warn_missing=False,
):
"""Retrieves the embeddings for the given IDs from the index.
If no IDs are provided, the entire index is returned.
Args:
sample_ids (None): a sample ID or list of sample IDs for which to
retrieve embeddings
label_ids (None): a label ID or list of label IDs for which to
retrieve embeddings
allow_missing (True): whether to allow the index to not contain IDs
that you provide (True) or whether to raise an error in this
case (False)
warn_missing (False): whether to log a warning if the index does
not contain IDs that you provide
Returns:
a tuple of:
- a ``num_embeddings x num_dims`` array of embeddings
- a ``num_embeddings`` array of sample IDs
- a ``num_embeddings`` array of label IDs, if applicable, or else
``None``
"""
raise NotImplementedError("subclass must implement get_embeddings()")
def use_view(
self,
samples,
allow_missing=True,
warn_missing=False,
):
"""Restricts the index to the provided view.
Subsequent calls to methods on this instance will only contain results
from the specified view rather than the full index.
Use :meth:`clear_view` to reset to the full index. Or, equivalently,
use the context manager interface as demonstrated below to
automatically reset the view when the context exits.
Example usage::
import fiftyone as fo
import fiftyone.brain as fob
import fiftyone.zoo as foz
dataset = foz.load_zoo_dataset("quickstart")
results = fob.compute_similarity(dataset)
print(results.index_size) # 200
view = dataset.take(50)
with results.use_view(view):
print(results.index_size) # 50
results.find_unique(10)
print(results.unique_ids)
plot = results.visualize_unique()
plot.show()
Args:
samples: a :class:`fiftyone.core.collections.SampleCollection`
allow_missing (True): whether to allow the provided collection to
contain data points that this index does not contain (True) or
whether to raise an error in this case (False)
warn_missing (False): whether to log a warning if the provided
collection contains data points that this index does not
contain
Returns:
self
"""
self._last_view = self._curr_view
self._curr_view = samples
self._curr_view_allow_missing = allow_missing
self._curr_view_warn_missing = warn_missing
self._curr_sample_ids = None
self._curr_label_ids = None
self._curr_keep_inds = None
self._curr_missing_size = None
return self
def _apply_view(self):
sample_ids = self.sample_ids
label_ids = self.label_ids
if sample_ids is not None and not self.has_view:
keep_inds = None
good_inds = None
else:
sample_ids, label_ids, keep_inds, good_inds = fbu.filter_ids(
self._curr_view,
sample_ids,
label_ids,
patches_field=self.config.patches_field,
allow_missing=self._curr_view_allow_missing,
warn_missing=self._curr_view_warn_missing,
)
if good_inds is not None:
missing_size = good_inds.size - np.count_nonzero(good_inds)
else:
missing_size = None
self._curr_sample_ids = sample_ids
self._curr_label_ids = label_ids
self._curr_keep_inds = keep_inds
self._curr_missing_size = missing_size
def _apply_view_if_necessary(self):
if self._curr_sample_ids is None:
self._apply_view()
def clear_view(self):
"""Clears the view set by :meth:`use_view`, if any.
Subsequent operations will be performed on the full index.
"""
self.use_view(self._samples)
def reload(self):
"""Reloads the index for the current view.
Subclasses may override this method, but by default this method simply
passes the current :meth:`view` back into :meth:`use_view`, which
updates the index's current ID set based on any changes to the view
since the index was last loaded.
"""
self.use_view(self._curr_view)
def cleanup(self):
"""Deletes the similarity index from the backend."""
raise NotImplementedError("subclass must implement cleanup()")
def values(self, path_or_expr):
"""Extracts a flat list of values from the given field or expression
corresponding to the current :meth:`view`.
This method always returns values in the same order as
:meth:`current_sample_ids` and :meth:`current_label_ids`.
Args:
path_or_expr: the values to extract, which can be:
- the name of a sample field or ``embedded.field.name`` from
which to extract numeric or string values
- a :class:`fiftyone.core.expressions.ViewExpression`
defining numeric or string values to compute via
:meth:`fiftyone.core.collections.SampleCollection.values`
Returns:
a list of values
"""
samples = self.view
patches_field = self.config.patches_field
if patches_field is not None:
ids = self.current_label_ids
else:
ids = self.current_sample_ids
return fbu.get_values(
samples, path_or_expr, ids, patches_field=patches_field
)
def sort_by_similarity(
self,
query,
k=None,
reverse=False,
aggregation="mean",
dist_field=None,
_mongo=False,
):
"""Returns a view that sorts the samples/labels in :meth:`view` by
similarity to the specified query.
When querying by IDs, the query can be any ID(s) in the full index of
this instance, even if the current :meth:`view` contains a subset of
the full index.
Args:
query: the query, which can be any of the following:
- an ID or iterable of IDs
- a ``num_dims`` vector or ``num_queries x num_dims`` array
of vectors
- a prompt or iterable of prompts (if supported by the index)
k (None): the number of matches to return. Some backends may
support ``None``, in which case all samples will be sorted
reverse (False): whether to sort by least similarity (True) or
greatest similarity (False). Some backends may not support
least similarity
aggregation ("mean"): the aggregation method to use when multiple
queries are provided. The default is ``"mean"``, which means
that the query vectors are averaged prior to searching. Some
backends may support additional options
dist_field (None): the name of a float field in which to store the
distance of each example to the specified query. The field is
created if necessary
Returns:
a :class:`fiftyone.core.view.DatasetView`
"""
samples = self.view
patches_field = self.config.patches_field
selecting_samples = patches_field is None or isinstance(
samples, fop.PatchesView
)
kwargs = dict(
query=self._parse_query(query),
k=k,
reverse=reverse,
aggregation=aggregation,
return_dists=dist_field is not None,
)
if dist_field is not None:
sample_ids, label_ids, dists = self._kneighbors(**kwargs)
else:
sample_ids, label_ids = self._kneighbors(**kwargs)
if selecting_samples:
if patches_field is not None:
ids = label_ids
else:
ids = sample_ids
else:
ids = label_ids
# Store query distances
if dist_field is not None:
if selecting_samples:
values = dict(zip(ids, dists))
samples.set_values(dist_field, values, key_field="id")
else:
label_type, path = samples._get_label_field_path(
patches_field, dist_field
)
if issubclass(label_type, fol._LABEL_LIST_FIELDS):
samples._set_list_values_by_id(
path,
sample_ids,
label_ids,
dists,
path.rsplit(".", 1)[0],
)
else:
values = dict(zip(sample_ids, dists))
samples.set_values(path, values, key_field="id")
# Construct sorted view
stages = []
if selecting_samples:
stage = fos.Select(ids, ordered=True)
stages.append(stage)
else:
# Sorting by object similarity but this is not a patches view, so
# arrange the samples in order of their first occuring label
result_sample_ids = _unique_no_sort(sample_ids)
stage = fos.Select(result_sample_ids, ordered=True)
stages.append(stage)
if k is not None:
_ids = [ObjectId(_id) for _id in ids]
stage = fos.FilterLabels(patches_field, F("_id").is_in(_ids))
stages.append(stage)
if _mongo:
pipeline = []
for stage in stages:
stage.validate(samples)
pipeline.extend(stage.to_mongo(samples))
return pipeline
view = samples
for stage in stages:
view = view.add_stage(stage)
return view
def _parse_query(self, query):
if query is None:
raise ValueError("At least one query must be provided")
if isinstance(query, np.ndarray):
# Query by vector(s)
if query.size == 0:
raise ValueError("At least one query vector must be provided")
return query
if etau.is_str(query):
query = [query]
else:
query = list(query)
if not query:
raise ValueError("At least one query must be provided")
if etau.is_numeric(query[0]):
return np.asarray(query)
try:
ObjectId(query[0])
is_prompts = False
except:
is_prompts = True
if is_prompts:
if not self.supports_prompts:
raise ValueError(
"Invalid query '%s'; this model does not support prompts"
% query[0]
)
model = self.get_model()
with model:
return model.embed_prompts(query)
return query
def _kneighbors(
self,
query=None,
k=None,
reverse=False,
aggregation=None,
return_dists=False,
):
"""Returns the k-nearest neighbors for the given query.
This method should only return results from the current :meth:`view`.
Args:
query (None): the query, which can be any of the following:
- an ID or list of IDs for which to return neighbors
- an embedding or ``num_queries x num_dim`` array of
embeddings for which to return neighbors
- Some backends may also support ``None``, in which case the
neighbors for all points in the current :meth:`view` are
returned
k (None): the number of neighbors to return. Some backends may
enforce upper bounds on this parameter
reverse (False): whether to sort by least similarity (True) or
greatest similarity (False). Some backends may not support
least similarity
aggregation (None): an optional aggregation method to use when
multiple queries are provided. All backends must support
``"mean"``, which averages query vectors prior to searching.
Backends may support additional options as well
return_dists (False): whether to return query-neighbor distances
Returns:
the query result, in one of the following formats:
- a ``(sample_ids, label_ids, dists)`` tuple, when
``return_dists`` is True
- a ``(sample_ids, label_ids)`` tuple, when ``return_dists``
is False
In the above, ``sample_ids`` and ``label_ids`` (if applicable)
contain the IDs of the nearest neighbors, in one of the following
formats:
- a list of nearest neighbor IDs, when a single query ID or
vector is provided, **or** when an ``aggregation`` is
provided
- a list of lists of nearest neighbor IDs, when multiple
query IDs/vectors and no ``aggregation`` is provided
and ``dists`` contains the corresponding query-neighbor distances
for each result.
If the backend supports full index queries (``query=None``), then
``inds`` are returned rather than ``(sample_ids, label_ids)``, in
the following format:
- a list of arrays of the **integer indexes** (not IDs) of
nearest neighbor points for every vector in the index, when
no query is provided
"""
raise NotImplementedError("subclass must implement _kneighbors()")
def get_model(self):
"""Returns the stored model for this index.
Returns:
a :class:`fiftyone.core.models.Model`
"""
if self._model is None:
model = self.config.model
if model is None:
raise ValueError("These results don't have a stored model")
if etau.is_str(model):
model_kwargs = self.config.model_kwargs or {}
model = foz.load_zoo_model(model, **model_kwargs)
self._model = model
return self._model
def compute_embeddings(
self,
samples,
model=None,
batch_size=None,
num_workers=None,
skip_failures=True,
skip_existing=False,
warn_existing=False,
force_square=False,
alpha=None,
progress=None,
):
"""Computes embeddings for the given samples using this backend's
model.
Args:
samples: a :class:`fiftyone.core.collections.SampleCollection`
model (None): a :class:`fiftyone.core.models.Model` to apply. If
not provided, these results must have been created with a
stored model, which will be used by default
batch_size (None): an optional batch size to use when computing
embeddings. Only applicable when a ``model`` is provided
num_workers (None): the number of workers to use when loading
images. Only applicable when a Torch-based model is being used
to compute embeddings
skip_failures (True): whether to gracefully continue without
raising an error if embeddings cannot be generated for a sample
skip_existing (False): whether to skip generating embeddings for
sample/label IDs that are already in the index
warn_existing (False): whether to log a warning if any IDs already
exist in the index
force_square (False): whether to minimally manipulate the patch
bounding boxes into squares prior to extraction. Only
applicable when a ``model`` and ``patches_field`` are specified
alpha (None): an optional expansion/contraction to apply to the
patches before extracting them, in ``[-1, inf)``. If provided,
the length and width of the box are expanded (or contracted,
when ``alpha < 0``) by ``(100 * alpha)%``. For example, set
``alpha = 0.1`` to expand the boxes by 10%, and set
``alpha = -0.1`` to contract the boxes by 10%. Only applicable
when a ``model`` and ``patches_field`` are specified
progress (None): whether to render a progress bar (True/False), use
the default value ``fiftyone.config.show_progress_bars``
(None), or a progress callback function to invoke instead
Returns:
a tuple of:
- a ``num_embeddings x num_dims`` array of embeddings
- a ``num_embeddings`` array of sample IDs
- a ``num_embeddings`` array of label IDs, if applicable, or else
``None``
"""
if model is None:
model = self.get_model()
if skip_existing:
if self.config.patches_field is not None:
index_ids = self.label_ids
else:
index_ids = self.sample_ids
if index_ids is not None:
samples = fbu.skip_ids(
samples,
index_ids,
patches_field=self.config.patches_field,
warn_existing=warn_existing,
)
else:
logger.warning(
"This index does not support skipping existing IDs"
)
if self.config.roi_field is not None:
patches_field = self.config.roi_field
handle_missing = "image"
agg_fcn = lambda e: np.mean(e, axis=0)
else:
patches_field = self.config.patches_field
handle_missing = "skip"
agg_fcn = None
return fbu.get_embeddings(
samples,
model=model,
patches_field=patches_field,
force_square=force_square,
alpha=alpha,
handle_missing=handle_missing,
agg_fcn=agg_fcn,
batch_size=batch_size,
num_workers=num_workers,
skip_failures=skip_failures,
progress=progress,
)
@classmethod
def _from_dict(cls, d, samples, config, brain_key):
"""Builds a :class:`SimilarityIndex` from a JSON representation of it.
Args:
d: a JSON dict
samples: the :class:`fiftyone.core.collections.SampleCollection`
for the run
config: the :class:`SimilarityConfig` for the run
brain_key: the brain key
Returns:
a :class:`SimilarityIndex`
"""
raise NotImplementedError("subclass must implement _from_dict()")
class DuplicatesMixin(object):
"""Mixin for :class:`SimilarityIndex` instances that support duplicate
detection operations.
Similarity backends can expose this mixin simply by implementing
:meth:`_radius_neighbors`.
"""
def __init__(self):
self._thresh = None
self._unique_ids = None
self._duplicate_ids = None
self._neighbors_map = None
@property
def thresh(self):
"""The threshold used by the last call to :meth:`find_duplicates` or
:meth:`find_unique`.
"""
return self._thresh
@property
def unique_ids(self):
"""A list of unique IDs from the last call to :meth:`find_duplicates`
or :meth:`find_unique`.
"""
return self._unique_ids
@property
def duplicate_ids(self):
"""A list of duplicate IDs from the last call to
:meth:`find_duplicates` or :meth:`find_unique`.
"""
return self._duplicate_ids
@property
def neighbors_map(self):
"""A dictionary mapping IDs to lists of ``(dup_id, dist)`` tuples from
the last call to :meth:`find_duplicates`.
"""
return self._neighbors_map
def _radius_neighbors(self, query=None, thresh=None, return_dists=False):
"""Returns the neighbors within the given distance threshold for the
given query.
This method should only return results from the current :meth:`view`.
Args:
query (None): the query, which can be any of the following:
- an ID or list of IDs for which to return neighbors
- an embedding or ``num_queries x num_dim`` array of
embeddings for which to return neighbors
- ``None``, in which case the neighbors for all points in the
current :meth:`view` are returned
thresh (None): the distance threshold to use
return_dists (False): whether to return query-neighbor distances
Returns:
the query result, in one of the following formats:
- a ``(sample_ids, label_ids, dists)`` tuple, when
``return_dists`` is True
- a ``(sample_ids, label_ids)`` tuple, when ``return_dists``
is False
In the above, ``sample_ids`` and ``label_ids`` (if applicable)
contain the IDs of the nearest neighbors, in one of the following
formats:
- a list of nearest neighbor IDs, when a single query ID or
vector is provided, **or** when an ``aggregation`` is
provided
- a list of lists of nearest neighbor IDs, when multiple
query IDs/vectors and no ``aggregation`` is provided
and ``dists`` contains the corresponding query-neighbor distances
for each result.
If the backend supports full index queries (``query=None``), then
``inds`` are returned rather than ``(sample_ids, label_ids)``, in
the following format:
- a list of arrays of the **integer indexes** (not IDs) of
nearest neighbor points for every vector in the index, when
no query is provided
"""
raise NotImplementedError(
"subclass must implement _radius_neighbors()"
)
def find_duplicates(self, thresh=None, fraction=None):
"""Queries the index to find near-duplicate examples based on the
provided parameters.
Calling this method populates the :meth:`unique_ids`,
:meth:`duplicate_ids`, :attr:`neighbors_map`, and :attr:`thresh`
properties of this object with the results of the query.
Use :meth:`duplicates_view` and :meth:`visualize_duplicates` to analyze
the results generated by this method.
Args:
thresh (None): a distance threshold to use to determine duplicates.
If specified, the non-duplicate set will be the (approximately)
largest set such that all pairwise distances between
non-duplicate examples are greater than this threshold
fraction (None): a desired fraction of images/patches to tag as
duplicates, in ``[0, 1]``. In this case ``thresh`` is
automatically tuned to achieve the desired fraction of
duplicates
"""
if self.config.patches_field is not None:
logger.info("Computing duplicate patches...")
ids = self.current_label_ids
else:
logger.info("Computing duplicate samples...")
ids = self.current_sample_ids
# Detect duplicates
if fraction is not None:
num_keep = int(round(min(max(0, 1.0 - fraction), 1) * len(ids)))
unique_ids, thresh = self._remove_duplicates_count(
num_keep, ids, init_thresh=thresh
)
else:
unique_ids = self._remove_duplicates_thresh(thresh, ids)
_unique_ids = set(unique_ids)
duplicate_ids = [_id for _id in ids if _id not in _unique_ids]
# Locate nearest non-duplicate for each duplicate
if unique_ids and duplicate_ids:
if self.config.patches_field is not None:
unique_view = self._samples.select_labels(
ids=unique_ids, fields=self.config.patches_field
)
else:
unique_view = self._samples.select(unique_ids)
with self.use_view(unique_view):
_sample_ids, _label_ids, dists = self._kneighbors(
query=duplicate_ids, k=1, return_dists=True
)
if self.config.patches_field is not None:
nearest_ids = _label_ids
else:
nearest_ids = _sample_ids
neighbors_map = defaultdict(list)
for dup_id, _ids, _dists in zip(duplicate_ids, nearest_ids, dists):
neighbors_map[_ids[0]].append((dup_id, _dists[0]))
neighbors_map = {
k: sorted(v, key=lambda t: t[1])
for k, v in neighbors_map.items()
}
else:
neighbors_map = {}
logger.info("Duplicates computation complete")
self._thresh = thresh
self._unique_ids = unique_ids
self._duplicate_ids = duplicate_ids
self._neighbors_map = neighbors_map
def find_unique(self, count):
"""Queries the index to select a subset of examples of the specified
size that are maximally unique with respect to each other.
Calling this method populates the :meth:`unique_ids`,
:meth:`duplicate_ids`, and :attr:`thresh` properties of this object
with the results of the query.
Use :meth:`unique_view` and :meth:`visualize_unique` to analyze the
results generated by this method.
Args:
count: the desired number of unique examples
"""
if self.config.patches_field is not None:
logger.info("Computing unique patches...")
ids = self.current_label_ids
else:
logger.info("Computing unique samples...")
ids = self.current_sample_ids
unique_ids, thresh = self._remove_duplicates_count(count, ids)
_unique_ids = set(unique_ids)
duplicate_ids = [_id for _id in ids if _id not in _unique_ids]
logger.info("Uniqueness computation complete")
self._thresh = thresh
self._unique_ids = unique_ids
self._duplicate_ids = duplicate_ids
self._neighbors_map = None
def _remove_duplicates_count(self, num_keep, ids, init_thresh=None):
if init_thresh is not None:
thresh = init_thresh
else:
thresh = 1
if num_keep <= 0:
logger.info(
"threshold: -, kept: %d, target: %d", num_keep, num_keep
)
return set(), None
if num_keep >= len(ids):
logger.info(
"threshold: -, kept: %d, target: %d", num_keep, num_keep
)
return set(ids), None
thresh_lims = [0, None]
num_target = num_keep
num_keep = -1
while True:
keep_ids = self._remove_duplicates_thresh(thresh, ids)
num_keep_last = num_keep
num_keep = len(keep_ids)
logger.info(
"threshold: %f, kept: %d, target: %d",
thresh,
num_keep,
num_target,
)
if num_keep == num_target or (
num_keep == num_keep_last
and thresh_lims[1] is not None
and thresh_lims[1] - thresh_lims[0] < 1e-6
):
break
if num_keep < num_target:
# Need to decrease threshold
thresh_lims[1] = thresh
thresh = 0.5 * (thresh_lims[0] + thresh)
else:
# Need to increase threshold
thresh_lims[0] = thresh
if thresh_lims[1] is not None:
thresh = 0.5 * (thresh + thresh_lims[1])
else:
thresh *= 2
return keep_ids, thresh
def _remove_duplicates_thresh(self, thresh, ids):
nearest_inds = self._radius_neighbors(thresh=thresh)
n = len(ids)
keep = set(range(n))
for ind in range(n):
if ind in keep:
keep -= {i for i in nearest_inds[ind] if i > ind}
return [ids[i] for i in keep]
def plot_distances(self, bins=100, log=False, backend="plotly", **kwargs):
"""Plots a histogram of the distance between each example and its
nearest neighbor.
If `:meth:`find_duplicates` or :meth:`find_unique` has been executed,
the threshold used is also indicated on the plot.
Args:
bins (100): the number of bins to use
log (False): whether to use a log scale y-axis
backend ("plotly"): the plotting backend to use. Supported values
are ``("plotly", "matplotlib")``
**kwargs: keyword arguments for the backend plotting method
Returns:
one of the following:
- a :class:`fiftyone.core.plots.plotly.PlotlyNotebookPlot`, if
you are working in a notebook context and the plotly backend is
used
- a plotly or matplotlib figure, otherwise
"""
metric = self.config.metric
thresh = self.thresh
_, dists = self._kneighbors(k=1, return_dists=True)
dists = np.array([d[0] for d in dists])
if backend == "matplotlib":
return _plot_distances_mpl(
dists, metric, thresh, bins, log, **kwargs
)
return _plot_distances_plotly(
dists, metric, thresh, bins, log, **kwargs
)
def duplicates_view(
self,
type_field=None,
id_field=None,
dist_field=None,
sort_by="distance",
reverse=False,
):
"""Returns a view that contains only the duplicate examples and their
corresponding nearest non-duplicate examples generated by the last call
to :meth:`find_duplicates`.
If you are analyzing patches, the returned view will be a
:class:`fiftyone.core.patches.PatchesView`.
The examples are organized so that each non-duplicate is immediately
followed by all duplicate(s) that are nearest to it.
Args:
type_field (None): the name of a string field in which to store
``"nearest"`` and ``"duplicate"`` labels. The field is created
if necessary
id_field (None): the name of a string field in which to store the
ID of the nearest non-duplicate for each example in the view.
The field is created if necessary
dist_field (None): the name of a float field in which to store the
distance of each example to its nearest non-duplicate example.
The field is created if necessary
sort_by ("distance"): specifies how to sort the groups of duplicate
examples. The supported values are:
- ``"distance"``: sort the groups by the distance between the
non-duplicate and its (nearest, if multiple) duplicate
- ``"count"``: sort the groups by the number of duplicate
examples
reverse (False): whether to sort in descending order
Returns:
a :class:`fiftyone.core.view.DatasetView`
"""
if self.neighbors_map is None:
raise ValueError(
"You must first call `find_duplicates()` to generate results"
)
samples = self.view
patches_field = self.config.patches_field
neighbors_map = self.neighbors_map
if patches_field is not None and not isinstance(
samples, fop.PatchesView
):
samples = samples.to_patches(patches_field)
if sort_by == "distance":
key = lambda kv: min(e[1] for e in kv[1])
elif sort_by == "count":
key = lambda kv: len(kv[1])
else:
raise ValueError(
"Invalid sort_by='%s'; supported values are %s"
% (sort_by, ("distance", "count"))
)
existing_ids = set(samples.values("id"))
neighbors = [
(k, v) for k, v in neighbors_map.items() if k in existing_ids
]
ids = []
types = {}
nearest_ids = {}
dists = {}
for _id, duplicates in sorted(neighbors, key=key, reverse=reverse):
ids.append(_id)
types[_id] = "nearest"
nearest_ids[_id] = _id
dists[_id] = 0.0
for dup_id, dist in duplicates:
ids.append(dup_id)
types[dup_id] = "duplicate"
nearest_ids[dup_id] = _id
dists[dup_id] = dist
if type_field is not None:
samples.set_values(type_field, types, key_field="id")
if id_field is not None:
samples.set_values(id_field, nearest_ids, key_field="id")
if dist_field is not None:
samples.set_values(dist_field, dists, key_field="id")
return samples.select(ids, ordered=True)
def unique_view(self):
"""Returns a view that contains only the unique examples generated by
the last call to :meth:`find_duplicates` or :meth:`find_unique`.
If you are analyzing patches, the returned view will be a
:class:`fiftyone.core.patches.PatchesView`.
Returns:
a :class:`fiftyone.core.view.DatasetView`
"""
if self.unique_ids is None:
raise ValueError(
"You must first call `find_unique()` or `find_duplicates()` "
"to generate results"
)
samples = self.view
patches_field = self.config.patches_field
unique_ids = self.unique_ids
if patches_field is not None and not isinstance(
samples, fop.PatchesView
):
samples = samples.to_patches(patches_field)
return samples.select(unique_ids)
def visualize_duplicates(self, visualization, backend="plotly", **kwargs):
"""Generates an interactive scatterplot of the results generated by the
last call to :meth:`find_duplicates`.
The ``visualization`` argument can be any visualization computed on the
same dataset (or subset of it) as long as it contains every
sample/object in the view whose results you are visualizing.
The points are colored based on the following partition:
- "duplicate": duplicate example
- "nearest": nearest neighbor of a duplicate example
- "unique": the remaining unique examples
Edges are also drawn between each duplicate and its nearest
non-duplicate neighbor.
You can attach plots generated by this method to an App session via its
:attr:`fiftyone.core.session.Session.plots` attribute, which will
automatically sync the session's view with the currently selected
points in the plot.
Args:
visualization: a
:class:`fiftyone.brain.visualization.VisualizationResults`
instance to use to visualize the results
backend ("plotly"): the plotting backend to use. Supported values
are ``("plotly", "matplotlib")``
**kwargs: keyword arguments for the backend plotting method:
- "plotly" backend: :meth:`fiftyone.core.plots.plotly.scatterplot`
- "matplotlib" backend: :meth:`fiftyone.core.plots.matplotlib.scatterplot`
Returns:
a :class:`fiftyone.core.plots.base.InteractivePlot`
"""
if self.neighbors_map is None:
raise ValueError(
"You must first call `find_duplicates()` to generate results"
)
samples = self.view
duplicate_ids = self.duplicate_ids
neighbors_map = self.neighbors_map
patches_field = self.config.patches_field
dup_ids = set(duplicate_ids)
nearest_ids = set(neighbors_map.keys())
with visualization.use_view(samples, allow_missing=True):
if patches_field is not None:
ids = visualization.current_label_ids
else:
ids = visualization.current_sample_ids
labels = []
for _id in ids:
if _id in dup_ids:
label = "duplicate"
elif _id in nearest_ids:
label = "nearest"
else:
label = "unique"
labels.append(label)
if backend == "plotly":
kwargs["edges"] = _build_edges(ids, neighbors_map)
kwargs["edges_title"] = "neighbors"
kwargs["labels_title"] = "type"
return visualization.visualize(
labels=labels,
classes=["unique", "nearest", "duplicate"],
backend=backend,
**kwargs,
)
def visualize_unique(self, visualization, backend="plotly", **kwargs):
"""Generates an interactive scatterplot of the results generated by the
last call to :meth:`find_unique`.
The ``visualization`` argument can be any visualization computed on the
same dataset (or subset of it) as long as it contains every
sample/object in the view whose results you are visualizing.
The points are colored based on the following partition:
- "unique": the unique examples
- "other": the other examples
You can attach plots generated by this method to an App session via its
:attr:`fiftyone.core.session.Session.plots` attribute, which will
automatically sync the session's view with the currently selected
points in the plot.
Args:
visualization: a
:class:`fiftyone.brain.visualization.VisualizationResults`
instance to use to visualize the results
backend ("plotly"): the plotting backend to use. Supported values
are ``("plotly", "matplotlib")``
**kwargs: keyword arguments for the backend plotting method:
- "plotly" backend: :meth:`fiftyone.core.plots.plotly.scatterplot`
- "matplotlib" backend: :meth:`fiftyone.core.plots.matplotlib.scatterplot`
Returns:
a :class:`fiftyone.core.plots.base.InteractivePlot`
"""
if self.unique_ids is None:
raise ValueError(
"You must first call `find_unique()` to generate results"
)
samples = self.view
unique_ids = self.unique_ids
patches_field = self.config.patches_field
unique_ids = set(unique_ids)
with visualization.use_view(samples, allow_missing=True):
if patches_field is not None:
ids = visualization.current_label_ids
else:
ids = visualization.current_sample_ids
labels = []
for _id in ids:
if _id in unique_ids:
label = "unique"
else:
label = "other"
labels.append(label)
return visualization.visualize(
labels=labels,
classes=["other", "unique"],
backend=backend,
**kwargs,
)
def _unique_no_sort(values):
seen = set()
return [v for v in values if v not in seen and not seen.add(v)]
def _build_edges(ids, neighbors_map):
inds_map = {_id: idx for idx, _id in enumerate(ids)}
edges = []
for nearest_id, duplicates in neighbors_map.items():
nearest_ind = inds_map[nearest_id]
for dup_id, _ in duplicates:
dup_ind = inds_map[dup_id]
edges.append((dup_ind, nearest_ind))
return np.array(edges)
def _plot_distances_plotly(dists, metric, thresh, bins, log, **kwargs):
import plotly.graph_objects as go
import fiftyone.core.plots.plotly as fopl
counts, edges = np.histogram(dists, bins=bins)
left_edges = edges[:-1]
widths = edges[1:] - edges[:-1]
customdata = np.stack((edges[:-1], edges[1:]), axis=1)
hover_lines = [
"count: %{y}",
"distance: [%{customdata[0]:.2f}, %{customdata[1]:.2f}]",
]
hovertemplate = "
".join(hover_lines) + ""
bar = go.Bar(
x=left_edges,
y=counts,
width=widths,
customdata=customdata,
offset=0,
marker_color="#FF6D04",
hovertemplate=hovertemplate,
showlegend=False,
)
traces = [bar]
if thresh is not None:
line = go.Scatter(
x=[thresh, thresh],
y=[0, max(counts)],
mode="lines",
line=dict(color="#17191C", width=3),
hovertemplate="thresh: %{x}",
showlegend=False,
)
traces.append(line)
figure = go.Figure(traces)
figure.update_layout(
xaxis_title="nearest neighbor distance (%s)" % metric,
yaxis_title="count",
hovermode="x",
yaxis_rangemode="tozero",
)
if log:
figure.update_layout(yaxis_type="log")
figure.update_layout(**fopl._DEFAULT_LAYOUT)
figure.update_layout(**kwargs)
if foc.is_jupyter_context():
figure = fopl.PlotlyNotebookPlot(figure)
return figure
def _plot_distances_mpl(
dists, metric, thresh, bins, log, ax=None, figsize=None, **kwargs
):
import matplotlib.pyplot as plt
if ax is None:
fig, ax = plt.subplots()
else:
fig = ax.figure
counts, edges = np.histogram(dists, bins=bins)
left_edges = edges[:-1]
widths = edges[1:] - edges[:-1]
ax.bar(
left_edges,
counts,
width=widths,
align="edge",
color="#FF6D04",
**kwargs,
)
if thresh is not None:
ax.vlines(thresh, 0, max(counts), color="#17191C", linewidth=3)
if log:
ax.set_yscale("log")
ax.set_xlabel("nearest neighbor distance (%s)" % metric)
ax.set_ylabel("count")
if figsize is not None:
fig.set_size_inches(*figsize)
plt.tight_layout()
return fig
================================================
FILE: fiftyone/brain/visualization.py
================================================
"""
Visualization interface.
| Copyright 2017-2026, Voxel51, Inc.
| `voxel51.com `_
|
"""
from copy import deepcopy
import inspect
import logging
from packaging import version
import numpy as np
import sklearn
import sklearn.decomposition as skd
import sklearn.manifold as skm
import eta.core.utils as etau
import fiftyone.brain as fb
import fiftyone.core.brain as fob
import fiftyone.core.expressions as foe
import fiftyone.core.fields as fof
import fiftyone.core.plots as fop
import fiftyone.core.utils as fou
import fiftyone.core.validation as fov
fbu = fou.lazy_import("fiftyone.brain.internal.core.utils")
umap = fou.lazy_import("umap")
logger = logging.getLogger(__name__)
_DEFAULT_MODEL = "mobilenet-v2-imagenet-torch"
_DEFAULT_BATCH_SIZE = None
def compute_visualization(
samples,
patches_field,
embeddings,
points,
create_index,
points_field,
brain_key,
num_dims,
method,
similarity_index,
model,
model_kwargs,
force_square,
alpha,
batch_size,
num_workers,
skip_failures,
progress,
**kwargs,
):
"""See ``fiftyone/brain/__init__.py``."""
fov.validate_collection(samples)
if method == "manual" and points is None:
raise ValueError(
"You must provide your own `points` when `method='manual'`"
)
if points is not None:
method = "manual"
model = None
embeddings = None
embeddings_field = None
num_dims = _get_dimension(points)
if create_index and points_field is None:
points_field = brain_key
if points_field is not None and num_dims != 2:
raise ValueError("`points_field` is only supported when `num_dims=2`")
if etau.is_str(embeddings):
embeddings_field, embeddings_exist = fbu.parse_data_field(
samples,
embeddings,
patches_field=patches_field,
data_type="embeddings",
)
embeddings = None
else:
embeddings_field = None
embeddings_exist = None
if points_field is not None:
points_field, _ = fbu.parse_data_field(
samples,
points_field,
patches_field=patches_field,
data_type="points",
)
if etau.is_str(similarity_index):
similarity_index = samples.load_brain_results(similarity_index)
if (
model is None
and points is None
and embeddings is None
and similarity_index is None
and not embeddings_exist
):
model = _DEFAULT_MODEL
if batch_size is None:
batch_size = _DEFAULT_BATCH_SIZE
config = _parse_config(
method,
embeddings_field=embeddings_field,
points_field=points_field,
similarity_index=similarity_index,
model=model,
model_kwargs=model_kwargs,
patches_field=patches_field,
num_dims=num_dims,
**kwargs,
)
brain_method = config.build()
brain_method.ensure_requirements()
if brain_key is not None:
brain_method.register_run(samples, brain_key)
if points is None:
embeddings, sample_ids, label_ids = fbu.get_embeddings(
samples,
model=model,
model_kwargs=model_kwargs,
patches_field=patches_field,
embeddings_field=embeddings_field,
embeddings=embeddings,
similarity_index=similarity_index,
force_square=force_square,
alpha=alpha,
batch_size=batch_size,
num_workers=num_workers,
skip_failures=skip_failures,
progress=progress,
)
logger.info("Generating visualization...")
points = brain_method.fit(embeddings)
else:
points, sample_ids, label_ids = fbu.parse_data(
samples,
patches_field=patches_field,
data=points,
data_type="points",
)
if points_field is not None:
_generate_spatial_index(
samples,
points,
points_field,
sample_ids,
label_ids=label_ids,
patches_field=patches_field,
create_index=create_index,
progress=progress,
)
results = VisualizationResults(
samples,
config,
brain_key,
points,
sample_ids=sample_ids,
label_ids=label_ids,
)
brain_method.save_run_results(samples, brain_key, results)
return results
def values(results, path_or_expr):
samples = results.view
patches_field = results.config.patches_field
if patches_field is not None:
ids = results.current_label_ids
else:
ids = results.current_sample_ids
return fbu.get_values(
samples, path_or_expr, ids, patches_field=patches_field
)
def visualize(
results,
labels=None,
sizes=None,
classes=None,
backend="plotly",
**kwargs,
):
points = results.current_points
samples = results.view
patches_field = results.config.patches_field
good_inds = results._curr_good_inds
if patches_field is not None:
ids = results.current_label_ids
else:
ids = results.current_sample_ids
if good_inds is not None:
if etau.is_container(labels) and not _is_expr(labels):
labels = fbu.filter_values(
labels, good_inds, patches_field=patches_field
)
if etau.is_container(sizes) and not _is_expr(sizes):
sizes = fbu.filter_values(
sizes, good_inds, patches_field=patches_field
)
if labels is not None and _is_expr(labels):
labels = fbu.get_values(
samples, labels, ids, patches_field=patches_field
)
if sizes is not None and _is_expr(sizes):
sizes = fbu.get_values(
samples, sizes, ids, patches_field=patches_field
)
return fop.scatterplot(
points,
samples=samples,
ids=ids,
link_field=patches_field,
labels=labels,
sizes=sizes,
classes=classes,
backend=backend,
**kwargs,
)
def _is_expr(arg):
return isinstance(arg, (foe.ViewExpression, dict))
def _parse_config(name, **kwargs):
if name is None:
name = fb.brain_config.default_visualization_method
if inspect.isclass(name):
return name(**kwargs)
methods = fb.brain_config.visualization_methods
if name not in methods:
raise ValueError(
"Unsupported method '%s'. The available methods are %s"
% (name, sorted(methods.keys()))
)
params = deepcopy(methods[name])
config_cls = kwargs.pop("config_cls", None)
if config_cls is None:
config_cls = params.pop("config_cls", None)
if config_cls is None:
raise ValueError(
"Visualization method '%s' has no `config_cls`" % name
)
if etau.is_str(config_cls):
config_cls = etau.get_class(config_cls)
params.update(**kwargs)
return config_cls(**params)
def _get_dimension(points):
if isinstance(points, dict):
points = next(iter(points.values()), None)
if isinstance(points, list):
points = next(iter(points), None)
if points is None:
return 2
return points.shape[-1]
def _generate_spatial_index(
samples,
points,
points_field,
sample_ids,
label_ids=None,
patches_field=None,
create_index=True,
progress=False,
):
# Indexes are not currently usable on patch visualizations
if create_index and patches_field is not None:
create_index = False
dataset = samples._root_dataset
if patches_field is not None:
_, points_field = dataset._get_label_field_path(
patches_field, points_field
)
logger.info("Generating spatial index in field '%s'...", points_field)
dataset.add_sample_field(
points_field, fof.ListField, subfield=fof.FloatField
)
points = points.astype(float)
if create_index:
min_val, max_val = points.min(), points.max()
dataset.create_index([(points_field, "2d")], min=min_val, max=max_val)
points = points.tolist()
if patches_field is not None:
values = dict(zip(label_ids, points))
dataset.set_label_values(points_field, values, progress=progress)
else:
values = dict(zip(sample_ids, points))
dataset.set_values(
points_field, values, key_field="id", progress=progress
)
class VisualizationResults(fob.BrainResults):
"""Class storing the results of
:meth:`fiftyone.brain.compute_visualization`.
Args:
samples: the :class:`fiftyone.core.collections.SampleCollection` used
config: the :class:`VisualizationConfig` used
brain_key: the brain key
points: a ``num_points x num_dims`` array of visualization points
sample_ids (None): a ``num_points`` array of sample IDs
label_ids (None): a ``num_points`` array of label IDs, if applicable
backend (None): a :class:`Visualization` backend
"""
def __init__(
self,
samples,
config,
brain_key,
points,
sample_ids=None,
label_ids=None,
backend=None,
):
super().__init__(samples, config, brain_key, backend=backend)
if sample_ids is None:
sample_ids, label_ids = fbu.get_ids(
samples,
patches_field=config.patches_field,
data=points,
data_type="points",
)
self.points = points
self.sample_ids = sample_ids
self.label_ids = label_ids
self._last_view = None
self._curr_view = None
self._curr_points = None
self._curr_sample_ids = None
self._curr_label_ids = None
self._curr_keep_inds = None
self._curr_good_inds = None
self.use_view(samples)
def __enter__(self):
self._last_view = self.view
return self
def __exit__(self, *args):
self.use_view(self._last_view)
self._last_view = None
@property
def config(self):
"""The :class:`VisualizationConfig` for the results."""
return self._config
@property
def index_size(self):
"""The number of active points in the index.
If :meth:`use_view` has been called to restrict the index, this
property will reflect the size of the active index.
"""
return len(self._curr_sample_ids)
@property
def total_index_size(self):
"""The total number of data points in the index.
If :meth:`use_view` has been called to restrict the index, this value
may be larger than the current :meth:`index_size`.
"""
return len(self.points)
@property
def missing_size(self):
"""The total number of data points in :meth:`view` that are missing
from this index.
This property is only applicable when :meth:`use_view` has been called,
and it will be ``None`` if no data points are missing.
"""
good = self._curr_good_inds
if good is None:
return None
return good.size - np.count_nonzero(good)
@property
def current_points(self):
"""The currently active points in the index.
If :meth:`use_view` has been called, this may be a subset of the full
index.
"""
return self._curr_points
@property
def current_sample_ids(self):
"""The sample IDs of the currently active points in the index.
If :meth:`use_view` has been called, this may be a subset of the full
index.
"""
return self._curr_sample_ids
@property
def current_label_ids(self):
"""The label IDs of the currently active points in the index, or
``None`` if not applicable.
If :meth:`use_view` has been called, this may be a subset of the full
index.
"""
return self._curr_label_ids
@property
def view(self):
"""The :class:`fiftyone.core.collections.SampleCollection` against
which results are currently being generated.
If :meth:`use_view` has been called, this view may be different than
the collection on which the full index was generated.
"""
return self._curr_view
@property
def has_spatial_index(self):
"""Whether these results have a spatial index.
Use :meth:`index_points` to add a spatial index to an existing set of
visualization results.
"""
return self.config.points_field is not None
def use_view(
self, sample_collection, allow_missing=True, warn_missing=False
):
"""Restricts the index to the provided view.
Subsequent calls to methods on this instance will only contain results
from the specified view rather than the full index.
Use :meth:`clear_view` to reset to the full index. Or, equivalently,
use the context manager interface as demonstrated below to
automatically reset the view when the context exits.
Example usage::
import fiftyone as fo
import fiftyone.brain as fob
import fiftyone.zoo as foz
dataset = foz.load_zoo_dataset("quickstart")
results = fob.compute_visualization(dataset)
print(results.index_size) # 200
view = dataset.take(50)
with results.use_view(view):
print(results.index_size) # 50
plot = results.visualize()
plot.show()
Args:
sample_collection: a
:class:`fiftyone.core.collections.SampleCollection`
allow_missing (True): whether to allow the provided collection to
contain data points that this index does not contain (True) or
whether to raise an error in this case (False)
warn_missing (False): whether to log a warning if the provided
collection contains data points that this index does not
contain
Returns:
self
"""
sample_ids, label_ids, keep_inds, good_inds = fbu.filter_ids(
sample_collection,
self.sample_ids,
self.label_ids,
patches_field=self._config.patches_field,
allow_missing=allow_missing,
warn_missing=warn_missing,
)
if keep_inds is not None:
points = self.points[keep_inds, :]
else:
points = self.points
self._curr_view = sample_collection
self._curr_points = points
self._curr_sample_ids = sample_ids
self._curr_label_ids = label_ids
self._curr_keep_inds = keep_inds
self._curr_good_inds = good_inds
return self
def clear_view(self):
"""Clears the view set by :meth:`use_view`, if any.
Subsequent operations will be performed on the full index.
"""
self.use_view(self._samples)
def values(self, path_or_expr):
"""Extracts a flat list of values from the given field or expression
corresponding to the current :meth:`view`.
This method always returns values in the same order as
:meth:`current_points`, :meth:`current_sample_ids`, and
:meth:`current_label_ids`.
Args:
path_or_expr: the values to extract, which can be:
- the name of a sample field or ``embedded.field.name`` from
which to extract numeric or string values
- a :class:`fiftyone.core.expressions.ViewExpression`
defining numeric or string values to compute via
:meth:`fiftyone.core.collections.SampleCollection.values`
Returns:
a list of values
"""
return values(self, path_or_expr)
def visualize(
self,
labels=None,
sizes=None,
classes=None,
backend="plotly",
**kwargs,
):
"""Generates an interactive scatterplot of the visualization results
for the current :meth:`view`.
This method supports 2D or 3D visualizations, but interactive point
selection is only available in 2D.
You can use the ``labels`` parameters to define a coloring for the
points, and you can use the ``sizes`` parameter to scale the sizes of
the points.
You can attach plots generated by this method to an App session via its
:attr:`fiftyone.core.session.Session.plots` attribute, which will
automatically sync the session's view with the currently selected
points in the plot.
Args:
labels (None): data to use to color the points. Can be any of the
following:
- the name of a sample field or ``embedded.field.name`` from
which to extract numeric or string values
- a :class:`fiftyone.core.expressions.ViewExpression`
defining numeric or string values to compute via
:meth:`fiftyone.core.collections.SampleCollection.values`
- a list or array-like of numeric or string values
- a list of lists of numeric or string values, if the data in
this visualization corresponds to a label list field like
:class:`fiftyone.core.labels.Detections`
sizes (None): data to use to scale the sizes of the points. Can be
any of the following:
- the name of a sample field or ``embedded.field.name`` from
which to extract numeric values
- a :class:`fiftyone.core.expressions.ViewExpression`
defining numeric values to compute via
:meth:`fiftyone.core.collections.SampleCollection.values`
- a list or array-like of numeric values
- a list of lists of numeric values, if the data in this
visualization corresponds to a label list field like
:class:`fiftyone.core.labels.Detections`
classes (None): an optional list of classes whose points to plot.
Only applicable when ``labels`` contains strings
backend ("plotly"): the plotting backend to use. Supported values
are ``("plotly", "matplotlib")``
**kwargs: keyword arguments for the backend plotting method:
- "plotly" backend: :meth:`fiftyone.core.plots.plotly.scatterplot`
- "matplotlib" backend: :meth:`fiftyone.core.plots.matplotlib.scatterplot`
Returns:
an :class:`fiftyone.core.plots.base.InteractivePlot`
"""
return visualize(
self,
labels=labels,
sizes=sizes,
classes=classes,
backend=backend,
**kwargs,
)
def index_points(
self,
points_field=None,
create_index=True,
progress=None,
):
"""Adds a spatial index for these visualization results to its
dataset's samples.
This method is useful if you want to add a spatial index to existing
visualization results that don't yet have one.
Spatial indexes are highly recommended for large datasets as they
enable efficient querying when lassoing points in embeddings plots.
Args:
points_field (None): an optional field name in which to store the
spatial index. The default is the result's ``brain_key``
create_index (True): whether to create a database index for the
points
progress (None): whether to render a progress bar (True/False),
use the default value ``fiftyone.config.show_progress_bars``
(None), or a progress callback function to invoke instead
"""
if points_field is None:
if self.key is None:
raise ValueError(
"You must provide a `points_field` when indexing points "
"that are not associated with a brain key"
)
points_field = self.key
_generate_spatial_index(
self.samples,
self.points,
points_field,
self.sample_ids,
label_ids=self.label_ids,
patches_field=self.config.patches_field,
create_index=create_index,
progress=progress,
)
if self.key is not None:
self.config.points_field = points_field
self.save_config()
def remove_index(self):
"""Removes the spatial index from these visualization results, if one
exists.
"""
points_field = self.config.points_field
if points_field is None:
return
dataset = self.samples._root_dataset
if self.config.patches_field is not None:
_, points_field = dataset._get_label_field_path(
self.config.patches_field, points_field
)
dataset.delete_sample_field(points_field, error_level=1)
if self.key is not None:
self.config.points_field = None
self.save_config()
@classmethod
def _from_dict(cls, d, samples, config, brain_key):
points = np.array(d["points"])
sample_ids = d.get("sample_ids", None)
if sample_ids is not None:
sample_ids = np.array(sample_ids)
label_ids = d.get("label_ids", None)
if label_ids is not None:
label_ids = np.array(label_ids)
return cls(
samples,
config,
brain_key,
points,
sample_ids=sample_ids,
label_ids=label_ids,
)
class VisualizationConfig(fob.BrainMethodConfig):
"""Base class for configuring visualization methods.
Args:
embeddings_field (None): the sample field containing the embeddings,
if one was provided
points_field (None): the name of a field in which to store the
visualization points, if requested
similarity_index (None): the similarity index containing the
embeddings, if one was provided
model (None): the :class:`fiftyone.core.models.Model` or name of the
zoo model that was used to compute embeddings, if known
model_kwargs (None): a dictionary of optional keyword arguments to pass
to the model's ``Config`` when a model name is provided
patches_field (None): the sample field defining the patches being
analyzed, if any
num_dims (2): the dimension of the visualization space
"""
def __init__(
self,
embeddings_field=None,
points_field=None,
similarity_index=None,
model=None,
model_kwargs=None,
patches_field=None,
num_dims=2,
**kwargs,
):
if similarity_index is not None and not etau.is_str(similarity_index):
similarity_index = similarity_index.key
if model is not None and not etau.is_str(model):
model = None
self.embeddings_field = embeddings_field
self.points_field = points_field
self.similarity_index = similarity_index
self.model = model
self.model_kwargs = model_kwargs
self.patches_field = patches_field
self.num_dims = num_dims
super().__init__(**kwargs)
@property
def type(self):
return "visualization"
class Visualization(fob.BrainMethod):
def fit(self, embeddings):
raise NotImplementedError("subclass must implement fit()")
def get_fields(self, samples, brain_key):
fields = []
if self.config.patches_field is not None:
fields.append(self.config.patches_field)
elif self.config.points_field is not None:
fields.append(self.config.points_field)
return fields
def rename(self, samples, key, new_key):
patches_field = self.config.patches_field
points_field = self.config.points_field
dataset = samples._root_dataset
if points_field is not None and points_field == key:
old_path = key
new_path = new_key
if patches_field is not None:
_, old_path = dataset._get_label_field_path(
patches_field, old_path
)
_, new_path = dataset._get_label_field_path(
patches_field, new_path
)
self.config.points_field = new_key
self.update_run_config(samples, key, self.config)
dataset.rename_sample_field(old_path, new_path)
def cleanup(self, samples, key):
patches_field = self.config.patches_field
points_field = self.config.points_field
dataset = samples._root_dataset
if points_field is not None:
if patches_field is not None:
_, points_field = dataset._get_label_field_path(
patches_field, points_field
)
dataset.delete_sample_field(points_field, error_level=1)
class UMAPVisualizationConfig(VisualizationConfig):
"""Configuration for Uniform Manifold Approximation and Projection (UMAP)
embedding visualization.
See https://github.com/lmcinnes/umap for more information about the
supported parameters.
Args:
embeddings_field (None): the sample field containing the embeddings,
if one was provided
points_field (None): the name of a field in which to store the
visualization points, if requested
similarity_index (None): the similarity index containing the
embeddings, if one was provided
model (None): the :class:`fiftyone.core.models.Model` or name of the
zoo model that was used to compute embeddings, if known
model_kwargs (None): a dictionary of optional keyword arguments to pass
to the model's ``Config`` when a model name is provided
patches_field (None): the sample field defining the patches being
analyzed, if any
num_dims (2): the dimension of the visualization space
num_neighbors (15): the number of neighboring points used in local
approximations of manifold structure. Larger values will result in
more global structure being preserved at the loss of detailed local
structure. Typical values are in ``[5, 50]``
metric ("euclidean"): the metric to use when calculating distance
between embeddings. See the UMAP documentation for supported values
min_dist (0.1): the effective minimum distance between embedded
points. This controls how tightly the embedding is allowed compress
points together. Larger values ensure embedded points are more
evenly distributed, while smaller values allow the algorithm to
optimise more accurately with regard to local structure. Typical
values are in ``[0.001, 0.5]``
seed (None): a random seed
verbose (True): whether to log progress
"""
def __init__(
self,
embeddings_field=None,
points_field=None,
similarity_index=None,
model=None,
model_kwargs=None,
patches_field=None,
num_dims=2,
num_neighbors=15,
metric="euclidean",
min_dist=0.1,
seed=None,
verbose=True,
**kwargs,
):
super().__init__(
embeddings_field=embeddings_field,
points_field=points_field,
similarity_index=similarity_index,
model=model,
model_kwargs=model_kwargs,
patches_field=patches_field,
num_dims=num_dims,
**kwargs,
)
self.num_neighbors = num_neighbors
self.metric = metric
self.min_dist = min_dist
self.seed = seed
self.verbose = verbose
@property
def method(self):
return "umap"
class UMAPVisualization(Visualization):
def ensure_requirements(self):
fou.ensure_package(
"umap-learn>=0.5",
error_msg=(
"You must install the `umap-learn>=0.5` package in order to "
"use UMAP-based visualization. This is recommended, as UMAP "
"is awesome! If you do not wish to install UMAP, try "
"`method='tsne'` instead"
),
)
def fit(self, embeddings):
_umap = umap.UMAP(
n_components=self.config.num_dims,
n_neighbors=self.config.num_neighbors,
metric=self.config.metric,
min_dist=self.config.min_dist,
random_state=self.config.seed,
verbose=self.config.verbose,
)
return _umap.fit_transform(embeddings)
class TSNEVisualizationConfig(VisualizationConfig):
"""Configuration for t-distributed Stochastic Neighbor Embedding (t-SNE)
visualization.
See https://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html
for more information about the supported parameters.
Args:
embeddings_field (None): the sample field containing the embeddings,
if one was provided
points_field (None): the name of a field in which to store the
visualization points, if requested
similarity_index (None): the similarity index containing the
embeddings, if one was provided
model (None): the :class:`fiftyone.core.models.Model` or name of the
zoo model that was used to compute embeddings, if known
model_kwargs (None): a dictionary of optional keyword arguments to pass
to the model's ``Config`` when a model name is provided
patches_field (None): the sample field defining the patches being
analyzed, if any
num_dims (2): the dimension of the visualization space
pca_dims (50): the number of PCA dimensions to compute prior to running
t-SNE. It is highly recommended to reduce the number of dimensions
to a reasonable number (e.g. 50) before running t-SNE, as this will
suppress some noise and speed up the computation of pairwise
distances between samples
svd_solver ("randomized"): the SVD solver to use when performing PCA.
Consult the sklearn docmentation for details
metric ("euclidean"): the metric to use when calculating distance
between embeddings. Must be a supported value for the ``metric``
argument of ``scipy.spatial.distance.pdist``
perplexity (30.0): the perplexity to use. Perplexity is related to the
number of nearest neighbors that is used in other manifold learning
algorithms. Larger datasets usually require a larger perplexity.
Typical values are in ``[5, 50]``
learning_rate (200.0): the learning rate to use. Typical values are
in ``[10, 1000]``. If the learning rate is too high, the data may
look like a ball with any point approximately equidistant from its
nearest neighbours. If the learning rate is too low, most points
may look compressed in a dense cloud with few outliers. If the cost
function gets stuck in a bad local minimum increasing the learning
rate may help
max_iters (1000): the maximum number of iterations to run. Should be at
least 250
seed (None): a random seed
verbose (True): whether to log progress
"""
def __init__(
self,
embeddings_field=None,
points_field=None,
similarity_index=None,
model=None,
model_kwargs=None,
patches_field=None,
num_dims=2,
pca_dims=50,
svd_solver="randomized",
metric="euclidean",
perplexity=30.0,
learning_rate=200.0,
max_iters=1000,
seed=None,
verbose=True,
**kwargs,
):
super().__init__(
embeddings_field=embeddings_field,
points_field=points_field,
similarity_index=similarity_index,
model=model,
model_kwargs=model_kwargs,
patches_field=patches_field,
num_dims=num_dims,
**kwargs,
)
self.pca_dims = pca_dims
self.svd_solver = svd_solver
self.metric = metric
self.perplexity = perplexity
self.learning_rate = learning_rate
self.max_iters = max_iters
self.seed = seed
self.verbose = verbose
@property
def method(self):
return "tsne"
class TSNEVisualization(Visualization):
def fit(self, embeddings):
if self.config.pca_dims is not None:
_pca = skd.PCA(
n_components=self.config.pca_dims,
svd_solver=self.config.svd_solver,
random_state=self.config.seed,
)
embeddings = _pca.fit_transform(embeddings)
embeddings = embeddings.astype(np.float32, copy=False)
verbose = 2 if self.config.verbose else 0
sklearn_version = version.parse(sklearn.__version__)
iter_param = (
"max_iter"
if sklearn_version >= version.parse("1.5.0")
else "n_iter"
)
_tsne = skm.TSNE(
n_components=self.config.num_dims,
perplexity=self.config.perplexity,
learning_rate=self.config.learning_rate,
metric=self.config.metric,
init="pca",
random_state=self.config.seed,
verbose=verbose,
**{iter_param: self.config.max_iters},
)
return _tsne.fit_transform(embeddings)
class PCAVisualizationConfig(VisualizationConfig):
"""Configuration for principal component analysis (PCA) embedding
visualization.
See https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.PCA.html
for more information about the supported parameters.
Args:
embeddings_field (None): the sample field containing the embeddings,
if one was provided
points_field (None): the name of a field in which to store the
visualization points, if requested
similarity_index (None): the similarity index containing the
embeddings, if one was provided
model (None): the :class:`fiftyone.core.models.Model` or name of the
zoo model that was used to compute embeddings, if known
model_kwargs (None): a dictionary of optional keyword arguments to pass
to the model's ``Config`` when a model name is provided
patches_field (None): the sample field defining the patches being
analyzed, if any
num_dims (2): the dimension of the visualization space
svd_solver ("randomized"): the SVD solver to use. Consult the sklearn
docmentation for details
seed (None): a random seed
"""
def __init__(
self,
embeddings_field=None,
points_field=None,
similarity_index=None,
model=None,
model_kwargs=None,
patches_field=None,
num_dims=2,
svd_solver="randomized",
seed=None,
**kwargs,
):
super().__init__(
embeddings_field=embeddings_field,
points_field=points_field,
similarity_index=similarity_index,
model=model,
model_kwargs=model_kwargs,
patches_field=patches_field,
num_dims=num_dims,
**kwargs,
)
self.svd_solver = svd_solver
self.seed = seed
@property
def method(self):
return "pca"
class PCAVisualization(Visualization):
def fit(self, embeddings):
_pca = skd.PCA(
n_components=self.config.num_dims,
svd_solver=self.config.svd_solver,
random_state=self.config.seed,
)
return _pca.fit_transform(embeddings)
class ManualVisualizationConfig(VisualizationConfig):
"""Configuration for manually-provided low-dimensional visualizations.
Args:
patches_field (None): the sample field defining the patches being
analyzed, if any
num_dims (2): the dimension of the visualization space
"""
def __init__(self, patches_field=None, num_dims=2, **kwargs):
super().__init__(
patches_field=patches_field, num_dims=num_dims, **kwargs
)
@property
def method(self):
return "manual"
class ManualVisualization(Visualization):
def fit(self, embeddings):
raise NotImplementedError(
"The low-dimensional representation must be manually provided "
"when using this method"
)
================================================
FILE: install.bat
================================================
@echo off
:: Installs the `fiftyone-brain` package and its dependencies.
::
:: Usage:
:: .\install.bat
::
:: Copyright 2017-2026, Voxel51, Inc.
:: voxel51.com
::
:: Commands:
:: -h Display help message
:: -d Install developer dependencies.
set SHOW_HELP=false
set DEV_INSTALL=false
:parse
IF "%~1"=="" GOTO endparse
IF "%~1"=="-h" GOTO helpmessage
IF "%~1"=="-d" set DEV_INSTALL=true
SHIFT
GOTO parse
:endparse
echo ***** INSTALLING FIFTYONE-BRAIN *****
IF %DEV_INSTALL%==true (
echo Performing dev install
pip install -r requirements/dev.txt
pre-commit install
pip install -e .
) else (
pip install -r requirements.txt
pip install .
)
echo ***** INSTALLATION COMPLETE *****
exit /b
:helpmessage
echo Additional Arguments:
echo -h Display help message
echo -d Install developer dependencies.
exit /b
================================================
FILE: install.sh
================================================
#!/bin/sh
# Installs the `fiftyone-brain` package and its dependencies.
#
# Usage:
# sh install.sh
#
# Copyright 2017-2026, Voxel51, Inc.
# voxel51.com
#
# Show usage information
set -e
usage() {
echo "Usage: sh $0 [-h] [-d]
Getting help:
-h Display this help message.
Custom installations:
-d Install developer dependencies.
"
}
# Parse flags
SHOW_HELP=false
DEV_INSTALL=false
while getopts "hd" FLAG; do
case "${FLAG}" in
h) SHOW_HELP=true ;;
d) DEV_INSTALL=true ;;
*) usage ;;
esac
done
[ ${SHOW_HELP} = true ] && usage && exit 0
OS=$(uname -s)
echo "***** INSTALLING FIFTYONE-BRAIN *****"
if [ ${DEV_INSTALL} = true ]; then
echo "Performing dev install"
pip install -r requirements/dev.txt
pre-commit install
pip install -e .
else
pip install -r requirements.txt
pip install .
fi
echo "***** INSTALLATION COMPLETE *****"
================================================
FILE: pylintrc
================================================
[MASTER]
# Specify a configuration file.
#rcfile=
# Python code to execute, usually for sys.path manipulation such as
# pygtk.require().
#init-hook=
# Add files or directories to the blacklist. They should be base names, not
# paths.
ignore=CVS
# Add files or directories matching the regex patterns to the blacklist. The
# regex matches against base names, not paths.
ignore-patterns=
# Pickle collected data for later comparisons.
persistent=yes
# List of plugins (as comma separated values of python modules names) to load,
# usually to register additional checkers.
load-plugins=
# Use multiple processes to speed up Pylint.
jobs=1
# Allow loading of arbitrary C extensions. Extensions are imported into the
# active Python interpreter and may run arbitrary code.
unsafe-load-any-extension=no
# A comma-separated list of package or module names from where C extensions may
# be loaded. Extensions are loading into the active Python interpreter and may
# run arbitrary code
extension-pkg-whitelist=
# Allow optimization of some AST trees. This will activate a peephole AST
# optimizer, which will apply various small optimizations. For instance, it can
# be used to obtain the result of joining multiple strings with the addition
# operator. Joining a lot of strings can lead to a maximum recursion error in
# Pylint and this flag can prevent that. It has one side effect, the resulting
# AST will be different than the one from reality. This option is deprecated
# and it will be removed in Pylint 2.0.
optimize-ast=no
[MESSAGES CONTROL]
# Only show warnings with the listed confidence levels. Leave empty to show
# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
confidence=
# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
# multiple time (only on the command line, not in the configuration file where
# it should appear only once). See also the "--disable" option for examples.
#enable=
# Disable the message, report, category or checker with the given id(s). You
# can either give multiple identifiers separated by comma (,) or put this
# option multiple times (only on the command line, not in the configuration
# file where it should appear only once).You can also use "--disable=all" to
# disable everything first and then reenable specific checks. For example, if
# you want to run only the similarities checker, you can use "--disable=all
# --enable=similarities". If you want to run only the classes checker, but have
# no Warning level messages displayed, use"--disable=all --enable=classes
# --disable=W"
#disable=import-star-module-level,old-octal-literal,oct-method,print-statement,unpacking-in-except,parameter-unpacking,backtick,old-raise-syntax,old-ne-operator,long-suffix,dict-view-method,dict-iter-method,metaclass-assignment,next-method-called,raising-string,indexing-exception,raw_input-builtin,long-builtin,file-builtin,execfile-builtin,coerce-builtin,cmp-builtin,buffer-builtin,basestring-builtin,apply-builtin,filter-builtin-not-iterating,using-cmp-argument,useless-suppression,range-builtin-not-iterating,suppressed-message,no-absolute-import,old-division,cmp-method,reload-builtin,zip-builtin-not-iterating,intern-builtin,unichr-builtin,reduce-builtin,standarderror-builtin,unicode-builtin,xrange-builtin,coerce-method,delslice-method,getslice-method,setslice-method,input-builtin,round-builtin,hex-method,nonzero-method,map-builtin-not-iterating
disable=too-few-public-methods,too-many-instance-attributes,too-many-arguments,too-many-locals,too-many-lines,non-iterator-returned,too-many-statements,useless-object-inheritance,abstract-method,too-many-ancestors,too-many-branches,unnecessary-pass,too-many-public-methods,bad-continuation
[REPORTS]
# Set the output format. Available formats are text, parseable, colorized, msvs
# (visual studio) and html. You can also give a reporter class, eg
# mypackage.mymodule.MyReporterClass.
output-format=colorized
# Put messages in a separate file for each module / package specified on the
# command line instead of printing them on stdout. Reports (if any) will be
# written in a file name "pylint_global.[txt|html]". This option is deprecated
# and it will be removed in Pylint 2.0.
files-output=no
# Tells whether to display a full report or only the messages
reports=no
score=no
# Python expression which should return a note less than 10 (10 is the highest
# note). You have access to the variables errors warning, statement which
# respectively contain the number of errors / warnings messages and the total
# number of statements analyzed. This is used by the global evaluation report
# (RP0004).
evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
# Template used to display messages. This is a python new-style format string
# used to format the message information. See doc for all details
#msg-template=
[BASIC]
# Good variable names which should always be accepted, separated by a comma
good-names=i,j,k
# Bad variable names which should always be refused, separated by a comma
bad-names=
# Colon-delimited sets of names that determine each other's naming style when
# the name regexes allow several styles.
name-group=
# Include a hint for the correct naming format with invalid-name
include-naming-hint=no
# List of decorators that produce properties, such as abc.abstractproperty. Add
# to this list to register other decorators that produce valid properties.
property-classes=abc.abstractproperty
# Regular expression matching correct function names
function-rgx=[a-z_]([a-z0-9_]{0,30})$
# Naming hint for function names
function-name-hint=[a-z_]([a-z0-9_]{0,30})$
# Regular expression matching correct variable names
variable-rgx=[a-z_]([a-z0-9_]{0,30})$
# Naming hint for variable names
variable-name-hint=[a-z_]([a-z0-9_]{0,30})$
# Regular expression matching correct constant names
const-rgx=(([A-Z_][A-Z0-9_]*)|(__.*__))$
# Naming hint for constant names
const-name-hint=(([A-Z_][A-Z0-9_]*)|(__.*__))$
# Regular expression matching correct attribute names
attr-rgx=[a-z_]([a-z0-9_]{0,30})$
# Naming hint for attribute names
attr-name-hint=[a-z_]([a-z0-9_]{0,30})$
# Regular expression matching correct argument names
argument-rgx=[a-z_]([a-z0-9_]{0,30})$
# Naming hint for argument names
argument-name-hint=[a-z_]([a-z0-9_]{0,30})$
# Regular expression matching correct class attribute names
class-attribute-rgx=([A-Za-z_]([A-Za-z0-9_]{0,30})|(__.*__))$
# Naming hint for class attribute names
class-attribute-name-hint=([A-Za-z_]([A-Za-z0-9_]{0,30})|(__.*__))$
# Regular expression matching correct inline iteration names
inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$
# Naming hint for inline iteration names
inlinevar-name-hint=[A-Za-z_][A-Za-z0-9_]*$
# Regular expression matching correct class names
class-rgx=[A-Z_][a-zA-Z0-9]+$
# Naming hint for class names
class-name-hint=[A-Z_][a-zA-Z0-9]+$
# Regular expression matching correct module names
module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$
# Naming hint for module names
module-name-hint=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$
# Regular expression matching correct method names
method-rgx=[a-z_]([a-z0-9_]{0,30})$
# Naming hint for method names
method-name-hint=[a-z_]([a-z0-9_]{0,30})$
# Regular expression which should only match function or class names that do
# not require a docstring.
no-docstring-rgx=^_
# Minimum line length for functions/classes that require docstrings, shorter
# ones are exempt.
docstring-min-length=-1
[ELIF]
# Maximum number of nested blocks for function / method body
max-nested-blocks=5
[FORMAT]
# Maximum number of characters on a single line.
max-line-length=79
# Regexp for a line that is allowed to be longer than the limit.
ignore-long-lines=^\s*(# )??$
# Allow the body of an if to be on the same line as the test if there is no
# else.
single-line-if-stmt=no
# List of optional constructs for which whitespace checking is disabled. `dict-
# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}.
# `trailing-comma` allows a space between comma and closing bracket: (a, ).
# `empty-line` allows space-only lines.
no-space-check=trailing-comma,dict-separator
# Maximum number of lines in a module
max-module-lines=1000
# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
# tab).
indent-string=' '
# Number of spaces of indent required inside a hanging or continued line.
indent-after-paren=4
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
expected-line-ending-format=
[LOGGING]
# Logging modules to check that the string format arguments are in logging
# function parameter format
logging-modules=logging
[MISCELLANEOUS]
# List of note tags to take in consideration, separated by a comma.
notes=FIXME,XXX,TODO
[SIMILARITIES]
# Minimum lines number of a similarity.
min-similarity-lines=4
# Ignore comments when computing similarities.
ignore-comments=yes
# Ignore docstrings when computing similarities.
ignore-docstrings=yes
# Ignore imports when computing similarities.
ignore-imports=no
[SPELLING]
# Spelling dictionary name. Available dictionaries: none. To make it working
# install python-enchant package.
spelling-dict=
# List of comma separated words that should not be checked.
spelling-ignore-words=
# A path to a file that contains private dictionary; one word per line.
spelling-private-dict-file=
# Tells whether to store unknown words to indicated private dictionary in
# --spelling-private-dict-file option instead of raising a message.
spelling-store-unknown-words=no
[TYPECHECK]
# Tells whether missing members accessed in mixin class should be ignored. A
# mixin class is detected if its name ends with "mixin" (case insensitive).
ignore-mixin-members=yes
# List of module names for which member attributes should not be checked
# (useful for modules/projects where namespaces are manipulated during runtime
# and thus existing member attributes cannot be deduced by static analysis. It
# supports qualified module names, as well as Unix pattern matching.
ignored-modules=
# List of class names for which member attributes should not be checked (useful
# for classes with dynamically set attributes). This supports the use of
# qualified names.
ignored-classes=optparse.Values,thread._local,_thread._local
# List of members which are set dynamically and missed by pylint inference
# system, and so shouldn't trigger E1101 when accessed. Python regular
# expressions are accepted.
generated-members=torch.*,fiftyone.*,fo.*
# List of decorators that produce context managers, such as
# contextlib.contextmanager. Add to this list to register other decorators that
# produce valid context managers.
contextmanager-decorators=contextlib.contextmanager
[CLASSES]
# List of method names used to declare (i.e. assign) instance attributes.
defining-attr-methods=__init__,__new__,setUp
# List of valid names for the first argument in a class method.
valid-classmethod-first-arg=cls
# List of valid names for the first argument in a metaclass class method.
valid-metaclass-classmethod-first-arg=mcs
# List of member names, which should be excluded from the protected access
# warning.
exclude-protected=_asdict,_fields,_replace,_source,_make
[DESIGN]
# Maximum number of arguments for function / method
max-args=5
# Argument names that match this expression will be ignored. Default to name
# with leading underscore
ignored-argument-names=_.*
# Maximum number of locals for function / method body
max-locals=15
# Maximum number of return / yield for function / method body
max-returns=6
# Maximum number of branch for function / method body
max-branches=12
# Maximum number of statements in function / method body
max-statements=50
# Maximum number of parents for a class (see R0901).
max-parents=7
# Maximum number of attributes for a class (see R0902).
max-attributes=7
# Minimum number of public methods for a class (see R0903).
min-public-methods=2
# Maximum number of public methods for a class (see R0904).
max-public-methods=20
# Maximum number of boolean expressions in a if statement
max-bool-expr=5
[IMPORTS]
# Deprecated modules which should not be used, separated by a comma
deprecated-modules=regsub,TERMIOS,Bastion,rexec
# Create a graph of every (i.e. internal and external) dependencies in the
# given file (report RP0402 must not be disabled)
import-graph=
# Create a graph of external dependencies in the given file (report RP0402 must
# not be disabled)
ext-import-graph=
# Create a graph of internal dependencies in the given file (report RP0402 must
# not be disabled)
int-import-graph=
# Force import order to recognize a module as part of the standard
# compatibility libraries.
known-standard-library=
# Force import order to recognize a module as part of a third party library.
known-third-party=enchant
# Analyse import fallback blocks. This can be used to support both Python 2 and
# 3 compatible code, which means that the block might have code that exists
# only in one or another interpreter, leading to false positives when analysed.
analyse-fallback-blocks=no
[EXCEPTIONS]
# Exceptions that will emit a warning when being caught. Defaults to
# "Exception"
overgeneral-exceptions=Exception
================================================
FILE: pyproject.toml
================================================
[tool.black]
line-length = 79
include = '\.pyi?$'
exclude = '''
/(
| \.git
)/
'''
================================================
FILE: pytest.ini
================================================
[pytest]
python_files = *test*.py
filterwarnings =
ignore:dns.hash module will be removed in future versions:DeprecationWarning
ignore:the imp module is deprecated in favour of importlib:DeprecationWarning
ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning
ignore:Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated:DeprecationWarning
ignore:numpy.* size changed, may indicate binary incompatibility:RuntimeWarning
================================================
FILE: requirements/build.txt
================================================
-r common.txt
pytest==5.4.3
twine>=3
================================================
FILE: requirements/common.txt
================================================
numpy
scipy
scikit-learn
================================================
FILE: requirements/dev.txt
================================================
-r common.txt
flickrapi==2.4.0
imageio==2.8.0
ipython>=7.16.1
pandas
pre-commit==2.0.1
pylint==2.3.1
pytest==7.3.1
twine>=3
voxel51-eta[storage]
================================================
FILE: requirements/prod.txt
================================================
-r common.txt
================================================
FILE: requirements.txt
================================================
-r requirements/prod.txt
================================================
FILE: setup.py
================================================
#!/usr/bin/env python
"""
Installs `fiftyone-brain`.
| Copyright 2017-2026, Voxel51, Inc.
| `voxel51.com `_
|
"""
import os
from setuptools import setup
VERSION = "0.21.4"
def get_version():
if "RELEASE_VERSION" in os.environ:
version = os.environ["RELEASE_VERSION"]
if not version.startswith(VERSION):
raise ValueError(
"Release version doest not match version: %s and %s"
% (version, VERSION)
)
return version
return VERSION
with open("README.md", "r") as fh:
long_description = fh.read()
setup(
name="fiftyone-brain",
version=get_version(),
description="FiftyOne Brain",
author="Voxel51, Inc.",
author_email="info@voxel51.com",
url="https://github.com/voxel51/fiftyone-brain",
license="Apache",
long_description=long_description,
long_description_content_type="text/markdown",
packages=["fiftyone.brain"],
include_package_data=True,
install_requires=["numpy", "scipy>=1.2.0", "scikit-learn"],
classifiers=[
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Scientific/Engineering :: Image Processing",
"Topic :: Scientific/Engineering :: Image Recognition",
"Topic :: Scientific/Engineering :: Information Analysis",
"Topic :: Scientific/Engineering :: Visualization",
"Operating System :: MacOS :: MacOS X",
"Operating System :: POSIX :: Linux",
"Operating System :: Microsoft :: Windows",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
],
scripts=[],
python_requires=">=3.9",
)
================================================
FILE: tests/README.md
================================================
# FiftyOne-Brain Tests
The brain currently uses both
[unittest](https://docs.python.org/3/library/unittest.html) and
[pytest](https://docs.pytest.org/en/stable) to implement its tests.
## Contents
| File | Description |
| -------------------- | -------------------------------------------------------- |
| `test_uniqueness.py` | Tests of the uniqueness capability |
| `models/*.py` | Tests of the various models used by the brain |
| `intensive/*.py` | Intensive tests that are not included in automated tests |
## Running tests
To run all tests in this directory, execute:
```shell
pytest . -s
```
To run a specific set of tests, execute:
```shell
pytest .py -s
```
To run a specific test case, execute:
```shell
pytest .py -s -k
```
## Copyright
Copyright 2017-2026, Voxel51, Inc.
voxel51.com
================================================
FILE: tests/intensive/test_interface.py
================================================
"""
Brain interface tests.
| Copyright 2017-2026, Voxel51, Inc.
| `voxel51.com `_
|
"""
import unittest
import fiftyone as fo
import fiftyone.brain as fob
import fiftyone.zoo as foz
def test_uniqueness():
dataset = foz.load_zoo_dataset("quickstart").clone()
fob.compute_uniqueness(dataset)
print(dataset.list_brain_runs())
print(dataset.get_brain_info("uniqueness"))
print(dataset.bounds("uniqueness"))
dataset.delete_brain_runs()
print(dataset)
def test_detection_mistakenness():
dataset = foz.load_zoo_dataset("quickstart").clone()
fob.compute_mistakenness(
dataset, "predictions", label_field="ground_truth", copy_missing=True
)
print(dataset.list_brain_runs())
print(dataset.get_brain_info("mistakenness"))
# should be non-trivial
print(dataset.bounds("mistakenness"))
print(dataset.bounds("possible_missing"))
print(dataset.bounds("possible_spurious"))
print(dataset.bounds("ground_truth.detections.mistakenness"))
print(dataset.bounds("ground_truth.detections.mistakenness_loc"))
print(dataset.count_values("ground_truth.detections.possible_spurious"))
print(dataset.count_values("predictions.detections.possible_missing"))
print(dataset.count_values("ground_truth.detections.possible_missing"))
dataset.delete_brain_runs()
print(dataset)
# should be None
print(dataset.bounds("ground_truth.detections.mistakenness"))
print(dataset.bounds("ground_truth.detections.mistakenness_loc"))
print(dataset.count_values("ground_truth.detections.possible_spurious"))
print(dataset.count_values("predictions.detections.possible_missing"))
print(dataset.count_values("ground_truth.detections.possible_missing"))
def test_classification_mistakenness_confidence():
dataset = foz.load_zoo_dataset("quickstart").clone()
test_view = dataset.take(10)
# labels proxy
model = foz.load_zoo_model("alexnet-imagenet-torch")
test_view.apply_model(model, "alexnet")
# predictions proxy
model = foz.load_zoo_model("resnet50-imagenet-torch")
test_view.apply_model(model, "resnet50")
fob.compute_mistakenness(test_view, "resnet50", label_field="alexnet")
print(dataset.list_brain_runs())
print(dataset.load_brain_view("mistakenness"))
print(dataset.bounds("mistakenness"))
dataset.delete_brain_runs()
print(dataset)
def test_classification_mistakenness_logits():
dataset = foz.load_zoo_dataset("quickstart").clone()
test_view = dataset.take(10)
# labels proxy
model = foz.load_zoo_model("alexnet-imagenet-torch")
test_view.apply_model(model, "alexnet")
# predictions proxy
model = foz.load_zoo_model("resnet50-imagenet-torch")
test_view.apply_model(model, "resnet50", store_logits=True)
fob.compute_mistakenness(
test_view, "resnet50", label_field="alexnet", use_logits=True
)
print(dataset.list_brain_runs())
print(dataset.load_brain_view("mistakenness"))
print(dataset.bounds("mistakenness"))
dataset.delete_brain_runs()
print(dataset)
def test_hardness():
dataset = foz.load_zoo_dataset("quickstart").clone()
test_view = dataset.take(10)
model = foz.load_zoo_model("alexnet-imagenet-torch")
test_view.apply_model(model, "alexnet", store_logits=True)
fob.compute_hardness(test_view, "alexnet")
print(dataset.list_brain_runs())
print(dataset.get_brain_info("hardness"))
print(dataset.load_brain_view("hardness"))
print(dataset.bounds("hardness"))
dataset.delete_brain_runs()
print(dataset)
if __name__ == "__main__":
fo.config.show_progress_bars = True
unittest.main(verbosity=2)
================================================
FILE: tests/intensive/test_similarity.py
================================================
"""
Similarity tests.
Usage::
# Optional: specific backends to test
export SIMILARITY_BACKENDS=qdrant,pinecone,milvus,redis,elasticsearch,mosaic,pgvector,lancedb
pytest tests/intensive/test_similarity.py -s -k test_XXX
Qdrant setup::
docker pull qdrant/qdrant
docker run -p 6333:6333 qdrant/qdrant
pip install qdrant-client
Pinecone setup::
# Sign up at https://www.pinecone.io
# Download API key and environment
pip install pinecone-client
Milvus setup::
# Instructions from: https://milvus.io/docs/install_standalone-docker.md
wget https://github.com/milvus-io/milvus/releases/download/v2.2.11/milvus-standalone-docker-compose.yml -O docker-compose.yml
docker compose up -d
pip install pymilvus
LanceDB setup::
pip install lancedb
Redis setup::
brew tap redis-stack/redis-stack
brew install redis-stack
redis-stack-server
pip install redis
Elasticsearch setup::
# Instructions from: https://www.elastic.co/guide/en/elasticsearch/reference/current/getting-started.html#run-elasticsearch
docker run -p 127.0.0.1:9200:9200 -d \
--name elasticsearch \
-e ELASTIC_PASSWORD=elastic \
-e "discovery.type=single-node" \
-e "xpack.security.http.ssl.enabled=false" \
-e "xpack.license.self_generated.type=trial" \
docker.elastic.co/elasticsearch/elasticsearch:8.15.0
pip install elasticsearch
Mosaic setup::
# In your databricks workspace, generate a personal access token for authentication.
# You will also need to create a catalog and schema in your workspace.
# You will have to create an endpoint under `compute` -> `vector search`
pip install databricks-vectorsearch
PGVector setup::
# Run a postgres instance locally with pgvector extension
docker pull pgvector/pgvector:pg17
docker run --name postgres -p 5432:5432 -e POSTGRES_PASSWORD=mysecretpassword -d pgvector/pgvector:pg17
# Enter the container and create the vector extension
docker exec -it postgres ./bin/psql -U postgres
CREATE EXTENSION IF NOT EXISTS vector; # run in container
pip install psycopg2
Brain config setup at `~/.fiftyone/brain_config.json`::
{
"similarity_backends": {
"pinecone": {
"api_key": "XXXXXXXX",
"cloud": "aws",
"region": "us-east-1",
"environment": "us-east-1-aws"
},
"qdrant": {
"url": "http://localhost:6333"
},
"milvus": {
"uri": "http://localhost:19530"
},
"lancedb": {
"uri": "/tmp/lancedb"
},
"redis": {
"host": "localhost",
"port": 6379
}
"elasticsearch": {
"hosts": "http://localhost:9200",
"username": "elastic",
"password": "elastic"
},
"mosaic": {
"workspace_url": "https://.cloud.databricks.com/",
"personal_access_token": "",
"catalog_name": "",
"schema_name": "",
"endpoint_name": ""
},
"pgvector": {
"connection_string": "postgresql://postgres:mysecretpassword@localhost:5432/postgres"
}
}
}
| Copyright 2017-2026, Voxel51, Inc.
| `voxel51.com `_
|
"""
import random
import os
import time
import unittest
import numpy as np
import fiftyone as fo
import fiftyone.brain as fob # pylint: disable=import-error,no-name-in-module
import fiftyone.zoo as foz
from fiftyone import ViewField as F
CUSTOM_BACKENDS = [
"qdrant",
"pinecone",
"milvus",
"redis",
"elasticsearch",
"mosaic",
"pgvector",
"lancedb",
]
def get_custom_backends():
if "SIMILARITY_BACKENDS" in os.environ:
return os.environ["SIMILARITY_BACKENDS"].split(",")
return CUSTOM_BACKENDS
def test_brain_config():
similarity_backends = fob.brain_config.similarity_backends
assert "sklearn" in similarity_backends
for backend in get_custom_backends():
if backend == "qdrant":
assert "qdrant" in similarity_backends
# this isn't mandatory
# assert "url" in similarity_backends["qdrant"]
if backend == "pinecone":
assert "pinecone" in similarity_backends
# this isn't mandatory
# assert "api_key" in similarity_backends["pinecone"]
# assert "cloud" in similarity_backends["pinecone"]
# assert "region" in similarity_backends["pinecone"]
# assert "environment" in similarity_backends["pinecone"]
if backend == "milvus":
assert "milvus" in similarity_backends
# this isn't mandatory
# assert "uri" in similarity_backends["milvus"]
if backend == "lancedb":
assert "lancedb" in similarity_backends
# this isn't mandatory
# assert "uri" in similarity_backends["lancedb"]
if backend == "redis":
assert "redis" in similarity_backends
# this isn't mandatory
# assert "host" in similarity_backends["redis"]
# assert "port" in similarity_backends["redis"]
if backend == "elasticsearch":
assert "elasticsearch" in similarity_backends
def test_image_similarity_backends():
dataset = foz.load_zoo_dataset(
"quickstart",
dataset_name="quickstart-test-similarity-image",
drop_existing_dataset=True,
)
# sklearn backend
###########################################################################
index1 = fob.compute_similarity(
dataset,
model="clip-vit-base32-torch",
metric="euclidean",
embeddings=False,
backend="sklearn",
brain_key="clip_sklearn",
)
embeddings, sample_ids, _ = index1.compute_embeddings(dataset)
index1.add_to_index(embeddings, sample_ids)
index1.save()
index1.reload()
assert index1.total_index_size == 200
assert index1.index_size == 200
assert index1.missing_size is None
prompt = "kites high in the air"
view1 = dataset.sort_by_similarity(prompt, k=10, brain_key="clip_sklearn")
assert len(view1) == 10
del index1
dataset.clear_cache()
print(dataset.get_brain_info("clip_sklearn"))
index1 = dataset.load_brain_results("clip_sklearn")
assert index1.total_index_size == 200
embeddings1, sample_ids1, _ = index1.get_embeddings()
assert embeddings1.shape == (200, 512)
assert sample_ids1.shape == (200,)
ids = random.sample(list(index1.sample_ids), 100)
embeddings1, sample_ids1, _ = index1.get_embeddings(sample_ids=ids)
assert embeddings1.shape == (100, 512)
assert sample_ids1.shape == (100,)
index1.remove_from_index(sample_ids=ids)
assert index1.total_index_size == 100
index1.cleanup()
dataset.delete_brain_run("clip_sklearn")
# custom backends
###########################################################################
for backend in get_custom_backends():
brain_key = "clip_" + backend
index2 = fob.compute_similarity(
dataset,
model="clip-vit-base32-torch",
metric="euclidean",
embeddings=False,
backend=backend,
brain_key=brain_key,
)
index2.add_to_index(embeddings, sample_ids)
assert _verify_total_index_size(index=index2, expected_size=200)
assert index2.total_index_size == 200
assert index2.index_size == 200
assert index2.missing_size is None
view2 = dataset.sort_by_similarity(prompt, k=10, brain_key=brain_key)
assert len(view2) == 10
del index2
dataset.clear_cache()
print(dataset.get_brain_info(brain_key))
index2 = dataset.load_brain_results(brain_key)
assert index2.total_index_size == 200
# Pinecone and Milvus require IDs, so this method is not supported
if backend not in ("pinecone", "milvus"):
embeddings2, sample_ids2, _ = index2.get_embeddings()
assert embeddings2.shape == (200, 512)
assert sample_ids2.shape == (200,)
embeddings2, sample_ids2, _ = index2.get_embeddings(sample_ids=ids)
assert embeddings2.shape == (100, 512)
assert sample_ids2.shape == (100,)
assert set(sample_ids1) == set(sample_ids2)
embeddings2_dict = dict(zip(sample_ids2, embeddings2))
_embeddings2 = np.array([embeddings2_dict[i] for i in sample_ids1])
assert np.allclose(embeddings1, _embeddings2)
index2.remove_from_index(sample_ids=ids)
# Collection size is known to be wrong in Milvus after deletions
# As of July 5, 2023 this has not been fixed
# https://github.com/milvus-io/milvus/issues/17193
if backend != "milvus":
assert index2.total_index_size == 100
index2.cleanup()
dataset.delete_brain_run(brain_key)
dataset.delete()
def test_patch_similarity_backends():
dataset = foz.load_zoo_dataset(
"quickstart",
dataset_name="quickstart-test-similarity-patch",
drop_existing_dataset=True,
)
# sklearn backend
###########################################################################
index1 = fob.compute_similarity(
dataset,
patches_field="ground_truth",
model="clip-vit-base32-torch",
metric="euclidean",
embeddings=False,
backend="sklearn",
brain_key="gt_clip_sklearn",
)
embeddings, sample_ids, label_ids = index1.compute_embeddings(dataset)
index1.add_to_index(embeddings, sample_ids, label_ids=label_ids)
index1.save()
index1.reload()
assert index1.total_index_size == 1232
assert index1.index_size == 1232
assert index1.missing_size is None
view = dataset.to_patches("ground_truth")
prompt = "cute puppies"
view1 = view.sort_by_similarity(prompt, k=10, brain_key="gt_clip_sklearn")
assert len(view1) == 10
del index1
dataset.clear_cache()
print(dataset.get_brain_info("gt_clip_sklearn"))
index1 = dataset.load_brain_results("gt_clip_sklearn")
assert index1.total_index_size == 1232
embeddings1, sample_ids1, label_ids1 = index1.get_embeddings()
assert embeddings1.shape == (1232, 512)
assert sample_ids1.shape == (1232,)
assert label_ids1.shape == (1232,)
ids = random.sample(list(index1.label_ids), 100)
embeddings1, sample_ids1, label_ids1 = index1.get_embeddings(label_ids=ids)
assert embeddings1.shape == (100, 512)
assert sample_ids1.shape == (100,)
assert label_ids1.shape == (100,)
index1.remove_from_index(label_ids=ids)
assert index1.total_index_size == 1132
index1.cleanup()
dataset.delete_brain_run("gt_clip_sklearn")
# custom backends
###########################################################################
for backend in get_custom_backends():
brain_key = "gt_clip_" + backend
index2 = fob.compute_similarity(
dataset,
patches_field="ground_truth",
model="clip-vit-base32-torch",
metric="euclidean",
embeddings=False,
backend=backend,
brain_key=brain_key,
)
index2.add_to_index(embeddings, sample_ids, label_ids=label_ids)
assert _verify_total_index_size(index=index2, expected_size=1232)
assert index2.total_index_size == 1232
assert index2.index_size == 1232
assert index2.missing_size is None
view2 = view.sort_by_similarity(prompt, k=10, brain_key=brain_key)
assert len(view2) == 10
del index2
dataset.clear_cache()
print(dataset.get_brain_info(brain_key))
index2 = dataset.load_brain_results(brain_key)
assert index2.total_index_size == 1232
# Pinecone and Milvus require IDs, so this method is not supported
if backend not in ("pinecone", "milvus"):
embeddings2, sample_ids2, label_ids2 = index2.get_embeddings()
assert embeddings2.shape == (1232, 512)
assert sample_ids2.shape == (1232,)
assert label_ids2.shape == (1232,)
embeddings2, sample_ids2, label_ids2 = index2.get_embeddings(
label_ids=ids
)
assert embeddings2.shape == (100, 512)
assert sample_ids2.shape == (100,)
assert label_ids2.shape == (100,)
assert set(label_ids1) == set(label_ids2)
embeddings2_dict = dict(zip(label_ids2, embeddings2))
_embeddings2 = np.array([embeddings2_dict[i] for i in label_ids1])
assert np.allclose(embeddings1, _embeddings2)
index2.remove_from_index(label_ids=ids)
# Collection size is known to be wrong in Milvus after deletions
# As of July 5, 2023 this has not been fixed
# https://github.com/milvus-io/milvus/issues/17193
if backend != "milvus":
assert index2.total_index_size == 1132
index2.cleanup()
dataset.delete_brain_run(brain_key)
dataset.delete()
def test_qdrant_backend_config():
"""
- *_similarity_backends tests run with custom backends as "externally" configured
- To test varying connection details (eg with qdrant), re-configure externally and re-run tests
- This test white-box tests that gRPC-related config settings are applied to QdrantClient
"""
backend = "qdrant"
if backend not in get_custom_backends():
return
dataset = foz.load_zoo_dataset("quickstart", max_samples=5)
brain_key = "clip_" + backend
index = fob.compute_similarity(
dataset,
model="clip-vit-base32-torch",
metric="euclidean",
embeddings=False,
backend=backend,
brain_key=brain_key,
)
qclient = index.client
qremote = qclient._client
qdrant_config = fob.brain_config.similarity_backends["qdrant"]
if "prefer_grpc" in qdrant_config:
prefer_grpc = qdrant_config["prefer_grpc"]
assert qremote._prefer_grpc == prefer_grpc
print(f"Applied qdrant config prefer_grpc={prefer_grpc}")
else:
print("Qdrant config prefer_grpc unset")
if "grpc_port" in qdrant_config:
grpc_port = qdrant_config["grpc_port"]
assert qremote._grpc_port == grpc_port
print(f"Applied qdrant config grpc_port={grpc_port}")
else:
print("Qdrant config grpc_port unset")
dataset.delete()
def test_images():
dataset = _load_images_dataset()
index = dataset.load_brain_results("img_sim")
assert index.total_index_size == len(dataset)
assert set(dataset.values("id")) == set(index.sample_ids)
def test_images_subset():
dataset = _load_images_dataset()
index = dataset.load_brain_results("img_sim")
view = dataset.take(10)
index.use_view(view)
assert index.index_size == len(view)
assert set(view.values("id")) == set(index.current_sample_ids)
def test_images_missing():
dataset = _load_images_dataset().limit(4).clone()
dataset.add_samples(
[
fo.Sample(filepath="non-existent1.png"),
fo.Sample(filepath="non-existent2.png"),
fo.Sample(filepath="non-existent3.png"),
fo.Sample(filepath="non-existent4.png"),
]
)
sample_ids = dataset[:4].values("id")
index = fob.compute_similarity(dataset, batch_size=1)
assert index.total_index_size == 4
assert set(sample_ids) == set(index.sample_ids)
model = foz.load_zoo_model("inception-v3-imagenet-torch")
index = fob.compute_similarity(
dataset,
model=model,
embeddings="embeddings_missing",
batch_size=1,
)
assert len(dataset.exists("embeddings_missing")) == 4
assert index.index_size == 4
assert set(sample_ids) == set(index.sample_ids)
def test_images_embeddings():
dataset = foz.load_zoo_dataset(
"quickstart", max_samples=10, drop_existing_dataset=True
)
model = foz.load_zoo_model("clip-vit-base32-torch")
n = len(dataset)
# Embeddings are computed on-the-fly and stored on dataset
index1 = fob.compute_similarity(
dataset,
embeddings="embeddings",
model="clip-vit-base32-torch",
brain_key="img_sim1",
backend="sklearn",
)
assert index1.total_index_size == n
assert index1.config.supports_prompts is True
assert "embeddings" not in index1.serialize()
# Embeddings already exist on dataset
dataset.compute_embeddings(model, embeddings_field="embeddings2")
index2 = fob.compute_similarity(
dataset,
embeddings="embeddings2",
model="clip-vit-base32-torch",
brain_key="img_sim2",
backend="sklearn",
)
assert index2.total_index_size == n
assert index2.config.supports_prompts is True
assert "embeddings" not in index2.serialize()
# Embeddings stored in index itself
index3 = fob.compute_similarity(
dataset,
model="clip-vit-base32-torch",
brain_key="img_sim3",
backend="sklearn",
)
assert index3.total_index_size == n
assert index3.config.supports_prompts is True
assert "embeddings" in index3.serialize()
# Embeddings stored on dataset (but field doesn't initially exist)
index4 = fob.compute_similarity(
dataset,
embeddings="embeddings4",
brain_key="img_sim4",
backend="sklearn",
)
embeddings = np.random.randn(n, 512)
sample_ids = dataset.values("id")
index4.add_to_index(embeddings, sample_ids)
assert index4.total_index_size == n
assert index4.config.supports_prompts is not True
assert "embeddings" not in index4.serialize()
dataset.delete()
def test_patches():
dataset = _load_patches_dataset()
index = dataset.load_brain_results("gt_sim")
label_ids = dataset.values("ground_truth.detections.id", unwind=True)
assert index.total_index_size == len(label_ids)
assert set(label_ids) == set(index.label_ids)
def test_patches_subset():
dataset = _load_patches_dataset()
index = dataset.load_brain_results("gt_sim")
label_ids = dataset.values("ground_truth.detections.id", unwind=True)
assert index.total_index_size == len(label_ids)
assert set(label_ids) == set(index.label_ids)
view = dataset.filter_labels("ground_truth", F("label") == "person")
index.use_view(view)
label_ids = view.values("ground_truth.detections.id", unwind=True)
assert index.index_size == len(label_ids)
assert set(label_ids) == set(index.current_label_ids)
def test_patches_missing():
dataset = _load_patches_dataset().limit(4).clone()
dataset.add_samples(
[
fo.Sample(filepath="non-existent1.png"),
fo.Sample(filepath="non-existent2.png"),
fo.Sample(filepath="non-existent3.png"),
fo.Sample(filepath="non-existent4.png"),
]
)
for sample in dataset[4:]:
sample["ground_truth"] = fo.Detections(
detections=[fo.Detection(bounding_box=[0.1, 0.1, 0.8, 0.8])]
)
sample.save()
index = fob.compute_similarity(
dataset, patches_field="ground_truth", batch_size=1
)
num_patches = dataset[:4].count("ground_truth.detections")
label_ids = dataset[:4].values("ground_truth.detections.id", unwind=True)
assert index.total_index_size == num_patches
assert set(label_ids) == set(index.label_ids)
model = foz.load_zoo_model("inception-v3-imagenet-torch")
index = fob.compute_similarity(
dataset,
model=model,
patches_field="ground_truth",
embeddings="embeddings_missing",
batch_size=1,
)
view = dataset.filter_labels(
"ground_truth", F("embeddings_missing") != None
)
assert view.count("ground_truth.detections") == num_patches
assert index.total_index_size == num_patches
assert set(label_ids) == set(index.label_ids)
def test_patches_embeddings():
dataset = foz.load_zoo_dataset(
"quickstart", max_samples=10, drop_existing_dataset=True
)
model = foz.load_zoo_model("clip-vit-base32-torch")
n = dataset.count("ground_truth.detections")
# Embeddings are computed on-the-fly and stored on dataset
index1 = fob.compute_similarity(
dataset,
patches_field="ground_truth",
embeddings="embeddings",
model="clip-vit-base32-torch",
brain_key="gt_sim1",
backend="sklearn",
)
assert index1.total_index_size == n
assert index1.config.supports_prompts is True
assert "embeddings" not in index1.serialize()
# Embeddings already exist on dataset
dataset.compute_patch_embeddings(
model, "ground_truth", embeddings_field="embeddings2"
)
index2 = fob.compute_similarity(
dataset,
patches_field="ground_truth",
embeddings="embeddings2",
model="clip-vit-base32-torch",
brain_key="gt_sim2",
backend="sklearn",
)
assert index2.total_index_size == n
assert index2.config.supports_prompts is True
assert "embeddings" not in index2.serialize()
# Embeddings stored in index itself
index3 = fob.compute_similarity(
dataset,
patches_field="ground_truth",
model="clip-vit-base32-torch",
brain_key="gt_sim3",
backend="sklearn",
)
assert index3.total_index_size == n
assert index3.config.supports_prompts is True
assert "embeddings" in index3.serialize()
# Embeddings stored on dataset (but field doesn't initially exist)
index4 = fob.compute_similarity(
dataset,
patches_field="ground_truth",
embeddings="embeddings4",
brain_key="gt_sim4",
backend="sklearn",
)
embeddings = np.random.randn(n, 512)
view = dataset.to_patches("ground_truth")
sample_ids, label_ids = view.values(["sample_id", "id"])
index4.add_to_index(embeddings, sample_ids, label_ids=label_ids)
assert index4.total_index_size == n
assert index4.config.supports_prompts is not True
assert "embeddings" not in index4.serialize()
dataset.delete()
def _load_images_dataset():
name = "test-similarity-images"
if fo.dataset_exists(name):
return fo.load_dataset(name)
return _make_images_dataset(name)
def _load_patches_dataset():
name = "test-similarity-patches"
if fo.dataset_exists(name):
return fo.load_dataset(name)
return _make_patches_dataset(name)
def _make_images_dataset(name):
dataset = foz.load_zoo_dataset(
"quickstart", max_samples=20, dataset_name=name
)
model = foz.load_zoo_model("inception-v3-imagenet-torch")
# Embed images
dataset.compute_embeddings(
model, embeddings_field="embeddings", batch_size=8
)
# Image similarity
fob.compute_similarity(
dataset, embeddings="embeddings", brain_key="img_sim"
)
return dataset
def _make_patches_dataset(name):
dataset = foz.load_zoo_dataset(
"quickstart", max_samples=20, dataset_name=name
)
model = foz.load_zoo_model("inception-v3-imagenet-torch")
# Embed ground truth patches
dataset.compute_patch_embeddings(
model,
"ground_truth",
embeddings_field="embeddings",
batch_size=8,
force_square=True,
)
# Patch similarity
fob.compute_similarity(
dataset,
patches_field="ground_truth",
embeddings="embeddings",
brain_key="gt_sim",
)
return dataset
def _verify_total_index_size(index, expected_size, timeout=10, interval=1):
elapsed_time = 0
while index.total_index_size != expected_size and elapsed_time < timeout:
time.sleep(interval)
elapsed_time += interval
return index.total_index_size == expected_size
if __name__ == "__main__":
fo.config.show_progress_bars = True
unittest.main(verbosity=2)
================================================
FILE: tests/intensive/test_uniqueness.py
================================================
"""
Uniqueness tests.
| Copyright 2017-2026, Voxel51, Inc.
| `voxel51.com `_
|
"""
import unittest
import fiftyone as fo
import fiftyone.brain as fob
import fiftyone.zoo as foz
def test_uniqueness():
_run_uniqueness()
def test_uniqueness_torch():
model = foz.load_zoo_model("inception-v3-imagenet-torch")
_run_uniqueness(model=model, batch_size=16)
def test_uniqueness_tf():
model = foz.load_zoo_model("resnet-v2-50-imagenet-tf1")
_run_uniqueness(model=model, batch_size=16)
def test_uniqueness_missing():
dataset = fo.Dataset()
dataset.add_samples(
[
fo.Sample(filepath="non-existent1.png"),
fo.Sample(filepath="non-existent2.png"),
fo.Sample(filepath="non-existent3.png"),
fo.Sample(filepath="non-existent4.png"),
]
)
fob.compute_uniqueness(dataset, batch_size=1)
view = dataset.exists("uniqueness")
assert dataset.has_field("uniqueness")
assert len(view) == 0
def test_roi_uniqueness():
_run_uniqueness(roi_field="ground_truth")
def test_roi_uniqueness_torch():
model = foz.load_zoo_model("inception-v3-imagenet-torch")
_run_uniqueness(roi_field="ground_truth", model=model, batch_size=16)
def test_roi_uniqueness_tf():
model = foz.load_zoo_model("resnet-v2-50-imagenet-tf1")
_run_uniqueness(roi_field="ground_truth", model=model, batch_size=16)
def test_roi_uniqueness_missing():
dataset = fo.Dataset()
dataset.add_samples(
[
fo.Sample(filepath="non-existent1.png"),
fo.Sample(filepath="non-existent2.png"),
fo.Sample(filepath="non-existent3.png"),
fo.Sample(filepath="non-existent4.png"),
]
)
for sample in dataset:
sample["ground_truth"] = fo.Detections(
detections=[fo.Detection(bounding_box=[0.1, 0.1, 0.8, 0.8])]
)
sample.save()
fob.compute_uniqueness(dataset, roi_field="ground_truth", batch_size=1)
view = dataset.exists("uniqueness")
assert dataset.has_field("uniqueness")
assert len(view) == 0
def test_uniqueness_similarity_index():
dataset = foz.load_zoo_dataset(
"quickstart", dataset_name=fo.get_default_dataset_name()
)
dataset.delete_sample_field("uniqueness")
# Full similarity index
similarity_index = fob.compute_similarity(
dataset, brain_key="sklearn_index", backend="sklearn"
)
fob.compute_uniqueness(dataset, similarity_index=similarity_index)
assert dataset.has_field("uniqueness")
dataset.clear_cache()
dataset.delete_sample_field("uniqueness")
fob.compute_uniqueness(dataset, similarity_index="sklearn_index")
assert dataset.has_field("uniqueness")
# Partial similarity index
view = dataset.take(100, seed=51)
similarity_index2 = fob.compute_similarity(
view, brain_key="sklearn_index2", backend="sklearn"
)
fob.compute_uniqueness(
dataset,
uniqueness_field="uniqueness2",
similarity_index="sklearn_index2",
)
assert len(dataset.exists("uniqueness2")) == len(view)
def _run_uniqueness(roi_field=None, model=None, batch_size=None):
dataset = foz.load_zoo_dataset(
"quickstart", dataset_name=fo.get_default_dataset_name()
)
dataset.delete_sample_field("uniqueness")
view = dataset.take(50)
num_samples = len(view)
fob.compute_uniqueness(
view, roi_field=roi_field, model=model, batch_size=batch_size
)
num_uniqueness = dataset.count("uniqueness")
assert num_uniqueness == num_samples
bounds = dataset.bounds("uniqueness")
assert bounds[0] >= 0
assert bounds[1] <= 1
if __name__ == "__main__":
fo.config.show_progress_bars = True
unittest.main(verbosity=2)
================================================
FILE: tests/intensive/test_visualization.py
================================================
"""
Visualization tests.
All of these tests are designed to be run manually via::
pytest tests/intensive/test_visualization.py -s -k test_
| Copyright 2017-2026, Voxel51, Inc.
| `voxel51.com `_
|
"""
import unittest
import cv2
import numpy as np
import fiftyone as fo
import fiftyone.brain as fob
import fiftyone.zoo as foz
from fiftyone import ViewField as F
def test_mnist():
dataset = foz.load_zoo_dataset("mnist", split="test")
# pylint: disable=no-member
embeddings = np.array(
[
cv2.imread(f, cv2.IMREAD_UNCHANGED).ravel()
for f in dataset.values("filepath")
]
)
results = fob.compute_visualization(
dataset,
embeddings=embeddings,
num_dims=2,
verbose=True,
seed=51,
)
plot = results.visualize(labels="ground_truth.label")
plot.show()
input("Press enter to continue...")
def test_images():
dataset = _load_images_dataset()
results = dataset.load_brain_results("img_viz")
assert results.total_index_size == len(dataset)
assert set(dataset.values("id")) == set(results.sample_ids)
plot = results.visualize(labels="uniqueness")
plot.show()
input("Press enter to continue...")
def test_images_subset():
dataset = _load_images_dataset()
results = dataset.load_brain_results("img_viz")
view = dataset.take(10)
results.use_view(view)
assert results.index_size == len(view)
assert set(view.values("id")) == set(results.current_sample_ids)
plot = results.visualize(labels="uniqueness")
plot.show()
input("Press enter to continue...")
def test_images_missing():
dataset = _load_images_dataset().limit(4).clone()
dataset.add_samples(
[
fo.Sample(filepath="non-existent1.png"),
fo.Sample(filepath="non-existent2.png"),
fo.Sample(filepath="non-existent3.png"),
fo.Sample(filepath="non-existent4.png"),
]
)
sample_ids = dataset[:4].values("id")
results = fob.compute_visualization(dataset, batch_size=1)
assert results.total_index_size == 4
assert set(sample_ids) == set(results.sample_ids)
model = foz.load_zoo_model("inception-v3-imagenet-torch")
results = fob.compute_visualization(
dataset,
model=model,
embeddings="embeddings_missing",
batch_size=1,
)
assert len(dataset.exists("embeddings_missing")) == 4
assert results.total_index_size == 4
assert set(sample_ids) == set(results.sample_ids)
def test_patches():
dataset = _load_patches_dataset()
results = dataset.load_brain_results("gt_viz")
label_ids = dataset.values("ground_truth.detections.id", unwind=True)
assert results.total_index_size == len(label_ids)
assert set(label_ids) == set(results.label_ids)
plot = results.visualize(labels="ground_truth.detections.label")
plot.show()
input("Press enter to continue...")
def test_patches_subset():
dataset = _load_patches_dataset()
results = dataset.load_brain_results("gt_viz")
plot = results.visualize(
labels="ground_truth.detections.label",
classes=["person"],
)
plot.show()
input("Press enter to continue...")
view = dataset.filter_labels("ground_truth", F("label") == "person")
results.use_view(view)
label_ids = view.values("ground_truth.detections.id", unwind=True)
assert results.index_size == len(label_ids)
assert set(label_ids) == set(results.current_label_ids)
plot = results.visualize(labels="ground_truth.detections.label")
plot.show()
input("Press enter to continue...")
def test_patches_missing():
dataset = _load_patches_dataset().limit(4).clone()
dataset.add_samples(
[
fo.Sample(filepath="non-existent1.png"),
fo.Sample(filepath="non-existent2.png"),
fo.Sample(filepath="non-existent3.png"),
fo.Sample(filepath="non-existent4.png"),
]
)
for sample in dataset[4:]:
sample["ground_truth"] = fo.Detections(
detections=[fo.Detection(bounding_box=[0.1, 0.1, 0.8, 0.8])]
)
sample.save()
results = fob.compute_visualization(
dataset, patches_field="ground_truth", batch_size=1
)
num_patches = dataset[:4].count("ground_truth.detections")
label_ids = dataset[:4].values("ground_truth.detections.id", unwind=True)
assert results.total_index_size == num_patches
assert set(label_ids) == set(results.label_ids)
model = foz.load_zoo_model("inception-v3-imagenet-torch")
results = fob.compute_visualization(
dataset,
model=model,
patches_field="ground_truth",
embeddings="embeddings_missing",
batch_size=1,
)
view = dataset.filter_labels(
"ground_truth", F("embeddings_missing") != None
)
assert view.count("ground_truth.detections") == num_patches
assert results.total_index_size == num_patches
assert set(label_ids) == set(results.label_ids)
def test_points():
dataset = foz.load_zoo_dataset("quickstart")
n = len(dataset)
p = dataset.count("ground_truth.detections")
d = 512
points1 = np.random.rand(n, d)
results1 = fob.compute_visualization(
dataset,
points=points1,
brain_key="test1",
)
assert results1.points.shape == (n, d)
points2 = {_id: np.random.rand(d) for _id in dataset.values("id")}
results2 = fob.compute_visualization(
dataset,
points=points2,
brain_key="test2",
)
assert results2.points.shape == (n, d)
points3 = np.random.rand(p, d)
results3 = fob.compute_visualization(
dataset,
patches_field="ground_truth",
points=points3,
brain_key="test3",
)
assert results3.points.shape == (p, d)
points4 = {
_id: np.random.rand(d)
for _id in dataset.values("ground_truth.detections.id", unwind=True)
}
results4 = fob.compute_visualization(
dataset,
patches_field="ground_truth",
points=points4,
brain_key="test4",
)
assert results4.points.shape == (p, d)
dataset.delete()
def test_similarity_index():
dataset = foz.load_zoo_dataset(
"quickstart", dataset_name=fo.get_default_dataset_name()
)
# Full similarity index
similarity_index = fob.compute_similarity(
dataset, brain_key="sklearn_index", backend="sklearn"
)
results = fob.compute_visualization(
dataset,
brain_key="img_viz",
similarity_index=similarity_index,
)
assert len(results.points) == len(dataset)
# Partial similarity index
view = dataset.take(100, seed=51)
similarity_index2 = fob.compute_similarity(
view, brain_key="sklearn_index2", backend="sklearn"
)
results2 = fob.compute_visualization(
dataset,
brain_key="img_viz2",
similarity_index="sklearn_index2",
)
assert len(results2.points) == len(view)
def test_points_field():
dataset = _load_images_dataset()
num_points = len(dataset)
points = np.random.randn(num_points, 2)
brain_key = "test_points"
points_field = brain_key
fob.compute_visualization(
dataset,
brain_key=brain_key,
points=points,
create_index=True,
)
dataset.clear_cache()
results = dataset.load_brain_results(brain_key)
assert results.config.points_field == points_field
assert dataset.has_sample_field(points_field)
assert points_field in dataset.list_indexes()
sample_points = dataset.first()[points_field]
assert isinstance(sample_points, list)
assert len(sample_points) == 2
assert isinstance(sample_points[0], float)
points = results.points
assert len(points) == num_points
assert len(points[0]) == 2
all_points = dataset.values(points_field)
assert np.allclose(points, all_points)
dataset.delete_brain_run(brain_key)
assert not dataset.has_sample_field(points_field)
assert points_field not in dataset.list_indexes()
def test_points_field_patches():
dataset = _load_patches_dataset()
num_points = dataset.count("ground_truth.detections")
points = np.random.randn(num_points, 2)
brain_key = "test_points"
points_field = brain_key
points_path = f"ground_truth.detections.{points_field}"
fob.compute_visualization(
dataset,
brain_key=brain_key,
points=points,
patches_field="ground_truth",
create_index=True,
)
dataset.clear_cache()
results = dataset.load_brain_results(brain_key)
assert results.config.points_field == points_field
assert dataset.has_sample_field(points_path)
# Patch visualizations can't currently make use of database indexes
assert points_path not in dataset.list_indexes()
label_points = dataset.first().ground_truth.detections[0][points_field]
assert isinstance(label_points, list)
assert len(label_points) == 2
assert isinstance(label_points[0], float)
points = results.points
assert len(points) == num_points
assert len(points[0]) == 2
all_points = dataset.values(f"ground_truth.detections[].{points_field}")
assert np.allclose(points, all_points)
dataset.delete_brain_run(brain_key)
assert not dataset.has_sample_field(points_path)
def test_index_points():
dataset = _load_images_dataset()
num_points = len(dataset)
points = np.random.randn(num_points, 2)
brain_key = "test_points"
points_field = brain_key
fob.compute_visualization(dataset, brain_key=brain_key, points=points)
dataset.clear_cache()
results = dataset.load_brain_results(brain_key)
assert results.config.points_field is None
assert not dataset.has_sample_field(points_field)
assert points_field not in dataset.list_indexes()
results.index_points()
dataset.clear_cache()
results = dataset.load_brain_results(brain_key)
assert results.config.points_field == points_field
assert dataset.has_sample_field(points_field)
assert points_field in dataset.list_indexes()
points = results.points
all_points = dataset.values(points_field)
assert np.allclose(points, all_points)
results.remove_index()
dataset.clear_cache()
results = dataset.load_brain_results(brain_key)
assert results.config.points_field is None
assert not dataset.has_sample_field(points_field)
assert points_field not in dataset.list_indexes()
def test_index_points_patches():
dataset = _load_patches_dataset()
num_points = dataset.count("ground_truth.detections")
points = np.random.randn(num_points, 2)
brain_key = "test_points"
points_field = brain_key
points_path = f"ground_truth.detections.{points_field}"
fob.compute_visualization(
dataset,
brain_key=brain_key,
points=points,
patches_field="ground_truth",
)
dataset.clear_cache()
results = dataset.load_brain_results(brain_key)
assert results.config.points_field is None
assert not dataset.has_sample_field(points_path)
results.index_points()
dataset.clear_cache()
results = dataset.load_brain_results(brain_key)
assert results.config.points_field == points_field
assert dataset.has_sample_field(points_path)
points = results.points
all_points = dataset.values(f"ground_truth.detections[].{points_field}")
assert np.allclose(points, all_points)
results.remove_index()
dataset.clear_cache()
results = dataset.load_brain_results(brain_key)
assert results.config.points_field is None
assert not dataset.has_sample_field(points_path)
def _load_images_dataset():
name = "test-visualization-images"
if fo.dataset_exists(name):
return fo.load_dataset(name)
return _make_images_dataset(name)
def _load_patches_dataset():
name = "test-visualization-patches"
if fo.dataset_exists(name):
return fo.load_dataset(name)
return _make_patches_dataset(name)
def _make_images_dataset(name):
dataset = foz.load_zoo_dataset(
"quickstart", max_samples=20, dataset_name=name
)
model = foz.load_zoo_model("inception-v3-imagenet-torch")
# Embed images
dataset.compute_embeddings(
model, embeddings_field="embeddings", batch_size=8
)
# Image visualization
fob.compute_visualization(
dataset,
embeddings="embeddings",
num_dims=2,
verbose=True,
seed=51,
brain_key="img_viz",
)
return dataset
def _make_patches_dataset(name):
dataset = foz.load_zoo_dataset(
"quickstart", max_samples=20, dataset_name=name
)
model = foz.load_zoo_model("inception-v3-imagenet-torch")
# Embed ground truth patches
dataset.compute_patch_embeddings(
model,
"ground_truth",
embeddings_field="embeddings",
batch_size=8,
force_square=True,
)
# Patch visualization
fob.compute_visualization(
dataset,
patches_field="ground_truth",
embeddings="embeddings",
num_dims=2,
verbose=True,
seed=51,
brain_key="gt_viz",
)
return dataset
if __name__ == "__main__":
fo.config.show_progress_bars = True
unittest.main(verbosity=2)
================================================
FILE: tests/models/test_simple_resnet.py
================================================
"""
Tests for :mod:`fiftyone.brain.internal.models.simple_resnet`.
| Copyright 2017-2026, Voxel51, Inc.
| `voxel51.com `_
|
"""
import imageio
from PIL import Image
import torch
import eta.core.image as etai
import fiftyone as fo
import fiftyone.core.utils as fou
import fiftyone.zoo as foz
import fiftyone.brain.internal.models as fbm
def _transpose(x, source, target):
return x.permute([source.index(d) for d in target])
def _check_prediction(actual, expected):
assert isinstance(actual, fo.Classification)
assert isinstance(expected, fo.Classification)
# @todo fix me on 3.9
# assert actual.label == expected.label
def test_simple_resnet():
dataset = foz.load_zoo_dataset(
"cifar10",
split="test",
dataset_name=fo.get_default_dataset_name(),
shuffle=True,
max_samples=1,
)
sample = dataset.first()
filepath = sample.filepath
print("Working on image at %s" % filepath)
img_pil = Image.open(filepath)
print("img_pil is type %s" % type(img_pil))
img_numpy = imageio.imread(filepath)
print("img_numpy is type %s" % type(img_numpy))
print(img_numpy.shape)
img_torch = torch.from_numpy(img_numpy)
img_torch = _transpose(img_torch, "HWC", "CHW")
print("img_torch is type %s" % type(img_torch))
print(img_torch.shape)
assert tuple(reversed(img_torch.shape)) == img_numpy.shape
img_eta = etai.read(filepath)
print("img_eta is type %s" % type(img_eta))
print(img_eta.shape)
assert tuple(img_eta.shape) == img_numpy.shape
model = fbm.load_model("simple-resnet-cifar10")
with model:
print("PIL")
p_pil = model.predict(img_pil)
print(p_pil)
print("IMAGEIO")
p_numpy = model.predict(img_numpy)
print(p_numpy)
_check_prediction(p_numpy, p_pil)
print("ETA")
p_eta = model.predict(img_eta)
print(p_eta)
_check_prediction(p_eta, p_pil)
print("PIL (manual preprocessing)")
with fou.SetAttributes(model, preprocess=False):
img_tensor = model.transforms(img_pil)
p_pil2 = model.predict(img_tensor)
print(p_pil2)
_check_prediction(p_pil2, p_pil)
print("IMAGEIO (manual preprocessing)")
with fou.SetAttributes(model, preprocess=False):
img_tensor = model.transforms(img_numpy)
p_numpy2 = model.predict(img_tensor)
print(p_numpy2)
_check_prediction(p_numpy2, p_numpy)
if __name__ == "__main__":
test_simple_resnet()
================================================
FILE: tests/test_uniqueness.py
================================================
"""
Uniqueness tests.
| Copyright 2017-2026, Voxel51, Inc.
| `voxel51.com `_
|
"""
import os
import unittest
import eta.core.storage as etas
import eta.core.utils as etau
import fiftyone as fo
import fiftyone.brain as fob
import fiftyone.zoo as foz
def test_uniqueness():
dataset = foz.load_zoo_dataset("cifar10", split="test")
assert "uniqueness" not in dataset.get_field_schema()
view = dataset.view().take(100)
fob.compute_uniqueness(view)
print(dataset)
assert "uniqueness" in dataset.get_field_schema()
def test_gray():
"""Test default support for handling grayscale images.
Requires Voxel51 Google Drive credentials to download the test data.
"""
with etau.TempDir() as tmpdir:
tmp_zip = os.path.join(tmpdir, "data.zip")
tmp_data = os.path.join(tmpdir, "brain_grayscale_test_data")
client = etas.GoogleDriveStorageClient()
client.download("1ECeNnLmKQCHxlVdRqGefV5eXOD_OkmWx", tmp_zip)
etau.extract_zip(tmp_zip, delete_zip=True)
dataset = fo.Dataset.from_dir(tmp_data, fo.types.ImageDirectory)
fob.compute_uniqueness(dataset)
print(dataset)
if __name__ == "__main__":
fo.config.show_progress_bars = True
unittest.main(verbosity=2)