Full Code of facebookresearch/vggt for AI

main 44b3afbd1869 cached
83 files
584.7 KB
146.8k tokens
430 symbols
1 requests
Download .txt
Showing preview only (612K chars total). Download the full file or copy to clipboard to get everything.
Repository: facebookresearch/vggt
Branch: main
Commit: 44b3afbd1869
Files: 83
Total size: 584.7 KB

Directory structure:
gitextract_yzgcgvr3/

├── .gitattributes
├── .gitignore
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE.txt
├── README.md
├── demo_colmap.py
├── demo_gradio.py
├── demo_viser.py
├── docs/
│   └── package.md
├── pyproject.toml
├── requirements.txt
├── requirements_demo.txt
├── training/
│   ├── README.md
│   ├── __init__.py
│   ├── config/
│   │   ├── default.yaml
│   │   └── default_dataset.yaml
│   ├── data/
│   │   ├── __init__.py
│   │   ├── augmentation.py
│   │   ├── base_dataset.py
│   │   ├── composed_dataset.py
│   │   ├── dataset_util.py
│   │   ├── datasets/
│   │   │   ├── co3d.py
│   │   │   └── vkitti.py
│   │   ├── dynamic_dataloader.py
│   │   ├── preprocess/
│   │   │   └── vkitti.sh
│   │   ├── track_util.py
│   │   └── worker_fn.py
│   ├── launch.py
│   ├── loss.py
│   ├── train_utils/
│   │   ├── __init__.py
│   │   ├── checkpoint.py
│   │   ├── distributed.py
│   │   ├── freeze.py
│   │   ├── general.py
│   │   ├── gradient_clip.py
│   │   ├── logging.py
│   │   ├── normalization.py
│   │   ├── optimizer.py
│   │   └── tb_writer.py
│   └── trainer.py
├── vggt/
│   ├── dependency/
│   │   ├── __init__.py
│   │   ├── distortion.py
│   │   ├── np_to_pycolmap.py
│   │   ├── projection.py
│   │   ├── track_modules/
│   │   │   ├── __init__.py
│   │   │   ├── base_track_predictor.py
│   │   │   ├── blocks.py
│   │   │   ├── modules.py
│   │   │   ├── track_refine.py
│   │   │   └── utils.py
│   │   ├── track_predict.py
│   │   ├── vggsfm_tracker.py
│   │   └── vggsfm_utils.py
│   ├── heads/
│   │   ├── camera_head.py
│   │   ├── dpt_head.py
│   │   ├── head_act.py
│   │   ├── track_head.py
│   │   ├── track_modules/
│   │   │   ├── __init__.py
│   │   │   ├── base_track_predictor.py
│   │   │   ├── blocks.py
│   │   │   ├── modules.py
│   │   │   └── utils.py
│   │   └── utils.py
│   ├── layers/
│   │   ├── __init__.py
│   │   ├── attention.py
│   │   ├── block.py
│   │   ├── drop_path.py
│   │   ├── layer_scale.py
│   │   ├── mlp.py
│   │   ├── patch_embed.py
│   │   ├── rope.py
│   │   ├── swiglu_ffn.py
│   │   └── vision_transformer.py
│   ├── models/
│   │   ├── aggregator.py
│   │   └── vggt.py
│   └── utils/
│       ├── geometry.py
│       ├── helper.py
│       ├── load_fn.py
│       ├── pose_enc.py
│       ├── rotation.py
│       └── visual_track.py
└── visual_util.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitattributes
================================================
# SCM syntax highlighting & preventing 3-way merges
pixi.lock merge=binary linguist-language=YAML linguist-generated=true


================================================
FILE: .gitignore
================================================
.hydra/
output/
ckpt/
# Byte-compiled / optimized / DLL files
__pycache__/
**/__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Profiling data
.prof

# Folder specific to your needs
**/tmp/
**/outputs/skyseg.onnx
skyseg.onnx

# pixi environments
.pixi
*.egg-info


================================================
FILE: CODE_OF_CONDUCT.md
================================================
# Code of Conduct

## Our Pledge

In the interest of fostering an open and welcoming environment, we as
contributors and maintainers pledge to make participation in our project and
our community a harassment-free experience for everyone, regardless of age, body
size, disability, ethnicity, sex characteristics, gender identity and expression,
level of experience, education, socio-economic status, nationality, personal
appearance, race, religion, or sexual identity and orientation.

## Our Standards

Examples of behavior that contributes to creating a positive environment
include:

* Using welcoming and inclusive language
* Being respectful of differing viewpoints and experiences
* Gracefully accepting constructive criticism
* Focusing on what is best for the community
* Showing empathy towards other community members

Examples of unacceptable behavior by participants include:

* The use of sexualized language or imagery and unwelcome sexual attention or
advances
* Trolling, insulting/derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or electronic
address, without explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting

## Our Responsibilities

Project maintainers are responsible for clarifying the standards of acceptable
behavior and are expected to take appropriate and fair corrective action in
response to any instances of unacceptable behavior.

Project maintainers have the right and responsibility to remove, edit, or
reject comments, commits, code, wiki edits, issues, and other contributions
that are not aligned to this Code of Conduct, or to ban temporarily or
permanently any contributor for other behaviors that they deem inappropriate,
threatening, offensive, or harmful.

## Scope

This Code of Conduct applies within all project spaces, and it also applies when
an individual is representing the project or its community in public spaces.
Examples of representing a project or community include using an official
project e-mail address, posting via an official social media account, or acting
as an appointed representative at an online or offline event. Representation of
a project may be further defined and clarified by project maintainers.

This Code of Conduct also applies outside the project spaces when there is a
reasonable belief that an individual's behavior may have a negative impact on
the project or its community.

## Enforcement

Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported by contacting the project team at <opensource-conduct@meta.com>. All
complaints will be reviewed and investigated and will result in a response that
is deemed necessary and appropriate to the circumstances. The project team is
obligated to maintain confidentiality with regard to the reporter of an incident.
Further details of specific enforcement policies may be posted separately.

Project maintainers who do not follow or enforce the Code of Conduct in good
faith may face temporary or permanent repercussions as determined by other
members of the project's leadership.

## Attribution

This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html

[homepage]: https://www.contributor-covenant.org

For answers to common questions about this code of conduct, see
https://www.contributor-covenant.org/faq


================================================
FILE: CONTRIBUTING.md
================================================
# Contributing to vggt
We want to make contributing to this project as easy and transparent as
possible.

## Pull Requests
We actively welcome your pull requests.

1. Fork the repo and create your branch from `main`.
2. If you've added code that should be tested, add tests.
3. If you've changed APIs, update the documentation.
4. Ensure the test suite passes.
5. Make sure your code lints.
6. If you haven't already, complete the Contributor License Agreement ("CLA").

## Contributor License Agreement ("CLA")
In order to accept your pull request, we need you to submit a CLA. You only need
to do this once to work on any of Facebook's open source projects.

Complete your CLA here: <https://code.facebook.com/cla>

## Issues
We use GitHub issues to track public bugs. Please ensure your description is
clear and has sufficient instructions to be able to reproduce the issue.

Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
disclosure of security bugs. In those cases, please go through the process
outlined on that page and do not file a public issue.

## License
By contributing to vggt, you agree that your contributions will be licensed
under the LICENSE file in the root directory of this source tree.

================================================
FILE: LICENSE.txt
================================================
VGGT License

v1 Last Updated: July 29, 2025

“Acceptable Use Policy” means the Acceptable Use Policy, applicable to Research Materials, that is incorporated into this Agreement.

“Agreement” means the terms and conditions for use, reproduction, distribution and modification of the Research Materials set forth herein.


“Documentation” means the specifications, manuals and documentation accompanying 
Research Materials distributed by Meta.


“Licensee” or “you” means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity’s behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.

“Meta” or “we” means Meta Platforms Ireland Limited (if you are located in or, if you are an entity, your principal place of business is in the EEA or Switzerland) and Meta Platforms, Inc. (if you are located outside of the EEA or Switzerland).
“Research Materials” means, collectively, Documentation and the models, software and algorithms, including machine-learning model code, trained model weights, inference-enabling code, training-enabling code, fine-tuning enabling code, demonstration materials and other elements of the foregoing distributed by Meta and made available under this Agreement.

By clicking “I Accept” below or by using or distributing any portion or element of the Research Materials, you agree to be bound by this Agreement.


1. License Rights and Redistribution.


a. Grant of Rights. You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Meta’s intellectual property or other rights owned by Meta embodied in the Research Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the Research Materials.  
 
b. Redistribution and Use.  


i. Distribution of Research Materials, and any derivative works thereof, are subject to the terms of this Agreement. If you distribute or make the Research Materials, or any derivative works thereof, available to a third party, you may only do so under the terms of this Agreement. You shall also provide a copy of this Agreement to such third party.


ii.  If you submit for publication the results of research you perform on, using, or otherwise in connection with Research Materials, you must acknowledge the use of Research Materials in your publication.


iii. Your use of the Research Materials must comply with applicable laws and regulations (including Trade Control Laws) and adhere to the Acceptable Use Policy, which is hereby incorporated by reference into this Agreement.
2. User Support. Your  use of the Research Materials is done at your own discretion; Meta does not process any information nor provide any service in relation to such use.  Meta is under no obligation to provide any support services for the Research Materials. Any support provided is “as is”, “with all faults”, and without warranty of any kind.


3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE RESEARCH MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN “AS IS” BASIS, WITHOUT WARRANTIES OF ANY KIND, AND META DISCLAIMS ALL WARRANTIES OF ANY KIND, BOTH EXPRESS AND IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE RESEARCH MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE RESEARCH MATERIALS AND ANY OUTPUT AND RESULTS.
 
4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT OR INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
 
5. Intellectual Property.


a. Subject to Meta’s ownership of Research Materials and derivatives made by or for Meta, with respect to any derivative works and modifications of the Research Materials that are made by you, as between you and Meta, you are and will be the owner of such derivative works and modifications.
 
b. If you institute litigation or other proceedings against Meta or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Research Materials, outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Meta from and against any claim by any third party arising out of or related to your use or distribution of the Research Materials.
 
6. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Research Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Meta may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the Research Materials. Sections 5, 6 and 9 shall survive the termination of this Agreement. 
 
7. Governing Law and Jurisdiction. This Agreement will be governed and construed under the laws of the State of California without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. The courts of California shall have exclusive jurisdiction of any dispute arising out of this Agreement. 


8. Modifications and Amendments. Meta may modify this Agreement from time to time; provided that they are similar in spirit to the current version of the Agreement, but may differ in detail to address new problems or concerns. All such changes will be effective immediately. Your continued use of the Research Materials after any modification to this Agreement constitutes your agreement to such modification. Except as provided in this Agreement, no modification or addition to any provision of this Agreement will be binding unless it is in writing and signed by an authorized representative of both you and Meta.


Acceptable Use Policy 

Meta seeks to further understanding of new and existing research domains with the mission of advancing the state-of-the-art in artificial intelligence through open research for the benefit of all. 

As part of this mission, Meta makes certain research materials available for use in accordance with this Agreement (including the Acceptable Use Policy). Meta is committed to promoting the safe and responsible use of such research materials.  

Prohibited Uses

You agree you will not use, or allow others to use, Research Materials to:

 Violate the law or others’ rights, including to:
Engage in, promote, generate, contribute to, encourage, plan, incite, or further illegal or unlawful activity or content, such as:
Violence or terrorism
Exploitation or harm to children, including the solicitation, creation, acquisition, or dissemination of child exploitative content or failure to report Child Sexual Abuse Material
Human trafficking, exploitation, and sexual violence
The illegal distribution of information or materials to minors, including obscene materials, or failure to employ legally required age-gating in connection with such information or materials.
Sexual solicitation
Any other criminal activity

Engage in, promote, incite, or facilitate the harassment, abuse, threatening, or bullying of individuals or groups of individuals

Engage in, promote, incite, or facilitate discrimination or other unlawful or harmful conduct in the provision of employment, employment benefits, credit, housing, other economic benefits, or other essential goods and services

Engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or related professional practices

Collect, process, disclose, generate, or infer health, demographic, or other sensitive personal or private information about individuals without rights and consents required by applicable laws

Engage in or facilitate any action or generate any content that infringes, misappropriates, or otherwise violates any third-party rights, including the outputs or results of any technology using Research Materials

Create, generate, or facilitate the creation of malicious code, malware, computer viruses or do anything else that could disable, overburden, interfere with or impair the proper working, integrity, operation or appearance of a website or computer system

2. Engage in, promote, incite, facilitate, or assist in the planning or development of activities that present a risk of death or bodily harm to individuals, including use of research artifacts related to the following:

Military, warfare, nuclear industries or applications, espionage, use for materials or activities that are subject to the International Traffic Arms Regulations (ITAR) maintained by the United States Department of State

Guns and illegal weapons (including weapon development)

Illegal drugs and regulated/controlled substances
Operation of critical infrastructure, transportation technologies, or heavy machinery

Self-harm or harm to others, including suicide, cutting, and eating disorders
Any content intended to incite or promote violence, abuse, or any infliction of bodily harm to an individual

3. Intentionally deceive or mislead others, including use of Research Materials related to the following:

 Generating, promoting, or furthering fraud or the creation or promotion of disinformation
 Generating, promoting, or furthering defamatory content, including the creation of defamatory statements, images, or other content

Generating, promoting, or further distributing spam

 Impersonating another individual without consent, authorization, or legal right

Representing that outputs of research materials or outputs from technology using Research Materials are human-generated

Generating or facilitating false online engagement, including fake reviews and other means of fake online engagement

4. Fail to appropriately disclose to end users any known dangers of your Research Materials.


================================================
FILE: README.md
================================================
<div align="center">
<h1>VGGT: Visual Geometry Grounded Transformer</h1>

<a href="https://jytime.github.io/data/VGGT_CVPR25.pdf" target="_blank" rel="noopener noreferrer">
  <img src="https://img.shields.io/badge/Paper-VGGT" alt="Paper PDF">
</a>
<a href="https://arxiv.org/abs/2503.11651"><img src="https://img.shields.io/badge/arXiv-2503.11651-b31b1b" alt="arXiv"></a>
<a href="https://vgg-t.github.io/"><img src="https://img.shields.io/badge/Project_Page-green" alt="Project Page"></a>
<a href="https://huggingface.co/spaces/facebook/vggt"><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Demo-blue'></a>


