Showing preview only (612K chars total). Download the full file or copy to clipboard to get everything.
Repository: facebookresearch/vggt
Branch: main
Commit: 44b3afbd1869
Files: 83
Total size: 584.7 KB
Directory structure:
gitextract_yzgcgvr3/
├── .gitattributes
├── .gitignore
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE.txt
├── README.md
├── demo_colmap.py
├── demo_gradio.py
├── demo_viser.py
├── docs/
│ └── package.md
├── pyproject.toml
├── requirements.txt
├── requirements_demo.txt
├── training/
│ ├── README.md
│ ├── __init__.py
│ ├── config/
│ │ ├── default.yaml
│ │ └── default_dataset.yaml
│ ├── data/
│ │ ├── __init__.py
│ │ ├── augmentation.py
│ │ ├── base_dataset.py
│ │ ├── composed_dataset.py
│ │ ├── dataset_util.py
│ │ ├── datasets/
│ │ │ ├── co3d.py
│ │ │ └── vkitti.py
│ │ ├── dynamic_dataloader.py
│ │ ├── preprocess/
│ │ │ └── vkitti.sh
│ │ ├── track_util.py
│ │ └── worker_fn.py
│ ├── launch.py
│ ├── loss.py
│ ├── train_utils/
│ │ ├── __init__.py
│ │ ├── checkpoint.py
│ │ ├── distributed.py
│ │ ├── freeze.py
│ │ ├── general.py
│ │ ├── gradient_clip.py
│ │ ├── logging.py
│ │ ├── normalization.py
│ │ ├── optimizer.py
│ │ └── tb_writer.py
│ └── trainer.py
├── vggt/
│ ├── dependency/
│ │ ├── __init__.py
│ │ ├── distortion.py
│ │ ├── np_to_pycolmap.py
│ │ ├── projection.py
│ │ ├── track_modules/
│ │ │ ├── __init__.py
│ │ │ ├── base_track_predictor.py
│ │ │ ├── blocks.py
│ │ │ ├── modules.py
│ │ │ ├── track_refine.py
│ │ │ └── utils.py
│ │ ├── track_predict.py
│ │ ├── vggsfm_tracker.py
│ │ └── vggsfm_utils.py
│ ├── heads/
│ │ ├── camera_head.py
│ │ ├── dpt_head.py
│ │ ├── head_act.py
│ │ ├── track_head.py
│ │ ├── track_modules/
│ │ │ ├── __init__.py
│ │ │ ├── base_track_predictor.py
│ │ │ ├── blocks.py
│ │ │ ├── modules.py
│ │ │ └── utils.py
│ │ └── utils.py
│ ├── layers/
│ │ ├── __init__.py
│ │ ├── attention.py
│ │ ├── block.py
│ │ ├── drop_path.py
│ │ ├── layer_scale.py
│ │ ├── mlp.py
│ │ ├── patch_embed.py
│ │ ├── rope.py
│ │ ├── swiglu_ffn.py
│ │ └── vision_transformer.py
│ ├── models/
│ │ ├── aggregator.py
│ │ └── vggt.py
│ └── utils/
│ ├── geometry.py
│ ├── helper.py
│ ├── load_fn.py
│ ├── pose_enc.py
│ ├── rotation.py
│ └── visual_track.py
└── visual_util.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitattributes
================================================
# SCM syntax highlighting & preventing 3-way merges
pixi.lock merge=binary linguist-language=YAML linguist-generated=true
================================================
FILE: .gitignore
================================================
.hydra/
output/
ckpt/
# Byte-compiled / optimized / DLL files
__pycache__/
**/__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Profiling data
.prof
# Folder specific to your needs
**/tmp/
**/outputs/skyseg.onnx
skyseg.onnx
# pixi environments
.pixi
*.egg-info
================================================
FILE: CODE_OF_CONDUCT.md
================================================
# Code of Conduct
## Our Pledge
In the interest of fostering an open and welcoming environment, we as
contributors and maintainers pledge to make participation in our project and
our community a harassment-free experience for everyone, regardless of age, body
size, disability, ethnicity, sex characteristics, gender identity and expression,
level of experience, education, socio-economic status, nationality, personal
appearance, race, religion, or sexual identity and orientation.
## Our Standards
Examples of behavior that contributes to creating a positive environment
include:
* Using welcoming and inclusive language
* Being respectful of differing viewpoints and experiences
* Gracefully accepting constructive criticism
* Focusing on what is best for the community
* Showing empathy towards other community members
Examples of unacceptable behavior by participants include:
* The use of sexualized language or imagery and unwelcome sexual attention or
advances
* Trolling, insulting/derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or electronic
address, without explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Our Responsibilities
Project maintainers are responsible for clarifying the standards of acceptable
behavior and are expected to take appropriate and fair corrective action in
response to any instances of unacceptable behavior.
Project maintainers have the right and responsibility to remove, edit, or
reject comments, commits, code, wiki edits, issues, and other contributions
that are not aligned to this Code of Conduct, or to ban temporarily or
permanently any contributor for other behaviors that they deem inappropriate,
threatening, offensive, or harmful.
## Scope
This Code of Conduct applies within all project spaces, and it also applies when
an individual is representing the project or its community in public spaces.
Examples of representing a project or community include using an official
project e-mail address, posting via an official social media account, or acting
as an appointed representative at an online or offline event. Representation of
a project may be further defined and clarified by project maintainers.
This Code of Conduct also applies outside the project spaces when there is a
reasonable belief that an individual's behavior may have a negative impact on
the project or its community.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported by contacting the project team at <opensource-conduct@meta.com>. All
complaints will be reviewed and investigated and will result in a response that
is deemed necessary and appropriate to the circumstances. The project team is
obligated to maintain confidentiality with regard to the reporter of an incident.
Further details of specific enforcement policies may be posted separately.
Project maintainers who do not follow or enforce the Code of Conduct in good
faith may face temporary or permanent repercussions as determined by other
members of the project's leadership.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see
https://www.contributor-covenant.org/faq
================================================
FILE: CONTRIBUTING.md
================================================
# Contributing to vggt
We want to make contributing to this project as easy and transparent as
possible.
## Pull Requests
We actively welcome your pull requests.
1. Fork the repo and create your branch from `main`.
2. If you've added code that should be tested, add tests.
3. If you've changed APIs, update the documentation.
4. Ensure the test suite passes.
5. Make sure your code lints.
6. If you haven't already, complete the Contributor License Agreement ("CLA").
## Contributor License Agreement ("CLA")
In order to accept your pull request, we need you to submit a CLA. You only need
to do this once to work on any of Facebook's open source projects.
Complete your CLA here: <https://code.facebook.com/cla>
## Issues
We use GitHub issues to track public bugs. Please ensure your description is
clear and has sufficient instructions to be able to reproduce the issue.
Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
disclosure of security bugs. In those cases, please go through the process
outlined on that page and do not file a public issue.
## License
By contributing to vggt, you agree that your contributions will be licensed
under the LICENSE file in the root directory of this source tree.
================================================
FILE: LICENSE.txt
================================================
VGGT License
v1 Last Updated: July 29, 2025
“Acceptable Use Policy” means the Acceptable Use Policy, applicable to Research Materials, that is incorporated into this Agreement.
“Agreement” means the terms and conditions for use, reproduction, distribution and modification of the Research Materials set forth herein.
“Documentation” means the specifications, manuals and documentation accompanying
Research Materials distributed by Meta.
“Licensee” or “you” means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity’s behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
“Meta” or “we” means Meta Platforms Ireland Limited (if you are located in or, if you are an entity, your principal place of business is in the EEA or Switzerland) and Meta Platforms, Inc. (if you are located outside of the EEA or Switzerland).
“Research Materials” means, collectively, Documentation and the models, software and algorithms, including machine-learning model code, trained model weights, inference-enabling code, training-enabling code, fine-tuning enabling code, demonstration materials and other elements of the foregoing distributed by Meta and made available under this Agreement.
By clicking “I Accept” below or by using or distributing any portion or element of the Research Materials, you agree to be bound by this Agreement.
1. License Rights and Redistribution.
a. Grant of Rights. You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Meta’s intellectual property or other rights owned by Meta embodied in the Research Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the Research Materials.
b. Redistribution and Use.
i. Distribution of Research Materials, and any derivative works thereof, are subject to the terms of this Agreement. If you distribute or make the Research Materials, or any derivative works thereof, available to a third party, you may only do so under the terms of this Agreement. You shall also provide a copy of this Agreement to such third party.
ii. If you submit for publication the results of research you perform on, using, or otherwise in connection with Research Materials, you must acknowledge the use of Research Materials in your publication.
iii. Your use of the Research Materials must comply with applicable laws and regulations (including Trade Control Laws) and adhere to the Acceptable Use Policy, which is hereby incorporated by reference into this Agreement.
2. User Support. Your use of the Research Materials is done at your own discretion; Meta does not process any information nor provide any service in relation to such use. Meta is under no obligation to provide any support services for the Research Materials. Any support provided is “as is”, “with all faults”, and without warranty of any kind.
3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE RESEARCH MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN “AS IS” BASIS, WITHOUT WARRANTIES OF ANY KIND, AND META DISCLAIMS ALL WARRANTIES OF ANY KIND, BOTH EXPRESS AND IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE RESEARCH MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE RESEARCH MATERIALS AND ANY OUTPUT AND RESULTS.
4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT OR INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
5. Intellectual Property.
a. Subject to Meta’s ownership of Research Materials and derivatives made by or for Meta, with respect to any derivative works and modifications of the Research Materials that are made by you, as between you and Meta, you are and will be the owner of such derivative works and modifications.
b. If you institute litigation or other proceedings against Meta or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Research Materials, outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Meta from and against any claim by any third party arising out of or related to your use or distribution of the Research Materials.
6. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Research Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Meta may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the Research Materials. Sections 5, 6 and 9 shall survive the termination of this Agreement.
7. Governing Law and Jurisdiction. This Agreement will be governed and construed under the laws of the State of California without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. The courts of California shall have exclusive jurisdiction of any dispute arising out of this Agreement.
8. Modifications and Amendments. Meta may modify this Agreement from time to time; provided that they are similar in spirit to the current version of the Agreement, but may differ in detail to address new problems or concerns. All such changes will be effective immediately. Your continued use of the Research Materials after any modification to this Agreement constitutes your agreement to such modification. Except as provided in this Agreement, no modification or addition to any provision of this Agreement will be binding unless it is in writing and signed by an authorized representative of both you and Meta.
Acceptable Use Policy
Meta seeks to further understanding of new and existing research domains with the mission of advancing the state-of-the-art in artificial intelligence through open research for the benefit of all.
As part of this mission, Meta makes certain research materials available for use in accordance with this Agreement (including the Acceptable Use Policy). Meta is committed to promoting the safe and responsible use of such research materials.
Prohibited Uses
You agree you will not use, or allow others to use, Research Materials to:
Violate the law or others’ rights, including to:
Engage in, promote, generate, contribute to, encourage, plan, incite, or further illegal or unlawful activity or content, such as:
Violence or terrorism
Exploitation or harm to children, including the solicitation, creation, acquisition, or dissemination of child exploitative content or failure to report Child Sexual Abuse Material
Human trafficking, exploitation, and sexual violence
The illegal distribution of information or materials to minors, including obscene materials, or failure to employ legally required age-gating in connection with such information or materials.
Sexual solicitation
Any other criminal activity
Engage in, promote, incite, or facilitate the harassment, abuse, threatening, or bullying of individuals or groups of individuals
Engage in, promote, incite, or facilitate discrimination or other unlawful or harmful conduct in the provision of employment, employment benefits, credit, housing, other economic benefits, or other essential goods and services
Engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or related professional practices
Collect, process, disclose, generate, or infer health, demographic, or other sensitive personal or private information about individuals without rights and consents required by applicable laws
Engage in or facilitate any action or generate any content that infringes, misappropriates, or otherwise violates any third-party rights, including the outputs or results of any technology using Research Materials
Create, generate, or facilitate the creation of malicious code, malware, computer viruses or do anything else that could disable, overburden, interfere with or impair the proper working, integrity, operation or appearance of a website or computer system
2. Engage in, promote, incite, facilitate, or assist in the planning or development of activities that present a risk of death or bodily harm to individuals, including use of research artifacts related to the following:
Military, warfare, nuclear industries or applications, espionage, use for materials or activities that are subject to the International Traffic Arms Regulations (ITAR) maintained by the United States Department of State
Guns and illegal weapons (including weapon development)
Illegal drugs and regulated/controlled substances
Operation of critical infrastructure, transportation technologies, or heavy machinery
Self-harm or harm to others, including suicide, cutting, and eating disorders
Any content intended to incite or promote violence, abuse, or any infliction of bodily harm to an individual
3. Intentionally deceive or mislead others, including use of Research Materials related to the following:
Generating, promoting, or furthering fraud or the creation or promotion of disinformation
Generating, promoting, or furthering defamatory content, including the creation of defamatory statements, images, or other content
Generating, promoting, or further distributing spam
Impersonating another individual without consent, authorization, or legal right
Representing that outputs of research materials or outputs from technology using Research Materials are human-generated
Generating or facilitating false online engagement, including fake reviews and other means of fake online engagement
4. Fail to appropriately disclose to end users any known dangers of your Research Materials.
================================================
FILE: README.md
================================================
<div align="center">
<h1>VGGT: Visual Geometry Grounded Transformer</h1>
<a href="https://jytime.github.io/data/VGGT_CVPR25.pdf" target="_blank" rel="noopener noreferrer">
<img src="https://img.shields.io/badge/Paper-VGGT" alt="Paper PDF">
</a>
<a href="https://arxiv.org/abs/2503.11651"><img src="https://img.shields.io/badge/arXiv-2503.11651-b31b1b" alt="arXiv"></a>
<a href="https://vgg-t.github.io/"><img src="https://img.shields.io/badge/Project_Page-green" alt="Project Page"></a>
<a href="https://huggingface.co/spaces/facebook/vggt"><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Demo-blue'></a>
**[Visual Geometry Group, University of Oxford](https://www.robots.ox.ac.uk/~vgg/)**; **[Meta AI](https://ai.facebook.com/research/)**
[Jianyuan Wang](https://jytime.github.io/), [Minghao Chen](https://silent-chen.github.io/), [Nikita Karaev](https://nikitakaraevv.github.io/), [Andrea Vedaldi](https://www.robots.ox.ac.uk/~vedaldi/), [Christian Rupprecht](https://chrirupp.github.io/), [David Novotny](https://d-novotny.github.io/)
</div>
```bibtex
@inproceedings{wang2025vggt,
title={VGGT: Visual Geometry Grounded Transformer},
author={Wang, Jianyuan and Chen, Minghao and Karaev, Nikita and Vedaldi, Andrea and Rupprecht, Christian and Novotny, David},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
year={2025}
}
```
## Updates
- [July 29, 2025] We've updated the license for VGGT to permit **commercial use** (excluding military applications). All code in this repository is now under a commercial-use-friendly license. However, only the newly released checkpoint [**VGGT-1B-Commercial**](https://huggingface.co/facebook/VGGT-1B-Commercial) is licensed for commercial usage — the original checkpoint remains non-commercial. Full license details are available [here](https://github.com/facebookresearch/vggt/blob/main/LICENSE.txt). Access to the checkpoint requires completing an application form, which is processed by a system similar to LLaMA's approval workflow, automatically. The new checkpoint delivers similar performance to the original model. Please submit an issue if you notice a significant performance discrepancy.
- [July 6, 2025] Training code is now available in the `training` folder, including an example to finetune VGGT on a custom dataset.
- [June 13, 2025] Honored to receive the Best Paper Award at CVPR 2025! Apologies if I’m slow to respond to queries or GitHub issues these days. If you’re interested, our oral presentation is available [here](https://docs.google.com/presentation/d/1JVuPnuZx6RgAy-U5Ezobg73XpBi7FrOh/edit?usp=sharing&ouid=107115712143490405606&rtpof=true&sd=true). Another long presentation can be found [here](https://docs.google.com/presentation/d/1aSv0e5PmH1mnwn2MowlJIajFUYZkjqgw/edit?usp=sharing&ouid=107115712143490405606&rtpof=true&sd=true) (Note: it’s shared in .pptx format with animations — quite large, but feel free to use it as a template if helpful.)
- [June 2, 2025] Added a script to run VGGT and save predictions in COLMAP format, with bundle adjustment support optional. The saved COLMAP files can be directly used with [gsplat](https://github.com/nerfstudio-project/gsplat) or other NeRF/Gaussian splatting libraries.
- [May 3, 2025] Evaluation code for reproducing our camera pose estimation results on Co3D is now available in the [evaluation](https://github.com/facebookresearch/vggt/tree/evaluation) branch.
## Overview
Visual Geometry Grounded Transformer (VGGT, CVPR 2025) is a feed-forward neural network that directly infers all key 3D attributes of a scene, including extrinsic and intrinsic camera parameters, point maps, depth maps, and 3D point tracks, **from one, a few, or hundreds of its views, within seconds**.
## Quick Start
First, clone this repository to your local machine, and install the dependencies (torch, torchvision, numpy, Pillow, and huggingface_hub).
```bash
git clone git@github.com:facebookresearch/vggt.git
cd vggt
pip install -r requirements.txt
```
Alternatively, you can install VGGT as a package (<a href="docs/package.md">click here</a> for details).
Now, try the model with just a few lines of code:
```python
import torch
from vggt.models.vggt import VGGT
from vggt.utils.load_fn import load_and_preprocess_images
device = "cuda" if torch.cuda.is_available() else "cpu"
# bfloat16 is supported on Ampere GPUs (Compute Capability 8.0+)
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
# Initialize the model and load the pretrained weights.
# This will automatically download the model weights the first time it's run, which may take a while.
model = VGGT.from_pretrained("facebook/VGGT-1B").to(device)
# Load and preprocess example images (replace with your own image paths)
image_names = ["path/to/imageA.png", "path/to/imageB.png", "path/to/imageC.png"]
images = load_and_preprocess_images(image_names).to(device)
with torch.no_grad():
with torch.cuda.amp.autocast(dtype=dtype):
# Predict attributes including cameras, depth maps, and point maps.
predictions = model(images)
```
The model weights will be automatically downloaded from Hugging Face. If you encounter issues such as slow loading, you can manually download them [here](https://huggingface.co/facebook/VGGT-1B/blob/main/model.pt) and load, or:
```python
model = VGGT()
_URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt"
model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
```
## Detailed Usage
<details>
<summary>Click to expand</summary>
You can also optionally choose which attributes (branches) to predict, as shown below. This achieves the same result as the example above. This example uses a batch size of 1 (processing a single scene), but it naturally works for multiple scenes.
```python
from vggt.utils.pose_enc import pose_encoding_to_extri_intri
from vggt.utils.geometry import unproject_depth_map_to_point_map
with torch.no_grad():
with torch.cuda.amp.autocast(dtype=dtype):
images = images[None] # add batch dimension
aggregated_tokens_list, ps_idx = model.aggregator(images)
# Predict Cameras
pose_enc = model.camera_head(aggregated_tokens_list)[-1]
# Extrinsic and intrinsic matrices, following OpenCV convention (camera from world)
extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc, images.shape[-2:])
# Predict Depth Maps
depth_map, depth_conf = model.depth_head(aggregated_tokens_list, images, ps_idx)
# Predict Point Maps
point_map, point_conf = model.point_head(aggregated_tokens_list, images, ps_idx)
# Construct 3D Points from Depth Maps and Cameras
# which usually leads to more accurate 3D points than point map branch
point_map_by_unprojection = unproject_depth_map_to_point_map(depth_map.squeeze(0),
extrinsic.squeeze(0),
intrinsic.squeeze(0))
# Predict Tracks
# choose your own points to track, with shape (N, 2) for one scene
query_points = torch.FloatTensor([[100.0, 200.0],
[60.72, 259.94]]).to(device)
track_list, vis_score, conf_score = model.track_head(aggregated_tokens_list, images, ps_idx, query_points=query_points[None])
```
Furthermore, if certain pixels in the input frames are unwanted (e.g., reflective surfaces, sky, or water), you can simply mask them by setting the corresponding pixel values to 0 or 1. Precise segmentation masks aren't necessary - simple bounding box masks work effectively (check this [issue](https://github.com/facebookresearch/vggt/issues/47) for an example).
</details>
## Interactive Demo
We provide multiple ways to visualize your 3D reconstructions. Before using these visualization tools, install the required dependencies:
```bash
pip install -r requirements_demo.txt
```
### Interactive 3D Visualization
**Please note:** VGGT typically reconstructs a scene in less than 1 second. However, visualizing 3D points may take tens of seconds due to third-party rendering, independent of VGGT's processing time. The visualization is slow especially when the number of images is large.
#### Gradio Web Interface
Our Gradio-based interface allows you to upload images/videos, run reconstruction, and interactively explore the 3D scene in your browser. You can launch this in your local machine or try it on [Hugging Face](https://huggingface.co/spaces/facebook/vggt).
```bash
python demo_gradio.py
```
<details>
<summary>Click to preview the Gradio interactive interface</summary>

</details>
#### Viser 3D Viewer
Run the following command to run reconstruction and visualize the point clouds in viser. Note this script requires a path to a folder containing images. It assumes only image files under the folder. You can set `--use_point_map` to use the point cloud from the point map branch, instead of the depth-based point cloud.
```bash
python demo_viser.py --image_folder path/to/your/images/folder
```
## Exporting to COLMAP Format
We also support exporting VGGT's predictions directly to COLMAP format, by:
```bash
# Feedforward prediction only
python demo_colmap.py --scene_dir=/YOUR/SCENE_DIR/
# With bundle adjustment
python demo_colmap.py --scene_dir=/YOUR/SCENE_DIR/ --use_ba
# Run with bundle adjustment using reduced parameters for faster processing
# Reduces max_query_pts from 4096 (default) to 2048 and query_frame_num from 8 (default) to 5
# Trade-off: Faster execution but potentially less robust reconstruction in complex scenes (you may consider setting query_frame_num equal to your total number of images)
# See demo_colmap.py for additional bundle adjustment configuration options
python demo_colmap.py --scene_dir=/YOUR/SCENE_DIR/ --use_ba --max_query_pts=2048 --query_frame_num=5
```
Please ensure that the images are stored in `/YOUR/SCENE_DIR/images/`. This folder should contain only the images. Check the examples folder for the desired data structure.
The reconstruction result (camera parameters and 3D points) will be automatically saved under `/YOUR/SCENE_DIR/sparse/` in the COLMAP format, such as:
```
SCENE_DIR/
├── images/
└── sparse/
├── cameras.bin
├── images.bin
└── points3D.bin
```
## Integration with Gaussian Splatting
The exported COLMAP files can be directly used with [gsplat](https://github.com/nerfstudio-project/gsplat) for Gaussian Splatting training. Install `gsplat` following their official instructions (we recommend `gsplat==1.3.0`):
An example command to train the model is:
```
cd gsplat
python examples/simple_trainer.py default --data_factor 1 --data_dir /YOUR/SCENE_DIR/ --result_dir /YOUR/RESULT_DIR/
```
## Zero-shot Single-view Reconstruction
Our model shows surprisingly good performance on single-view reconstruction, although it was never trained for this task. The model does not need to duplicate the single-view image to a pair, instead, it can directly infer the 3D structure from the tokens of the single view image. Feel free to try it with our demos above, which naturally works for single-view reconstruction.
We did not quantitatively test monocular depth estimation performance ourselves, but [@kabouzeid](https://github.com/kabouzeid) generously provided a comparison of VGGT to recent methods [here](https://github.com/facebookresearch/vggt/issues/36). VGGT shows competitive or better results compared to state-of-the-art monocular approaches such as DepthAnything v2 or MoGe, despite never being explicitly trained for single-view tasks.
## Runtime and GPU Memory
We benchmark the runtime and GPU memory usage of VGGT's aggregator on a single NVIDIA H100 GPU across various input sizes.
| **Input Frames** | 1 | 2 | 4 | 8 | 10 | 20 | 50 | 100 | 200 |
|:----------------:|:-:|:-:|:-:|:-:|:--:|:--:|:--:|:---:|:---:|
| **Time (s)** | 0.04 | 0.05 | 0.07 | 0.11 | 0.14 | 0.31 | 1.04 | 3.12 | 8.75 |
| **Memory (GB)** | 1.88 | 2.07 | 2.45 | 3.23 | 3.63 | 5.58 | 11.41 | 21.15 | 40.63 |
Note that these results were obtained using Flash Attention 3, which is faster than the default Flash Attention 2 implementation while maintaining almost the same memory usage. Feel free to compile Flash Attention 3 from source to get better performance.
## Research Progression
Our work builds upon a series of previous research projects. If you're interested in understanding how our research evolved, check out our previous works:
<table border="0" cellspacing="0" cellpadding="0">
<tr>
<td align="left">
<a href="https://github.com/jytime/Deep-SfM-Revisited">Deep SfM Revisited</a>
</td>
<td style="white-space: pre;">──┐</td>
<td></td>
</tr>
<tr>
<td align="left">
<a href="https://github.com/facebookresearch/PoseDiffusion">PoseDiffusion</a>
</td>
<td style="white-space: pre;">─────►</td>
<td>
<a href="https://github.com/facebookresearch/vggsfm">VGGSfM</a> ──►
<a href="https://github.com/facebookresearch/vggt">VGGT</a>
</td>
</tr>
<tr>
<td align="left">
<a href="https://github.com/facebookresearch/co-tracker">CoTracker</a>
</td>
<td style="white-space: pre;">──┘</td>
<td></td>
</tr>
</table>
## Acknowledgements
Thanks to these great repositories: [PoseDiffusion](https://github.com/facebookresearch/PoseDiffusion), [VGGSfM](https://github.com/facebookresearch/vggsfm), [CoTracker](https://github.com/facebookresearch/co-tracker), [DINOv2](https://github.com/facebookresearch/dinov2), [Dust3r](https://github.com/naver/dust3r), [Moge](https://github.com/microsoft/moge), [PyTorch3D](https://github.com/facebookresearch/pytorch3d), [Sky Segmentation](https://github.com/xiongzhu666/Sky-Segmentation-and-Post-processing), [Depth Anything V2](https://github.com/DepthAnything/Depth-Anything-V2), [Metric3D](https://github.com/YvanYin/Metric3D) and many other inspiring works in the community.
## Checklist
- [x] Release the training code
- [ ] Release VGGT-500M and VGGT-200M
## License
See the [LICENSE](./LICENSE.txt) file for details about the license under which this code is made available.
Please note that only this [model checkpoint](https://huggingface.co/facebook/VGGT-1B-Commercial) allows commercial usage. This new checkpoint achieves the same performance level (might be slightly better) as the original one, e.g., AUC@30: 90.37 vs. 89.98 on the Co3D dataset.
================================================
FILE: demo_colmap.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import random
import numpy as np
import glob
import os
import copy
import torch
import torch.nn.functional as F
# Configure CUDA settings
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
import argparse
from pathlib import Path
import trimesh
import pycolmap
from vggt.models.vggt import VGGT
from vggt.utils.load_fn import load_and_preprocess_images_square
from vggt.utils.pose_enc import pose_encoding_to_extri_intri
from vggt.utils.geometry import unproject_depth_map_to_point_map
from vggt.utils.helper import create_pixel_coordinate_grid, randomly_limit_trues
from vggt.dependency.track_predict import predict_tracks
from vggt.dependency.np_to_pycolmap import batch_np_matrix_to_pycolmap, batch_np_matrix_to_pycolmap_wo_track
# TODO: add support for masks
# TODO: add iterative BA
# TODO: add support for radial distortion, which needs extra_params
# TODO: test with more cases
# TODO: test different camera types
def parse_args():
parser = argparse.ArgumentParser(description="VGGT Demo")
parser.add_argument("--scene_dir", type=str, required=True, help="Directory containing the scene images")
parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
parser.add_argument("--use_ba", action="store_true", default=False, help="Use BA for reconstruction")
######### BA parameters #########
parser.add_argument(
"--max_reproj_error", type=float, default=8.0, help="Maximum reprojection error for reconstruction"
)
parser.add_argument("--shared_camera", action="store_true", default=False, help="Use shared camera for all images")
parser.add_argument("--camera_type", type=str, default="SIMPLE_PINHOLE", help="Camera type for reconstruction")
parser.add_argument("--vis_thresh", type=float, default=0.2, help="Visibility threshold for tracks")
parser.add_argument("--query_frame_num", type=int, default=8, help="Number of frames to query")
parser.add_argument("--max_query_pts", type=int, default=4096, help="Maximum number of query points")
parser.add_argument(
"--fine_tracking", action="store_true", default=True, help="Use fine tracking (slower but more accurate)"
)
parser.add_argument(
"--conf_thres_value", type=float, default=5.0, help="Confidence threshold value for depth filtering (wo BA)"
)
return parser.parse_args()
def run_VGGT(model, images, dtype, resolution=518):
# images: [B, 3, H, W]
assert len(images.shape) == 4
assert images.shape[1] == 3
# hard-coded to use 518 for VGGT
images = F.interpolate(images, size=(resolution, resolution), mode="bilinear", align_corners=False)
with torch.no_grad():
with torch.cuda.amp.autocast(dtype=dtype):
images = images[None] # add batch dimension
aggregated_tokens_list, ps_idx = model.aggregator(images)
# Predict Cameras
pose_enc = model.camera_head(aggregated_tokens_list)[-1]
# Extrinsic and intrinsic matrices, following OpenCV convention (camera from world)
extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc, images.shape[-2:])
# Predict Depth Maps
depth_map, depth_conf = model.depth_head(aggregated_tokens_list, images, ps_idx)
extrinsic = extrinsic.squeeze(0).cpu().numpy()
intrinsic = intrinsic.squeeze(0).cpu().numpy()
depth_map = depth_map.squeeze(0).cpu().numpy()
depth_conf = depth_conf.squeeze(0).cpu().numpy()
return extrinsic, intrinsic, depth_map, depth_conf
def demo_fn(args):
# Print configuration
print("Arguments:", vars(args))
# Set seed for reproducibility
np.random.seed(args.seed)
torch.manual_seed(args.seed)
random.seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed) # for multi-GPU
print(f"Setting seed as: {args.seed}")
# Set device and dtype
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
print(f"Using dtype: {dtype}")
# Run VGGT for camera and depth estimation
model = VGGT()
_URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt"
model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
model.eval()
model = model.to(device)
print(f"Model loaded")
# Get image paths and preprocess them
image_dir = os.path.join(args.scene_dir, "images")
image_path_list = glob.glob(os.path.join(image_dir, "*"))
if len(image_path_list) == 0:
raise ValueError(f"No images found in {image_dir}")
base_image_path_list = [os.path.basename(path) for path in image_path_list]
# Load images and original coordinates
# Load Image in 1024, while running VGGT with 518
vggt_fixed_resolution = 518
img_load_resolution = 1024
images, original_coords = load_and_preprocess_images_square(image_path_list, img_load_resolution)
images = images.to(device)
original_coords = original_coords.to(device)
print(f"Loaded {len(images)} images from {image_dir}")
# Run VGGT to estimate camera and depth
# Run with 518x518 images
extrinsic, intrinsic, depth_map, depth_conf = run_VGGT(model, images, dtype, vggt_fixed_resolution)
points_3d = unproject_depth_map_to_point_map(depth_map, extrinsic, intrinsic)
if args.use_ba:
image_size = np.array(images.shape[-2:])
scale = img_load_resolution / vggt_fixed_resolution
shared_camera = args.shared_camera
with torch.cuda.amp.autocast(dtype=dtype):
# Predicting Tracks
# Using VGGSfM tracker instead of VGGT tracker for efficiency
# VGGT tracker requires multiple backbone runs to query different frames (this is a problem caused by the training process)
# Will be fixed in VGGT v2
# You can also change the pred_tracks to tracks from any other methods
# e.g., from COLMAP, from CoTracker, or by chaining 2D matches from Lightglue/LoFTR.
pred_tracks, pred_vis_scores, pred_confs, points_3d, points_rgb = predict_tracks(
images,
conf=depth_conf,
points_3d=points_3d,
masks=None,
max_query_pts=args.max_query_pts,
query_frame_num=args.query_frame_num,
keypoint_extractor="aliked+sp",
fine_tracking=args.fine_tracking,
)
torch.cuda.empty_cache()
# rescale the intrinsic matrix from 518 to 1024
intrinsic[:, :2, :] *= scale
track_mask = pred_vis_scores > args.vis_thresh
# TODO: radial distortion, iterative BA, masks
reconstruction, valid_track_mask = batch_np_matrix_to_pycolmap(
points_3d,
extrinsic,
intrinsic,
pred_tracks,
image_size,
masks=track_mask,
max_reproj_error=args.max_reproj_error,
shared_camera=shared_camera,
camera_type=args.camera_type,
points_rgb=points_rgb,
)
if reconstruction is None:
raise ValueError("No reconstruction can be built with BA")
# Bundle Adjustment
ba_options = pycolmap.BundleAdjustmentOptions()
pycolmap.bundle_adjustment(reconstruction, ba_options)
reconstruction_resolution = img_load_resolution
else:
conf_thres_value = args.conf_thres_value
max_points_for_colmap = 100000 # randomly sample 3D points
shared_camera = False # in the feedforward manner, we do not support shared camera
camera_type = "PINHOLE" # in the feedforward manner, we only support PINHOLE camera
image_size = np.array([vggt_fixed_resolution, vggt_fixed_resolution])
num_frames, height, width, _ = points_3d.shape
points_rgb = F.interpolate(
images, size=(vggt_fixed_resolution, vggt_fixed_resolution), mode="bilinear", align_corners=False
)
points_rgb = (points_rgb.cpu().numpy() * 255).astype(np.uint8)
points_rgb = points_rgb.transpose(0, 2, 3, 1)
# (S, H, W, 3), with x, y coordinates and frame indices
points_xyf = create_pixel_coordinate_grid(num_frames, height, width)
conf_mask = depth_conf >= conf_thres_value
# at most writing 100000 3d points to colmap reconstruction object
conf_mask = randomly_limit_trues(conf_mask, max_points_for_colmap)
points_3d = points_3d[conf_mask]
points_xyf = points_xyf[conf_mask]
points_rgb = points_rgb[conf_mask]
print("Converting to COLMAP format")
reconstruction = batch_np_matrix_to_pycolmap_wo_track(
points_3d,
points_xyf,
points_rgb,
extrinsic,
intrinsic,
image_size,
shared_camera=shared_camera,
camera_type=camera_type,
)
reconstruction_resolution = vggt_fixed_resolution
reconstruction = rename_colmap_recons_and_rescale_camera(
reconstruction,
base_image_path_list,
original_coords.cpu().numpy(),
img_size=reconstruction_resolution,
shift_point2d_to_original_res=True,
shared_camera=shared_camera,
)
print(f"Saving reconstruction to {args.scene_dir}/sparse")
sparse_reconstruction_dir = os.path.join(args.scene_dir, "sparse")
os.makedirs(sparse_reconstruction_dir, exist_ok=True)
reconstruction.write(sparse_reconstruction_dir)
# Save point cloud for fast visualization
trimesh.PointCloud(points_3d, colors=points_rgb).export(os.path.join(args.scene_dir, "sparse/points.ply"))
return True
def rename_colmap_recons_and_rescale_camera(
reconstruction, image_paths, original_coords, img_size, shift_point2d_to_original_res=False, shared_camera=False
):
rescale_camera = True
for pyimageid in reconstruction.images:
# Reshaped the padded&resized image to the original size
# Rename the images to the original names
pyimage = reconstruction.images[pyimageid]
pycamera = reconstruction.cameras[pyimage.camera_id]
pyimage.name = image_paths[pyimageid - 1]
if rescale_camera:
# Rescale the camera parameters
pred_params = copy.deepcopy(pycamera.params)
real_image_size = original_coords[pyimageid - 1, -2:]
resize_ratio = max(real_image_size) / img_size
pred_params = pred_params * resize_ratio
real_pp = real_image_size / 2
pred_params[-2:] = real_pp # center of the image
pycamera.params = pred_params
pycamera.width = real_image_size[0]
pycamera.height = real_image_size[1]
if shift_point2d_to_original_res:
# Also shift the point2D to original resolution
top_left = original_coords[pyimageid - 1, :2]
for point2D in pyimage.points2D:
point2D.xy = (point2D.xy - top_left) * resize_ratio
if shared_camera:
# If shared_camera, all images share the same camera
# no need to rescale any more
rescale_camera = False
return reconstruction
if __name__ == "__main__":
args = parse_args()
with torch.no_grad():
demo_fn(args)
# Work in Progress (WIP)
"""
VGGT Runner Script
=================
A script to run the VGGT model for 3D reconstruction from image sequences.
Directory Structure
------------------
Input:
input_folder/
└── images/ # Source images for reconstruction
Output:
output_folder/
├── images/
├── sparse/ # Reconstruction results
│ ├── cameras.bin # Camera parameters (COLMAP format)
│ ├── images.bin # Pose for each image (COLMAP format)
│ ├── points3D.bin # 3D points (COLMAP format)
│ └── points.ply # Point cloud visualization file
└── visuals/ # Visualization outputs TODO
Key Features
-----------
• Dual-mode Support: Run reconstructions using either VGGT or VGGT+BA
• Resolution Preservation: Maintains original image resolution in camera parameters and tracks
• COLMAP Compatibility: Exports results in standard COLMAP sparse reconstruction format
"""
================================================
FILE: demo_gradio.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import cv2
import torch
import numpy as np
import gradio as gr
import sys
import shutil
from datetime import datetime
import glob
import gc
import time
sys.path.append("vggt/")
from visual_util import predictions_to_glb
from vggt.models.vggt import VGGT
from vggt.utils.load_fn import load_and_preprocess_images
from vggt.utils.pose_enc import pose_encoding_to_extri_intri
from vggt.utils.geometry import unproject_depth_map_to_point_map
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Initializing and loading VGGT model...")
# model = VGGT.from_pretrained("facebook/VGGT-1B") # another way to load the model
model = VGGT()
_URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt"
model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
model.eval()
model = model.to(device)
# -------------------------------------------------------------------------
# 1) Core model inference
# -------------------------------------------------------------------------
def run_model(target_dir, model) -> dict:
"""
Run the VGGT model on images in the 'target_dir/images' folder and return predictions.
"""
print(f"Processing images from {target_dir}")
# Device check
device = "cuda" if torch.cuda.is_available() else "cpu"
if not torch.cuda.is_available():
raise ValueError("CUDA is not available. Check your environment.")
# Move model to device
model = model.to(device)
model.eval()
# Load and preprocess images
image_names = glob.glob(os.path.join(target_dir, "images", "*"))
image_names = sorted(image_names)
print(f"Found {len(image_names)} images")
if len(image_names) == 0:
raise ValueError("No images found. Check your upload.")
images = load_and_preprocess_images(image_names).to(device)
print(f"Preprocessed images shape: {images.shape}")
# Run inference
print("Running inference...")
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
with torch.no_grad():
with torch.cuda.amp.autocast(dtype=dtype):
predictions = model(images)
# Convert pose encoding to extrinsic and intrinsic matrices
print("Converting pose encoding to extrinsic and intrinsic matrices...")
extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], images.shape[-2:])
predictions["extrinsic"] = extrinsic
predictions["intrinsic"] = intrinsic
# Convert tensors to numpy
for key in predictions.keys():
if isinstance(predictions[key], torch.Tensor):
predictions[key] = predictions[key].cpu().numpy().squeeze(0) # remove batch dimension
predictions['pose_enc_list'] = None # remove pose_enc_list
# Generate world points from depth map
print("Computing world points from depth map...")
depth_map = predictions["depth"] # (S, H, W, 1)
world_points = unproject_depth_map_to_point_map(depth_map, predictions["extrinsic"], predictions["intrinsic"])
predictions["world_points_from_depth"] = world_points
# Clean up
torch.cuda.empty_cache()
return predictions
# -------------------------------------------------------------------------
# 2) Handle uploaded video/images --> produce target_dir + images
# -------------------------------------------------------------------------
def handle_uploads(input_video, input_images):
"""
Create a new 'target_dir' + 'images' subfolder, and place user-uploaded
images or extracted frames from video into it. Return (target_dir, image_paths).
"""
start_time = time.time()
gc.collect()
torch.cuda.empty_cache()
# Create a unique folder name
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
target_dir = f"input_images_{timestamp}"
target_dir_images = os.path.join(target_dir, "images")
# Clean up if somehow that folder already exists
if os.path.exists(target_dir):
shutil.rmtree(target_dir)
os.makedirs(target_dir)
os.makedirs(target_dir_images)
image_paths = []
# --- Handle images ---
if input_images is not None:
for file_data in input_images:
if isinstance(file_data, dict) and "name" in file_data:
file_path = file_data["name"]
else:
file_path = file_data
dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
shutil.copy(file_path, dst_path)
image_paths.append(dst_path)
# --- Handle video ---
if input_video is not None:
if isinstance(input_video, dict) and "name" in input_video:
video_path = input_video["name"]
else:
video_path = input_video
vs = cv2.VideoCapture(video_path)
fps = vs.get(cv2.CAP_PROP_FPS)
frame_interval = int(fps * 1) # 1 frame/sec
count = 0
video_frame_num = 0
while True:
gotit, frame = vs.read()
if not gotit:
break
count += 1
if count % frame_interval == 0:
image_path = os.path.join(target_dir_images, f"{video_frame_num:06}.png")
cv2.imwrite(image_path, frame)
image_paths.append(image_path)
video_frame_num += 1
# Sort final images for gallery
image_paths = sorted(image_paths)
end_time = time.time()
print(f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds")
return target_dir, image_paths
# -------------------------------------------------------------------------
# 3) Update gallery on upload
# -------------------------------------------------------------------------
def update_gallery_on_upload(input_video, input_images):
"""
Whenever user uploads or changes files, immediately handle them
and show in the gallery. Return (target_dir, image_paths).
If nothing is uploaded, returns "None" and empty list.
"""
if not input_video and not input_images:
return None, None, None, None
target_dir, image_paths = handle_uploads(input_video, input_images)
return None, target_dir, image_paths, "Upload complete. Click 'Reconstruct' to begin 3D processing."
# -------------------------------------------------------------------------
# 4) Reconstruction: uses the target_dir plus any viz parameters
# -------------------------------------------------------------------------
def gradio_demo(
target_dir,
conf_thres=3.0,
frame_filter="All",
mask_black_bg=False,
mask_white_bg=False,
show_cam=True,
mask_sky=False,
prediction_mode="Pointmap Regression",
):
"""
Perform reconstruction using the already-created target_dir/images.
"""
if not os.path.isdir(target_dir) or target_dir == "None":
return None, "No valid target directory found. Please upload first.", None, None
start_time = time.time()
gc.collect()
torch.cuda.empty_cache()
# Prepare frame_filter dropdown
target_dir_images = os.path.join(target_dir, "images")
all_files = sorted(os.listdir(target_dir_images)) if os.path.isdir(target_dir_images) else []
all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
frame_filter_choices = ["All"] + all_files
print("Running run_model...")
with torch.no_grad():
predictions = run_model(target_dir, model)
# Save predictions
prediction_save_path = os.path.join(target_dir, "predictions.npz")
np.savez(prediction_save_path, **predictions)
# Handle None frame_filter
if frame_filter is None:
frame_filter = "All"
# Build a GLB file name
glbfile = os.path.join(
target_dir,
f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_pred{prediction_mode.replace(' ', '_')}.glb",
)
# Convert predictions to GLB
glbscene = predictions_to_glb(
predictions,
conf_thres=conf_thres,
filter_by_frames=frame_filter,
mask_black_bg=mask_black_bg,
mask_white_bg=mask_white_bg,
show_cam=show_cam,
mask_sky=mask_sky,
target_dir=target_dir,
prediction_mode=prediction_mode,
)
glbscene.export(file_obj=glbfile)
# Cleanup
del predictions
gc.collect()
torch.cuda.empty_cache()
end_time = time.time()
print(f"Total time: {end_time - start_time:.2f} seconds (including IO)")
log_msg = f"Reconstruction Success ({len(all_files)} frames). Waiting for visualization."
return glbfile, log_msg, gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True)
# -------------------------------------------------------------------------
# 5) Helper functions for UI resets + re-visualization
# -------------------------------------------------------------------------
def clear_fields():
"""
Clears the 3D viewer, the stored target_dir, and empties the gallery.
"""
return None
def update_log():
"""
Display a quick log message while waiting.
"""
return "Loading and Reconstructing..."
def update_visualization(
target_dir, conf_thres, frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode, is_example
):
"""
Reload saved predictions from npz, create (or reuse) the GLB for new parameters,
and return it for the 3D viewer. If is_example == "True", skip.
"""
# If it's an example click, skip as requested
if is_example == "True":
return None, "No reconstruction available. Please click the Reconstruct button first."
if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
return None, "No reconstruction available. Please click the Reconstruct button first."
predictions_path = os.path.join(target_dir, "predictions.npz")
if not os.path.exists(predictions_path):
return None, f"No reconstruction available at {predictions_path}. Please run 'Reconstruct' first."
key_list = [
"pose_enc",
"depth",
"depth_conf",
"world_points",
"world_points_conf",
"images",
"extrinsic",
"intrinsic",
"world_points_from_depth",
]
loaded = np.load(predictions_path)
predictions = {key: np.array(loaded[key]) for key in key_list}
glbfile = os.path.join(
target_dir,
f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_pred{prediction_mode.replace(' ', '_')}.glb",
)
if not os.path.exists(glbfile):
glbscene = predictions_to_glb(
predictions,
conf_thres=conf_thres,
filter_by_frames=frame_filter,
mask_black_bg=mask_black_bg,
mask_white_bg=mask_white_bg,
show_cam=show_cam,
mask_sky=mask_sky,
target_dir=target_dir,
prediction_mode=prediction_mode,
)
glbscene.export(file_obj=glbfile)
return glbfile, "Updating Visualization"
# -------------------------------------------------------------------------
# Example images
# -------------------------------------------------------------------------
great_wall_video = "examples/videos/great_wall.mp4"
colosseum_video = "examples/videos/Colosseum.mp4"
room_video = "examples/videos/room.mp4"
kitchen_video = "examples/videos/kitchen.mp4"
fern_video = "examples/videos/fern.mp4"
single_cartoon_video = "examples/videos/single_cartoon.mp4"
single_oil_painting_video = "examples/videos/single_oil_painting.mp4"
pyramid_video = "examples/videos/pyramid.mp4"
# -------------------------------------------------------------------------
# 6) Build Gradio UI
# -------------------------------------------------------------------------
theme = gr.themes.Ocean()
theme.set(
checkbox_label_background_fill_selected="*button_primary_background_fill",
checkbox_label_text_color_selected="*button_primary_text_color",
)
with gr.Blocks(
theme=theme,
css="""
.custom-log * {
font-style: italic;
font-size: 22px !important;
background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%);
-webkit-background-clip: text;
background-clip: text;
font-weight: bold !important;
color: transparent !important;
text-align: center !important;
}
.example-log * {
font-style: italic;
font-size: 16px !important;
background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%);
-webkit-background-clip: text;
background-clip: text;
color: transparent !important;
}
#my_radio .wrap {
display: flex;
flex-wrap: nowrap;
justify-content: center;
align-items: center;
}
#my_radio .wrap label {
display: flex;
width: 50%;
justify-content: center;
align-items: center;
margin: 0;
padding: 10px 0;
box-sizing: border-box;
}
""",
) as demo:
# Instead of gr.State, we use a hidden Textbox:
is_example = gr.Textbox(label="is_example", visible=False, value="None")
num_images = gr.Textbox(label="num_images", visible=False, value="None")
gr.HTML(
"""
<h1>🏛️ VGGT: Visual Geometry Grounded Transformer</h1>
<p>
<a href="https://github.com/facebookresearch/vggt">🐙 GitHub Repository</a> |
<a href="#">Project Page</a>
</p>
<div style="font-size: 16px; line-height: 1.5;">
<p>Upload a video or a set of images to create a 3D reconstruction of a scene or object. VGGT takes these images and generates a 3D point cloud, along with estimated camera poses.</p>
<h3>Getting Started:</h3>
<ol>
<li><strong>Upload Your Data:</strong> Use the "Upload Video" or "Upload Images" buttons on the left to provide your input. Videos will be automatically split into individual frames (one frame per second).</li>
<li><strong>Preview:</strong> Your uploaded images will appear in the gallery on the left.</li>
<li><strong>Reconstruct:</strong> Click the "Reconstruct" button to start the 3D reconstruction process.</li>
<li><strong>Visualize:</strong> The 3D reconstruction will appear in the viewer on the right. You can rotate, pan, and zoom to explore the model, and download the GLB file. Note the visualization of 3D points may be slow for a large number of input images.</li>
<li>
<strong>Adjust Visualization (Optional):</strong>
After reconstruction, you can fine-tune the visualization using the options below
<details style="display:inline;">
<summary style="display:inline;">(<strong>click to expand</strong>):</summary>
<ul>
<li><em>Confidence Threshold:</em> Adjust the filtering of points based on confidence.</li>
<li><em>Show Points from Frame:</em> Select specific frames to display in the point cloud.</li>
<li><em>Show Camera:</em> Toggle the display of estimated camera positions.</li>
<li><em>Filter Sky / Filter Black Background:</em> Remove sky or black-background points.</li>
<li><em>Select a Prediction Mode:</em> Choose between "Depthmap and Camera Branch" or "Pointmap Branch."</li>
</ul>
</details>
</li>
</ol>
<p><strong style="color: #0ea5e9;">Please note:</strong> <span style="color: #0ea5e9; font-weight: bold;">VGGT typically reconstructs a scene in less than 1 second. However, visualizing 3D points may take tens of seconds due to third-party rendering, which are independent of VGGT's processing time. </span></p>
</div>
"""
)
target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")
with gr.Row():
with gr.Column(scale=2):
input_video = gr.Video(label="Upload Video", interactive=True)
input_images = gr.File(file_count="multiple", label="Upload Images", interactive=True)
image_gallery = gr.Gallery(
label="Preview",
columns=4,
height="300px",
show_download_button=True,
object_fit="contain",
preview=True,
)
with gr.Column(scale=4):
with gr.Column():
gr.Markdown("**3D Reconstruction (Point Cloud and Camera Poses)**")
log_output = gr.Markdown(
"Please upload a video or images, then click Reconstruct.", elem_classes=["custom-log"]
)
reconstruction_output = gr.Model3D(height=520, zoom_speed=0.5, pan_speed=0.5)
with gr.Row():
submit_btn = gr.Button("Reconstruct", scale=1, variant="primary")
clear_btn = gr.ClearButton(
[input_video, input_images, reconstruction_output, log_output, target_dir_output, image_gallery],
scale=1,
)
with gr.Row():
prediction_mode = gr.Radio(
["Depthmap and Camera Branch", "Pointmap Branch"],
label="Select a Prediction Mode",
value="Depthmap and Camera Branch",
scale=1,
elem_id="my_radio",
)
with gr.Row():
conf_thres = gr.Slider(minimum=0, maximum=100, value=50, step=0.1, label="Confidence Threshold (%)")
frame_filter = gr.Dropdown(choices=["All"], value="All", label="Show Points from Frame")
with gr.Column():
show_cam = gr.Checkbox(label="Show Camera", value=True)
mask_sky = gr.Checkbox(label="Filter Sky", value=False)
mask_black_bg = gr.Checkbox(label="Filter Black Background", value=False)
mask_white_bg = gr.Checkbox(label="Filter White Background", value=False)
# ---------------------- Examples section ----------------------
examples = [
[colosseum_video, "22", None, 20.0, False, False, True, False, "Depthmap and Camera Branch", "True"],
[pyramid_video, "30", None, 35.0, False, False, True, False, "Depthmap and Camera Branch", "True"],
[single_cartoon_video, "1", None, 15.0, False, False, True, False, "Depthmap and Camera Branch", "True"],
[single_oil_painting_video, "1", None, 20.0, False, False, True, True, "Depthmap and Camera Branch", "True"],
[room_video, "8", None, 5.0, False, False, True, False, "Depthmap and Camera Branch", "True"],
[kitchen_video, "25", None, 50.0, False, False, True, False, "Depthmap and Camera Branch", "True"],
[fern_video, "20", None, 45.0, False, False, True, False, "Depthmap and Camera Branch", "True"],
]
def example_pipeline(
input_video,
num_images_str,
input_images,
conf_thres,
mask_black_bg,
mask_white_bg,
show_cam,
mask_sky,
prediction_mode,
is_example_str,
):
"""
1) Copy example images to new target_dir
2) Reconstruct
3) Return model3D + logs + new_dir + updated dropdown + gallery
We do NOT return is_example. It's just an input.
"""
target_dir, image_paths = handle_uploads(input_video, input_images)
# Always use "All" for frame_filter in examples
frame_filter = "All"
glbfile, log_msg, dropdown = gradio_demo(
target_dir, conf_thres, frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode
)
return glbfile, log_msg, target_dir, dropdown, image_paths
gr.Markdown("Click any row to load an example.", elem_classes=["example-log"])
gr.Examples(
examples=examples,
inputs=[
input_video,
num_images,
input_images,
conf_thres,
mask_black_bg,
mask_white_bg,
show_cam,
mask_sky,
prediction_mode,
is_example,
],
outputs=[reconstruction_output, log_output, target_dir_output, frame_filter, image_gallery],
fn=example_pipeline,
cache_examples=False,
examples_per_page=50,
)
# -------------------------------------------------------------------------
# "Reconstruct" button logic:
# - Clear fields
# - Update log
# - gradio_demo(...) with the existing target_dir
# - Then set is_example = "False"
# -------------------------------------------------------------------------
submit_btn.click(fn=clear_fields, inputs=[], outputs=[reconstruction_output]).then(
fn=update_log, inputs=[], outputs=[log_output]
).then(
fn=gradio_demo,
inputs=[
target_dir_output,
conf_thres,
frame_filter,
mask_black_bg,
mask_white_bg,
show_cam,
mask_sky,
prediction_mode,
],
outputs=[reconstruction_output, log_output, frame_filter],
).then(
fn=lambda: "False", inputs=[], outputs=[is_example] # set is_example to "False"
)
# -------------------------------------------------------------------------
# Real-time Visualization Updates
# -------------------------------------------------------------------------
conf_thres.change(
update_visualization,
[
target_dir_output,
conf_thres,
frame_filter,
mask_black_bg,
mask_white_bg,
show_cam,
mask_sky,
prediction_mode,
is_example,
],
[reconstruction_output, log_output],
)
frame_filter.change(
update_visualization,
[
target_dir_output,
conf_thres,
frame_filter,
mask_black_bg,
mask_white_bg,
show_cam,
mask_sky,
prediction_mode,
is_example,
],
[reconstruction_output, log_output],
)
mask_black_bg.change(
update_visualization,
[
target_dir_output,
conf_thres,
frame_filter,
mask_black_bg,
mask_white_bg,
show_cam,
mask_sky,
prediction_mode,
is_example,
],
[reconstruction_output, log_output],
)
mask_white_bg.change(
update_visualization,
[
target_dir_output,
conf_thres,
frame_filter,
mask_black_bg,
mask_white_bg,
show_cam,
mask_sky,
prediction_mode,
is_example,
],
[reconstruction_output, log_output],
)
show_cam.change(
update_visualization,
[
target_dir_output,
conf_thres,
frame_filter,
mask_black_bg,
mask_white_bg,
show_cam,
mask_sky,
prediction_mode,
is_example,
],
[reconstruction_output, log_output],
)
mask_sky.change(
update_visualization,
[
target_dir_output,
conf_thres,
frame_filter,
mask_black_bg,
mask_white_bg,
show_cam,
mask_sky,
prediction_mode,
is_example,
],
[reconstruction_output, log_output],
)
prediction_mode.change(
update_visualization,
[
target_dir_output,
conf_thres,
frame_filter,
mask_black_bg,
mask_white_bg,
show_cam,
mask_sky,
prediction_mode,
is_example,
],
[reconstruction_output, log_output],
)
# -------------------------------------------------------------------------
# Auto-update gallery whenever user uploads or changes their files
# -------------------------------------------------------------------------
input_video.change(
fn=update_gallery_on_upload,
inputs=[input_video, input_images],
outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
)
input_images.change(
fn=update_gallery_on_upload,
inputs=[input_video, input_images],
outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
)
demo.queue(max_size=20).launch(show_error=True, share=True)
================================================
FILE: demo_viser.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import glob
import time
import threading
import argparse
from typing import List, Optional
import numpy as np
import torch
from tqdm.auto import tqdm
import viser
import viser.transforms as viser_tf
import cv2
try:
import onnxruntime
except ImportError:
print("onnxruntime not found. Sky segmentation may not work.")
from visual_util import segment_sky, download_file_from_url
from vggt.models.vggt import VGGT
from vggt.utils.load_fn import load_and_preprocess_images
from vggt.utils.geometry import closed_form_inverse_se3, unproject_depth_map_to_point_map
from vggt.utils.pose_enc import pose_encoding_to_extri_intri
def viser_wrapper(
pred_dict: dict,
port: int = 8080,
init_conf_threshold: float = 50.0, # represents percentage (e.g., 50 means filter lowest 50%)
use_point_map: bool = False,
background_mode: bool = False,
mask_sky: bool = False,
image_folder: str = None,
):
"""
Visualize predicted 3D points and camera poses with viser.
Args:
pred_dict (dict):
{
"images": (S, 3, H, W) - Input images,
"world_points": (S, H, W, 3),
"world_points_conf": (S, H, W),
"depth": (S, H, W, 1),
"depth_conf": (S, H, W),
"extrinsic": (S, 3, 4),
"intrinsic": (S, 3, 3),
}
port (int): Port number for the viser server.
init_conf_threshold (float): Initial percentage of low-confidence points to filter out.
use_point_map (bool): Whether to visualize world_points or use depth-based points.
background_mode (bool): Whether to run the server in background thread.
mask_sky (bool): Whether to apply sky segmentation to filter out sky points.
image_folder (str): Path to the folder containing input images.
"""
print(f"Starting viser server on port {port}")
server = viser.ViserServer(host="0.0.0.0", port=port)
server.gui.configure_theme(titlebar_content=None, control_layout="collapsible")
# Unpack prediction dict
images = pred_dict["images"] # (S, 3, H, W)
world_points_map = pred_dict["world_points"] # (S, H, W, 3)
conf_map = pred_dict["world_points_conf"] # (S, H, W)
depth_map = pred_dict["depth"] # (S, H, W, 1)
depth_conf = pred_dict["depth_conf"] # (S, H, W)
extrinsics_cam = pred_dict["extrinsic"] # (S, 3, 4)
intrinsics_cam = pred_dict["intrinsic"] # (S, 3, 3)
# Compute world points from depth if not using the precomputed point map
if not use_point_map:
world_points = unproject_depth_map_to_point_map(depth_map, extrinsics_cam, intrinsics_cam)
conf = depth_conf
else:
world_points = world_points_map
conf = conf_map
# Apply sky segmentation if enabled
if mask_sky and image_folder is not None:
conf = apply_sky_segmentation(conf, image_folder)
# Convert images from (S, 3, H, W) to (S, H, W, 3)
# Then flatten everything for the point cloud
colors = images.transpose(0, 2, 3, 1) # now (S, H, W, 3)
S, H, W, _ = world_points.shape
# Flatten
points = world_points.reshape(-1, 3)
colors_flat = (colors.reshape(-1, 3) * 255).astype(np.uint8)
conf_flat = conf.reshape(-1)
cam_to_world_mat = closed_form_inverse_se3(extrinsics_cam) # shape (S, 4, 4) typically
# For convenience, we store only (3,4) portion
cam_to_world = cam_to_world_mat[:, :3, :]
# Compute scene center and recenter
scene_center = np.mean(points, axis=0)
points_centered = points - scene_center
cam_to_world[..., -1] -= scene_center
# Store frame indices so we can filter by frame
frame_indices = np.repeat(np.arange(S), H * W)
# Build the viser GUI
gui_show_frames = server.gui.add_checkbox("Show Cameras", initial_value=True)
# Now the slider represents percentage of points to filter out
gui_points_conf = server.gui.add_slider(
"Confidence Percent", min=0, max=100, step=0.1, initial_value=init_conf_threshold
)
gui_frame_selector = server.gui.add_dropdown(
"Show Points from Frames", options=["All"] + [str(i) for i in range(S)], initial_value="All"
)
# Create the main point cloud handle
# Compute the threshold value as the given percentile
init_threshold_val = np.percentile(conf_flat, init_conf_threshold)
init_conf_mask = (conf_flat >= init_threshold_val) & (conf_flat > 0.1)
point_cloud = server.scene.add_point_cloud(
name="viser_pcd",
points=points_centered[init_conf_mask],
colors=colors_flat[init_conf_mask],
point_size=0.001,
point_shape="circle",
)
# We will store references to frames & frustums so we can toggle visibility
frames: List[viser.FrameHandle] = []
frustums: List[viser.CameraFrustumHandle] = []
def visualize_frames(extrinsics: np.ndarray, images_: np.ndarray) -> None:
"""
Add camera frames and frustums to the scene.
extrinsics: (S, 3, 4)
images_: (S, 3, H, W)
"""
# Clear any existing frames or frustums
for f in frames:
f.remove()
frames.clear()
for fr in frustums:
fr.remove()
frustums.clear()
# Optionally attach a callback that sets the viewpoint to the chosen camera
def attach_callback(frustum: viser.CameraFrustumHandle, frame: viser.FrameHandle) -> None:
@frustum.on_click
def _(_) -> None:
for client in server.get_clients().values():
client.camera.wxyz = frame.wxyz
client.camera.position = frame.position
img_ids = range(S)
for img_id in tqdm(img_ids):
cam2world_3x4 = extrinsics[img_id]
T_world_camera = viser_tf.SE3.from_matrix(cam2world_3x4)
# Add a small frame axis
frame_axis = server.scene.add_frame(
f"frame_{img_id}",
wxyz=T_world_camera.rotation().wxyz,
position=T_world_camera.translation(),
axes_length=0.05,
axes_radius=0.002,
origin_radius=0.002,
)
frames.append(frame_axis)
# Convert the image for the frustum
img = images_[img_id] # shape (3, H, W)
img = (img.transpose(1, 2, 0) * 255).astype(np.uint8)
h, w = img.shape[:2]
# If you want correct FOV from intrinsics, do something like:
# fx = intrinsics_cam[img_id, 0, 0]
# fov = 2 * np.arctan2(h/2, fx)
# For demonstration, we pick a simple approximate FOV:
fy = 1.1 * h
fov = 2 * np.arctan2(h / 2, fy)
# Add the frustum
frustum_cam = server.scene.add_camera_frustum(
f"frame_{img_id}/frustum", fov=fov, aspect=w / h, scale=0.05, image=img, line_width=1.0
)
frustums.append(frustum_cam)
attach_callback(frustum_cam, frame_axis)
def update_point_cloud() -> None:
"""Update the point cloud based on current GUI selections."""
# Here we compute the threshold value based on the current percentage
current_percentage = gui_points_conf.value
threshold_val = np.percentile(conf_flat, current_percentage)
print(f"Threshold absolute value: {threshold_val}, percentage: {current_percentage}%")
conf_mask = (conf_flat >= threshold_val) & (conf_flat > 1e-5)
if gui_frame_selector.value == "All":
frame_mask = np.ones_like(conf_mask, dtype=bool)
else:
selected_idx = int(gui_frame_selector.value)
frame_mask = frame_indices == selected_idx
combined_mask = conf_mask & frame_mask
point_cloud.points = points_centered[combined_mask]
point_cloud.colors = colors_flat[combined_mask]
@gui_points_conf.on_update
def _(_) -> None:
update_point_cloud()
@gui_frame_selector.on_update
def _(_) -> None:
update_point_cloud()
@gui_show_frames.on_update
def _(_) -> None:
"""Toggle visibility of camera frames and frustums."""
for f in frames:
f.visible = gui_show_frames.value
for fr in frustums:
fr.visible = gui_show_frames.value
# Add the camera frames to the scene
visualize_frames(cam_to_world, images)
print("Starting viser server...")
# If background_mode is True, spawn a daemon thread so the main thread can continue.
if background_mode:
def server_loop():
while True:
time.sleep(0.001)
thread = threading.Thread(target=server_loop, daemon=True)
thread.start()
else:
while True:
time.sleep(0.01)
return server
# Helper functions for sky segmentation
def apply_sky_segmentation(conf: np.ndarray, image_folder: str) -> np.ndarray:
"""
Apply sky segmentation to confidence scores.
Args:
conf (np.ndarray): Confidence scores with shape (S, H, W)
image_folder (str): Path to the folder containing input images
Returns:
np.ndarray: Updated confidence scores with sky regions masked out
"""
S, H, W = conf.shape
sky_masks_dir = image_folder.rstrip("/") + "_sky_masks"
os.makedirs(sky_masks_dir, exist_ok=True)
# Download skyseg.onnx if it doesn't exist
if not os.path.exists("skyseg.onnx"):
print("Downloading skyseg.onnx...")
download_file_from_url("https://huggingface.co/JianyuanWang/skyseg/resolve/main/skyseg.onnx", "skyseg.onnx")
skyseg_session = onnxruntime.InferenceSession("skyseg.onnx")
image_files = sorted(glob.glob(os.path.join(image_folder, "*")))
sky_mask_list = []
print("Generating sky masks...")
for i, image_path in enumerate(tqdm(image_files[:S])): # Limit to the number of images in the batch
image_name = os.path.basename(image_path)
mask_filepath = os.path.join(sky_masks_dir, image_name)
if os.path.exists(mask_filepath):
sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE)
else:
sky_mask = segment_sky(image_path, skyseg_session, mask_filepath)
# Resize mask to match H×W if needed
if sky_mask.shape[0] != H or sky_mask.shape[1] != W:
sky_mask = cv2.resize(sky_mask, (W, H))
sky_mask_list.append(sky_mask)
# Convert list to numpy array with shape S×H×W
sky_mask_array = np.array(sky_mask_list)
# Apply sky mask to confidence scores
sky_mask_binary = (sky_mask_array > 0.1).astype(np.float32)
conf = conf * sky_mask_binary
print("Sky segmentation applied successfully")
return conf
parser = argparse.ArgumentParser(description="VGGT demo with viser for 3D visualization")
parser.add_argument(
"--image_folder", type=str, default="examples/kitchen/images/", help="Path to folder containing images"
)
parser.add_argument("--use_point_map", action="store_true", help="Use point map instead of depth-based points")
parser.add_argument("--background_mode", action="store_true", help="Run the viser server in background mode")
parser.add_argument("--port", type=int, default=8080, help="Port number for the viser server")
parser.add_argument(
"--conf_threshold", type=float, default=25.0, help="Initial percentage of low-confidence points to filter out"
)
parser.add_argument("--mask_sky", action="store_true", help="Apply sky segmentation to filter out sky points")
def main():
"""
Main function for the VGGT demo with viser for 3D visualization.
This function:
1. Loads the VGGT model
2. Processes input images from the specified folder
3. Runs inference to generate 3D points and camera poses
4. Optionally applies sky segmentation to filter out sky points
5. Visualizes the results using viser
Command-line arguments:
--image_folder: Path to folder containing input images
--use_point_map: Use point map instead of depth-based points
--background_mode: Run the viser server in background mode
--port: Port number for the viser server
--conf_threshold: Initial percentage of low-confidence points to filter out
--mask_sky: Apply sky segmentation to filter out sky points
"""
args = parser.parse_args()
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
print("Initializing and loading VGGT model...")
# model = VGGT.from_pretrained("facebook/VGGT-1B")
model = VGGT()
_URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt"
model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
model.eval()
model = model.to(device)
# Use the provided image folder path
print(f"Loading images from {args.image_folder}...")
image_names = glob.glob(os.path.join(args.image_folder, "*"))
print(f"Found {len(image_names)} images")
images = load_and_preprocess_images(image_names).to(device)
print(f"Preprocessed images shape: {images.shape}")
print("Running inference...")
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
with torch.no_grad():
with torch.cuda.amp.autocast(dtype=dtype):
predictions = model(images)
print("Converting pose encoding to extrinsic and intrinsic matrices...")
extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], images.shape[-2:])
predictions["extrinsic"] = extrinsic
predictions["intrinsic"] = intrinsic
print("Processing model outputs...")
for key in predictions.keys():
if isinstance(predictions[key], torch.Tensor):
predictions[key] = predictions[key].cpu().numpy().squeeze(0) # remove batch dimension and convert to numpy
if args.use_point_map:
print("Visualizing 3D points from point map")
else:
print("Visualizing 3D points by unprojecting depth map by cameras")
if args.mask_sky:
print("Sky segmentation enabled - will filter out sky points")
print("Starting viser visualization...")
viser_server = viser_wrapper(
predictions,
port=args.port,
init_conf_threshold=args.conf_threshold,
use_point_map=args.use_point_map,
background_mode=args.background_mode,
mask_sky=args.mask_sky,
image_folder=args.image_folder,
)
print("Visualization complete")
if __name__ == "__main__":
main()
================================================
FILE: docs/package.md
================================================
# Alternative Installation Methods
This document explains how to install VGGT as a package using different package managers.
## Prerequisites
Before installing VGGT as a package, you need to install PyTorch and torchvision. We don't list these as dependencies to avoid CUDA version mismatches. Install them first, with an example as:
```bash
# install pytorch 2.3.1 with cuda 12.1
pip install torch==2.3.1 torchvision==0.18.1 --index-url https://download.pytorch.org/whl/cu121
```
## Installation Options
### Install with pip
The simplest way to install VGGT is using pip:
```bash
pip install -e .
```
### Install and run with pixi
[Pixi](https://pixi.sh) is a package management tool for creating reproducible environments.
1. First, [download and install pixi](https://pixi.sh/latest/get_started/)
2. Then run:
```bash
pixi run -e python demo_gradio.py
```
### Install and run with uv
[uv](https://docs.astral.sh/uv/) is a fast Python package installer and resolver.
1. First, [install uv](https://docs.astral.sh/uv/getting-started/installation/)
2. Then run:
```bash
uv run --extra demo demo_gradio.py
```
================================================
FILE: pyproject.toml
================================================
[project]
authors = [{name = "Jianyuan Wang", email = "jianyuan@robots.ox.ac.uk"}]
dependencies = [
"numpy<2",
"Pillow",
"huggingface_hub",
"einops",
"safetensors",
"opencv-python",
]
name = "vggt"
requires-python = ">= 3.10"
version = "0.0.1"
[project.optional-dependencies]
demo = [
"gradio==5.17.1",
"viser==0.2.23",
"tqdm",
"hydra-core",
"omegaconf",
"opencv-python",
"scipy",
"onnxruntime",
"requests",
"trimesh",
"matplotlib",
]
# Using setuptools as the build backend
[build-system]
requires = ["setuptools>=61.0", "wheel"]
build-backend = "setuptools.build_meta"
# setuptools configuration
[tool.setuptools.packages.find]
where = ["."]
include = ["vggt*"]
# Pixi configuration
[tool.pixi.workspace]
channels = ["conda-forge"]
platforms = ["linux-64"]
[tool.pixi.pypi-dependencies]
vggt = { path = ".", editable = true }
[tool.pixi.environments]
default = { solve-group = "default" }
demo = { features = ["demo"], solve-group = "default" }
[tool.pixi.tasks]
================================================
FILE: requirements.txt
================================================
torch==2.3.1
torchvision==0.18.1
numpy==1.26.1
Pillow
huggingface_hub
einops
safetensors
================================================
FILE: requirements_demo.txt
================================================
gradio==5.17.1
viser==0.2.23
tqdm
hydra-core
omegaconf
opencv-python
scipy
onnxruntime
requests
trimesh
matplotlib
pydantic==2.10.6
# feel free to skip the dependencies below if you do not need demo_colmap.py
pycolmap==3.10.0
pyceres==2.3
git+https://github.com/jytime/LightGlue.git#egg=lightglue
================================================
FILE: training/README.md
================================================
# Training
This is a re-implementation of our framework for training VGGT. This document shows how to set up the environment and run VGGT training. I have aimed to faithfully reproduce the original training framework, but please open an issue if anything looks off.
## 1. Prerequisites
Before you begin, ensure you have completed the following steps:
1. **Install VGGT as a package:**
```bash
pip install -e .
```
2. **Prepare the dataset and annotations:**
- Download the Co3D dataset from the [official repository](https://github.com/facebookresearch/co3d).
- Download the required annotation files from [Hugging Face](https://huggingface.co/datasets/JianyuanWang/co3d_anno/tree/main).
## 2. Configuration
After downloading the dataset and annotations, configure the paths in `training/config/default.yaml`.
### Required Path Configuration
1. Open `training/config/default.yaml`
2. Update the following paths with your absolute directory paths:
- `CO3D_DIR`: Path to your Co3D dataset
- `CO3D_ANNOTATION_DIR`: Path to your Co3D annotation files
- `resume_checkpoint_path`: Path to your pre-trained VGGT checkpoint
### Configuration Example
```yaml
data:
train:
dataset:
dataset_configs:
- _target_: data.datasets.co3d.Co3dDataset
split: train
CO3D_DIR: /YOUR/PATH/TO/CO3D
CO3D_ANNOTATION_DIR: /YOUR/PATH/TO/CO3D_ANNOTATION
# ... same for val ...
checkpoint:
resume_checkpoint_path: /YOUR/PATH/TO/CKPT
```
## 3. Fine-tuning on Co3D
To fine-tune the provided pre-trained model on the Co3D dataset, run the following command. This example uses 4 GPUs with PyTorch Distributed Data Parallel (DDP):
```bash
torchrun --nproc_per_node=4 launch.py
```
The default configuration in `training/config/default.yaml` is set up for fine-tuning. It automatically resumes from a checkpoint and freezes the model's `aggregator` module during training.
## 4. Training on Multiple Datasets
The dataloader supports multiple datasets naturally. For example, if you have downloaded VKitti using `preprocess/vkitti.sh`, you can train on Co3D+VKitti by configuring:
```yaml
data:
train:
dataset:
_target_: data.composed_dataset.ComposedDataset
dataset_configs:
- _target_: data.datasets.co3d.Co3dDataset
split: train
CO3D_DIR: /YOUR/PATH/TO/CO3D
CO3D_ANNOTATION_DIR: /YOUR/PATH/TO/CO3D_ANNOTATION
len_train: 100000
- _target_: data.datasets.vkitti.VKittiDataset
split: train
VKitti_DIR: /YOUR/PATH/TO/VKitti
len_train: 100000
expand_ratio: 8
```
The ratio of different datasets can be controlled by setting `len_train`. For example, Co3D with `len_train: 10000` and VKitti with `len_train: 2000` will result in Co3D being sampled five times more frequently than VKitti.
## 5. Common Questions
### Memory Management
If you encounter out-of-memory (OOM) errors on your GPU, consider adjusting the following parameters in `training/config/default.yaml`:
- `max_img_per_gpu`: Reduce this value to decrease the batch size per GPU
- `accum_steps`: Sets the number of gradient accumulation steps (default is 2). This feature splits batches into smaller chunks to save memory, though it may slightly increase training time. Note that gradient accumulation was not used for the original VGGT model.
### Learning Rate Tuning
The main hyperparameter to be careful about is learning rate. Note that learning rate depends on the effective batch size, which is `batch_size_per_gpu × num_gpus`. Therefore, I highly recommend trying several learning rates based on your training setup. Generally, trying values like `5e-6`, `1e-5`, `5e-5`, `1e-4`, `5e-4` should be sufficient.
### Tracking Head
The tracking head can slightly improve accuracy but is not necessary. For general cases, especially when GPU resources are limited, we suggest fine-tuning the pre-trained model only with camera and depth heads, which is the setting in `default.yaml`. This will provide good enough results.
### Dataloader Validation
To check if your dataloader is working correctly, the best approach is to visualize its output. You can save the 3D world points as follows and then visually inspect the PLY files:
```python
def save_ply(points, colors, filename):
import open3d as o3d
if torch.is_tensor(points):
points_visual = points.reshape(-1, 3).cpu().numpy()
else:
points_visual = points.reshape(-1, 3)
if torch.is_tensor(colors):
points_visual_rgb = colors.reshape(-1, 3).cpu().numpy()
else:
points_visual_rgb = colors.reshape(-1, 3)
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(points_visual.astype(np.float64))
pcd.colors = o3d.utility.Vector3dVector(points_visual_rgb.astype(np.float64))
o3d.io.write_point_cloud(filename, pcd, write_ascii=True)
# Usage example
save_ply(
batch["world_points"][0].reshape(-1, 3),
batch["images"][0].permute(0, 2, 3, 1).reshape(-1, 3),
"debug.ply"
)
```
### Handling Unordered Sequences
For unordered sequences, you can check how we compute the ranking (similarity) between one frame and all other frames, as discussed in [Issue #82](https://github.com/facebookresearch/vggt/issues/82).
### Expected Coordinate System
Camera poses are expected to follow the OpenCV `camera-from-world` convention. Depth maps should be aligned with their corresponding camera poses.
================================================
FILE: training/__init__.py
================================================
================================================
FILE: training/config/default.yaml
================================================
defaults:
- default_dataset.yaml
exp_name: exp001
img_size: 518
num_workers: 8
seed_value: 42
accum_steps: 2 # We did not use gradient accumulation in our training, while if you suffer from OOM, you can try to use it.
patch_size: 14
val_epoch_freq: 5
max_img_per_gpu: 48
limit_train_batches: 800
limit_val_batches: 400
data:
# The code for data still looks too complicated. I should refactor this again (do I have time?...)
train:
_target_: data.dynamic_dataloader.DynamicTorchDataset
num_workers: ${num_workers}
max_img_per_gpu: ${max_img_per_gpu}
common_config:
img_size: ${img_size}
patch_size: ${patch_size}
debug: False
repeat_batch: False
dataset:
_target_: data.composed_dataset.ComposedDataset
dataset_configs:
- _target_: data.datasets.co3d.Co3dDataset
split: train
CO3D_DIR: /YOUR/PATH/TO/CO3D
CO3D_ANNOTATION_DIR: /YOUR/PATH/TO/CO3D_ANNOTATION
val:
_target_: data.dynamic_dataloader.DynamicTorchDataset
num_workers: ${num_workers}
max_img_per_gpu: ${max_img_per_gpu}
common_config:
img_size: ${img_size}
patch_size: ${patch_size}
debug: False
dataset:
_target_: data.composed_dataset.ComposedDataset
dataset_configs:
- _target_: data.datasets.co3d.Co3dDataset
split: test
CO3D_DIR: /YOUR/PATH/TO/CO3D
CO3D_ANNOTATION_DIR: /YOUR/PATH/TO/CO3D_ANNOTATION
logging:
log_dir: logs
log_visuals: False
log_freq: 1
log_level_primary: DEBUG
log_level_secondary: WARNING
all_ranks: False
tensorboard_writer:
_target_: train_utils.tb_writer.TensorBoardLogger
path: ${logging.log_dir}/tensorboard
scalar_keys_to_log:
train:
keys_to_log:
- loss_objective
- loss_camera
- loss_T
- loss_R
- loss_FL
- loss_conf_depth
- loss_reg_depth
- loss_grad_depth
val:
keys_to_log:
- loss_objective
- loss_camera
- loss_T
- loss_R
- loss_FL
- loss_conf_depth
- loss_reg_depth
- loss_grad_depth
checkpoint:
save_dir: logs/${exp_name}/ckpts
save_freq: 5
resume_checkpoint_path: /YOUR/PATH/TO/CKPT
strict: False
loss:
_target_: loss.MultitaskLoss
camera:
weight: 5.0
loss_type: "l1" # The paper uses smooth l1 loss, but we found l1 loss is more stable than smooth l1 and l2 loss.
depth:
weight: 1.0
gradient_loss_fn: "grad"
valid_range: 0.98
point: null
# If you want to enable point, use the following config
# point:
# weight: 1.0
# gradient_loss_fn: "normal"
# valid_range: 0.98
track: null
optim:
param_group_modifiers: False
optimizer:
_target_: torch.optim.AdamW
lr: 5e-5
weight_decay: 0.05
frozen_module_names:
- "*aggregator*" # example, freeze the aggregator
amp:
enabled: True
amp_dtype: bfloat16
gradient_clip:
_target_: train_utils.gradient_clip.GradientClipper
configs:
- module_name: ["aggregator"]
max_norm: 1.0 # feel free to reduce this if you see instabilities
norm_type: 2
- module_name: ["depth"]
max_norm: 1.0 # feel free to reduce this if you see instabilities
norm_type: 2
- module_name: ["camera"]
max_norm: 1.0 # feel free to reduce this if you see instabilities
norm_type: 2
options:
lr:
- scheduler:
_target_: fvcore.common.param_scheduler.CompositeParamScheduler
schedulers:
- _target_: fvcore.common.param_scheduler.LinearParamScheduler
start_value: 1e-8
end_value: 5e-5
- _target_: fvcore.common.param_scheduler.CosineParamScheduler
start_value: 5e-5
end_value: 1e-8
lengths: [0.05, 0.95]
interval_scaling: ['rescaled', 'rescaled']
weight_decay:
- scheduler:
_target_: fvcore.common.param_scheduler.ConstantParamScheduler
value: 0.05
max_epochs: 20
model:
_target_: vggt.models.vggt.VGGT
enable_camera: True
enable_depth: True
enable_point: False
enable_track: False
distributed:
# check https://docs.pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html for options
backend: nccl
comms_dtype: None
find_unused_parameters: False
timeout_mins: 30
gradient_as_bucket_view: True # Less memory used
bucket_cap_mb: 25
broadcast_buffers: True
cuda:
cudnn_deterministic: False
cudnn_benchmark: False
allow_tf32: True
================================================
FILE: training/config/default_dataset.yaml
================================================
# Template for the dataset config
data:
# The code still looks too complicated. I should refactor this again (do I have time?...)
train:
_target_: data.dynamic_dataloader.DynamicTorchDataset
num_workers: 8
max_img_per_gpu: 48
# Shuffling in PyTorch DataLoader can sometimes copy large dicts and exceed CPU memory
# (see: https://github.com/pytorch/pytorch/issues/13246).
# To avoid this, set shuffle=False and enable common_config.inside_random=True instead.
shuffle: True
pin_memory: False
common_config: # common config for evaluation
fix_img_num: -1 # -1 means do not fix the number of images
fix_aspect_ratio: 1.0
load_track: False
track_num: 1024
training: True
inside_random: True
img_size: 224
patch_size: 14
rescale: True
rescale_aug: True
landscape_check: False
debug: False
get_nearby: True
load_depth: True
img_nums: [2, 24]
max_img_per_gpu: 48
allow_duplicate_img: True
repeat_batch: False
augs:
cojitter: True
cojitter_ratio: 0.3
scales: [0.8, 1.2]
aspects: [0.33, 1.0]
color_jitter:
brightness: 0.5
contrast: 0.5
saturation: 0.5
hue: 0.1
p: 0.9
gray_scale: True
gau_blur: False
val:
_target_: data.dynamic_dataloader.DynamicTorchDataset
num_workers: 8
max_img_per_gpu: 48
# Shuffling in PyTorch DataLoader can sometimes copy large dicts and exceed CPU memory
# (see: https://github.com/pytorch/pytorch/issues/13246).
# To avoid this, set shuffle=False and enable common_config.inside_random=True instead.
shuffle: True
pin_memory: False
common_config: # common config for evaluation
fix_img_num: -1 # -1 means do not fix the number of images
fix_aspect_ratio: 1.0
load_track: False
track_num: 1024
training: False
inside_random: True
img_size: 224
patch_size: 14
rescale: True
rescale_aug: False
landscape_check: False
debug: False
get_nearby: True
load_depth: True
img_nums: [2, 12]
allow_duplicate_img: True
augs:
cojitter: False
cojitter_ratio: 0.5
scales: null
aspects: [1.0, 1.0]
color_jitter: null
gray_scale: False
gau_blur: False
================================================
FILE: training/data/__init__.py
================================================
================================================
FILE: training/data/augmentation.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional, Dict
from torchvision import transforms
def get_image_augmentation(
color_jitter: Optional[Dict[str, float]] = None,
gray_scale: bool = True,
gau_blur: bool = False
) -> Optional[transforms.Compose]:
"""Create a composition of image augmentations.
Args:
color_jitter: Dictionary containing color jitter parameters:
- brightness: float (default: 0.5)
- contrast: float (default: 0.5)
- saturation: float (default: 0.5)
- hue: float (default: 0.1)
- p: probability of applying (default: 0.9)
If None, uses default values
gray_scale: Whether to apply random grayscale (default: True)
gau_blur: Whether to apply gaussian blur (default: False)
Returns:
A Compose object of transforms or None if no transforms are added
"""
transform_list = []
default_jitter = {
"brightness": 0.5,
"contrast": 0.5,
"saturation": 0.5,
"hue": 0.1,
"p": 0.9
}
# Handle color jitter
if color_jitter is not None:
# Merge with defaults for missing keys
effective_jitter = {**default_jitter, **color_jitter}
else:
effective_jitter = default_jitter
transform_list.append(
transforms.RandomApply(
[
transforms.ColorJitter(
brightness=effective_jitter["brightness"],
contrast=effective_jitter["contrast"],
saturation=effective_jitter["saturation"],
hue=effective_jitter["hue"],
)
],
p=effective_jitter["p"],
)
)
if gray_scale:
transform_list.append(transforms.RandomGrayscale(p=0.05))
if gau_blur:
transform_list.append(
transforms.RandomApply(
[transforms.GaussianBlur(5, sigma=(0.1, 1.0))], p=0.05
)
)
return transforms.Compose(transform_list) if transform_list else None
================================================
FILE: training/data/base_dataset.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
from PIL import Image, ImageFile
from torch.utils.data import Dataset
from .dataset_util import *
Image.MAX_IMAGE_PIXELS = None
ImageFile.LOAD_TRUNCATED_IMAGES = True
class BaseDataset(Dataset):
"""
Base dataset class for VGGT and VGGSfM training.
This abstract class handles common operations like image resizing,
augmentation, and coordinate transformations. Concrete dataset
implementations should inherit from this class.
Attributes:
img_size: Target image size (typically the width)
patch_size: Size of patches for vit
augs.scales: Scale range for data augmentation [min, max]
rescale: Whether to rescale images
rescale_aug: Whether to apply augmentation during rescaling
landscape_check: Whether to handle landscape vs portrait orientation
"""
def __init__(
self,
common_conf,
):
"""
Initialize the base dataset with common configuration.
Args:
common_conf: Configuration object with the following properties, shared by all datasets:
- img_size: Default is 518
- patch_size: Default is 14
- augs.scales: Default is [0.8, 1.2]
- rescale: Default is True
- rescale_aug: Default is True
- landscape_check: Default is True
"""
super().__init__()
self.img_size = common_conf.img_size
self.patch_size = common_conf.patch_size
self.aug_scale = common_conf.augs.scales
self.rescale = common_conf.rescale
self.rescale_aug = common_conf.rescale_aug
self.landscape_check = common_conf.landscape_check
def __len__(self):
return self.len_train
def __getitem__(self, idx_N):
"""
Get an item from the dataset.
Args:
idx_N: Tuple containing (seq_index, img_per_seq, aspect_ratio)
Returns:
Dataset item as returned by get_data()
"""
seq_index, img_per_seq, aspect_ratio = idx_N
return self.get_data(
seq_index=seq_index, img_per_seq=img_per_seq, aspect_ratio=aspect_ratio
)
def get_data(self, seq_index=None, seq_name=None, ids=None, aspect_ratio=1.0):
"""
Abstract method to retrieve data for a given sequence.
Args:
seq_index (int, optional): Index of the sequence
seq_name (str, optional): Name of the sequence
ids (list, optional): List of frame IDs
aspect_ratio (float, optional): Target aspect ratio.
Returns:
Dataset-specific data
Raises:
NotImplementedError: This method must be implemented by subclasses
"""
raise NotImplementedError(
"This is an abstract method and should be implemented in the subclass, i.e., each dataset should implement its own get_data method."
)
def get_target_shape(self, aspect_ratio):
"""
Calculate the target shape based on the given aspect ratio.
Args:
aspect_ratio: Target aspect ratio
Returns:
numpy.ndarray: Target image shape [height, width]
"""
short_size = int(self.img_size * aspect_ratio)
small_size = self.patch_size
# ensure the input shape is friendly to vision transformer
if short_size % small_size != 0:
short_size = (short_size // small_size) * small_size
image_shape = np.array([short_size, self.img_size])
return image_shape
def process_one_image(
self,
image,
depth_map,
extri_opencv,
intri_opencv,
original_size,
target_image_shape,
track=None,
filepath=None,
safe_bound=4,
):
"""
Process a single image and its associated data.
This method handles image transformations, depth processing, and coordinate conversions.
Args:
image (numpy.ndarray): Input image array
depth_map (numpy.ndarray): Depth map array
extri_opencv (numpy.ndarray): Extrinsic camera matrix (OpenCV convention)
intri_opencv (numpy.ndarray): Intrinsic camera matrix (OpenCV convention)
original_size (numpy.ndarray): Original image size [height, width]
target_image_shape (numpy.ndarray): Target image shape after processing
track (numpy.ndarray, optional): Optional tracking information. Defaults to None.
filepath (str, optional): Optional file path for debugging. Defaults to None.
safe_bound (int, optional): Safety margin for cropping operations. Defaults to 4.
Returns:
tuple: (
image (numpy.ndarray): Processed image,
depth_map (numpy.ndarray): Processed depth map,
extri_opencv (numpy.ndarray): Updated extrinsic matrix,
intri_opencv (numpy.ndarray): Updated intrinsic matrix,
world_coords_points (numpy.ndarray): 3D points in world coordinates,
cam_coords_points (numpy.ndarray): 3D points in camera coordinates,
point_mask (numpy.ndarray): Boolean mask of valid points,
track (numpy.ndarray, optional): Updated tracking information
)
"""
# Make copies to avoid in-place operations affecting original data
image = np.copy(image)
depth_map = np.copy(depth_map)
extri_opencv = np.copy(extri_opencv)
intri_opencv = np.copy(intri_opencv)
if track is not None:
track = np.copy(track)
# Apply random scale augmentation during training if enabled
if self.training and self.aug_scale:
random_h_scale, random_w_scale = np.random.uniform(
self.aug_scale[0], self.aug_scale[1], 2
)
# Avoid random padding by capping at 1.0
random_h_scale = min(random_h_scale, 1.0)
random_w_scale = min(random_w_scale, 1.0)
aug_size = original_size * np.array([random_h_scale, random_w_scale])
aug_size = aug_size.astype(np.int32)
else:
aug_size = original_size
# Move principal point to the image center and crop if necessary
image, depth_map, intri_opencv, track = crop_image_depth_and_intrinsic_by_pp(
image, depth_map, intri_opencv, aug_size, track=track, filepath=filepath,
)
original_size = np.array(image.shape[:2]) # update original_size
target_shape = target_image_shape
# Handle landscape vs. portrait orientation
rotate_to_portrait = False
if self.landscape_check:
# Switch between landscape and portrait if necessary
if original_size[0] > 1.25 * original_size[1]:
if (target_image_shape[0] != target_image_shape[1]) and (np.random.rand() > 0.5):
target_shape = np.array([target_image_shape[1], target_image_shape[0]])
rotate_to_portrait = True
# Resize images and update intrinsics
if self.rescale:
image, depth_map, intri_opencv, track = resize_image_depth_and_intrinsic(
image, depth_map, intri_opencv, target_shape, original_size, track=track,
safe_bound=safe_bound,
rescale_aug=self.rescale_aug
)
else:
print("Not rescaling the images")
# Ensure final crop to target shape
image, depth_map, intri_opencv, track = crop_image_depth_and_intrinsic_by_pp(
image, depth_map, intri_opencv, target_shape, track=track, filepath=filepath, strict=True,
)
# Apply 90-degree rotation if needed
if rotate_to_portrait:
assert self.landscape_check
clockwise = np.random.rand() > 0.5
image, depth_map, extri_opencv, intri_opencv, track = rotate_90_degrees(
image,
depth_map,
extri_opencv,
intri_opencv,
clockwise=clockwise,
track=track,
)
# Convert depth to world and camera coordinates
world_coords_points, cam_coords_points, point_mask = (
depth_to_world_coords_points(depth_map, extri_opencv, intri_opencv)
)
return (
image,
depth_map,
extri_opencv,
intri_opencv,
world_coords_points,
cam_coords_points,
point_mask,
track,
)
def get_nearby_ids(self, ids, full_seq_num, expand_ratio=None, expand_range=None):
"""
TODO: add the function to sample the ids by pose similarity ranking.
Sample a set of IDs from a sequence close to a given start index.
You can specify the range either as a ratio of the number of input IDs
or as a fixed integer window.
Args:
ids (list): Initial list of IDs. The first element is used as the anchor.
full_seq_num (int): Total number of items in the full sequence.
expand_ratio (float, optional): Factor by which the number of IDs expands
around the start index. Default is 2.0 if neither expand_ratio nor
expand_range is provided.
expand_range (int, optional): Fixed number of items to expand around the
start index. If provided, expand_ratio is ignored.
Returns:
numpy.ndarray: Array of sampled IDs, with the first element being the
original start index.
Examples:
# Using expand_ratio (default behavior)
# If ids=[100,101,102] and full_seq_num=200, with expand_ratio=2.0,
# expand_range = int(3 * 2.0) = 6, so IDs sampled from [94...106] (if boundaries allow).
# Using expand_range directly
# If ids=[100,101,102] and full_seq_num=200, with expand_range=10,
# IDs are sampled from [90...110] (if boundaries allow).
Raises:
ValueError: If no IDs are provided.
"""
if len(ids) == 0:
raise ValueError("No IDs provided.")
if expand_range is None and expand_ratio is None:
expand_ratio = 2.0 # Default behavior
total_ids = len(ids)
start_idx = ids[0]
# Determine the actual expand_range
if expand_range is None:
# Use ratio to determine range
expand_range = int(total_ids * expand_ratio)
# Calculate valid boundaries
low_bound = max(0, start_idx - expand_range)
high_bound = min(full_seq_num, start_idx + expand_range)
# Create the valid range of indices
valid_range = np.arange(low_bound, high_bound)
# Sample 'total_ids - 1' items, because we already have the start_idx
sampled_ids = np.random.choice(
valid_range,
size=(total_ids - 1),
replace=True, # we accept the situation that some sampled ids are the same
)
# Insert the start_idx at the beginning
result_ids = np.insert(sampled_ids, 0, start_idx)
return result_ids
================================================
FILE: training/data/composed_dataset.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from abc import ABC
from hydra.utils import instantiate
import torch
import random
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import ConcatDataset
import bisect
from .dataset_util import *
from .track_util import *
from .augmentation import get_image_augmentation
class ComposedDataset(Dataset, ABC):
"""
Composes multiple base datasets and applies common configurations.
This dataset provides a flexible way to combine multiple base datasets while
applying shared augmentations, track generation, and other processing steps.
It handles image normalization, tensor conversion, and other preparations
needed for training computer vision models with sequences of images.
"""
def __init__(self, dataset_configs: dict, common_config: dict, **kwargs):
"""
Initializes the ComposedDataset.
Args:
dataset_configs (dict): List of Hydra configurations for base datasets.
common_config (dict): Shared configurations (augs, tracks, mode, etc.).
**kwargs: Additional arguments (unused).
"""
base_dataset_list = []
# Instantiate each base dataset with common configuration
for baseset_dict in dataset_configs:
baseset = instantiate(baseset_dict, common_conf=common_config)
base_dataset_list.append(baseset)
# Use custom concatenation class that supports tuple indexing
self.base_dataset = TupleConcatDataset(base_dataset_list, common_config)
# --- Augmentation Settings ---
# Controls whether to apply identical color jittering across all frames in a sequence
self.cojitter = common_config.augs.cojitter
# Probability of using shared jitter vs. frame-specific jitter
self.cojitter_ratio = common_config.augs.cojitter_ratio
# Initialize image augmentations (color jitter, grayscale, gaussian blur)
self.image_aug = get_image_augmentation(
color_jitter=common_config.augs.color_jitter,
gray_scale=common_config.augs.gray_scale,
gau_blur=common_config.augs.gau_blur,
)
# --- Optional Fixed Settings (useful for debugging) ---
# Force each sequence to have exactly this many images (if > 0)
self.fixed_num_images = common_config.fix_img_num
# Force a specific aspect ratio for all images
self.fixed_aspect_ratio = common_config.fix_aspect_ratio
# --- Track Settings ---
# Whether to include point tracks in the output
self.load_track = common_config.load_track
# Number of point tracks to include per sequence
self.track_num = common_config.track_num
# --- Mode Settings ---
# Whether the dataset is being used for training (affects augmentations)
self.training = common_config.training
self.common_config = common_config
self.total_samples = len(self.base_dataset)
def __len__(self):
"""Returns the total number of sequences in the dataset."""
return self.total_samples
def __getitem__(self, idx_tuple):
"""
Retrieves a data sample (sequence) from the dataset.
Loads raw data, converts to PyTorch tensors, applies augmentations,
and prepares tracks if enabled.
Args:
idx_tuple (tuple): a tuple of (seq_idx, num_images, aspect_ratio)
Returns:
dict: A dictionary containing the sequence data (images, poses, tracks, etc.).
"""
# If fixed settings are provided, override the tuple values
if self.fixed_num_images > 0:
seq_idx = idx_tuple[0] if isinstance(idx_tuple, tuple) else idx_tuple
idx_tuple = (seq_idx, self.fixed_num_images, self.fixed_aspect_ratio)
# Retrieve the raw data batch from the appropriate base dataset
batch = self.base_dataset[idx_tuple]
seq_name = batch["seq_name"]
# --- Data Conversion and Preparation ---
# Convert numpy arrays to tensors
images = torch.from_numpy(np.stack(batch["images"]).astype(np.float32)).contiguous()
# Normalize images from [0, 255] to [0, 1]
images = images.permute(0,3,1,2).to(torch.get_default_dtype()).div(255)
# Convert other data to tensors with appropriate types
depths = torch.from_numpy(np.stack(batch["depths"]).astype(np.float32))
extrinsics = torch.from_numpy(np.stack(batch["extrinsics"]).astype(np.float32))
intrinsics = torch.from_numpy(np.stack(batch["intrinsics"]).astype(np.float32))
cam_points = torch.from_numpy(np.stack(batch["cam_points"]).astype(np.float32))
world_points = torch.from_numpy(np.stack(batch["world_points"]).astype(np.float32))
point_masks = torch.from_numpy(np.stack(batch["point_masks"])) # Mask indicating valid depths / world points / cam points per frame
ids = torch.from_numpy(batch["ids"]) # Frame indices sampled from the original sequence
# --- Apply Color Augmentation (training mode only) ---
if self.training and self.image_aug is not None:
if self.cojitter and random.random() > self.cojitter_ratio:
# Apply the same color jittering transformation to all frames
images = self.image_aug(images)
else:
# Apply different color jittering to each frame individually
for aug_img_idx in range(len(images)):
images[aug_img_idx] = self.image_aug(images[aug_img_idx])
# --- Prepare Final Sample Dictionary ---
sample = {
"seq_name": seq_name,
"ids": ids,
"images": images,
"depths": depths,
"extrinsics": extrinsics,
"intrinsics": intrinsics,
"cam_points": cam_points,
"world_points": world_points,
"point_masks": point_masks,
}
# --- Track Processing (if enabled) ---
if self.load_track:
if batch["tracks"] is not None:
# Use pre-computed tracks from the dataset
tracks = torch.from_numpy(np.stack(batch["tracks"]).astype(np.float32))
track_vis_mask = torch.from_numpy(np.stack(batch["track_masks"]).astype(bool))
# Sample a subset of tracks randomly
valid_indices = torch.where(track_vis_mask[0])[0]
if len(valid_indices) >= self.track_num:
# If we have enough tracks, sample without replacement
sampled_indices = valid_indices[torch.randperm(len(valid_indices))][:self.track_num]
else:
# If not enough tracks, sample with replacement (allow duplicates)
sampled_indices = valid_indices[torch.randint(0, len(valid_indices),
(self.track_num,),
dtype=torch.int64,
device=valid_indices.device)]
# Extract the sampled tracks and their masks
tracks = tracks[:, sampled_indices, :]
track_vis_mask = track_vis_mask[:, sampled_indices]
track_positive_mask = torch.ones(track_vis_mask.shape[1]).bool()
else:
# Generate tracks on-the-fly using depth information
# This creates synthetic tracks based on the 3D information available
tracks, track_vis_mask, track_positive_mask = build_tracks_by_depth(
extrinsics, intrinsics, world_points, depths, point_masks, images,
target_track_num=self.track_num, seq_name=seq_name
)
# Add track information to the sample dictionary
sample["tracks"] = tracks
sample["track_vis_mask"] = track_vis_mask
sample["track_positive_mask"] = track_positive_mask
return sample
class TupleConcatDataset(ConcatDataset):
"""
A custom ConcatDataset that supports indexing with a tuple.
Standard PyTorch ConcatDataset only accepts an integer index. This class extends
that functionality to allow passing a tuple like (sample_idx, num_images, aspect_ratio),
where the first element is used to determine which sample to fetch, and the full
tuple is passed down to the selected dataset's __getitem__ method.
It also supports an option to randomly sample across all datasets, ignoring the
provided index. This is useful during training when shuffling the entire dataset
might cause memory issues due to duplicating dictionaries. If doing this, you can
set pytorch's dataloader shuffle to False.
"""
def __init__(self, datasets, common_config):
"""
Initialize the TupleConcatDataset.
Args:
datasets (iterable): An iterable of PyTorch Dataset objects to concatenate.
common_config (dict): Common configuration dict, used to check for random sampling.
"""
super().__init__(datasets)
# If True, ignores the input index and samples randomly across all datasets
# This provides an alternative to dataloader shuffling for large datasets
self.inside_random = common_config.inside_random
def __getitem__(self, idx):
"""
Retrieves an item using either an integer index or a tuple index.
Args:
idx (int or tuple): The index. If tuple, the first element is the sequence
index across the concatenated datasets, and the rest are
passed down. If int, it's treated as the sequence index.
Returns:
The item returned by the underlying dataset's __getitem__ method.
Raises:
ValueError: If the index is out of range or the tuple doesn't have exactly 3 elements.
"""
idx_tuple = None
if isinstance(idx, tuple):
idx_tuple = idx
idx = idx_tuple[0] # Extract the sequence index
# Override index with random value if inside_random is enabled
if self.inside_random:
total_len = self.cumulative_sizes[-1]
idx = random.randint(0, total_len - 1)
# Handle negative indices
if idx < 0:
if -idx > len(self):
raise ValueError(
"absolute value of index should not exceed dataset length"
)
idx = len(self) + idx
# Find which dataset the index belongs to
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
# Create the tuple to pass to the underlying dataset
if len(idx_tuple) == 3:
idx_tuple = (sample_idx,) + idx_tuple[1:]
else:
raise ValueError("Tuple index must have exactly three elements")
# Pass the modified tuple to the appropriate dataset
return self.datasets[dataset_idx][idx_tuple]
================================================
FILE: training/data/dataset_util.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
import cv2
import math
import numpy as np
from PIL import Image
import PIL
try:
lanczos = PIL.Image.Resampling.LANCZOS
bicubic = PIL.Image.Resampling.BICUBIC
except AttributeError:
lanczos = PIL.Image.LANCZOS
bicubic = PIL.Image.BICUBIC
from vggt.utils.geometry import closed_form_inverse_se3
#####################################################################################################################
def crop_image_depth_and_intrinsic_by_pp(
image, depth_map, intrinsic, target_shape, track=None, filepath=None, strict=False
):
"""
TODO: some names of width and height seem not consistent. Need to check.
Crops the given image and depth map around the camera's principal point, as defined by `intrinsic`.
Specifically:
- Ensures that the crop is centered on (cx, cy).
- Optionally pads the image (and depth map) if `strict=True` and the result is smaller than `target_shape`.
- Shifts the camera intrinsic matrix (and `track` if provided) accordingly.
Args:
image (np.ndarray):
Input image array of shape (H, W, 3).
depth_map (np.ndarray or None):
Depth map array of shape (H, W), or None if not available.
intrinsic (np.ndarray):
Camera intrinsic matrix (3x3). The principal point is assumed to be at (intrinsic[1,2], intrinsic[0,2]).
target_shape (tuple[int, int]):
Desired output shape.
track (np.ndarray or None):
Optional array of shape (N, 2). Interpreted as (x, y) pixel coordinates. Will be shifted after cropping.
filepath (str or None):
An optional file path for debug logging (only used if strict mode triggers warnings).
strict (bool):
If True, will zero-pad to ensure the exact target_shape even if the cropped region is smaller.
Raises:
AssertionError:
If the input image is smaller than `target_shape`.
ValueError:
If the cropped image is larger than `target_shape` (in strict mode), which should not normally happen.
Returns:
tuple:
(cropped_image, cropped_depth_map, updated_intrinsic, updated_track)
- cropped_image (np.ndarray): Cropped (and optionally padded) image.
- cropped_depth_map (np.ndarray or None): Cropped (and optionally padded) depth map.
- updated_intrinsic (np.ndarray): Intrinsic matrix adjusted for the crop.
- updated_track (np.ndarray or None): Track array adjusted for the crop, or None if track was not provided.
"""
original_size = np.array(image.shape)
intrinsic = np.copy(intrinsic)
if original_size[0] < target_shape[0]:
error_message = (
f"Width check failed: original width {original_size[0]} "
f"is less than target width {target_shape[0]}."
)
print(error_message)
raise AssertionError(error_message)
if original_size[1] < target_shape[1]:
error_message = (
f"Height check failed: original height {original_size[1]} "
f"is less than target height {target_shape[1]}."
)
print(error_message)
raise AssertionError(error_message)
# Identify principal point (cx, cy) from intrinsic
cx = (intrinsic[1, 2])
cy = (intrinsic[0, 2])
# Compute how far we can crop in each direction
if strict:
half_x = min((target_shape[0] / 2), cx)
half_y = min((target_shape[1] / 2), cy)
else:
half_x = min((target_shape[0] / 2), cx, original_size[0] - cx)
half_y = min((target_shape[1] / 2), cy, original_size[1] - cy)
# Compute starting indices
start_x = math.floor(cx) - math.floor(half_x)
start_y = math.floor(cy) - math.floor(half_y)
assert start_x >= 0
assert start_y >= 0
# Compute ending indices
if strict:
end_x = start_x + target_shape[0]
end_y = start_y + target_shape[1]
else:
end_x = start_x + 2 * math.floor(half_x)
end_y = start_y + 2 * math.floor(half_y)
# Perform the crop
image = image[start_x:end_x, start_y:end_y, :]
if depth_map is not None:
depth_map = depth_map[start_x:end_x, start_y:end_y]
# Shift the principal point in the intrinsic
intrinsic[1, 2] = intrinsic[1, 2] - start_x
intrinsic[0, 2] = intrinsic[0, 2] - start_y
# Adjust track if provided
if track is not None:
track[:, 1] = track[:, 1] - start_x
track[:, 0] = track[:, 0] - start_y
# If strict, zero-pad if the new shape is smaller than target_shape
if strict:
if (image.shape[:2] != target_shape).any():
print(f"{filepath} does not meet the target shape")
current_h, current_w = image.shape[:2]
target_h, target_w = target_shape[0], target_shape[1]
pad_h = target_h - current_h
pad_w = target_w - current_w
if pad_h < 0 or pad_w < 0:
raise ValueError(
f"The cropped image is bigger than the target shape: "
f"cropped=({current_h},{current_w}), "
f"target=({target_h},{target_w})."
)
image = np.pad(
image,
pad_width=((0, pad_h), (0, pad_w), (0, 0)),
mode="constant",
constant_values=0,
)
if depth_map is not None:
depth_map = np.pad(
depth_map,
pad_width=((0, pad_h), (0, pad_w)),
mode="constant",
constant_values=0,
)
return image, depth_map, intrinsic, track
def resize_image_depth_and_intrinsic(
image,
depth_map,
intrinsic,
target_shape,
original_size,
track=None,
pixel_center=True,
safe_bound=4,
rescale_aug=True,
):
"""
Resizes the given image and depth map (if provided) to slightly larger than `target_shape`,
updating the intrinsic matrix (and track array if present). Optionally uses random rescaling
to create some additional margin (based on `rescale_aug`).
Steps:
1. Compute a scaling factor so that the resized result is at least `target_shape + safe_bound`.
2. Apply an optional triangular random factor if `rescale_aug=True`.
3. Resize the image with LANCZOS if downscaling, BICUBIC if upscaling.
4. Resize the depth map with nearest-neighbor.
5. Update the camera intrinsic and track coordinates (if any).
Args:
image (np.ndarray):
Input image array (H, W, 3).
depth_map (np.ndarray or None):
Depth map array (H, W), or None if unavailable.
intrinsic (np.ndarray):
Camera intrinsic matrix (3x3).
target_shape (np.ndarray or tuple[int, int]):
Desired final shape (height, width).
original_size (np.ndarray or tuple[int, int]):
Original size of the image in (height, width).
track (np.ndarray or None):
Optional (N, 2) array of pixel coordinates. Will be scaled.
pixel_center (bool):
If True, accounts for 0.5 pixel center shift during resizing.
safe_bound (int or float):
Additional margin (in pixels) to add to target_shape before resizing.
rescale_aug (bool):
If True, randomly increase the `safe_bound` within a certain range to simulate augmentation.
Returns:
tuple:
(resized_image, resized_depth_map, updated_intrinsic, updated_track)
- resized_image (np.ndarray): The resized image.
- resized_depth_map (np.ndarray or None): The resized depth map.
- updated_intrinsic (np.ndarray): Camera intrinsic updated for new resolution.
- updated_track (np.ndarray or None): Track array updated or None if not provided.
Raises:
AssertionError:
If the shapes of the resized image and depth map do not match.
"""
if rescale_aug:
random_boundary = np.random.triangular(0, 0, 0.3)
safe_bound = safe_bound + random_boundary * target_shape.max()
resize_scales = (target_shape + safe_bound) / original_size
max_resize_scale = np.max(resize_scales)
intrinsic = np.copy(intrinsic)
# Convert image to PIL for resizing
image = Image.fromarray(image)
input_resolution = np.array(image.size)
output_resolution = np.floor(input_resolution * max_resize_scale).astype(int)
image = image.resize(tuple(output_resolution), resample=lanczos if max_resize_scale < 1 else bicubic)
image = np.array(image)
if depth_map is not None:
depth_map = cv2.resize(
depth_map,
output_resolution,
fx=max_resize_scale,
fy=max_resize_scale,
interpolation=cv2.INTER_NEAREST,
)
actual_size = np.array(image.shape[:2])
actual_resize_scale = np.max(actual_size / original_size)
if pixel_center:
intrinsic[0, 2] = intrinsic[0, 2] + 0.5
intrinsic[1, 2] = intrinsic[1, 2] + 0.5
intrinsic[:2, :] = intrinsic[:2, :] * actual_resize_scale
if track is not None:
track = track * actual_resize_scale
if pixel_center:
intrinsic[0, 2] = intrinsic[0, 2] - 0.5
intrinsic[1, 2] = intrinsic[1, 2] - 0.5
assert image.shape[:2] == depth_map.shape[:2]
return image, depth_map, intrinsic, track
def threshold_depth_map(
depth_map: np.ndarray,
max_percentile: float = 99,
min_percentile: float = 1,
max_depth: float = -1,
) -> np.ndarray:
"""
Thresholds a depth map using percentile-based limits and optional maximum depth clamping.
Steps:
1. If `max_depth > 0`, clamp all values above `max_depth` to zero.
2. Compute `max_percentile` and `min_percentile` thresholds using nanpercentile.
3. Zero out values above/below these thresholds, if thresholds are > 0.
Args:
depth_map (np.ndarray):
Input depth map (H, W).
max_percentile (float):
Upper percentile (0-100). Values above this will be set to zero.
min_percentile (float):
Lower percentile (0-100). Values below this will be set to zero.
max_depth (float):
Absolute maximum depth. If > 0, any depth above this is set to zero.
If <= 0, no maximum-depth clamp is applied.
Returns:
np.ndarray:
Depth map (H, W) after thresholding. Some or all values may be zero.
Returns None if depth_map is None.
"""
if depth_map is None:
return None
depth_map = depth_map.astype(float, copy=True)
# Optional clamp by max_depth
if max_depth > 0:
depth_map[depth_map > max_depth] = 0.0
# Percentile-based thresholds
depth_max_thres = (
np.nanpercentile(depth_map, max_percentile) if max_percentile > 0 else None
)
depth_min_thres = (
np.nanpercentile(depth_map, min_percentile) if min_percentile > 0 else None
)
# Apply the thresholds if they are > 0
if depth_max_thres is not None and depth_max_thres > 0:
depth_map[depth_map > depth_max_thres] = 0.0
if depth_min_thres is not None and depth_min_thres > 0:
depth_map[depth_map < depth_min_thres] = 0.0
return depth_map
def depth_to_world_coords_points(
depth_map: np.ndarray,
extrinsic: np.ndarray,
intrinsic: np.ndarray,
eps=1e-8,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Converts a depth map to world coordinates (HxWx3) given the camera extrinsic and intrinsic.
Returns both the world coordinates and the intermediate camera coordinates,
as well as a mask for valid depth.
Args:
depth_map (np.ndarray):
Depth map of shape (H, W).
extrinsic (np.ndarray):
Extrinsic matrix of shape (3, 4), representing the camera pose in OpenCV convention (camera-from-world).
intrinsic (np.ndarray):
Intrinsic matrix of shape (3, 3).
eps (float):
Small epsilon for thresholding valid depth.
Returns:
tuple[np.ndarray, np.ndarray, np.ndarray]:
(world_coords_points, cam_coords_points, point_mask)
- world_coords_points: (H, W, 3) array of 3D points in world frame.
- cam_coords_points: (H, W, 3) array of 3D points in camera frame.
- point_mask: (H, W) boolean array where True indicates valid (non-zero) depth.
"""
if depth_map is None:
return None, None, None
# Valid depth mask
point_mask = depth_map > eps
# Convert depth map to camera coordinates
cam_coords_points = depth_to_cam_coords_points(depth_map, intrinsic)
# The extrinsic is camera-from-world, so invert it to transform camera->world
cam_to_world_extrinsic = closed_form_inverse_se3(extrinsic[None])[0]
R_cam_to_world = cam_to_world_extrinsic[:3, :3]
t_cam_to_world = cam_to_world_extrinsic[:3, 3]
# Apply the rotation and translation to the camera coordinates
world_coords_points = (
np.dot(cam_coords_points, R_cam_to_world.T) + t_cam_to_world
) # HxWx3, 3x3 -> HxWx3
# world_coords_points = np.einsum("ij,hwj->hwi", R_cam_to_world, cam_coords_points) + t_cam_to_world
return world_coords_points, cam_coords_points, point_mask
def depth_to_cam_coords_points(
depth_map: np.ndarray, intrinsic: np.ndarray
) -> np.ndarray:
"""
Unprojects a depth map into camera coordinates, returning (H, W, 3).
Args:
depth_map (np.ndarray):
Depth map of shape (H, W).
intrinsic (np.ndarray):
3x3 camera intrinsic matrix.
Assumes zero skew and standard OpenCV layout:
[ fx 0 cx ]
[ 0 fy cy ]
[ 0 0 1 ]
Returns:
np.ndarray:
An (H, W, 3) array, where each pixel is mapped to (x, y, z) in the camera frame.
"""
H, W = depth_map.shape
assert intrinsic.shape == (3, 3), "Intrinsic matrix must be 3x3"
assert (
intrinsic[0, 1] == 0 and intrinsic[1, 0] == 0
), "Intrinsic matrix must have zero skew"
# Intrinsic parameters
fu, fv = intrinsic[0, 0], intrinsic[1, 1]
cu, cv = intrinsic[0, 2], intrinsic[1, 2]
# Generate grid of pixel coordinates
u, v = np.meshgrid(np.arange(W), np.arange(H))
# Unproject to camera coordinates
x_cam = (u - cu) * depth_map / fu
y_cam = (v - cv) * depth_map / fv
z_cam = depth_map
# Stack to form camera coordinates
return np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
def rotate_90_degrees(
image, depth_map, extri_opencv, intri_opencv, clockwise=True, track=None
):
"""
Rotates the input image, depth map, and camera parameters by 90 degrees.
Applies one of two 90-degree rotations:
- Clockwise
- Counterclockwise (if clockwise=False)
The extrinsic and intrinsic matrices are adjusted accordingly to maintain
correct camera geometry. Track coordinates are also updated if provided.
Args:
image (np.ndarray):
Input image of shape (H, W, 3).
depth_map (np.ndarray or None):
Depth map of shape (H, W), or None if not available.
extri_opencv (np.ndarray):
Extrinsic matrix (3x4) in OpenCV convention.
intri_opencv (np.ndarray):
Intrinsic matrix (3x3).
clockwise (bool):
If True, rotates the image 90 degrees clockwise; else 90 degrees counterclockwise.
track (np.ndarray or None):
Optional (N, 2) track array. Will be rotated accordingly.
Returns:
tuple:
(
rotated_image,
rotated_depth_map,
new_extri_opencv,
new_intri_opencv,
new_track
)
Where each is the updated version after the rotation.
"""
image_height, image_width = image.shape[:2]
# Rotate the image and depth map
rotated_image, rotated_depth_map = rotate_image_and_depth_rot90(image, depth_map, clockwise)
# Adjust the intrinsic matrix
new_intri_opencv = adjust_intrinsic_matrix_rot90(intri_opencv, image_width, image_height, clockwise)
if track is not None:
new_track = adjust_track_rot90(track, image_width, image_height, clockwise)
else:
new_track = None
# Adjust the extrinsic matrix
new_extri_opencv = adjust_extrinsic_matrix_rot90(extri_opencv, clockwise)
return (
rotated_image,
rotated_depth_map,
new_extri_opencv,
new_intri_opencv,
new_track,
)
def rotate_image_and_depth_rot90(image, depth_map, clockwise):
"""
Rotates the given image and depth map by 90 degrees (clockwise or counterclockwise),
using a transpose+flip pattern.
Args:
image (np.ndarray):
Input image of shape (H, W, 3).
depth_map (np.ndarray or None):
Depth map of shape (H, W), or None if not available.
clockwise (bool):
If True, rotate 90 degrees clockwise; else 90 degrees counterclockwise.
Returns:
tuple:
(rotated_image, rotated_depth_map)
"""
rotated_depth_map = None
if clockwise:
rotated_image = np.transpose(image, (1, 0, 2)) # Transpose height and width
rotated_image = np.flip(rotated_image, axis=1) # Flip horizontally
if depth_map is not None:
rotated_depth_map = np.transpose(depth_map, (1, 0))
rotated_depth_map = np.flip(rotated_depth_map, axis=1)
else:
rotated_image = np.transpose(image, (1, 0, 2)) # Transpose height and width
rotated_image = np.flip(rotated_image, axis=0) # Flip vertically
if depth_map is not None:
rotated_depth_map = np.transpose(depth_map, (1, 0))
rotated_depth_map = np.flip(rotated_depth_map, axis=0)
return np.copy(rotated_image), np.copy(rotated_depth_map)
def adjust_extrinsic_matrix_rot90(extri_opencv, clockwise):
"""
Adjusts the extrinsic matrix (3x4) for a 90-degree rotation of the image.
The rotation is in the image plane. This modifies the camera orientation
accordingly. The function applies either a clockwise or counterclockwise
90-degree rotation.
Args:
extri_opencv (np.ndarray):
Extrinsic matrix (3x4) in OpenCV convention.
clockwise (bool):
If True, rotate extrinsic for a 90-degree clockwise image rotation;
otherwise, counterclockwise.
Returns:
np.ndarray:
A new 3x4 extrinsic matrix after the rotation.
"""
R = extri_opencv[:, :3]
t = extri_opencv[:, 3]
if clockwise:
R_rotation = np.array([
[0, -1, 0],
[1, 0, 0],
[0, 0, 1]
])
else:
R_rotation = np.array([
[0, 1, 0],
[-1, 0, 0],
[0, 0, 1]
])
new_R = np.dot(R_rotation, R)
new_t = np.dot(R_rotation, t)
new_extri_opencv = np.hstack((new_R, new_t.reshape(-1, 1)))
return new_extri_opencv
def adjust_intrinsic_matrix_rot90(intri_opencv, image_width, image_height, clockwise):
"""
Adjusts the intrinsic matrix (3x3) for a 90-degree rotation of the image in the image plane.
Args:
intri_opencv (np.ndarray):
Intrinsic matrix (3x3).
image_width (int):
Original width of the image.
image_height (int):
Original height of the image.
clockwise (bool):
If True, rotate 90 degrees clockwise; else 90 degrees counterclockwise.
Returns:
np.ndarray:
A new 3x3 intrinsic matrix after the rotation.
"""
fx, fy, cx, cy = (
intri_opencv[0, 0],
intri_opencv[1, 1],
intri_opencv[0, 2],
intri_opencv[1, 2],
)
new_intri_opencv = np.eye(3)
if clockwise:
new_intri_opencv[0, 0] = fy
new_intri_opencv[1, 1] = fx
new_intri_opencv[0, 2] = image_height - cy
new_intri_opencv[1, 2] = cx
else:
new_intri_opencv[0, 0] = fy
new_intri_opencv[1, 1] = fx
new_intri_opencv[0, 2] = cy
new_intri_opencv[1, 2] = image_width - cx
return new_intri_opencv
def adjust_track_rot90(track, image_width, image_height, clockwise):
"""
Adjusts a track (N, 2) for a 90-degree rotation of the image in the image plane.
Args:
track (np.ndarray):
(N, 2) array of pixel coordinates, each row is (x, y).
image_width (int):
Original image width.
image_height (int):
Original image height.
clockwise (bool):
Whether the rotation is 90 degrees clockwise or counterclockwise.
Returns:
np.ndarray:
A new track of shape (N, 2) after rotation.
"""
if clockwise:
# (x, y) -> (y, image_width - 1 - x)
new_track = np.stack((track[:, 1], image_width - 1 - track[:, 0]), axis=-1)
else:
# (x, y) -> (image_height - 1 - y, x)
new_track = np.stack((image_height - 1 - track[:, 1], track[:, 0]), axis=-1)
return new_track
def read_image_cv2(path: str, rgb: bool = True) -> np.ndarray:
"""
Reads an image from disk using OpenCV, returning it as an RGB image array (H, W, 3).
Args:
path (str):
File path to the image.
rgb (bool):
If True, convert the image to RGB.
If False, leave the image in BGR/grayscale.
Returns:
np.ndarray or None:
A numpy array of shape (H, W, 3) if successful,
or None if the file does not exist or could not be read.
"""
if not os.path.exists(path) or os.path.getsize(path) == 0:
print(f"File does not exist or is empty: {path}")
return None
img = cv2.imread(path)
if img is None:
print(f"Could not load image={path}. Retrying...")
img = cv2.imread(path)
if img is None:
print("Retry failed.")
return None
if rgb:
if len(img.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
else:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
def read_depth(path: str, scale_adjustment=1.0) -> np.ndarray:
"""
Reads a depth map from disk in either .exr or .png format. The .exr is loaded using OpenCV
with the environment variable OPENCV_IO_ENABLE_OPENEXR=1. The .png is assumed to be a 16-bit
PNG (converted from half float).
Args:
path (str):
File path to the depth image. Must end with .exr or .png.
scale_adjustment (float):
A multiplier for adjusting the loaded depth values (default=1.0).
Returns:
np.ndarray:
A float32 array (H, W) containing the loaded depth. Zeros or non-finite values
may indicate invalid regions.
Raises:
ValueError:
If the file extension is not supported.
"""
if path.lower().endswith(".exr"):
# Ensure OPENCV_IO_ENABLE_OPENEXR is set to "1"
d = cv2.imread(path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)[..., 0]
d[d > 1e9] = 0.0
elif path.lower().endswith(".png"):
d = load_16big_png_depth(path)
else:
raise ValueError(f'unsupported depth file name "{path}"')
d = d * scale_adjustment
d[~np.isfinite(d)] = 0.0
return d
def load_16big_png_depth(depth_png: str) -> np.ndarray:
"""
Loads a 16-bit PNG as a half-float depth map (H, W), returning a float32 NumPy array.
Implementation detail:
- PIL loads 16-bit data as 32-bit "I" mode.
- We reinterpret the bits as float16, then cast to float32.
Args:
depth_png (str):
File path to the 16-bit PNG.
Returns:
np.ndarray:
A float32 depth array of shape (H, W).
"""
with Image.open(depth_png) as depth_pil:
depth = (
np.frombuffer(np.array(depth_pil, dtype=np.uint16), dtype=np.float16)
.astype(np.float32)
.reshape((depth_pil.size[1], depth_pil.size[0]))
)
return depth
================================================
FILE: training/data/datasets/co3d.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import gzip
import json
import os.path as osp
import os
import logging
import cv2
import random
import numpy as np
from data.dataset_util import *
from data.base_dataset import BaseDataset
SEEN_CATEGORIES = [
"apple",
"backpack",
"banana",
"baseballbat",
"baseballglove",
"bench",
"bicycle",
"bottle",
"bowl",
"broccoli",
"cake",
"car",
"carrot",
"cellphone",
"chair",
"cup",
"donut",
"hairdryer",
"handbag",
"hydrant",
"keyboard",
"laptop",
"microwave",
"motorcycle",
"mouse",
"orange",
"parkingmeter",
"pizza",
"plant",
"stopsign",
"teddybear",
"toaster",
"toilet",
"toybus",
"toyplane",
"toytrain",
"toytruck",
"tv",
"umbrella",
"vase",
"wineglass",
]
class Co3dDataset(BaseDataset):
def __init__(
self,
common_conf,
split: str = "train",
CO3D_DIR: str = None,
CO3D_ANNOTATION_DIR: str = None,
min_num_images: int = 24,
len_train: int = 100000,
len_test: int = 10000,
):
"""
Initialize the Co3dDataset.
Args:
common_conf: Configuration object with common settings.
split (str): Dataset split, either 'train' or 'test'.
CO3D_DIR (str): Directory path to CO3D data.
CO3D_ANNOTATION_DIR (str): Directory path to CO3D annotations.
min_num_images (int): Minimum number of images per sequence.
len_train (int): Length of the training dataset.
len_test (int): Length of the test dataset.
Raises:
ValueError: If CO3D_DIR or CO3D_ANNOTATION_DIR is not specified.
"""
super().__init__(common_conf=common_conf)
self.debug = common_conf.debug
self.training = common_conf.training
self.get_nearby = common_conf.get_nearby
self.load_depth = common_conf.load_depth
self.inside_random = common_conf.inside_random
self.allow_duplicate_img = common_conf.allow_duplicate_img
if CO3D_DIR is None or CO3D_ANNOTATION_DIR is None:
raise ValueError("Both CO3D_DIR and CO3D_ANNOTATION_DIR must be specified.")
category = sorted(SEEN_CATEGORIES)
if self.debug:
category = ["apple"]
if split == "train":
split_name_list = ["train"]
self.len_train = len_train
elif split == "test":
split_name_list = ["test"]
self.len_train = len_test
else:
raise ValueError(f"Invalid split: {split}")
self.invalid_sequence = [] # set any invalid sequence names here
self.category_map = {}
self.data_store = {}
self.seqlen = None
self.min_num_images = min_num_images
logging.info(f"CO3D_DIR is {CO3D_DIR}")
self.CO3D_DIR = CO3D_DIR
self.CO3D_ANNOTATION_DIR = CO3D_ANNOTATION_DIR
total_frame_num = 0
for c in category:
for split_name in split_name_list:
annotation_file = osp.join(
self.CO3D_ANNOTATION_DIR, f"{c}_{split_name}.jgz"
)
try:
with gzip.open(annotation_file, "r") as fin:
annotation = json.loads(fin.read())
except FileNotFoundError:
logging.error(f"Annotation file not found: {annotation_file}")
continue
for seq_name, seq_data in annotation.items():
if len(seq_data) < min_num_images:
continue
if seq_name in self.invalid_sequence:
continue
total_frame_num += len(seq_data)
self.data_store[seq_name] = seq_data
self.sequence_list = list(self.data_store.keys())
self.sequence_list_len = len(self.sequence_list)
self.total_frame_num = total_frame_num
status = "Training" if self.training else "Testing"
logging.info(f"{status}: Co3D Data size: {self.sequence_list_len}")
logging.info(f"{status}: Co3D Data dataset length: {len(self)}")
def get_data(
self,
seq_index: int = None,
img_per_seq: int = None,
seq_name: str = None,
ids: list = None,
aspect_ratio: float = 1.0,
) -> dict:
"""
Retrieve data for a specific sequence.
Args:
seq_index (int): Index of the sequence to retrieve.
img_per_seq (int): Number of images per sequence.
seq_name (str): Name of the sequence.
ids (list): Specific IDs to retrieve.
aspect_ratio (float): Aspect ratio for image processing.
Returns:
dict: A batch of data including images, depths, and other metadata.
"""
if self.inside_random:
seq_index = random.randint(0, self.sequence_list_len - 1)
if seq_name is None:
seq_name = self.sequence_list[seq_index]
metadata = self.data_store[seq_name]
if ids is None:
ids = np.random.choice(
len(metadata), img_per_seq, replace=self.allow_duplicate_img
)
annos = [metadata[i] for i in ids]
target_image_shape = self.get_target_shape(aspect_ratio)
images = []
depths = []
cam_points = []
world_points = []
point_masks = []
extrinsics = []
intrinsics = []
image_paths = []
original_sizes = []
for anno in annos:
filepath = anno["filepath"]
image_path = osp.join(self.CO3D_DIR, filepath)
image = read_image_cv2(image_path)
if self.load_depth:
depth_path = image_path.replace("/images", "/depths") + ".geometric.png"
depth_map = read_depth(depth_path, 1.0)
mvs_mask_path = image_path.replace(
"/images", "/depth_masks"
).replace(".jpg", ".png")
mvs_mask = cv2.imread(mvs_mask_path, cv2.IMREAD_GRAYSCALE) > 128
depth_map[~mvs_mask] = 0
depth_map = threshold_depth_map(
depth_map, min_percentile=-1, max_percentile=98
)
else:
depth_map = None
original_size = np.array(image.shape[:2])
extri_opencv = np.array(anno["extri"])
intri_opencv = np.array(anno["intri"])
(
image,
depth_map,
extri_opencv,
intri_opencv,
world_coords_points,
cam_coords_points,
point_mask,
_,
) = self.process_one_image(
image,
depth_map,
extri_opencv,
intri_opencv,
original_size,
target_image_shape,
filepath=filepath,
)
images.append(image)
depths.append(depth_map)
extrinsics.append(extri_opencv)
intrinsics.append(intri_opencv)
cam_points.append(cam_coords_points)
world_points.append(world_coords_points)
point_masks.append(point_mask)
image_paths.append(image_path)
original_sizes.append(original_size)
set_name = "co3d"
batch = {
"seq_name": set_name + "_" + seq_name,
"ids": ids,
"frame_num": len(extrinsics),
"images": images,
"depths": depths,
"extrinsics": extrinsics,
"intrinsics": intrinsics,
"cam_points": cam_points,
"world_points": world_points,
"point_masks": point_masks,
"original_sizes": original_sizes,
}
return batch
================================================
FILE: training/data/datasets/vkitti.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import os.path as osp
import logging
import random
import glob
import cv2
import numpy as np
from data.dataset_util import *
from data.base_dataset import BaseDataset
class VKittiDataset(BaseDataset):
def __init__(
self,
common_conf,
split: str = "train",
VKitti_DIR: str = "/checkpoint/repligen/jianyuan/datasets/vkitti/",
min_num_images: int = 24,
len_train: int = 100000,
len_test: int = 10000,
expand_ratio: int = 8,
):
"""
Initialize the VKittiDataset.
Args:
common_conf: Configuration object with common settings.
split (str): Dataset split, either 'train' or 'test'.
VKitti_DIR (str): Directory path to VKitti data.
min_num_images (int): Minimum number of images per sequence.
len_train (int): Length of the training dataset.
len_test (int): Length of the test dataset.
expand_range (int): Range for expanding nearby image selection.
get_nearby_thres (int): Threshold for nearby image selection.
"""
super().__init__(common_conf=common_conf)
self.debug = common_conf.debug
self.training = common_conf.training
self.get_nearby = common_conf.get_nearby
self.inside_random = common_conf.inside_random
self.allow_duplicate_img = common_conf.allow_duplicate_img
self.expand_ratio = expand_ratio
self.VKitti_DIR = VKitti_DIR
self.min_num_images = min_num_images
if split == "train":
self.len_train = len_train
elif split == "test":
self.len_train = len_test
else:
raise ValueError(f"Invalid split: {split}")
logging.info(f"VKitti_DIR is {self.VKitti_DIR}")
# Load or generate sequence list
txt_path = osp.join(self.VKitti_DIR, "sequence_list.txt")
if osp.exists(txt_path):
with open(txt_path, 'r') as f:
sequence_list = [line.strip() for line in f.readlines()]
else:
# Generate sequence list and save to txt
sequence_list = glob.glob(osp.join(self.VKitti_DIR, "*/*/*/rgb/*"))
sequence_list = [file_path.split(self.VKitti_DIR)[-1].lstrip('/') for file_path in sequence_list]
sequence_list = sorted(sequence_list)
# Save to txt file
with open(txt_path, 'w') as f:
f.write('\n'.join(sequence_list))
self.sequence_list = sequence_list
self.sequence_list_len = len(self.sequence_list)
self.depth_max = 80
status = "Training" if self.training else "Testing"
logging.info(f"{status}: VKitti Real Data size: {self.sequence_list_len}")
logging.info(f"{status}: VKitti Data dataset length: {len(self)}")
def get_data(
self,
seq_index: int = None,
img_per_seq: int = None,
seq_name: str = None,
ids: list = None,
aspect_ratio: float = 1.0,
) -> dict:
"""
Retrieve data for a specific sequence.
Args:
seq_index (int): Index of the sequence to retrieve.
img_per_seq (int): Number of images per sequence.
seq_name (str): Name of the sequence.
ids (list): Specific IDs to retrieve.
aspect_ratio (float): Aspect ratio for image processing.
Returns:
dict: A batch of data including images, depths, and other metadata.
"""
if self.inside_random and self.training:
seq_index = random.randint(0, self.sequence_list_len - 1)
if seq_name is None:
seq_name = self.sequence_list[seq_index]
camera_id = int(seq_name[-1])
# Load camera parameters
try:
camera_parameters = np.loadtxt(
osp.join(self.VKitti_DIR, "/".join(seq_name.split("/")[:2]), "extrinsic.txt"),
delimiter=" ",
skiprows=1
)
camera_parameters = camera_parameters[camera_parameters[:, 1] == camera_id]
camera_intrinsic = np.loadtxt(
osp.join(self.VKitti_DIR, "/".join(seq_name.split("/")[:2]), "intrinsic.txt"),
delimiter=" ",
skiprows=1
)
camera_intrinsic = camera_intrinsic[camera_intrinsic[:, 1] == camera_id]
except Exception as e:
logging.error(f"Error loading camera parameters for {seq_name}: {e}")
raise
num_images = len(camera_parameters)
if ids is None:
ids = np.random.choice(num_images, img_per_seq, replace=self.allow_duplicate_img)
if self.get_nearby:
ids = self.get_nearby_ids(ids, num_images, expand_ratio=self.expand_ratio)
target_image_shape = self.get_target_shape(aspect_ratio)
images = []
depths = []
cam_points = []
world_points = []
point_masks = []
extrinsics = []
intrinsics = []
original_sizes = []
for image_idx in ids:
image_filepath = osp.join(self.VKitti_DIR, seq_name, f"rgb_{image_idx:05d}.jpg")
depth_filepath = osp.join(self.VKitti_DIR, seq_name, f"depth_{image_idx:05d}.png").replace("/rgb", "/depth")
image = read_image_cv2(image_filepath)
depth_map = cv2.imread(depth_filepath, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
depth_map = depth_map / 100
depth_map = threshold_depth_map(depth_map, max_percentile=-1, min_percentile=-1, max_depth=self.depth_max)
assert image.shape[:2] == depth_map.shape, f"Image and depth shape mismatch: {image.shape[:2]} vs {depth_map.shape}"
original_size = np.array(image.shape[:2])
# Process camera matrices
extri_opencv = camera_parameters[image_idx][2:].reshape(4, 4)
extri_opencv = extri_opencv[:3]
intri_opencv = np.eye(3)
intri_opencv[0, 0] = camera_intrinsic[image_idx][-4]
intri_opencv[1, 1] = camera_intrinsic[image_idx][-3]
intri_opencv[0, 2] = camera_intrinsic[image_idx][-2]
intri_opencv[1, 2] = camera_intrinsic[image_idx][-1]
(
image,
depth_map,
extri_opencv,
intri_opencv,
world_coords_points,
cam_coords_points,
point_mask,
_,
) = self.process_one_image(
image,
depth_map,
extri_opencv,
intri_opencv,
original_size,
target_image_shape,
filepath=image_filepath,
)
if (image.shape[:2] != target_image_shape).any():
logging.error(f"Wrong shape for {seq_name}: expected {target_image_shape}, got {image.shape[:2]}")
continue
images.append(image)
depths.append(depth_map)
extrinsics.append(extri_opencv)
intrinsics.append(intri_opencv)
cam_points.append(cam_coords_points)
world_points.append(world_coords_points)
point_masks.append(point_mask)
original_sizes.append(original_size)
set_name = "vkitti"
batch = {
"seq_name": set_name + "_" + seq_name,
"ids": ids,
"frame_num": len(extrinsics),
"images": images,
"depths": depths,
"extrinsics": extrinsics,
"intrinsics": intrinsics,
"cam_points": cam_points,
"world_points": world_points,
"point_masks": point_masks,
"original_sizes": original_sizes,
}
return batch
================================================
FILE: training/data/dynamic_dataloader.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Callable, Optional
from hydra.utils import instantiate
import random
import numpy as np
from torch.utils.data import DataLoader, Dataset, DistributedSampler, IterableDataset, Sampler
from abc import ABC, abstractmethod
from .worker_fn import get_worker_init_fn
class DynamicTorchDataset(ABC):
def __init__(
self,
dataset: dict,
common_config: dict,
num_workers: int,
shuffle: bool,
pin_memory: bool,
drop_last: bool = True,
collate_fn: Optional[Callable] = None,
worker_init_fn: Optional[Callable] = None,
persistent_workers: bool = False,
seed: int = 42,
max_img_per_gpu: int = 48,
) -> None:
self.dataset_config = dataset
self.common_config = common_config
self.num_workers = num_workers
self.shuffle = shuffle
self.pin_memory = pin_memory
self.drop_last = drop_last
self.collate_fn = collate_fn
self.worker_init_fn = worker_init_fn
self.persistent_workers = persistent_workers
self.seed = seed
self.max_img_per_gpu = max_img_per_gpu
# Instantiate the dataset
self.dataset = instantiate(dataset, common_config=common_config, _recursive_=False)
# Extract aspect ratio and image number ranges from the configuration
self.aspect_ratio_range = common_config.augs.aspects # e.g., [0.5, 1.0]
self.image_num_range = common_config.img_nums # e.g., [2, 24]
# Validate the aspect ratio and image number ranges
if len(self.aspect_ratio_range) != 2 or self.aspect_ratio_range[0] > self.aspect_ratio_range[1]:
raise ValueError(f"aspect_ratio_range must be [min, max] with min <= max, got {self.aspect_ratio_range}")
if len(self.image_num_range) != 2 or self.image_num_range[0] < 1 or self.image_num_range[0] > self.image_num_range[1]:
raise ValueError(f"image_num_range must be [min, max] with 1 <= min <= max, got {self.image_num_range}")
# Create samplers
self.sampler = DynamicDistributedSampler(self.dataset, seed=seed, shuffle=shuffle)
self.batch_sampler = DynamicBatchSampler(
self.sampler,
self.aspect_ratio_range,
self.image_num_range,
seed=seed,
max_img_per_gpu=max_img_per_gpu
)
def get_loader(self, epoch):
print("Building dynamic dataloader with epoch:", epoch)
# Set the epoch for the sampler
self.sampler.set_epoch(epoch)
if hasattr(self.dataset, "epoch"):
self.dataset.epoch = epoch
if hasattr(self.dataset, "set_epoch"):
self.dataset.set_epoch(epoch)
# Create and return the dataloader
return DataLoader(
self.dataset,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
batch_sampler=self.batch_sampler,
collate_fn=self.collate_fn,
persistent_workers=self.persistent_workers,
worker_init_fn=get_worker_init_fn(
seed=self.seed,
num_workers=self.num_workers,
epoch=epoch,
worker_init_fn=self.worker_init_fn,
),
)
class DynamicBatchSampler(Sampler):
"""
A custom batch sampler that dynamically adjusts batch size, aspect ratio, and image number
for each sample. Batches within a sample share the same aspect ratio and image number.
"""
def __init__(self,
sampler,
aspect_ratio_range,
image_num_range,
epoch=0,
seed=42,
max_img_per_gpu=48):
"""
Initializes the dynamic batch sampler.
Args:
sampler: Instance of DynamicDistributedSampler.
aspect_ratio_range: List containing [min_aspect_ratio, max_aspect_ratio].
image_num_range: List containing [min_images, max_images] per sample.
epoch: Current epoch number.
seed: Random seed for reproducibility.
max_img_per_gpu: Maximum number of images to fit in GPU memory.
"""
self.sampler = sampler
self.aspect_ratio_range = aspect_ratio_range
self.image_num_range = image_num_range
self.rng = random.Random()
# Uniformly sample from the range of possible image numbers
# For any image number, the weight is 1.0 (uniform sampling). You can set any different weights here.
self.image_num_weights = {num_images: 1.0 for num_images in range(image_num_range[0], image_num_range[1]+1)}
# Possible image numbers, e.g., [2, 3, 4, ..., 24]
self.possible_nums = np.array([n for n in self.image_num_weights.keys()
if self.image_num_range[0] <= n <= self.image_num_range[1]])
# Normalize weights for sampling
weights = [self.image_num_weights[n] for n in self.possible_nums]
self.normalized_weights = np.array(weights) / sum(weights)
# Maximum image number per GPU
self.max_img_per_gpu = max_img_per_gpu
# Set the epoch for the sampler
self.set_epoch(epoch + seed)
def set_epoch(self, epoch):
"""
Sets the epoch for this sampler, affecting the random sequence.
Args:
epoch: The epoch number.
"""
self.sampler.set_epoch(epoch)
self.epoch = epoch
self.rng.seed(epoch * 100)
def __iter__(self):
"""
Yields batches of samples with synchronized dynamic parameters.
Returns:
Iterator yielding batches of indices with associated parameters.
"""
sampler_iterator = iter(self.sampler)
while True:
try:
# Sample random image number and aspect ratio
random_image_num = int(np.random.choice(self.possible_nums, p=self.normalized_weights))
random_aspect_ratio = round(self.rng.uniform(self.aspect_ratio_range[0], self.aspect_ratio_range[1]), 2)
# Update sampler parameters
self.sampler.update_parameters(
aspect_ratio=random_aspect_ratio,
image_num=random_image_num
)
# Calculate batch size based on max images per GPU and current image number
batch_size = self.max_img_per_gpu / random_image_num
batch_size = np.floor(batch_size).astype(int)
batch_size = max(1, batch_size) # Ensure batch size is at least 1
# Collect samples for the current batch
current_batch = []
for _ in range(batch_size):
try:
item = next(sampler_iterator) # item is (idx, aspect_ratio, image_num)
current_batch.append(item)
except StopIteration:
break # No more samples
if not current_batch:
break # No more data to yield
yield current_batch
except StopIteration:
break # End of sampler's iterator
def __len__(self):
# Return a large dummy length
return 1000000
class DynamicDistributedSampler(DistributedSampler):
"""
Extends PyTorch's DistributedSampler to include dynamic aspect_ratio and image_num
parameters, which can be passed into the dataset's __getitem__ method.
"""
def __init__(
self,
dataset,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = False,
seed: int = 0,
drop_last: bool = False,
):
super().__init__(
dataset,
num_replicas=num_replicas,
rank=rank,
shuffle=shuffle,
seed=seed,
drop_last=drop_last
)
self.aspect_ratio = None
self.image_num = None
def __iter__(self):
"""
Yields a sequence of (index, image_num, aspect_ratio).
Relies on the parent class's logic for shuffling/distributing
the indices across replicas, then attaches extra parameters.
"""
indices_iter = super().__iter__()
for idx in indices_iter:
yield (idx, self.image_num, self.aspect_ratio,)
def update_parameters(self, aspect_ratio, image_num):
"""
Updates dynamic parameters for each new epoch or iteration.
Args:
aspect_ratio: The aspect ratio to set.
image_num: The number of images to set.
"""
self.aspect_ratio = aspect_ratio
self.image_num = image_num
================================================
FILE: training/data/preprocess/vkitti.sh
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
mkdir vkitti
cd vkitti
wget https://download.europe.naverlabs.com//virtual_kitti_2.0.3/vkitti_2.0.3_rgb.tar
tar -xvf vkitti_2.0.3_rgb.tar
wget https://download.europe.naverlabs.com//virtual_kitti_2.0.3/vkitti_2.0.3_depth.tar
tar -xvf vkitti_2.0.3_depth.tar
wget https://download.europe.naverlabs.com//virtual_kitti_2.0.3/vkitti_2.0.3_textgt.tar.gz
tar -xvf vkitti_2.0.3_textgt.tar.gz
cd ..
================================================
FILE: training/data/track_util.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import cv2
import numpy as np
import torch
import logging
from vggt.utils.geometry import *
def build_tracks_by_depth(extrinsics, intrinsics, world_points, depths, point_masks, images, pos_rel_thres=0.05, neg_epipolar_thres=16,
boundary_thres=4, target_track_num=512, neg_ratio = 0.0, neg_sample_size_ratio = 0.5, seq_name=None):
"""
Args:
extrinsics: (N, 3, 4)
intrinsics: (N, 3, 3)
world_points: (N, H, W, 3)
depths: (N, H, W)
point_masks: (N, H, W)
pos_rel_thres: float, relative threshold for positive track depth check
neg_epipolar_thres: float, threshold for negative track epipolar check, in px
boundary_thres: int, boundary in px to skip near edges
target_track_num: int, total # tracks to build
neg_ratio: fraction of final tracks that should be negative
neg_sample_size_ratio: fraction of W/H used for random offset
Returns:
final_tracks: (N, P, 2) float
final_vis_masks: (N, P) bool
final_pos_masks: (P) bool, indicate if a mask is positive or negative
"""
# Wait, should we do this before resizing the image?
B, H, W, _ = world_points.shape
# We use the first frame as the query frame, so [0]
query_world_points = world_points[0]
query_point_masks = point_masks[0]
if (query_point_masks).sum() > 0:
# at least one point
valid_query_points = query_world_points[query_point_masks]
# image_points: BxPx2
# cam_points: Bx3xP (yes 3xP instead of Px3). Probably we can change it in the future
image_points, cam_points = project_world_points_to_cam(valid_query_points, extrinsics, intrinsics)
# proj_depths: BxP
proj_depths = cam_points[:, -1]
# floor to get the left top corner
uv_int = image_points.floor().long().clone()
uv_inside_flag = (uv_int[..., 0] >= boundary_thres) & (uv_int[..., 0] < (W - boundary_thres)) & (uv_int[..., 1] >= boundary_thres) & (uv_int[..., 1] < (H - boundary_thres))
uv_int[~uv_inside_flag] = 0
batch_indices = torch.arange(B).view(B, 1).expand(-1, uv_int.shape[1])
# Use these indices to sample from the depth map
# since we interpolate depths by nearest,
# so assume the left top corner is (x, y)
# we want to check for (x,y), (x+1,y), (x,y+1), (x+1,y+1)
depth_inside_flag = None
for shift in [(0,0), (1,0), (0,1), (1,1)]:
cur_uv_int = uv_int + torch.tensor(shift)
cur_depth_inside_flag = get_depth_inside_flag(depths, batch_indices, cur_uv_int, proj_depths, pos_rel_thres)
if depth_inside_flag is None:
depth_inside_flag = cur_depth_inside_flag
else:
depth_inside_flag = torch.logical_or(depth_inside_flag, cur_depth_inside_flag)
# B, P, 2
positive_tracks = image_points
positive_vis_masks = torch.logical_and(uv_inside_flag, depth_inside_flag)
else:
print(f"No valid query points in {seq_name}")
positive_tracks = torch.zeros(B, target_track_num, 2, device=world_points.device, dtype=torch.float32)
positive_vis_masks = torch.zeros(B, target_track_num, device=world_points.device, dtype=torch.bool)
sampled_neg_track_num = target_track_num * 4 # we sample more negative tracks to ensure the quality
perb_range = [int(W*neg_sample_size_ratio), int(H*neg_sample_size_ratio)]
# sample negative query points
us = torch.randint(low=0, high=W, size=(1, sampled_neg_track_num), device=world_points.device)
vs = torch.randint(low=0, high=H, size=(1, sampled_neg_track_num), device=world_points.device)
neg_query_uvs = torch.stack([us, vs], dim=-1)
# construct negative tracks
delta_us = torch.rand(size=(B, sampled_neg_track_num), device=world_points.device) * perb_range[0]
delta_vs = torch.rand(size=(B, sampled_neg_track_num), device=world_points.device) * perb_range[1]
delta_us[0] = 0
delta_vs[0] = 0
negative_tracks = neg_query_uvs + torch.stack([delta_us, delta_vs], dim=-1)
# Do epipolar check here
negative_sampson_distances = track_epipolar_check(negative_tracks, extrinsics, intrinsics)
negative_epipolar_check = (negative_sampson_distances > neg_epipolar_thres).all(dim=0) # we set the threshold to 5 px
# Filter out those satifsfying epipolar check
negative_tracks = negative_tracks[:, negative_epipolar_check]
# Prepare for output
final_tracks = torch.zeros(B, target_track_num, 2, device=world_points.device, dtype=torch.float32)
final_vis_masks = torch.zeros(B, target_track_num, device=world_points.device, dtype=torch.bool)
final_pos_masks = torch.zeros(target_track_num, device=world_points.device, dtype=torch.bool)
target_pos_track_num = target_track_num - int(target_track_num * neg_ratio)
sampled_pos_track_num = 0
sampled_positive_tracks, sampled_positive_vis_masks = sample_positive_tracks(positive_tracks, positive_vis_masks, target_pos_track_num)
sampled_pos_track_num = sampled_positive_tracks.shape[1]
final_tracks[:, :sampled_pos_track_num] = sampled_positive_tracks
final_vis_masks[:, :sampled_pos_track_num] = sampled_positive_vis_masks
final_pos_masks[:sampled_pos_track_num] = True
target_neg_track_num = target_track_num - sampled_pos_track_num
# Now we need to sample negative tracks
# just do simple random sampling
rand_indices = torch.randperm(negative_tracks.shape[1], device=negative_tracks.device)
sampled_neg_tracks = negative_tracks[:, rand_indices[:target_neg_track_num]]
sampled_neg_track_num = sampled_neg_tracks.shape[1]
final_tracks[:, sampled_pos_track_num:sampled_pos_track_num+sampled_neg_track_num] = sampled_neg_tracks
if sampled_pos_track_num+sampled_neg_track_num!=target_track_num:
logging.warning(f"sampled_pos_track_num+sampled_neg_track_num!=target_track_num: {sampled_pos_track_num+sampled_neg_track_num} != {target_track_num}")
# Do not need to set final_vis_masks and final_pos_masks, because they are all False
# Do not need to check the shape of final_tracks, as it is zeroed out
# NOTE: We need to do some visual checks
return final_tracks, final_vis_masks, final_pos_masks
def get_depth_inside_flag(depths, batch_indices, uv_int, proj_depths, rel_thres):
sampled_depths = depths[batch_indices, uv_int[..., 1], uv_int[..., 0]]
depth_diff = (proj_depths - sampled_depths).abs()
depth_inside_flag = torch.logical_and(depth_diff < (proj_depths * rel_thres), depth_diff < (sampled_depths * rel_thres))
return depth_inside_flag
def sample_positive_tracks(tracks, tracks_mask, track_num, half_top = True, seq_name=None):
# tracks: (B, T, 2)
# tracks_mask: (B, T)
# track_num: int
# half_top: bool
# if the query frame is not valid, then the track is not valid
tracks_mask[:, tracks_mask[0]==False] = False
track_frame_num = tracks_mask.sum(dim=0)
tracks_mask[:, track_frame_num<=1] = False
track_frame_num = tracks_mask.sum(dim=0)
_, track_num_sort_idx = track_frame_num.sort(descending=True)
if half_top:
if len(track_num_sort_idx)//2 > track_num:
# drop those tracks with too small number of valid frames
# track_num_sort_idx = track_num_sort_idx[:track_num]
track_num_sort_idx = track_num_sort_idx[:len(track_num_sort_idx)//2]
pick_idx = torch.randperm(len(track_num_sort_idx))[:track_num]
track_num_sort_idx = track_num_sort_idx[pick_idx]
tracks = tracks[:, track_num_sort_idx].clone()
tracks_mask = tracks_mask[:, track_num_sort_idx].clone()
tracks_mask = tracks_mask.bool() # ensure the type is bool
return tracks, tracks_mask
# Only for Debugging and Visualization
def track_epipolar_check(tracks, extrinsics, intrinsics, use_essential_mat = False):
from kornia.geometry.epipolar import sampson_epipolar_distance
B, T, _ = tracks.shape
essential_mats = get_essential_matrix(extrinsics[0:1].expand(B-1, -1, -1), extrinsics[1:])
if use_essential_mat:
tracks_normalized = cam_from_img(tracks, intrinsics)
sampson_distances = sampson_epipolar_distance(tracks_normalized[0:1].expand(B-1, -1, -1), tracks_normalized[1:], essential_mats)
else:
K1 = intrinsics[0:1].expand(B-1, -1, -1)
K2 = intrinsics[1:].expand(B-1, -1, -1)
fundamental_mats = K2.inverse().permute(0, 2, 1).matmul(essential_mats).matmul(K1.inverse())
sampson_distances = sampson_epipolar_distance(tracks[0:1].expand(B-1, -1, -1), tracks[1:], fundamental_mats)
return sampson_distances
def get_essential_matrix(extrinsic1, extrinsic2):
R1 = extrinsic1[:, :3, :3]
t1 = extrinsic1[:, :3, 3]
R2 = extrinsic2[:, :3, :3]
t2 = extrinsic2[:, :3, 3]
R12 = R2.matmul(R1.permute(0, 2, 1))
t12 = t2 - R12.matmul(t1[..., None])[..., 0]
E_R = R12
E_t = -E_R.permute(0, 2, 1).matmul(t12[..., None])[..., 0]
E = E_R.matmul(hat(E_t))
return E
def hat(v: torch.Tensor) -> torch.Tensor:
N, dim = v.shape
if dim != 3:
raise ValueError("Input vectors have to be 3-dimensional.")
x, y, z = v.unbind(1)
h_01 = -z.view(N, 1, 1)
h_02 = y.view(N, 1, 1)
h_10 = z.view(N, 1, 1)
h_12 = -x.view(N, 1, 1)
h_20 = -y.view(N, 1, 1)
h_21 = x.view(N, 1, 1)
zeros = torch.zeros((N, 1, 1), dtype=v.dtype, device=v.device)
row1 = torch.cat((zeros, h_01, h_02), dim=2)
row2 = torch.cat((h_10, zeros, h_12), dim=2)
row3 = torch.cat((h_20, h_21, zeros), dim=2)
h = torch.cat((row1, row2, row3), dim=1)
return h
def color_from_xy(x, y, W, H, cmap_name="hsv"):
"""
Map (x, y) -> color in (R, G, B).
1) Normalize x,y to [0,1].
2) Combine them into a single scalar c in [0,1].
3) Use matplotlib's colormap to convert c -> (R,G,B).
You can customize step 2, e.g., c = (x + y)/2, or some function of (x, y).
"""
import matplotlib.cm
import matplotlib.colors
x_norm = x / max(W - 1, 1)
y_norm = y / max(H - 1, 1)
# Simple combination:
c = (x_norm + y_norm) / 2.0
cmap = matplotlib.cm.get_cmap(cmap_name)
# cmap(c) -> (r,g,b,a) in [0,1]
rgba = cmap(c)
r, g, b = rgba[0], rgba[1], rgba[2]
return (r, g, b) # in [0,1], RGB order
def get_track_colors_by_position(
tracks_b,
vis_mask_b=None,
image_width=None,
image_height=None,
cmap_name="hsv"
):
"""
Given all tracks in one sample (b), compute a (N,3) array of RGB color values
in [0,255]. The color is determined by the (x,y) position in the first
visible frame for each track.
Args:
tracks_b: Tensor of shape (S, N, 2). (x,y) for each track in each frame.
vis_mask_b: (S, N) boolean mask; if None, assume all are visible.
image_width, image_height: used for normalizing (x, y).
cmap_name: for matplotlib (e.g., 'hsv', 'rainbow', 'jet').
Returns:
track_colors: np.ndarray of shape (N, 3), each row is (R,G,B) in [0,255].
"""
S, N, _ = tracks_b.shape
track_colors = np.zeros((N, 3), dtype=np.uint8)
if vis_mask_b is None:
# treat all as visible
vis_mask_b = torch.ones(S, N, dtype=torch.bool, device=tracks_b.device)
for i in range(N):
# Find first visible frame for track i
visible_frames = torch.where(vis_mask_b[:, i])[0]
if len(visible_frames) == 0:
# track is never visible; just assign black or something
track_colors[i] = (0, 0, 0)
continue
first_s = int(visible_frames[0].item())
# use that frame's (x,y)
x, y = tracks_b[first_s, i].tolist()
# map (x,y) -> (R,G,B) in [0,1]
r, g, b = color_from_xy(
x, y,
W=image_width,
H=image_height,
cmap_name=cmap_name
)
# scale to [0,255]
r, g, b = int(r*255), int(g*255), int(b*255)
track_colors[i] = (r, g, b)
return track_colors
def visualize_tracks_on_images(
images,
tracks,
track_vis_mask=None,
out_dir="track_visuals_concat_by_xy",
image_format="CHW", # "CHW" or "HWC"
normalize_mode="[0,1]",
cmap_name="hsv" # e.g. "hsv", "rainbow", "jet"
):
"""
Visualizes all frames for each sample (b) in ONE horizontal row, saving
one PNG per sample. Each track's color is determined by its (x,y) position
in the first visible frame (or frame 0 if always visible).
Finally convert the BGR result to RGB before saving.
Args:
images: torch.Tensor (B, S, 3, H, W) if CHW or (B, S, H, W, 3) if HWC.
tracks: torch.Tensor (B, S, N, 2), last dim = (x, y).
track_vis_mask: torch.Tensor (B, S, N) or None.
out_dir: folder to save visualizations.
image_format: "CHW" or "HWC".
normalize_mode: "[0,1]", "[-1,1]", or None for direct raw -> 0..255
cmap_name: a matplotlib colormap name for color_from_xy.
Returns:
None (saves images in out_dir).
"""
import matplotlib
matplotlib.use('Agg') # for non-interactive (optional)
os.makedirs(out_dir, exist_ok=True)
B, S = images.shape[0], images.shape[1]
_, _, N, _ = tracks.shape # (B, S, N, 2)
# Move to CPU
images = images.cpu().clone()
tracks = tracks.cpu().clone()
if track_vis_mask is not None:
track_vis_mask = track_vis_mask.cpu().clone()
# Infer H, W from images shape
if image_format == "CHW":
# e.g. images[b, s].shape = (3, H, W)
H, W = images.shape[3], images.shape[4]
else:
# e.g. images[b, s].shape = (H, W, 3)
H, W = images.shape[2], images.shape[3]
for b in range(B):
# Pre-compute the color for each track i based on first visible position
# in sample b:
track_colors_rgb = get_track_colors_by_position(
tracks[b], # shape (S, N, 2)
vis_mask_b=track_vis_mask[b] if track_vis_mask is not None else None,
image_width=W,
image_height=H,
cmap_name=cmap_name
)
# We'll accumulate each frame’s drawn image in a list
frame_images = []
for s in range(S):
# shape => either (3, H, W) or (H, W, 3)
img = images[b, s]
# Convert to (H, W, 3)
if image_format == "CHW":
img = img.permute(1, 2, 0) # (H, W, 3)
# else "HWC", do nothing
img = img.numpy().astype(np.float32)
# Scale to [0,255] if needed
if normalize_mode == "[0,1]":
img = np.clip(img, 0, 1) * 255.0
elif normalize_mode == "[-1,1]":
img = (img + 1.0) * 0.5 * 255.0
img = np.clip(img, 0, 255.0)
# else no normalization
# Convert to uint8
img = img.astype(np.uint8)
# For drawing in OpenCV, the image is assumed BGR,
# but *currently* it's in (R,G,B) if your original is truly RGB.
# We'll do the color conversion AFTER drawing so that we can call
# cv2.circle(...) with BGR color.
# That means we need to swap the channels now to get BGR for drawing.
# If your images are actually BGR, you may skip or adapt.
img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
# Draw each visible track
cur_tracks = tracks[b, s] # shape (N, 2)
if track_vis_mask is not None:
valid_indices = torch.where(track_vis_mask[b, s])[0]
else:
valid_indices = range(N)
cur_tracks_np = cur_tracks.numpy()
for i in valid_indices:
x, y = cur_tracks_np[i]
pt = (int(round(x)), int(round(y)))
# track_colors_rgb[i] is (R,G,B). For OpenCV circle, we need BGR
R, G, B = track_colors_rgb[i]
color_bgr = (int(B), int(G), int(R))
cv2.circle(img_bgr, pt, radius=3, color=color_bgr, thickness=-1)
# Convert back to RGB for consistent final saving:
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
frame_images.append(img_rgb)
# Concatenate all frames horizontally: (H, S*W, 3)
row_img = np.concatenate(frame_images, axis=1)
out_path = os.path.join(out_dir, f"tracks_b{b}.png")
cv2.imwrite(out_path, row_img)
print(f"[INFO] Saved color-by-XY track visualization for sample b={b} -> {out_path}")
================================================
FILE: training/data/worker_fn.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Utilities for distributed training and deterministic seed generation.
This module provides functions for working with PyTorch's distributed
training capabilities and ensuring reproducible data loading.
"""
import os
import torch
import random
import numpy as np
import torch.distributed as dist
from functools import partial
def is_dist_avail_and_initialized():
"""
Check if distributed training is available and initialized.
Returns:
bool: True if distributed training is available and initialized, False otherwise.
"""
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_rank():
"""
Get the rank of the current process in distributed training.
Returns:
int: The rank of the current process, or 0 if distributed training is not initialized.
"""
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def get_world_size():
"""
Get the total number of processes in distributed training.
Returns:
int: The world size, or 1 if distributed training is not initialized.
"""
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def default_worker_init_fn(worker_id, num_workers, epoch, seed=0):
"""
Default function to initialize random seeds for dataloader workers.
Ensures that each worker across different ranks, epochs, and world sizes
gets a unique random seed for reproducibility.
Args:
worker_id (int): ID of the dataloader worker.
num_workers (int): Total number of dataloader workers.
epoch (int): Current training epoch.
seed (int, optional): Base seed for randomization. Defaults to 0.
"""
rank = get_rank()
world_size = get_world_size()
distributed_rank = int(os.environ.get("RANK", None))
# Use prime numbers for better distribution
RANK_MULTIPLIER = 1
WORKER_MULTIPLIER = 1
WORLD_MULTIPLIER = 1
EPOCH_MULTIPLIER = 12345
DISTRIBUTED_RANK_MULTIPLIER = 1042
worker_seed = (
rank * num_workers * RANK_MULTIPLIER +
worker_id * WORKER_MULTIPLIER +
seed +
world_size * WORLD_MULTIPLIER +
epoch * EPOCH_MULTIPLIER
+ distributed_rank * DISTRIBUTED_RANK_MULTIPLIER
)
print(f"Rank: {rank}, World size: {world_size}, Distributed rank: {distributed_rank}")
print(f"Worker seed: {worker_seed}")
torch.random.manual_seed(worker_seed)
np.random.seed(worker_seed)
random.seed(worker_seed)
return
def get_worker_init_fn(seed, num_workers, epoch, worker_init_fn=None):
"""
Get a worker initialization function for dataloaders.
Args:
seed (int): Base seed for randomization.
num_workers (int): Number of dataloader workers.
epoch (int): Current training epoch.
worker_init_fn (callable, optional): Custom worker initialization function.
If provided, this will be returned instead of the default one.
Returns:
callable: A worker initialization function to use with DataLoader.
"""
if worker_init_fn is not None:
return worker_init_fn
return partial(
default_worker_init_fn,
num_workers=num_workers,
epoch=epoch,
seed=seed,
)
================================================
FILE: training/launch.py
================================================
# Copyright (c) Meta Platforms, Inc. and affil
gitextract_yzgcgvr3/ ├── .gitattributes ├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE.txt ├── README.md ├── demo_colmap.py ├── demo_gradio.py ├── demo_viser.py ├── docs/ │ └── package.md ├── pyproject.toml ├── requirements.txt ├── requirements_demo.txt ├── training/ │ ├── README.md │ ├── __init__.py │ ├── config/ │ │ ├── default.yaml │ │ └── default_dataset.yaml │ ├── data/ │ │ ├── __init__.py │ │ ├── augmentation.py │ │ ├── base_dataset.py │ │ ├── composed_dataset.py │ │ ├── dataset_util.py │ │ ├── datasets/ │ │ │ ├── co3d.py │ │ │ └── vkitti.py │ │ ├── dynamic_dataloader.py │ │ ├── preprocess/ │ │ │ └── vkitti.sh │ │ ├── track_util.py │ │ └── worker_fn.py │ ├── launch.py │ ├── loss.py │ ├── train_utils/ │ │ ├── __init__.py │ │ ├── checkpoint.py │ │ ├── distributed.py │ │ ├── freeze.py │ │ ├── general.py │ │ ├── gradient_clip.py │ │ ├── logging.py │ │ ├── normalization.py │ │ ├── optimizer.py │ │ └── tb_writer.py │ └── trainer.py ├── vggt/ │ ├── dependency/ │ │ ├── __init__.py │ │ ├── distortion.py │ │ ├── np_to_pycolmap.py │ │ ├── projection.py │ │ ├── track_modules/ │ │ │ ├── __init__.py │ │ │ ├── base_track_predictor.py │ │ │ ├── blocks.py │ │ │ ├── modules.py │ │ │ ├── track_refine.py │ │ │ └── utils.py │ │ ├── track_predict.py │ │ ├── vggsfm_tracker.py │ │ └── vggsfm_utils.py │ ├── heads/ │ │ ├── camera_head.py │ │ ├── dpt_head.py │ │ ├── head_act.py │ │ ├── track_head.py │ │ ├── track_modules/ │ │ │ ├── __init__.py │ │ │ ├── base_track_predictor.py │ │ │ ├── blocks.py │ │ │ ├── modules.py │ │ │ └── utils.py │ │ └── utils.py │ ├── layers/ │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── block.py │ │ ├── drop_path.py │ │ ├── layer_scale.py │ │ ├── mlp.py │ │ ├── patch_embed.py │ │ ├── rope.py │ │ ├── swiglu_ffn.py │ │ └── vision_transformer.py │ ├── models/ │ │ ├── aggregator.py │ │ └── vggt.py │ └── utils/ │ ├── geometry.py │ ├── helper.py │ ├── load_fn.py │ ├── pose_enc.py │ ├── rotation.py │ └── visual_track.py └── visual_util.py
SYMBOL INDEX (430 symbols across 62 files)
FILE: demo_colmap.py
function parse_args (line 42) | def parse_args():
function run_VGGT (line 65) | def run_VGGT(model, images, dtype, resolution=518):
function demo_fn (line 93) | def demo_fn(args):
function rename_colmap_recons_and_rescale_camera (line 254) | def rename_colmap_recons_and_rescale_camera(
FILE: demo_gradio.py
function run_model (line 44) | def run_model(target_dir, model) -> dict:
function handle_uploads (line 103) | def handle_uploads(input_video, input_images):
function update_gallery_on_upload (line 171) | def update_gallery_on_upload(input_video, input_images):
function gradio_demo (line 186) | def gradio_demo(
function clear_fields (line 259) | def clear_fields():
function update_log (line 266) | def update_log():
function update_visualization (line 273) | def update_visualization(
function example_pipeline (line 494) | def example_pipeline(
FILE: demo_viser.py
function viser_wrapper (line 34) | def viser_wrapper(
function apply_sky_segmentation (line 258) | def apply_sky_segmentation(conf: np.ndarray, image_folder: str) -> np.nd...
function main (line 321) | def main():
FILE: training/data/augmentation.py
function get_image_augmentation (line 11) | def get_image_augmentation(
FILE: training/data/base_dataset.py
class BaseDataset (line 17) | class BaseDataset(Dataset):
method __init__ (line 33) | def __init__(
method __len__ (line 57) | def __len__(self):
method __getitem__ (line 60) | def __getitem__(self, idx_N):
method get_data (line 75) | def get_data(self, seq_index=None, seq_name=None, ids=None, aspect_rat...
method get_target_shape (line 95) | def get_target_shape(self, aspect_ratio):
method process_one_image (line 115) | def process_one_image(
method get_nearby_ids (line 237) | def get_nearby_ids(self, ids, full_seq_num, expand_ratio=None, expand_...
FILE: training/data/composed_dataset.py
class ComposedDataset (line 21) | class ComposedDataset(Dataset, ABC):
method __init__ (line 30) | def __init__(self, dataset_configs: dict, common_config: dict, **kwargs):
method __len__ (line 80) | def __len__(self):
method __getitem__ (line 85) | def __getitem__(self, idx_tuple):
class TupleConcatDataset (line 187) | class TupleConcatDataset(ConcatDataset):
method __init__ (line 201) | def __init__(self, datasets, common_config):
method __getitem__ (line 214) | def __getitem__(self, idx):
FILE: training/data/dataset_util.py
function crop_image_depth_and_intrinsic_by_pp (line 26) | def crop_image_depth_and_intrinsic_by_pp(
function resize_image_depth_and_intrinsic (line 161) | def resize_image_depth_and_intrinsic(
function threshold_depth_map (line 261) | def threshold_depth_map(
function depth_to_world_coords_points (line 317) | def depth_to_world_coords_points(
function depth_to_cam_coords_points (line 369) | def depth_to_cam_coords_points(
function rotate_90_degrees (line 411) | def rotate_90_degrees(
function rotate_image_and_depth_rot90 (line 474) | def rotate_image_and_depth_rot90(image, depth_map, clockwise):
function adjust_extrinsic_matrix_rot90 (line 507) | def adjust_extrinsic_matrix_rot90(extri_opencv, clockwise):
function adjust_intrinsic_matrix_rot90 (line 548) | def adjust_intrinsic_matrix_rot90(intri_opencv, image_width, image_heigh...
function adjust_track_rot90 (line 588) | def adjust_track_rot90(track, image_width, image_height, clockwise):
function read_image_cv2 (line 616) | def read_image_cv2(path: str, rgb: bool = True) -> np.ndarray:
function read_depth (line 653) | def read_depth(path: str, scale_adjustment=1.0) -> np.ndarray:
function load_16big_png_depth (line 689) | def load_16big_png_depth(depth_png: str) -> np.ndarray:
FILE: training/data/datasets/co3d.py
class Co3dDataset (line 67) | class Co3dDataset(BaseDataset):
method __init__ (line 68) | def __init__(
method get_data (line 162) | def get_data(
FILE: training/data/datasets/vkitti.py
class VKittiDataset (line 20) | class VKittiDataset(BaseDataset):
method __init__ (line 21) | def __init__(
method get_data (line 89) | def get_data(
FILE: training/data/dynamic_dataloader.py
class DynamicTorchDataset (line 17) | class DynamicTorchDataset(ABC):
method __init__ (line 18) | def __init__(
method get_loader (line 67) | def get_loader(self, epoch):
class DynamicBatchSampler (line 94) | class DynamicBatchSampler(Sampler):
method __init__ (line 99) | def __init__(self,
method set_epoch (line 140) | def set_epoch(self, epoch):
method __iter__ (line 151) | def __iter__(self):
method __len__ (line 194) | def __len__(self):
class DynamicDistributedSampler (line 199) | class DynamicDistributedSampler(DistributedSampler):
method __init__ (line 204) | def __init__(
method __iter__ (line 224) | def __iter__(self):
method update_parameters (line 235) | def update_parameters(self, aspect_ratio, image_num):
FILE: training/data/track_util.py
function build_tracks_by_depth (line 19) | def build_tracks_by_depth(extrinsics, intrinsics, world_points, depths, ...
function get_depth_inside_flag (line 149) | def get_depth_inside_flag(depths, batch_indices, uv_int, proj_depths, re...
function sample_positive_tracks (line 161) | def sample_positive_tracks(tracks, tracks_mask, track_num, half_top = Tr...
function track_epipolar_check (line 198) | def track_epipolar_check(tracks, extrinsics, intrinsics, use_essential_m...
function get_essential_matrix (line 216) | def get_essential_matrix(extrinsic1, extrinsic2):
function hat (line 231) | def hat(v: torch.Tensor) -> torch.Tensor:
function color_from_xy (line 257) | def color_from_xy(x, y, W, H, cmap_name="hsv"):
function get_track_colors_by_position (line 281) | def get_track_colors_by_position(
function visualize_tracks_on_images (line 335) | def visualize_tracks_on_images(
FILE: training/data/worker_fn.py
function is_dist_avail_and_initialized (line 22) | def is_dist_avail_and_initialized():
function get_rank (line 36) | def get_rank():
function get_world_size (line 48) | def get_world_size():
function default_worker_init_fn (line 60) | def default_worker_init_fn(worker_id, num_workers, epoch, seed=0):
function get_worker_init_fn (line 102) | def get_worker_init_fn(seed, num_workers, epoch, worker_init_fn=None):
FILE: training/launch.py
function main (line 13) | def main():
FILE: training/loss.py
class MultitaskLoss (line 17) | class MultitaskLoss(torch.nn.Module):
method __init__ (line 27) | def __init__(self, camera=None, depth=None, point=None, track=None, **...
method forward (line 35) | def forward(self, predictions, batch) -> torch.Tensor:
function compute_camera_loss (line 81) | def compute_camera_loss(
function camera_loss_single (line 157) | def camera_loss_single(pred_pose_enc, gt_pose_enc, loss_type="l1"):
function compute_point_loss (line 199) | def compute_point_loss(predictions, batch, gamma=1.0, alpha=0.2, gradien...
function compute_depth_loss (line 239) | def compute_depth_loss(predictions, batch, gamma=1.0, alpha=0.2, gradien...
function regression_loss (line 281) | def regression_loss(pred, gt, mask, conf=None, gradient_loss_fn=None, ga...
function gradient_loss_multi_scale_wrapper (line 370) | def gradient_loss_multi_scale_wrapper(prediction, target, mask, scales=4...
function normal_loss (line 398) | def normal_loss(prediction, target, mask, cos_eps=1e-8, conf=None, gamma...
function gradient_loss (line 456) | def gradient_loss(prediction, target, mask, conf=None, gamma=1.0, alpha=...
function point_map_to_normal (line 511) | def point_map_to_normal(point_map, mask, eps=1e-6):
function filter_by_quantile (line 567) | def filter_by_quantile(loss_tensor, valid_range, min_elements=1000, hard...
function torch_quantile (line 606) | def torch_quantile(
FILE: training/train_utils/checkpoint.py
class DDPCheckpointSaver (line 38) | class DDPCheckpointSaver:
method __init__ (line 39) | def __init__(
method save_checkpoint (line 52) | def save_checkpoint(
function robust_torch_save (line 72) | def robust_torch_save(checkpoint: Dict[str, Any], checkpoint_path: str) ...
FILE: training/train_utils/distributed.py
function get_machine_local_and_dist_rank (line 12) | def get_machine_local_and_dist_rank():
FILE: training/train_utils/freeze.py
function freeze_modules (line 24) | def freeze_modules(model: nn.Module, patterns: List[str], recursive: boo...
function _freeze (line 62) | def _freeze(mod: nn.Module, recursive: bool) -> None:
function _check_every_pattern_used (line 91) | def _check_every_pattern_used(matched_names: set[str], patterns: List[st...
FILE: training/train_utils/general.py
function check_and_fix_inf_nan (line 29) | def check_and_fix_inf_nan(input_tensor, loss_name="default", hard_max=100):
function get_resume_checkpoint (line 60) | def get_resume_checkpoint(checkpoint_save_dir):
class DurationMeter (line 69) | class DurationMeter:
method __init__ (line 70) | def __init__(self, name, device, fmt=":f"):
method reset (line 76) | def reset(self):
method update (line 79) | def update(self, val):
method add (line 82) | def add(self, val):
method __str__ (line 85) | def __str__(self):
function human_readable_time (line 89) | def human_readable_time(time_seconds):
class ProgressMeter (line 98) | class ProgressMeter:
method __init__ (line 99) | def __init__(self, num_batches, meters, real_meters, prefix=""):
method display (line 105) | def display(self, batch):
method _get_batch_fmtstr (line 119) | def _get_batch_fmtstr(self, num_batches):
class _CopyableData (line 127) | class _CopyableData(Protocol):
method to (line 128) | def to(self, device: torch.device, *args: Any, **kwargs: Any):
function _is_named_tuple (line 133) | def _is_named_tuple(x) -> bool:
function copy_data_to_device (line 137) | def copy_data_to_device(data, device: torch.device, *args: Any, **kwargs...
function safe_makedirs (line 197) | def safe_makedirs(path: str):
function set_seeds (line 215) | def set_seeds(seed_value, max_epochs, dist_rank):
function log_env_variables (line 233) | def log_env_variables():
function is_dist_avail_and_initialized (line 243) | def is_dist_avail_and_initialized():
class AverageMeter (line 252) | class AverageMeter:
method __init__ (line 260) | def __init__(self, name: str, device: Optional[torch.device] = None, f...
method reset (line 266) | def reset(self):
method update (line 273) | def update(self, val, n=1):
method __str__ (line 282) | def __str__(self) -> str:
method value (line 288) | def value(self) -> float:
method average (line 293) | def average(self) -> float:
function pretty_int (line 302) | def pretty_int(n: int) -> str:
function model_summary (line 313) | def model_summary(model: torch.nn.Module,
function get_rank (line 364) | def get_rank():
FILE: training/train_utils/gradient_clip.py
class GradientClipper (line 12) | class GradientClipper:
method __init__ (line 17) | def __init__(self, configs, *args, **kwargs):
method setup_clipping (line 40) | def setup_clipping(self, model: nn.Module) -> None:
method __call__ (line 80) | def __call__(self, model: nn.Module) -> Optional[torch.Tensor]:
FILE: training/train_utils/logging.py
function _cached_log_stream (line 22) | def _cached_log_stream(filename):
function setup_logging (line 30) | def setup_logging(
FILE: training/train_utils/normalization.py
function check_valid_tensor (line 14) | def check_valid_tensor(input_tensor: Optional[torch.Tensor], name: str =...
function normalize_camera_extrinsics_and_points_batch (line 27) | def normalize_camera_extrinsics_and_points_batch(
FILE: training/train_utils/optimizer.py
class OptimizerWrapper (line 20) | class OptimizerWrapper:
method __init__ (line 23) | def __init__(self, optimizer: torch.optim.Optimizer, schedulers=None) ...
method step (line 33) | def step(self, where: float = 1.0, closure=None):
method zero_grad (line 38) | def zero_grad(self, *args, **kwargs):
method _validate_optimizer_schedulers (line 41) | def _validate_optimizer_schedulers(self):
method step_schedulers (line 51) | def step_schedulers(self, where: float) -> None:
function validate_param_group_params (line 64) | def validate_param_group_params(param_groups: List[Dict], model: nn.Modu...
function get_full_parameter_name (line 96) | def get_full_parameter_name(module_name: str, param_name: str) -> str:
function get_module_cls_to_param_names (line 100) | def get_module_cls_to_param_names(model: nn.Module) -> Dict[type, Set[st...
function unix_param_pattern_to_parameter_names (line 111) | def unix_param_pattern_to_parameter_names(filter_param_names: Union[List...
function unix_module_cls_pattern_to_parameter_names (line 125) | def unix_module_cls_pattern_to_parameter_names(filter_module_cls_names: ...
function _unix_pattern_to_parameter_names (line 142) | def _unix_pattern_to_parameter_names(scheduler_cfg,
function set_default_parameters (line 161) | def set_default_parameters(scheduler_cfgs: List[dict], all_parameter_nam...
function name_constraints_to_parameters (line 180) | def name_constraints_to_parameters(param_constraints: List[Set[str]],
function map_scheduler_cfgs_to_param_groups (line 186) | def map_scheduler_cfgs_to_param_groups(all_scheduler_cfgs: Iterable[List...
function construct_optimizer (line 208) | def construct_optimizer(model: nn.Module,
function construct_optimizers (line 262) | def construct_optimizers(model: nn.Module, optim_conf) -> Union[List[Opt...
FILE: training/train_utils/tb_writer.py
class TensorBoardLogger (line 18) | class TensorBoardLogger:
method __init__ (line 25) | def __init__(
method writer (line 62) | def writer(self) -> Optional[SummaryWriter]:
method path (line 67) | def path(self) -> str:
method flush (line 71) | def flush(self) -> None:
method close (line 76) | def close(self) -> None:
method log_dict (line 85) | def log_dict(self, payload: Dict[str, Any], step: int) -> None:
method log (line 98) | def log(self, name: str, data: Any, step: int) -> None:
method log_visuals (line 111) | def log_visuals(
FILE: training/trainer.py
class Trainer (line 46) | class Trainer:
method __init__ (line 60) | def __init__(
method _setup_timers (line 170) | def _setup_timers(self):
method _setup_env_variables (line 175) | def _setup_env_variables(self, env_variables_conf: Optional[Dict[str, ...
method _setup_torch_dist_and_backend (line 182) | def _setup_torch_dist_and_backend(self, cuda_conf: Dict, distributed_c...
method _load_resuming_checkpoint (line 198) | def _load_resuming_checkpoint(self, ckpt_path: str):
method _setup_device (line 228) | def _setup_device(self, device: str):
method _setup_components (line 239) | def _setup_components(self):
method _setup_dataloaders (line 273) | def _setup_dataloaders(self):
method _setup_ddp_distributed_training (line 289) | def _setup_ddp_distributed_training(self, distributed_conf: Dict, devi...
method save_checkpoint (line 306) | def save_checkpoint(self, epoch: int, checkpoint_names: Optional[List[...
method _get_scalar_log_keys (line 359) | def _get_scalar_log_keys(self, phase: str) -> List[str]:
method run (line 365) | def run(self):
method run_train (line 377) | def run_train(self):
method run_val (line 403) | def run_val(self):
method val_epoch (line 419) | def val_epoch(self, val_loader):
method train_epoch (line 501) | def train_epoch(self, train_loader):
method _run_steps_on_batch_chunks (line 638) | def _run_steps_on_batch_chunks(
method _apply_batch_repetition (line 692) | def _apply_batch_repetition(self, batch: Mapping) -> Mapping:
method _process_batch (line 716) | def _process_batch(self, batch: Mapping):
method _step (line 738) | def _step(self, batch, model: nn.Module, phase: str, loss_meters: dict):
method _update_and_log_scalars (line 760) | def _update_and_log_scalars(self, data: Mapping, phase: str, step: int...
method _log_tb_visuals (line 772) | def _log_tb_visuals(self, batch: Mapping, phase: str, step: int) -> None:
function chunk_batch_for_accum_steps (line 823) | def chunk_batch_for_accum_steps(batch: Mapping, accum_steps: int) -> Lis...
function is_sequence_of_primitives (line 829) | def is_sequence_of_primitives(data: Any) -> bool:
function get_chunk_from_data (line 838) | def get_chunk_from_data(data: Any, chunk_id: int, num_chunks: int) -> Any:
FILE: vggt/dependency/distortion.py
function _is_numpy (line 14) | def _is_numpy(x: ArrayLike) -> bool:
function _is_torch (line 18) | def _is_torch(x: ArrayLike) -> bool:
function _ensure_torch (line 22) | def _ensure_torch(x: ArrayLike) -> torch.Tensor:
function single_undistortion (line 32) | def single_undistortion(params, tracks_normalized):
function iterative_undistortion (line 51) | def iterative_undistortion(params, tracks_normalized, max_iterations=100...
function apply_distortion (line 99) | def apply_distortion(extra_params, u, v):
FILE: vggt/dependency/np_to_pycolmap.py
function batch_np_matrix_to_pycolmap (line 12) | def batch_np_matrix_to_pycolmap(
function pycolmap_to_batch_np_matrix (line 148) | def pycolmap_to_batch_np_matrix(reconstruction, device="cpu", camera_typ...
function batch_np_matrix_to_pycolmap_wo_track (line 201) | def batch_np_matrix_to_pycolmap_wo_track(
function _build_pycolmap_intri (line 293) | def _build_pycolmap_intri(fidx, intrinsics, camera_type, extra_params=No...
FILE: vggt/dependency/projection.py
function img_from_cam_np (line 12) | def img_from_cam_np(
function project_3D_points_np (line 50) | def project_3D_points_np(
function project_3D_points (line 105) | def project_3D_points(points3D, extrinsics, intrinsics=None, extra_param...
function img_from_cam (line 140) | def img_from_cam(intrinsics, points_cam, extra_params=None, default=0.0):
FILE: vggt/dependency/track_modules/base_track_predictor.py
class BaseTrackerPredictor (line 15) | class BaseTrackerPredictor(nn.Module):
method __init__ (line 16) | def __init__(
method forward (line 71) | def forward(self, query_points, fmaps=None, iters=4, return_feat=False...
FILE: vggt/dependency/track_modules/blocks.py
class BasicEncoder (line 25) | class BasicEncoder(nn.Module):
method __init__ (line 26) | def __init__(self, input_dim=3, output_dim=128, stride=4):
method _make_layer (line 58) | def _make_layer(self, dim, stride=1):
method forward (line 66) | def forward(self, x):
class ShallowEncoder (line 90) | class ShallowEncoder(nn.Module):
method __init__ (line 91) | def __init__(self, input_dim=3, output_dim=32, stride=1, norm_fn="inst...
method _make_layer (line 126) | def _make_layer(self, dim, stride=1):
method forward (line 132) | def forward(self, x):
function _bilinear_intepolate (line 151) | def _bilinear_intepolate(x, stride, H, W):
class EfficientUpdateFormer (line 155) | class EfficientUpdateFormer(nn.Module):
method __init__ (line 160) | def __init__(
method initialize_weights (line 210) | def initialize_weights(self):
method forward (line 224) | def forward(self, input_tensor, mask=None):
class CorrBlock (line 264) | class CorrBlock:
method __init__ (line 265) | def __init__(self, fmaps, num_levels=4, radius=4, multiple_track_feats...
method sample (line 282) | def sample(self, coords):
method corr (line 309) | def corr(self, targets):
FILE: vggt/dependency/track_modules/modules.py
function _ntuple (line 19) | def _ntuple(n):
function exists (line 28) | def exists(val):
function default (line 32) | def default(val, d):
class ResidualBlock (line 39) | class ResidualBlock(nn.Module):
method __init__ (line 44) | def __init__(self, in_planes, planes, norm_fn="group", stride=1, kerne...
method forward (line 86) | def forward(self, x):
class Mlp (line 97) | class Mlp(nn.Module):
method __init__ (line 100) | def __init__(
method forward (line 124) | def forward(self, x):
class AttnBlock (line 133) | class AttnBlock(nn.Module):
method __init__ (line 134) | def __init__(
method forward (line 155) | def forward(self, x, mask=None):
class CrossAttnBlock (line 172) | class CrossAttnBlock(nn.Module):
method __init__ (line 173) | def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4....
method forward (line 190) | def forward(self, x, context, mask=None):
FILE: vggt/dependency/track_modules/track_refine.py
function refine_track (line 22) | def refine_track(
function refine_track_v0 (line 163) | def refine_track_v0(
function compute_score_fn (line 302) | def compute_score_fn(query_point_feat, patch_feat, fine_pred_track, srad...
function extract_glimpse (line 381) | def extract_glimpse(
FILE: vggt/dependency/track_modules/utils.py
function get_2d_sincos_pos_embed (line 19) | def get_2d_sincos_pos_embed(embed_dim: int, grid_size: Union[int, Tuple[...
function get_2d_sincos_pos_embed_from_grid (line 44) | def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor...
function get_1d_sincos_pos_embed_from_grid (line 65) | def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor)...
function get_2d_embedding (line 91) | def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) ...
function bilinear_sampler (line 125) | def bilinear_sampler(input, coords, align_corners=True, padding_mode="bo...
function sample_features4d (line 186) | def sample_features4d(input, coords):
FILE: vggt/dependency/track_predict.py
function predict_tracks (line 12) | def predict_tracks(
function _forward_on_query (line 135) | def _forward_on_query(
function _augment_non_visible_frames (line 232) | def _augment_non_visible_frames(
FILE: vggt/dependency/vggsfm_tracker.py
class TrackerPredictor (line 25) | class TrackerPredictor(nn.Module):
method __init__ (line 26) | def __init__(self, **extra_args):
method forward (line 58) | def forward(
method process_images_to_fmaps (line 106) | def process_images_to_fmaps(self, images):
FILE: vggt/dependency/vggsfm_utils.py
function build_vggsfm_tracker (line 29) | def build_vggsfm_tracker(model_path=None):
function generate_rank_by_dino (line 51) | def generate_rank_by_dino(
function farthest_point_sampling (line 118) | def farthest_point_sampling(distance_matrix, num_samples, most_common_fr...
function calculate_index_mappings (line 153) | def calculate_index_mappings(query_index, S, device=None):
function switch_tensor_order (line 174) | def switch_tensor_order(tensors, order, dim=1):
function initialize_feature_extractors (line 189) | def initialize_feature_extractors(max_query_num, det_thres=0.005, extrac...
function extract_keypoints (line 227) | def extract_keypoints(query_image, extractors, round_keypoints=True):
function predict_tracks_in_chunks (line 255) | def predict_tracks_in_chunks(
FILE: vggt/heads/camera_head.py
class CameraHead (line 19) | class CameraHead(nn.Module):
method __init__ (line 26) | def __init__(
method forward (line 73) | def forward(self, aggregated_tokens_list: list, num_iterations: int = ...
method trunk_fn (line 95) | def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> ...
function modulate (line 144) | def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) ...
FILE: vggt/heads/dpt_head.py
class DPTHead (line 21) | class DPTHead(nn.Module):
method __init__ (line 43) | def __init__(
method forward (line 115) | def forward(
method _forward_impl (line 172) | def _forward_impl(
method _apply_pos_embed (line 249) | def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: flo...
method scratch_forward (line 261) | def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor:
function _make_fusion_block (line 299) | def _make_fusion_block(features: int, size: int = None, has_residual: bo...
function _make_scratch (line 313) | def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, ...
class ResidualConvUnit (line 344) | class ResidualConvUnit(nn.Module):
method __init__ (line 347) | def __init__(self, features, activation, bn, groups=1):
method forward (line 366) | def forward(self, x):
class FeatureFusionBlock (line 389) | class FeatureFusionBlock(nn.Module):
method __init__ (line 392) | def __init__(
method forward (line 432) | def forward(self, *xs, size=None):
function custom_interpolate (line 459) | def custom_interpolate(
FILE: vggt/heads/head_act.py
function activate_pose (line 12) | def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear", ...
function base_pose_act (line 38) | def base_pose_act(pose_enc, act_type="linear"):
function activate_head (line 61) | def activate_head(out, activation="norm_exp", conf_activation="expp1"):
function inverse_log_transform (line 115) | def inverse_log_transform(y):
FILE: vggt/heads/track_head.py
class TrackHead (line 12) | class TrackHead(nn.Module):
method __init__ (line 18) | def __init__(
method forward (line 72) | def forward(self, aggregated_tokens_list, images, patch_start_idx, que...
FILE: vggt/heads/track_modules/base_track_predictor.py
class BaseTrackerPredictor (line 17) | class BaseTrackerPredictor(nn.Module):
method __init__ (line 18) | def __init__(
method forward (line 82) | def forward(self, query_points, fmaps=None, iters=6, return_feat=False...
FILE: vggt/heads/track_modules/blocks.py
class EfficientUpdateFormer (line 19) | class EfficientUpdateFormer(nn.Module):
method __init__ (line 24) | def __init__(
method initialize_weights (line 80) | def initialize_weights(self):
method forward (line 90) | def forward(self, input_tensor, mask=None):
class CorrBlock (line 137) | class CorrBlock:
method __init__ (line 138) | def __init__(self, fmaps, num_levels=4, radius=4, multiple_track_feats...
method corr_sample (line 176) | def corr_sample(self, targets, coords):
function compute_corr_level (line 231) | def compute_corr_level(fmap1, fmap2s, C):
FILE: vggt/heads/track_modules/modules.py
function _ntuple (line 19) | def _ntuple(n):
function exists (line 28) | def exists(val):
function default (line 32) | def default(val, d):
class ResidualBlock (line 39) | class ResidualBlock(nn.Module):
method __init__ (line 44) | def __init__(self, in_planes, planes, norm_fn="group", stride=1, kerne...
method forward (line 86) | def forward(self, x):
class Mlp (line 97) | class Mlp(nn.Module):
method __init__ (line 100) | def __init__(
method forward (line 124) | def forward(self, x):
class AttnBlock (line 133) | class AttnBlock(nn.Module):
method __init__ (line 134) | def __init__(
method forward (line 156) | def forward(self, x, mask=None):
class CrossAttnBlock (line 173) | class CrossAttnBlock(nn.Module):
method __init__ (line 174) | def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4....
method forward (line 192) | def forward(self, x, context, mask=None):
FILE: vggt/heads/track_modules/utils.py
function get_2d_sincos_pos_embed (line 18) | def get_2d_sincos_pos_embed(embed_dim: int, grid_size: Union[int, Tuple[...
function get_2d_sincos_pos_embed_from_grid (line 43) | def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor...
function get_1d_sincos_pos_embed_from_grid (line 64) | def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor)...
function get_2d_embedding (line 90) | def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) ...
function bilinear_sampler (line 124) | def bilinear_sampler(input, coords, align_corners=True, padding_mode="bo...
function sample_features4d (line 193) | def sample_features4d(input, coords):
FILE: vggt/heads/utils.py
function position_grid_to_embed (line 11) | def position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega...
function make_sincos_pos_embed (line 36) | def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: fl...
function create_uv_grid (line 66) | def create_uv_grid(
FILE: vggt/layers/attention.py
class Attention (line 21) | class Attention(nn.Module):
method __init__ (line 22) | def __init__(
method forward (line 50) | def forward(self, x: Tensor, pos=None) -> Tensor:
class MemEffAttention (line 75) | class MemEffAttention(Attention):
method forward (line 76) | def forward(self, x: Tensor, attn_bias=None, pos=None) -> Tensor:
FILE: vggt/layers/block.py
class Block (line 27) | class Block(nn.Module):
method __init__ (line 28) | def __init__(
method forward (line 77) | def forward(self, x: Tensor, pos=None) -> Tensor:
function drop_add_residual_stochastic_depth (line 101) | def drop_add_residual_stochastic_depth(
function get_branges_scales (line 128) | def get_branges_scales(x, sample_drop_ratio=0.0):
function add_residual (line 136) | def add_residual(x, brange, residual, residual_scale_factor, scaling_vec...
function get_attn_bias_and_cat (line 151) | def get_attn_bias_and_cat(x_list, branges=None):
function drop_add_residual_stochastic_depth_list (line 175) | def drop_add_residual_stochastic_depth_list(
class NestedTensorBlock (line 198) | class NestedTensorBlock(Block):
method forward_nested (line 199) | def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
method forward (line 239) | def forward(self, x_or_x_list):
FILE: vggt/layers/drop_path.py
function drop_path (line 14) | def drop_path(x, drop_prob: float = 0.0, training: bool = False):
class DropPath (line 26) | class DropPath(nn.Module):
method __init__ (line 29) | def __init__(self, drop_prob=None):
method forward (line 33) | def forward(self, x):
FILE: vggt/layers/layer_scale.py
class LayerScale (line 15) | class LayerScale(nn.Module):
method __init__ (line 16) | def __init__(self, dim: int, init_values: Union[float, Tensor] = 1e-5,...
method forward (line 21) | def forward(self, x: Tensor) -> Tensor:
FILE: vggt/layers/mlp.py
class Mlp (line 16) | class Mlp(nn.Module):
method __init__ (line 17) | def __init__(
method forward (line 34) | def forward(self, x: Tensor) -> Tensor:
FILE: vggt/layers/patch_embed.py
function make_2tuple (line 16) | def make_2tuple(x):
class PatchEmbed (line 25) | class PatchEmbed(nn.Module):
method __init__ (line 37) | def __init__(
method forward (line 65) | def forward(self, x: Tensor) -> Tensor:
method flops (line 80) | def flops(self) -> float:
FILE: vggt/layers/rope.py
class PositionGetter (line 24) | class PositionGetter:
method __init__ (line 35) | def __init__(self):
method __call__ (line 39) | def __call__(self, batch_size: int, height: int, width: int, device: t...
class RotaryPositionEmbedding2D (line 62) | class RotaryPositionEmbedding2D(nn.Module):
method __init__ (line 79) | def __init__(self, frequency: float = 100.0, scaling_factor: float = 1...
method _compute_frequency_components (line 86) | def _compute_frequency_components(
method _rotate_features (line 120) | def _rotate_features(x: torch.Tensor) -> torch.Tensor:
method _apply_1d_rope (line 133) | def _apply_1d_rope(
method forward (line 154) | def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> to...
FILE: vggt/layers/swiglu_ffn.py
class SwiGLUFFN (line 14) | class SwiGLUFFN(nn.Module):
method __init__ (line 15) | def __init__(
method forward (line 30) | def forward(self, x: Tensor) -> Tensor:
class SwiGLUFFNFused (line 54) | class SwiGLUFFNFused(SwiGLU):
method __init__ (line 55) | def __init__(
FILE: vggt/layers/vision_transformer.py
function named_apply (line 24) | def named_apply(fn: Callable, module: nn.Module, name="", depth_first=Tr...
class BlockChunk (line 35) | class BlockChunk(nn.ModuleList):
method forward (line 36) | def forward(self, x):
class DinoVisionTransformer (line 42) | class DinoVisionTransformer(nn.Module):
method __init__ (line 43) | def __init__(
method init_weights (line 173) | def init_weights(self):
method interpolate_pos_encoding (line 180) | def interpolate_pos_encoding(self, x, w, h):
method prepare_tokens_with_masks (line 214) | def prepare_tokens_with_masks(self, x, masks=None):
method forward_features_list (line 228) | def forward_features_list(self, x_list, masks_list):
method forward_features (line 252) | def forward_features(self, x, masks=None):
method _get_intermediate_layers_not_chunked (line 273) | def _get_intermediate_layers_not_chunked(self, x, n=1):
method _get_intermediate_layers_chunked (line 285) | def _get_intermediate_layers_chunked(self, x, n=1):
method get_intermediate_layers (line 299) | def get_intermediate_layers(
method forward (line 325) | def forward(self, *args, is_training=True, **kwargs):
function init_weights_vit_timm (line 333) | def init_weights_vit_timm(module: nn.Module, name: str = ""):
function vit_small (line 341) | def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
function vit_base (line 355) | def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
function vit_large (line 369) | def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
function vit_giant2 (line 383) | def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
FILE: vggt/models/aggregator.py
class Aggregator (line 25) | class Aggregator(nn.Module):
method __init__ (line 52) | def __init__(
method __build_patch_embed__ (line 143) | def __build_patch_embed__(
method forward (line 184) | def forward(self, images: torch.Tensor) -> Tuple[List[torch.Tensor], i...
method _process_frame_attention (line 260) | def _process_frame_attention(self, tokens, B, S, P, C, frame_idx, pos=...
method _process_global_attention (line 284) | def _process_global_attention(self, tokens, B, S, P, C, global_idx, po...
function slice_expand_and_flatten (line 308) | def slice_expand_and_flatten(token_tensor, B, S):
FILE: vggt/models/vggt.py
class VGGT (line 17) | class VGGT(nn.Module, PyTorchModelHubMixin):
method __init__ (line 18) | def __init__(self, img_size=518, patch_size=14, embed_dim=1024,
method forward (line 29) | def forward(self, images: torch.Tensor, query_points: torch.Tensor = N...
FILE: vggt/utils/geometry.py
function unproject_depth_map_to_point_map (line 15) | def unproject_depth_map_to_point_map(
function depth_to_world_coords_points (line 47) | def depth_to_world_coords_points(
function depth_to_cam_coords_points (line 87) | def depth_to_cam_coords_points(depth_map: np.ndarray, intrinsic: np.ndar...
function closed_form_inverse_se3 (line 120) | def closed_form_inverse_se3(se3, R=None, T=None):
function project_world_points_to_camera_points_batch (line 175) | def project_world_points_to_camera_points_batch(world_points, cam_extrin...
function project_world_points_to_cam (line 204) | def project_world_points_to_cam(
function img_from_cam (line 251) | def img_from_cam(cam_intrinsics, cam_points, distortion_params=None, def...
function cam_from_img (line 294) | def cam_from_img(pred_tracks, intrinsics, extra_params=None):
FILE: vggt/utils/helper.py
function randomly_limit_trues (line 10) | def randomly_limit_trues(mask: np.ndarray, max_trues: int) -> np.ndarray:
function create_pixel_coordinate_grid (line 33) | def create_pixel_coordinate_grid(num_frames, height, width):
FILE: vggt/utils/load_fn.py
function load_and_preprocess_images_square (line 13) | def load_and_preprocess_images_square(image_path_list, target_size=1024):
function load_and_preprocess_images (line 97) | def load_and_preprocess_images(image_path_list, mode="crop"):
FILE: vggt/utils/pose_enc.py
function extri_intri_to_pose_encoding (line 11) | def extri_intri_to_pose_encoding(
function pose_encoding_to_extri_intri (line 62) | def pose_encoding_to_extri_intri(
FILE: vggt/utils/rotation.py
function quat_to_mat (line 14) | def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor:
function mat_to_quat (line 47) | def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor:
function _sqrt_positive_part (line 106) | def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
function standardize_quaternion (line 120) | def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
FILE: vggt/utils/visual_track.py
function color_from_xy (line 13) | def color_from_xy(x, y, W, H, cmap_name="hsv"):
function get_track_colors_by_position (line 37) | def get_track_colors_by_position(tracks_b, vis_mask_b=None, image_width=...
function visualize_tracks_on_images (line 80) | def visualize_tracks_on_images(
FILE: visual_util.py
function predictions_to_glb (line 18) | def predictions_to_glb(
function integrate_camera_into_scene (line 218) | def integrate_camera_into_scene(scene: trimesh.Scene, transform: np.ndar...
function apply_scene_alignment (line 263) | def apply_scene_alignment(scene_3d: trimesh.Scene, extrinsics_matrices: ...
function get_opengl_conversion_matrix (line 287) | def get_opengl_conversion_matrix() -> np.ndarray:
function transform_points (line 304) | def transform_points(transformation: np.ndarray, points: np.ndarray, dim...
function compute_camera_faces (line 329) | def compute_camera_faces(cone_shape: trimesh.Trimesh) -> np.ndarray:
function segment_sky (line 365) | def segment_sky(image_path, onnx_session, mask_filename=None):
function run_skyseg (line 396) | def run_skyseg(onnx_session, input_size, image):
function download_file_from_url (line 436) | def download_file_from_url(url, filename):
Condensed preview — 83 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (624K chars).
[
{
"path": ".gitattributes",
"chars": 122,
"preview": "# SCM syntax highlighting & preventing 3-way merges\npixi.lock merge=binary linguist-language=YAML linguist-generated=tru"
},
{
"path": ".gitignore",
"chars": 2020,
"preview": ".hydra/\noutput/\nckpt/\n# Byte-compiled / optimized / DLL files\n__pycache__/\n**/__pycache__/\n*.py[cod]\n*$py.class\n\n# C ext"
},
{
"path": "CODE_OF_CONDUCT.md",
"chars": 3537,
"preview": "# Code of Conduct\n\n## Our Pledge\n\nIn the interest of fostering an open and welcoming environment, we as\ncontributors and"
},
{
"path": "CONTRIBUTING.md",
"chars": 1242,
"preview": "# Contributing to vggt\nWe want to make contributing to this project as easy and transparent as\npossible.\n\n## Pull Reques"
},
{
"path": "LICENSE.txt",
"chars": 10666,
"preview": "VGGT License\n\nv1 Last Updated: July 29, 2025\n\n“Acceptable Use Policy” means the Acceptable Use Policy, applicable to Res"
},
{
"path": "README.md",
"chars": 14662,
"preview": "<div align=\"center\">\n<h1>VGGT: Visual Geometry Grounded Transformer</h1>\n\n<a href=\"https://jytime.github.io/data/VGGT_CV"
},
{
"path": "demo_colmap.py",
"chars": 12668,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "demo_gradio.py",
"chars": 25027,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "demo_viser.py",
"chars": 14794,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "docs/package.md",
"chars": 1127,
"preview": "# Alternative Installation Methods\n\nThis document explains how to install VGGT as a package using different package mana"
},
{
"path": "pyproject.toml",
"chars": 1040,
"preview": "[project]\nauthors = [{name = \"Jianyuan Wang\", email = \"jianyuan@robots.ox.ac.uk\"}]\ndependencies = [\n \"numpy<2\",\n \""
},
{
"path": "requirements.txt",
"chars": 89,
"preview": "torch==2.3.1\ntorchvision==0.18.1\nnumpy==1.26.1\nPillow\nhuggingface_hub\neinops\nsafetensors\n"
},
{
"path": "requirements_demo.txt",
"chars": 298,
"preview": "gradio==5.17.1\nviser==0.2.23\ntqdm\nhydra-core\nomegaconf\nopencv-python\nscipy\nonnxruntime\nrequests\ntrimesh\nmatplotlib\npydan"
},
{
"path": "training/README.md",
"chars": 5499,
"preview": "# Training\n\nThis is a re-implementation of our framework for training VGGT. This document shows how to set up the enviro"
},
{
"path": "training/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "training/config/default.yaml",
"chars": 4619,
"preview": "defaults:\n - default_dataset.yaml\n\nexp_name: exp001\nimg_size: 518\nnum_workers: 8\nseed_value: 42\naccum_steps: 2 # We "
},
{
"path": "training/config/default_dataset.yaml",
"chars": 2411,
"preview": "# Template for the dataset config\ndata:\n # The code still looks too complicated. I should refactor this again (do I hav"
},
{
"path": "training/data/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "training/data/augmentation.py",
"chars": 2247,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "training/data/base_dataset.py",
"chars": 11524,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "training/data/composed_dataset.py",
"chars": 11436,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "training/data/dataset_util.py",
"chars": 24624,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "training/data/datasets/co3d.py",
"chars": 8205,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "training/data/datasets/vkitti.py",
"chars": 8090,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "training/data/dynamic_dataloader.py",
"chars": 9007,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "training/data/preprocess/vkitti.sh",
"chars": 595,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "training/data/track_util.py",
"chars": 17058,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "training/data/worker_fn.py",
"chars": 3613,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "training/launch.py",
"chars": 843,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "training/loss.py",
"chars": 30473,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "training/train_utils/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "training/train_utils/checkpoint.py",
"chars": 2640,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "training/train_utils/distributed.py",
"chars": 652,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "training/train_utils/freeze.py",
"chars": 3136,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "training/train_utils/general.py",
"chars": 11550,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "training/train_utils/gradient_clip.py",
"chars": 4151,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "training/train_utils/logging.py",
"chars": 2363,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "training/train_utils/normalization.py",
"chars": 5077,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "training/train_utils/optimizer.py",
"chars": 11008,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "training/train_utils/tb_writer.py",
"chars": 4299,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "training/trainer.py",
"chars": 33636,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "vggt/dependency/__init__.py",
"chars": 185,
"preview": "from .track_modules.track_refine import refine_track\nfrom .track_modules.blocks import BasicEncoder, ShallowEncoder\nfrom"
},
{
"path": "vggt/dependency/distortion.py",
"chars": 6511,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "vggt/dependency/np_to_pycolmap.py",
"chars": 10764,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "vggt/dependency/projection.py",
"chars": 9059,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "vggt/dependency/track_modules/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "vggt/dependency/track_modules/base_track_predictor.py",
"chars": 7425,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "vggt/dependency/track_modules/blocks.py",
"chars": 12239,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
},
{
"path": "vggt/dependency/track_modules/modules.py",
"chars": 6274,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "vggt/dependency/track_modules/track_refine.py",
"chars": 17290,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "vggt/dependency/track_modules/utils.py",
"chars": 7825,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "vggt/dependency/track_predict.py",
"chars": 12355,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "vggt/dependency/vggsfm_tracker.py",
"chars": 4668,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "vggt/dependency/vggsfm_utils.py",
"chars": 11093,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "vggt/heads/camera_head.py",
"chars": 5817,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "vggt/heads/dpt_head.py",
"chars": 17327,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "vggt/heads/head_act.py",
"chars": 3741,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "vggt/heads/track_head.py",
"chars": 4207,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "vggt/heads/track_modules/__init__.py",
"chars": 198,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "vggt/heads/track_modules/base_track_predictor.py",
"chars": 8027,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "vggt/heads/track_modules/blocks.py",
"chars": 9806,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
},
{
"path": "vggt/heads/track_modules/modules.py",
"chars": 6132,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "vggt/heads/track_modules/utils.py",
"chars": 8136,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "vggt/heads/utils.py",
"chars": 3941,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "vggt/layers/__init__.py",
"chars": 382,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "vggt/layers/attention.py",
"chars": 3067,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# This source code is licensed under the Apache License, Version "
},
{
"path": "vggt/layers/block.py",
"chars": 9379,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# This source code is licensed under the Apache License, Version "
},
{
"path": "vggt/layers/drop_path.py",
"chars": 1157,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# This source code is licensed under the Apache License, Version "
},
{
"path": "vggt/layers/layer_scale.py",
"chars": 781,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# This source code is licensed under the Apache License, Version "
},
{
"path": "vggt/layers/mlp.py",
"chars": 1269,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# This source code is licensed under the Apache License, Version "
},
{
"path": "vggt/layers/patch_embed.py",
"chars": 2794,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# This source code is licensed under the Apache License, Version "
},
{
"path": "vggt/layers/rope.py",
"chars": 7676,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# This source code is licensed under the Apache License, Version "
},
{
"path": "vggt/layers/swiglu_ffn.py",
"chars": 2131,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# This source code is licensed under the Apache License, Version "
},
{
"path": "vggt/layers/vision_transformer.py",
"chars": 15075,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# This source code is licensed under the Apache License, Version "
},
{
"path": "vggt/models/aggregator.py",
"chars": 12944,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "vggt/models/vggt.py",
"chars": 4862,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "vggt/utils/geometry.py",
"chars": 11849,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "vggt/utils/helper.py",
"chars": 2317,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "vggt/utils/load_fn.py",
"chars": 8724,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "vggt/utils/pose_enc.py",
"chars": 5517,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "vggt/utils/rotation.py",
"chars": 4584,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "vggt/utils/visual_track.py",
"chars": 8354,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "visual_util.py",
"chars": 16797,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
}
]
About this extraction
This page contains the full source code of the facebookresearch/vggt GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 83 files (584.7 KB), approximately 146.8k tokens, and a symbol index with 430 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.