**[Visual Geometry Group, University of Oxford](https://www.robots.ox.ac.uk/~vgg/)**; **[Meta AI](https://ai.facebook.com/research/)**


[Jianyuan Wang](https://jytime.github.io/), [Minghao Chen](https://silent-chen.github.io/), [Nikita Karaev](https://nikitakaraevv.github.io/), [Andrea Vedaldi](https://www.robots.ox.ac.uk/~vedaldi/), [Christian Rupprecht](https://chrirupp.github.io/), [David Novotny](https://d-novotny.github.io/)
</div>

```bibtex
@inproceedings{wang2025vggt,
  title={VGGT: Visual Geometry Grounded Transformer},
  author={Wang, Jianyuan and Chen, Minghao and Karaev, Nikita and Vedaldi, Andrea and Rupprecht, Christian and Novotny, David},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  year={2025}
}
```

## Updates

- [July 29, 2025] We've updated the license for VGGT to permit **commercial use** (excluding military applications). All code in this repository is now under a commercial-use-friendly license. However, only the newly released checkpoint [**VGGT-1B-Commercial**](https://huggingface.co/facebook/VGGT-1B-Commercial) is licensed for commercial usage — the original checkpoint remains non-commercial. Full license details are available [here](https://github.com/facebookresearch/vggt/blob/main/LICENSE.txt). Access to the checkpoint requires completing an application form, which is processed by a system similar to LLaMA's approval workflow, automatically. The new checkpoint delivers similar performance to the original model. Please submit an issue if you notice a significant performance discrepancy.



- [July 6, 2025] Training code is now available in the `training` folder, including an example to finetune VGGT on a custom dataset. 


- [June 13, 2025] Honored to receive the Best Paper Award at CVPR 2025! Apologies if I’m slow to respond to queries or GitHub issues these days. If you’re interested, our oral presentation is available [here](https://docs.google.com/presentation/d/1JVuPnuZx6RgAy-U5Ezobg73XpBi7FrOh/edit?usp=sharing&ouid=107115712143490405606&rtpof=true&sd=true). Another long presentation can be found [here](https://docs.google.com/presentation/d/1aSv0e5PmH1mnwn2MowlJIajFUYZkjqgw/edit?usp=sharing&ouid=107115712143490405606&rtpof=true&sd=true) (Note: it’s shared in .pptx format with animations — quite large, but feel free to use it as a template if helpful.)


- [June 2, 2025] Added a script to run VGGT and save predictions in COLMAP format, with bundle adjustment support optional. The saved COLMAP files can be directly used with [gsplat](https://github.com/nerfstudio-project/gsplat) or other NeRF/Gaussian splatting libraries.


- [May 3, 2025] Evaluation code for reproducing our camera pose estimation results on Co3D is now available in the [evaluation](https://github.com/facebookresearch/vggt/tree/evaluation) branch. 


## Overview

Visual Geometry Grounded Transformer (VGGT, CVPR 2025) is a feed-forward neural network that directly infers all key 3D attributes of a scene, including extrinsic and intrinsic camera parameters, point maps, depth maps, and 3D point tracks, **from one, a few, or hundreds of its views, within seconds**.


## Quick Start

First, clone this repository to your local machine, and install the dependencies (torch, torchvision, numpy, Pillow, and huggingface_hub). 

```bash
git clone git@github.com:facebookresearch/vggt.git 
cd vggt
pip install -r requirements.txt
```

Alternatively, you can install VGGT as a package (<a href="docs/package.md">click here</a> for details).


Now, try the model with just a few lines of code:

```python
import torch
from vggt.models.vggt import VGGT
from vggt.utils.load_fn import load_and_preprocess_images

device = "cuda" if torch.cuda.is_available() else "cpu"
# bfloat16 is supported on Ampere GPUs (Compute Capability 8.0+) 
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16

# Initialize the model and load the pretrained weights.
# This will automatically download the model weights the first time it's run, which may take a while.
model = VGGT.from_pretrained("facebook/VGGT-1B").to(device)

# Load and preprocess example images (replace with your own image paths)
image_names = ["path/to/imageA.png", "path/to/imageB.png", "path/to/imageC.png"]  
images = load_and_preprocess_images(image_names).to(device)

with torch.no_grad():
    with torch.cuda.amp.autocast(dtype=dtype):
        # Predict attributes including cameras, depth maps, and point maps.
        predictions = model(images)
```

The model weights will be automatically downloaded from Hugging Face. If you encounter issues such as slow loading, you can manually download them [here](https://huggingface.co/facebook/VGGT-1B/blob/main/model.pt) and load, or:

```python
model = VGGT()
_URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt"
model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
```

## Detailed Usage

<details>
<summary>Click to expand</summary>

You can also optionally choose which attributes (branches) to predict, as shown below. This achieves the same result as the example above. This example uses a batch size of 1 (processing a single scene), but it naturally works for multiple scenes.

```python
from vggt.utils.pose_enc import pose_encoding_to_extri_intri
from vggt.utils.geometry import unproject_depth_map_to_point_map

with torch.no_grad():
    with torch.cuda.amp.autocast(dtype=dtype):
        images = images[None]  # add batch dimension
        aggregated_tokens_list, ps_idx = model.aggregator(images)
                
    # Predict Cameras
    pose_enc = model.camera_head(aggregated_tokens_list)[-1]
    # Extrinsic and intrinsic matrices, following OpenCV convention (camera from world)
    extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc, images.shape[-2:])

    # Predict Depth Maps
    depth_map, depth_conf = model.depth_head(aggregated_tokens_list, images, ps_idx)

    # Predict Point Maps
    point_map, point_conf = model.point_head(aggregated_tokens_list, images, ps_idx)
        
    # Construct 3D Points from Depth Maps and Cameras
    # which usually leads to more accurate 3D points than point map branch
    point_map_by_unprojection = unproject_depth_map_to_point_map(depth_map.squeeze(0), 
                                                                extrinsic.squeeze(0), 
                                                                intrinsic.squeeze(0))

    # Predict Tracks
    # choose your own points to track, with shape (N, 2) for one scene
    query_points = torch.FloatTensor([[100.0, 200.0], 
                                        [60.72, 259.94]]).to(device)
    track_list, vis_score, conf_score = model.track_head(aggregated_tokens_list, images, ps_idx, query_points=query_points[None])
```


Furthermore, if certain pixels in the input frames are unwanted (e.g., reflective surfaces, sky, or water), you can simply mask them by setting the corresponding pixel values to 0 or 1. Precise segmentation masks aren't necessary - simple bounding box masks work effectively (check this [issue](https://github.com/facebookresearch/vggt/issues/47) for an example).

</details>


## Interactive Demo

We provide multiple ways to visualize your 3D reconstructions. Before using these visualization tools, install the required dependencies:

```bash
pip install -r requirements_demo.txt
```

### Interactive 3D Visualization

**Please note:** VGGT typically reconstructs a scene in less than 1 second. However, visualizing 3D points may take tens of seconds due to third-party rendering, independent of VGGT's processing time. The visualization is slow especially when the number of images is large.


#### Gradio Web Interface

Our Gradio-based interface allows you to upload images/videos, run reconstruction, and interactively explore the 3D scene in your browser. You can launch this in your local machine or try it on [Hugging Face](https://huggingface.co/spaces/facebook/vggt).


```bash
python demo_gradio.py
```

<details>
<summary>Click to preview the Gradio interactive interface</summary>

![Gradio Web Interface Preview](https://jytime.github.io/data/vggt_hf_demo_screen.png)
</details>


#### Viser 3D Viewer

Run the following command to run reconstruction and visualize the point clouds in viser. Note this script requires a path to a folder containing images. It assumes only image files under the folder. You can set `--use_point_map` to use the point cloud from the point map branch, instead of the depth-based point cloud.

```bash
python demo_viser.py --image_folder path/to/your/images/folder
```

## Exporting to COLMAP Format

We also support exporting VGGT's predictions directly to COLMAP format, by:

```bash 
# Feedforward prediction only
python demo_colmap.py --scene_dir=/YOUR/SCENE_DIR/ 

# With bundle adjustment
python demo_colmap.py --scene_dir=/YOUR/SCENE_DIR/ --use_ba

# Run with bundle adjustment using reduced parameters for faster processing
# Reduces max_query_pts from 4096 (default) to 2048 and query_frame_num from 8 (default) to 5
# Trade-off: Faster execution but potentially less robust reconstruction in complex scenes (you may consider setting query_frame_num equal to your total number of images) 
# See demo_colmap.py for additional bundle adjustment configuration options
python demo_colmap.py --scene_dir=/YOUR/SCENE_DIR/ --use_ba --max_query_pts=2048 --query_frame_num=5
```

Please ensure that the images are stored in `/YOUR/SCENE_DIR/images/`. This folder should contain only the images. Check the examples folder for the desired data structure. 

The reconstruction result (camera parameters and 3D points) will be automatically saved under `/YOUR/SCENE_DIR/sparse/` in the COLMAP format, such as:

``` 
SCENE_DIR/
├── images/
└── sparse/
    ├── cameras.bin
    ├── images.bin
    └── points3D.bin
```

## Integration with Gaussian Splatting


The exported COLMAP files can be directly used with [gsplat](https://github.com/nerfstudio-project/gsplat) for Gaussian Splatting training. Install `gsplat` following their official instructions (we recommend `gsplat==1.3.0`):

An example command to train the model is:
```
cd gsplat
python examples/simple_trainer.py  default --data_factor 1 --data_dir /YOUR/SCENE_DIR/ --result_dir /YOUR/RESULT_DIR/
```



## Zero-shot Single-view Reconstruction

Our model shows surprisingly good performance on single-view reconstruction, although it was never trained for this task. The model does not need to duplicate the single-view image to a pair, instead, it can directly infer the 3D structure from the tokens of the single view image. Feel free to try it with our demos above, which naturally works for single-view reconstruction.


We did not quantitatively test monocular depth estimation performance ourselves, but [@kabouzeid](https://github.com/kabouzeid) generously provided a comparison of VGGT to recent methods [here](https://github.com/facebookresearch/vggt/issues/36). VGGT shows competitive or better results compared to state-of-the-art monocular approaches such as DepthAnything v2 or MoGe, despite never being explicitly trained for single-view tasks. 



## Runtime and GPU Memory

We benchmark the runtime and GPU memory usage of VGGT's aggregator on a single NVIDIA H100 GPU across various input sizes. 

| **Input Frames** | 1 | 2 | 4 | 8 | 10 | 20 | 50 | 100 | 200 |
|:----------------:|:-:|:-:|:-:|:-:|:--:|:--:|:--:|:---:|:---:|
| **Time (s)**     | 0.04 | 0.05 | 0.07 | 0.11 | 0.14 | 0.31 | 1.04 | 3.12 | 8.75 |
| **Memory (GB)**  | 1.88 | 2.07 | 2.45 | 3.23 | 3.63 | 5.58 | 11.41 | 21.15 | 40.63 |

Note that these results were obtained using Flash Attention 3, which is faster than the default Flash Attention 2 implementation while maintaining almost the same memory usage. Feel free to compile Flash Attention 3 from source to get better performance.


## Research Progression

Our work builds upon a series of previous research projects. If you're interested in understanding how our research evolved, check out our previous works:


<table border="0" cellspacing="0" cellpadding="0">
  <tr>
    <td align="left">
      <a href="https://github.com/jytime/Deep-SfM-Revisited">Deep SfM Revisited</a>
    </td>
    <td style="white-space: pre;">──┐</td>
    <td></td>
  </tr>
  <tr>
    <td align="left">
      <a href="https://github.com/facebookresearch/PoseDiffusion">PoseDiffusion</a>
    </td>
    <td style="white-space: pre;">─────►</td>
    <td>
      <a href="https://github.com/facebookresearch/vggsfm">VGGSfM</a> ──►
      <a href="https://github.com/facebookresearch/vggt">VGGT</a>
    </td>
  </tr>
  <tr>
    <td align="left">
      <a href="https://github.com/facebookresearch/co-tracker">CoTracker</a>
    </td>
    <td style="white-space: pre;">──┘</td>
    <td></td>
  </tr>
</table>


## Acknowledgements

Thanks to these great repositories: [PoseDiffusion](https://github.com/facebookresearch/PoseDiffusion), [VGGSfM](https://github.com/facebookresearch/vggsfm), [CoTracker](https://github.com/facebookresearch/co-tracker), [DINOv2](https://github.com/facebookresearch/dinov2), [Dust3r](https://github.com/naver/dust3r), [Moge](https://github.com/microsoft/moge), [PyTorch3D](https://github.com/facebookresearch/pytorch3d), [Sky Segmentation](https://github.com/xiongzhu666/Sky-Segmentation-and-Post-processing), [Depth Anything V2](https://github.com/DepthAnything/Depth-Anything-V2), [Metric3D](https://github.com/YvanYin/Metric3D) and many other inspiring works in the community.

## Checklist

- [x] Release the training code
- [ ] Release VGGT-500M and VGGT-200M


## License
See the [LICENSE](./LICENSE.txt) file for details about the license under which this code is made available.

Please note that only this [model checkpoint](https://huggingface.co/facebook/VGGT-1B-Commercial) allows commercial usage. This new checkpoint achieves the same performance level (might be slightly better) as the original one, e.g., AUC@30: 90.37 vs. 89.98 on the Co3D dataset.


================================================
FILE: demo_colmap.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import random
import numpy as np
import glob
import os
import copy
import torch
import torch.nn.functional as F

# Configure CUDA settings
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False

import argparse
from pathlib import Path
import trimesh
import pycolmap


from vggt.models.vggt import VGGT
from vggt.utils.load_fn import load_and_preprocess_images_square
from vggt.utils.pose_enc import pose_encoding_to_extri_intri
from vggt.utils.geometry import unproject_depth_map_to_point_map
from vggt.utils.helper import create_pixel_coordinate_grid, randomly_limit_trues
from vggt.dependency.track_predict import predict_tracks
from vggt.dependency.np_to_pycolmap import batch_np_matrix_to_pycolmap, batch_np_matrix_to_pycolmap_wo_track


# TODO: add support for masks
# TODO: add iterative BA
# TODO: add support for radial distortion, which needs extra_params
# TODO: test with more cases
# TODO: test different camera types


def parse_args():
    parser = argparse.ArgumentParser(description="VGGT Demo")
    parser.add_argument("--scene_dir", type=str, required=True, help="Directory containing the scene images")
    parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
    parser.add_argument("--use_ba", action="store_true", default=False, help="Use BA for reconstruction")
    ######### BA parameters #########
    parser.add_argument(
        "--max_reproj_error", type=float, default=8.0, help="Maximum reprojection error for reconstruction"
    )
    parser.add_argument("--shared_camera", action="store_true", default=False, help="Use shared camera for all images")
    parser.add_argument("--camera_type", type=str, default="SIMPLE_PINHOLE", help="Camera type for reconstruction")
    parser.add_argument("--vis_thresh", type=float, default=0.2, help="Visibility threshold for tracks")
    parser.add_argument("--query_frame_num", type=int, default=8, help="Number of frames to query")
    parser.add_argument("--max_query_pts", type=int, default=4096, help="Maximum number of query points")
    parser.add_argument(
        "--fine_tracking", action="store_true", default=True, help="Use fine tracking (slower but more accurate)"
    )
    parser.add_argument(
        "--conf_thres_value", type=float, default=5.0, help="Confidence threshold value for depth filtering (wo BA)"
    )
    return parser.parse_args()


def run_VGGT(model, images, dtype, resolution=518):
    # images: [B, 3, H, W]

    assert len(images.shape) == 4
    assert images.shape[1] == 3

    # hard-coded to use 518 for VGGT
    images = F.interpolate(images, size=(resolution, resolution), mode="bilinear", align_corners=False)

    with torch.no_grad():
        with torch.cuda.amp.autocast(dtype=dtype):
            images = images[None]  # add batch dimension
            aggregated_tokens_list, ps_idx = model.aggregator(images)

        # Predict Cameras
        pose_enc = model.camera_head(aggregated_tokens_list)[-1]
        # Extrinsic and intrinsic matrices, following OpenCV convention (camera from world)
        extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc, images.shape[-2:])
        # Predict Depth Maps
        depth_map, depth_conf = model.depth_head(aggregated_tokens_list, images, ps_idx)

    extrinsic = extrinsic.squeeze(0).cpu().numpy()
    intrinsic = intrinsic.squeeze(0).cpu().numpy()
    depth_map = depth_map.squeeze(0).cpu().numpy()
    depth_conf = depth_conf.squeeze(0).cpu().numpy()
    return extrinsic, intrinsic, depth_map, depth_conf


def demo_fn(args):
    # Print configuration
    print("Arguments:", vars(args))

    # Set seed for reproducibility
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)  # for multi-GPU
    print(f"Setting seed as: {args.seed}")

    # Set device and dtype
    dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")
    print(f"Using dtype: {dtype}")

    # Run VGGT for camera and depth estimation
    model = VGGT()
    _URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt"
    model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
    model.eval()
    model = model.to(device)
    print(f"Model loaded")

    # Get image paths and preprocess them
    image_dir = os.path.join(args.scene_dir, "images")
    image_path_list = glob.glob(os.path.join(image_dir, "*"))
    if len(image_path_list) == 0:
        raise ValueError(f"No images found in {image_dir}")
    base_image_path_list = [os.path.basename(path) for path in image_path_list]

    # Load images and original coordinates
    # Load Image in 1024, while running VGGT with 518
    vggt_fixed_resolution = 518
    img_load_resolution = 1024

    images, original_coords = load_and_preprocess_images_square(image_path_list, img_load_resolution)
    images = images.to(device)
    original_coords = original_coords.to(device)
    print(f"Loaded {len(images)} images from {image_dir}")

    # Run VGGT to estimate camera and depth
    # Run with 518x518 images
    extrinsic, intrinsic, depth_map, depth_conf = run_VGGT(model, images, dtype, vggt_fixed_resolution)
    points_3d = unproject_depth_map_to_point_map(depth_map, extrinsic, intrinsic)

    if args.use_ba:
        image_size = np.array(images.shape[-2:])
        scale = img_load_resolution / vggt_fixed_resolution
        shared_camera = args.shared_camera

        with torch.cuda.amp.autocast(dtype=dtype):
            # Predicting Tracks
            # Using VGGSfM tracker instead of VGGT tracker for efficiency
            # VGGT tracker requires multiple backbone runs to query different frames (this is a problem caused by the training process)
            # Will be fixed in VGGT v2

            # You can also change the pred_tracks to tracks from any other methods
            # e.g., from COLMAP, from CoTracker, or by chaining 2D matches from Lightglue/LoFTR.
            pred_tracks, pred_vis_scores, pred_confs, points_3d, points_rgb = predict_tracks(
                images,
                conf=depth_conf,
                points_3d=points_3d,
                masks=None,
                max_query_pts=args.max_query_pts,
                query_frame_num=args.query_frame_num,
                keypoint_extractor="aliked+sp",
                fine_tracking=args.fine_tracking,
            )

            torch.cuda.empty_cache()

        # rescale the intrinsic matrix from 518 to 1024
        intrinsic[:, :2, :] *= scale
        track_mask = pred_vis_scores > args.vis_thresh

        # TODO: radial distortion, iterative BA, masks
        reconstruction, valid_track_mask = batch_np_matrix_to_pycolmap(
            points_3d,
            extrinsic,
            intrinsic,
            pred_tracks,
            image_size,
            masks=track_mask,
            max_reproj_error=args.max_reproj_error,
            shared_camera=shared_camera,
            camera_type=args.camera_type,
            points_rgb=points_rgb,
        )

        if reconstruction is None:
            raise ValueError("No reconstruction can be built with BA")

        # Bundle Adjustment
        ba_options = pycolmap.BundleAdjustmentOptions()
        pycolmap.bundle_adjustment(reconstruction, ba_options)

        reconstruction_resolution = img_load_resolution
    else:
        conf_thres_value = args.conf_thres_value
        max_points_for_colmap = 100000  # randomly sample 3D points
        shared_camera = False  # in the feedforward manner, we do not support shared camera
        camera_type = "PINHOLE"  # in the feedforward manner, we only support PINHOLE camera

        image_size = np.array([vggt_fixed_resolution, vggt_fixed_resolution])
        num_frames, height, width, _ = points_3d.shape

        points_rgb = F.interpolate(
            images, size=(vggt_fixed_resolution, vggt_fixed_resolution), mode="bilinear", align_corners=False
        )
        points_rgb = (points_rgb.cpu().numpy() * 255).astype(np.uint8)
        points_rgb = points_rgb.transpose(0, 2, 3, 1)

        # (S, H, W, 3), with x, y coordinates and frame indices
        points_xyf = create_pixel_coordinate_grid(num_frames, height, width)

        conf_mask = depth_conf >= conf_thres_value
        # at most writing 100000 3d points to colmap reconstruction object
        conf_mask = randomly_limit_trues(conf_mask, max_points_for_colmap)

        points_3d = points_3d[conf_mask]
        points_xyf = points_xyf[conf_mask]
        points_rgb = points_rgb[conf_mask]

        print("Converting to COLMAP format")
        reconstruction = batch_np_matrix_to_pycolmap_wo_track(
            points_3d,
            points_xyf,
            points_rgb,
            extrinsic,
            intrinsic,
            image_size,
            shared_camera=shared_camera,
            camera_type=camera_type,
        )

        reconstruction_resolution = vggt_fixed_resolution

    reconstruction = rename_colmap_recons_and_rescale_camera(
        reconstruction,
        base_image_path_list,
        original_coords.cpu().numpy(),
        img_size=reconstruction_resolution,
        shift_point2d_to_original_res=True,
        shared_camera=shared_camera,
    )

    print(f"Saving reconstruction to {args.scene_dir}/sparse")
    sparse_reconstruction_dir = os.path.join(args.scene_dir, "sparse")
    os.makedirs(sparse_reconstruction_dir, exist_ok=True)
    reconstruction.write(sparse_reconstruction_dir)

    # Save point cloud for fast visualization
    trimesh.PointCloud(points_3d, colors=points_rgb).export(os.path.join(args.scene_dir, "sparse/points.ply"))

    return True


def rename_colmap_recons_and_rescale_camera(
    reconstruction, image_paths, original_coords, img_size, shift_point2d_to_original_res=False, shared_camera=False
):
    rescale_camera = True

    for pyimageid in reconstruction.images:
        # Reshaped the padded&resized image to the original size
        # Rename the images to the original names
        pyimage = reconstruction.images[pyimageid]
        pycamera = reconstruction.cameras[pyimage.camera_id]
        pyimage.name = image_paths[pyimageid - 1]

        if rescale_camera:
            # Rescale the camera parameters
            pred_params = copy.deepcopy(pycamera.params)

            real_image_size = original_coords[pyimageid - 1, -2:]
            resize_ratio = max(real_image_size) / img_size
            pred_params = pred_params * resize_ratio
            real_pp = real_image_size / 2
            pred_params[-2:] = real_pp  # center of the image

            pycamera.params = pred_params
            pycamera.width = real_image_size[0]
            pycamera.height = real_image_size[1]

        if shift_point2d_to_original_res:
            # Also shift the point2D to original resolution
            top_left = original_coords[pyimageid - 1, :2]

            for point2D in pyimage.points2D:
                point2D.xy = (point2D.xy - top_left) * resize_ratio

        if shared_camera:
            # If shared_camera, all images share the same camera
            # no need to rescale any more
            rescale_camera = False

    return reconstruction


if __name__ == "__main__":
    args = parse_args()
    with torch.no_grad():
        demo_fn(args)


# Work in Progress (WIP)

"""
VGGT Runner Script
=================

A script to run the VGGT model for 3D reconstruction from image sequences.

Directory Structure
------------------
Input:
    input_folder/
    └── images/            # Source images for reconstruction

Output:
    output_folder/
    ├── images/
    ├── sparse/           # Reconstruction results
    │   ├── cameras.bin   # Camera parameters (COLMAP format)
    │   ├── images.bin    # Pose for each image (COLMAP format)
    │   ├── points3D.bin  # 3D points (COLMAP format)
    │   └── points.ply    # Point cloud visualization file 
    └── visuals/          # Visualization outputs TODO

Key Features
-----------
• Dual-mode Support: Run reconstructions using either VGGT or VGGT+BA
• Resolution Preservation: Maintains original image resolution in camera parameters and tracks
• COLMAP Compatibility: Exports results in standard COLMAP sparse reconstruction format
"""


================================================
FILE: demo_gradio.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import os
import cv2
import torch
import numpy as np
import gradio as gr
import sys
import shutil
from datetime import datetime
import glob
import gc
import time

sys.path.append("vggt/")

from visual_util import predictions_to_glb
from vggt.models.vggt import VGGT
from vggt.utils.load_fn import load_and_preprocess_images
from vggt.utils.pose_enc import pose_encoding_to_extri_intri
from vggt.utils.geometry import unproject_depth_map_to_point_map

device = "cuda" if torch.cuda.is_available() else "cpu"

print("Initializing and loading VGGT model...")
# model = VGGT.from_pretrained("facebook/VGGT-1B")  # another way to load the model

model = VGGT()
_URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt"
model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))


model.eval()
model = model.to(device)


# -------------------------------------------------------------------------
# 1) Core model inference
# -------------------------------------------------------------------------
def run_model(target_dir, model) -> dict:
    """
    Run the VGGT model on images in the 'target_dir/images' folder and return predictions.
    """
    print(f"Processing images from {target_dir}")

    # Device check
    device = "cuda" if torch.cuda.is_available() else "cpu"
    if not torch.cuda.is_available():
        raise ValueError("CUDA is not available. Check your environment.")

    # Move model to device
    model = model.to(device)
    model.eval()

    # Load and preprocess images
    image_names = glob.glob(os.path.join(target_dir, "images", "*"))
    image_names = sorted(image_names)
    print(f"Found {len(image_names)} images")
    if len(image_names) == 0:
        raise ValueError("No images found. Check your upload.")

    images = load_and_preprocess_images(image_names).to(device)
    print(f"Preprocessed images shape: {images.shape}")

    # Run inference
    print("Running inference...")
    dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16

    with torch.no_grad():
        with torch.cuda.amp.autocast(dtype=dtype):
            predictions = model(images)

    # Convert pose encoding to extrinsic and intrinsic matrices
    print("Converting pose encoding to extrinsic and intrinsic matrices...")
    extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], images.shape[-2:])
    predictions["extrinsic"] = extrinsic
    predictions["intrinsic"] = intrinsic

    # Convert tensors to numpy
    for key in predictions.keys():
        if isinstance(predictions[key], torch.Tensor):
            predictions[key] = predictions[key].cpu().numpy().squeeze(0)  # remove batch dimension
    predictions['pose_enc_list'] = None # remove pose_enc_list

    # Generate world points from depth map
    print("Computing world points from depth map...")
    depth_map = predictions["depth"]  # (S, H, W, 1)
    world_points = unproject_depth_map_to_point_map(depth_map, predictions["extrinsic"], predictions["intrinsic"])
    predictions["world_points_from_depth"] = world_points

    # Clean up
    torch.cuda.empty_cache()
    return predictions


# -------------------------------------------------------------------------
# 2) Handle uploaded video/images --> produce target_dir + images
# -------------------------------------------------------------------------
def handle_uploads(input_video, input_images):
    """
    Create a new 'target_dir' + 'images' subfolder, and place user-uploaded
    images or extracted frames from video into it. Return (target_dir, image_paths).
    """
    start_time = time.time()
    gc.collect()
    torch.cuda.empty_cache()

    # Create a unique folder name
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
    target_dir = f"input_images_{timestamp}"
    target_dir_images = os.path.join(target_dir, "images")

    # Clean up if somehow that folder already exists
    if os.path.exists(target_dir):
        shutil.rmtree(target_dir)
    os.makedirs(target_dir)
    os.makedirs(target_dir_images)

    image_paths = []

    # --- Handle images ---
    if input_images is not None:
        for file_data in input_images:
            if isinstance(file_data, dict) and "name" in file_data:
                file_path = file_data["name"]
            else:
                file_path = file_data
            dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
            shutil.copy(file_path, dst_path)
            image_paths.append(dst_path)

    # --- Handle video ---
    if input_video is not None:
        if isinstance(input_video, dict) and "name" in input_video:
            video_path = input_video["name"]
        else:
            video_path = input_video

        vs = cv2.VideoCapture(video_path)
        fps = vs.get(cv2.CAP_PROP_FPS)
        frame_interval = int(fps * 1)  # 1 frame/sec

        count = 0
        video_frame_num = 0
        while True:
            gotit, frame = vs.read()
            if not gotit:
                break
            count += 1
            if count % frame_interval == 0:
                image_path = os.path.join(target_dir_images, f"{video_frame_num:06}.png")
                cv2.imwrite(image_path, frame)
                image_paths.append(image_path)
                video_frame_num += 1

    # Sort final images for gallery
    image_paths = sorted(image_paths)

    end_time = time.time()
    print(f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds")
    return target_dir, image_paths


# -------------------------------------------------------------------------
# 3) Update gallery on upload
# -------------------------------------------------------------------------
def update_gallery_on_upload(input_video, input_images):
    """
    Whenever user uploads or changes files, immediately handle them
    and show in the gallery. Return (target_dir, image_paths).
    If nothing is uploaded, returns "None" and empty list.
    """
    if not input_video and not input_images:
        return None, None, None, None
    target_dir, image_paths = handle_uploads(input_video, input_images)
    return None, target_dir, image_paths, "Upload complete. Click 'Reconstruct' to begin 3D processing."


# -------------------------------------------------------------------------
# 4) Reconstruction: uses the target_dir plus any viz parameters
# -------------------------------------------------------------------------
def gradio_demo(
    target_dir,
    conf_thres=3.0,
    frame_filter="All",
    mask_black_bg=False,
    mask_white_bg=False,
    show_cam=True,
    mask_sky=False,
    prediction_mode="Pointmap Regression",
):
    """
    Perform reconstruction using the already-created target_dir/images.
    """
    if not os.path.isdir(target_dir) or target_dir == "None":
        return None, "No valid target directory found. Please upload first.", None, None

    start_time = time.time()
    gc.collect()
    torch.cuda.empty_cache()

    # Prepare frame_filter dropdown
    target_dir_images = os.path.join(target_dir, "images")
    all_files = sorted(os.listdir(target_dir_images)) if os.path.isdir(target_dir_images) else []
    all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
    frame_filter_choices = ["All"] + all_files

    print("Running run_model...")
    with torch.no_grad():
        predictions = run_model(target_dir, model)

    # Save predictions
    prediction_save_path = os.path.join(target_dir, "predictions.npz")
    np.savez(prediction_save_path, **predictions)

    # Handle None frame_filter
    if frame_filter is None:
        frame_filter = "All"

    # Build a GLB file name
    glbfile = os.path.join(
        target_dir,
        f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_pred{prediction_mode.replace(' ', '_')}.glb",
    )

    # Convert predictions to GLB
    glbscene = predictions_to_glb(
        predictions,
        conf_thres=conf_thres,
        filter_by_frames=frame_filter,
        mask_black_bg=mask_black_bg,
        mask_white_bg=mask_white_bg,
        show_cam=show_cam,
        mask_sky=mask_sky,
        target_dir=target_dir,
        prediction_mode=prediction_mode,
    )
    glbscene.export(file_obj=glbfile)

    # Cleanup
    del predictions
    gc.collect()
    torch.cuda.empty_cache()

    end_time = time.time()
    print(f"Total time: {end_time - start_time:.2f} seconds (including IO)")
    log_msg = f"Reconstruction Success ({len(all_files)} frames). Waiting for visualization."

    return glbfile, log_msg, gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True)


# -------------------------------------------------------------------------
# 5) Helper functions for UI resets + re-visualization
# -------------------------------------------------------------------------
def clear_fields():
    """
    Clears the 3D viewer, the stored target_dir, and empties the gallery.
    """
    return None


def update_log():
    """
    Display a quick log message while waiting.
    """
    return "Loading and Reconstructing..."


def update_visualization(
    target_dir, conf_thres, frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode, is_example
):
    """
    Reload saved predictions from npz, create (or reuse) the GLB for new parameters,
    and return it for the 3D viewer. If is_example == "True", skip.
    """

    # If it's an example click, skip as requested
    if is_example == "True":
        return None, "No reconstruction available. Please click the Reconstruct button first."

    if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
        return None, "No reconstruction available. Please click the Reconstruct button first."

    predictions_path = os.path.join(target_dir, "predictions.npz")
    if not os.path.exists(predictions_path):
        return None, f"No reconstruction available at {predictions_path}. Please run 'Reconstruct' first."

    key_list = [
        "pose_enc",
        "depth",
        "depth_conf",
        "world_points",
        "world_points_conf",
        "images",
        "extrinsic",
        "intrinsic",
        "world_points_from_depth",
    ]

    loaded = np.load(predictions_path)
    predictions = {key: np.array(loaded[key]) for key in key_list}

    glbfile = os.path.join(
        target_dir,
        f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_pred{prediction_mode.replace(' ', '_')}.glb",
    )

    if not os.path.exists(glbfile):
        glbscene = predictions_to_glb(
            predictions,
            conf_thres=conf_thres,
            filter_by_frames=frame_filter,
            mask_black_bg=mask_black_bg,
            mask_white_bg=mask_white_bg,
            show_cam=show_cam,
            mask_sky=mask_sky,
            target_dir=target_dir,
            prediction_mode=prediction_mode,
        )
        glbscene.export(file_obj=glbfile)

    return glbfile, "Updating Visualization"


# -------------------------------------------------------------------------
# Example images
# -------------------------------------------------------------------------

great_wall_video = "examples/videos/great_wall.mp4"
colosseum_video = "examples/videos/Colosseum.mp4"
room_video = "examples/videos/room.mp4"
kitchen_video = "examples/videos/kitchen.mp4"
fern_video = "examples/videos/fern.mp4"
single_cartoon_video = "examples/videos/single_cartoon.mp4"
single_oil_painting_video = "examples/videos/single_oil_painting.mp4"
pyramid_video = "examples/videos/pyramid.mp4"


# -------------------------------------------------------------------------
# 6) Build Gradio UI
# -------------------------------------------------------------------------
theme = gr.themes.Ocean()
theme.set(
    checkbox_label_background_fill_selected="*button_primary_background_fill",
    checkbox_label_text_color_selected="*button_primary_text_color",
)

with gr.Blocks(
    theme=theme,
    css="""
    .custom-log * {
        font-style: italic;
        font-size: 22px !important;
        background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%);
        -webkit-background-clip: text;
        background-clip: text;
        font-weight: bold !important;
        color: transparent !important;
        text-align: center !important;
    }
    
    .example-log * {
        font-style: italic;
        font-size: 16px !important;
        background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%);
        -webkit-background-clip: text;
        background-clip: text;
        color: transparent !important;
    }
    
    #my_radio .wrap {
        display: flex;
        flex-wrap: nowrap;
        justify-content: center;
        align-items: center;
    }

    #my_radio .wrap label {
        display: flex;
        width: 50%;
        justify-content: center;
        align-items: center;
        margin: 0;
        padding: 10px 0;
        box-sizing: border-box;
    }
    """,
) as demo:
    # Instead of gr.State, we use a hidden Textbox:
    is_example = gr.Textbox(label="is_example", visible=False, value="None")
    num_images = gr.Textbox(label="num_images", visible=False, value="None")

    gr.HTML(
        """
    <h1>🏛️ VGGT: Visual Geometry Grounded Transformer</h1>
    <p>
    <a href="https://github.com/facebookresearch/vggt">🐙 GitHub Repository</a> |
    <a href="#">Project Page</a>
    </p>

    <div style="font-size: 16px; line-height: 1.5;">
    <p>Upload a video or a set of images to create a 3D reconstruction of a scene or object. VGGT takes these images and generates a 3D point cloud, along with estimated camera poses.</p>

    <h3>Getting Started:</h3>
    <ol>
        <li><strong>Upload Your Data:</strong> Use the "Upload Video" or "Upload Images" buttons on the left to provide your input. Videos will be automatically split into individual frames (one frame per second).</li>
        <li><strong>Preview:</strong> Your uploaded images will appear in the gallery on the left.</li>
        <li><strong>Reconstruct:</strong> Click the "Reconstruct" button to start the 3D reconstruction process.</li>
        <li><strong>Visualize:</strong> The 3D reconstruction will appear in the viewer on the right. You can rotate, pan, and zoom to explore the model, and download the GLB file. Note the visualization of 3D points may be slow for a large number of input images.</li>
        <li>
        <strong>Adjust Visualization (Optional):</strong>
        After reconstruction, you can fine-tune the visualization using the options below
        <details style="display:inline;">
            <summary style="display:inline;">(<strong>click to expand</strong>):</summary>
            <ul>
            <li><em>Confidence Threshold:</em> Adjust the filtering of points based on confidence.</li>
            <li><em>Show Points from Frame:</em> Select specific frames to display in the point cloud.</li>
            <li><em>Show Camera:</em> Toggle the display of estimated camera positions.</li>
            <li><em>Filter Sky / Filter Black Background:</em> Remove sky or black-background points.</li>
            <li><em>Select a Prediction Mode:</em> Choose between "Depthmap and Camera Branch" or "Pointmap Branch."</li>
            </ul>
        </details>
        </li>
    </ol>
    <p><strong style="color: #0ea5e9;">Please note:</strong> <span style="color: #0ea5e9; font-weight: bold;">VGGT typically reconstructs a scene in less than 1 second. However, visualizing 3D points may take tens of seconds due to third-party rendering, which are independent of VGGT's processing time. </span></p>
    </div>
    """
    )

    target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")

    with gr.Row():
        with gr.Column(scale=2):
            input_video = gr.Video(label="Upload Video", interactive=True)
            input_images = gr.File(file_count="multiple", label="Upload Images", interactive=True)

            image_gallery = gr.Gallery(
                label="Preview",
                columns=4,
                height="300px",
                show_download_button=True,
                object_fit="contain",
                preview=True,
            )

        with gr.Column(scale=4):
            with gr.Column():
                gr.Markdown("**3D Reconstruction (Point Cloud and Camera Poses)**")
                log_output = gr.Markdown(
                    "Please upload a video or images, then click Reconstruct.", elem_classes=["custom-log"]
                )
                reconstruction_output = gr.Model3D(height=520, zoom_speed=0.5, pan_speed=0.5)

            with gr.Row():
                submit_btn = gr.Button("Reconstruct", scale=1, variant="primary")
                clear_btn = gr.ClearButton(
                    [input_video, input_images, reconstruction_output, log_output, target_dir_output, image_gallery],
                    scale=1,
                )

            with gr.Row():
                prediction_mode = gr.Radio(
                    ["Depthmap and Camera Branch", "Pointmap Branch"],
                    label="Select a Prediction Mode",
                    value="Depthmap and Camera Branch",
                    scale=1,
                    elem_id="my_radio",
                )

            with gr.Row():
                conf_thres = gr.Slider(minimum=0, maximum=100, value=50, step=0.1, label="Confidence Threshold (%)")
                frame_filter = gr.Dropdown(choices=["All"], value="All", label="Show Points from Frame")
                with gr.Column():
                    show_cam = gr.Checkbox(label="Show Camera", value=True)
                    mask_sky = gr.Checkbox(label="Filter Sky", value=False)
                    mask_black_bg = gr.Checkbox(label="Filter Black Background", value=False)
                    mask_white_bg = gr.Checkbox(label="Filter White Background", value=False)

    # ---------------------- Examples section ----------------------
    examples = [
        [colosseum_video, "22", None, 20.0, False, False, True, False, "Depthmap and Camera Branch", "True"],
        [pyramid_video, "30", None, 35.0, False, False, True, False, "Depthmap and Camera Branch", "True"],
        [single_cartoon_video, "1", None, 15.0, False, False, True, False, "Depthmap and Camera Branch", "True"],
        [single_oil_painting_video, "1", None, 20.0, False, False, True, True, "Depthmap and Camera Branch", "True"],
        [room_video, "8", None, 5.0, False, False, True, False, "Depthmap and Camera Branch", "True"],
        [kitchen_video, "25", None, 50.0, False, False, True, False, "Depthmap and Camera Branch", "True"],
        [fern_video, "20", None, 45.0, False, False, True, False, "Depthmap and Camera Branch", "True"],
    ]

    def example_pipeline(
        input_video,
        num_images_str,
        input_images,
        conf_thres,
        mask_black_bg,
        mask_white_bg,
        show_cam,
        mask_sky,
        prediction_mode,
        is_example_str,
    ):
        """
        1) Copy example images to new target_dir
        2) Reconstruct
        3) Return model3D + logs + new_dir + updated dropdown + gallery
        We do NOT return is_example. It's just an input.
        """
        target_dir, image_paths = handle_uploads(input_video, input_images)
        # Always use "All" for frame_filter in examples
        frame_filter = "All"
        glbfile, log_msg, dropdown = gradio_demo(
            target_dir, conf_thres, frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode
        )
        return glbfile, log_msg, target_dir, dropdown, image_paths

    gr.Markdown("Click any row to load an example.", elem_classes=["example-log"])

    gr.Examples(
        examples=examples,
        inputs=[
            input_video,
            num_images,
            input_images,
            conf_thres,
            mask_black_bg,
            mask_white_bg,
            show_cam,
            mask_sky,
            prediction_mode,
            is_example,
        ],
        outputs=[reconstruction_output, log_output, target_dir_output, frame_filter, image_gallery],
        fn=example_pipeline,
        cache_examples=False,
        examples_per_page=50,
    )

    # -------------------------------------------------------------------------
    # "Reconstruct" button logic:
    #  - Clear fields
    #  - Update log
    #  - gradio_demo(...) with the existing target_dir
    #  - Then set is_example = "False"
    # -------------------------------------------------------------------------
    submit_btn.click(fn=clear_fields, inputs=[], outputs=[reconstruction_output]).then(
        fn=update_log, inputs=[], outputs=[log_output]
    ).then(
        fn=gradio_demo,
        inputs=[
            target_dir_output,
            conf_thres,
            frame_filter,
            mask_black_bg,
            mask_white_bg,
            show_cam,
            mask_sky,
            prediction_mode,
        ],
        outputs=[reconstruction_output, log_output, frame_filter],
    ).then(
        fn=lambda: "False", inputs=[], outputs=[is_example]  # set is_example to "False"
    )

    # -------------------------------------------------------------------------
    # Real-time Visualization Updates
    # -------------------------------------------------------------------------
    conf_thres.change(
        update_visualization,
        [
            target_dir_output,
            conf_thres,
            frame_filter,
            mask_black_bg,
            mask_white_bg,
            show_cam,
            mask_sky,
            prediction_mode,
            is_example,
        ],
        [reconstruction_output, log_output],
    )
    frame_filter.change(
        update_visualization,
        [
            target_dir_output,
            conf_thres,
            frame_filter,
            mask_black_bg,
            mask_white_bg,
            show_cam,
            mask_sky,
            prediction_mode,
            is_example,
        ],
        [reconstruction_output, log_output],
    )
    mask_black_bg.change(
        update_visualization,
        [
            target_dir_output,
            conf_thres,
            frame_filter,
            mask_black_bg,
            mask_white_bg,
            show_cam,
            mask_sky,
            prediction_mode,
            is_example,
        ],
        [reconstruction_output, log_output],
    )
    mask_white_bg.change(
        update_visualization,
        [
            target_dir_output,
            conf_thres,
            frame_filter,
            mask_black_bg,
            mask_white_bg,
            show_cam,
            mask_sky,
            prediction_mode,
            is_example,
        ],
        [reconstruction_output, log_output],
    )
    show_cam.change(
        update_visualization,
        [
            target_dir_output,
            conf_thres,
            frame_filter,
            mask_black_bg,
            mask_white_bg,
            show_cam,
            mask_sky,
            prediction_mode,
            is_example,
        ],
        [reconstruction_output, log_output],
    )
    mask_sky.change(
        update_visualization,
        [
            target_dir_output,
            conf_thres,
            frame_filter,
            mask_black_bg,
            mask_white_bg,
            show_cam,
            mask_sky,
            prediction_mode,
            is_example,
        ],
        [reconstruction_output, log_output],
    )
    prediction_mode.change(
        update_visualization,
        [
            target_dir_output,
            conf_thres,
            frame_filter,
            mask_black_bg,
            mask_white_bg,
            show_cam,
            mask_sky,
            prediction_mode,
            is_example,
        ],
        [reconstruction_output, log_output],
    )

    # -------------------------------------------------------------------------
    # Auto-update gallery whenever user uploads or changes their files
    # -------------------------------------------------------------------------
    input_video.change(
        fn=update_gallery_on_upload,
        inputs=[input_video, input_images],
        outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
    )
    input_images.change(
        fn=update_gallery_on_upload,
        inputs=[input_video, input_images],
        outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
    )

    demo.queue(max_size=20).launch(show_error=True, share=True)


================================================
FILE: demo_viser.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import os
import glob
import time
import threading
import argparse
from typing import List, Optional

import numpy as np
import torch
from tqdm.auto import tqdm
import viser
import viser.transforms as viser_tf
import cv2


try:
    import onnxruntime
except ImportError:
    print("onnxruntime not found. Sky segmentation may not work.")

from visual_util import segment_sky, download_file_from_url
from vggt.models.vggt import VGGT
from vggt.utils.load_fn import load_and_preprocess_images
from vggt.utils.geometry import closed_form_inverse_se3, unproject_depth_map_to_point_map
from vggt.utils.pose_enc import pose_encoding_to_extri_intri


def viser_wrapper(
    pred_dict: dict,
    port: int = 8080,
    init_conf_threshold: float = 50.0,  # represents percentage (e.g., 50 means filter lowest 50%)
    use_point_map: bool = False,
    background_mode: bool = False,
    mask_sky: bool = False,
    image_folder: str = None,
):
    """
    Visualize predicted 3D points and camera poses with viser.

    Args:
        pred_dict (dict):
            {
                "images": (S, 3, H, W)   - Input images,
                "world_points": (S, H, W, 3),
                "world_points_conf": (S, H, W),
                "depth": (S, H, W, 1),
                "depth_conf": (S, H, W),
                "extrinsic": (S, 3, 4),
                "intrinsic": (S, 3, 3),
            }
        port (int): Port number for the viser server.
        init_conf_threshold (float): Initial percentage of low-confidence points to filter out.
        use_point_map (bool): Whether to visualize world_points or use depth-based points.
        background_mode (bool): Whether to run the server in background thread.
        mask_sky (bool): Whether to apply sky segmentation to filter out sky points.
        image_folder (str): Path to the folder containing input images.
    """
    print(f"Starting viser server on port {port}")

    server = viser.ViserServer(host="0.0.0.0", port=port)
    server.gui.configure_theme(titlebar_content=None, control_layout="collapsible")

    # Unpack prediction dict
    images = pred_dict["images"]  # (S, 3, H, W)
    world_points_map = pred_dict["world_points"]  # (S, H, W, 3)
    conf_map = pred_dict["world_points_conf"]  # (S, H, W)

    depth_map = pred_dict["depth"]  # (S, H, W, 1)
    depth_conf = pred_dict["depth_conf"]  # (S, H, W)

    extrinsics_cam = pred_dict["extrinsic"]  # (S, 3, 4)
    intrinsics_cam = pred_dict["intrinsic"]  # (S, 3, 3)

    # Compute world points from depth if not using the precomputed point map
    if not use_point_map:
        world_points = unproject_depth_map_to_point_map(depth_map, extrinsics_cam, intrinsics_cam)
        conf = depth_conf
    else:
        world_points = world_points_map
        conf = conf_map

    # Apply sky segmentation if enabled
    if mask_sky and image_folder is not None:
        conf = apply_sky_segmentation(conf, image_folder)

    # Convert images from (S, 3, H, W) to (S, H, W, 3)
    # Then flatten everything for the point cloud
    colors = images.transpose(0, 2, 3, 1)  # now (S, H, W, 3)
    S, H, W, _ = world_points.shape

    # Flatten
    points = world_points.reshape(-1, 3)
    colors_flat = (colors.reshape(-1, 3) * 255).astype(np.uint8)
    conf_flat = conf.reshape(-1)

    cam_to_world_mat = closed_form_inverse_se3(extrinsics_cam)  # shape (S, 4, 4) typically
    # For convenience, we store only (3,4) portion
    cam_to_world = cam_to_world_mat[:, :3, :]

    # Compute scene center and recenter
    scene_center = np.mean(points, axis=0)
    points_centered = points - scene_center
    cam_to_world[..., -1] -= scene_center

    # Store frame indices so we can filter by frame
    frame_indices = np.repeat(np.arange(S), H * W)

    # Build the viser GUI
    gui_show_frames = server.gui.add_checkbox("Show Cameras", initial_value=True)

    # Now the slider represents percentage of points to filter out
    gui_points_conf = server.gui.add_slider(
        "Confidence Percent", min=0, max=100, step=0.1, initial_value=init_conf_threshold
    )

    gui_frame_selector = server.gui.add_dropdown(
        "Show Points from Frames", options=["All"] + [str(i) for i in range(S)], initial_value="All"
    )

    # Create the main point cloud handle
    # Compute the threshold value as the given percentile
    init_threshold_val = np.percentile(conf_flat, init_conf_threshold)
    init_conf_mask = (conf_flat >= init_threshold_val) & (conf_flat > 0.1)
    point_cloud = server.scene.add_point_cloud(
        name="viser_pcd",
        points=points_centered[init_conf_mask],
        colors=colors_flat[init_conf_mask],
        point_size=0.001,
        point_shape="circle",
    )

    # We will store references to frames & frustums so we can toggle visibility
    frames: List[viser.FrameHandle] = []
    frustums: List[viser.CameraFrustumHandle] = []

    def visualize_frames(extrinsics: np.ndarray, images_: np.ndarray) -> None:
        """
        Add camera frames and frustums to the scene.
        extrinsics: (S, 3, 4)
        images_:    (S, 3, H, W)
        """
        # Clear any existing frames or frustums
        for f in frames:
            f.remove()
        frames.clear()
        for fr in frustums:
            fr.remove()
        frustums.clear()

        # Optionally attach a callback that sets the viewpoint to the chosen camera
        def attach_callback(frustum: viser.CameraFrustumHandle, frame: viser.FrameHandle) -> None:
            @frustum.on_click
            def _(_) -> None:
                for client in server.get_clients().values():
                    client.camera.wxyz = frame.wxyz
                    client.camera.position = frame.position

        img_ids = range(S)
        for img_id in tqdm(img_ids):
            cam2world_3x4 = extrinsics[img_id]
            T_world_camera = viser_tf.SE3.from_matrix(cam2world_3x4)

            # Add a small frame axis
            frame_axis = server.scene.add_frame(
                f"frame_{img_id}",
                wxyz=T_world_camera.rotation().wxyz,
                position=T_world_camera.translation(),
                axes_length=0.05,
                axes_radius=0.002,
                origin_radius=0.002,
            )
            frames.append(frame_axis)

            # Convert the image for the frustum
            img = images_[img_id]  # shape (3, H, W)
            img = (img.transpose(1, 2, 0) * 255).astype(np.uint8)
            h, w = img.shape[:2]

            # If you want correct FOV from intrinsics, do something like:
            # fx = intrinsics_cam[img_id, 0, 0]
            # fov = 2 * np.arctan2(h/2, fx)
            # For demonstration, we pick a simple approximate FOV:
            fy = 1.1 * h
            fov = 2 * np.arctan2(h / 2, fy)

            # Add the frustum
            frustum_cam = server.scene.add_camera_frustum(
                f"frame_{img_id}/frustum", fov=fov, aspect=w / h, scale=0.05, image=img, line_width=1.0
            )
            frustums.append(frustum_cam)
            attach_callback(frustum_cam, frame_axis)

    def update_point_cloud() -> None:
        """Update the point cloud based on current GUI selections."""
        # Here we compute the threshold value based on the current percentage
        current_percentage = gui_points_conf.value
        threshold_val = np.percentile(conf_flat, current_percentage)

        print(f"Threshold absolute value: {threshold_val}, percentage: {current_percentage}%")

        conf_mask = (conf_flat >= threshold_val) & (conf_flat > 1e-5)

        if gui_frame_selector.value == "All":
            frame_mask = np.ones_like(conf_mask, dtype=bool)
        else:
            selected_idx = int(gui_frame_selector.value)
            frame_mask = frame_indices == selected_idx

        combined_mask = conf_mask & frame_mask
        point_cloud.points = points_centered[combined_mask]
        point_cloud.colors = colors_flat[combined_mask]

    @gui_points_conf.on_update
    def _(_) -> None:
        update_point_cloud()

    @gui_frame_selector.on_update
    def _(_) -> None:
        update_point_cloud()

    @gui_show_frames.on_update
    def _(_) -> None:
        """Toggle visibility of camera frames and frustums."""
        for f in frames:
            f.visible = gui_show_frames.value
        for fr in frustums:
            fr.visible = gui_show_frames.value

    # Add the camera frames to the scene
    visualize_frames(cam_to_world, images)

    print("Starting viser server...")
    # If background_mode is True, spawn a daemon thread so the main thread can continue.
    if background_mode:

        def server_loop():
            while True:
                time.sleep(0.001)

        thread = threading.Thread(target=server_loop, daemon=True)
        thread.start()
    else:
        while True:
            time.sleep(0.01)

    return server


# Helper functions for sky segmentation


def apply_sky_segmentation(conf: np.ndarray, image_folder: str) -> np.ndarray:
    """
    Apply sky segmentation to confidence scores.

    Args:
        conf (np.ndarray): Confidence scores with shape (S, H, W)
        image_folder (str): Path to the folder containing input images

    Returns:
        np.ndarray: Updated confidence scores with sky regions masked out
    """
    S, H, W = conf.shape
    sky_masks_dir = image_folder.rstrip("/") + "_sky_masks"
    os.makedirs(sky_masks_dir, exist_ok=True)

    # Download skyseg.onnx if it doesn't exist
    if not os.path.exists("skyseg.onnx"):
        print("Downloading skyseg.onnx...")
        download_file_from_url("https://huggingface.co/JianyuanWang/skyseg/resolve/main/skyseg.onnx", "skyseg.onnx")

    skyseg_session = onnxruntime.InferenceSession("skyseg.onnx")
    image_files = sorted(glob.glob(os.path.join(image_folder, "*")))
    sky_mask_list = []

    print("Generating sky masks...")
    for i, image_path in enumerate(tqdm(image_files[:S])):  # Limit to the number of images in the batch
        image_name = os.path.basename(image_path)
        mask_filepath = os.path.join(sky_masks_dir, image_name)

        if os.path.exists(mask_filepath):
            sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE)
        else:
            sky_mask = segment_sky(image_path, skyseg_session, mask_filepath)

        # Resize mask to match H×W if needed
        if sky_mask.shape[0] != H or sky_mask.shape[1] != W:
            sky_mask = cv2.resize(sky_mask, (W, H))

        sky_mask_list.append(sky_mask)

    # Convert list to numpy array with shape S×H×W
    sky_mask_array = np.array(sky_mask_list)
    # Apply sky mask to confidence scores
    sky_mask_binary = (sky_mask_array > 0.1).astype(np.float32)
    conf = conf * sky_mask_binary

    print("Sky segmentation applied successfully")
    return conf


parser = argparse.ArgumentParser(description="VGGT demo with viser for 3D visualization")
parser.add_argument(
    "--image_folder", type=str, default="examples/kitchen/images/", help="Path to folder containing images"
)
parser.add_argument("--use_point_map", action="store_true", help="Use point map instead of depth-based points")
parser.add_argument("--background_mode", action="store_true", help="Run the viser server in background mode")
parser.add_argument("--port", type=int, default=8080, help="Port number for the viser server")
parser.add_argument(
    "--conf_threshold", type=float, default=25.0, help="Initial percentage of low-confidence points to filter out"
)
parser.add_argument("--mask_sky", action="store_true", help="Apply sky segmentation to filter out sky points")


def main():
    """
    Main function for the VGGT demo with viser for 3D visualization.

    This function:
    1. Loads the VGGT model
    2. Processes input images from the specified folder
    3. Runs inference to generate 3D points and camera poses
    4. Optionally applies sky segmentation to filter out sky points
    5. Visualizes the results using viser

    Command-line arguments:
    --image_folder: Path to folder containing input images
    --use_point_map: Use point map instead of depth-based points
    --background_mode: Run the viser server in background mode
    --port: Port number for the viser server
    --conf_threshold: Initial percentage of low-confidence points to filter out
    --mask_sky: Apply sky segmentation to filter out sky points
    """
    args = parser.parse_args()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    print("Initializing and loading VGGT model...")
    # model = VGGT.from_pretrained("facebook/VGGT-1B")

    model = VGGT()
    _URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt"
    model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))

    model.eval()
    model = model.to(device)

    # Use the provided image folder path
    print(f"Loading images from {args.image_folder}...")
    image_names = glob.glob(os.path.join(args.image_folder, "*"))
    print(f"Found {len(image_names)} images")

    images = load_and_preprocess_images(image_names).to(device)
    print(f"Preprocessed images shape: {images.shape}")

    print("Running inference...")
    dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16

    with torch.no_grad():
        with torch.cuda.amp.autocast(dtype=dtype):
            predictions = model(images)

    print("Converting pose encoding to extrinsic and intrinsic matrices...")
    extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], images.shape[-2:])
    predictions["extrinsic"] = extrinsic
    predictions["intrinsic"] = intrinsic

    print("Processing model outputs...")
    for key in predictions.keys():
        if isinstance(predictions[key], torch.Tensor):
            predictions[key] = predictions[key].cpu().numpy().squeeze(0)  # remove batch dimension and convert to numpy

    if args.use_point_map:
        print("Visualizing 3D points from point map")
    else:
        print("Visualizing 3D points by unprojecting depth map by cameras")

    if args.mask_sky:
        print("Sky segmentation enabled - will filter out sky points")

    print("Starting viser visualization...")

    viser_server = viser_wrapper(
        predictions,
        port=args.port,
        init_conf_threshold=args.conf_threshold,
        use_point_map=args.use_point_map,
        background_mode=args.background_mode,
        mask_sky=args.mask_sky,
        image_folder=args.image_folder,
    )
    print("Visualization complete")


if __name__ == "__main__":
    main()


================================================
FILE: docs/package.md
================================================
# Alternative Installation Methods

This document explains how to install VGGT as a package using different package managers. 

## Prerequisites

Before installing VGGT as a package, you need to install PyTorch and torchvision. We don't list these as dependencies to avoid CUDA version mismatches. Install them first, with an example as:

```bash
# install pytorch 2.3.1 with cuda 12.1
pip install torch==2.3.1 torchvision==0.18.1 --index-url https://download.pytorch.org/whl/cu121
```

## Installation Options

### Install with pip

The simplest way to install VGGT is using pip:

```bash
pip install -e .
```

### Install and run with pixi

[Pixi](https://pixi.sh) is a package management tool for creating reproducible environments.

1. First, [download and install pixi](https://pixi.sh/latest/get_started/)
2. Then run:

```bash
pixi run -e python demo_gradio.py
```

### Install and run with uv

[uv](https://docs.astral.sh/uv/) is a fast Python package installer and resolver.

1. First, [install uv](https://docs.astral.sh/uv/getting-started/installation/)
2. Then run:

```bash
uv run --extra demo demo_gradio.py
```



================================================
FILE: pyproject.toml
================================================
[project]
authors = [{name = "Jianyuan Wang", email = "jianyuan@robots.ox.ac.uk"}]
dependencies = [
    "numpy<2",
    "Pillow",
    "huggingface_hub",
    "einops",
    "safetensors",
    "opencv-python",
]
name = "vggt"
requires-python = ">= 3.10"
version = "0.0.1"

[project.optional-dependencies]
demo = [
    "gradio==5.17.1",
    "viser==0.2.23",
    "tqdm",
    "hydra-core",
    "omegaconf",
    "opencv-python",
    "scipy",
    "onnxruntime",
    "requests",
    "trimesh",
    "matplotlib",
]

# Using setuptools as the build backend
[build-system]
requires = ["setuptools>=61.0", "wheel"]
build-backend = "setuptools.build_meta"

# setuptools configuration
[tool.setuptools.packages.find]
where = ["."]
include = ["vggt*"]

# Pixi configuration
[tool.pixi.workspace]
channels = ["conda-forge"]
platforms = ["linux-64"]

[tool.pixi.pypi-dependencies]
vggt = { path = ".", editable = true }

[tool.pixi.environments]
default = { solve-group = "default" }
demo = { features = ["demo"], solve-group = "default" }

[tool.pixi.tasks]


================================================
FILE: requirements.txt
================================================
torch==2.3.1
torchvision==0.18.1
numpy==1.26.1
Pillow
huggingface_hub
einops
safetensors


================================================
FILE: requirements_demo.txt
================================================
gradio==5.17.1
viser==0.2.23
tqdm
hydra-core
omegaconf
opencv-python
scipy
onnxruntime
requests
trimesh
matplotlib
pydantic==2.10.6
# feel free to skip the dependencies below if you do not need demo_colmap.py
pycolmap==3.10.0
pyceres==2.3
git+https://github.com/jytime/LightGlue.git#egg=lightglue



================================================
FILE: training/README.md
================================================
# Training

This is a re-implementation of our framework for training VGGT. This document shows how to set up the environment and run VGGT training. I have aimed to faithfully reproduce the original training framework, but please open an issue if anything looks off.

## 1. Prerequisites

Before you begin, ensure you have completed the following steps:

1. **Install VGGT as a package:**
   ```bash
   pip install -e .
   ```

2. **Prepare the dataset and annotations:**
   - Download the Co3D dataset from the [official repository](https://github.com/facebookresearch/co3d).
   - Download the required annotation files from [Hugging Face](https://huggingface.co/datasets/JianyuanWang/co3d_anno/tree/main).

## 2. Configuration

After downloading the dataset and annotations, configure the paths in `training/config/default.yaml`.

### Required Path Configuration

1. Open `training/config/default.yaml`
2. Update the following paths with your absolute directory paths:
   - `CO3D_DIR`: Path to your Co3D dataset
   - `CO3D_ANNOTATION_DIR`: Path to your Co3D annotation files
   - `resume_checkpoint_path`: Path to your pre-trained VGGT checkpoint

### Configuration Example

```yaml
data:
  train:
    dataset:
      dataset_configs:
        - _target_: data.datasets.co3d.Co3dDataset
          split: train
          CO3D_DIR: /YOUR/PATH/TO/CO3D
          CO3D_ANNOTATION_DIR: /YOUR/PATH/TO/CO3D_ANNOTATION
# ... same for val ...

checkpoint:
  resume_checkpoint_path: /YOUR/PATH/TO/CKPT
```

## 3. Fine-tuning on Co3D

To fine-tune the provided pre-trained model on the Co3D dataset, run the following command. This example uses 4 GPUs with PyTorch Distributed Data Parallel (DDP):

```bash
torchrun --nproc_per_node=4 launch.py
```

The default configuration in `training/config/default.yaml` is set up for fine-tuning. It automatically resumes from a checkpoint and freezes the model's `aggregator` module during training.

## 4. Training on Multiple Datasets

The dataloader supports multiple datasets naturally. For example, if you have downloaded VKitti using `preprocess/vkitti.sh`, you can train on Co3D+VKitti by configuring:

```yaml
data:
  train:
    dataset:
      _target_: data.composed_dataset.ComposedDataset
      dataset_configs:
        - _target_: data.datasets.co3d.Co3dDataset
          split: train
          CO3D_DIR: /YOUR/PATH/TO/CO3D
          CO3D_ANNOTATION_DIR: /YOUR/PATH/TO/CO3D_ANNOTATION
          len_train: 100000
        - _target_: data.datasets.vkitti.VKittiDataset
          split: train
          VKitti_DIR: /YOUR/PATH/TO/VKitti
          len_train: 100000
          expand_ratio: 8 
```

The ratio of different datasets can be controlled by setting `len_train`. For example, Co3D with `len_train: 10000` and VKitti with `len_train: 2000` will result in Co3D being sampled five times more frequently than VKitti.

## 5. Common Questions

### Memory Management

If you encounter out-of-memory (OOM) errors on your GPU, consider adjusting the following parameters in `training/config/default.yaml`:

- `max_img_per_gpu`: Reduce this value to decrease the batch size per GPU
- `accum_steps`: Sets the number of gradient accumulation steps (default is 2). This feature splits batches into smaller chunks to save memory, though it may slightly increase training time. Note that gradient accumulation was not used for the original VGGT model.

### Learning Rate Tuning

The main hyperparameter to be careful about is learning rate. Note that learning rate depends on the effective batch size, which is `batch_size_per_gpu × num_gpus`. Therefore, I highly recommend trying several learning rates based on your training setup. Generally, trying values like `5e-6`, `1e-5`, `5e-5`, `1e-4`, `5e-4` should be sufficient.

### Tracking Head

The tracking head can slightly improve accuracy but is not necessary. For general cases, especially when GPU resources are limited, we suggest fine-tuning the pre-trained model only with camera and depth heads, which is the setting in `default.yaml`. This will provide good enough results.

### Dataloader Validation

To check if your dataloader is working correctly, the best approach is to visualize its output. You can save the 3D world points as follows and then visually inspect the PLY files:

```python
def save_ply(points, colors, filename):
    import open3d as o3d                
    if torch.is_tensor(points):
        points_visual = points.reshape(-1, 3).cpu().numpy()
    else:
        points_visual = points.reshape(-1, 3)
    if torch.is_tensor(colors):
        points_visual_rgb = colors.reshape(-1, 3).cpu().numpy()
    else:
        points_visual_rgb = colors.reshape(-1, 3)
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(points_visual.astype(np.float64))
    pcd.colors = o3d.utility.Vector3dVector(points_visual_rgb.astype(np.float64))
    o3d.io.write_point_cloud(filename, pcd, write_ascii=True)

# Usage example
save_ply(
    batch["world_points"][0].reshape(-1, 3), 
    batch["images"][0].permute(0, 2, 3, 1).reshape(-1, 3), 
    "debug.ply"
)
```

### Handling Unordered Sequences

For unordered sequences, you can check how we compute the ranking (similarity) between one frame and all other frames, as discussed in [Issue #82](https://github.com/facebookresearch/vggt/issues/82).

### Expected Coordinate System

Camera poses are expected to follow the OpenCV `camera-from-world` convention. Depth maps should be aligned with their corresponding camera poses.


================================================
FILE: training/__init__.py
================================================


================================================
FILE: training/config/default.yaml
================================================
defaults:
  - default_dataset.yaml

exp_name: exp001
img_size: 518
num_workers: 8
seed_value: 42
accum_steps: 2    # We did not use gradient accumulation in our training, while if you suffer from OOM, you can try to use it.
patch_size: 14
val_epoch_freq: 5
max_img_per_gpu: 48

limit_train_batches: 800
limit_val_batches: 400


data:
  # The code for data still looks too complicated. I should refactor this again (do I have time?...)
  train:
    _target_: data.dynamic_dataloader.DynamicTorchDataset
    num_workers: ${num_workers}
    max_img_per_gpu: ${max_img_per_gpu}
    common_config:
      img_size: ${img_size}
      patch_size: ${patch_size}
      debug: False
      repeat_batch: False
    dataset:
      _target_: data.composed_dataset.ComposedDataset
      dataset_configs:
        - _target_: data.datasets.co3d.Co3dDataset
          split: train
          CO3D_DIR: /YOUR/PATH/TO/CO3D
          CO3D_ANNOTATION_DIR: /YOUR/PATH/TO/CO3D_ANNOTATION
  val:
    _target_: data.dynamic_dataloader.DynamicTorchDataset
    num_workers: ${num_workers}
    max_img_per_gpu: ${max_img_per_gpu}
    common_config:
      img_size: ${img_size}
      patch_size: ${patch_size}
      debug: False
    dataset:
      _target_: data.composed_dataset.ComposedDataset
      dataset_configs:
        - _target_: data.datasets.co3d.Co3dDataset
          split: test
          CO3D_DIR: /YOUR/PATH/TO/CO3D
          CO3D_ANNOTATION_DIR: /YOUR/PATH/TO/CO3D_ANNOTATION


logging:
  log_dir: logs
  log_visuals: False
  log_freq: 1
  log_level_primary: DEBUG
  log_level_secondary: WARNING
  all_ranks: False
  tensorboard_writer:
    _target_: train_utils.tb_writer.TensorBoardLogger
    path: ${logging.log_dir}/tensorboard
  scalar_keys_to_log:
    train:
      keys_to_log:
        - loss_objective
        - loss_camera
        - loss_T
        - loss_R
        - loss_FL
        - loss_conf_depth
        - loss_reg_depth
        - loss_grad_depth
    val:
      keys_to_log:
        - loss_objective
        - loss_camera
        - loss_T
        - loss_R
        - loss_FL
        - loss_conf_depth
        - loss_reg_depth
        - loss_grad_depth



checkpoint:
  save_dir: logs/${exp_name}/ckpts
  save_freq: 5
  resume_checkpoint_path: /YOUR/PATH/TO/CKPT
  strict: False


loss:
  _target_: loss.MultitaskLoss
  camera: 
    weight: 5.0
    loss_type: "l1" # The paper uses smooth l1 loss, but we found l1 loss is more stable than smooth l1 and l2 loss.  
  depth:
    weight: 1.0
    gradient_loss_fn: "grad" 
    valid_range: 0.98
  point: null
  # If you want to enable point, use the following config
  # point: 
  #   weight: 1.0
  #   gradient_loss_fn: "normal" 
  #   valid_range: 0.98
  track: null   




optim:
  param_group_modifiers: False

  optimizer:
    _target_: torch.optim.AdamW
    lr: 5e-5
    weight_decay: 0.05

  frozen_module_names:
      - "*aggregator*"  # example, freeze the aggregator

  amp:
    enabled: True
    amp_dtype: bfloat16
  gradient_clip:
    _target_: train_utils.gradient_clip.GradientClipper
    configs:
      - module_name: ["aggregator"]
        max_norm: 1.0   # feel free to reduce this if you see instabilities
        norm_type: 2
      - module_name: ["depth"]
        max_norm: 1.0   # feel free to reduce this if you see instabilities
        norm_type: 2
      - module_name: ["camera"]
        max_norm: 1.0   # feel free to reduce this if you see instabilities
        norm_type: 2
  options:
    lr:
      - scheduler:
          _target_: fvcore.common.param_scheduler.CompositeParamScheduler
          schedulers:
            - _target_: fvcore.common.param_scheduler.LinearParamScheduler
              start_value: 1e-8
              end_value: 5e-5
            - _target_: fvcore.common.param_scheduler.CosineParamScheduler
              start_value: 5e-5
              end_value: 1e-8
          lengths: [0.05, 0.95]
          interval_scaling: ['rescaled', 'rescaled']
    weight_decay:
      - scheduler:
          _target_: fvcore.common.param_scheduler.ConstantParamScheduler
          value: 0.05




max_epochs: 20

model:
  _target_: vggt.models.vggt.VGGT
  enable_camera: True
  enable_depth: True
  enable_point: False
  enable_track: False


distributed:
  # check https://docs.pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html for options
  backend: nccl
  comms_dtype: None
  find_unused_parameters: False
  timeout_mins: 30
  gradient_as_bucket_view: True  # Less memory used
  bucket_cap_mb: 25
  broadcast_buffers: True

cuda:
    cudnn_deterministic: False
    cudnn_benchmark: False
    allow_tf32: True


================================================
FILE: training/config/default_dataset.yaml
================================================
# Template for the dataset config
data:
  # The code still looks too complicated. I should refactor this again (do I have time?...)
  train:
    _target_: data.dynamic_dataloader.DynamicTorchDataset
    num_workers: 8
    max_img_per_gpu: 48
    # Shuffling in PyTorch DataLoader can sometimes copy large dicts and exceed CPU memory
    # (see: https://github.com/pytorch/pytorch/issues/13246).
    # To avoid this, set shuffle=False and enable common_config.inside_random=True instead.
    shuffle: True
    pin_memory: False
    common_config:  # common config for evaluation
      fix_img_num: -1 # -1 means do not fix the number of images
      fix_aspect_ratio: 1.0
      load_track: False
      track_num: 1024
      training: True
      inside_random: True
      img_size: 224
      patch_size: 14
      rescale: True
      rescale_aug: True
      landscape_check: False
      debug: False
      get_nearby: True
      load_depth: True
      img_nums: [2, 24]
      max_img_per_gpu: 48
      allow_duplicate_img: True
      repeat_batch: False

      augs:
        cojitter: True
        cojitter_ratio: 0.3
        scales: [0.8, 1.2]
        aspects: [0.33, 1.0]
        color_jitter:
          brightness: 0.5
          contrast: 0.5
          saturation: 0.5
          hue: 0.1
          p: 0.9
        gray_scale: True
        gau_blur: False
  val:
    _target_: data.dynamic_dataloader.DynamicTorchDataset
    num_workers: 8
    max_img_per_gpu: 48
    # Shuffling in PyTorch DataLoader can sometimes copy large dicts and exceed CPU memory
    # (see: https://github.com/pytorch/pytorch/issues/13246).
    # To avoid this, set shuffle=False and enable common_config.inside_random=True instead.
    shuffle: True
    pin_memory: False
    common_config:  # common config for evaluation
      fix_img_num: -1 # -1 means do not fix the number of images
      fix_aspect_ratio: 1.0
      load_track: False
      track_num: 1024
      training: False
      inside_random: True
      img_size: 224
      patch_size: 14
      rescale: True
      rescale_aug: False
      landscape_check: False
      debug: False
      get_nearby: True
      load_depth: True
      img_nums: [2, 12]
      allow_duplicate_img: True

      augs:
        cojitter: False
        cojitter_ratio: 0.5
        scales: null
        aspects: [1.0, 1.0]
        color_jitter: null
        gray_scale: False
        gau_blur: False

================================================
FILE: training/data/__init__.py
================================================


================================================
FILE: training/data/augmentation.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional, Dict
from torchvision import transforms


def get_image_augmentation(
    color_jitter: Optional[Dict[str, float]] = None,
    gray_scale: bool = True,
    gau_blur: bool = False
) -> Optional[transforms.Compose]:
    """Create a composition of image augmentations.

    Args:
        color_jitter: Dictionary containing color jitter parameters:
            - brightness: float (default: 0.5)
            - contrast: float (default: 0.5)
            - saturation: float (default: 0.5)
            - hue: float (default: 0.1)
            - p: probability of applying (default: 0.9)
            If None, uses default values
        gray_scale: Whether to apply random grayscale (default: True)
        gau_blur: Whether to apply gaussian blur (default: False)

    Returns:
        A Compose object of transforms or None if no transforms are added
    """
    transform_list = []
    default_jitter = {
        "brightness": 0.5,
        "contrast": 0.5,
        "saturation": 0.5,
        "hue": 0.1,
        "p": 0.9
    }

    # Handle color jitter
    if color_jitter is not None:
        # Merge with defaults for missing keys
        effective_jitter = {**default_jitter, **color_jitter}
    else:
        effective_jitter = default_jitter

    transform_list.append(
        transforms.RandomApply(
            [
                transforms.ColorJitter(
                    brightness=effective_jitter["brightness"],
                    contrast=effective_jitter["contrast"],
                    saturation=effective_jitter["saturation"],
                    hue=effective_jitter["hue"],
                )
            ],
            p=effective_jitter["p"],
        )
    )

    if gray_scale:
        transform_list.append(transforms.RandomGrayscale(p=0.05))

    if gau_blur:
        transform_list.append(
            transforms.RandomApply(
                [transforms.GaussianBlur(5, sigma=(0.1, 1.0))], p=0.05
            )
        )

    return transforms.Compose(transform_list) if transform_list else None


================================================
FILE: training/data/base_dataset.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np
from PIL import Image, ImageFile

from torch.utils.data import Dataset
from .dataset_util import *

Image.MAX_IMAGE_PIXELS = None
ImageFile.LOAD_TRUNCATED_IMAGES = True


class BaseDataset(Dataset):
    """
    Base dataset class for VGGT and VGGSfM training.

    This abstract class handles common operations like image resizing,
    augmentation, and coordinate transformations. Concrete dataset
    implementations should inherit from this class.

    Attributes:
        img_size: Target image size (typically the width)
        patch_size: Size of patches for vit
        augs.scales: Scale range for data augmentation [min, max]
        rescale: Whether to rescale images
        rescale_aug: Whether to apply augmentation during rescaling
        landscape_check: Whether to handle landscape vs portrait orientation
    """
    def __init__(
        self,
        common_conf,
    ):
        """
        Initialize the base dataset with common configuration.

        Args:
            common_conf: Configuration object with the following properties, shared by all datasets:
                - img_size: Default is 518
                - patch_size: Default is 14
                - augs.scales: Default is [0.8, 1.2]
                - rescale: Default is True
                - rescale_aug: Default is True
                - landscape_check: Default is True
        """
        super().__init__()
        self.img_size = common_conf.img_size
        self.patch_size = common_conf.patch_size
        self.aug_scale = common_conf.augs.scales
        self.rescale = common_conf.rescale
        self.rescale_aug = common_conf.rescale_aug
        self.landscape_check = common_conf.landscape_check

    def __len__(self):
        return self.len_train

    def __getitem__(self, idx_N):
        """
        Get an item from the dataset.

        Args:
            idx_N: Tuple containing (seq_index, img_per_seq, aspect_ratio)

        Returns:
            Dataset item as returned by get_data()
        """
        seq_index, img_per_seq, aspect_ratio = idx_N
        return self.get_data(
            seq_index=seq_index, img_per_seq=img_per_seq, aspect_ratio=aspect_ratio
        )

    def get_data(self, seq_index=None, seq_name=None, ids=None, aspect_ratio=1.0):
        """
        Abstract method to retrieve data for a given sequence.

        Args:
            seq_index (int, optional): Index of the sequence
            seq_name (str, optional): Name of the sequence
            ids (list, optional): List of frame IDs
            aspect_ratio (float, optional): Target aspect ratio.

        Returns:
            Dataset-specific data

        Raises:
            NotImplementedError: This method must be implemented by subclasses
        """
        raise NotImplementedError(
            "This is an abstract method and should be implemented in the subclass, i.e., each dataset should implement its own get_data method."
        )

    def get_target_shape(self, aspect_ratio):
        """
        Calculate the target shape based on the given aspect ratio.

        Args:
            aspect_ratio: Target aspect ratio

        Returns:
            numpy.ndarray: Target image shape [height, width]
        """
        short_size = int(self.img_size * aspect_ratio)
        small_size = self.patch_size

        # ensure the input shape is friendly to vision transformer
        if short_size % small_size != 0:
            short_size = (short_size // small_size) * small_size

        image_shape = np.array([short_size, self.img_size])
        return image_shape

    def process_one_image(
        self,
        image,
        depth_map,
        extri_opencv,
        intri_opencv,
        original_size,
        target_image_shape,
        track=None,
        filepath=None,
        safe_bound=4,
    ):
        """
        Process a single image and its associated data.

        This method handles image transformations, depth processing, and coordinate conversions.

        Args:
            image (numpy.ndarray): Input image array
            depth_map (numpy.ndarray): Depth map array
            extri_opencv (numpy.ndarray): Extrinsic camera matrix (OpenCV convention)
            intri_opencv (numpy.ndarray): Intrinsic camera matrix (OpenCV convention)
            original_size (numpy.ndarray): Original image size [height, width]
            target_image_shape (numpy.ndarray): Target image shape after processing
            track (numpy.ndarray, optional): Optional tracking information. Defaults to None.
            filepath (str, optional): Optional file path for debugging. Defaults to None.
            safe_bound (int, optional): Safety margin for cropping operations. Defaults to 4.

        Returns:
            tuple: (
                image (numpy.ndarray): Processed image,
                depth_map (numpy.ndarray): Processed depth map,
                extri_opencv (numpy.ndarray): Updated extrinsic matrix,
                intri_opencv (numpy.ndarray): Updated intrinsic matrix,
                world_coords_points (numpy.ndarray): 3D points in world coordinates,
                cam_coords_points (numpy.ndarray): 3D points in camera coordinates,
                point_mask (numpy.ndarray): Boolean mask of valid points,
                track (numpy.ndarray, optional): Updated tracking information
            )
        """
        # Make copies to avoid in-place operations affecting original data
        image = np.copy(image)
        depth_map = np.copy(depth_map)
        extri_opencv = np.copy(extri_opencv)
        intri_opencv = np.copy(intri_opencv)
        if track is not None:
            track = np.copy(track)

        # Apply random scale augmentation during training if enabled
        if self.training and self.aug_scale:
            random_h_scale, random_w_scale = np.random.uniform(
                self.aug_scale[0], self.aug_scale[1], 2
            )
            # Avoid random padding by capping at 1.0
            random_h_scale = min(random_h_scale, 1.0)
            random_w_scale = min(random_w_scale, 1.0)
            aug_size = original_size * np.array([random_h_scale, random_w_scale])
            aug_size = aug_size.astype(np.int32)
        else:
            aug_size = original_size

        # Move principal point to the image center and crop if necessary
        image, depth_map, intri_opencv, track = crop_image_depth_and_intrinsic_by_pp(
            image, depth_map, intri_opencv, aug_size, track=track, filepath=filepath,
        )

        original_size = np.array(image.shape[:2])  # update original_size
        target_shape = target_image_shape

        # Handle landscape vs. portrait orientation
        rotate_to_portrait = False
        if self.landscape_check:
            # Switch between landscape and portrait if necessary
            if original_size[0] > 1.25 * original_size[1]:
                if (target_image_shape[0] != target_image_shape[1]) and (np.random.rand() > 0.5):
                    target_shape = np.array([target_image_shape[1], target_image_shape[0]])
                    rotate_to_portrait = True

        # Resize images and update intrinsics
        if self.rescale:
            image, depth_map, intri_opencv, track = resize_image_depth_and_intrinsic(
                image, depth_map, intri_opencv, target_shape, original_size, track=track,
                safe_bound=safe_bound,
                rescale_aug=self.rescale_aug
            )
        else:
            print("Not rescaling the images")

        # Ensure final crop to target shape
        image, depth_map, intri_opencv, track = crop_image_depth_and_intrinsic_by_pp(
            image, depth_map, intri_opencv, target_shape, track=track, filepath=filepath, strict=True,
        )

        # Apply 90-degree rotation if needed
        if rotate_to_portrait:
            assert self.landscape_check
            clockwise = np.random.rand() > 0.5
            image, depth_map, extri_opencv, intri_opencv, track = rotate_90_degrees(
                image,
                depth_map,
                extri_opencv,
                intri_opencv,
                clockwise=clockwise,
                track=track,
            )

        # Convert depth to world and camera coordinates
        world_coords_points, cam_coords_points, point_mask = (
            depth_to_world_coords_points(depth_map, extri_opencv, intri_opencv)
        )

        return (
            image,
            depth_map,
            extri_opencv,
            intri_opencv,
            world_coords_points,
            cam_coords_points,
            point_mask,
            track,
        )

    def get_nearby_ids(self, ids, full_seq_num, expand_ratio=None, expand_range=None):
        """
        TODO: add the function to sample the ids by pose similarity ranking.

        Sample a set of IDs from a sequence close to a given start index.

        You can specify the range either as a ratio of the number of input IDs
        or as a fixed integer window.


        Args:
            ids (list): Initial list of IDs. The first element is used as the anchor.
            full_seq_num (int): Total number of items in the full sequence.
            expand_ratio (float, optional): Factor by which the number of IDs expands
                around the start index. Default is 2.0 if neither expand_ratio nor
                expand_range is provided.
            expand_range (int, optional): Fixed number of items to expand around the
                start index. If provided, expand_ratio is ignored.

        Returns:
            numpy.ndarray: Array of sampled IDs, with the first element being the
                original start index.

        Examples:
            # Using expand_ratio (default behavior)
            # If ids=[100,101,102] and full_seq_num=200, with expand_ratio=2.0,
            # expand_range = int(3 * 2.0) = 6, so IDs sampled from [94...106] (if boundaries allow).

            # Using expand_range directly
            # If ids=[100,101,102] and full_seq_num=200, with expand_range=10,
            # IDs are sampled from [90...110] (if boundaries allow).

        Raises:
            ValueError: If no IDs are provided.
        """
        if len(ids) == 0:
            raise ValueError("No IDs provided.")

        if expand_range is None and expand_ratio is None:
            expand_ratio = 2.0  # Default behavior

        total_ids = len(ids)
        start_idx = ids[0]

        # Determine the actual expand_range
        if expand_range is None:
            # Use ratio to determine range
            expand_range = int(total_ids * expand_ratio)

        # Calculate valid boundaries
        low_bound = max(0, start_idx - expand_range)
        high_bound = min(full_seq_num, start_idx + expand_range)

        # Create the valid range of indices
        valid_range = np.arange(low_bound, high_bound)

        # Sample 'total_ids - 1' items, because we already have the start_idx
        sampled_ids = np.random.choice(
            valid_range,
            size=(total_ids - 1),
            replace=True,   # we accept the situation that some sampled ids are the same
        )

        # Insert the start_idx at the beginning
        result_ids = np.insert(sampled_ids, 0, start_idx)

        return result_ids


================================================
FILE: training/data/composed_dataset.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from abc import ABC

from hydra.utils import instantiate
import torch
import random
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import ConcatDataset
import bisect
from .dataset_util import *
from .track_util import *
from .augmentation import get_image_augmentation


class ComposedDataset(Dataset, ABC):
    """
    Composes multiple base datasets and applies common configurations.

    This dataset provides a flexible way to combine multiple base datasets while
    applying shared augmentations, track generation, and other processing steps.
    It handles image normalization, tensor conversion, and other preparations
    needed for training computer vision models with sequences of images.
    """
    def __init__(self, dataset_configs: dict, common_config: dict, **kwargs):
        """
        Initializes the ComposedDataset.

        Args:
            dataset_configs (dict): List of Hydra configurations for base datasets.
            common_config (dict): Shared configurations (augs, tracks, mode, etc.).
            **kwargs: Additional arguments (unused).
        """
        base_dataset_list = []

        # Instantiate each base dataset with common configuration
        for baseset_dict in dataset_configs:
            baseset = instantiate(baseset_dict, common_conf=common_config)
            base_dataset_list.append(baseset)

        # Use custom concatenation class that supports tuple indexing
        self.base_dataset = TupleConcatDataset(base_dataset_list, common_config)

        # --- Augmentation Settings ---
        # Controls whether to apply identical color jittering across all frames in a sequence
        self.cojitter = common_config.augs.cojitter
        # Probability of using shared jitter vs. frame-specific jitter
        self.cojitter_ratio = common_config.augs.cojitter_ratio
        # Initialize image augmentations (color jitter, grayscale, gaussian blur)
        self.image_aug = get_image_augmentation(
            color_jitter=common_config.augs.color_jitter,
            gray_scale=common_config.augs.gray_scale,
            gau_blur=common_config.augs.gau_blur,
        )

        # --- Optional Fixed Settings (useful for debugging) ---
        # Force each sequence to have exactly this many images (if > 0)
        self.fixed_num_images = common_config.fix_img_num
        # Force a specific aspect ratio for all images
        self.fixed_aspect_ratio = common_config.fix_aspect_ratio

        # --- Track Settings ---
        # Whether to include point tracks in the output
        self.load_track = common_config.load_track
        # Number of point tracks to include per sequence
        self.track_num = common_config.track_num

        # --- Mode Settings ---
        # Whether the dataset is being used for training (affects augmentations)
        self.training = common_config.training
        self.common_config = common_config

        self.total_samples = len(self.base_dataset)

    def __len__(self):
        """Returns the total number of sequences in the dataset."""
        return self.total_samples


    def __getitem__(self, idx_tuple):
        """
        Retrieves a data sample (sequence) from the dataset.

        Loads raw data, converts to PyTorch tensors, applies augmentations,
        and prepares tracks if enabled.

        Args:
            idx_tuple (tuple): a tuple of (seq_idx, num_images, aspect_ratio)

        Returns:
            dict: A dictionary containing the sequence data (images, poses, tracks, etc.).
        """
        # If fixed settings are provided, override the tuple values
        if self.fixed_num_images > 0:
            seq_idx = idx_tuple[0] if isinstance(idx_tuple, tuple) else idx_tuple
            idx_tuple = (seq_idx, self.fixed_num_images, self.fixed_aspect_ratio)

        # Retrieve the raw data batch from the appropriate base dataset
        batch = self.base_dataset[idx_tuple]
        seq_name = batch["seq_name"]

        # --- Data Conversion and Preparation ---
        # Convert numpy arrays to tensors
        images = torch.from_numpy(np.stack(batch["images"]).astype(np.float32)).contiguous()
        # Normalize images from [0, 255] to [0, 1]
        images = images.permute(0,3,1,2).to(torch.get_default_dtype()).div(255)

        # Convert other data to tensors with appropriate types
        depths = torch.from_numpy(np.stack(batch["depths"]).astype(np.float32))
        extrinsics = torch.from_numpy(np.stack(batch["extrinsics"]).astype(np.float32))
        intrinsics = torch.from_numpy(np.stack(batch["intrinsics"]).astype(np.float32))
        cam_points = torch.from_numpy(np.stack(batch["cam_points"]).astype(np.float32))
        world_points = torch.from_numpy(np.stack(batch["world_points"]).astype(np.float32))
        point_masks = torch.from_numpy(np.stack(batch["point_masks"])) # Mask indicating valid depths / world points / cam points per frame
        ids = torch.from_numpy(batch["ids"])    # Frame indices sampled from the original sequence


        # --- Apply Color Augmentation (training mode only) ---
        if self.training and self.image_aug is not None:
            if self.cojitter and random.random() > self.cojitter_ratio:
                # Apply the same color jittering transformation to all frames
                images = self.image_aug(images)
            else:
                # Apply different color jittering to each frame individually
                for aug_img_idx in range(len(images)):
                    images[aug_img_idx] = self.image_aug(images[aug_img_idx])


        # --- Prepare Final Sample Dictionary ---
        sample = {
            "seq_name": seq_name,
            "ids": ids,
            "images": images,
            "depths": depths,
            "extrinsics": extrinsics,
            "intrinsics": intrinsics,
            "cam_points": cam_points,
            "world_points": world_points,
            "point_masks": point_masks,
        }

        # --- Track Processing (if enabled) ---
        if self.load_track:
            if batch["tracks"] is not None:
                # Use pre-computed tracks from the dataset
                tracks = torch.from_numpy(np.stack(batch["tracks"]).astype(np.float32))
                track_vis_mask = torch.from_numpy(np.stack(batch["track_masks"]).astype(bool))

                # Sample a subset of tracks randomly
                valid_indices = torch.where(track_vis_mask[0])[0]
                if len(valid_indices) >= self.track_num:
                    # If we have enough tracks, sample without replacement
                    sampled_indices = valid_indices[torch.randperm(len(valid_indices))][:self.track_num]
                else:
                    # If not enough tracks, sample with replacement (allow duplicates)
                    sampled_indices = valid_indices[torch.randint(0, len(valid_indices),
                                                    (self.track_num,),
                                                    dtype=torch.int64,
                                                    device=valid_indices.device)]

                # Extract the sampled tracks and their masks
                tracks = tracks[:, sampled_indices, :]
                track_vis_mask = track_vis_mask[:, sampled_indices]
                track_positive_mask = torch.ones(track_vis_mask.shape[1]).bool()

            else:
                # Generate tracks on-the-fly using depth information
                # This creates synthetic tracks based on the 3D information available
                tracks, track_vis_mask, track_positive_mask = build_tracks_by_depth(
                    extrinsics, intrinsics, world_points, depths, point_masks, images,
                    target_track_num=self.track_num, seq_name=seq_name
                )

            # Add track information to the sample dictionary
            sample["tracks"] = tracks
            sample["track_vis_mask"] = track_vis_mask
            sample["track_positive_mask"] = track_positive_mask

        return sample


class TupleConcatDataset(ConcatDataset):
    """
    A custom ConcatDataset that supports indexing with a tuple.

    Standard PyTorch ConcatDataset only accepts an integer index. This class extends
    that functionality to allow passing a tuple like (sample_idx, num_images, aspect_ratio),
    where the first element is used to determine which sample to fetch, and the full
    tuple is passed down to the selected dataset's __getitem__ method.

    It also supports an option to randomly sample across all datasets, ignoring the
    provided index. This is useful during training when shuffling the entire dataset
    might cause memory issues due to duplicating dictionaries. If doing this, you can
    set pytorch's dataloader shuffle to False.
    """
    def __init__(self, datasets, common_config):
        """
        Initialize the TupleConcatDataset.

        Args:
            datasets (iterable): An iterable of PyTorch Dataset objects to concatenate.
            common_config (dict): Common configuration dict, used to check for random sampling.
        """
        super().__init__(datasets)
        # If True, ignores the input index and samples randomly across all datasets
        # This provides an alternative to dataloader shuffling for large datasets
        self.inside_random = common_config.inside_random

    def __getitem__(self, idx):
        """
        Retrieves an item using either an integer index or a tuple index.

        Args:
            idx (int or tuple): The index. If tuple, the first element is the sequence
                               index across the concatenated datasets, and the rest are
                               passed down. If int, it's treated as the sequence index.

        Returns:
            The item returned by the underlying dataset's __getitem__ method.

        Raises:
            ValueError: If the index is out of range or the tuple doesn't have exactly 3 elements.
        """
        idx_tuple = None
        if isinstance(idx, tuple):
            idx_tuple = idx
            idx = idx_tuple[0]  # Extract the sequence index

        # Override index with random value if inside_random is enabled
        if self.inside_random:
            total_len = self.cumulative_sizes[-1]
            idx = random.randint(0, total_len - 1)

        # Handle negative indices
        if idx < 0:
            if -idx > len(self):
                raise ValueError(
                    "absolute value of index should not exceed dataset length"
                )
            idx = len(self) + idx

        # Find which dataset the index belongs to
        dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
        if dataset_idx == 0:
            sample_idx = idx
        else:
            sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]

        # Create the tuple to pass to the underlying dataset
        if len(idx_tuple) == 3:
            idx_tuple = (sample_idx,) + idx_tuple[1:]
        else:
            raise ValueError("Tuple index must have exactly three elements")

        # Pass the modified tuple to the appropriate dataset
        return self.datasets[dataset_idx][idx_tuple]


================================================
FILE: training/data/dataset_util.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import os
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
import cv2
import math
import numpy as np
from PIL import Image
import PIL
try:
    lanczos = PIL.Image.Resampling.LANCZOS
    bicubic = PIL.Image.Resampling.BICUBIC
except AttributeError:
    lanczos = PIL.Image.LANCZOS
    bicubic = PIL.Image.BICUBIC

from vggt.utils.geometry import closed_form_inverse_se3



#####################################################################################################################
def crop_image_depth_and_intrinsic_by_pp(
    image, depth_map, intrinsic, target_shape, track=None, filepath=None, strict=False
):
    """
    TODO: some names of width and height seem not consistent. Need to check.
    
    
    Crops the given image and depth map around the camera's principal point, as defined by `intrinsic`.
    Specifically:
      - Ensures that the crop is centered on (cx, cy).
      - Optionally pads the image (and depth map) if `strict=True` and the result is smaller than `target_shape`.
      - Shifts the camera intrinsic matrix (and `track` if provided) accordingly.

    Args:
        image (np.ndarray):
            Input image array of shape (H, W, 3).
        depth_map (np.ndarray or None):
            Depth map array of shape (H, W), or None if not available.
        intrinsic (np.ndarray):
            Camera intrinsic matrix (3x3). The principal point is assumed to be at (intrinsic[1,2], intrinsic[0,2]).
        target_shape (tuple[int, int]):
            Desired output shape.
        track (np.ndarray or None):
            Optional array of shape (N, 2). Interpreted as (x, y) pixel coordinates. Will be shifted after cropping.
        filepath (str or None):
            An optional file path for debug logging (only used if strict mode triggers warnings).
        strict (bool):
            If True, will zero-pad to ensure the exact target_shape even if the cropped region is smaller.

    Raises:
        AssertionError:
            If the input image is smaller than `target_shape`.
        ValueError:
            If the cropped image is larger than `target_shape` (in strict mode), which should not normally happen.

    Returns:
        tuple:
            (cropped_image, cropped_depth_map, updated_intrinsic, updated_track)

            - cropped_image (np.ndarray): Cropped (and optionally padded) image.
            - cropped_depth_map (np.ndarray or None): Cropped (and optionally padded) depth map.
            - updated_intrinsic (np.ndarray): Intrinsic matrix adjusted for the crop.
            - updated_track (np.ndarray or None): Track array adjusted for the crop, or None if track was not provided.
    """
    original_size = np.array(image.shape)
    intrinsic = np.copy(intrinsic)

    if original_size[0] < target_shape[0]:
        error_message = (
            f"Width check failed: original width {original_size[0]} "
            f"is less than target width {target_shape[0]}."
        )
        print(error_message)
        raise AssertionError(error_message)

    if original_size[1] < target_shape[1]:
        error_message = (
            f"Height check failed: original height {original_size[1]} "
            f"is less than target height {target_shape[1]}."
        )
        print(error_message)
        raise AssertionError(error_message)

    # Identify principal point (cx, cy) from intrinsic
    cx = (intrinsic[1, 2])
    cy = (intrinsic[0, 2])

    # Compute how far we can crop in each direction
    if strict:
        half_x = min((target_shape[0] / 2), cx)
        half_y = min((target_shape[1] / 2), cy)
    else:
        half_x = min((target_shape[0] / 2), cx, original_size[0] - cx)
        half_y = min((target_shape[1] / 2), cy, original_size[1] - cy)

    # Compute starting indices
    start_x = math.floor(cx) - math.floor(half_x)
    start_y = math.floor(cy) - math.floor(half_y)

    assert start_x >= 0
    assert start_y >= 0

    # Compute ending indices
    if strict:
        end_x = start_x + target_shape[0]
        end_y = start_y + target_shape[1]
    else:
        end_x = start_x + 2 * math.floor(half_x)
        end_y = start_y + 2 * math.floor(half_y)

    # Perform the crop
    image = image[start_x:end_x, start_y:end_y, :]
    if depth_map is not None:
        depth_map = depth_map[start_x:end_x, start_y:end_y]

    # Shift the principal point in the intrinsic
    intrinsic[1, 2] = intrinsic[1, 2] - start_x
    intrinsic[0, 2] = intrinsic[0, 2] - start_y

    # Adjust track if provided
    if track is not None:
        track[:, 1] = track[:, 1] - start_x
        track[:, 0] = track[:, 0] - start_y

    # If strict, zero-pad if the new shape is smaller than target_shape
    if strict:
        if (image.shape[:2] != target_shape).any():
            print(f"{filepath} does not meet the target shape")
            current_h, current_w = image.shape[:2]
            target_h, target_w = target_shape[0], target_shape[1]
            pad_h = target_h - current_h
            pad_w = target_w - current_w
            if pad_h < 0 or pad_w < 0:
                raise ValueError(
                    f"The cropped image is bigger than the target shape: "
                    f"cropped=({current_h},{current_w}), "
                    f"target=({target_h},{target_w})."
                )
            image = np.pad(
                image,
                pad_width=((0, pad_h), (0, pad_w), (0, 0)),
                mode="constant",
                constant_values=0,
            )
            if depth_map is not None:
                depth_map = np.pad(
                    depth_map,
                    pad_width=((0, pad_h), (0, pad_w)),
                    mode="constant",
                    constant_values=0,
                )

    return image, depth_map, intrinsic, track


def resize_image_depth_and_intrinsic(
    image,
    depth_map,
    intrinsic,
    target_shape,
    original_size,
    track=None,
    pixel_center=True,
    safe_bound=4,
    rescale_aug=True,
):
    """
    Resizes the given image and depth map (if provided) to slightly larger than `target_shape`,
    updating the intrinsic matrix (and track array if present). Optionally uses random rescaling
    to create some additional margin (based on `rescale_aug`).

    Steps:
      1. Compute a scaling factor so that the resized result is at least `target_shape + safe_bound`.
      2. Apply an optional triangular random factor if `rescale_aug=True`.
      3. Resize the image with LANCZOS if downscaling, BICUBIC if upscaling.
      4. Resize the depth map with nearest-neighbor.
      5. Update the camera intrinsic and track coordinates (if any).

    Args:
        image (np.ndarray):
            Input image array (H, W, 3).
        depth_map (np.ndarray or None):
            Depth map array (H, W), or None if unavailable.
        intrinsic (np.ndarray):
            Camera intrinsic matrix (3x3).
        target_shape (np.ndarray or tuple[int, int]):
            Desired final shape (height, width).
        original_size (np.ndarray or tuple[int, int]):
            Original size of the image in (height, width).
        track (np.ndarray or None):
            Optional (N, 2) array of pixel coordinates. Will be scaled.
        pixel_center (bool):
            If True, accounts for 0.5 pixel center shift during resizing.
        safe_bound (int or float):
            Additional margin (in pixels) to add to target_shape before resizing.
        rescale_aug (bool):
            If True, randomly increase the `safe_bound` within a certain range to simulate augmentation.

    Returns:
        tuple:
            (resized_image, resized_depth_map, updated_intrinsic, updated_track)

            - resized_image (np.ndarray): The resized image.
            - resized_depth_map (np.ndarray or None): The resized depth map.
            - updated_intrinsic (np.ndarray): Camera intrinsic updated for new resolution.
            - updated_track (np.ndarray or None): Track array updated or None if not provided.

    Raises:
        AssertionError:
            If the shapes of the resized image and depth map do not match.
    """
    if rescale_aug:
        random_boundary = np.random.triangular(0, 0, 0.3)
        safe_bound = safe_bound + random_boundary * target_shape.max()

    resize_scales = (target_shape + safe_bound) / original_size
    max_resize_scale = np.max(resize_scales)
    intrinsic = np.copy(intrinsic)

    # Convert image to PIL for resizing
    image = Image.fromarray(image)
    input_resolution = np.array(image.size)
    output_resolution = np.floor(input_resolution * max_resize_scale).astype(int)
    image = image.resize(tuple(output_resolution), resample=lanczos if max_resize_scale < 1 else bicubic)
    image = np.array(image)

    if depth_map is not None:
        depth_map = cv2.resize(
            depth_map,
            output_resolution,
            fx=max_resize_scale,
            fy=max_resize_scale,
            interpolation=cv2.INTER_NEAREST,
        )

    actual_size = np.array(image.shape[:2])
    actual_resize_scale = np.max(actual_size / original_size)

    if pixel_center:
        intrinsic[0, 2] = intrinsic[0, 2] + 0.5
        intrinsic[1, 2] = intrinsic[1, 2] + 0.5

    intrinsic[:2, :] = intrinsic[:2, :] * actual_resize_scale

    if track is not None:
        track = track * actual_resize_scale

    if pixel_center:
        intrinsic[0, 2] = intrinsic[0, 2] - 0.5
        intrinsic[1, 2] = intrinsic[1, 2] - 0.5

    assert image.shape[:2] == depth_map.shape[:2]
    return image, depth_map, intrinsic, track


def threshold_depth_map(
    depth_map: np.ndarray,
    max_percentile: float = 99,
    min_percentile: float = 1,
    max_depth: float = -1,
) -> np.ndarray:
    """
    Thresholds a depth map using percentile-based limits and optional maximum depth clamping.

    Steps:
      1. If `max_depth > 0`, clamp all values above `max_depth` to zero.
      2. Compute `max_percentile` and `min_percentile` thresholds using nanpercentile.
      3. Zero out values above/below these thresholds, if thresholds are > 0.

    Args:
        depth_map (np.ndarray):
            Input depth map (H, W).
        max_percentile (float):
            Upper percentile (0-100). Values above this will be set to zero.
        min_percentile (float):
            Lower percentile (0-100). Values below this will be set to zero.
        max_depth (float):
            Absolute maximum depth. If > 0, any depth above this is set to zero.
            If <= 0, no maximum-depth clamp is applied.

    Returns:
        np.ndarray:
            Depth map (H, W) after thresholding. Some or all values may be zero.
            Returns None if depth_map is None.
    """
    if depth_map is None:
        return None

    depth_map = depth_map.astype(float, copy=True)

    # Optional clamp by max_depth
    if max_depth > 0:
        depth_map[depth_map > max_depth] = 0.0

    # Percentile-based thresholds
    depth_max_thres = (
        np.nanpercentile(depth_map, max_percentile) if max_percentile > 0 else None
    )
    depth_min_thres = (
        np.nanpercentile(depth_map, min_percentile) if min_percentile > 0 else None
    )

    # Apply the thresholds if they are > 0
    if depth_max_thres is not None and depth_max_thres > 0:
        depth_map[depth_map > depth_max_thres] = 0.0
    if depth_min_thres is not None and depth_min_thres > 0:
        depth_map[depth_map < depth_min_thres] = 0.0

    return depth_map


def depth_to_world_coords_points(
    depth_map: np.ndarray,
    extrinsic: np.ndarray,
    intrinsic: np.ndarray,
    eps=1e-8,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Converts a depth map to world coordinates (HxWx3) given the camera extrinsic and intrinsic.
    Returns both the world coordinates and the intermediate camera coordinates,
    as well as a mask for valid depth.

    Args:
        depth_map (np.ndarray):
            Depth map of shape (H, W).
        extrinsic (np.ndarray):
            Extrinsic matrix of shape (3, 4), representing the camera pose in OpenCV convention (camera-from-world).
        intrinsic (np.ndarray):
            Intrinsic matrix of shape (3, 3).
        eps (float):
            Small epsilon for thresholding valid depth.

    Returns:
        tuple[np.ndarray, np.ndarray, np.ndarray]:
            (world_coords_points, cam_coords_points, point_mask)

            - world_coords_points: (H, W, 3) array of 3D points in world frame.
            - cam_coords_points: (H, W, 3) array of 3D points in camera frame.
            - point_mask: (H, W) boolean array where True indicates valid (non-zero) depth.
    """
    if depth_map is None:
        return None, None, None

    # Valid depth mask
    point_mask = depth_map > eps

    # Convert depth map to camera coordinates
    cam_coords_points = depth_to_cam_coords_points(depth_map, intrinsic)

    # The extrinsic is camera-from-world, so invert it to transform camera->world
    cam_to_world_extrinsic = closed_form_inverse_se3(extrinsic[None])[0]
    R_cam_to_world = cam_to_world_extrinsic[:3, :3]
    t_cam_to_world = cam_to_world_extrinsic[:3, 3]

    # Apply the rotation and translation to the camera coordinates
    world_coords_points = (
        np.dot(cam_coords_points, R_cam_to_world.T) + t_cam_to_world
    ) # HxWx3, 3x3 -> HxWx3
    # world_coords_points = np.einsum("ij,hwj->hwi", R_cam_to_world, cam_coords_points) + t_cam_to_world

    return world_coords_points, cam_coords_points, point_mask


def depth_to_cam_coords_points(
    depth_map: np.ndarray, intrinsic: np.ndarray
) -> np.ndarray:
    """
    Unprojects a depth map into camera coordinates, returning (H, W, 3).

    Args:
        depth_map (np.ndarray):
            Depth map of shape (H, W).
        intrinsic (np.ndarray):
            3x3 camera intrinsic matrix.
            Assumes zero skew and standard OpenCV layout:
            [ fx   0   cx ]
            [  0  fy   cy ]
            [  0   0    1 ]

    Returns:
        np.ndarray:
            An (H, W, 3) array, where each pixel is mapped to (x, y, z) in the camera frame.
    """
    H, W = depth_map.shape
    assert intrinsic.shape == (3, 3), "Intrinsic matrix must be 3x3"
    assert (
        intrinsic[0, 1] == 0 and intrinsic[1, 0] == 0
    ), "Intrinsic matrix must have zero skew"

    # Intrinsic parameters
    fu, fv = intrinsic[0, 0], intrinsic[1, 1]
    cu, cv = intrinsic[0, 2], intrinsic[1, 2]

    # Generate grid of pixel coordinates
    u, v = np.meshgrid(np.arange(W), np.arange(H))
    
    # Unproject to camera coordinates
    x_cam = (u - cu) * depth_map / fu
    y_cam = (v - cv) * depth_map / fv
    z_cam = depth_map

    # Stack to form camera coordinates
    return np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)


def rotate_90_degrees(
    image, depth_map, extri_opencv, intri_opencv, clockwise=True, track=None
):
    """
    Rotates the input image, depth map, and camera parameters by 90 degrees.

    Applies one of two 90-degree rotations:
    - Clockwise
    - Counterclockwise (if clockwise=False)

    The extrinsic and intrinsic matrices are adjusted accordingly to maintain
    correct camera geometry. Track coordinates are also updated if provided.

    Args:
        image (np.ndarray):
            Input image of shape (H, W, 3).
        depth_map (np.ndarray or None):
            Depth map of shape (H, W), or None if not available.
        extri_opencv (np.ndarray):
            Extrinsic matrix (3x4) in OpenCV convention.
        intri_opencv (np.ndarray):
            Intrinsic matrix (3x3).
        clockwise (bool):
            If True, rotates the image 90 degrees clockwise; else 90 degrees counterclockwise.
        track (np.ndarray or None):
            Optional (N, 2) track array. Will be rotated accordingly.

    Returns:
        tuple:
            (
                rotated_image,
                rotated_depth_map,
                new_extri_opencv,
                new_intri_opencv,
                new_track
            )

            Where each is the updated version after the rotation.
    """
    image_height, image_width = image.shape[:2]

    # Rotate the image and depth map
    rotated_image, rotated_depth_map = rotate_image_and_depth_rot90(image, depth_map, clockwise)
    # Adjust the intrinsic matrix
    new_intri_opencv = adjust_intrinsic_matrix_rot90(intri_opencv, image_width, image_height, clockwise)

    if track is not None:
        new_track = adjust_track_rot90(track, image_width, image_height, clockwise)
    else:
        new_track = None

    # Adjust the extrinsic matrix
    new_extri_opencv = adjust_extrinsic_matrix_rot90(extri_opencv, clockwise)

    return (
        rotated_image,
        rotated_depth_map,
        new_extri_opencv,
        new_intri_opencv,
        new_track,
    )


def rotate_image_and_depth_rot90(image, depth_map, clockwise):
    """
    Rotates the given image and depth map by 90 degrees (clockwise or counterclockwise),
    using a transpose+flip pattern.

    Args:
        image (np.ndarray):
            Input image of shape (H, W, 3).
        depth_map (np.ndarray or None):
            Depth map of shape (H, W), or None if not available.
        clockwise (bool):
            If True, rotate 90 degrees clockwise; else 90 degrees counterclockwise.

    Returns:
        tuple:
            (rotated_image, rotated_depth_map)
    """
    rotated_depth_map = None
    if clockwise:
        rotated_image = np.transpose(image, (1, 0, 2))  # Transpose height and width
        rotated_image = np.flip(rotated_image, axis=1)  # Flip horizontally
        if depth_map is not None:
            rotated_depth_map = np.transpose(depth_map, (1, 0))
            rotated_depth_map = np.flip(rotated_depth_map, axis=1)
    else:
        rotated_image = np.transpose(image, (1, 0, 2))  # Transpose height and width
        rotated_image = np.flip(rotated_image, axis=0)  # Flip vertically
        if depth_map is not None:
            rotated_depth_map = np.transpose(depth_map, (1, 0))
            rotated_depth_map = np.flip(rotated_depth_map, axis=0)
    return np.copy(rotated_image), np.copy(rotated_depth_map)


def adjust_extrinsic_matrix_rot90(extri_opencv, clockwise):
    """
    Adjusts the extrinsic matrix (3x4) for a 90-degree rotation of the image.

    The rotation is in the image plane. This modifies the camera orientation
    accordingly. The function applies either a clockwise or counterclockwise
    90-degree rotation.

    Args:
        extri_opencv (np.ndarray):
            Extrinsic matrix (3x4) in OpenCV convention.
        clockwise (bool):
            If True, rotate extrinsic for a 90-degree clockwise image rotation;
            otherwise, counterclockwise.

    Returns:
        np.ndarray:
            A new 3x4 extrinsic matrix after the rotation.
    """
    R = extri_opencv[:, :3]
    t = extri_opencv[:, 3]

    if clockwise:
        R_rotation = np.array([
            [0, -1, 0],
            [1,  0, 0],
            [0,  0, 1]
        ])
    else:
        R_rotation = np.array([
            [0, 1, 0],
            [-1, 0, 0],
            [0, 0, 1]
        ])

    new_R = np.dot(R_rotation, R)
    new_t = np.dot(R_rotation, t)
    new_extri_opencv = np.hstack((new_R, new_t.reshape(-1, 1)))
    return new_extri_opencv


def adjust_intrinsic_matrix_rot90(intri_opencv, image_width, image_height, clockwise):
    """
    Adjusts the intrinsic matrix (3x3) for a 90-degree rotation of the image in the image plane.

    Args:
        intri_opencv (np.ndarray):
            Intrinsic matrix (3x3).
        image_width (int):
            Original width of the image.
        image_height (int):
            Original height of the image.
        clockwise (bool):
            If True, rotate 90 degrees clockwise; else 90 degrees counterclockwise.

    Returns:
        np.ndarray:
            A new 3x3 intrinsic matrix after the rotation.
    """
    fx, fy, cx, cy = (
        intri_opencv[0, 0],
        intri_opencv[1, 1],
        intri_opencv[0, 2],
        intri_opencv[1, 2],
    )

    new_intri_opencv = np.eye(3)
    if clockwise:
        new_intri_opencv[0, 0] = fy
        new_intri_opencv[1, 1] = fx
        new_intri_opencv[0, 2] = image_height - cy
        new_intri_opencv[1, 2] = cx
    else:
        new_intri_opencv[0, 0] = fy
        new_intri_opencv[1, 1] = fx
        new_intri_opencv[0, 2] = cy
        new_intri_opencv[1, 2] = image_width - cx

    return new_intri_opencv


def adjust_track_rot90(track, image_width, image_height, clockwise):
    """
    Adjusts a track (N, 2) for a 90-degree rotation of the image in the image plane.

    Args:
        track (np.ndarray):
            (N, 2) array of pixel coordinates, each row is (x, y).
        image_width (int):
            Original image width.
        image_height (int):
            Original image height.
        clockwise (bool):
            Whether the rotation is 90 degrees clockwise or counterclockwise.

    Returns:
        np.ndarray:
            A new track of shape (N, 2) after rotation.
    """
    if clockwise:
        # (x, y) -> (y, image_width - 1 - x)
        new_track = np.stack((track[:, 1], image_width - 1 - track[:, 0]), axis=-1)
    else:
        # (x, y) -> (image_height - 1 - y, x)
        new_track = np.stack((image_height - 1 - track[:, 1], track[:, 0]), axis=-1)

    return new_track


def read_image_cv2(path: str, rgb: bool = True) -> np.ndarray:
    """
    Reads an image from disk using OpenCV, returning it as an RGB image array (H, W, 3).

    Args:
        path (str):
            File path to the image.
        rgb (bool):
            If True, convert the image to RGB.
            If False, leave the image in BGR/grayscale.

    Returns:
        np.ndarray or None:
            A numpy array of shape (H, W, 3) if successful,
            or None if the file does not exist or could not be read.
    """
    if not os.path.exists(path) or os.path.getsize(path) == 0:
        print(f"File does not exist or is empty: {path}")
        return None

    img = cv2.imread(path)
    if img is None:
        print(f"Could not load image={path}. Retrying...")
        img = cv2.imread(path)
        if img is None:
            print("Retry failed.")
            return None

    if rgb:
        if len(img.shape) == 2:
            img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
        else:
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    return img


def read_depth(path: str, scale_adjustment=1.0) -> np.ndarray:
    """
    Reads a depth map from disk in either .exr or .png format. The .exr is loaded using OpenCV
    with the environment variable OPENCV_IO_ENABLE_OPENEXR=1. The .png is assumed to be a 16-bit
    PNG (converted from half float).

    Args:
        path (str):
            File path to the depth image. Must end with .exr or .png.
        scale_adjustment (float):
            A multiplier for adjusting the loaded depth values (default=1.0).

    Returns:
        np.ndarray:
            A float32 array (H, W) containing the loaded depth. Zeros or non-finite values
            may indicate invalid regions.

    Raises:
        ValueError:
            If the file extension is not supported.
    """
    if path.lower().endswith(".exr"):
        # Ensure OPENCV_IO_ENABLE_OPENEXR is set to "1"
        d = cv2.imread(path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)[..., 0]
        d[d > 1e9] = 0.0
    elif path.lower().endswith(".png"):
        d = load_16big_png_depth(path)
    else:
        raise ValueError(f'unsupported depth file name "{path}"')

    d = d * scale_adjustment
    d[~np.isfinite(d)] = 0.0

    return d


def load_16big_png_depth(depth_png: str) -> np.ndarray:
    """
    Loads a 16-bit PNG as a half-float depth map (H, W), returning a float32 NumPy array.

    Implementation detail:
      - PIL loads 16-bit data as 32-bit "I" mode.
      - We reinterpret the bits as float16, then cast to float32.

    Args:
        depth_png (str):
            File path to the 16-bit PNG.

    Returns:
        np.ndarray:
            A float32 depth array of shape (H, W).
    """
    with Image.open(depth_png) as depth_pil:
        depth = (
            np.frombuffer(np.array(depth_pil, dtype=np.uint16), dtype=np.float16)
            .astype(np.float32)
            .reshape((depth_pil.size[1], depth_pil.size[0]))
        )
    return depth


================================================
FILE: training/data/datasets/co3d.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import gzip
import json
import os.path as osp
import os
import logging

import cv2
import random
import numpy as np


from data.dataset_util import *
from data.base_dataset import BaseDataset


SEEN_CATEGORIES = [
    "apple",
    "backpack",
    "banana",
    "baseballbat",
    "baseballglove",
    "bench",
    "bicycle",
    "bottle",
    "bowl",
    "broccoli",
    "cake",
    "car",
    "carrot",
    "cellphone",
    "chair",
    "cup",
    "donut",
    "hairdryer",
    "handbag",
    "hydrant",
    "keyboard",
    "laptop",
    "microwave",
    "motorcycle",
    "mouse",
    "orange",
    "parkingmeter",
    "pizza",
    "plant",
    "stopsign",
    "teddybear",
    "toaster",
    "toilet",
    "toybus",
    "toyplane",
    "toytrain",
    "toytruck",
    "tv",
    "umbrella",
    "vase",
    "wineglass",
]


class Co3dDataset(BaseDataset):
    def __init__(
        self,
        common_conf,
        split: str = "train",
        CO3D_DIR: str = None,
        CO3D_ANNOTATION_DIR: str = None,
        min_num_images: int = 24,
        len_train: int = 100000,
        len_test: int = 10000,
    ):
        """
        Initialize the Co3dDataset.

        Args:
            common_conf: Configuration object with common settings.
            split (str): Dataset split, either 'train' or 'test'.
            CO3D_DIR (str): Directory path to CO3D data.
            CO3D_ANNOTATION_DIR (str): Directory path to CO3D annotations.
            min_num_images (int): Minimum number of images per sequence.
            len_train (int): Length of the training dataset.
            len_test (int): Length of the test dataset.
        Raises:
            ValueError: If CO3D_DIR or CO3D_ANNOTATION_DIR is not specified.
        """
        super().__init__(common_conf=common_conf)

        self.debug = common_conf.debug
        self.training = common_conf.training
        self.get_nearby = common_conf.get_nearby
        self.load_depth = common_conf.load_depth
        self.inside_random = common_conf.inside_random
        self.allow_duplicate_img = common_conf.allow_duplicate_img

        if CO3D_DIR is None or CO3D_ANNOTATION_DIR is None:
            raise ValueError("Both CO3D_DIR and CO3D_ANNOTATION_DIR must be specified.")

        category = sorted(SEEN_CATEGORIES)

        if self.debug:
            category = ["apple"]

        if split == "train":
            split_name_list = ["train"]
            self.len_train = len_train
        elif split == "test":
            split_name_list = ["test"]
            self.len_train = len_test
        else:
            raise ValueError(f"Invalid split: {split}")

        self.invalid_sequence = [] # set any invalid sequence names here


        self.category_map = {}
        self.data_store = {}
        self.seqlen = None
        self.min_num_images = min_num_images

        logging.info(f"CO3D_DIR is {CO3D_DIR}")

        self.CO3D_DIR = CO3D_DIR
        self.CO3D_ANNOTATION_DIR = CO3D_ANNOTATION_DIR

        total_frame_num = 0

        for c in category:
            for split_name in split_name_list:
                annotation_file = osp.join(
                    self.CO3D_ANNOTATION_DIR, f"{c}_{split_name}.jgz"
                )

                try:
                    with gzip.open(annotation_file, "r") as fin:
                        annotation = json.loads(fin.read())
                except FileNotFoundError:
                    logging.error(f"Annotation file not found: {annotation_file}")
                    continue

                for seq_name, seq_data in annotation.items():
                    if len(seq_data) < min_num_images:
                        continue
                    if seq_name in self.invalid_sequence:
                        continue
                    total_frame_num += len(seq_data)
                    self.data_store[seq_name] = seq_data

        self.sequence_list = list(self.data_store.keys())
        self.sequence_list_len = len(self.sequence_list)
        self.total_frame_num = total_frame_num

        status = "Training" if self.training else "Testing"
        logging.info(f"{status}: Co3D Data size: {self.sequence_list_len}")
        logging.info(f"{status}: Co3D Data dataset length: {len(self)}")

    def get_data(
        self,
        seq_index: int = None,
        img_per_seq: int = None,
        seq_name: str = None,
        ids: list = None,
        aspect_ratio: float = 1.0,
    ) -> dict:
        """
        Retrieve data for a specific sequence.

        Args:
            seq_index (int): Index of the sequence to retrieve.
            img_per_seq (int): Number of images per sequence.
            seq_name (str): Name of the sequence.
            ids (list): Specific IDs to retrieve.
            aspect_ratio (float): Aspect ratio for image processing.

        Returns:
            dict: A batch of data including images, depths, and other metadata.
        """
        if self.inside_random:
            seq_index = random.randint(0, self.sequence_list_len - 1)
            
        if seq_name is None:
            seq_name = self.sequence_list[seq_index]

        metadata = self.data_store[seq_name]

        if ids is None:
            ids = np.random.choice(
                len(metadata), img_per_seq, replace=self.allow_duplicate_img
            )

        annos = [metadata[i] for i in ids]

        target_image_shape = self.get_target_shape(aspect_ratio)

        images = []
        depths = []
        cam_points = []
        world_points = []
        point_masks = []
        extrinsics = []
        intrinsics = []
        image_paths = []
        original_sizes = []

        for anno in annos:
            filepath = anno["filepath"]

            image_path = osp.join(self.CO3D_DIR, filepath)
            image = read_image_cv2(image_path)

            if self.load_depth:
                depth_path = image_path.replace("/images", "/depths") + ".geometric.png"
                depth_map = read_depth(depth_path, 1.0)

                mvs_mask_path = image_path.replace(
                    "/images", "/depth_masks"
                ).replace(".jpg", ".png")
                mvs_mask = cv2.imread(mvs_mask_path, cv2.IMREAD_GRAYSCALE) > 128
                depth_map[~mvs_mask] = 0

                depth_map = threshold_depth_map(
                    depth_map, min_percentile=-1, max_percentile=98
                )
            else:
                depth_map = None

            original_size = np.array(image.shape[:2])
            extri_opencv = np.array(anno["extri"])
            intri_opencv = np.array(anno["intri"])

            (
                image,
                depth_map,
                extri_opencv,
                intri_opencv,
                world_coords_points,
                cam_coords_points,
                point_mask,
                _,
            ) = self.process_one_image(
                image,
                depth_map,
                extri_opencv,
                intri_opencv,
                original_size,
                target_image_shape,
                filepath=filepath,
            )

            images.append(image)
            depths.append(depth_map)
            extrinsics.append(extri_opencv)
            intrinsics.append(intri_opencv)
            cam_points.append(cam_coords_points)
            world_points.append(world_coords_points)
            point_masks.append(point_mask)
            image_paths.append(image_path)
            original_sizes.append(original_size)

        set_name = "co3d"

        batch = {
            "seq_name": set_name + "_" + seq_name,
            "ids": ids,
            "frame_num": len(extrinsics),
            "images": images,
            "depths": depths,
            "extrinsics": extrinsics,
            "intrinsics": intrinsics,
            "cam_points": cam_points,
            "world_points": world_points,
            "point_masks": point_masks,
            "original_sizes": original_sizes,
        }
        return batch


================================================
FILE: training/data/datasets/vkitti.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import os
import os.path as osp
import logging
import random
import glob

import cv2
import numpy as np

from data.dataset_util import *
from data.base_dataset import BaseDataset


class VKittiDataset(BaseDataset):
    def __init__(
        self,
        common_conf,
        split: str = "train",
        VKitti_DIR: str = "/checkpoint/repligen/jianyuan/datasets/vkitti/",
        min_num_images: int = 24,
        len_train: int = 100000,
        len_test: int = 10000,
        expand_ratio: int = 8,
    ):
        """
        Initialize the VKittiDataset.

        Args:
            common_conf: Configuration object with common settings.
            split (str): Dataset split, either 'train' or 'test'.
            VKitti_DIR (str): Directory path to VKitti data.
            min_num_images (int): Minimum number of images per sequence.
            len_train (int): Length of the training dataset.
            len_test (int): Length of the test dataset.
            expand_range (int): Range for expanding nearby image selection.
            get_nearby_thres (int): Threshold for nearby image selection.
        """
        super().__init__(common_conf=common_conf)

        self.debug = common_conf.debug
        self.training = common_conf.training
        self.get_nearby = common_conf.get_nearby
        self.inside_random = common_conf.inside_random
        self.allow_duplicate_img = common_conf.allow_duplicate_img
        
        self.expand_ratio = expand_ratio
        self.VKitti_DIR = VKitti_DIR
        self.min_num_images = min_num_images

        if split == "train":
            self.len_train = len_train
        elif split == "test":
            self.len_train = len_test
        else:
            raise ValueError(f"Invalid split: {split}")
        
        logging.info(f"VKitti_DIR is {self.VKitti_DIR}")

        # Load or generate sequence list
        txt_path = osp.join(self.VKitti_DIR, "sequence_list.txt")
        if osp.exists(txt_path):
            with open(txt_path, 'r') as f:
                sequence_list = [line.strip() for line in f.readlines()]
        else:
            # Generate sequence list and save to txt            
            sequence_list = glob.glob(osp.join(self.VKitti_DIR, "*/*/*/rgb/*"))            
            sequence_list = [file_path.split(self.VKitti_DIR)[-1].lstrip('/') for file_path in sequence_list]
            sequence_list = sorted(sequence_list)

            # Save to txt file
            with open(txt_path, 'w') as f:
                f.write('\n'.join(sequence_list))

        self.sequence_list = sequence_list
        self.sequence_list_len = len(self.sequence_list)

        self.depth_max = 80

        status = "Training" if self.training else "Testing"
        logging.info(f"{status}: VKitti Real Data size: {self.sequence_list_len}")
        logging.info(f"{status}: VKitti Data dataset length: {len(self)}")

    def get_data(
        self,
        seq_index: int = None,
        img_per_seq: int = None,
        seq_name: str = None,
        ids: list = None,
        aspect_ratio: float = 1.0,
    ) -> dict:
        """
        Retrieve data for a specific sequence.

        Args:
            seq_index (int): Index of the sequence to retrieve.
            img_per_seq (int): Number of images per sequence.
            seq_name (str): Name of the sequence.
            ids (list): Specific IDs to retrieve.
            aspect_ratio (float): Aspect ratio for image processing.

        Returns:
            dict: A batch of data including images, depths, and other metadata.
        """
        if self.inside_random and self.training:
            seq_index = random.randint(0, self.sequence_list_len - 1)

        if seq_name is None:
            seq_name = self.sequence_list[seq_index]

        camera_id = int(seq_name[-1])

        # Load camera parameters
        try:
            camera_parameters = np.loadtxt(
                osp.join(self.VKitti_DIR, "/".join(seq_name.split("/")[:2]), "extrinsic.txt"), 
                delimiter=" ", 
                skiprows=1
            )
            camera_parameters = camera_parameters[camera_parameters[:, 1] == camera_id]

            camera_intrinsic = np.loadtxt(
                osp.join(self.VKitti_DIR, "/".join(seq_name.split("/")[:2]), "intrinsic.txt"), 
                delimiter=" ", 
                skiprows=1
            )
            camera_intrinsic = camera_intrinsic[camera_intrinsic[:, 1] == camera_id]
        except Exception as e:
            logging.error(f"Error loading camera parameters for {seq_name}: {e}")
            raise

        num_images = len(camera_parameters)

        if ids is None:
            ids = np.random.choice(num_images, img_per_seq, replace=self.allow_duplicate_img)

        if self.get_nearby:
            ids = self.get_nearby_ids(ids, num_images, expand_ratio=self.expand_ratio)

        target_image_shape = self.get_target_shape(aspect_ratio)

        images = []
        depths = []
        cam_points = []
        world_points = []
        point_masks = []
        extrinsics = []
        intrinsics = []
        original_sizes = []

        for image_idx in ids:
            image_filepath = osp.join(self.VKitti_DIR, seq_name, f"rgb_{image_idx:05d}.jpg")
            depth_filepath = osp.join(self.VKitti_DIR, seq_name, f"depth_{image_idx:05d}.png").replace("/rgb", "/depth")

            image = read_image_cv2(image_filepath)
            depth_map = cv2.imread(depth_filepath, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
            depth_map = depth_map / 100
            depth_map = threshold_depth_map(depth_map, max_percentile=-1, min_percentile=-1, max_depth=self.depth_max)

            assert image.shape[:2] == depth_map.shape, f"Image and depth shape mismatch: {image.shape[:2]} vs {depth_map.shape}"

            original_size = np.array(image.shape[:2])

            # Process camera matrices
            extri_opencv = camera_parameters[image_idx][2:].reshape(4, 4)
            extri_opencv = extri_opencv[:3]

            intri_opencv = np.eye(3)
            intri_opencv[0, 0] = camera_intrinsic[image_idx][-4]
            intri_opencv[1, 1] = camera_intrinsic[image_idx][-3]
            intri_opencv[0, 2] = camera_intrinsic[image_idx][-2]
            intri_opencv[1, 2] = camera_intrinsic[image_idx][-1]

            (
                image,
                depth_map,
                extri_opencv,
                intri_opencv,
                world_coords_points,
                cam_coords_points,
                point_mask,
                _,
            ) = self.process_one_image(
                image,
                depth_map,
                extri_opencv,
                intri_opencv,
                original_size,
                target_image_shape,
                filepath=image_filepath,
            )

            if (image.shape[:2] != target_image_shape).any():
                logging.error(f"Wrong shape for {seq_name}: expected {target_image_shape}, got {image.shape[:2]}")
                continue

            images.append(image)
            depths.append(depth_map)
            extrinsics.append(extri_opencv)
            intrinsics.append(intri_opencv)
            cam_points.append(cam_coords_points)
            world_points.append(world_coords_points)
            point_masks.append(point_mask)
            original_sizes.append(original_size)

        set_name = "vkitti"
        batch = {
            "seq_name": set_name + "_" + seq_name,
            "ids": ids,
            "frame_num": len(extrinsics),
            "images": images,
            "depths": depths,
            "extrinsics": extrinsics,
            "intrinsics": intrinsics,
            "cam_points": cam_points,
            "world_points": world_points,
            "point_masks": point_masks,
            "original_sizes": original_sizes,
        }
        return batch

================================================
FILE: training/data/dynamic_dataloader.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import Callable, Optional

from hydra.utils import instantiate
import random
import numpy as np
from torch.utils.data import DataLoader, Dataset, DistributedSampler, IterableDataset, Sampler
from abc import ABC, abstractmethod

from .worker_fn import get_worker_init_fn

class DynamicTorchDataset(ABC):
    def __init__(
        self,
        dataset: dict,
        common_config: dict,
        num_workers: int,
        shuffle: bool,
        pin_memory: bool,
        drop_last: bool = True,
        collate_fn: Optional[Callable] = None,
        worker_init_fn: Optional[Callable] = None,
        persistent_workers: bool = False,
        seed: int = 42,
        max_img_per_gpu: int = 48,
    ) -> None:
        self.dataset_config = dataset
        self.common_config = common_config
        self.num_workers = num_workers
        self.shuffle = shuffle
        self.pin_memory = pin_memory
        self.drop_last = drop_last
        self.collate_fn = collate_fn
        self.worker_init_fn = worker_init_fn
        self.persistent_workers = persistent_workers
        self.seed = seed
        self.max_img_per_gpu = max_img_per_gpu

        # Instantiate the dataset
        self.dataset = instantiate(dataset, common_config=common_config, _recursive_=False)

        # Extract aspect ratio and image number ranges from the configuration
        self.aspect_ratio_range = common_config.augs.aspects  # e.g., [0.5, 1.0]
        self.image_num_range = common_config.img_nums    # e.g., [2, 24]

        # Validate the aspect ratio and image number ranges
        if len(self.aspect_ratio_range) != 2 or self.aspect_ratio_range[0] > self.aspect_ratio_range[1]:
            raise ValueError(f"aspect_ratio_range must be [min, max] with min <= max, got {self.aspect_ratio_range}")
        if len(self.image_num_range) != 2 or self.image_num_range[0] < 1 or self.image_num_range[0] > self.image_num_range[1]:
            raise ValueError(f"image_num_range must be [min, max] with 1 <= min <= max, got {self.image_num_range}")

        # Create samplers
        self.sampler = DynamicDistributedSampler(self.dataset, seed=seed, shuffle=shuffle)
        self.batch_sampler = DynamicBatchSampler(
            self.sampler,
            self.aspect_ratio_range,
            self.image_num_range,
            seed=seed,
            max_img_per_gpu=max_img_per_gpu
        )

    def get_loader(self, epoch):
        print("Building dynamic dataloader with epoch:", epoch)

        # Set the epoch for the sampler
        self.sampler.set_epoch(epoch)
        if hasattr(self.dataset, "epoch"):
            self.dataset.epoch = epoch
        if hasattr(self.dataset, "set_epoch"):
            self.dataset.set_epoch(epoch)

        # Create and return the dataloader
        return DataLoader(
            self.dataset,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            batch_sampler=self.batch_sampler,
            collate_fn=self.collate_fn,
            persistent_workers=self.persistent_workers,
            worker_init_fn=get_worker_init_fn(
                seed=self.seed,
                num_workers=self.num_workers,
                epoch=epoch,
                worker_init_fn=self.worker_init_fn,
            ),
        )
        

class DynamicBatchSampler(Sampler):
    """
    A custom batch sampler that dynamically adjusts batch size, aspect ratio, and image number
    for each sample. Batches within a sample share the same aspect ratio and image number.
    """
    def __init__(self,
                 sampler,
                 aspect_ratio_range,
                 image_num_range,
                 epoch=0,
                 seed=42,
                 max_img_per_gpu=48):
        """
        Initializes the dynamic batch sampler.

        Args:
            sampler: Instance of DynamicDistributedSampler.
            aspect_ratio_range: List containing [min_aspect_ratio, max_aspect_ratio].
            image_num_range: List containing [min_images, max_images] per sample.
            epoch: Current epoch number.
            seed: Random seed for reproducibility.
            max_img_per_gpu: Maximum number of images to fit in GPU memory.
        """
        self.sampler = sampler
        self.aspect_ratio_range = aspect_ratio_range
        self.image_num_range = image_num_range
        self.rng = random.Random()

        # Uniformly sample from the range of possible image numbers
        # For any image number, the weight is 1.0 (uniform sampling). You can set any different weights here.
        self.image_num_weights = {num_images: 1.0 for num_images in range(image_num_range[0], image_num_range[1]+1)}

        # Possible image numbers, e.g., [2, 3, 4, ..., 24]
        self.possible_nums = np.array([n for n in self.image_num_weights.keys()
                                       if self.image_num_range[0] <= n <= self.image_num_range[1]])

        # Normalize weights for sampling
        weights = [self.image_num_weights[n] for n in self.possible_nums]
        self.normalized_weights = np.array(weights) / sum(weights)

        # Maximum image number per GPU
        self.max_img_per_gpu = max_img_per_gpu

        # Set the epoch for the sampler
        self.set_epoch(epoch + seed)

    def set_epoch(self, epoch):
        """
        Sets the epoch for this sampler, affecting the random sequence.

        Args:
            epoch: The epoch number.
        """
        self.sampler.set_epoch(epoch)
        self.epoch = epoch
        self.rng.seed(epoch * 100)

    def __iter__(self):
        """
        Yields batches of samples with synchronized dynamic parameters.

        Returns:
            Iterator yielding batches of indices with associated parameters.
        """
        sampler_iterator = iter(self.sampler)

        while True:
            try:
                # Sample random image number and aspect ratio
                random_image_num = int(np.random.choice(self.possible_nums, p=self.normalized_weights))
                random_aspect_ratio = round(self.rng.uniform(self.aspect_ratio_range[0], self.aspect_ratio_range[1]), 2)

                # Update sampler parameters
                self.sampler.update_parameters(
                    aspect_ratio=random_aspect_ratio,
                    image_num=random_image_num
                )

                # Calculate batch size based on max images per GPU and current image number
                batch_size = self.max_img_per_gpu / random_image_num
                batch_size = np.floor(batch_size).astype(int)
                batch_size = max(1, batch_size)  # Ensure batch size is at least 1

                # Collect samples for the current batch
                current_batch = []
                for _ in range(batch_size):
                    try:
                        item = next(sampler_iterator)  # item is (idx, aspect_ratio, image_num)
                        current_batch.append(item)
                    except StopIteration:
                        break  # No more samples

                if not current_batch:
                    break  # No more data to yield

                yield current_batch

            except StopIteration:
                break  # End of sampler's iterator

    def __len__(self):
        # Return a large dummy length
        return 1000000


class DynamicDistributedSampler(DistributedSampler):
    """
    Extends PyTorch's DistributedSampler to include dynamic aspect_ratio and image_num
    parameters, which can be passed into the dataset's __getitem__ method.
    """
    def __init__(
        self,
        dataset,
        num_replicas: Optional[int] = None,
        rank: Optional[int] = None,
        shuffle: bool = False,
        seed: int = 0,
        drop_last: bool = False,
    ):
        super().__init__(
            dataset,
            num_replicas=num_replicas,
            rank=rank,
            shuffle=shuffle,
            seed=seed,
            drop_last=drop_last
        )
        self.aspect_ratio = None
        self.image_num = None

    def __iter__(self):
        """
        Yields a sequence of (index, image_num, aspect_ratio).
        Relies on the parent class's logic for shuffling/distributing
        the indices across replicas, then attaches extra parameters.
        """
        indices_iter = super().__iter__()

        for idx in indices_iter:
            yield (idx, self.image_num, self.aspect_ratio,)

    def update_parameters(self, aspect_ratio, image_num):
        """
        Updates dynamic parameters for each new epoch or iteration.

        Args:
            aspect_ratio: The aspect ratio to set.
            image_num: The number of images to set.
        """
        self.aspect_ratio = aspect_ratio
        self.image_num = image_num


================================================
FILE: training/data/preprocess/vkitti.sh
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

mkdir vkitti
cd vkitti

wget https://download.europe.naverlabs.com//virtual_kitti_2.0.3/vkitti_2.0.3_rgb.tar
tar -xvf vkitti_2.0.3_rgb.tar

wget https://download.europe.naverlabs.com//virtual_kitti_2.0.3/vkitti_2.0.3_depth.tar
tar -xvf vkitti_2.0.3_depth.tar

wget https://download.europe.naverlabs.com//virtual_kitti_2.0.3/vkitti_2.0.3_textgt.tar.gz
tar -xvf vkitti_2.0.3_textgt.tar.gz


cd ..



================================================
FILE: training/data/track_util.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import os

import cv2
import numpy as np
import torch

import logging

from vggt.utils.geometry import *



def build_tracks_by_depth(extrinsics, intrinsics, world_points, depths, point_masks, images, pos_rel_thres=0.05, neg_epipolar_thres=16, 
                          boundary_thres=4, target_track_num=512, neg_ratio = 0.0, neg_sample_size_ratio = 0.5, seq_name=None):
    """
    Args:
        extrinsics: (N, 3, 4)
        intrinsics: (N, 3, 3)
        world_points: (N, H, W, 3)
        depths: (N, H, W)
        point_masks: (N, H, W)
        pos_rel_thres: float, relative threshold for positive track depth check
        neg_epipolar_thres: float, threshold for negative track epipolar check, in px
        boundary_thres: int, boundary in px to skip near edges
        target_track_num: int, total # tracks to build
        neg_ratio: fraction of final tracks that should be negative
        neg_sample_size_ratio: fraction of W/H used for random offset

    Returns:
        final_tracks: (N, P, 2) float
        final_vis_masks: (N, P) bool
        final_pos_masks: (P) bool, indicate if a mask is positive or negative
    """
    # Wait, should we do this before resizing the image?
    
    B, H, W, _ = world_points.shape

    # We use the first frame as the query frame, so [0]
    query_world_points = world_points[0]
    query_point_masks = point_masks[0]


    if (query_point_masks).sum() > 0:
        # at least one point
        valid_query_points = query_world_points[query_point_masks]
        
        # image_points: BxPx2
        # cam_points: Bx3xP (yes 3xP instead of Px3). Probably we can change it in the future
        image_points, cam_points = project_world_points_to_cam(valid_query_points, extrinsics, intrinsics)
        
        # proj_depths: BxP
        proj_depths = cam_points[:, -1]

        # floor to get the left top corner
        uv_int = image_points.floor().long().clone()
        
        uv_inside_flag = (uv_int[..., 0] >= boundary_thres) & (uv_int[..., 0] < (W - boundary_thres)) & (uv_int[..., 1] >= boundary_thres) & (uv_int[..., 1] < (H - boundary_thres))
        uv_int[~uv_inside_flag] = 0
        batch_indices = torch.arange(B).view(B, 1).expand(-1, uv_int.shape[1])

        # Use these indices to sample from the depth map
        # since we interpolate depths by nearest,
        # so assume the left top corner is (x, y)
        # we want to check for (x,y), (x+1,y), (x,y+1), (x+1,y+1)
        
        depth_inside_flag = None
        for shift in [(0,0), (1,0), (0,1), (1,1)]:
            cur_uv_int = uv_int + torch.tensor(shift)
            cur_depth_inside_flag = get_depth_inside_flag(depths, batch_indices, cur_uv_int, proj_depths, pos_rel_thres)
            if depth_inside_flag is None:
                depth_inside_flag = cur_depth_inside_flag
            else:
                depth_inside_flag = torch.logical_or(depth_inside_flag, cur_depth_inside_flag)

        # B, P, 2
        positive_tracks = image_points
        positive_vis_masks = torch.logical_and(uv_inside_flag, depth_inside_flag)
    else:
        print(f"No valid query points in {seq_name}")
        positive_tracks = torch.zeros(B, target_track_num, 2, device=world_points.device, dtype=torch.float32)
        positive_vis_masks = torch.zeros(B, target_track_num, device=world_points.device, dtype=torch.bool)

    
    sampled_neg_track_num = target_track_num * 4 # we sample more negative tracks to ensure the quality
    
    perb_range = [int(W*neg_sample_size_ratio), int(H*neg_sample_size_ratio)]
    
    # sample negative query points
    us = torch.randint(low=0, high=W, size=(1, sampled_neg_track_num), device=world_points.device)
    vs = torch.randint(low=0, high=H, size=(1, sampled_neg_track_num), device=world_points.device)
    neg_query_uvs = torch.stack([us, vs], dim=-1)
    
    # construct negative tracks
    delta_us = torch.rand(size=(B, sampled_neg_track_num), device=world_points.device) * perb_range[0]
    delta_vs = torch.rand(size=(B, sampled_neg_track_num), device=world_points.device) * perb_range[1]
    delta_us[0] = 0
    delta_vs[0] = 0
    negative_tracks = neg_query_uvs + torch.stack([delta_us, delta_vs], dim=-1)

    # Do epipolar check here
    negative_sampson_distances = track_epipolar_check(negative_tracks, extrinsics, intrinsics)
    negative_epipolar_check = (negative_sampson_distances > neg_epipolar_thres).all(dim=0)   # we set the threshold to 5 px
    # Filter out those satifsfying epipolar check
    negative_tracks = negative_tracks[:, negative_epipolar_check]
        
    # Prepare for output
    final_tracks = torch.zeros(B, target_track_num, 2, device=world_points.device, dtype=torch.float32)
    final_vis_masks = torch.zeros(B, target_track_num, device=world_points.device, dtype=torch.bool)
    final_pos_masks = torch.zeros(target_track_num, device=world_points.device, dtype=torch.bool)
    
    target_pos_track_num = target_track_num - int(target_track_num * neg_ratio)
    sampled_pos_track_num = 0

    sampled_positive_tracks, sampled_positive_vis_masks = sample_positive_tracks(positive_tracks, positive_vis_masks, target_pos_track_num)
    sampled_pos_track_num = sampled_positive_tracks.shape[1]
    final_tracks[:, :sampled_pos_track_num] = sampled_positive_tracks
    final_vis_masks[:, :sampled_pos_track_num] = sampled_positive_vis_masks
    final_pos_masks[:sampled_pos_track_num] = True


    target_neg_track_num = target_track_num - sampled_pos_track_num

    # Now we need to sample negative tracks
    # just do simple random sampling
    rand_indices = torch.randperm(negative_tracks.shape[1], device=negative_tracks.device)
    sampled_neg_tracks = negative_tracks[:, rand_indices[:target_neg_track_num]]
    sampled_neg_track_num = sampled_neg_tracks.shape[1]
    final_tracks[:, sampled_pos_track_num:sampled_pos_track_num+sampled_neg_track_num] = sampled_neg_tracks
    
    if sampled_pos_track_num+sampled_neg_track_num!=target_track_num:
        logging.warning(f"sampled_pos_track_num+sampled_neg_track_num!=target_track_num: {sampled_pos_track_num+sampled_neg_track_num} != {target_track_num}")
    # Do not need to set final_vis_masks and final_pos_masks, because they are all False
    # Do not need to check the shape of final_tracks, as it is zeroed out
    
        
    # NOTE: We need to do some visual checks

    
    return final_tracks, final_vis_masks, final_pos_masks



def get_depth_inside_flag(depths, batch_indices, uv_int, proj_depths, rel_thres):
    sampled_depths = depths[batch_indices, uv_int[..., 1], uv_int[..., 0]]
    depth_diff = (proj_depths - sampled_depths).abs()
    depth_inside_flag = torch.logical_and(depth_diff < (proj_depths * rel_thres), depth_diff < (sampled_depths * rel_thres))
    return depth_inside_flag







def sample_positive_tracks(tracks, tracks_mask, track_num, half_top = True, seq_name=None):
    # tracks: (B, T, 2)
    # tracks_mask: (B, T)
    # track_num: int
    # half_top: bool

    # if the query frame is not valid, then the track is not valid
    tracks_mask[:, tracks_mask[0]==False] = False
    
    track_frame_num = tracks_mask.sum(dim=0)
    tracks_mask[:, track_frame_num<=1] = False
    track_frame_num = tracks_mask.sum(dim=0)
    
    _, track_num_sort_idx = track_frame_num.sort(descending=True)
    
    if half_top:
        if len(track_num_sort_idx)//2 > track_num:
            # drop those tracks with too small number of valid frames
            # track_num_sort_idx = track_num_sort_idx[:track_num]
            track_num_sort_idx = track_num_sort_idx[:len(track_num_sort_idx)//2]

    pick_idx = torch.randperm(len(track_num_sort_idx))[:track_num]
    track_num_sort_idx = track_num_sort_idx[pick_idx]
    
    tracks = tracks[:, track_num_sort_idx].clone()
    tracks_mask = tracks_mask[:, track_num_sort_idx].clone()
    
    
    tracks_mask = tracks_mask.bool()    # ensure the type is bool
    return tracks, tracks_mask
    
    
    


#  Only for Debugging and Visualization

def track_epipolar_check(tracks, extrinsics, intrinsics, use_essential_mat = False):
    from kornia.geometry.epipolar import sampson_epipolar_distance

    B, T, _ = tracks.shape
    essential_mats = get_essential_matrix(extrinsics[0:1].expand(B-1, -1, -1), extrinsics[1:])

    if use_essential_mat:
        tracks_normalized = cam_from_img(tracks, intrinsics)
        sampson_distances = sampson_epipolar_distance(tracks_normalized[0:1].expand(B-1, -1, -1), tracks_normalized[1:], essential_mats)
    else:
        K1 = intrinsics[0:1].expand(B-1, -1, -1)
        K2 = intrinsics[1:].expand(B-1, -1, -1)
        fundamental_mats = K2.inverse().permute(0, 2, 1).matmul(essential_mats).matmul(K1.inverse())
        sampson_distances = sampson_epipolar_distance(tracks[0:1].expand(B-1, -1, -1), tracks[1:], fundamental_mats)

    return sampson_distances


def get_essential_matrix(extrinsic1, extrinsic2):
    R1 = extrinsic1[:, :3, :3]
    t1 = extrinsic1[:, :3, 3]
    R2 = extrinsic2[:, :3, :3]
    t2 = extrinsic2[:, :3, 3]
    
    R12 = R2.matmul(R1.permute(0, 2, 1))
    t12 = t2 - R12.matmul(t1[..., None])[..., 0]
    E_R = R12
    E_t = -E_R.permute(0, 2, 1).matmul(t12[..., None])[..., 0]
    E = E_R.matmul(hat(E_t))
    return E



def hat(v: torch.Tensor) -> torch.Tensor:
    N, dim = v.shape
    if dim != 3:
        raise ValueError("Input vectors have to be 3-dimensional.")

    x, y, z = v.unbind(1)

    h_01 = -z.view(N, 1, 1)
    h_02 = y.view(N, 1, 1)
    h_10 = z.view(N, 1, 1)
    h_12 = -x.view(N, 1, 1)
    h_20 = -y.view(N, 1, 1)
    h_21 = x.view(N, 1, 1)

    zeros = torch.zeros((N, 1, 1), dtype=v.dtype, device=v.device)

    row1 = torch.cat((zeros, h_01, h_02), dim=2)
    row2 = torch.cat((h_10, zeros, h_12), dim=2)
    row3 = torch.cat((h_20, h_21, zeros), dim=2)

    h = torch.cat((row1, row2, row3), dim=1)

    return h



def color_from_xy(x, y, W, H, cmap_name="hsv"):
    """
    Map (x, y) -> color in (R, G, B).
    1) Normalize x,y to [0,1].
    2) Combine them into a single scalar c in [0,1].
    3) Use matplotlib's colormap to convert c -> (R,G,B).

    You can customize step 2, e.g., c = (x + y)/2, or some function of (x, y).
    """
    import matplotlib.cm
    import matplotlib.colors

    x_norm = x / max(W - 1, 1)
    y_norm = y / max(H - 1, 1)
    # Simple combination:
    c = (x_norm + y_norm) / 2.0

    cmap = matplotlib.cm.get_cmap(cmap_name)
    # cmap(c) -> (r,g,b,a) in [0,1]
    rgba = cmap(c)
    r, g, b = rgba[0], rgba[1], rgba[2]
    return (r, g, b)  # in [0,1], RGB order


def get_track_colors_by_position(
    tracks_b, 
    vis_mask_b=None,
    image_width=None,
    image_height=None,
    cmap_name="hsv"
):
    """
    Given all tracks in one sample (b), compute a (N,3) array of RGB color values
    in [0,255]. The color is determined by the (x,y) position in the first
    visible frame for each track.

    Args:
        tracks_b: Tensor of shape (S, N, 2). (x,y) for each track in each frame.
        vis_mask_b: (S, N) boolean mask; if None, assume all are visible.
        image_width, image_height: used for normalizing (x, y).
        cmap_name: for matplotlib (e.g., 'hsv', 'rainbow', 'jet').

    Returns:
        track_colors: np.ndarray of shape (N, 3), each row is (R,G,B) in [0,255].
    """
    S, N, _ = tracks_b.shape
    track_colors = np.zeros((N, 3), dtype=np.uint8)

    if vis_mask_b is None:
        # treat all as visible
        vis_mask_b = torch.ones(S, N, dtype=torch.bool, device=tracks_b.device)

    for i in range(N):
        # Find first visible frame for track i
        visible_frames = torch.where(vis_mask_b[:, i])[0]
        if len(visible_frames) == 0:
            # track is never visible; just assign black or something
            track_colors[i] = (0, 0, 0)
            continue

        first_s = int(visible_frames[0].item())
        # use that frame's (x,y)
        x, y = tracks_b[first_s, i].tolist()

        # map (x,y) -> (R,G,B) in [0,1]
        r, g, b = color_from_xy(
            x, y, 
            W=image_width, 
            H=image_height, 
            cmap_name=cmap_name
        )
        # scale to [0,255]
        r, g, b = int(r*255), int(g*255), int(b*255)
        track_colors[i] = (r, g, b)

    return track_colors


def visualize_tracks_on_images(
    images, 
    tracks, 
    track_vis_mask=None, 
    out_dir="track_visuals_concat_by_xy",
    image_format="CHW",   # "CHW" or "HWC"
    normalize_mode="[0,1]",
    cmap_name="hsv"       # e.g. "hsv", "rainbow", "jet"
):
    """
    Visualizes all frames for each sample (b) in ONE horizontal row, saving
    one PNG per sample. Each track's color is determined by its (x,y) position
    in the first visible frame (or frame 0 if always visible).
    Finally convert the BGR result to RGB before saving.

    Args:
        images: torch.Tensor (B, S, 3, H, W) if CHW or (B, S, H, W, 3) if HWC.
        tracks: torch.Tensor (B, S, N, 2), last dim = (x, y).
        track_vis_mask: torch.Tensor (B, S, N) or None.
        out_dir: folder to save visualizations.
        image_format: "CHW" or "HWC".
        normalize_mode: "[0,1]", "[-1,1]", or None for direct raw -> 0..255
        cmap_name: a matplotlib colormap name for color_from_xy.

    Returns:
        None (saves images in out_dir).
    """
    import matplotlib
    matplotlib.use('Agg')  # for non-interactive (optional)

    os.makedirs(out_dir, exist_ok=True)

    B, S = images.shape[0], images.shape[1]
    _, _, N, _ = tracks.shape  # (B, S, N, 2)

    # Move to CPU
    images = images.cpu().clone()
    tracks = tracks.cpu().clone()
    if track_vis_mask is not None:
        track_vis_mask = track_vis_mask.cpu().clone()

    # Infer H, W from images shape
    if image_format == "CHW":
        # e.g. images[b, s].shape = (3, H, W)
        H, W = images.shape[3], images.shape[4]
    else:
        # e.g. images[b, s].shape = (H, W, 3)
        H, W = images.shape[2], images.shape[3]

    for b in range(B):
        # Pre-compute the color for each track i based on first visible position
        # in sample b:
        track_colors_rgb = get_track_colors_by_position(
            tracks[b],                   # shape (S, N, 2)
            vis_mask_b=track_vis_mask[b] if track_vis_mask is not None else None,
            image_width=W,
            image_height=H,
            cmap_name=cmap_name
        )
        # We'll accumulate each frame’s drawn image in a list
        frame_images = []

        for s in range(S):
            # shape => either (3, H, W) or (H, W, 3)
            img = images[b, s]
            
            # Convert to (H, W, 3)
            if image_format == "CHW":
                img = img.permute(1, 2, 0)  # (H, W, 3)
            # else "HWC", do nothing

            img = img.numpy().astype(np.float32)

            # Scale to [0,255] if needed
            if normalize_mode == "[0,1]":
                img = np.clip(img, 0, 1) * 255.0
            elif normalize_mode == "[-1,1]":
                img = (img + 1.0) * 0.5 * 255.0
                img = np.clip(img, 0, 255.0)
            # else no normalization

            # Convert to uint8
            img = img.astype(np.uint8)

            # For drawing in OpenCV, the image is assumed BGR, 
            # but *currently* it's in (R,G,B) if your original is truly RGB.
            # We'll do the color conversion AFTER drawing so that we can call 
            # cv2.circle(...) with BGR color. 
            # That means we need to swap the channels now to get BGR for drawing.
            # If your images are actually BGR, you may skip or adapt.
            img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)

            # Draw each visible track
            cur_tracks = tracks[b, s]  # shape (N, 2)
            if track_vis_mask is not None:
                valid_indices = torch.where(track_vis_mask[b, s])[0]
            else:
                valid_indices = range(N)

            cur_tracks_np = cur_tracks.numpy()
            for i in valid_indices:
                x, y = cur_tracks_np[i]
                pt = (int(round(x)), int(round(y)))

                # track_colors_rgb[i] is (R,G,B). For OpenCV circle, we need BGR
                R, G, B = track_colors_rgb[i]
                color_bgr = (int(B), int(G), int(R))  
                cv2.circle(img_bgr, pt, radius=3, color=color_bgr, thickness=-1)

            # Convert back to RGB for consistent final saving:
            img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)

            frame_images.append(img_rgb)

        # Concatenate all frames horizontally: (H, S*W, 3)
        row_img = np.concatenate(frame_images, axis=1)

        out_path = os.path.join(out_dir, f"tracks_b{b}.png")
        cv2.imwrite(out_path, row_img)
        print(f"[INFO] Saved color-by-XY track visualization for sample b={b} -> {out_path}")

================================================
FILE: training/data/worker_fn.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

"""
Utilities for distributed training and deterministic seed generation.
This module provides functions for working with PyTorch's distributed
training capabilities and ensuring reproducible data loading.
"""

import os
import torch
import random
import numpy as np

import torch.distributed as dist
from functools import partial


def is_dist_avail_and_initialized():
    """
    Check if distributed training is available and initialized.
    
    Returns:
        bool: True if distributed training is available and initialized, False otherwise.
    """
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True


def get_rank():
    """
    Get the rank of the current process in distributed training.
    
    Returns:
        int: The rank of the current process, or 0 if distributed training is not initialized.
    """
    if not is_dist_avail_and_initialized():
        return 0
    return dist.get_rank()


def get_world_size():
    """
    Get the total number of processes in distributed training.
    
    Returns:
        int: The world size, or 1 if distributed training is not initialized.
    """
    if not is_dist_avail_and_initialized():
        return 1
    return dist.get_world_size()


def default_worker_init_fn(worker_id, num_workers, epoch, seed=0):
    """
    Default function to initialize random seeds for dataloader workers.
    
    Ensures that each worker across different ranks, epochs, and world sizes
    gets a unique random seed for reproducibility.
    
    Args:
        worker_id (int): ID of the dataloader worker.
        num_workers (int): Total number of dataloader workers.
        epoch (int): Current training epoch.
        seed (int, optional): Base seed for randomization. Defaults to 0.
    """
    rank = get_rank()
    world_size = get_world_size()
    distributed_rank = int(os.environ.get("RANK", None))

    # Use prime numbers for better distribution
    RANK_MULTIPLIER = 1
    WORKER_MULTIPLIER = 1
    WORLD_MULTIPLIER = 1
    EPOCH_MULTIPLIER = 12345
    DISTRIBUTED_RANK_MULTIPLIER = 1042
    
    worker_seed = (
        rank * num_workers * RANK_MULTIPLIER + 
        worker_id * WORKER_MULTIPLIER + 
        seed + 
        world_size * WORLD_MULTIPLIER + 
        epoch * EPOCH_MULTIPLIER
        + distributed_rank * DISTRIBUTED_RANK_MULTIPLIER
    )

    print(f"Rank: {rank}, World size: {world_size}, Distributed rank: {distributed_rank}")
    print(f"Worker seed: {worker_seed}")
    
    
    torch.random.manual_seed(worker_seed)
    np.random.seed(worker_seed)
    random.seed(worker_seed)
    return

def get_worker_init_fn(seed, num_workers, epoch, worker_init_fn=None):
    """
    Get a worker initialization function for dataloaders.
    
    Args:
        seed (int): Base seed for randomization.
        num_workers (int): Number of dataloader workers.
        epoch (int): Current training epoch.
        worker_init_fn (callable, optional): Custom worker initialization function.
            If provided, this will be returned instead of the default one.
            
    Returns:
        callable: A worker initialization function to use with DataLoader.
    """
    if worker_init_fn is not None:
        return worker_init_fn

    return partial(
        default_worker_init_fn,
        num_workers=num_workers,
        epoch=epoch,
        seed=seed,
    )


================================================
FILE: training/launch.py
================================================
# Copyright (c) Meta Platforms, Inc. and affil
Download .txt
gitextract_yzgcgvr3/

├── .gitattributes
├── .gitignore
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE.txt
├── README.md
├── demo_colmap.py
├── demo_gradio.py
├── demo_viser.py
├── docs/
│   └── package.md
├── pyproject.toml
├── requirements.txt
├── requirements_demo.txt
├── training/
│   ├── README.md
│   ├── __init__.py
│   ├── config/
│   │   ├── default.yaml
│   │   └── default_dataset.yaml
│   ├── data/
│   │   ├── __init__.py
│   │   ├── augmentation.py
│   │   ├── base_dataset.py
│   │   ├── composed_dataset.py
│   │   ├── dataset_util.py
│   │   ├── datasets/
│   │   │   ├── co3d.py
│   │   │   └── vkitti.py
│   │   ├── dynamic_dataloader.py
│   │   ├── preprocess/
│   │   │   └── vkitti.sh
│   │   ├── track_util.py
│   │   └── worker_fn.py
│   ├── launch.py
│   ├── loss.py
│   ├── train_utils/
│   │   ├── __init__.py
│   │   ├── checkpoint.py
│   │   ├── distributed.py
│   │   ├── freeze.py
│   │   ├── general.py
│   │   ├── gradient_clip.py
│   │   ├── logging.py
│   │   ├── normalization.py
│   │   ├── optimizer.py
│   │   └── tb_writer.py
│   └── trainer.py
├── vggt/
│   ├── dependency/
│   │   ├── __init__.py
│   │   ├── distortion.py
│   │   ├── np_to_pycolmap.py
│   │   ├── projection.py
│   │   ├── track_modules/
│   │   │   ├── __init__.py
│   │   │   ├── base_track_predictor.py
│   │   │   ├── blocks.py
│   │   │   ├── modules.py
│   │   │   ├── track_refine.py
│   │   │   └── utils.py
│   │   ├── track_predict.py
│   │   ├── vggsfm_tracker.py
│   │   └── vggsfm_utils.py
│   ├── heads/
│   │   ├── camera_head.py
│   │   ├── dpt_head.py
│   │   ├── head_act.py
│   │   ├── track_head.py
│   │   ├── track_modules/
│   │   │   ├── __init__.py
│   │   │   ├── base_track_predictor.py
│   │   │   ├── blocks.py
│   │   │   ├── modules.py
│   │   │   └── utils.py
│   │   └── utils.py
│   ├── layers/
│   │   ├── __init__.py
│   │   ├── attention.py
│   │   ├── block.py
│   │   ├── drop_path.py
│   │   ├── layer_scale.py
│   │   ├── mlp.py
│   │   ├── patch_embed.py
│   │   ├── rope.py
│   │   ├── swiglu_ffn.py
│   │   └── vision_transformer.py
│   ├── models/
│   │   ├── aggregator.py
│   │   └── vggt.py
│   └── utils/
│       ├── geometry.py
│       ├── helper.py
│       ├── load_fn.py
│       ├── pose_enc.py
│       ├── rotation.py
│       └── visual_track.py
└── visual_util.py
Download .txt
SYMBOL INDEX (430 symbols across 62 files)

FILE: demo_colmap.py
  function parse_args (line 42) | def parse_args():
  function run_VGGT (line 65) | def run_VGGT(model, images, dtype, resolution=518):
  function demo_fn (line 93) | def demo_fn(args):
  function rename_colmap_recons_and_rescale_camera (line 254) | def rename_colmap_recons_and_rescale_camera(

FILE: demo_gradio.py
  function run_model (line 44) | def run_model(target_dir, model) -> dict:
  function handle_uploads (line 103) | def handle_uploads(input_video, input_images):
  function update_gallery_on_upload (line 171) | def update_gallery_on_upload(input_video, input_images):
  function gradio_demo (line 186) | def gradio_demo(
  function clear_fields (line 259) | def clear_fields():
  function update_log (line 266) | def update_log():
  function update_visualization (line 273) | def update_visualization(
  function example_pipeline (line 494) | def example_pipeline(

FILE: demo_viser.py
  function viser_wrapper (line 34) | def viser_wrapper(
  function apply_sky_segmentation (line 258) | def apply_sky_segmentation(conf: np.ndarray, image_folder: str) -> np.nd...
  function main (line 321) | def main():

FILE: training/data/augmentation.py
  function get_image_augmentation (line 11) | def get_image_augmentation(

FILE: training/data/base_dataset.py
  class BaseDataset (line 17) | class BaseDataset(Dataset):
    method __init__ (line 33) | def __init__(
    method __len__ (line 57) | def __len__(self):
    method __getitem__ (line 60) | def __getitem__(self, idx_N):
    method get_data (line 75) | def get_data(self, seq_index=None, seq_name=None, ids=None, aspect_rat...
    method get_target_shape (line 95) | def get_target_shape(self, aspect_ratio):
    method process_one_image (line 115) | def process_one_image(
    method get_nearby_ids (line 237) | def get_nearby_ids(self, ids, full_seq_num, expand_ratio=None, expand_...

FILE: training/data/composed_dataset.py
  class ComposedDataset (line 21) | class ComposedDataset(Dataset, ABC):
    method __init__ (line 30) | def __init__(self, dataset_configs: dict, common_config: dict, **kwargs):
    method __len__ (line 80) | def __len__(self):
    method __getitem__ (line 85) | def __getitem__(self, idx_tuple):
  class TupleConcatDataset (line 187) | class TupleConcatDataset(ConcatDataset):
    method __init__ (line 201) | def __init__(self, datasets, common_config):
    method __getitem__ (line 214) | def __getitem__(self, idx):

FILE: training/data/dataset_util.py
  function crop_image_depth_and_intrinsic_by_pp (line 26) | def crop_image_depth_and_intrinsic_by_pp(
  function resize_image_depth_and_intrinsic (line 161) | def resize_image_depth_and_intrinsic(
  function threshold_depth_map (line 261) | def threshold_depth_map(
  function depth_to_world_coords_points (line 317) | def depth_to_world_coords_points(
  function depth_to_cam_coords_points (line 369) | def depth_to_cam_coords_points(
  function rotate_90_degrees (line 411) | def rotate_90_degrees(
  function rotate_image_and_depth_rot90 (line 474) | def rotate_image_and_depth_rot90(image, depth_map, clockwise):
  function adjust_extrinsic_matrix_rot90 (line 507) | def adjust_extrinsic_matrix_rot90(extri_opencv, clockwise):
  function adjust_intrinsic_matrix_rot90 (line 548) | def adjust_intrinsic_matrix_rot90(intri_opencv, image_width, image_heigh...
  function adjust_track_rot90 (line 588) | def adjust_track_rot90(track, image_width, image_height, clockwise):
  function read_image_cv2 (line 616) | def read_image_cv2(path: str, rgb: bool = True) -> np.ndarray:
  function read_depth (line 653) | def read_depth(path: str, scale_adjustment=1.0) -> np.ndarray:
  function load_16big_png_depth (line 689) | def load_16big_png_depth(depth_png: str) -> np.ndarray:

FILE: training/data/datasets/co3d.py
  class Co3dDataset (line 67) | class Co3dDataset(BaseDataset):
    method __init__ (line 68) | def __init__(
    method get_data (line 162) | def get_data(

FILE: training/data/datasets/vkitti.py
  class VKittiDataset (line 20) | class VKittiDataset(BaseDataset):
    method __init__ (line 21) | def __init__(
    method get_data (line 89) | def get_data(

FILE: training/data/dynamic_dataloader.py
  class DynamicTorchDataset (line 17) | class DynamicTorchDataset(ABC):
    method __init__ (line 18) | def __init__(
    method get_loader (line 67) | def get_loader(self, epoch):
  class DynamicBatchSampler (line 94) | class DynamicBatchSampler(Sampler):
    method __init__ (line 99) | def __init__(self,
    method set_epoch (line 140) | def set_epoch(self, epoch):
    method __iter__ (line 151) | def __iter__(self):
    method __len__ (line 194) | def __len__(self):
  class DynamicDistributedSampler (line 199) | class DynamicDistributedSampler(DistributedSampler):
    method __init__ (line 204) | def __init__(
    method __iter__ (line 224) | def __iter__(self):
    method update_parameters (line 235) | def update_parameters(self, aspect_ratio, image_num):

FILE: training/data/track_util.py
  function build_tracks_by_depth (line 19) | def build_tracks_by_depth(extrinsics, intrinsics, world_points, depths, ...
  function get_depth_inside_flag (line 149) | def get_depth_inside_flag(depths, batch_indices, uv_int, proj_depths, re...
  function sample_positive_tracks (line 161) | def sample_positive_tracks(tracks, tracks_mask, track_num, half_top = Tr...
  function track_epipolar_check (line 198) | def track_epipolar_check(tracks, extrinsics, intrinsics, use_essential_m...
  function get_essential_matrix (line 216) | def get_essential_matrix(extrinsic1, extrinsic2):
  function hat (line 231) | def hat(v: torch.Tensor) -> torch.Tensor:
  function color_from_xy (line 257) | def color_from_xy(x, y, W, H, cmap_name="hsv"):
  function get_track_colors_by_position (line 281) | def get_track_colors_by_position(
  function visualize_tracks_on_images (line 335) | def visualize_tracks_on_images(

FILE: training/data/worker_fn.py
  function is_dist_avail_and_initialized (line 22) | def is_dist_avail_and_initialized():
  function get_rank (line 36) | def get_rank():
  function get_world_size (line 48) | def get_world_size():
  function default_worker_init_fn (line 60) | def default_worker_init_fn(worker_id, num_workers, epoch, seed=0):
  function get_worker_init_fn (line 102) | def get_worker_init_fn(seed, num_workers, epoch, worker_init_fn=None):

FILE: training/launch.py
  function main (line 13) | def main():

FILE: training/loss.py
  class MultitaskLoss (line 17) | class MultitaskLoss(torch.nn.Module):
    method __init__ (line 27) | def __init__(self, camera=None, depth=None, point=None, track=None, **...
    method forward (line 35) | def forward(self, predictions, batch) -> torch.Tensor:
  function compute_camera_loss (line 81) | def compute_camera_loss(
  function camera_loss_single (line 157) | def camera_loss_single(pred_pose_enc, gt_pose_enc, loss_type="l1"):
  function compute_point_loss (line 199) | def compute_point_loss(predictions, batch, gamma=1.0, alpha=0.2, gradien...
  function compute_depth_loss (line 239) | def compute_depth_loss(predictions, batch, gamma=1.0, alpha=0.2, gradien...
  function regression_loss (line 281) | def regression_loss(pred, gt, mask, conf=None, gradient_loss_fn=None, ga...
  function gradient_loss_multi_scale_wrapper (line 370) | def gradient_loss_multi_scale_wrapper(prediction, target, mask, scales=4...
  function normal_loss (line 398) | def normal_loss(prediction, target, mask, cos_eps=1e-8, conf=None, gamma...
  function gradient_loss (line 456) | def gradient_loss(prediction, target, mask, conf=None, gamma=1.0, alpha=...
  function point_map_to_normal (line 511) | def point_map_to_normal(point_map, mask, eps=1e-6):
  function filter_by_quantile (line 567) | def filter_by_quantile(loss_tensor, valid_range, min_elements=1000, hard...
  function torch_quantile (line 606) | def torch_quantile(

FILE: training/train_utils/checkpoint.py
  class DDPCheckpointSaver (line 38) | class DDPCheckpointSaver:
    method __init__ (line 39) | def __init__(
    method save_checkpoint (line 52) | def save_checkpoint(
  function robust_torch_save (line 72) | def robust_torch_save(checkpoint: Dict[str, Any], checkpoint_path: str) ...

FILE: training/train_utils/distributed.py
  function get_machine_local_and_dist_rank (line 12) | def get_machine_local_and_dist_rank():

FILE: training/train_utils/freeze.py
  function freeze_modules (line 24) | def freeze_modules(model: nn.Module, patterns: List[str], recursive: boo...
  function _freeze (line 62) | def _freeze(mod: nn.Module, recursive: bool) -> None:
  function _check_every_pattern_used (line 91) | def _check_every_pattern_used(matched_names: set[str], patterns: List[st...

FILE: training/train_utils/general.py
  function check_and_fix_inf_nan (line 29) | def check_and_fix_inf_nan(input_tensor, loss_name="default", hard_max=100):
  function get_resume_checkpoint (line 60) | def get_resume_checkpoint(checkpoint_save_dir):
  class DurationMeter (line 69) | class DurationMeter:
    method __init__ (line 70) | def __init__(self, name, device, fmt=":f"):
    method reset (line 76) | def reset(self):
    method update (line 79) | def update(self, val):
    method add (line 82) | def add(self, val):
    method __str__ (line 85) | def __str__(self):
  function human_readable_time (line 89) | def human_readable_time(time_seconds):
  class ProgressMeter (line 98) | class ProgressMeter:
    method __init__ (line 99) | def __init__(self, num_batches, meters, real_meters, prefix=""):
    method display (line 105) | def display(self, batch):
    method _get_batch_fmtstr (line 119) | def _get_batch_fmtstr(self, num_batches):
  class _CopyableData (line 127) | class _CopyableData(Protocol):
    method to (line 128) | def to(self, device: torch.device, *args: Any, **kwargs: Any):
  function _is_named_tuple (line 133) | def _is_named_tuple(x) -> bool:
  function copy_data_to_device (line 137) | def copy_data_to_device(data, device: torch.device, *args: Any, **kwargs...
  function safe_makedirs (line 197) | def safe_makedirs(path: str):
  function set_seeds (line 215) | def set_seeds(seed_value, max_epochs, dist_rank):
  function log_env_variables (line 233) | def log_env_variables():
  function is_dist_avail_and_initialized (line 243) | def is_dist_avail_and_initialized():
  class AverageMeter (line 252) | class AverageMeter:
    method __init__ (line 260) | def __init__(self, name: str, device: Optional[torch.device] = None, f...
    method reset (line 266) | def reset(self):
    method update (line 273) | def update(self, val, n=1):
    method __str__ (line 282) | def __str__(self) -> str:
    method value (line 288) | def value(self) -> float:
    method average (line 293) | def average(self) -> float:
  function pretty_int (line 302) | def pretty_int(n: int) -> str:
  function model_summary (line 313) | def model_summary(model: torch.nn.Module,
  function get_rank (line 364) | def get_rank():

FILE: training/train_utils/gradient_clip.py
  class GradientClipper (line 12) | class GradientClipper:
    method __init__ (line 17) | def __init__(self, configs, *args, **kwargs):
    method setup_clipping (line 40) | def setup_clipping(self, model: nn.Module) -> None:
    method __call__ (line 80) | def __call__(self, model: nn.Module) -> Optional[torch.Tensor]:

FILE: training/train_utils/logging.py
  function _cached_log_stream (line 22) | def _cached_log_stream(filename):
  function setup_logging (line 30) | def setup_logging(

FILE: training/train_utils/normalization.py
  function check_valid_tensor (line 14) | def check_valid_tensor(input_tensor: Optional[torch.Tensor], name: str =...
  function normalize_camera_extrinsics_and_points_batch (line 27) | def normalize_camera_extrinsics_and_points_batch(

FILE: training/train_utils/optimizer.py
  class OptimizerWrapper (line 20) | class OptimizerWrapper:
    method __init__ (line 23) | def __init__(self, optimizer: torch.optim.Optimizer, schedulers=None) ...
    method step (line 33) | def step(self, where: float = 1.0, closure=None):
    method zero_grad (line 38) | def zero_grad(self, *args, **kwargs):
    method _validate_optimizer_schedulers (line 41) | def _validate_optimizer_schedulers(self):
    method step_schedulers (line 51) | def step_schedulers(self, where: float) -> None:
  function validate_param_group_params (line 64) | def validate_param_group_params(param_groups: List[Dict], model: nn.Modu...
  function get_full_parameter_name (line 96) | def get_full_parameter_name(module_name: str, param_name: str) -> str:
  function get_module_cls_to_param_names (line 100) | def get_module_cls_to_param_names(model: nn.Module) -> Dict[type, Set[st...
  function unix_param_pattern_to_parameter_names (line 111) | def unix_param_pattern_to_parameter_names(filter_param_names: Union[List...
  function unix_module_cls_pattern_to_parameter_names (line 125) | def unix_module_cls_pattern_to_parameter_names(filter_module_cls_names: ...
  function _unix_pattern_to_parameter_names (line 142) | def _unix_pattern_to_parameter_names(scheduler_cfg,
  function set_default_parameters (line 161) | def set_default_parameters(scheduler_cfgs: List[dict], all_parameter_nam...
  function name_constraints_to_parameters (line 180) | def name_constraints_to_parameters(param_constraints: List[Set[str]],
  function map_scheduler_cfgs_to_param_groups (line 186) | def map_scheduler_cfgs_to_param_groups(all_scheduler_cfgs: Iterable[List...
  function construct_optimizer (line 208) | def construct_optimizer(model: nn.Module,
  function construct_optimizers (line 262) | def construct_optimizers(model: nn.Module, optim_conf) -> Union[List[Opt...

FILE: training/train_utils/tb_writer.py
  class TensorBoardLogger (line 18) | class TensorBoardLogger:
    method __init__ (line 25) | def __init__(
    method writer (line 62) | def writer(self) -> Optional[SummaryWriter]:
    method path (line 67) | def path(self) -> str:
    method flush (line 71) | def flush(self) -> None:
    method close (line 76) | def close(self) -> None:
    method log_dict (line 85) | def log_dict(self, payload: Dict[str, Any], step: int) -> None:
    method log (line 98) | def log(self, name: str, data: Any, step: int) -> None:
    method log_visuals (line 111) | def log_visuals(

FILE: training/trainer.py
  class Trainer (line 46) | class Trainer:
    method __init__ (line 60) | def __init__(
    method _setup_timers (line 170) | def _setup_timers(self):
    method _setup_env_variables (line 175) | def _setup_env_variables(self, env_variables_conf: Optional[Dict[str, ...
    method _setup_torch_dist_and_backend (line 182) | def _setup_torch_dist_and_backend(self, cuda_conf: Dict, distributed_c...
    method _load_resuming_checkpoint (line 198) | def _load_resuming_checkpoint(self, ckpt_path: str):
    method _setup_device (line 228) | def _setup_device(self, device: str):
    method _setup_components (line 239) | def _setup_components(self):
    method _setup_dataloaders (line 273) | def _setup_dataloaders(self):
    method _setup_ddp_distributed_training (line 289) | def _setup_ddp_distributed_training(self, distributed_conf: Dict, devi...
    method save_checkpoint (line 306) | def save_checkpoint(self, epoch: int, checkpoint_names: Optional[List[...
    method _get_scalar_log_keys (line 359) | def _get_scalar_log_keys(self, phase: str) -> List[str]:
    method run (line 365) | def run(self):
    method run_train (line 377) | def run_train(self):
    method run_val (line 403) | def run_val(self):
    method val_epoch (line 419) | def val_epoch(self, val_loader):
    method train_epoch (line 501) | def train_epoch(self, train_loader):
    method _run_steps_on_batch_chunks (line 638) | def _run_steps_on_batch_chunks(
    method _apply_batch_repetition (line 692) | def _apply_batch_repetition(self, batch: Mapping) -> Mapping:
    method _process_batch (line 716) | def _process_batch(self, batch: Mapping):
    method _step (line 738) | def _step(self, batch, model: nn.Module, phase: str, loss_meters: dict):
    method _update_and_log_scalars (line 760) | def _update_and_log_scalars(self, data: Mapping, phase: str, step: int...
    method _log_tb_visuals (line 772) | def _log_tb_visuals(self, batch: Mapping, phase: str, step: int) -> None:
  function chunk_batch_for_accum_steps (line 823) | def chunk_batch_for_accum_steps(batch: Mapping, accum_steps: int) -> Lis...
  function is_sequence_of_primitives (line 829) | def is_sequence_of_primitives(data: Any) -> bool:
  function get_chunk_from_data (line 838) | def get_chunk_from_data(data: Any, chunk_id: int, num_chunks: int) -> Any:

FILE: vggt/dependency/distortion.py
  function _is_numpy (line 14) | def _is_numpy(x: ArrayLike) -> bool:
  function _is_torch (line 18) | def _is_torch(x: ArrayLike) -> bool:
  function _ensure_torch (line 22) | def _ensure_torch(x: ArrayLike) -> torch.Tensor:
  function single_undistortion (line 32) | def single_undistortion(params, tracks_normalized):
  function iterative_undistortion (line 51) | def iterative_undistortion(params, tracks_normalized, max_iterations=100...
  function apply_distortion (line 99) | def apply_distortion(extra_params, u, v):

FILE: vggt/dependency/np_to_pycolmap.py
  function batch_np_matrix_to_pycolmap (line 12) | def batch_np_matrix_to_pycolmap(
  function pycolmap_to_batch_np_matrix (line 148) | def pycolmap_to_batch_np_matrix(reconstruction, device="cpu", camera_typ...
  function batch_np_matrix_to_pycolmap_wo_track (line 201) | def batch_np_matrix_to_pycolmap_wo_track(
  function _build_pycolmap_intri (line 293) | def _build_pycolmap_intri(fidx, intrinsics, camera_type, extra_params=No...

FILE: vggt/dependency/projection.py
  function img_from_cam_np (line 12) | def img_from_cam_np(
  function project_3D_points_np (line 50) | def project_3D_points_np(
  function project_3D_points (line 105) | def project_3D_points(points3D, extrinsics, intrinsics=None, extra_param...
  function img_from_cam (line 140) | def img_from_cam(intrinsics, points_cam, extra_params=None, default=0.0):

FILE: vggt/dependency/track_modules/base_track_predictor.py
  class BaseTrackerPredictor (line 15) | class BaseTrackerPredictor(nn.Module):
    method __init__ (line 16) | def __init__(
    method forward (line 71) | def forward(self, query_points, fmaps=None, iters=4, return_feat=False...

FILE: vggt/dependency/track_modules/blocks.py
  class BasicEncoder (line 25) | class BasicEncoder(nn.Module):
    method __init__ (line 26) | def __init__(self, input_dim=3, output_dim=128, stride=4):
    method _make_layer (line 58) | def _make_layer(self, dim, stride=1):
    method forward (line 66) | def forward(self, x):
  class ShallowEncoder (line 90) | class ShallowEncoder(nn.Module):
    method __init__ (line 91) | def __init__(self, input_dim=3, output_dim=32, stride=1, norm_fn="inst...
    method _make_layer (line 126) | def _make_layer(self, dim, stride=1):
    method forward (line 132) | def forward(self, x):
  function _bilinear_intepolate (line 151) | def _bilinear_intepolate(x, stride, H, W):
  class EfficientUpdateFormer (line 155) | class EfficientUpdateFormer(nn.Module):
    method __init__ (line 160) | def __init__(
    method initialize_weights (line 210) | def initialize_weights(self):
    method forward (line 224) | def forward(self, input_tensor, mask=None):
  class CorrBlock (line 264) | class CorrBlock:
    method __init__ (line 265) | def __init__(self, fmaps, num_levels=4, radius=4, multiple_track_feats...
    method sample (line 282) | def sample(self, coords):
    method corr (line 309) | def corr(self, targets):

FILE: vggt/dependency/track_modules/modules.py
  function _ntuple (line 19) | def _ntuple(n):
  function exists (line 28) | def exists(val):
  function default (line 32) | def default(val, d):
  class ResidualBlock (line 39) | class ResidualBlock(nn.Module):
    method __init__ (line 44) | def __init__(self, in_planes, planes, norm_fn="group", stride=1, kerne...
    method forward (line 86) | def forward(self, x):
  class Mlp (line 97) | class Mlp(nn.Module):
    method __init__ (line 100) | def __init__(
    method forward (line 124) | def forward(self, x):
  class AttnBlock (line 133) | class AttnBlock(nn.Module):
    method __init__ (line 134) | def __init__(
    method forward (line 155) | def forward(self, x, mask=None):
  class CrossAttnBlock (line 172) | class CrossAttnBlock(nn.Module):
    method __init__ (line 173) | def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4....
    method forward (line 190) | def forward(self, x, context, mask=None):

FILE: vggt/dependency/track_modules/track_refine.py
  function refine_track (line 22) | def refine_track(
  function refine_track_v0 (line 163) | def refine_track_v0(
  function compute_score_fn (line 302) | def compute_score_fn(query_point_feat, patch_feat, fine_pred_track, srad...
  function extract_glimpse (line 381) | def extract_glimpse(

FILE: vggt/dependency/track_modules/utils.py
  function get_2d_sincos_pos_embed (line 19) | def get_2d_sincos_pos_embed(embed_dim: int, grid_size: Union[int, Tuple[...
  function get_2d_sincos_pos_embed_from_grid (line 44) | def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor...
  function get_1d_sincos_pos_embed_from_grid (line 65) | def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor)...
  function get_2d_embedding (line 91) | def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) ...
  function bilinear_sampler (line 125) | def bilinear_sampler(input, coords, align_corners=True, padding_mode="bo...
  function sample_features4d (line 186) | def sample_features4d(input, coords):

FILE: vggt/dependency/track_predict.py
  function predict_tracks (line 12) | def predict_tracks(
  function _forward_on_query (line 135) | def _forward_on_query(
  function _augment_non_visible_frames (line 232) | def _augment_non_visible_frames(

FILE: vggt/dependency/vggsfm_tracker.py
  class TrackerPredictor (line 25) | class TrackerPredictor(nn.Module):
    method __init__ (line 26) | def __init__(self, **extra_args):
    method forward (line 58) | def forward(
    method process_images_to_fmaps (line 106) | def process_images_to_fmaps(self, images):

FILE: vggt/dependency/vggsfm_utils.py
  function build_vggsfm_tracker (line 29) | def build_vggsfm_tracker(model_path=None):
  function generate_rank_by_dino (line 51) | def generate_rank_by_dino(
  function farthest_point_sampling (line 118) | def farthest_point_sampling(distance_matrix, num_samples, most_common_fr...
  function calculate_index_mappings (line 153) | def calculate_index_mappings(query_index, S, device=None):
  function switch_tensor_order (line 174) | def switch_tensor_order(tensors, order, dim=1):
  function initialize_feature_extractors (line 189) | def initialize_feature_extractors(max_query_num, det_thres=0.005, extrac...
  function extract_keypoints (line 227) | def extract_keypoints(query_image, extractors, round_keypoints=True):
  function predict_tracks_in_chunks (line 255) | def predict_tracks_in_chunks(

FILE: vggt/heads/camera_head.py
  class CameraHead (line 19) | class CameraHead(nn.Module):
    method __init__ (line 26) | def __init__(
    method forward (line 73) | def forward(self, aggregated_tokens_list: list, num_iterations: int = ...
    method trunk_fn (line 95) | def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> ...
  function modulate (line 144) | def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) ...

FILE: vggt/heads/dpt_head.py
  class DPTHead (line 21) | class DPTHead(nn.Module):
    method __init__ (line 43) | def __init__(
    method forward (line 115) | def forward(
    method _forward_impl (line 172) | def _forward_impl(
    method _apply_pos_embed (line 249) | def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: flo...
    method scratch_forward (line 261) | def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor:
  function _make_fusion_block (line 299) | def _make_fusion_block(features: int, size: int = None, has_residual: bo...
  function _make_scratch (line 313) | def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, ...
  class ResidualConvUnit (line 344) | class ResidualConvUnit(nn.Module):
    method __init__ (line 347) | def __init__(self, features, activation, bn, groups=1):
    method forward (line 366) | def forward(self, x):
  class FeatureFusionBlock (line 389) | class FeatureFusionBlock(nn.Module):
    method __init__ (line 392) | def __init__(
    method forward (line 432) | def forward(self, *xs, size=None):
  function custom_interpolate (line 459) | def custom_interpolate(

FILE: vggt/heads/head_act.py
  function activate_pose (line 12) | def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear", ...
  function base_pose_act (line 38) | def base_pose_act(pose_enc, act_type="linear"):
  function activate_head (line 61) | def activate_head(out, activation="norm_exp", conf_activation="expp1"):
  function inverse_log_transform (line 115) | def inverse_log_transform(y):

FILE: vggt/heads/track_head.py
  class TrackHead (line 12) | class TrackHead(nn.Module):
    method __init__ (line 18) | def __init__(
    method forward (line 72) | def forward(self, aggregated_tokens_list, images, patch_start_idx, que...

FILE: vggt/heads/track_modules/base_track_predictor.py
  class BaseTrackerPredictor (line 17) | class BaseTrackerPredictor(nn.Module):
    method __init__ (line 18) | def __init__(
    method forward (line 82) | def forward(self, query_points, fmaps=None, iters=6, return_feat=False...

FILE: vggt/heads/track_modules/blocks.py
  class EfficientUpdateFormer (line 19) | class EfficientUpdateFormer(nn.Module):
    method __init__ (line 24) | def __init__(
    method initialize_weights (line 80) | def initialize_weights(self):
    method forward (line 90) | def forward(self, input_tensor, mask=None):
  class CorrBlock (line 137) | class CorrBlock:
    method __init__ (line 138) | def __init__(self, fmaps, num_levels=4, radius=4, multiple_track_feats...
    method corr_sample (line 176) | def corr_sample(self, targets, coords):
  function compute_corr_level (line 231) | def compute_corr_level(fmap1, fmap2s, C):

FILE: vggt/heads/track_modules/modules.py
  function _ntuple (line 19) | def _ntuple(n):
  function exists (line 28) | def exists(val):
  function default (line 32) | def default(val, d):
  class ResidualBlock (line 39) | class ResidualBlock(nn.Module):
    method __init__ (line 44) | def __init__(self, in_planes, planes, norm_fn="group", stride=1, kerne...
    method forward (line 86) | def forward(self, x):
  class Mlp (line 97) | class Mlp(nn.Module):
    method __init__ (line 100) | def __init__(
    method forward (line 124) | def forward(self, x):
  class AttnBlock (line 133) | class AttnBlock(nn.Module):
    method __init__ (line 134) | def __init__(
    method forward (line 156) | def forward(self, x, mask=None):
  class CrossAttnBlock (line 173) | class CrossAttnBlock(nn.Module):
    method __init__ (line 174) | def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4....
    method forward (line 192) | def forward(self, x, context, mask=None):

FILE: vggt/heads/track_modules/utils.py
  function get_2d_sincos_pos_embed (line 18) | def get_2d_sincos_pos_embed(embed_dim: int, grid_size: Union[int, Tuple[...
  function get_2d_sincos_pos_embed_from_grid (line 43) | def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor...
  function get_1d_sincos_pos_embed_from_grid (line 64) | def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor)...
  function get_2d_embedding (line 90) | def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) ...
  function bilinear_sampler (line 124) | def bilinear_sampler(input, coords, align_corners=True, padding_mode="bo...
  function sample_features4d (line 193) | def sample_features4d(input, coords):

FILE: vggt/heads/utils.py
  function position_grid_to_embed (line 11) | def position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega...
  function make_sincos_pos_embed (line 36) | def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: fl...
  function create_uv_grid (line 66) | def create_uv_grid(

FILE: vggt/layers/attention.py
  class Attention (line 21) | class Attention(nn.Module):
    method __init__ (line 22) | def __init__(
    method forward (line 50) | def forward(self, x: Tensor, pos=None) -> Tensor:
  class MemEffAttention (line 75) | class MemEffAttention(Attention):
    method forward (line 76) | def forward(self, x: Tensor, attn_bias=None, pos=None) -> Tensor:

FILE: vggt/layers/block.py
  class Block (line 27) | class Block(nn.Module):
    method __init__ (line 28) | def __init__(
    method forward (line 77) | def forward(self, x: Tensor, pos=None) -> Tensor:
  function drop_add_residual_stochastic_depth (line 101) | def drop_add_residual_stochastic_depth(
  function get_branges_scales (line 128) | def get_branges_scales(x, sample_drop_ratio=0.0):
  function add_residual (line 136) | def add_residual(x, brange, residual, residual_scale_factor, scaling_vec...
  function get_attn_bias_and_cat (line 151) | def get_attn_bias_and_cat(x_list, branges=None):
  function drop_add_residual_stochastic_depth_list (line 175) | def drop_add_residual_stochastic_depth_list(
  class NestedTensorBlock (line 198) | class NestedTensorBlock(Block):
    method forward_nested (line 199) | def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
    method forward (line 239) | def forward(self, x_or_x_list):

FILE: vggt/layers/drop_path.py
  function drop_path (line 14) | def drop_path(x, drop_prob: float = 0.0, training: bool = False):
  class DropPath (line 26) | class DropPath(nn.Module):
    method __init__ (line 29) | def __init__(self, drop_prob=None):
    method forward (line 33) | def forward(self, x):

FILE: vggt/layers/layer_scale.py
  class LayerScale (line 15) | class LayerScale(nn.Module):
    method __init__ (line 16) | def __init__(self, dim: int, init_values: Union[float, Tensor] = 1e-5,...
    method forward (line 21) | def forward(self, x: Tensor) -> Tensor:

FILE: vggt/layers/mlp.py
  class Mlp (line 16) | class Mlp(nn.Module):
    method __init__ (line 17) | def __init__(
    method forward (line 34) | def forward(self, x: Tensor) -> Tensor:

FILE: vggt/layers/patch_embed.py
  function make_2tuple (line 16) | def make_2tuple(x):
  class PatchEmbed (line 25) | class PatchEmbed(nn.Module):
    method __init__ (line 37) | def __init__(
    method forward (line 65) | def forward(self, x: Tensor) -> Tensor:
    method flops (line 80) | def flops(self) -> float:

FILE: vggt/layers/rope.py
  class PositionGetter (line 24) | class PositionGetter:
    method __init__ (line 35) | def __init__(self):
    method __call__ (line 39) | def __call__(self, batch_size: int, height: int, width: int, device: t...
  class RotaryPositionEmbedding2D (line 62) | class RotaryPositionEmbedding2D(nn.Module):
    method __init__ (line 79) | def __init__(self, frequency: float = 100.0, scaling_factor: float = 1...
    method _compute_frequency_components (line 86) | def _compute_frequency_components(
    method _rotate_features (line 120) | def _rotate_features(x: torch.Tensor) -> torch.Tensor:
    method _apply_1d_rope (line 133) | def _apply_1d_rope(
    method forward (line 154) | def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> to...

FILE: vggt/layers/swiglu_ffn.py
  class SwiGLUFFN (line 14) | class SwiGLUFFN(nn.Module):
    method __init__ (line 15) | def __init__(
    method forward (line 30) | def forward(self, x: Tensor) -> Tensor:
  class SwiGLUFFNFused (line 54) | class SwiGLUFFNFused(SwiGLU):
    method __init__ (line 55) | def __init__(

FILE: vggt/layers/vision_transformer.py
  function named_apply (line 24) | def named_apply(fn: Callable, module: nn.Module, name="", depth_first=Tr...
  class BlockChunk (line 35) | class BlockChunk(nn.ModuleList):
    method forward (line 36) | def forward(self, x):
  class DinoVisionTransformer (line 42) | class DinoVisionTransformer(nn.Module):
    method __init__ (line 43) | def __init__(
    method init_weights (line 173) | def init_weights(self):
    method interpolate_pos_encoding (line 180) | def interpolate_pos_encoding(self, x, w, h):
    method prepare_tokens_with_masks (line 214) | def prepare_tokens_with_masks(self, x, masks=None):
    method forward_features_list (line 228) | def forward_features_list(self, x_list, masks_list):
    method forward_features (line 252) | def forward_features(self, x, masks=None):
    method _get_intermediate_layers_not_chunked (line 273) | def _get_intermediate_layers_not_chunked(self, x, n=1):
    method _get_intermediate_layers_chunked (line 285) | def _get_intermediate_layers_chunked(self, x, n=1):
    method get_intermediate_layers (line 299) | def get_intermediate_layers(
    method forward (line 325) | def forward(self, *args, is_training=True, **kwargs):
  function init_weights_vit_timm (line 333) | def init_weights_vit_timm(module: nn.Module, name: str = ""):
  function vit_small (line 341) | def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
  function vit_base (line 355) | def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
  function vit_large (line 369) | def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
  function vit_giant2 (line 383) | def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):

FILE: vggt/models/aggregator.py
  class Aggregator (line 25) | class Aggregator(nn.Module):
    method __init__ (line 52) | def __init__(
    method __build_patch_embed__ (line 143) | def __build_patch_embed__(
    method forward (line 184) | def forward(self, images: torch.Tensor) -> Tuple[List[torch.Tensor], i...
    method _process_frame_attention (line 260) | def _process_frame_attention(self, tokens, B, S, P, C, frame_idx, pos=...
    method _process_global_attention (line 284) | def _process_global_attention(self, tokens, B, S, P, C, global_idx, po...
  function slice_expand_and_flatten (line 308) | def slice_expand_and_flatten(token_tensor, B, S):

FILE: vggt/models/vggt.py
  class VGGT (line 17) | class VGGT(nn.Module, PyTorchModelHubMixin):
    method __init__ (line 18) | def __init__(self, img_size=518, patch_size=14, embed_dim=1024,
    method forward (line 29) | def forward(self, images: torch.Tensor, query_points: torch.Tensor = N...

FILE: vggt/utils/geometry.py
  function unproject_depth_map_to_point_map (line 15) | def unproject_depth_map_to_point_map(
  function depth_to_world_coords_points (line 47) | def depth_to_world_coords_points(
  function depth_to_cam_coords_points (line 87) | def depth_to_cam_coords_points(depth_map: np.ndarray, intrinsic: np.ndar...
  function closed_form_inverse_se3 (line 120) | def closed_form_inverse_se3(se3, R=None, T=None):
  function project_world_points_to_camera_points_batch (line 175) | def project_world_points_to_camera_points_batch(world_points, cam_extrin...
  function project_world_points_to_cam (line 204) | def project_world_points_to_cam(
  function img_from_cam (line 251) | def img_from_cam(cam_intrinsics, cam_points, distortion_params=None, def...
  function cam_from_img (line 294) | def cam_from_img(pred_tracks, intrinsics, extra_params=None):

FILE: vggt/utils/helper.py
  function randomly_limit_trues (line 10) | def randomly_limit_trues(mask: np.ndarray, max_trues: int) -> np.ndarray:
  function create_pixel_coordinate_grid (line 33) | def create_pixel_coordinate_grid(num_frames, height, width):

FILE: vggt/utils/load_fn.py
  function load_and_preprocess_images_square (line 13) | def load_and_preprocess_images_square(image_path_list, target_size=1024):
  function load_and_preprocess_images (line 97) | def load_and_preprocess_images(image_path_list, mode="crop"):

FILE: vggt/utils/pose_enc.py
  function extri_intri_to_pose_encoding (line 11) | def extri_intri_to_pose_encoding(
  function pose_encoding_to_extri_intri (line 62) | def pose_encoding_to_extri_intri(

FILE: vggt/utils/rotation.py
  function quat_to_mat (line 14) | def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor:
  function mat_to_quat (line 47) | def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor:
  function _sqrt_positive_part (line 106) | def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
  function standardize_quaternion (line 120) | def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:

FILE: vggt/utils/visual_track.py
  function color_from_xy (line 13) | def color_from_xy(x, y, W, H, cmap_name="hsv"):
  function get_track_colors_by_position (line 37) | def get_track_colors_by_position(tracks_b, vis_mask_b=None, image_width=...
  function visualize_tracks_on_images (line 80) | def visualize_tracks_on_images(

FILE: visual_util.py
  function predictions_to_glb (line 18) | def predictions_to_glb(
  function integrate_camera_into_scene (line 218) | def integrate_camera_into_scene(scene: trimesh.Scene, transform: np.ndar...
  function apply_scene_alignment (line 263) | def apply_scene_alignment(scene_3d: trimesh.Scene, extrinsics_matrices: ...
  function get_opengl_conversion_matrix (line 287) | def get_opengl_conversion_matrix() -> np.ndarray:
  function transform_points (line 304) | def transform_points(transformation: np.ndarray, points: np.ndarray, dim...
  function compute_camera_faces (line 329) | def compute_camera_faces(cone_shape: trimesh.Trimesh) -> np.ndarray:
  function segment_sky (line 365) | def segment_sky(image_path, onnx_session, mask_filename=None):
  function run_skyseg (line 396) | def run_skyseg(onnx_session, input_size, image):
  function download_file_from_url (line 436) | def download_file_from_url(url, filename):
Condensed preview — 83 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (624K chars).
[
  {
    "path": ".gitattributes",
    "chars": 122,
    "preview": "# SCM syntax highlighting & preventing 3-way merges\npixi.lock merge=binary linguist-language=YAML linguist-generated=tru"
  },
  {
    "path": ".gitignore",
    "chars": 2020,
    "preview": ".hydra/\noutput/\nckpt/\n# Byte-compiled / optimized / DLL files\n__pycache__/\n**/__pycache__/\n*.py[cod]\n*$py.class\n\n# C ext"
  },
  {
    "path": "CODE_OF_CONDUCT.md",
    "chars": 3537,
    "preview": "# Code of Conduct\n\n## Our Pledge\n\nIn the interest of fostering an open and welcoming environment, we as\ncontributors and"
  },
  {
    "path": "CONTRIBUTING.md",
    "chars": 1242,
    "preview": "# Contributing to vggt\nWe want to make contributing to this project as easy and transparent as\npossible.\n\n## Pull Reques"
  },
  {
    "path": "LICENSE.txt",
    "chars": 10666,
    "preview": "VGGT License\n\nv1 Last Updated: July 29, 2025\n\n“Acceptable Use Policy” means the Acceptable Use Policy, applicable to Res"
  },
  {
    "path": "README.md",
    "chars": 14662,
    "preview": "<div align=\"center\">\n<h1>VGGT: Visual Geometry Grounded Transformer</h1>\n\n<a href=\"https://jytime.github.io/data/VGGT_CV"
  },
  {
    "path": "demo_colmap.py",
    "chars": 12668,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "demo_gradio.py",
    "chars": 25027,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "demo_viser.py",
    "chars": 14794,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "docs/package.md",
    "chars": 1127,
    "preview": "# Alternative Installation Methods\n\nThis document explains how to install VGGT as a package using different package mana"
  },
  {
    "path": "pyproject.toml",
    "chars": 1040,
    "preview": "[project]\nauthors = [{name = \"Jianyuan Wang\", email = \"jianyuan@robots.ox.ac.uk\"}]\ndependencies = [\n    \"numpy<2\",\n    \""
  },
  {
    "path": "requirements.txt",
    "chars": 89,
    "preview": "torch==2.3.1\ntorchvision==0.18.1\nnumpy==1.26.1\nPillow\nhuggingface_hub\neinops\nsafetensors\n"
  },
  {
    "path": "requirements_demo.txt",
    "chars": 298,
    "preview": "gradio==5.17.1\nviser==0.2.23\ntqdm\nhydra-core\nomegaconf\nopencv-python\nscipy\nonnxruntime\nrequests\ntrimesh\nmatplotlib\npydan"
  },
  {
    "path": "training/README.md",
    "chars": 5499,
    "preview": "# Training\n\nThis is a re-implementation of our framework for training VGGT. This document shows how to set up the enviro"
  },
  {
    "path": "training/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "training/config/default.yaml",
    "chars": 4619,
    "preview": "defaults:\n  - default_dataset.yaml\n\nexp_name: exp001\nimg_size: 518\nnum_workers: 8\nseed_value: 42\naccum_steps: 2    # We "
  },
  {
    "path": "training/config/default_dataset.yaml",
    "chars": 2411,
    "preview": "# Template for the dataset config\ndata:\n  # The code still looks too complicated. I should refactor this again (do I hav"
  },
  {
    "path": "training/data/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "training/data/augmentation.py",
    "chars": 2247,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "training/data/base_dataset.py",
    "chars": 11524,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "training/data/composed_dataset.py",
    "chars": 11436,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "training/data/dataset_util.py",
    "chars": 24624,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "training/data/datasets/co3d.py",
    "chars": 8205,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "training/data/datasets/vkitti.py",
    "chars": 8090,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "training/data/dynamic_dataloader.py",
    "chars": 9007,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "training/data/preprocess/vkitti.sh",
    "chars": 595,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "training/data/track_util.py",
    "chars": 17058,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "training/data/worker_fn.py",
    "chars": 3613,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "training/launch.py",
    "chars": 843,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "training/loss.py",
    "chars": 30473,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "training/train_utils/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "training/train_utils/checkpoint.py",
    "chars": 2640,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "training/train_utils/distributed.py",
    "chars": 652,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "training/train_utils/freeze.py",
    "chars": 3136,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "training/train_utils/general.py",
    "chars": 11550,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "training/train_utils/gradient_clip.py",
    "chars": 4151,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "training/train_utils/logging.py",
    "chars": 2363,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "training/train_utils/normalization.py",
    "chars": 5077,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "training/train_utils/optimizer.py",
    "chars": 11008,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "training/train_utils/tb_writer.py",
    "chars": 4299,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "training/trainer.py",
    "chars": 33636,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "vggt/dependency/__init__.py",
    "chars": 185,
    "preview": "from .track_modules.track_refine import refine_track\nfrom .track_modules.blocks import BasicEncoder, ShallowEncoder\nfrom"
  },
  {
    "path": "vggt/dependency/distortion.py",
    "chars": 6511,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "vggt/dependency/np_to_pycolmap.py",
    "chars": 10764,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "vggt/dependency/projection.py",
    "chars": 9059,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "vggt/dependency/track_modules/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "vggt/dependency/track_modules/base_track_predictor.py",
    "chars": 7425,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "vggt/dependency/track_modules/blocks.py",
    "chars": 12239,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
  },
  {
    "path": "vggt/dependency/track_modules/modules.py",
    "chars": 6274,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "vggt/dependency/track_modules/track_refine.py",
    "chars": 17290,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "vggt/dependency/track_modules/utils.py",
    "chars": 7825,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "vggt/dependency/track_predict.py",
    "chars": 12355,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "vggt/dependency/vggsfm_tracker.py",
    "chars": 4668,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "vggt/dependency/vggsfm_utils.py",
    "chars": 11093,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "vggt/heads/camera_head.py",
    "chars": 5817,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "vggt/heads/dpt_head.py",
    "chars": 17327,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "vggt/heads/head_act.py",
    "chars": 3741,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "vggt/heads/track_head.py",
    "chars": 4207,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "vggt/heads/track_modules/__init__.py",
    "chars": 198,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "vggt/heads/track_modules/base_track_predictor.py",
    "chars": 8027,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "vggt/heads/track_modules/blocks.py",
    "chars": 9806,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
  },
  {
    "path": "vggt/heads/track_modules/modules.py",
    "chars": 6132,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "vggt/heads/track_modules/utils.py",
    "chars": 8136,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "vggt/heads/utils.py",
    "chars": 3941,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "vggt/layers/__init__.py",
    "chars": 382,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "vggt/layers/attention.py",
    "chars": 3067,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# This source code is licensed under the Apache License, Version "
  },
  {
    "path": "vggt/layers/block.py",
    "chars": 9379,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# This source code is licensed under the Apache License, Version "
  },
  {
    "path": "vggt/layers/drop_path.py",
    "chars": 1157,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# This source code is licensed under the Apache License, Version "
  },
  {
    "path": "vggt/layers/layer_scale.py",
    "chars": 781,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# This source code is licensed under the Apache License, Version "
  },
  {
    "path": "vggt/layers/mlp.py",
    "chars": 1269,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# This source code is licensed under the Apache License, Version "
  },
  {
    "path": "vggt/layers/patch_embed.py",
    "chars": 2794,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# This source code is licensed under the Apache License, Version "
  },
  {
    "path": "vggt/layers/rope.py",
    "chars": 7676,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# This source code is licensed under the Apache License, Version "
  },
  {
    "path": "vggt/layers/swiglu_ffn.py",
    "chars": 2131,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# This source code is licensed under the Apache License, Version "
  },
  {
    "path": "vggt/layers/vision_transformer.py",
    "chars": 15075,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# This source code is licensed under the Apache License, Version "
  },
  {
    "path": "vggt/models/aggregator.py",
    "chars": 12944,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "vggt/models/vggt.py",
    "chars": 4862,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "vggt/utils/geometry.py",
    "chars": 11849,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "vggt/utils/helper.py",
    "chars": 2317,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "vggt/utils/load_fn.py",
    "chars": 8724,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "vggt/utils/pose_enc.py",
    "chars": 5517,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "vggt/utils/rotation.py",
    "chars": 4584,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "vggt/utils/visual_track.py",
    "chars": 8354,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "visual_util.py",
    "chars": 16797,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  }
]

About this extraction

This page contains the full source code of the facebookresearch/vggt GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 83 files (584.7 KB), approximately 146.8k tokens, and a symbol index with 430 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!