Repository: facebookresearch/vggt Branch: main Commit: 44b3afbd1869 Files: 83 Total size: 584.7 KB Directory structure: gitextract_yzgcgvr3/ ├── .gitattributes ├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE.txt ├── README.md ├── demo_colmap.py ├── demo_gradio.py ├── demo_viser.py ├── docs/ │ └── package.md ├── pyproject.toml ├── requirements.txt ├── requirements_demo.txt ├── training/ │ ├── README.md │ ├── __init__.py │ ├── config/ │ │ ├── default.yaml │ │ └── default_dataset.yaml │ ├── data/ │ │ ├── __init__.py │ │ ├── augmentation.py │ │ ├── base_dataset.py │ │ ├── composed_dataset.py │ │ ├── dataset_util.py │ │ ├── datasets/ │ │ │ ├── co3d.py │ │ │ └── vkitti.py │ │ ├── dynamic_dataloader.py │ │ ├── preprocess/ │ │ │ └── vkitti.sh │ │ ├── track_util.py │ │ └── worker_fn.py │ ├── launch.py │ ├── loss.py │ ├── train_utils/ │ │ ├── __init__.py │ │ ├── checkpoint.py │ │ ├── distributed.py │ │ ├── freeze.py │ │ ├── general.py │ │ ├── gradient_clip.py │ │ ├── logging.py │ │ ├── normalization.py │ │ ├── optimizer.py │ │ └── tb_writer.py │ └── trainer.py ├── vggt/ │ ├── dependency/ │ │ ├── __init__.py │ │ ├── distortion.py │ │ ├── np_to_pycolmap.py │ │ ├── projection.py │ │ ├── track_modules/ │ │ │ ├── __init__.py │ │ │ ├── base_track_predictor.py │ │ │ ├── blocks.py │ │ │ ├── modules.py │ │ │ ├── track_refine.py │ │ │ └── utils.py │ │ ├── track_predict.py │ │ ├── vggsfm_tracker.py │ │ └── vggsfm_utils.py │ ├── heads/ │ │ ├── camera_head.py │ │ ├── dpt_head.py │ │ ├── head_act.py │ │ ├── track_head.py │ │ ├── track_modules/ │ │ │ ├── __init__.py │ │ │ ├── base_track_predictor.py │ │ │ ├── blocks.py │ │ │ ├── modules.py │ │ │ └── utils.py │ │ └── utils.py │ ├── layers/ │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── block.py │ │ ├── drop_path.py │ │ ├── layer_scale.py │ │ ├── mlp.py │ │ ├── patch_embed.py │ │ ├── rope.py │ │ ├── swiglu_ffn.py │ │ └── vision_transformer.py │ ├── models/ │ │ ├── aggregator.py │ │ └── vggt.py │ └── utils/ │ ├── geometry.py │ ├── helper.py │ ├── load_fn.py │ ├── pose_enc.py │ ├── rotation.py │ └── visual_track.py └── visual_util.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitattributes ================================================ # SCM syntax highlighting & preventing 3-way merges pixi.lock merge=binary linguist-language=YAML linguist-generated=true ================================================ FILE: .gitignore ================================================ .hydra/ output/ ckpt/ # Byte-compiled / optimized / DLL files __pycache__/ **/__pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ pip-wheel-metadata/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ cover/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # PEP 582; used by e.g. github.com/David-OConnor/pyflow __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # pytype static type analyzer .pytype/ # Profiling data .prof # Folder specific to your needs **/tmp/ **/outputs/skyseg.onnx skyseg.onnx # pixi environments .pixi *.egg-info ================================================ FILE: CODE_OF_CONDUCT.md ================================================ # Code of Conduct ## Our Pledge In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to make participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. ## Our Standards Examples of behavior that contributes to creating a positive environment include: * Using welcoming and inclusive language * Being respectful of differing viewpoints and experiences * Gracefully accepting constructive criticism * Focusing on what is best for the community * Showing empathy towards other community members Examples of unacceptable behavior by participants include: * The use of sexualized language or imagery and unwelcome sexual attention or advances * Trolling, insulting/derogatory comments, and personal or political attacks * Public or private harassment * Publishing others' private information, such as a physical or electronic address, without explicit permission * Other conduct which could reasonably be considered inappropriate in a professional setting ## Our Responsibilities Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. ## Scope This Code of Conduct applies within all project spaces, and it also applies when an individual is representing the project or its community in public spaces. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. This Code of Conduct also applies outside the project spaces when there is a reasonable belief that an individual's behavior may have a negative impact on the project or its community. ## Enforcement Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at . All complaints will be reviewed and investigated and will result in a response that is deemed necessary and appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project's leadership. ## Attribution This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html [homepage]: https://www.contributor-covenant.org For answers to common questions about this code of conduct, see https://www.contributor-covenant.org/faq ================================================ FILE: CONTRIBUTING.md ================================================ # Contributing to vggt We want to make contributing to this project as easy and transparent as possible. ## Pull Requests We actively welcome your pull requests. 1. Fork the repo and create your branch from `main`. 2. If you've added code that should be tested, add tests. 3. If you've changed APIs, update the documentation. 4. Ensure the test suite passes. 5. Make sure your code lints. 6. If you haven't already, complete the Contributor License Agreement ("CLA"). ## Contributor License Agreement ("CLA") In order to accept your pull request, we need you to submit a CLA. You only need to do this once to work on any of Facebook's open source projects. Complete your CLA here: ## Issues We use GitHub issues to track public bugs. Please ensure your description is clear and has sufficient instructions to be able to reproduce the issue. Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe disclosure of security bugs. In those cases, please go through the process outlined on that page and do not file a public issue. ## License By contributing to vggt, you agree that your contributions will be licensed under the LICENSE file in the root directory of this source tree. ================================================ FILE: LICENSE.txt ================================================ VGGT License v1 Last Updated: July 29, 2025 “Acceptable Use Policy” means the Acceptable Use Policy, applicable to Research Materials, that is incorporated into this Agreement. “Agreement” means the terms and conditions for use, reproduction, distribution and modification of the Research Materials set forth herein. “Documentation” means the specifications, manuals and documentation accompanying Research Materials distributed by Meta. “Licensee” or “you” means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity’s behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf. “Meta” or “we” means Meta Platforms Ireland Limited (if you are located in or, if you are an entity, your principal place of business is in the EEA or Switzerland) and Meta Platforms, Inc. (if you are located outside of the EEA or Switzerland). “Research Materials” means, collectively, Documentation and the models, software and algorithms, including machine-learning model code, trained model weights, inference-enabling code, training-enabling code, fine-tuning enabling code, demonstration materials and other elements of the foregoing distributed by Meta and made available under this Agreement. By clicking “I Accept” below or by using or distributing any portion or element of the Research Materials, you agree to be bound by this Agreement. 1. License Rights and Redistribution. a. Grant of Rights. You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Meta’s intellectual property or other rights owned by Meta embodied in the Research Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the Research Materials. b. Redistribution and Use. i. Distribution of Research Materials, and any derivative works thereof, are subject to the terms of this Agreement. If you distribute or make the Research Materials, or any derivative works thereof, available to a third party, you may only do so under the terms of this Agreement. You shall also provide a copy of this Agreement to such third party. ii. If you submit for publication the results of research you perform on, using, or otherwise in connection with Research Materials, you must acknowledge the use of Research Materials in your publication. iii. Your use of the Research Materials must comply with applicable laws and regulations (including Trade Control Laws) and adhere to the Acceptable Use Policy, which is hereby incorporated by reference into this Agreement. 2. User Support. Your use of the Research Materials is done at your own discretion; Meta does not process any information nor provide any service in relation to such use. Meta is under no obligation to provide any support services for the Research Materials. Any support provided is “as is”, “with all faults”, and without warranty of any kind. 3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE RESEARCH MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN “AS IS” BASIS, WITHOUT WARRANTIES OF ANY KIND, AND META DISCLAIMS ALL WARRANTIES OF ANY KIND, BOTH EXPRESS AND IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE RESEARCH MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE RESEARCH MATERIALS AND ANY OUTPUT AND RESULTS. 4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT OR INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING. 5. Intellectual Property. a. Subject to Meta’s ownership of Research Materials and derivatives made by or for Meta, with respect to any derivative works and modifications of the Research Materials that are made by you, as between you and Meta, you are and will be the owner of such derivative works and modifications. b. If you institute litigation or other proceedings against Meta or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Research Materials, outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Meta from and against any claim by any third party arising out of or related to your use or distribution of the Research Materials. 6. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Research Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Meta may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the Research Materials. Sections 5, 6 and 9 shall survive the termination of this Agreement. 7. Governing Law and Jurisdiction. This Agreement will be governed and construed under the laws of the State of California without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. The courts of California shall have exclusive jurisdiction of any dispute arising out of this Agreement. 8. Modifications and Amendments. Meta may modify this Agreement from time to time; provided that they are similar in spirit to the current version of the Agreement, but may differ in detail to address new problems or concerns. All such changes will be effective immediately. Your continued use of the Research Materials after any modification to this Agreement constitutes your agreement to such modification. Except as provided in this Agreement, no modification or addition to any provision of this Agreement will be binding unless it is in writing and signed by an authorized representative of both you and Meta. Acceptable Use Policy Meta seeks to further understanding of new and existing research domains with the mission of advancing the state-of-the-art in artificial intelligence through open research for the benefit of all. As part of this mission, Meta makes certain research materials available for use in accordance with this Agreement (including the Acceptable Use Policy). Meta is committed to promoting the safe and responsible use of such research materials. Prohibited Uses You agree you will not use, or allow others to use, Research Materials to: Violate the law or others’ rights, including to: Engage in, promote, generate, contribute to, encourage, plan, incite, or further illegal or unlawful activity or content, such as: Violence or terrorism Exploitation or harm to children, including the solicitation, creation, acquisition, or dissemination of child exploitative content or failure to report Child Sexual Abuse Material Human trafficking, exploitation, and sexual violence The illegal distribution of information or materials to minors, including obscene materials, or failure to employ legally required age-gating in connection with such information or materials. Sexual solicitation Any other criminal activity Engage in, promote, incite, or facilitate the harassment, abuse, threatening, or bullying of individuals or groups of individuals Engage in, promote, incite, or facilitate discrimination or other unlawful or harmful conduct in the provision of employment, employment benefits, credit, housing, other economic benefits, or other essential goods and services Engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or related professional practices Collect, process, disclose, generate, or infer health, demographic, or other sensitive personal or private information about individuals without rights and consents required by applicable laws Engage in or facilitate any action or generate any content that infringes, misappropriates, or otherwise violates any third-party rights, including the outputs or results of any technology using Research Materials Create, generate, or facilitate the creation of malicious code, malware, computer viruses or do anything else that could disable, overburden, interfere with or impair the proper working, integrity, operation or appearance of a website or computer system 2. Engage in, promote, incite, facilitate, or assist in the planning or development of activities that present a risk of death or bodily harm to individuals, including use of research artifacts related to the following: Military, warfare, nuclear industries or applications, espionage, use for materials or activities that are subject to the International Traffic Arms Regulations (ITAR) maintained by the United States Department of State Guns and illegal weapons (including weapon development) Illegal drugs and regulated/controlled substances Operation of critical infrastructure, transportation technologies, or heavy machinery Self-harm or harm to others, including suicide, cutting, and eating disorders Any content intended to incite or promote violence, abuse, or any infliction of bodily harm to an individual 3. Intentionally deceive or mislead others, including use of Research Materials related to the following: Generating, promoting, or furthering fraud or the creation or promotion of disinformation Generating, promoting, or furthering defamatory content, including the creation of defamatory statements, images, or other content Generating, promoting, or further distributing spam Impersonating another individual without consent, authorization, or legal right Representing that outputs of research materials or outputs from technology using Research Materials are human-generated Generating or facilitating false online engagement, including fake reviews and other means of fake online engagement 4. Fail to appropriately disclose to end users any known dangers of your Research Materials. ================================================ FILE: README.md ================================================

VGGT: Visual Geometry Grounded Transformer

Paper PDF arXiv Project Page **[Visual Geometry Group, University of Oxford](https://www.robots.ox.ac.uk/~vgg/)**; **[Meta AI](https://ai.facebook.com/research/)** [Jianyuan Wang](https://jytime.github.io/), [Minghao Chen](https://silent-chen.github.io/), [Nikita Karaev](https://nikitakaraevv.github.io/), [Andrea Vedaldi](https://www.robots.ox.ac.uk/~vedaldi/), [Christian Rupprecht](https://chrirupp.github.io/), [David Novotny](https://d-novotny.github.io/)
```bibtex @inproceedings{wang2025vggt, title={VGGT: Visual Geometry Grounded Transformer}, author={Wang, Jianyuan and Chen, Minghao and Karaev, Nikita and Vedaldi, Andrea and Rupprecht, Christian and Novotny, David}, booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, year={2025} } ``` ## Updates - [July 29, 2025] We've updated the license for VGGT to permit **commercial use** (excluding military applications). All code in this repository is now under a commercial-use-friendly license. However, only the newly released checkpoint [**VGGT-1B-Commercial**](https://huggingface.co/facebook/VGGT-1B-Commercial) is licensed for commercial usage — the original checkpoint remains non-commercial. Full license details are available [here](https://github.com/facebookresearch/vggt/blob/main/LICENSE.txt). Access to the checkpoint requires completing an application form, which is processed by a system similar to LLaMA's approval workflow, automatically. The new checkpoint delivers similar performance to the original model. Please submit an issue if you notice a significant performance discrepancy. - [July 6, 2025] Training code is now available in the `training` folder, including an example to finetune VGGT on a custom dataset. - [June 13, 2025] Honored to receive the Best Paper Award at CVPR 2025! Apologies if I’m slow to respond to queries or GitHub issues these days. If you’re interested, our oral presentation is available [here](https://docs.google.com/presentation/d/1JVuPnuZx6RgAy-U5Ezobg73XpBi7FrOh/edit?usp=sharing&ouid=107115712143490405606&rtpof=true&sd=true). Another long presentation can be found [here](https://docs.google.com/presentation/d/1aSv0e5PmH1mnwn2MowlJIajFUYZkjqgw/edit?usp=sharing&ouid=107115712143490405606&rtpof=true&sd=true) (Note: it’s shared in .pptx format with animations — quite large, but feel free to use it as a template if helpful.) - [June 2, 2025] Added a script to run VGGT and save predictions in COLMAP format, with bundle adjustment support optional. The saved COLMAP files can be directly used with [gsplat](https://github.com/nerfstudio-project/gsplat) or other NeRF/Gaussian splatting libraries. - [May 3, 2025] Evaluation code for reproducing our camera pose estimation results on Co3D is now available in the [evaluation](https://github.com/facebookresearch/vggt/tree/evaluation) branch. ## Overview Visual Geometry Grounded Transformer (VGGT, CVPR 2025) is a feed-forward neural network that directly infers all key 3D attributes of a scene, including extrinsic and intrinsic camera parameters, point maps, depth maps, and 3D point tracks, **from one, a few, or hundreds of its views, within seconds**. ## Quick Start First, clone this repository to your local machine, and install the dependencies (torch, torchvision, numpy, Pillow, and huggingface_hub). ```bash git clone git@github.com:facebookresearch/vggt.git cd vggt pip install -r requirements.txt ``` Alternatively, you can install VGGT as a package (click here for details). Now, try the model with just a few lines of code: ```python import torch from vggt.models.vggt import VGGT from vggt.utils.load_fn import load_and_preprocess_images device = "cuda" if torch.cuda.is_available() else "cpu" # bfloat16 is supported on Ampere GPUs (Compute Capability 8.0+) dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16 # Initialize the model and load the pretrained weights. # This will automatically download the model weights the first time it's run, which may take a while. model = VGGT.from_pretrained("facebook/VGGT-1B").to(device) # Load and preprocess example images (replace with your own image paths) image_names = ["path/to/imageA.png", "path/to/imageB.png", "path/to/imageC.png"] images = load_and_preprocess_images(image_names).to(device) with torch.no_grad(): with torch.cuda.amp.autocast(dtype=dtype): # Predict attributes including cameras, depth maps, and point maps. predictions = model(images) ``` The model weights will be automatically downloaded from Hugging Face. If you encounter issues such as slow loading, you can manually download them [here](https://huggingface.co/facebook/VGGT-1B/blob/main/model.pt) and load, or: ```python model = VGGT() _URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt" model.load_state_dict(torch.hub.load_state_dict_from_url(_URL)) ``` ## Detailed Usage
Click to expand You can also optionally choose which attributes (branches) to predict, as shown below. This achieves the same result as the example above. This example uses a batch size of 1 (processing a single scene), but it naturally works for multiple scenes. ```python from vggt.utils.pose_enc import pose_encoding_to_extri_intri from vggt.utils.geometry import unproject_depth_map_to_point_map with torch.no_grad(): with torch.cuda.amp.autocast(dtype=dtype): images = images[None] # add batch dimension aggregated_tokens_list, ps_idx = model.aggregator(images) # Predict Cameras pose_enc = model.camera_head(aggregated_tokens_list)[-1] # Extrinsic and intrinsic matrices, following OpenCV convention (camera from world) extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc, images.shape[-2:]) # Predict Depth Maps depth_map, depth_conf = model.depth_head(aggregated_tokens_list, images, ps_idx) # Predict Point Maps point_map, point_conf = model.point_head(aggregated_tokens_list, images, ps_idx) # Construct 3D Points from Depth Maps and Cameras # which usually leads to more accurate 3D points than point map branch point_map_by_unprojection = unproject_depth_map_to_point_map(depth_map.squeeze(0), extrinsic.squeeze(0), intrinsic.squeeze(0)) # Predict Tracks # choose your own points to track, with shape (N, 2) for one scene query_points = torch.FloatTensor([[100.0, 200.0], [60.72, 259.94]]).to(device) track_list, vis_score, conf_score = model.track_head(aggregated_tokens_list, images, ps_idx, query_points=query_points[None]) ``` Furthermore, if certain pixels in the input frames are unwanted (e.g., reflective surfaces, sky, or water), you can simply mask them by setting the corresponding pixel values to 0 or 1. Precise segmentation masks aren't necessary - simple bounding box masks work effectively (check this [issue](https://github.com/facebookresearch/vggt/issues/47) for an example).
## Interactive Demo We provide multiple ways to visualize your 3D reconstructions. Before using these visualization tools, install the required dependencies: ```bash pip install -r requirements_demo.txt ``` ### Interactive 3D Visualization **Please note:** VGGT typically reconstructs a scene in less than 1 second. However, visualizing 3D points may take tens of seconds due to third-party rendering, independent of VGGT's processing time. The visualization is slow especially when the number of images is large. #### Gradio Web Interface Our Gradio-based interface allows you to upload images/videos, run reconstruction, and interactively explore the 3D scene in your browser. You can launch this in your local machine or try it on [Hugging Face](https://huggingface.co/spaces/facebook/vggt). ```bash python demo_gradio.py ```
Click to preview the Gradio interactive interface ![Gradio Web Interface Preview](https://jytime.github.io/data/vggt_hf_demo_screen.png)
#### Viser 3D Viewer Run the following command to run reconstruction and visualize the point clouds in viser. Note this script requires a path to a folder containing images. It assumes only image files under the folder. You can set `--use_point_map` to use the point cloud from the point map branch, instead of the depth-based point cloud. ```bash python demo_viser.py --image_folder path/to/your/images/folder ``` ## Exporting to COLMAP Format We also support exporting VGGT's predictions directly to COLMAP format, by: ```bash # Feedforward prediction only python demo_colmap.py --scene_dir=/YOUR/SCENE_DIR/ # With bundle adjustment python demo_colmap.py --scene_dir=/YOUR/SCENE_DIR/ --use_ba # Run with bundle adjustment using reduced parameters for faster processing # Reduces max_query_pts from 4096 (default) to 2048 and query_frame_num from 8 (default) to 5 # Trade-off: Faster execution but potentially less robust reconstruction in complex scenes (you may consider setting query_frame_num equal to your total number of images) # See demo_colmap.py for additional bundle adjustment configuration options python demo_colmap.py --scene_dir=/YOUR/SCENE_DIR/ --use_ba --max_query_pts=2048 --query_frame_num=5 ``` Please ensure that the images are stored in `/YOUR/SCENE_DIR/images/`. This folder should contain only the images. Check the examples folder for the desired data structure. The reconstruction result (camera parameters and 3D points) will be automatically saved under `/YOUR/SCENE_DIR/sparse/` in the COLMAP format, such as: ``` SCENE_DIR/ ├── images/ └── sparse/ ├── cameras.bin ├── images.bin └── points3D.bin ``` ## Integration with Gaussian Splatting The exported COLMAP files can be directly used with [gsplat](https://github.com/nerfstudio-project/gsplat) for Gaussian Splatting training. Install `gsplat` following their official instructions (we recommend `gsplat==1.3.0`): An example command to train the model is: ``` cd gsplat python examples/simple_trainer.py default --data_factor 1 --data_dir /YOUR/SCENE_DIR/ --result_dir /YOUR/RESULT_DIR/ ``` ## Zero-shot Single-view Reconstruction Our model shows surprisingly good performance on single-view reconstruction, although it was never trained for this task. The model does not need to duplicate the single-view image to a pair, instead, it can directly infer the 3D structure from the tokens of the single view image. Feel free to try it with our demos above, which naturally works for single-view reconstruction. We did not quantitatively test monocular depth estimation performance ourselves, but [@kabouzeid](https://github.com/kabouzeid) generously provided a comparison of VGGT to recent methods [here](https://github.com/facebookresearch/vggt/issues/36). VGGT shows competitive or better results compared to state-of-the-art monocular approaches such as DepthAnything v2 or MoGe, despite never being explicitly trained for single-view tasks. ## Runtime and GPU Memory We benchmark the runtime and GPU memory usage of VGGT's aggregator on a single NVIDIA H100 GPU across various input sizes. | **Input Frames** | 1 | 2 | 4 | 8 | 10 | 20 | 50 | 100 | 200 | |:----------------:|:-:|:-:|:-:|:-:|:--:|:--:|:--:|:---:|:---:| | **Time (s)** | 0.04 | 0.05 | 0.07 | 0.11 | 0.14 | 0.31 | 1.04 | 3.12 | 8.75 | | **Memory (GB)** | 1.88 | 2.07 | 2.45 | 3.23 | 3.63 | 5.58 | 11.41 | 21.15 | 40.63 | Note that these results were obtained using Flash Attention 3, which is faster than the default Flash Attention 2 implementation while maintaining almost the same memory usage. Feel free to compile Flash Attention 3 from source to get better performance. ## Research Progression Our work builds upon a series of previous research projects. If you're interested in understanding how our research evolved, check out our previous works:
Deep SfM Revisited ──┐
PoseDiffusion ─────► VGGSfM ──► VGGT
CoTracker ──┘
## Acknowledgements Thanks to these great repositories: [PoseDiffusion](https://github.com/facebookresearch/PoseDiffusion), [VGGSfM](https://github.com/facebookresearch/vggsfm), [CoTracker](https://github.com/facebookresearch/co-tracker), [DINOv2](https://github.com/facebookresearch/dinov2), [Dust3r](https://github.com/naver/dust3r), [Moge](https://github.com/microsoft/moge), [PyTorch3D](https://github.com/facebookresearch/pytorch3d), [Sky Segmentation](https://github.com/xiongzhu666/Sky-Segmentation-and-Post-processing), [Depth Anything V2](https://github.com/DepthAnything/Depth-Anything-V2), [Metric3D](https://github.com/YvanYin/Metric3D) and many other inspiring works in the community. ## Checklist - [x] Release the training code - [ ] Release VGGT-500M and VGGT-200M ## License See the [LICENSE](./LICENSE.txt) file for details about the license under which this code is made available. Please note that only this [model checkpoint](https://huggingface.co/facebook/VGGT-1B-Commercial) allows commercial usage. This new checkpoint achieves the same performance level (might be slightly better) as the original one, e.g., AUC@30: 90.37 vs. 89.98 on the Co3D dataset. ================================================ FILE: demo_colmap.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import random import numpy as np import glob import os import copy import torch import torch.nn.functional as F # Configure CUDA settings torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True torch.backends.cudnn.deterministic = False import argparse from pathlib import Path import trimesh import pycolmap from vggt.models.vggt import VGGT from vggt.utils.load_fn import load_and_preprocess_images_square from vggt.utils.pose_enc import pose_encoding_to_extri_intri from vggt.utils.geometry import unproject_depth_map_to_point_map from vggt.utils.helper import create_pixel_coordinate_grid, randomly_limit_trues from vggt.dependency.track_predict import predict_tracks from vggt.dependency.np_to_pycolmap import batch_np_matrix_to_pycolmap, batch_np_matrix_to_pycolmap_wo_track # TODO: add support for masks # TODO: add iterative BA # TODO: add support for radial distortion, which needs extra_params # TODO: test with more cases # TODO: test different camera types def parse_args(): parser = argparse.ArgumentParser(description="VGGT Demo") parser.add_argument("--scene_dir", type=str, required=True, help="Directory containing the scene images") parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility") parser.add_argument("--use_ba", action="store_true", default=False, help="Use BA for reconstruction") ######### BA parameters ######### parser.add_argument( "--max_reproj_error", type=float, default=8.0, help="Maximum reprojection error for reconstruction" ) parser.add_argument("--shared_camera", action="store_true", default=False, help="Use shared camera for all images") parser.add_argument("--camera_type", type=str, default="SIMPLE_PINHOLE", help="Camera type for reconstruction") parser.add_argument("--vis_thresh", type=float, default=0.2, help="Visibility threshold for tracks") parser.add_argument("--query_frame_num", type=int, default=8, help="Number of frames to query") parser.add_argument("--max_query_pts", type=int, default=4096, help="Maximum number of query points") parser.add_argument( "--fine_tracking", action="store_true", default=True, help="Use fine tracking (slower but more accurate)" ) parser.add_argument( "--conf_thres_value", type=float, default=5.0, help="Confidence threshold value for depth filtering (wo BA)" ) return parser.parse_args() def run_VGGT(model, images, dtype, resolution=518): # images: [B, 3, H, W] assert len(images.shape) == 4 assert images.shape[1] == 3 # hard-coded to use 518 for VGGT images = F.interpolate(images, size=(resolution, resolution), mode="bilinear", align_corners=False) with torch.no_grad(): with torch.cuda.amp.autocast(dtype=dtype): images = images[None] # add batch dimension aggregated_tokens_list, ps_idx = model.aggregator(images) # Predict Cameras pose_enc = model.camera_head(aggregated_tokens_list)[-1] # Extrinsic and intrinsic matrices, following OpenCV convention (camera from world) extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc, images.shape[-2:]) # Predict Depth Maps depth_map, depth_conf = model.depth_head(aggregated_tokens_list, images, ps_idx) extrinsic = extrinsic.squeeze(0).cpu().numpy() intrinsic = intrinsic.squeeze(0).cpu().numpy() depth_map = depth_map.squeeze(0).cpu().numpy() depth_conf = depth_conf.squeeze(0).cpu().numpy() return extrinsic, intrinsic, depth_map, depth_conf def demo_fn(args): # Print configuration print("Arguments:", vars(args)) # Set seed for reproducibility np.random.seed(args.seed) torch.manual_seed(args.seed) random.seed(args.seed) if torch.cuda.is_available(): torch.cuda.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) # for multi-GPU print(f"Setting seed as: {args.seed}") # Set device and dtype dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16 device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") print(f"Using dtype: {dtype}") # Run VGGT for camera and depth estimation model = VGGT() _URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt" model.load_state_dict(torch.hub.load_state_dict_from_url(_URL)) model.eval() model = model.to(device) print(f"Model loaded") # Get image paths and preprocess them image_dir = os.path.join(args.scene_dir, "images") image_path_list = glob.glob(os.path.join(image_dir, "*")) if len(image_path_list) == 0: raise ValueError(f"No images found in {image_dir}") base_image_path_list = [os.path.basename(path) for path in image_path_list] # Load images and original coordinates # Load Image in 1024, while running VGGT with 518 vggt_fixed_resolution = 518 img_load_resolution = 1024 images, original_coords = load_and_preprocess_images_square(image_path_list, img_load_resolution) images = images.to(device) original_coords = original_coords.to(device) print(f"Loaded {len(images)} images from {image_dir}") # Run VGGT to estimate camera and depth # Run with 518x518 images extrinsic, intrinsic, depth_map, depth_conf = run_VGGT(model, images, dtype, vggt_fixed_resolution) points_3d = unproject_depth_map_to_point_map(depth_map, extrinsic, intrinsic) if args.use_ba: image_size = np.array(images.shape[-2:]) scale = img_load_resolution / vggt_fixed_resolution shared_camera = args.shared_camera with torch.cuda.amp.autocast(dtype=dtype): # Predicting Tracks # Using VGGSfM tracker instead of VGGT tracker for efficiency # VGGT tracker requires multiple backbone runs to query different frames (this is a problem caused by the training process) # Will be fixed in VGGT v2 # You can also change the pred_tracks to tracks from any other methods # e.g., from COLMAP, from CoTracker, or by chaining 2D matches from Lightglue/LoFTR. pred_tracks, pred_vis_scores, pred_confs, points_3d, points_rgb = predict_tracks( images, conf=depth_conf, points_3d=points_3d, masks=None, max_query_pts=args.max_query_pts, query_frame_num=args.query_frame_num, keypoint_extractor="aliked+sp", fine_tracking=args.fine_tracking, ) torch.cuda.empty_cache() # rescale the intrinsic matrix from 518 to 1024 intrinsic[:, :2, :] *= scale track_mask = pred_vis_scores > args.vis_thresh # TODO: radial distortion, iterative BA, masks reconstruction, valid_track_mask = batch_np_matrix_to_pycolmap( points_3d, extrinsic, intrinsic, pred_tracks, image_size, masks=track_mask, max_reproj_error=args.max_reproj_error, shared_camera=shared_camera, camera_type=args.camera_type, points_rgb=points_rgb, ) if reconstruction is None: raise ValueError("No reconstruction can be built with BA") # Bundle Adjustment ba_options = pycolmap.BundleAdjustmentOptions() pycolmap.bundle_adjustment(reconstruction, ba_options) reconstruction_resolution = img_load_resolution else: conf_thres_value = args.conf_thres_value max_points_for_colmap = 100000 # randomly sample 3D points shared_camera = False # in the feedforward manner, we do not support shared camera camera_type = "PINHOLE" # in the feedforward manner, we only support PINHOLE camera image_size = np.array([vggt_fixed_resolution, vggt_fixed_resolution]) num_frames, height, width, _ = points_3d.shape points_rgb = F.interpolate( images, size=(vggt_fixed_resolution, vggt_fixed_resolution), mode="bilinear", align_corners=False ) points_rgb = (points_rgb.cpu().numpy() * 255).astype(np.uint8) points_rgb = points_rgb.transpose(0, 2, 3, 1) # (S, H, W, 3), with x, y coordinates and frame indices points_xyf = create_pixel_coordinate_grid(num_frames, height, width) conf_mask = depth_conf >= conf_thres_value # at most writing 100000 3d points to colmap reconstruction object conf_mask = randomly_limit_trues(conf_mask, max_points_for_colmap) points_3d = points_3d[conf_mask] points_xyf = points_xyf[conf_mask] points_rgb = points_rgb[conf_mask] print("Converting to COLMAP format") reconstruction = batch_np_matrix_to_pycolmap_wo_track( points_3d, points_xyf, points_rgb, extrinsic, intrinsic, image_size, shared_camera=shared_camera, camera_type=camera_type, ) reconstruction_resolution = vggt_fixed_resolution reconstruction = rename_colmap_recons_and_rescale_camera( reconstruction, base_image_path_list, original_coords.cpu().numpy(), img_size=reconstruction_resolution, shift_point2d_to_original_res=True, shared_camera=shared_camera, ) print(f"Saving reconstruction to {args.scene_dir}/sparse") sparse_reconstruction_dir = os.path.join(args.scene_dir, "sparse") os.makedirs(sparse_reconstruction_dir, exist_ok=True) reconstruction.write(sparse_reconstruction_dir) # Save point cloud for fast visualization trimesh.PointCloud(points_3d, colors=points_rgb).export(os.path.join(args.scene_dir, "sparse/points.ply")) return True def rename_colmap_recons_and_rescale_camera( reconstruction, image_paths, original_coords, img_size, shift_point2d_to_original_res=False, shared_camera=False ): rescale_camera = True for pyimageid in reconstruction.images: # Reshaped the padded&resized image to the original size # Rename the images to the original names pyimage = reconstruction.images[pyimageid] pycamera = reconstruction.cameras[pyimage.camera_id] pyimage.name = image_paths[pyimageid - 1] if rescale_camera: # Rescale the camera parameters pred_params = copy.deepcopy(pycamera.params) real_image_size = original_coords[pyimageid - 1, -2:] resize_ratio = max(real_image_size) / img_size pred_params = pred_params * resize_ratio real_pp = real_image_size / 2 pred_params[-2:] = real_pp # center of the image pycamera.params = pred_params pycamera.width = real_image_size[0] pycamera.height = real_image_size[1] if shift_point2d_to_original_res: # Also shift the point2D to original resolution top_left = original_coords[pyimageid - 1, :2] for point2D in pyimage.points2D: point2D.xy = (point2D.xy - top_left) * resize_ratio if shared_camera: # If shared_camera, all images share the same camera # no need to rescale any more rescale_camera = False return reconstruction if __name__ == "__main__": args = parse_args() with torch.no_grad(): demo_fn(args) # Work in Progress (WIP) """ VGGT Runner Script ================= A script to run the VGGT model for 3D reconstruction from image sequences. Directory Structure ------------------ Input: input_folder/ └── images/ # Source images for reconstruction Output: output_folder/ ├── images/ ├── sparse/ # Reconstruction results │ ├── cameras.bin # Camera parameters (COLMAP format) │ ├── images.bin # Pose for each image (COLMAP format) │ ├── points3D.bin # 3D points (COLMAP format) │ └── points.ply # Point cloud visualization file └── visuals/ # Visualization outputs TODO Key Features ----------- • Dual-mode Support: Run reconstructions using either VGGT or VGGT+BA • Resolution Preservation: Maintains original image resolution in camera parameters and tracks • COLMAP Compatibility: Exports results in standard COLMAP sparse reconstruction format """ ================================================ FILE: demo_gradio.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import os import cv2 import torch import numpy as np import gradio as gr import sys import shutil from datetime import datetime import glob import gc import time sys.path.append("vggt/") from visual_util import predictions_to_glb from vggt.models.vggt import VGGT from vggt.utils.load_fn import load_and_preprocess_images from vggt.utils.pose_enc import pose_encoding_to_extri_intri from vggt.utils.geometry import unproject_depth_map_to_point_map device = "cuda" if torch.cuda.is_available() else "cpu" print("Initializing and loading VGGT model...") # model = VGGT.from_pretrained("facebook/VGGT-1B") # another way to load the model model = VGGT() _URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt" model.load_state_dict(torch.hub.load_state_dict_from_url(_URL)) model.eval() model = model.to(device) # ------------------------------------------------------------------------- # 1) Core model inference # ------------------------------------------------------------------------- def run_model(target_dir, model) -> dict: """ Run the VGGT model on images in the 'target_dir/images' folder and return predictions. """ print(f"Processing images from {target_dir}") # Device check device = "cuda" if torch.cuda.is_available() else "cpu" if not torch.cuda.is_available(): raise ValueError("CUDA is not available. Check your environment.") # Move model to device model = model.to(device) model.eval() # Load and preprocess images image_names = glob.glob(os.path.join(target_dir, "images", "*")) image_names = sorted(image_names) print(f"Found {len(image_names)} images") if len(image_names) == 0: raise ValueError("No images found. Check your upload.") images = load_and_preprocess_images(image_names).to(device) print(f"Preprocessed images shape: {images.shape}") # Run inference print("Running inference...") dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16 with torch.no_grad(): with torch.cuda.amp.autocast(dtype=dtype): predictions = model(images) # Convert pose encoding to extrinsic and intrinsic matrices print("Converting pose encoding to extrinsic and intrinsic matrices...") extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], images.shape[-2:]) predictions["extrinsic"] = extrinsic predictions["intrinsic"] = intrinsic # Convert tensors to numpy for key in predictions.keys(): if isinstance(predictions[key], torch.Tensor): predictions[key] = predictions[key].cpu().numpy().squeeze(0) # remove batch dimension predictions['pose_enc_list'] = None # remove pose_enc_list # Generate world points from depth map print("Computing world points from depth map...") depth_map = predictions["depth"] # (S, H, W, 1) world_points = unproject_depth_map_to_point_map(depth_map, predictions["extrinsic"], predictions["intrinsic"]) predictions["world_points_from_depth"] = world_points # Clean up torch.cuda.empty_cache() return predictions # ------------------------------------------------------------------------- # 2) Handle uploaded video/images --> produce target_dir + images # ------------------------------------------------------------------------- def handle_uploads(input_video, input_images): """ Create a new 'target_dir' + 'images' subfolder, and place user-uploaded images or extracted frames from video into it. Return (target_dir, image_paths). """ start_time = time.time() gc.collect() torch.cuda.empty_cache() # Create a unique folder name timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") target_dir = f"input_images_{timestamp}" target_dir_images = os.path.join(target_dir, "images") # Clean up if somehow that folder already exists if os.path.exists(target_dir): shutil.rmtree(target_dir) os.makedirs(target_dir) os.makedirs(target_dir_images) image_paths = [] # --- Handle images --- if input_images is not None: for file_data in input_images: if isinstance(file_data, dict) and "name" in file_data: file_path = file_data["name"] else: file_path = file_data dst_path = os.path.join(target_dir_images, os.path.basename(file_path)) shutil.copy(file_path, dst_path) image_paths.append(dst_path) # --- Handle video --- if input_video is not None: if isinstance(input_video, dict) and "name" in input_video: video_path = input_video["name"] else: video_path = input_video vs = cv2.VideoCapture(video_path) fps = vs.get(cv2.CAP_PROP_FPS) frame_interval = int(fps * 1) # 1 frame/sec count = 0 video_frame_num = 0 while True: gotit, frame = vs.read() if not gotit: break count += 1 if count % frame_interval == 0: image_path = os.path.join(target_dir_images, f"{video_frame_num:06}.png") cv2.imwrite(image_path, frame) image_paths.append(image_path) video_frame_num += 1 # Sort final images for gallery image_paths = sorted(image_paths) end_time = time.time() print(f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds") return target_dir, image_paths # ------------------------------------------------------------------------- # 3) Update gallery on upload # ------------------------------------------------------------------------- def update_gallery_on_upload(input_video, input_images): """ Whenever user uploads or changes files, immediately handle them and show in the gallery. Return (target_dir, image_paths). If nothing is uploaded, returns "None" and empty list. """ if not input_video and not input_images: return None, None, None, None target_dir, image_paths = handle_uploads(input_video, input_images) return None, target_dir, image_paths, "Upload complete. Click 'Reconstruct' to begin 3D processing." # ------------------------------------------------------------------------- # 4) Reconstruction: uses the target_dir plus any viz parameters # ------------------------------------------------------------------------- def gradio_demo( target_dir, conf_thres=3.0, frame_filter="All", mask_black_bg=False, mask_white_bg=False, show_cam=True, mask_sky=False, prediction_mode="Pointmap Regression", ): """ Perform reconstruction using the already-created target_dir/images. """ if not os.path.isdir(target_dir) or target_dir == "None": return None, "No valid target directory found. Please upload first.", None, None start_time = time.time() gc.collect() torch.cuda.empty_cache() # Prepare frame_filter dropdown target_dir_images = os.path.join(target_dir, "images") all_files = sorted(os.listdir(target_dir_images)) if os.path.isdir(target_dir_images) else [] all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)] frame_filter_choices = ["All"] + all_files print("Running run_model...") with torch.no_grad(): predictions = run_model(target_dir, model) # Save predictions prediction_save_path = os.path.join(target_dir, "predictions.npz") np.savez(prediction_save_path, **predictions) # Handle None frame_filter if frame_filter is None: frame_filter = "All" # Build a GLB file name glbfile = os.path.join( target_dir, f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_pred{prediction_mode.replace(' ', '_')}.glb", ) # Convert predictions to GLB glbscene = predictions_to_glb( predictions, conf_thres=conf_thres, filter_by_frames=frame_filter, mask_black_bg=mask_black_bg, mask_white_bg=mask_white_bg, show_cam=show_cam, mask_sky=mask_sky, target_dir=target_dir, prediction_mode=prediction_mode, ) glbscene.export(file_obj=glbfile) # Cleanup del predictions gc.collect() torch.cuda.empty_cache() end_time = time.time() print(f"Total time: {end_time - start_time:.2f} seconds (including IO)") log_msg = f"Reconstruction Success ({len(all_files)} frames). Waiting for visualization." return glbfile, log_msg, gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True) # ------------------------------------------------------------------------- # 5) Helper functions for UI resets + re-visualization # ------------------------------------------------------------------------- def clear_fields(): """ Clears the 3D viewer, the stored target_dir, and empties the gallery. """ return None def update_log(): """ Display a quick log message while waiting. """ return "Loading and Reconstructing..." def update_visualization( target_dir, conf_thres, frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode, is_example ): """ Reload saved predictions from npz, create (or reuse) the GLB for new parameters, and return it for the 3D viewer. If is_example == "True", skip. """ # If it's an example click, skip as requested if is_example == "True": return None, "No reconstruction available. Please click the Reconstruct button first." if not target_dir or target_dir == "None" or not os.path.isdir(target_dir): return None, "No reconstruction available. Please click the Reconstruct button first." predictions_path = os.path.join(target_dir, "predictions.npz") if not os.path.exists(predictions_path): return None, f"No reconstruction available at {predictions_path}. Please run 'Reconstruct' first." key_list = [ "pose_enc", "depth", "depth_conf", "world_points", "world_points_conf", "images", "extrinsic", "intrinsic", "world_points_from_depth", ] loaded = np.load(predictions_path) predictions = {key: np.array(loaded[key]) for key in key_list} glbfile = os.path.join( target_dir, f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_pred{prediction_mode.replace(' ', '_')}.glb", ) if not os.path.exists(glbfile): glbscene = predictions_to_glb( predictions, conf_thres=conf_thres, filter_by_frames=frame_filter, mask_black_bg=mask_black_bg, mask_white_bg=mask_white_bg, show_cam=show_cam, mask_sky=mask_sky, target_dir=target_dir, prediction_mode=prediction_mode, ) glbscene.export(file_obj=glbfile) return glbfile, "Updating Visualization" # ------------------------------------------------------------------------- # Example images # ------------------------------------------------------------------------- great_wall_video = "examples/videos/great_wall.mp4" colosseum_video = "examples/videos/Colosseum.mp4" room_video = "examples/videos/room.mp4" kitchen_video = "examples/videos/kitchen.mp4" fern_video = "examples/videos/fern.mp4" single_cartoon_video = "examples/videos/single_cartoon.mp4" single_oil_painting_video = "examples/videos/single_oil_painting.mp4" pyramid_video = "examples/videos/pyramid.mp4" # ------------------------------------------------------------------------- # 6) Build Gradio UI # ------------------------------------------------------------------------- theme = gr.themes.Ocean() theme.set( checkbox_label_background_fill_selected="*button_primary_background_fill", checkbox_label_text_color_selected="*button_primary_text_color", ) with gr.Blocks( theme=theme, css=""" .custom-log * { font-style: italic; font-size: 22px !important; background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%); -webkit-background-clip: text; background-clip: text; font-weight: bold !important; color: transparent !important; text-align: center !important; } .example-log * { font-style: italic; font-size: 16px !important; background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%); -webkit-background-clip: text; background-clip: text; color: transparent !important; } #my_radio .wrap { display: flex; flex-wrap: nowrap; justify-content: center; align-items: center; } #my_radio .wrap label { display: flex; width: 50%; justify-content: center; align-items: center; margin: 0; padding: 10px 0; box-sizing: border-box; } """, ) as demo: # Instead of gr.State, we use a hidden Textbox: is_example = gr.Textbox(label="is_example", visible=False, value="None") num_images = gr.Textbox(label="num_images", visible=False, value="None") gr.HTML( """

🏛️ VGGT: Visual Geometry Grounded Transformer

🐙 GitHub Repository | Project Page

Upload a video or a set of images to create a 3D reconstruction of a scene or object. VGGT takes these images and generates a 3D point cloud, along with estimated camera poses.

Getting Started:

  1. Upload Your Data: Use the "Upload Video" or "Upload Images" buttons on the left to provide your input. Videos will be automatically split into individual frames (one frame per second).
  2. Preview: Your uploaded images will appear in the gallery on the left.
  3. Reconstruct: Click the "Reconstruct" button to start the 3D reconstruction process.
  4. Visualize: The 3D reconstruction will appear in the viewer on the right. You can rotate, pan, and zoom to explore the model, and download the GLB file. Note the visualization of 3D points may be slow for a large number of input images.
  5. Adjust Visualization (Optional): After reconstruction, you can fine-tune the visualization using the options below
    (click to expand):
    • Confidence Threshold: Adjust the filtering of points based on confidence.
    • Show Points from Frame: Select specific frames to display in the point cloud.
    • Show Camera: Toggle the display of estimated camera positions.
    • Filter Sky / Filter Black Background: Remove sky or black-background points.
    • Select a Prediction Mode: Choose between "Depthmap and Camera Branch" or "Pointmap Branch."

Please note: VGGT typically reconstructs a scene in less than 1 second. However, visualizing 3D points may take tens of seconds due to third-party rendering, which are independent of VGGT's processing time.

""" ) target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None") with gr.Row(): with gr.Column(scale=2): input_video = gr.Video(label="Upload Video", interactive=True) input_images = gr.File(file_count="multiple", label="Upload Images", interactive=True) image_gallery = gr.Gallery( label="Preview", columns=4, height="300px", show_download_button=True, object_fit="contain", preview=True, ) with gr.Column(scale=4): with gr.Column(): gr.Markdown("**3D Reconstruction (Point Cloud and Camera Poses)**") log_output = gr.Markdown( "Please upload a video or images, then click Reconstruct.", elem_classes=["custom-log"] ) reconstruction_output = gr.Model3D(height=520, zoom_speed=0.5, pan_speed=0.5) with gr.Row(): submit_btn = gr.Button("Reconstruct", scale=1, variant="primary") clear_btn = gr.ClearButton( [input_video, input_images, reconstruction_output, log_output, target_dir_output, image_gallery], scale=1, ) with gr.Row(): prediction_mode = gr.Radio( ["Depthmap and Camera Branch", "Pointmap Branch"], label="Select a Prediction Mode", value="Depthmap and Camera Branch", scale=1, elem_id="my_radio", ) with gr.Row(): conf_thres = gr.Slider(minimum=0, maximum=100, value=50, step=0.1, label="Confidence Threshold (%)") frame_filter = gr.Dropdown(choices=["All"], value="All", label="Show Points from Frame") with gr.Column(): show_cam = gr.Checkbox(label="Show Camera", value=True) mask_sky = gr.Checkbox(label="Filter Sky", value=False) mask_black_bg = gr.Checkbox(label="Filter Black Background", value=False) mask_white_bg = gr.Checkbox(label="Filter White Background", value=False) # ---------------------- Examples section ---------------------- examples = [ [colosseum_video, "22", None, 20.0, False, False, True, False, "Depthmap and Camera Branch", "True"], [pyramid_video, "30", None, 35.0, False, False, True, False, "Depthmap and Camera Branch", "True"], [single_cartoon_video, "1", None, 15.0, False, False, True, False, "Depthmap and Camera Branch", "True"], [single_oil_painting_video, "1", None, 20.0, False, False, True, True, "Depthmap and Camera Branch", "True"], [room_video, "8", None, 5.0, False, False, True, False, "Depthmap and Camera Branch", "True"], [kitchen_video, "25", None, 50.0, False, False, True, False, "Depthmap and Camera Branch", "True"], [fern_video, "20", None, 45.0, False, False, True, False, "Depthmap and Camera Branch", "True"], ] def example_pipeline( input_video, num_images_str, input_images, conf_thres, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode, is_example_str, ): """ 1) Copy example images to new target_dir 2) Reconstruct 3) Return model3D + logs + new_dir + updated dropdown + gallery We do NOT return is_example. It's just an input. """ target_dir, image_paths = handle_uploads(input_video, input_images) # Always use "All" for frame_filter in examples frame_filter = "All" glbfile, log_msg, dropdown = gradio_demo( target_dir, conf_thres, frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode ) return glbfile, log_msg, target_dir, dropdown, image_paths gr.Markdown("Click any row to load an example.", elem_classes=["example-log"]) gr.Examples( examples=examples, inputs=[ input_video, num_images, input_images, conf_thres, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode, is_example, ], outputs=[reconstruction_output, log_output, target_dir_output, frame_filter, image_gallery], fn=example_pipeline, cache_examples=False, examples_per_page=50, ) # ------------------------------------------------------------------------- # "Reconstruct" button logic: # - Clear fields # - Update log # - gradio_demo(...) with the existing target_dir # - Then set is_example = "False" # ------------------------------------------------------------------------- submit_btn.click(fn=clear_fields, inputs=[], outputs=[reconstruction_output]).then( fn=update_log, inputs=[], outputs=[log_output] ).then( fn=gradio_demo, inputs=[ target_dir_output, conf_thres, frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode, ], outputs=[reconstruction_output, log_output, frame_filter], ).then( fn=lambda: "False", inputs=[], outputs=[is_example] # set is_example to "False" ) # ------------------------------------------------------------------------- # Real-time Visualization Updates # ------------------------------------------------------------------------- conf_thres.change( update_visualization, [ target_dir_output, conf_thres, frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode, is_example, ], [reconstruction_output, log_output], ) frame_filter.change( update_visualization, [ target_dir_output, conf_thres, frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode, is_example, ], [reconstruction_output, log_output], ) mask_black_bg.change( update_visualization, [ target_dir_output, conf_thres, frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode, is_example, ], [reconstruction_output, log_output], ) mask_white_bg.change( update_visualization, [ target_dir_output, conf_thres, frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode, is_example, ], [reconstruction_output, log_output], ) show_cam.change( update_visualization, [ target_dir_output, conf_thres, frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode, is_example, ], [reconstruction_output, log_output], ) mask_sky.change( update_visualization, [ target_dir_output, conf_thres, frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode, is_example, ], [reconstruction_output, log_output], ) prediction_mode.change( update_visualization, [ target_dir_output, conf_thres, frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode, is_example, ], [reconstruction_output, log_output], ) # ------------------------------------------------------------------------- # Auto-update gallery whenever user uploads or changes their files # ------------------------------------------------------------------------- input_video.change( fn=update_gallery_on_upload, inputs=[input_video, input_images], outputs=[reconstruction_output, target_dir_output, image_gallery, log_output], ) input_images.change( fn=update_gallery_on_upload, inputs=[input_video, input_images], outputs=[reconstruction_output, target_dir_output, image_gallery, log_output], ) demo.queue(max_size=20).launch(show_error=True, share=True) ================================================ FILE: demo_viser.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import os import glob import time import threading import argparse from typing import List, Optional import numpy as np import torch from tqdm.auto import tqdm import viser import viser.transforms as viser_tf import cv2 try: import onnxruntime except ImportError: print("onnxruntime not found. Sky segmentation may not work.") from visual_util import segment_sky, download_file_from_url from vggt.models.vggt import VGGT from vggt.utils.load_fn import load_and_preprocess_images from vggt.utils.geometry import closed_form_inverse_se3, unproject_depth_map_to_point_map from vggt.utils.pose_enc import pose_encoding_to_extri_intri def viser_wrapper( pred_dict: dict, port: int = 8080, init_conf_threshold: float = 50.0, # represents percentage (e.g., 50 means filter lowest 50%) use_point_map: bool = False, background_mode: bool = False, mask_sky: bool = False, image_folder: str = None, ): """ Visualize predicted 3D points and camera poses with viser. Args: pred_dict (dict): { "images": (S, 3, H, W) - Input images, "world_points": (S, H, W, 3), "world_points_conf": (S, H, W), "depth": (S, H, W, 1), "depth_conf": (S, H, W), "extrinsic": (S, 3, 4), "intrinsic": (S, 3, 3), } port (int): Port number for the viser server. init_conf_threshold (float): Initial percentage of low-confidence points to filter out. use_point_map (bool): Whether to visualize world_points or use depth-based points. background_mode (bool): Whether to run the server in background thread. mask_sky (bool): Whether to apply sky segmentation to filter out sky points. image_folder (str): Path to the folder containing input images. """ print(f"Starting viser server on port {port}") server = viser.ViserServer(host="0.0.0.0", port=port) server.gui.configure_theme(titlebar_content=None, control_layout="collapsible") # Unpack prediction dict images = pred_dict["images"] # (S, 3, H, W) world_points_map = pred_dict["world_points"] # (S, H, W, 3) conf_map = pred_dict["world_points_conf"] # (S, H, W) depth_map = pred_dict["depth"] # (S, H, W, 1) depth_conf = pred_dict["depth_conf"] # (S, H, W) extrinsics_cam = pred_dict["extrinsic"] # (S, 3, 4) intrinsics_cam = pred_dict["intrinsic"] # (S, 3, 3) # Compute world points from depth if not using the precomputed point map if not use_point_map: world_points = unproject_depth_map_to_point_map(depth_map, extrinsics_cam, intrinsics_cam) conf = depth_conf else: world_points = world_points_map conf = conf_map # Apply sky segmentation if enabled if mask_sky and image_folder is not None: conf = apply_sky_segmentation(conf, image_folder) # Convert images from (S, 3, H, W) to (S, H, W, 3) # Then flatten everything for the point cloud colors = images.transpose(0, 2, 3, 1) # now (S, H, W, 3) S, H, W, _ = world_points.shape # Flatten points = world_points.reshape(-1, 3) colors_flat = (colors.reshape(-1, 3) * 255).astype(np.uint8) conf_flat = conf.reshape(-1) cam_to_world_mat = closed_form_inverse_se3(extrinsics_cam) # shape (S, 4, 4) typically # For convenience, we store only (3,4) portion cam_to_world = cam_to_world_mat[:, :3, :] # Compute scene center and recenter scene_center = np.mean(points, axis=0) points_centered = points - scene_center cam_to_world[..., -1] -= scene_center # Store frame indices so we can filter by frame frame_indices = np.repeat(np.arange(S), H * W) # Build the viser GUI gui_show_frames = server.gui.add_checkbox("Show Cameras", initial_value=True) # Now the slider represents percentage of points to filter out gui_points_conf = server.gui.add_slider( "Confidence Percent", min=0, max=100, step=0.1, initial_value=init_conf_threshold ) gui_frame_selector = server.gui.add_dropdown( "Show Points from Frames", options=["All"] + [str(i) for i in range(S)], initial_value="All" ) # Create the main point cloud handle # Compute the threshold value as the given percentile init_threshold_val = np.percentile(conf_flat, init_conf_threshold) init_conf_mask = (conf_flat >= init_threshold_val) & (conf_flat > 0.1) point_cloud = server.scene.add_point_cloud( name="viser_pcd", points=points_centered[init_conf_mask], colors=colors_flat[init_conf_mask], point_size=0.001, point_shape="circle", ) # We will store references to frames & frustums so we can toggle visibility frames: List[viser.FrameHandle] = [] frustums: List[viser.CameraFrustumHandle] = [] def visualize_frames(extrinsics: np.ndarray, images_: np.ndarray) -> None: """ Add camera frames and frustums to the scene. extrinsics: (S, 3, 4) images_: (S, 3, H, W) """ # Clear any existing frames or frustums for f in frames: f.remove() frames.clear() for fr in frustums: fr.remove() frustums.clear() # Optionally attach a callback that sets the viewpoint to the chosen camera def attach_callback(frustum: viser.CameraFrustumHandle, frame: viser.FrameHandle) -> None: @frustum.on_click def _(_) -> None: for client in server.get_clients().values(): client.camera.wxyz = frame.wxyz client.camera.position = frame.position img_ids = range(S) for img_id in tqdm(img_ids): cam2world_3x4 = extrinsics[img_id] T_world_camera = viser_tf.SE3.from_matrix(cam2world_3x4) # Add a small frame axis frame_axis = server.scene.add_frame( f"frame_{img_id}", wxyz=T_world_camera.rotation().wxyz, position=T_world_camera.translation(), axes_length=0.05, axes_radius=0.002, origin_radius=0.002, ) frames.append(frame_axis) # Convert the image for the frustum img = images_[img_id] # shape (3, H, W) img = (img.transpose(1, 2, 0) * 255).astype(np.uint8) h, w = img.shape[:2] # If you want correct FOV from intrinsics, do something like: # fx = intrinsics_cam[img_id, 0, 0] # fov = 2 * np.arctan2(h/2, fx) # For demonstration, we pick a simple approximate FOV: fy = 1.1 * h fov = 2 * np.arctan2(h / 2, fy) # Add the frustum frustum_cam = server.scene.add_camera_frustum( f"frame_{img_id}/frustum", fov=fov, aspect=w / h, scale=0.05, image=img, line_width=1.0 ) frustums.append(frustum_cam) attach_callback(frustum_cam, frame_axis) def update_point_cloud() -> None: """Update the point cloud based on current GUI selections.""" # Here we compute the threshold value based on the current percentage current_percentage = gui_points_conf.value threshold_val = np.percentile(conf_flat, current_percentage) print(f"Threshold absolute value: {threshold_val}, percentage: {current_percentage}%") conf_mask = (conf_flat >= threshold_val) & (conf_flat > 1e-5) if gui_frame_selector.value == "All": frame_mask = np.ones_like(conf_mask, dtype=bool) else: selected_idx = int(gui_frame_selector.value) frame_mask = frame_indices == selected_idx combined_mask = conf_mask & frame_mask point_cloud.points = points_centered[combined_mask] point_cloud.colors = colors_flat[combined_mask] @gui_points_conf.on_update def _(_) -> None: update_point_cloud() @gui_frame_selector.on_update def _(_) -> None: update_point_cloud() @gui_show_frames.on_update def _(_) -> None: """Toggle visibility of camera frames and frustums.""" for f in frames: f.visible = gui_show_frames.value for fr in frustums: fr.visible = gui_show_frames.value # Add the camera frames to the scene visualize_frames(cam_to_world, images) print("Starting viser server...") # If background_mode is True, spawn a daemon thread so the main thread can continue. if background_mode: def server_loop(): while True: time.sleep(0.001) thread = threading.Thread(target=server_loop, daemon=True) thread.start() else: while True: time.sleep(0.01) return server # Helper functions for sky segmentation def apply_sky_segmentation(conf: np.ndarray, image_folder: str) -> np.ndarray: """ Apply sky segmentation to confidence scores. Args: conf (np.ndarray): Confidence scores with shape (S, H, W) image_folder (str): Path to the folder containing input images Returns: np.ndarray: Updated confidence scores with sky regions masked out """ S, H, W = conf.shape sky_masks_dir = image_folder.rstrip("/") + "_sky_masks" os.makedirs(sky_masks_dir, exist_ok=True) # Download skyseg.onnx if it doesn't exist if not os.path.exists("skyseg.onnx"): print("Downloading skyseg.onnx...") download_file_from_url("https://huggingface.co/JianyuanWang/skyseg/resolve/main/skyseg.onnx", "skyseg.onnx") skyseg_session = onnxruntime.InferenceSession("skyseg.onnx") image_files = sorted(glob.glob(os.path.join(image_folder, "*"))) sky_mask_list = [] print("Generating sky masks...") for i, image_path in enumerate(tqdm(image_files[:S])): # Limit to the number of images in the batch image_name = os.path.basename(image_path) mask_filepath = os.path.join(sky_masks_dir, image_name) if os.path.exists(mask_filepath): sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE) else: sky_mask = segment_sky(image_path, skyseg_session, mask_filepath) # Resize mask to match H×W if needed if sky_mask.shape[0] != H or sky_mask.shape[1] != W: sky_mask = cv2.resize(sky_mask, (W, H)) sky_mask_list.append(sky_mask) # Convert list to numpy array with shape S×H×W sky_mask_array = np.array(sky_mask_list) # Apply sky mask to confidence scores sky_mask_binary = (sky_mask_array > 0.1).astype(np.float32) conf = conf * sky_mask_binary print("Sky segmentation applied successfully") return conf parser = argparse.ArgumentParser(description="VGGT demo with viser for 3D visualization") parser.add_argument( "--image_folder", type=str, default="examples/kitchen/images/", help="Path to folder containing images" ) parser.add_argument("--use_point_map", action="store_true", help="Use point map instead of depth-based points") parser.add_argument("--background_mode", action="store_true", help="Run the viser server in background mode") parser.add_argument("--port", type=int, default=8080, help="Port number for the viser server") parser.add_argument( "--conf_threshold", type=float, default=25.0, help="Initial percentage of low-confidence points to filter out" ) parser.add_argument("--mask_sky", action="store_true", help="Apply sky segmentation to filter out sky points") def main(): """ Main function for the VGGT demo with viser for 3D visualization. This function: 1. Loads the VGGT model 2. Processes input images from the specified folder 3. Runs inference to generate 3D points and camera poses 4. Optionally applies sky segmentation to filter out sky points 5. Visualizes the results using viser Command-line arguments: --image_folder: Path to folder containing input images --use_point_map: Use point map instead of depth-based points --background_mode: Run the viser server in background mode --port: Port number for the viser server --conf_threshold: Initial percentage of low-confidence points to filter out --mask_sky: Apply sky segmentation to filter out sky points """ args = parser.parse_args() device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") print("Initializing and loading VGGT model...") # model = VGGT.from_pretrained("facebook/VGGT-1B") model = VGGT() _URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt" model.load_state_dict(torch.hub.load_state_dict_from_url(_URL)) model.eval() model = model.to(device) # Use the provided image folder path print(f"Loading images from {args.image_folder}...") image_names = glob.glob(os.path.join(args.image_folder, "*")) print(f"Found {len(image_names)} images") images = load_and_preprocess_images(image_names).to(device) print(f"Preprocessed images shape: {images.shape}") print("Running inference...") dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16 with torch.no_grad(): with torch.cuda.amp.autocast(dtype=dtype): predictions = model(images) print("Converting pose encoding to extrinsic and intrinsic matrices...") extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], images.shape[-2:]) predictions["extrinsic"] = extrinsic predictions["intrinsic"] = intrinsic print("Processing model outputs...") for key in predictions.keys(): if isinstance(predictions[key], torch.Tensor): predictions[key] = predictions[key].cpu().numpy().squeeze(0) # remove batch dimension and convert to numpy if args.use_point_map: print("Visualizing 3D points from point map") else: print("Visualizing 3D points by unprojecting depth map by cameras") if args.mask_sky: print("Sky segmentation enabled - will filter out sky points") print("Starting viser visualization...") viser_server = viser_wrapper( predictions, port=args.port, init_conf_threshold=args.conf_threshold, use_point_map=args.use_point_map, background_mode=args.background_mode, mask_sky=args.mask_sky, image_folder=args.image_folder, ) print("Visualization complete") if __name__ == "__main__": main() ================================================ FILE: docs/package.md ================================================ # Alternative Installation Methods This document explains how to install VGGT as a package using different package managers. ## Prerequisites Before installing VGGT as a package, you need to install PyTorch and torchvision. We don't list these as dependencies to avoid CUDA version mismatches. Install them first, with an example as: ```bash # install pytorch 2.3.1 with cuda 12.1 pip install torch==2.3.1 torchvision==0.18.1 --index-url https://download.pytorch.org/whl/cu121 ``` ## Installation Options ### Install with pip The simplest way to install VGGT is using pip: ```bash pip install -e . ``` ### Install and run with pixi [Pixi](https://pixi.sh) is a package management tool for creating reproducible environments. 1. First, [download and install pixi](https://pixi.sh/latest/get_started/) 2. Then run: ```bash pixi run -e python demo_gradio.py ``` ### Install and run with uv [uv](https://docs.astral.sh/uv/) is a fast Python package installer and resolver. 1. First, [install uv](https://docs.astral.sh/uv/getting-started/installation/) 2. Then run: ```bash uv run --extra demo demo_gradio.py ``` ================================================ FILE: pyproject.toml ================================================ [project] authors = [{name = "Jianyuan Wang", email = "jianyuan@robots.ox.ac.uk"}] dependencies = [ "numpy<2", "Pillow", "huggingface_hub", "einops", "safetensors", "opencv-python", ] name = "vggt" requires-python = ">= 3.10" version = "0.0.1" [project.optional-dependencies] demo = [ "gradio==5.17.1", "viser==0.2.23", "tqdm", "hydra-core", "omegaconf", "opencv-python", "scipy", "onnxruntime", "requests", "trimesh", "matplotlib", ] # Using setuptools as the build backend [build-system] requires = ["setuptools>=61.0", "wheel"] build-backend = "setuptools.build_meta" # setuptools configuration [tool.setuptools.packages.find] where = ["."] include = ["vggt*"] # Pixi configuration [tool.pixi.workspace] channels = ["conda-forge"] platforms = ["linux-64"] [tool.pixi.pypi-dependencies] vggt = { path = ".", editable = true } [tool.pixi.environments] default = { solve-group = "default" } demo = { features = ["demo"], solve-group = "default" } [tool.pixi.tasks] ================================================ FILE: requirements.txt ================================================ torch==2.3.1 torchvision==0.18.1 numpy==1.26.1 Pillow huggingface_hub einops safetensors ================================================ FILE: requirements_demo.txt ================================================ gradio==5.17.1 viser==0.2.23 tqdm hydra-core omegaconf opencv-python scipy onnxruntime requests trimesh matplotlib pydantic==2.10.6 # feel free to skip the dependencies below if you do not need demo_colmap.py pycolmap==3.10.0 pyceres==2.3 git+https://github.com/jytime/LightGlue.git#egg=lightglue ================================================ FILE: training/README.md ================================================ # Training This is a re-implementation of our framework for training VGGT. This document shows how to set up the environment and run VGGT training. I have aimed to faithfully reproduce the original training framework, but please open an issue if anything looks off. ## 1. Prerequisites Before you begin, ensure you have completed the following steps: 1. **Install VGGT as a package:** ```bash pip install -e . ``` 2. **Prepare the dataset and annotations:** - Download the Co3D dataset from the [official repository](https://github.com/facebookresearch/co3d). - Download the required annotation files from [Hugging Face](https://huggingface.co/datasets/JianyuanWang/co3d_anno/tree/main). ## 2. Configuration After downloading the dataset and annotations, configure the paths in `training/config/default.yaml`. ### Required Path Configuration 1. Open `training/config/default.yaml` 2. Update the following paths with your absolute directory paths: - `CO3D_DIR`: Path to your Co3D dataset - `CO3D_ANNOTATION_DIR`: Path to your Co3D annotation files - `resume_checkpoint_path`: Path to your pre-trained VGGT checkpoint ### Configuration Example ```yaml data: train: dataset: dataset_configs: - _target_: data.datasets.co3d.Co3dDataset split: train CO3D_DIR: /YOUR/PATH/TO/CO3D CO3D_ANNOTATION_DIR: /YOUR/PATH/TO/CO3D_ANNOTATION # ... same for val ... checkpoint: resume_checkpoint_path: /YOUR/PATH/TO/CKPT ``` ## 3. Fine-tuning on Co3D To fine-tune the provided pre-trained model on the Co3D dataset, run the following command. This example uses 4 GPUs with PyTorch Distributed Data Parallel (DDP): ```bash torchrun --nproc_per_node=4 launch.py ``` The default configuration in `training/config/default.yaml` is set up for fine-tuning. It automatically resumes from a checkpoint and freezes the model's `aggregator` module during training. ## 4. Training on Multiple Datasets The dataloader supports multiple datasets naturally. For example, if you have downloaded VKitti using `preprocess/vkitti.sh`, you can train on Co3D+VKitti by configuring: ```yaml data: train: dataset: _target_: data.composed_dataset.ComposedDataset dataset_configs: - _target_: data.datasets.co3d.Co3dDataset split: train CO3D_DIR: /YOUR/PATH/TO/CO3D CO3D_ANNOTATION_DIR: /YOUR/PATH/TO/CO3D_ANNOTATION len_train: 100000 - _target_: data.datasets.vkitti.VKittiDataset split: train VKitti_DIR: /YOUR/PATH/TO/VKitti len_train: 100000 expand_ratio: 8 ``` The ratio of different datasets can be controlled by setting `len_train`. For example, Co3D with `len_train: 10000` and VKitti with `len_train: 2000` will result in Co3D being sampled five times more frequently than VKitti. ## 5. Common Questions ### Memory Management If you encounter out-of-memory (OOM) errors on your GPU, consider adjusting the following parameters in `training/config/default.yaml`: - `max_img_per_gpu`: Reduce this value to decrease the batch size per GPU - `accum_steps`: Sets the number of gradient accumulation steps (default is 2). This feature splits batches into smaller chunks to save memory, though it may slightly increase training time. Note that gradient accumulation was not used for the original VGGT model. ### Learning Rate Tuning The main hyperparameter to be careful about is learning rate. Note that learning rate depends on the effective batch size, which is `batch_size_per_gpu × num_gpus`. Therefore, I highly recommend trying several learning rates based on your training setup. Generally, trying values like `5e-6`, `1e-5`, `5e-5`, `1e-4`, `5e-4` should be sufficient. ### Tracking Head The tracking head can slightly improve accuracy but is not necessary. For general cases, especially when GPU resources are limited, we suggest fine-tuning the pre-trained model only with camera and depth heads, which is the setting in `default.yaml`. This will provide good enough results. ### Dataloader Validation To check if your dataloader is working correctly, the best approach is to visualize its output. You can save the 3D world points as follows and then visually inspect the PLY files: ```python def save_ply(points, colors, filename): import open3d as o3d if torch.is_tensor(points): points_visual = points.reshape(-1, 3).cpu().numpy() else: points_visual = points.reshape(-1, 3) if torch.is_tensor(colors): points_visual_rgb = colors.reshape(-1, 3).cpu().numpy() else: points_visual_rgb = colors.reshape(-1, 3) pcd = o3d.geometry.PointCloud() pcd.points = o3d.utility.Vector3dVector(points_visual.astype(np.float64)) pcd.colors = o3d.utility.Vector3dVector(points_visual_rgb.astype(np.float64)) o3d.io.write_point_cloud(filename, pcd, write_ascii=True) # Usage example save_ply( batch["world_points"][0].reshape(-1, 3), batch["images"][0].permute(0, 2, 3, 1).reshape(-1, 3), "debug.ply" ) ``` ### Handling Unordered Sequences For unordered sequences, you can check how we compute the ranking (similarity) between one frame and all other frames, as discussed in [Issue #82](https://github.com/facebookresearch/vggt/issues/82). ### Expected Coordinate System Camera poses are expected to follow the OpenCV `camera-from-world` convention. Depth maps should be aligned with their corresponding camera poses. ================================================ FILE: training/__init__.py ================================================ ================================================ FILE: training/config/default.yaml ================================================ defaults: - default_dataset.yaml exp_name: exp001 img_size: 518 num_workers: 8 seed_value: 42 accum_steps: 2 # We did not use gradient accumulation in our training, while if you suffer from OOM, you can try to use it. patch_size: 14 val_epoch_freq: 5 max_img_per_gpu: 48 limit_train_batches: 800 limit_val_batches: 400 data: # The code for data still looks too complicated. I should refactor this again (do I have time?...) train: _target_: data.dynamic_dataloader.DynamicTorchDataset num_workers: ${num_workers} max_img_per_gpu: ${max_img_per_gpu} common_config: img_size: ${img_size} patch_size: ${patch_size} debug: False repeat_batch: False dataset: _target_: data.composed_dataset.ComposedDataset dataset_configs: - _target_: data.datasets.co3d.Co3dDataset split: train CO3D_DIR: /YOUR/PATH/TO/CO3D CO3D_ANNOTATION_DIR: /YOUR/PATH/TO/CO3D_ANNOTATION val: _target_: data.dynamic_dataloader.DynamicTorchDataset num_workers: ${num_workers} max_img_per_gpu: ${max_img_per_gpu} common_config: img_size: ${img_size} patch_size: ${patch_size} debug: False dataset: _target_: data.composed_dataset.ComposedDataset dataset_configs: - _target_: data.datasets.co3d.Co3dDataset split: test CO3D_DIR: /YOUR/PATH/TO/CO3D CO3D_ANNOTATION_DIR: /YOUR/PATH/TO/CO3D_ANNOTATION logging: log_dir: logs log_visuals: False log_freq: 1 log_level_primary: DEBUG log_level_secondary: WARNING all_ranks: False tensorboard_writer: _target_: train_utils.tb_writer.TensorBoardLogger path: ${logging.log_dir}/tensorboard scalar_keys_to_log: train: keys_to_log: - loss_objective - loss_camera - loss_T - loss_R - loss_FL - loss_conf_depth - loss_reg_depth - loss_grad_depth val: keys_to_log: - loss_objective - loss_camera - loss_T - loss_R - loss_FL - loss_conf_depth - loss_reg_depth - loss_grad_depth checkpoint: save_dir: logs/${exp_name}/ckpts save_freq: 5 resume_checkpoint_path: /YOUR/PATH/TO/CKPT strict: False loss: _target_: loss.MultitaskLoss camera: weight: 5.0 loss_type: "l1" # The paper uses smooth l1 loss, but we found l1 loss is more stable than smooth l1 and l2 loss. depth: weight: 1.0 gradient_loss_fn: "grad" valid_range: 0.98 point: null # If you want to enable point, use the following config # point: # weight: 1.0 # gradient_loss_fn: "normal" # valid_range: 0.98 track: null optim: param_group_modifiers: False optimizer: _target_: torch.optim.AdamW lr: 5e-5 weight_decay: 0.05 frozen_module_names: - "*aggregator*" # example, freeze the aggregator amp: enabled: True amp_dtype: bfloat16 gradient_clip: _target_: train_utils.gradient_clip.GradientClipper configs: - module_name: ["aggregator"] max_norm: 1.0 # feel free to reduce this if you see instabilities norm_type: 2 - module_name: ["depth"] max_norm: 1.0 # feel free to reduce this if you see instabilities norm_type: 2 - module_name: ["camera"] max_norm: 1.0 # feel free to reduce this if you see instabilities norm_type: 2 options: lr: - scheduler: _target_: fvcore.common.param_scheduler.CompositeParamScheduler schedulers: - _target_: fvcore.common.param_scheduler.LinearParamScheduler start_value: 1e-8 end_value: 5e-5 - _target_: fvcore.common.param_scheduler.CosineParamScheduler start_value: 5e-5 end_value: 1e-8 lengths: [0.05, 0.95] interval_scaling: ['rescaled', 'rescaled'] weight_decay: - scheduler: _target_: fvcore.common.param_scheduler.ConstantParamScheduler value: 0.05 max_epochs: 20 model: _target_: vggt.models.vggt.VGGT enable_camera: True enable_depth: True enable_point: False enable_track: False distributed: # check https://docs.pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html for options backend: nccl comms_dtype: None find_unused_parameters: False timeout_mins: 30 gradient_as_bucket_view: True # Less memory used bucket_cap_mb: 25 broadcast_buffers: True cuda: cudnn_deterministic: False cudnn_benchmark: False allow_tf32: True ================================================ FILE: training/config/default_dataset.yaml ================================================ # Template for the dataset config data: # The code still looks too complicated. I should refactor this again (do I have time?...) train: _target_: data.dynamic_dataloader.DynamicTorchDataset num_workers: 8 max_img_per_gpu: 48 # Shuffling in PyTorch DataLoader can sometimes copy large dicts and exceed CPU memory # (see: https://github.com/pytorch/pytorch/issues/13246). # To avoid this, set shuffle=False and enable common_config.inside_random=True instead. shuffle: True pin_memory: False common_config: # common config for evaluation fix_img_num: -1 # -1 means do not fix the number of images fix_aspect_ratio: 1.0 load_track: False track_num: 1024 training: True inside_random: True img_size: 224 patch_size: 14 rescale: True rescale_aug: True landscape_check: False debug: False get_nearby: True load_depth: True img_nums: [2, 24] max_img_per_gpu: 48 allow_duplicate_img: True repeat_batch: False augs: cojitter: True cojitter_ratio: 0.3 scales: [0.8, 1.2] aspects: [0.33, 1.0] color_jitter: brightness: 0.5 contrast: 0.5 saturation: 0.5 hue: 0.1 p: 0.9 gray_scale: True gau_blur: False val: _target_: data.dynamic_dataloader.DynamicTorchDataset num_workers: 8 max_img_per_gpu: 48 # Shuffling in PyTorch DataLoader can sometimes copy large dicts and exceed CPU memory # (see: https://github.com/pytorch/pytorch/issues/13246). # To avoid this, set shuffle=False and enable common_config.inside_random=True instead. shuffle: True pin_memory: False common_config: # common config for evaluation fix_img_num: -1 # -1 means do not fix the number of images fix_aspect_ratio: 1.0 load_track: False track_num: 1024 training: False inside_random: True img_size: 224 patch_size: 14 rescale: True rescale_aug: False landscape_check: False debug: False get_nearby: True load_depth: True img_nums: [2, 12] allow_duplicate_img: True augs: cojitter: False cojitter_ratio: 0.5 scales: null aspects: [1.0, 1.0] color_jitter: null gray_scale: False gau_blur: False ================================================ FILE: training/data/__init__.py ================================================ ================================================ FILE: training/data/augmentation.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from typing import Optional, Dict from torchvision import transforms def get_image_augmentation( color_jitter: Optional[Dict[str, float]] = None, gray_scale: bool = True, gau_blur: bool = False ) -> Optional[transforms.Compose]: """Create a composition of image augmentations. Args: color_jitter: Dictionary containing color jitter parameters: - brightness: float (default: 0.5) - contrast: float (default: 0.5) - saturation: float (default: 0.5) - hue: float (default: 0.1) - p: probability of applying (default: 0.9) If None, uses default values gray_scale: Whether to apply random grayscale (default: True) gau_blur: Whether to apply gaussian blur (default: False) Returns: A Compose object of transforms or None if no transforms are added """ transform_list = [] default_jitter = { "brightness": 0.5, "contrast": 0.5, "saturation": 0.5, "hue": 0.1, "p": 0.9 } # Handle color jitter if color_jitter is not None: # Merge with defaults for missing keys effective_jitter = {**default_jitter, **color_jitter} else: effective_jitter = default_jitter transform_list.append( transforms.RandomApply( [ transforms.ColorJitter( brightness=effective_jitter["brightness"], contrast=effective_jitter["contrast"], saturation=effective_jitter["saturation"], hue=effective_jitter["hue"], ) ], p=effective_jitter["p"], ) ) if gray_scale: transform_list.append(transforms.RandomGrayscale(p=0.05)) if gau_blur: transform_list.append( transforms.RandomApply( [transforms.GaussianBlur(5, sigma=(0.1, 1.0))], p=0.05 ) ) return transforms.Compose(transform_list) if transform_list else None ================================================ FILE: training/data/base_dataset.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import numpy as np from PIL import Image, ImageFile from torch.utils.data import Dataset from .dataset_util import * Image.MAX_IMAGE_PIXELS = None ImageFile.LOAD_TRUNCATED_IMAGES = True class BaseDataset(Dataset): """ Base dataset class for VGGT and VGGSfM training. This abstract class handles common operations like image resizing, augmentation, and coordinate transformations. Concrete dataset implementations should inherit from this class. Attributes: img_size: Target image size (typically the width) patch_size: Size of patches for vit augs.scales: Scale range for data augmentation [min, max] rescale: Whether to rescale images rescale_aug: Whether to apply augmentation during rescaling landscape_check: Whether to handle landscape vs portrait orientation """ def __init__( self, common_conf, ): """ Initialize the base dataset with common configuration. Args: common_conf: Configuration object with the following properties, shared by all datasets: - img_size: Default is 518 - patch_size: Default is 14 - augs.scales: Default is [0.8, 1.2] - rescale: Default is True - rescale_aug: Default is True - landscape_check: Default is True """ super().__init__() self.img_size = common_conf.img_size self.patch_size = common_conf.patch_size self.aug_scale = common_conf.augs.scales self.rescale = common_conf.rescale self.rescale_aug = common_conf.rescale_aug self.landscape_check = common_conf.landscape_check def __len__(self): return self.len_train def __getitem__(self, idx_N): """ Get an item from the dataset. Args: idx_N: Tuple containing (seq_index, img_per_seq, aspect_ratio) Returns: Dataset item as returned by get_data() """ seq_index, img_per_seq, aspect_ratio = idx_N return self.get_data( seq_index=seq_index, img_per_seq=img_per_seq, aspect_ratio=aspect_ratio ) def get_data(self, seq_index=None, seq_name=None, ids=None, aspect_ratio=1.0): """ Abstract method to retrieve data for a given sequence. Args: seq_index (int, optional): Index of the sequence seq_name (str, optional): Name of the sequence ids (list, optional): List of frame IDs aspect_ratio (float, optional): Target aspect ratio. Returns: Dataset-specific data Raises: NotImplementedError: This method must be implemented by subclasses """ raise NotImplementedError( "This is an abstract method and should be implemented in the subclass, i.e., each dataset should implement its own get_data method." ) def get_target_shape(self, aspect_ratio): """ Calculate the target shape based on the given aspect ratio. Args: aspect_ratio: Target aspect ratio Returns: numpy.ndarray: Target image shape [height, width] """ short_size = int(self.img_size * aspect_ratio) small_size = self.patch_size # ensure the input shape is friendly to vision transformer if short_size % small_size != 0: short_size = (short_size // small_size) * small_size image_shape = np.array([short_size, self.img_size]) return image_shape def process_one_image( self, image, depth_map, extri_opencv, intri_opencv, original_size, target_image_shape, track=None, filepath=None, safe_bound=4, ): """ Process a single image and its associated data. This method handles image transformations, depth processing, and coordinate conversions. Args: image (numpy.ndarray): Input image array depth_map (numpy.ndarray): Depth map array extri_opencv (numpy.ndarray): Extrinsic camera matrix (OpenCV convention) intri_opencv (numpy.ndarray): Intrinsic camera matrix (OpenCV convention) original_size (numpy.ndarray): Original image size [height, width] target_image_shape (numpy.ndarray): Target image shape after processing track (numpy.ndarray, optional): Optional tracking information. Defaults to None. filepath (str, optional): Optional file path for debugging. Defaults to None. safe_bound (int, optional): Safety margin for cropping operations. Defaults to 4. Returns: tuple: ( image (numpy.ndarray): Processed image, depth_map (numpy.ndarray): Processed depth map, extri_opencv (numpy.ndarray): Updated extrinsic matrix, intri_opencv (numpy.ndarray): Updated intrinsic matrix, world_coords_points (numpy.ndarray): 3D points in world coordinates, cam_coords_points (numpy.ndarray): 3D points in camera coordinates, point_mask (numpy.ndarray): Boolean mask of valid points, track (numpy.ndarray, optional): Updated tracking information ) """ # Make copies to avoid in-place operations affecting original data image = np.copy(image) depth_map = np.copy(depth_map) extri_opencv = np.copy(extri_opencv) intri_opencv = np.copy(intri_opencv) if track is not None: track = np.copy(track) # Apply random scale augmentation during training if enabled if self.training and self.aug_scale: random_h_scale, random_w_scale = np.random.uniform( self.aug_scale[0], self.aug_scale[1], 2 ) # Avoid random padding by capping at 1.0 random_h_scale = min(random_h_scale, 1.0) random_w_scale = min(random_w_scale, 1.0) aug_size = original_size * np.array([random_h_scale, random_w_scale]) aug_size = aug_size.astype(np.int32) else: aug_size = original_size # Move principal point to the image center and crop if necessary image, depth_map, intri_opencv, track = crop_image_depth_and_intrinsic_by_pp( image, depth_map, intri_opencv, aug_size, track=track, filepath=filepath, ) original_size = np.array(image.shape[:2]) # update original_size target_shape = target_image_shape # Handle landscape vs. portrait orientation rotate_to_portrait = False if self.landscape_check: # Switch between landscape and portrait if necessary if original_size[0] > 1.25 * original_size[1]: if (target_image_shape[0] != target_image_shape[1]) and (np.random.rand() > 0.5): target_shape = np.array([target_image_shape[1], target_image_shape[0]]) rotate_to_portrait = True # Resize images and update intrinsics if self.rescale: image, depth_map, intri_opencv, track = resize_image_depth_and_intrinsic( image, depth_map, intri_opencv, target_shape, original_size, track=track, safe_bound=safe_bound, rescale_aug=self.rescale_aug ) else: print("Not rescaling the images") # Ensure final crop to target shape image, depth_map, intri_opencv, track = crop_image_depth_and_intrinsic_by_pp( image, depth_map, intri_opencv, target_shape, track=track, filepath=filepath, strict=True, ) # Apply 90-degree rotation if needed if rotate_to_portrait: assert self.landscape_check clockwise = np.random.rand() > 0.5 image, depth_map, extri_opencv, intri_opencv, track = rotate_90_degrees( image, depth_map, extri_opencv, intri_opencv, clockwise=clockwise, track=track, ) # Convert depth to world and camera coordinates world_coords_points, cam_coords_points, point_mask = ( depth_to_world_coords_points(depth_map, extri_opencv, intri_opencv) ) return ( image, depth_map, extri_opencv, intri_opencv, world_coords_points, cam_coords_points, point_mask, track, ) def get_nearby_ids(self, ids, full_seq_num, expand_ratio=None, expand_range=None): """ TODO: add the function to sample the ids by pose similarity ranking. Sample a set of IDs from a sequence close to a given start index. You can specify the range either as a ratio of the number of input IDs or as a fixed integer window. Args: ids (list): Initial list of IDs. The first element is used as the anchor. full_seq_num (int): Total number of items in the full sequence. expand_ratio (float, optional): Factor by which the number of IDs expands around the start index. Default is 2.0 if neither expand_ratio nor expand_range is provided. expand_range (int, optional): Fixed number of items to expand around the start index. If provided, expand_ratio is ignored. Returns: numpy.ndarray: Array of sampled IDs, with the first element being the original start index. Examples: # Using expand_ratio (default behavior) # If ids=[100,101,102] and full_seq_num=200, with expand_ratio=2.0, # expand_range = int(3 * 2.0) = 6, so IDs sampled from [94...106] (if boundaries allow). # Using expand_range directly # If ids=[100,101,102] and full_seq_num=200, with expand_range=10, # IDs are sampled from [90...110] (if boundaries allow). Raises: ValueError: If no IDs are provided. """ if len(ids) == 0: raise ValueError("No IDs provided.") if expand_range is None and expand_ratio is None: expand_ratio = 2.0 # Default behavior total_ids = len(ids) start_idx = ids[0] # Determine the actual expand_range if expand_range is None: # Use ratio to determine range expand_range = int(total_ids * expand_ratio) # Calculate valid boundaries low_bound = max(0, start_idx - expand_range) high_bound = min(full_seq_num, start_idx + expand_range) # Create the valid range of indices valid_range = np.arange(low_bound, high_bound) # Sample 'total_ids - 1' items, because we already have the start_idx sampled_ids = np.random.choice( valid_range, size=(total_ids - 1), replace=True, # we accept the situation that some sampled ids are the same ) # Insert the start_idx at the beginning result_ids = np.insert(sampled_ids, 0, start_idx) return result_ids ================================================ FILE: training/data/composed_dataset.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from abc import ABC from hydra.utils import instantiate import torch import random import numpy as np from torch.utils.data import Dataset from torch.utils.data import ConcatDataset import bisect from .dataset_util import * from .track_util import * from .augmentation import get_image_augmentation class ComposedDataset(Dataset, ABC): """ Composes multiple base datasets and applies common configurations. This dataset provides a flexible way to combine multiple base datasets while applying shared augmentations, track generation, and other processing steps. It handles image normalization, tensor conversion, and other preparations needed for training computer vision models with sequences of images. """ def __init__(self, dataset_configs: dict, common_config: dict, **kwargs): """ Initializes the ComposedDataset. Args: dataset_configs (dict): List of Hydra configurations for base datasets. common_config (dict): Shared configurations (augs, tracks, mode, etc.). **kwargs: Additional arguments (unused). """ base_dataset_list = [] # Instantiate each base dataset with common configuration for baseset_dict in dataset_configs: baseset = instantiate(baseset_dict, common_conf=common_config) base_dataset_list.append(baseset) # Use custom concatenation class that supports tuple indexing self.base_dataset = TupleConcatDataset(base_dataset_list, common_config) # --- Augmentation Settings --- # Controls whether to apply identical color jittering across all frames in a sequence self.cojitter = common_config.augs.cojitter # Probability of using shared jitter vs. frame-specific jitter self.cojitter_ratio = common_config.augs.cojitter_ratio # Initialize image augmentations (color jitter, grayscale, gaussian blur) self.image_aug = get_image_augmentation( color_jitter=common_config.augs.color_jitter, gray_scale=common_config.augs.gray_scale, gau_blur=common_config.augs.gau_blur, ) # --- Optional Fixed Settings (useful for debugging) --- # Force each sequence to have exactly this many images (if > 0) self.fixed_num_images = common_config.fix_img_num # Force a specific aspect ratio for all images self.fixed_aspect_ratio = common_config.fix_aspect_ratio # --- Track Settings --- # Whether to include point tracks in the output self.load_track = common_config.load_track # Number of point tracks to include per sequence self.track_num = common_config.track_num # --- Mode Settings --- # Whether the dataset is being used for training (affects augmentations) self.training = common_config.training self.common_config = common_config self.total_samples = len(self.base_dataset) def __len__(self): """Returns the total number of sequences in the dataset.""" return self.total_samples def __getitem__(self, idx_tuple): """ Retrieves a data sample (sequence) from the dataset. Loads raw data, converts to PyTorch tensors, applies augmentations, and prepares tracks if enabled. Args: idx_tuple (tuple): a tuple of (seq_idx, num_images, aspect_ratio) Returns: dict: A dictionary containing the sequence data (images, poses, tracks, etc.). """ # If fixed settings are provided, override the tuple values if self.fixed_num_images > 0: seq_idx = idx_tuple[0] if isinstance(idx_tuple, tuple) else idx_tuple idx_tuple = (seq_idx, self.fixed_num_images, self.fixed_aspect_ratio) # Retrieve the raw data batch from the appropriate base dataset batch = self.base_dataset[idx_tuple] seq_name = batch["seq_name"] # --- Data Conversion and Preparation --- # Convert numpy arrays to tensors images = torch.from_numpy(np.stack(batch["images"]).astype(np.float32)).contiguous() # Normalize images from [0, 255] to [0, 1] images = images.permute(0,3,1,2).to(torch.get_default_dtype()).div(255) # Convert other data to tensors with appropriate types depths = torch.from_numpy(np.stack(batch["depths"]).astype(np.float32)) extrinsics = torch.from_numpy(np.stack(batch["extrinsics"]).astype(np.float32)) intrinsics = torch.from_numpy(np.stack(batch["intrinsics"]).astype(np.float32)) cam_points = torch.from_numpy(np.stack(batch["cam_points"]).astype(np.float32)) world_points = torch.from_numpy(np.stack(batch["world_points"]).astype(np.float32)) point_masks = torch.from_numpy(np.stack(batch["point_masks"])) # Mask indicating valid depths / world points / cam points per frame ids = torch.from_numpy(batch["ids"]) # Frame indices sampled from the original sequence # --- Apply Color Augmentation (training mode only) --- if self.training and self.image_aug is not None: if self.cojitter and random.random() > self.cojitter_ratio: # Apply the same color jittering transformation to all frames images = self.image_aug(images) else: # Apply different color jittering to each frame individually for aug_img_idx in range(len(images)): images[aug_img_idx] = self.image_aug(images[aug_img_idx]) # --- Prepare Final Sample Dictionary --- sample = { "seq_name": seq_name, "ids": ids, "images": images, "depths": depths, "extrinsics": extrinsics, "intrinsics": intrinsics, "cam_points": cam_points, "world_points": world_points, "point_masks": point_masks, } # --- Track Processing (if enabled) --- if self.load_track: if batch["tracks"] is not None: # Use pre-computed tracks from the dataset tracks = torch.from_numpy(np.stack(batch["tracks"]).astype(np.float32)) track_vis_mask = torch.from_numpy(np.stack(batch["track_masks"]).astype(bool)) # Sample a subset of tracks randomly valid_indices = torch.where(track_vis_mask[0])[0] if len(valid_indices) >= self.track_num: # If we have enough tracks, sample without replacement sampled_indices = valid_indices[torch.randperm(len(valid_indices))][:self.track_num] else: # If not enough tracks, sample with replacement (allow duplicates) sampled_indices = valid_indices[torch.randint(0, len(valid_indices), (self.track_num,), dtype=torch.int64, device=valid_indices.device)] # Extract the sampled tracks and their masks tracks = tracks[:, sampled_indices, :] track_vis_mask = track_vis_mask[:, sampled_indices] track_positive_mask = torch.ones(track_vis_mask.shape[1]).bool() else: # Generate tracks on-the-fly using depth information # This creates synthetic tracks based on the 3D information available tracks, track_vis_mask, track_positive_mask = build_tracks_by_depth( extrinsics, intrinsics, world_points, depths, point_masks, images, target_track_num=self.track_num, seq_name=seq_name ) # Add track information to the sample dictionary sample["tracks"] = tracks sample["track_vis_mask"] = track_vis_mask sample["track_positive_mask"] = track_positive_mask return sample class TupleConcatDataset(ConcatDataset): """ A custom ConcatDataset that supports indexing with a tuple. Standard PyTorch ConcatDataset only accepts an integer index. This class extends that functionality to allow passing a tuple like (sample_idx, num_images, aspect_ratio), where the first element is used to determine which sample to fetch, and the full tuple is passed down to the selected dataset's __getitem__ method. It also supports an option to randomly sample across all datasets, ignoring the provided index. This is useful during training when shuffling the entire dataset might cause memory issues due to duplicating dictionaries. If doing this, you can set pytorch's dataloader shuffle to False. """ def __init__(self, datasets, common_config): """ Initialize the TupleConcatDataset. Args: datasets (iterable): An iterable of PyTorch Dataset objects to concatenate. common_config (dict): Common configuration dict, used to check for random sampling. """ super().__init__(datasets) # If True, ignores the input index and samples randomly across all datasets # This provides an alternative to dataloader shuffling for large datasets self.inside_random = common_config.inside_random def __getitem__(self, idx): """ Retrieves an item using either an integer index or a tuple index. Args: idx (int or tuple): The index. If tuple, the first element is the sequence index across the concatenated datasets, and the rest are passed down. If int, it's treated as the sequence index. Returns: The item returned by the underlying dataset's __getitem__ method. Raises: ValueError: If the index is out of range or the tuple doesn't have exactly 3 elements. """ idx_tuple = None if isinstance(idx, tuple): idx_tuple = idx idx = idx_tuple[0] # Extract the sequence index # Override index with random value if inside_random is enabled if self.inside_random: total_len = self.cumulative_sizes[-1] idx = random.randint(0, total_len - 1) # Handle negative indices if idx < 0: if -idx > len(self): raise ValueError( "absolute value of index should not exceed dataset length" ) idx = len(self) + idx # Find which dataset the index belongs to dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) if dataset_idx == 0: sample_idx = idx else: sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] # Create the tuple to pass to the underlying dataset if len(idx_tuple) == 3: idx_tuple = (sample_idx,) + idx_tuple[1:] else: raise ValueError("Tuple index must have exactly three elements") # Pass the modified tuple to the appropriate dataset return self.datasets[dataset_idx][idx_tuple] ================================================ FILE: training/data/dataset_util.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import os os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" import cv2 import math import numpy as np from PIL import Image import PIL try: lanczos = PIL.Image.Resampling.LANCZOS bicubic = PIL.Image.Resampling.BICUBIC except AttributeError: lanczos = PIL.Image.LANCZOS bicubic = PIL.Image.BICUBIC from vggt.utils.geometry import closed_form_inverse_se3 ##################################################################################################################### def crop_image_depth_and_intrinsic_by_pp( image, depth_map, intrinsic, target_shape, track=None, filepath=None, strict=False ): """ TODO: some names of width and height seem not consistent. Need to check. Crops the given image and depth map around the camera's principal point, as defined by `intrinsic`. Specifically: - Ensures that the crop is centered on (cx, cy). - Optionally pads the image (and depth map) if `strict=True` and the result is smaller than `target_shape`. - Shifts the camera intrinsic matrix (and `track` if provided) accordingly. Args: image (np.ndarray): Input image array of shape (H, W, 3). depth_map (np.ndarray or None): Depth map array of shape (H, W), or None if not available. intrinsic (np.ndarray): Camera intrinsic matrix (3x3). The principal point is assumed to be at (intrinsic[1,2], intrinsic[0,2]). target_shape (tuple[int, int]): Desired output shape. track (np.ndarray or None): Optional array of shape (N, 2). Interpreted as (x, y) pixel coordinates. Will be shifted after cropping. filepath (str or None): An optional file path for debug logging (only used if strict mode triggers warnings). strict (bool): If True, will zero-pad to ensure the exact target_shape even if the cropped region is smaller. Raises: AssertionError: If the input image is smaller than `target_shape`. ValueError: If the cropped image is larger than `target_shape` (in strict mode), which should not normally happen. Returns: tuple: (cropped_image, cropped_depth_map, updated_intrinsic, updated_track) - cropped_image (np.ndarray): Cropped (and optionally padded) image. - cropped_depth_map (np.ndarray or None): Cropped (and optionally padded) depth map. - updated_intrinsic (np.ndarray): Intrinsic matrix adjusted for the crop. - updated_track (np.ndarray or None): Track array adjusted for the crop, or None if track was not provided. """ original_size = np.array(image.shape) intrinsic = np.copy(intrinsic) if original_size[0] < target_shape[0]: error_message = ( f"Width check failed: original width {original_size[0]} " f"is less than target width {target_shape[0]}." ) print(error_message) raise AssertionError(error_message) if original_size[1] < target_shape[1]: error_message = ( f"Height check failed: original height {original_size[1]} " f"is less than target height {target_shape[1]}." ) print(error_message) raise AssertionError(error_message) # Identify principal point (cx, cy) from intrinsic cx = (intrinsic[1, 2]) cy = (intrinsic[0, 2]) # Compute how far we can crop in each direction if strict: half_x = min((target_shape[0] / 2), cx) half_y = min((target_shape[1] / 2), cy) else: half_x = min((target_shape[0] / 2), cx, original_size[0] - cx) half_y = min((target_shape[1] / 2), cy, original_size[1] - cy) # Compute starting indices start_x = math.floor(cx) - math.floor(half_x) start_y = math.floor(cy) - math.floor(half_y) assert start_x >= 0 assert start_y >= 0 # Compute ending indices if strict: end_x = start_x + target_shape[0] end_y = start_y + target_shape[1] else: end_x = start_x + 2 * math.floor(half_x) end_y = start_y + 2 * math.floor(half_y) # Perform the crop image = image[start_x:end_x, start_y:end_y, :] if depth_map is not None: depth_map = depth_map[start_x:end_x, start_y:end_y] # Shift the principal point in the intrinsic intrinsic[1, 2] = intrinsic[1, 2] - start_x intrinsic[0, 2] = intrinsic[0, 2] - start_y # Adjust track if provided if track is not None: track[:, 1] = track[:, 1] - start_x track[:, 0] = track[:, 0] - start_y # If strict, zero-pad if the new shape is smaller than target_shape if strict: if (image.shape[:2] != target_shape).any(): print(f"{filepath} does not meet the target shape") current_h, current_w = image.shape[:2] target_h, target_w = target_shape[0], target_shape[1] pad_h = target_h - current_h pad_w = target_w - current_w if pad_h < 0 or pad_w < 0: raise ValueError( f"The cropped image is bigger than the target shape: " f"cropped=({current_h},{current_w}), " f"target=({target_h},{target_w})." ) image = np.pad( image, pad_width=((0, pad_h), (0, pad_w), (0, 0)), mode="constant", constant_values=0, ) if depth_map is not None: depth_map = np.pad( depth_map, pad_width=((0, pad_h), (0, pad_w)), mode="constant", constant_values=0, ) return image, depth_map, intrinsic, track def resize_image_depth_and_intrinsic( image, depth_map, intrinsic, target_shape, original_size, track=None, pixel_center=True, safe_bound=4, rescale_aug=True, ): """ Resizes the given image and depth map (if provided) to slightly larger than `target_shape`, updating the intrinsic matrix (and track array if present). Optionally uses random rescaling to create some additional margin (based on `rescale_aug`). Steps: 1. Compute a scaling factor so that the resized result is at least `target_shape + safe_bound`. 2. Apply an optional triangular random factor if `rescale_aug=True`. 3. Resize the image with LANCZOS if downscaling, BICUBIC if upscaling. 4. Resize the depth map with nearest-neighbor. 5. Update the camera intrinsic and track coordinates (if any). Args: image (np.ndarray): Input image array (H, W, 3). depth_map (np.ndarray or None): Depth map array (H, W), or None if unavailable. intrinsic (np.ndarray): Camera intrinsic matrix (3x3). target_shape (np.ndarray or tuple[int, int]): Desired final shape (height, width). original_size (np.ndarray or tuple[int, int]): Original size of the image in (height, width). track (np.ndarray or None): Optional (N, 2) array of pixel coordinates. Will be scaled. pixel_center (bool): If True, accounts for 0.5 pixel center shift during resizing. safe_bound (int or float): Additional margin (in pixels) to add to target_shape before resizing. rescale_aug (bool): If True, randomly increase the `safe_bound` within a certain range to simulate augmentation. Returns: tuple: (resized_image, resized_depth_map, updated_intrinsic, updated_track) - resized_image (np.ndarray): The resized image. - resized_depth_map (np.ndarray or None): The resized depth map. - updated_intrinsic (np.ndarray): Camera intrinsic updated for new resolution. - updated_track (np.ndarray or None): Track array updated or None if not provided. Raises: AssertionError: If the shapes of the resized image and depth map do not match. """ if rescale_aug: random_boundary = np.random.triangular(0, 0, 0.3) safe_bound = safe_bound + random_boundary * target_shape.max() resize_scales = (target_shape + safe_bound) / original_size max_resize_scale = np.max(resize_scales) intrinsic = np.copy(intrinsic) # Convert image to PIL for resizing image = Image.fromarray(image) input_resolution = np.array(image.size) output_resolution = np.floor(input_resolution * max_resize_scale).astype(int) image = image.resize(tuple(output_resolution), resample=lanczos if max_resize_scale < 1 else bicubic) image = np.array(image) if depth_map is not None: depth_map = cv2.resize( depth_map, output_resolution, fx=max_resize_scale, fy=max_resize_scale, interpolation=cv2.INTER_NEAREST, ) actual_size = np.array(image.shape[:2]) actual_resize_scale = np.max(actual_size / original_size) if pixel_center: intrinsic[0, 2] = intrinsic[0, 2] + 0.5 intrinsic[1, 2] = intrinsic[1, 2] + 0.5 intrinsic[:2, :] = intrinsic[:2, :] * actual_resize_scale if track is not None: track = track * actual_resize_scale if pixel_center: intrinsic[0, 2] = intrinsic[0, 2] - 0.5 intrinsic[1, 2] = intrinsic[1, 2] - 0.5 assert image.shape[:2] == depth_map.shape[:2] return image, depth_map, intrinsic, track def threshold_depth_map( depth_map: np.ndarray, max_percentile: float = 99, min_percentile: float = 1, max_depth: float = -1, ) -> np.ndarray: """ Thresholds a depth map using percentile-based limits and optional maximum depth clamping. Steps: 1. If `max_depth > 0`, clamp all values above `max_depth` to zero. 2. Compute `max_percentile` and `min_percentile` thresholds using nanpercentile. 3. Zero out values above/below these thresholds, if thresholds are > 0. Args: depth_map (np.ndarray): Input depth map (H, W). max_percentile (float): Upper percentile (0-100). Values above this will be set to zero. min_percentile (float): Lower percentile (0-100). Values below this will be set to zero. max_depth (float): Absolute maximum depth. If > 0, any depth above this is set to zero. If <= 0, no maximum-depth clamp is applied. Returns: np.ndarray: Depth map (H, W) after thresholding. Some or all values may be zero. Returns None if depth_map is None. """ if depth_map is None: return None depth_map = depth_map.astype(float, copy=True) # Optional clamp by max_depth if max_depth > 0: depth_map[depth_map > max_depth] = 0.0 # Percentile-based thresholds depth_max_thres = ( np.nanpercentile(depth_map, max_percentile) if max_percentile > 0 else None ) depth_min_thres = ( np.nanpercentile(depth_map, min_percentile) if min_percentile > 0 else None ) # Apply the thresholds if they are > 0 if depth_max_thres is not None and depth_max_thres > 0: depth_map[depth_map > depth_max_thres] = 0.0 if depth_min_thres is not None and depth_min_thres > 0: depth_map[depth_map < depth_min_thres] = 0.0 return depth_map def depth_to_world_coords_points( depth_map: np.ndarray, extrinsic: np.ndarray, intrinsic: np.ndarray, eps=1e-8, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """ Converts a depth map to world coordinates (HxWx3) given the camera extrinsic and intrinsic. Returns both the world coordinates and the intermediate camera coordinates, as well as a mask for valid depth. Args: depth_map (np.ndarray): Depth map of shape (H, W). extrinsic (np.ndarray): Extrinsic matrix of shape (3, 4), representing the camera pose in OpenCV convention (camera-from-world). intrinsic (np.ndarray): Intrinsic matrix of shape (3, 3). eps (float): Small epsilon for thresholding valid depth. Returns: tuple[np.ndarray, np.ndarray, np.ndarray]: (world_coords_points, cam_coords_points, point_mask) - world_coords_points: (H, W, 3) array of 3D points in world frame. - cam_coords_points: (H, W, 3) array of 3D points in camera frame. - point_mask: (H, W) boolean array where True indicates valid (non-zero) depth. """ if depth_map is None: return None, None, None # Valid depth mask point_mask = depth_map > eps # Convert depth map to camera coordinates cam_coords_points = depth_to_cam_coords_points(depth_map, intrinsic) # The extrinsic is camera-from-world, so invert it to transform camera->world cam_to_world_extrinsic = closed_form_inverse_se3(extrinsic[None])[0] R_cam_to_world = cam_to_world_extrinsic[:3, :3] t_cam_to_world = cam_to_world_extrinsic[:3, 3] # Apply the rotation and translation to the camera coordinates world_coords_points = ( np.dot(cam_coords_points, R_cam_to_world.T) + t_cam_to_world ) # HxWx3, 3x3 -> HxWx3 # world_coords_points = np.einsum("ij,hwj->hwi", R_cam_to_world, cam_coords_points) + t_cam_to_world return world_coords_points, cam_coords_points, point_mask def depth_to_cam_coords_points( depth_map: np.ndarray, intrinsic: np.ndarray ) -> np.ndarray: """ Unprojects a depth map into camera coordinates, returning (H, W, 3). Args: depth_map (np.ndarray): Depth map of shape (H, W). intrinsic (np.ndarray): 3x3 camera intrinsic matrix. Assumes zero skew and standard OpenCV layout: [ fx 0 cx ] [ 0 fy cy ] [ 0 0 1 ] Returns: np.ndarray: An (H, W, 3) array, where each pixel is mapped to (x, y, z) in the camera frame. """ H, W = depth_map.shape assert intrinsic.shape == (3, 3), "Intrinsic matrix must be 3x3" assert ( intrinsic[0, 1] == 0 and intrinsic[1, 0] == 0 ), "Intrinsic matrix must have zero skew" # Intrinsic parameters fu, fv = intrinsic[0, 0], intrinsic[1, 1] cu, cv = intrinsic[0, 2], intrinsic[1, 2] # Generate grid of pixel coordinates u, v = np.meshgrid(np.arange(W), np.arange(H)) # Unproject to camera coordinates x_cam = (u - cu) * depth_map / fu y_cam = (v - cv) * depth_map / fv z_cam = depth_map # Stack to form camera coordinates return np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32) def rotate_90_degrees( image, depth_map, extri_opencv, intri_opencv, clockwise=True, track=None ): """ Rotates the input image, depth map, and camera parameters by 90 degrees. Applies one of two 90-degree rotations: - Clockwise - Counterclockwise (if clockwise=False) The extrinsic and intrinsic matrices are adjusted accordingly to maintain correct camera geometry. Track coordinates are also updated if provided. Args: image (np.ndarray): Input image of shape (H, W, 3). depth_map (np.ndarray or None): Depth map of shape (H, W), or None if not available. extri_opencv (np.ndarray): Extrinsic matrix (3x4) in OpenCV convention. intri_opencv (np.ndarray): Intrinsic matrix (3x3). clockwise (bool): If True, rotates the image 90 degrees clockwise; else 90 degrees counterclockwise. track (np.ndarray or None): Optional (N, 2) track array. Will be rotated accordingly. Returns: tuple: ( rotated_image, rotated_depth_map, new_extri_opencv, new_intri_opencv, new_track ) Where each is the updated version after the rotation. """ image_height, image_width = image.shape[:2] # Rotate the image and depth map rotated_image, rotated_depth_map = rotate_image_and_depth_rot90(image, depth_map, clockwise) # Adjust the intrinsic matrix new_intri_opencv = adjust_intrinsic_matrix_rot90(intri_opencv, image_width, image_height, clockwise) if track is not None: new_track = adjust_track_rot90(track, image_width, image_height, clockwise) else: new_track = None # Adjust the extrinsic matrix new_extri_opencv = adjust_extrinsic_matrix_rot90(extri_opencv, clockwise) return ( rotated_image, rotated_depth_map, new_extri_opencv, new_intri_opencv, new_track, ) def rotate_image_and_depth_rot90(image, depth_map, clockwise): """ Rotates the given image and depth map by 90 degrees (clockwise or counterclockwise), using a transpose+flip pattern. Args: image (np.ndarray): Input image of shape (H, W, 3). depth_map (np.ndarray or None): Depth map of shape (H, W), or None if not available. clockwise (bool): If True, rotate 90 degrees clockwise; else 90 degrees counterclockwise. Returns: tuple: (rotated_image, rotated_depth_map) """ rotated_depth_map = None if clockwise: rotated_image = np.transpose(image, (1, 0, 2)) # Transpose height and width rotated_image = np.flip(rotated_image, axis=1) # Flip horizontally if depth_map is not None: rotated_depth_map = np.transpose(depth_map, (1, 0)) rotated_depth_map = np.flip(rotated_depth_map, axis=1) else: rotated_image = np.transpose(image, (1, 0, 2)) # Transpose height and width rotated_image = np.flip(rotated_image, axis=0) # Flip vertically if depth_map is not None: rotated_depth_map = np.transpose(depth_map, (1, 0)) rotated_depth_map = np.flip(rotated_depth_map, axis=0) return np.copy(rotated_image), np.copy(rotated_depth_map) def adjust_extrinsic_matrix_rot90(extri_opencv, clockwise): """ Adjusts the extrinsic matrix (3x4) for a 90-degree rotation of the image. The rotation is in the image plane. This modifies the camera orientation accordingly. The function applies either a clockwise or counterclockwise 90-degree rotation. Args: extri_opencv (np.ndarray): Extrinsic matrix (3x4) in OpenCV convention. clockwise (bool): If True, rotate extrinsic for a 90-degree clockwise image rotation; otherwise, counterclockwise. Returns: np.ndarray: A new 3x4 extrinsic matrix after the rotation. """ R = extri_opencv[:, :3] t = extri_opencv[:, 3] if clockwise: R_rotation = np.array([ [0, -1, 0], [1, 0, 0], [0, 0, 1] ]) else: R_rotation = np.array([ [0, 1, 0], [-1, 0, 0], [0, 0, 1] ]) new_R = np.dot(R_rotation, R) new_t = np.dot(R_rotation, t) new_extri_opencv = np.hstack((new_R, new_t.reshape(-1, 1))) return new_extri_opencv def adjust_intrinsic_matrix_rot90(intri_opencv, image_width, image_height, clockwise): """ Adjusts the intrinsic matrix (3x3) for a 90-degree rotation of the image in the image plane. Args: intri_opencv (np.ndarray): Intrinsic matrix (3x3). image_width (int): Original width of the image. image_height (int): Original height of the image. clockwise (bool): If True, rotate 90 degrees clockwise; else 90 degrees counterclockwise. Returns: np.ndarray: A new 3x3 intrinsic matrix after the rotation. """ fx, fy, cx, cy = ( intri_opencv[0, 0], intri_opencv[1, 1], intri_opencv[0, 2], intri_opencv[1, 2], ) new_intri_opencv = np.eye(3) if clockwise: new_intri_opencv[0, 0] = fy new_intri_opencv[1, 1] = fx new_intri_opencv[0, 2] = image_height - cy new_intri_opencv[1, 2] = cx else: new_intri_opencv[0, 0] = fy new_intri_opencv[1, 1] = fx new_intri_opencv[0, 2] = cy new_intri_opencv[1, 2] = image_width - cx return new_intri_opencv def adjust_track_rot90(track, image_width, image_height, clockwise): """ Adjusts a track (N, 2) for a 90-degree rotation of the image in the image plane. Args: track (np.ndarray): (N, 2) array of pixel coordinates, each row is (x, y). image_width (int): Original image width. image_height (int): Original image height. clockwise (bool): Whether the rotation is 90 degrees clockwise or counterclockwise. Returns: np.ndarray: A new track of shape (N, 2) after rotation. """ if clockwise: # (x, y) -> (y, image_width - 1 - x) new_track = np.stack((track[:, 1], image_width - 1 - track[:, 0]), axis=-1) else: # (x, y) -> (image_height - 1 - y, x) new_track = np.stack((image_height - 1 - track[:, 1], track[:, 0]), axis=-1) return new_track def read_image_cv2(path: str, rgb: bool = True) -> np.ndarray: """ Reads an image from disk using OpenCV, returning it as an RGB image array (H, W, 3). Args: path (str): File path to the image. rgb (bool): If True, convert the image to RGB. If False, leave the image in BGR/grayscale. Returns: np.ndarray or None: A numpy array of shape (H, W, 3) if successful, or None if the file does not exist or could not be read. """ if not os.path.exists(path) or os.path.getsize(path) == 0: print(f"File does not exist or is empty: {path}") return None img = cv2.imread(path) if img is None: print(f"Could not load image={path}. Retrying...") img = cv2.imread(path) if img is None: print("Retry failed.") return None if rgb: if len(img.shape) == 2: img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) else: img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) return img def read_depth(path: str, scale_adjustment=1.0) -> np.ndarray: """ Reads a depth map from disk in either .exr or .png format. The .exr is loaded using OpenCV with the environment variable OPENCV_IO_ENABLE_OPENEXR=1. The .png is assumed to be a 16-bit PNG (converted from half float). Args: path (str): File path to the depth image. Must end with .exr or .png. scale_adjustment (float): A multiplier for adjusting the loaded depth values (default=1.0). Returns: np.ndarray: A float32 array (H, W) containing the loaded depth. Zeros or non-finite values may indicate invalid regions. Raises: ValueError: If the file extension is not supported. """ if path.lower().endswith(".exr"): # Ensure OPENCV_IO_ENABLE_OPENEXR is set to "1" d = cv2.imread(path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)[..., 0] d[d > 1e9] = 0.0 elif path.lower().endswith(".png"): d = load_16big_png_depth(path) else: raise ValueError(f'unsupported depth file name "{path}"') d = d * scale_adjustment d[~np.isfinite(d)] = 0.0 return d def load_16big_png_depth(depth_png: str) -> np.ndarray: """ Loads a 16-bit PNG as a half-float depth map (H, W), returning a float32 NumPy array. Implementation detail: - PIL loads 16-bit data as 32-bit "I" mode. - We reinterpret the bits as float16, then cast to float32. Args: depth_png (str): File path to the 16-bit PNG. Returns: np.ndarray: A float32 depth array of shape (H, W). """ with Image.open(depth_png) as depth_pil: depth = ( np.frombuffer(np.array(depth_pil, dtype=np.uint16), dtype=np.float16) .astype(np.float32) .reshape((depth_pil.size[1], depth_pil.size[0])) ) return depth ================================================ FILE: training/data/datasets/co3d.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import gzip import json import os.path as osp import os import logging import cv2 import random import numpy as np from data.dataset_util import * from data.base_dataset import BaseDataset SEEN_CATEGORIES = [ "apple", "backpack", "banana", "baseballbat", "baseballglove", "bench", "bicycle", "bottle", "bowl", "broccoli", "cake", "car", "carrot", "cellphone", "chair", "cup", "donut", "hairdryer", "handbag", "hydrant", "keyboard", "laptop", "microwave", "motorcycle", "mouse", "orange", "parkingmeter", "pizza", "plant", "stopsign", "teddybear", "toaster", "toilet", "toybus", "toyplane", "toytrain", "toytruck", "tv", "umbrella", "vase", "wineglass", ] class Co3dDataset(BaseDataset): def __init__( self, common_conf, split: str = "train", CO3D_DIR: str = None, CO3D_ANNOTATION_DIR: str = None, min_num_images: int = 24, len_train: int = 100000, len_test: int = 10000, ): """ Initialize the Co3dDataset. Args: common_conf: Configuration object with common settings. split (str): Dataset split, either 'train' or 'test'. CO3D_DIR (str): Directory path to CO3D data. CO3D_ANNOTATION_DIR (str): Directory path to CO3D annotations. min_num_images (int): Minimum number of images per sequence. len_train (int): Length of the training dataset. len_test (int): Length of the test dataset. Raises: ValueError: If CO3D_DIR or CO3D_ANNOTATION_DIR is not specified. """ super().__init__(common_conf=common_conf) self.debug = common_conf.debug self.training = common_conf.training self.get_nearby = common_conf.get_nearby self.load_depth = common_conf.load_depth self.inside_random = common_conf.inside_random self.allow_duplicate_img = common_conf.allow_duplicate_img if CO3D_DIR is None or CO3D_ANNOTATION_DIR is None: raise ValueError("Both CO3D_DIR and CO3D_ANNOTATION_DIR must be specified.") category = sorted(SEEN_CATEGORIES) if self.debug: category = ["apple"] if split == "train": split_name_list = ["train"] self.len_train = len_train elif split == "test": split_name_list = ["test"] self.len_train = len_test else: raise ValueError(f"Invalid split: {split}") self.invalid_sequence = [] # set any invalid sequence names here self.category_map = {} self.data_store = {} self.seqlen = None self.min_num_images = min_num_images logging.info(f"CO3D_DIR is {CO3D_DIR}") self.CO3D_DIR = CO3D_DIR self.CO3D_ANNOTATION_DIR = CO3D_ANNOTATION_DIR total_frame_num = 0 for c in category: for split_name in split_name_list: annotation_file = osp.join( self.CO3D_ANNOTATION_DIR, f"{c}_{split_name}.jgz" ) try: with gzip.open(annotation_file, "r") as fin: annotation = json.loads(fin.read()) except FileNotFoundError: logging.error(f"Annotation file not found: {annotation_file}") continue for seq_name, seq_data in annotation.items(): if len(seq_data) < min_num_images: continue if seq_name in self.invalid_sequence: continue total_frame_num += len(seq_data) self.data_store[seq_name] = seq_data self.sequence_list = list(self.data_store.keys()) self.sequence_list_len = len(self.sequence_list) self.total_frame_num = total_frame_num status = "Training" if self.training else "Testing" logging.info(f"{status}: Co3D Data size: {self.sequence_list_len}") logging.info(f"{status}: Co3D Data dataset length: {len(self)}") def get_data( self, seq_index: int = None, img_per_seq: int = None, seq_name: str = None, ids: list = None, aspect_ratio: float = 1.0, ) -> dict: """ Retrieve data for a specific sequence. Args: seq_index (int): Index of the sequence to retrieve. img_per_seq (int): Number of images per sequence. seq_name (str): Name of the sequence. ids (list): Specific IDs to retrieve. aspect_ratio (float): Aspect ratio for image processing. Returns: dict: A batch of data including images, depths, and other metadata. """ if self.inside_random: seq_index = random.randint(0, self.sequence_list_len - 1) if seq_name is None: seq_name = self.sequence_list[seq_index] metadata = self.data_store[seq_name] if ids is None: ids = np.random.choice( len(metadata), img_per_seq, replace=self.allow_duplicate_img ) annos = [metadata[i] for i in ids] target_image_shape = self.get_target_shape(aspect_ratio) images = [] depths = [] cam_points = [] world_points = [] point_masks = [] extrinsics = [] intrinsics = [] image_paths = [] original_sizes = [] for anno in annos: filepath = anno["filepath"] image_path = osp.join(self.CO3D_DIR, filepath) image = read_image_cv2(image_path) if self.load_depth: depth_path = image_path.replace("/images", "/depths") + ".geometric.png" depth_map = read_depth(depth_path, 1.0) mvs_mask_path = image_path.replace( "/images", "/depth_masks" ).replace(".jpg", ".png") mvs_mask = cv2.imread(mvs_mask_path, cv2.IMREAD_GRAYSCALE) > 128 depth_map[~mvs_mask] = 0 depth_map = threshold_depth_map( depth_map, min_percentile=-1, max_percentile=98 ) else: depth_map = None original_size = np.array(image.shape[:2]) extri_opencv = np.array(anno["extri"]) intri_opencv = np.array(anno["intri"]) ( image, depth_map, extri_opencv, intri_opencv, world_coords_points, cam_coords_points, point_mask, _, ) = self.process_one_image( image, depth_map, extri_opencv, intri_opencv, original_size, target_image_shape, filepath=filepath, ) images.append(image) depths.append(depth_map) extrinsics.append(extri_opencv) intrinsics.append(intri_opencv) cam_points.append(cam_coords_points) world_points.append(world_coords_points) point_masks.append(point_mask) image_paths.append(image_path) original_sizes.append(original_size) set_name = "co3d" batch = { "seq_name": set_name + "_" + seq_name, "ids": ids, "frame_num": len(extrinsics), "images": images, "depths": depths, "extrinsics": extrinsics, "intrinsics": intrinsics, "cam_points": cam_points, "world_points": world_points, "point_masks": point_masks, "original_sizes": original_sizes, } return batch ================================================ FILE: training/data/datasets/vkitti.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import os import os.path as osp import logging import random import glob import cv2 import numpy as np from data.dataset_util import * from data.base_dataset import BaseDataset class VKittiDataset(BaseDataset): def __init__( self, common_conf, split: str = "train", VKitti_DIR: str = "/checkpoint/repligen/jianyuan/datasets/vkitti/", min_num_images: int = 24, len_train: int = 100000, len_test: int = 10000, expand_ratio: int = 8, ): """ Initialize the VKittiDataset. Args: common_conf: Configuration object with common settings. split (str): Dataset split, either 'train' or 'test'. VKitti_DIR (str): Directory path to VKitti data. min_num_images (int): Minimum number of images per sequence. len_train (int): Length of the training dataset. len_test (int): Length of the test dataset. expand_range (int): Range for expanding nearby image selection. get_nearby_thres (int): Threshold for nearby image selection. """ super().__init__(common_conf=common_conf) self.debug = common_conf.debug self.training = common_conf.training self.get_nearby = common_conf.get_nearby self.inside_random = common_conf.inside_random self.allow_duplicate_img = common_conf.allow_duplicate_img self.expand_ratio = expand_ratio self.VKitti_DIR = VKitti_DIR self.min_num_images = min_num_images if split == "train": self.len_train = len_train elif split == "test": self.len_train = len_test else: raise ValueError(f"Invalid split: {split}") logging.info(f"VKitti_DIR is {self.VKitti_DIR}") # Load or generate sequence list txt_path = osp.join(self.VKitti_DIR, "sequence_list.txt") if osp.exists(txt_path): with open(txt_path, 'r') as f: sequence_list = [line.strip() for line in f.readlines()] else: # Generate sequence list and save to txt sequence_list = glob.glob(osp.join(self.VKitti_DIR, "*/*/*/rgb/*")) sequence_list = [file_path.split(self.VKitti_DIR)[-1].lstrip('/') for file_path in sequence_list] sequence_list = sorted(sequence_list) # Save to txt file with open(txt_path, 'w') as f: f.write('\n'.join(sequence_list)) self.sequence_list = sequence_list self.sequence_list_len = len(self.sequence_list) self.depth_max = 80 status = "Training" if self.training else "Testing" logging.info(f"{status}: VKitti Real Data size: {self.sequence_list_len}") logging.info(f"{status}: VKitti Data dataset length: {len(self)}") def get_data( self, seq_index: int = None, img_per_seq: int = None, seq_name: str = None, ids: list = None, aspect_ratio: float = 1.0, ) -> dict: """ Retrieve data for a specific sequence. Args: seq_index (int): Index of the sequence to retrieve. img_per_seq (int): Number of images per sequence. seq_name (str): Name of the sequence. ids (list): Specific IDs to retrieve. aspect_ratio (float): Aspect ratio for image processing. Returns: dict: A batch of data including images, depths, and other metadata. """ if self.inside_random and self.training: seq_index = random.randint(0, self.sequence_list_len - 1) if seq_name is None: seq_name = self.sequence_list[seq_index] camera_id = int(seq_name[-1]) # Load camera parameters try: camera_parameters = np.loadtxt( osp.join(self.VKitti_DIR, "/".join(seq_name.split("/")[:2]), "extrinsic.txt"), delimiter=" ", skiprows=1 ) camera_parameters = camera_parameters[camera_parameters[:, 1] == camera_id] camera_intrinsic = np.loadtxt( osp.join(self.VKitti_DIR, "/".join(seq_name.split("/")[:2]), "intrinsic.txt"), delimiter=" ", skiprows=1 ) camera_intrinsic = camera_intrinsic[camera_intrinsic[:, 1] == camera_id] except Exception as e: logging.error(f"Error loading camera parameters for {seq_name}: {e}") raise num_images = len(camera_parameters) if ids is None: ids = np.random.choice(num_images, img_per_seq, replace=self.allow_duplicate_img) if self.get_nearby: ids = self.get_nearby_ids(ids, num_images, expand_ratio=self.expand_ratio) target_image_shape = self.get_target_shape(aspect_ratio) images = [] depths = [] cam_points = [] world_points = [] point_masks = [] extrinsics = [] intrinsics = [] original_sizes = [] for image_idx in ids: image_filepath = osp.join(self.VKitti_DIR, seq_name, f"rgb_{image_idx:05d}.jpg") depth_filepath = osp.join(self.VKitti_DIR, seq_name, f"depth_{image_idx:05d}.png").replace("/rgb", "/depth") image = read_image_cv2(image_filepath) depth_map = cv2.imread(depth_filepath, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) depth_map = depth_map / 100 depth_map = threshold_depth_map(depth_map, max_percentile=-1, min_percentile=-1, max_depth=self.depth_max) assert image.shape[:2] == depth_map.shape, f"Image and depth shape mismatch: {image.shape[:2]} vs {depth_map.shape}" original_size = np.array(image.shape[:2]) # Process camera matrices extri_opencv = camera_parameters[image_idx][2:].reshape(4, 4) extri_opencv = extri_opencv[:3] intri_opencv = np.eye(3) intri_opencv[0, 0] = camera_intrinsic[image_idx][-4] intri_opencv[1, 1] = camera_intrinsic[image_idx][-3] intri_opencv[0, 2] = camera_intrinsic[image_idx][-2] intri_opencv[1, 2] = camera_intrinsic[image_idx][-1] ( image, depth_map, extri_opencv, intri_opencv, world_coords_points, cam_coords_points, point_mask, _, ) = self.process_one_image( image, depth_map, extri_opencv, intri_opencv, original_size, target_image_shape, filepath=image_filepath, ) if (image.shape[:2] != target_image_shape).any(): logging.error(f"Wrong shape for {seq_name}: expected {target_image_shape}, got {image.shape[:2]}") continue images.append(image) depths.append(depth_map) extrinsics.append(extri_opencv) intrinsics.append(intri_opencv) cam_points.append(cam_coords_points) world_points.append(world_coords_points) point_masks.append(point_mask) original_sizes.append(original_size) set_name = "vkitti" batch = { "seq_name": set_name + "_" + seq_name, "ids": ids, "frame_num": len(extrinsics), "images": images, "depths": depths, "extrinsics": extrinsics, "intrinsics": intrinsics, "cam_points": cam_points, "world_points": world_points, "point_masks": point_masks, "original_sizes": original_sizes, } return batch ================================================ FILE: training/data/dynamic_dataloader.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from typing import Callable, Optional from hydra.utils import instantiate import random import numpy as np from torch.utils.data import DataLoader, Dataset, DistributedSampler, IterableDataset, Sampler from abc import ABC, abstractmethod from .worker_fn import get_worker_init_fn class DynamicTorchDataset(ABC): def __init__( self, dataset: dict, common_config: dict, num_workers: int, shuffle: bool, pin_memory: bool, drop_last: bool = True, collate_fn: Optional[Callable] = None, worker_init_fn: Optional[Callable] = None, persistent_workers: bool = False, seed: int = 42, max_img_per_gpu: int = 48, ) -> None: self.dataset_config = dataset self.common_config = common_config self.num_workers = num_workers self.shuffle = shuffle self.pin_memory = pin_memory self.drop_last = drop_last self.collate_fn = collate_fn self.worker_init_fn = worker_init_fn self.persistent_workers = persistent_workers self.seed = seed self.max_img_per_gpu = max_img_per_gpu # Instantiate the dataset self.dataset = instantiate(dataset, common_config=common_config, _recursive_=False) # Extract aspect ratio and image number ranges from the configuration self.aspect_ratio_range = common_config.augs.aspects # e.g., [0.5, 1.0] self.image_num_range = common_config.img_nums # e.g., [2, 24] # Validate the aspect ratio and image number ranges if len(self.aspect_ratio_range) != 2 or self.aspect_ratio_range[0] > self.aspect_ratio_range[1]: raise ValueError(f"aspect_ratio_range must be [min, max] with min <= max, got {self.aspect_ratio_range}") if len(self.image_num_range) != 2 or self.image_num_range[0] < 1 or self.image_num_range[0] > self.image_num_range[1]: raise ValueError(f"image_num_range must be [min, max] with 1 <= min <= max, got {self.image_num_range}") # Create samplers self.sampler = DynamicDistributedSampler(self.dataset, seed=seed, shuffle=shuffle) self.batch_sampler = DynamicBatchSampler( self.sampler, self.aspect_ratio_range, self.image_num_range, seed=seed, max_img_per_gpu=max_img_per_gpu ) def get_loader(self, epoch): print("Building dynamic dataloader with epoch:", epoch) # Set the epoch for the sampler self.sampler.set_epoch(epoch) if hasattr(self.dataset, "epoch"): self.dataset.epoch = epoch if hasattr(self.dataset, "set_epoch"): self.dataset.set_epoch(epoch) # Create and return the dataloader return DataLoader( self.dataset, num_workers=self.num_workers, pin_memory=self.pin_memory, batch_sampler=self.batch_sampler, collate_fn=self.collate_fn, persistent_workers=self.persistent_workers, worker_init_fn=get_worker_init_fn( seed=self.seed, num_workers=self.num_workers, epoch=epoch, worker_init_fn=self.worker_init_fn, ), ) class DynamicBatchSampler(Sampler): """ A custom batch sampler that dynamically adjusts batch size, aspect ratio, and image number for each sample. Batches within a sample share the same aspect ratio and image number. """ def __init__(self, sampler, aspect_ratio_range, image_num_range, epoch=0, seed=42, max_img_per_gpu=48): """ Initializes the dynamic batch sampler. Args: sampler: Instance of DynamicDistributedSampler. aspect_ratio_range: List containing [min_aspect_ratio, max_aspect_ratio]. image_num_range: List containing [min_images, max_images] per sample. epoch: Current epoch number. seed: Random seed for reproducibility. max_img_per_gpu: Maximum number of images to fit in GPU memory. """ self.sampler = sampler self.aspect_ratio_range = aspect_ratio_range self.image_num_range = image_num_range self.rng = random.Random() # Uniformly sample from the range of possible image numbers # For any image number, the weight is 1.0 (uniform sampling). You can set any different weights here. self.image_num_weights = {num_images: 1.0 for num_images in range(image_num_range[0], image_num_range[1]+1)} # Possible image numbers, e.g., [2, 3, 4, ..., 24] self.possible_nums = np.array([n for n in self.image_num_weights.keys() if self.image_num_range[0] <= n <= self.image_num_range[1]]) # Normalize weights for sampling weights = [self.image_num_weights[n] for n in self.possible_nums] self.normalized_weights = np.array(weights) / sum(weights) # Maximum image number per GPU self.max_img_per_gpu = max_img_per_gpu # Set the epoch for the sampler self.set_epoch(epoch + seed) def set_epoch(self, epoch): """ Sets the epoch for this sampler, affecting the random sequence. Args: epoch: The epoch number. """ self.sampler.set_epoch(epoch) self.epoch = epoch self.rng.seed(epoch * 100) def __iter__(self): """ Yields batches of samples with synchronized dynamic parameters. Returns: Iterator yielding batches of indices with associated parameters. """ sampler_iterator = iter(self.sampler) while True: try: # Sample random image number and aspect ratio random_image_num = int(np.random.choice(self.possible_nums, p=self.normalized_weights)) random_aspect_ratio = round(self.rng.uniform(self.aspect_ratio_range[0], self.aspect_ratio_range[1]), 2) # Update sampler parameters self.sampler.update_parameters( aspect_ratio=random_aspect_ratio, image_num=random_image_num ) # Calculate batch size based on max images per GPU and current image number batch_size = self.max_img_per_gpu / random_image_num batch_size = np.floor(batch_size).astype(int) batch_size = max(1, batch_size) # Ensure batch size is at least 1 # Collect samples for the current batch current_batch = [] for _ in range(batch_size): try: item = next(sampler_iterator) # item is (idx, aspect_ratio, image_num) current_batch.append(item) except StopIteration: break # No more samples if not current_batch: break # No more data to yield yield current_batch except StopIteration: break # End of sampler's iterator def __len__(self): # Return a large dummy length return 1000000 class DynamicDistributedSampler(DistributedSampler): """ Extends PyTorch's DistributedSampler to include dynamic aspect_ratio and image_num parameters, which can be passed into the dataset's __getitem__ method. """ def __init__( self, dataset, num_replicas: Optional[int] = None, rank: Optional[int] = None, shuffle: bool = False, seed: int = 0, drop_last: bool = False, ): super().__init__( dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, seed=seed, drop_last=drop_last ) self.aspect_ratio = None self.image_num = None def __iter__(self): """ Yields a sequence of (index, image_num, aspect_ratio). Relies on the parent class's logic for shuffling/distributing the indices across replicas, then attaches extra parameters. """ indices_iter = super().__iter__() for idx in indices_iter: yield (idx, self.image_num, self.aspect_ratio,) def update_parameters(self, aspect_ratio, image_num): """ Updates dynamic parameters for each new epoch or iteration. Args: aspect_ratio: The aspect ratio to set. image_num: The number of images to set. """ self.aspect_ratio = aspect_ratio self.image_num = image_num ================================================ FILE: training/data/preprocess/vkitti.sh ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. mkdir vkitti cd vkitti wget https://download.europe.naverlabs.com//virtual_kitti_2.0.3/vkitti_2.0.3_rgb.tar tar -xvf vkitti_2.0.3_rgb.tar wget https://download.europe.naverlabs.com//virtual_kitti_2.0.3/vkitti_2.0.3_depth.tar tar -xvf vkitti_2.0.3_depth.tar wget https://download.europe.naverlabs.com//virtual_kitti_2.0.3/vkitti_2.0.3_textgt.tar.gz tar -xvf vkitti_2.0.3_textgt.tar.gz cd .. ================================================ FILE: training/data/track_util.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import os import cv2 import numpy as np import torch import logging from vggt.utils.geometry import * def build_tracks_by_depth(extrinsics, intrinsics, world_points, depths, point_masks, images, pos_rel_thres=0.05, neg_epipolar_thres=16, boundary_thres=4, target_track_num=512, neg_ratio = 0.0, neg_sample_size_ratio = 0.5, seq_name=None): """ Args: extrinsics: (N, 3, 4) intrinsics: (N, 3, 3) world_points: (N, H, W, 3) depths: (N, H, W) point_masks: (N, H, W) pos_rel_thres: float, relative threshold for positive track depth check neg_epipolar_thres: float, threshold for negative track epipolar check, in px boundary_thres: int, boundary in px to skip near edges target_track_num: int, total # tracks to build neg_ratio: fraction of final tracks that should be negative neg_sample_size_ratio: fraction of W/H used for random offset Returns: final_tracks: (N, P, 2) float final_vis_masks: (N, P) bool final_pos_masks: (P) bool, indicate if a mask is positive or negative """ # Wait, should we do this before resizing the image? B, H, W, _ = world_points.shape # We use the first frame as the query frame, so [0] query_world_points = world_points[0] query_point_masks = point_masks[0] if (query_point_masks).sum() > 0: # at least one point valid_query_points = query_world_points[query_point_masks] # image_points: BxPx2 # cam_points: Bx3xP (yes 3xP instead of Px3). Probably we can change it in the future image_points, cam_points = project_world_points_to_cam(valid_query_points, extrinsics, intrinsics) # proj_depths: BxP proj_depths = cam_points[:, -1] # floor to get the left top corner uv_int = image_points.floor().long().clone() uv_inside_flag = (uv_int[..., 0] >= boundary_thres) & (uv_int[..., 0] < (W - boundary_thres)) & (uv_int[..., 1] >= boundary_thres) & (uv_int[..., 1] < (H - boundary_thres)) uv_int[~uv_inside_flag] = 0 batch_indices = torch.arange(B).view(B, 1).expand(-1, uv_int.shape[1]) # Use these indices to sample from the depth map # since we interpolate depths by nearest, # so assume the left top corner is (x, y) # we want to check for (x,y), (x+1,y), (x,y+1), (x+1,y+1) depth_inside_flag = None for shift in [(0,0), (1,0), (0,1), (1,1)]: cur_uv_int = uv_int + torch.tensor(shift) cur_depth_inside_flag = get_depth_inside_flag(depths, batch_indices, cur_uv_int, proj_depths, pos_rel_thres) if depth_inside_flag is None: depth_inside_flag = cur_depth_inside_flag else: depth_inside_flag = torch.logical_or(depth_inside_flag, cur_depth_inside_flag) # B, P, 2 positive_tracks = image_points positive_vis_masks = torch.logical_and(uv_inside_flag, depth_inside_flag) else: print(f"No valid query points in {seq_name}") positive_tracks = torch.zeros(B, target_track_num, 2, device=world_points.device, dtype=torch.float32) positive_vis_masks = torch.zeros(B, target_track_num, device=world_points.device, dtype=torch.bool) sampled_neg_track_num = target_track_num * 4 # we sample more negative tracks to ensure the quality perb_range = [int(W*neg_sample_size_ratio), int(H*neg_sample_size_ratio)] # sample negative query points us = torch.randint(low=0, high=W, size=(1, sampled_neg_track_num), device=world_points.device) vs = torch.randint(low=0, high=H, size=(1, sampled_neg_track_num), device=world_points.device) neg_query_uvs = torch.stack([us, vs], dim=-1) # construct negative tracks delta_us = torch.rand(size=(B, sampled_neg_track_num), device=world_points.device) * perb_range[0] delta_vs = torch.rand(size=(B, sampled_neg_track_num), device=world_points.device) * perb_range[1] delta_us[0] = 0 delta_vs[0] = 0 negative_tracks = neg_query_uvs + torch.stack([delta_us, delta_vs], dim=-1) # Do epipolar check here negative_sampson_distances = track_epipolar_check(negative_tracks, extrinsics, intrinsics) negative_epipolar_check = (negative_sampson_distances > neg_epipolar_thres).all(dim=0) # we set the threshold to 5 px # Filter out those satifsfying epipolar check negative_tracks = negative_tracks[:, negative_epipolar_check] # Prepare for output final_tracks = torch.zeros(B, target_track_num, 2, device=world_points.device, dtype=torch.float32) final_vis_masks = torch.zeros(B, target_track_num, device=world_points.device, dtype=torch.bool) final_pos_masks = torch.zeros(target_track_num, device=world_points.device, dtype=torch.bool) target_pos_track_num = target_track_num - int(target_track_num * neg_ratio) sampled_pos_track_num = 0 sampled_positive_tracks, sampled_positive_vis_masks = sample_positive_tracks(positive_tracks, positive_vis_masks, target_pos_track_num) sampled_pos_track_num = sampled_positive_tracks.shape[1] final_tracks[:, :sampled_pos_track_num] = sampled_positive_tracks final_vis_masks[:, :sampled_pos_track_num] = sampled_positive_vis_masks final_pos_masks[:sampled_pos_track_num] = True target_neg_track_num = target_track_num - sampled_pos_track_num # Now we need to sample negative tracks # just do simple random sampling rand_indices = torch.randperm(negative_tracks.shape[1], device=negative_tracks.device) sampled_neg_tracks = negative_tracks[:, rand_indices[:target_neg_track_num]] sampled_neg_track_num = sampled_neg_tracks.shape[1] final_tracks[:, sampled_pos_track_num:sampled_pos_track_num+sampled_neg_track_num] = sampled_neg_tracks if sampled_pos_track_num+sampled_neg_track_num!=target_track_num: logging.warning(f"sampled_pos_track_num+sampled_neg_track_num!=target_track_num: {sampled_pos_track_num+sampled_neg_track_num} != {target_track_num}") # Do not need to set final_vis_masks and final_pos_masks, because they are all False # Do not need to check the shape of final_tracks, as it is zeroed out # NOTE: We need to do some visual checks return final_tracks, final_vis_masks, final_pos_masks def get_depth_inside_flag(depths, batch_indices, uv_int, proj_depths, rel_thres): sampled_depths = depths[batch_indices, uv_int[..., 1], uv_int[..., 0]] depth_diff = (proj_depths - sampled_depths).abs() depth_inside_flag = torch.logical_and(depth_diff < (proj_depths * rel_thres), depth_diff < (sampled_depths * rel_thres)) return depth_inside_flag def sample_positive_tracks(tracks, tracks_mask, track_num, half_top = True, seq_name=None): # tracks: (B, T, 2) # tracks_mask: (B, T) # track_num: int # half_top: bool # if the query frame is not valid, then the track is not valid tracks_mask[:, tracks_mask[0]==False] = False track_frame_num = tracks_mask.sum(dim=0) tracks_mask[:, track_frame_num<=1] = False track_frame_num = tracks_mask.sum(dim=0) _, track_num_sort_idx = track_frame_num.sort(descending=True) if half_top: if len(track_num_sort_idx)//2 > track_num: # drop those tracks with too small number of valid frames # track_num_sort_idx = track_num_sort_idx[:track_num] track_num_sort_idx = track_num_sort_idx[:len(track_num_sort_idx)//2] pick_idx = torch.randperm(len(track_num_sort_idx))[:track_num] track_num_sort_idx = track_num_sort_idx[pick_idx] tracks = tracks[:, track_num_sort_idx].clone() tracks_mask = tracks_mask[:, track_num_sort_idx].clone() tracks_mask = tracks_mask.bool() # ensure the type is bool return tracks, tracks_mask # Only for Debugging and Visualization def track_epipolar_check(tracks, extrinsics, intrinsics, use_essential_mat = False): from kornia.geometry.epipolar import sampson_epipolar_distance B, T, _ = tracks.shape essential_mats = get_essential_matrix(extrinsics[0:1].expand(B-1, -1, -1), extrinsics[1:]) if use_essential_mat: tracks_normalized = cam_from_img(tracks, intrinsics) sampson_distances = sampson_epipolar_distance(tracks_normalized[0:1].expand(B-1, -1, -1), tracks_normalized[1:], essential_mats) else: K1 = intrinsics[0:1].expand(B-1, -1, -1) K2 = intrinsics[1:].expand(B-1, -1, -1) fundamental_mats = K2.inverse().permute(0, 2, 1).matmul(essential_mats).matmul(K1.inverse()) sampson_distances = sampson_epipolar_distance(tracks[0:1].expand(B-1, -1, -1), tracks[1:], fundamental_mats) return sampson_distances def get_essential_matrix(extrinsic1, extrinsic2): R1 = extrinsic1[:, :3, :3] t1 = extrinsic1[:, :3, 3] R2 = extrinsic2[:, :3, :3] t2 = extrinsic2[:, :3, 3] R12 = R2.matmul(R1.permute(0, 2, 1)) t12 = t2 - R12.matmul(t1[..., None])[..., 0] E_R = R12 E_t = -E_R.permute(0, 2, 1).matmul(t12[..., None])[..., 0] E = E_R.matmul(hat(E_t)) return E def hat(v: torch.Tensor) -> torch.Tensor: N, dim = v.shape if dim != 3: raise ValueError("Input vectors have to be 3-dimensional.") x, y, z = v.unbind(1) h_01 = -z.view(N, 1, 1) h_02 = y.view(N, 1, 1) h_10 = z.view(N, 1, 1) h_12 = -x.view(N, 1, 1) h_20 = -y.view(N, 1, 1) h_21 = x.view(N, 1, 1) zeros = torch.zeros((N, 1, 1), dtype=v.dtype, device=v.device) row1 = torch.cat((zeros, h_01, h_02), dim=2) row2 = torch.cat((h_10, zeros, h_12), dim=2) row3 = torch.cat((h_20, h_21, zeros), dim=2) h = torch.cat((row1, row2, row3), dim=1) return h def color_from_xy(x, y, W, H, cmap_name="hsv"): """ Map (x, y) -> color in (R, G, B). 1) Normalize x,y to [0,1]. 2) Combine them into a single scalar c in [0,1]. 3) Use matplotlib's colormap to convert c -> (R,G,B). You can customize step 2, e.g., c = (x + y)/2, or some function of (x, y). """ import matplotlib.cm import matplotlib.colors x_norm = x / max(W - 1, 1) y_norm = y / max(H - 1, 1) # Simple combination: c = (x_norm + y_norm) / 2.0 cmap = matplotlib.cm.get_cmap(cmap_name) # cmap(c) -> (r,g,b,a) in [0,1] rgba = cmap(c) r, g, b = rgba[0], rgba[1], rgba[2] return (r, g, b) # in [0,1], RGB order def get_track_colors_by_position( tracks_b, vis_mask_b=None, image_width=None, image_height=None, cmap_name="hsv" ): """ Given all tracks in one sample (b), compute a (N,3) array of RGB color values in [0,255]. The color is determined by the (x,y) position in the first visible frame for each track. Args: tracks_b: Tensor of shape (S, N, 2). (x,y) for each track in each frame. vis_mask_b: (S, N) boolean mask; if None, assume all are visible. image_width, image_height: used for normalizing (x, y). cmap_name: for matplotlib (e.g., 'hsv', 'rainbow', 'jet'). Returns: track_colors: np.ndarray of shape (N, 3), each row is (R,G,B) in [0,255]. """ S, N, _ = tracks_b.shape track_colors = np.zeros((N, 3), dtype=np.uint8) if vis_mask_b is None: # treat all as visible vis_mask_b = torch.ones(S, N, dtype=torch.bool, device=tracks_b.device) for i in range(N): # Find first visible frame for track i visible_frames = torch.where(vis_mask_b[:, i])[0] if len(visible_frames) == 0: # track is never visible; just assign black or something track_colors[i] = (0, 0, 0) continue first_s = int(visible_frames[0].item()) # use that frame's (x,y) x, y = tracks_b[first_s, i].tolist() # map (x,y) -> (R,G,B) in [0,1] r, g, b = color_from_xy( x, y, W=image_width, H=image_height, cmap_name=cmap_name ) # scale to [0,255] r, g, b = int(r*255), int(g*255), int(b*255) track_colors[i] = (r, g, b) return track_colors def visualize_tracks_on_images( images, tracks, track_vis_mask=None, out_dir="track_visuals_concat_by_xy", image_format="CHW", # "CHW" or "HWC" normalize_mode="[0,1]", cmap_name="hsv" # e.g. "hsv", "rainbow", "jet" ): """ Visualizes all frames for each sample (b) in ONE horizontal row, saving one PNG per sample. Each track's color is determined by its (x,y) position in the first visible frame (or frame 0 if always visible). Finally convert the BGR result to RGB before saving. Args: images: torch.Tensor (B, S, 3, H, W) if CHW or (B, S, H, W, 3) if HWC. tracks: torch.Tensor (B, S, N, 2), last dim = (x, y). track_vis_mask: torch.Tensor (B, S, N) or None. out_dir: folder to save visualizations. image_format: "CHW" or "HWC". normalize_mode: "[0,1]", "[-1,1]", or None for direct raw -> 0..255 cmap_name: a matplotlib colormap name for color_from_xy. Returns: None (saves images in out_dir). """ import matplotlib matplotlib.use('Agg') # for non-interactive (optional) os.makedirs(out_dir, exist_ok=True) B, S = images.shape[0], images.shape[1] _, _, N, _ = tracks.shape # (B, S, N, 2) # Move to CPU images = images.cpu().clone() tracks = tracks.cpu().clone() if track_vis_mask is not None: track_vis_mask = track_vis_mask.cpu().clone() # Infer H, W from images shape if image_format == "CHW": # e.g. images[b, s].shape = (3, H, W) H, W = images.shape[3], images.shape[4] else: # e.g. images[b, s].shape = (H, W, 3) H, W = images.shape[2], images.shape[3] for b in range(B): # Pre-compute the color for each track i based on first visible position # in sample b: track_colors_rgb = get_track_colors_by_position( tracks[b], # shape (S, N, 2) vis_mask_b=track_vis_mask[b] if track_vis_mask is not None else None, image_width=W, image_height=H, cmap_name=cmap_name ) # We'll accumulate each frame’s drawn image in a list frame_images = [] for s in range(S): # shape => either (3, H, W) or (H, W, 3) img = images[b, s] # Convert to (H, W, 3) if image_format == "CHW": img = img.permute(1, 2, 0) # (H, W, 3) # else "HWC", do nothing img = img.numpy().astype(np.float32) # Scale to [0,255] if needed if normalize_mode == "[0,1]": img = np.clip(img, 0, 1) * 255.0 elif normalize_mode == "[-1,1]": img = (img + 1.0) * 0.5 * 255.0 img = np.clip(img, 0, 255.0) # else no normalization # Convert to uint8 img = img.astype(np.uint8) # For drawing in OpenCV, the image is assumed BGR, # but *currently* it's in (R,G,B) if your original is truly RGB. # We'll do the color conversion AFTER drawing so that we can call # cv2.circle(...) with BGR color. # That means we need to swap the channels now to get BGR for drawing. # If your images are actually BGR, you may skip or adapt. img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) # Draw each visible track cur_tracks = tracks[b, s] # shape (N, 2) if track_vis_mask is not None: valid_indices = torch.where(track_vis_mask[b, s])[0] else: valid_indices = range(N) cur_tracks_np = cur_tracks.numpy() for i in valid_indices: x, y = cur_tracks_np[i] pt = (int(round(x)), int(round(y))) # track_colors_rgb[i] is (R,G,B). For OpenCV circle, we need BGR R, G, B = track_colors_rgb[i] color_bgr = (int(B), int(G), int(R)) cv2.circle(img_bgr, pt, radius=3, color=color_bgr, thickness=-1) # Convert back to RGB for consistent final saving: img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) frame_images.append(img_rgb) # Concatenate all frames horizontally: (H, S*W, 3) row_img = np.concatenate(frame_images, axis=1) out_path = os.path.join(out_dir, f"tracks_b{b}.png") cv2.imwrite(out_path, row_img) print(f"[INFO] Saved color-by-XY track visualization for sample b={b} -> {out_path}") ================================================ FILE: training/data/worker_fn.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ Utilities for distributed training and deterministic seed generation. This module provides functions for working with PyTorch's distributed training capabilities and ensuring reproducible data loading. """ import os import torch import random import numpy as np import torch.distributed as dist from functools import partial def is_dist_avail_and_initialized(): """ Check if distributed training is available and initialized. Returns: bool: True if distributed training is available and initialized, False otherwise. """ if not dist.is_available(): return False if not dist.is_initialized(): return False return True def get_rank(): """ Get the rank of the current process in distributed training. Returns: int: The rank of the current process, or 0 if distributed training is not initialized. """ if not is_dist_avail_and_initialized(): return 0 return dist.get_rank() def get_world_size(): """ Get the total number of processes in distributed training. Returns: int: The world size, or 1 if distributed training is not initialized. """ if not is_dist_avail_and_initialized(): return 1 return dist.get_world_size() def default_worker_init_fn(worker_id, num_workers, epoch, seed=0): """ Default function to initialize random seeds for dataloader workers. Ensures that each worker across different ranks, epochs, and world sizes gets a unique random seed for reproducibility. Args: worker_id (int): ID of the dataloader worker. num_workers (int): Total number of dataloader workers. epoch (int): Current training epoch. seed (int, optional): Base seed for randomization. Defaults to 0. """ rank = get_rank() world_size = get_world_size() distributed_rank = int(os.environ.get("RANK", None)) # Use prime numbers for better distribution RANK_MULTIPLIER = 1 WORKER_MULTIPLIER = 1 WORLD_MULTIPLIER = 1 EPOCH_MULTIPLIER = 12345 DISTRIBUTED_RANK_MULTIPLIER = 1042 worker_seed = ( rank * num_workers * RANK_MULTIPLIER + worker_id * WORKER_MULTIPLIER + seed + world_size * WORLD_MULTIPLIER + epoch * EPOCH_MULTIPLIER + distributed_rank * DISTRIBUTED_RANK_MULTIPLIER ) print(f"Rank: {rank}, World size: {world_size}, Distributed rank: {distributed_rank}") print(f"Worker seed: {worker_seed}") torch.random.manual_seed(worker_seed) np.random.seed(worker_seed) random.seed(worker_seed) return def get_worker_init_fn(seed, num_workers, epoch, worker_init_fn=None): """ Get a worker initialization function for dataloaders. Args: seed (int): Base seed for randomization. num_workers (int): Number of dataloader workers. epoch (int): Current training epoch. worker_init_fn (callable, optional): Custom worker initialization function. If provided, this will be returned instead of the default one. Returns: callable: A worker initialization function to use with DataLoader. """ if worker_init_fn is not None: return worker_init_fn return partial( default_worker_init_fn, num_workers=num_workers, epoch=epoch, seed=seed, ) ================================================ FILE: training/launch.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import argparse from hydra import initialize, compose from omegaconf import DictConfig, OmegaConf from trainer import Trainer def main(): parser = argparse.ArgumentParser(description="Train model with configurable YAML file") parser.add_argument( "--config", type=str, default="default", help="Name of the config file (without .yaml extension, default: default)" ) args = parser.parse_args() with initialize(version_base=None, config_path="config"): cfg = compose(config_name=args.config) trainer = Trainer(**cfg) trainer.run() if __name__ == "__main__": main() ================================================ FILE: training/loss.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch import torch.nn.functional as F from dataclasses import dataclass from vggt.utils.pose_enc import extri_intri_to_pose_encoding from train_utils.general import check_and_fix_inf_nan from math import ceil, floor @dataclass(eq=False) class MultitaskLoss(torch.nn.Module): """ Multi-task loss module that combines different loss types for VGGT. Supports: - Camera loss - Depth loss - Point loss - Tracking loss (not cleaned yet, dirty code is at the bottom of this file) """ def __init__(self, camera=None, depth=None, point=None, track=None, **kwargs): super().__init__() # Loss configuration dictionaries for each task self.camera = camera self.depth = depth self.point = point self.track = track def forward(self, predictions, batch) -> torch.Tensor: """ Compute the total multi-task loss. Args: predictions: Dict containing model predictions for different tasks batch: Dict containing ground truth data and masks Returns: Dict containing individual losses and total objective """ total_loss = 0 loss_dict = {} # Camera pose loss - if pose encodings are predicted if "pose_enc_list" in predictions: camera_loss_dict = compute_camera_loss(predictions, batch, **self.camera) camera_loss = camera_loss_dict["loss_camera"] * self.camera["weight"] total_loss = total_loss + camera_loss loss_dict.update(camera_loss_dict) # Depth estimation loss - if depth maps are predicted if "depth" in predictions: depth_loss_dict = compute_depth_loss(predictions, batch, **self.depth) depth_loss = depth_loss_dict["loss_conf_depth"] + depth_loss_dict["loss_reg_depth"] + depth_loss_dict["loss_grad_depth"] depth_loss = depth_loss * self.depth["weight"] total_loss = total_loss + depth_loss loss_dict.update(depth_loss_dict) # 3D point reconstruction loss - if world points are predicted if "world_points" in predictions: point_loss_dict = compute_point_loss(predictions, batch, **self.point) point_loss = point_loss_dict["loss_conf_point"] + point_loss_dict["loss_reg_point"] + point_loss_dict["loss_grad_point"] point_loss = point_loss * self.point["weight"] total_loss = total_loss + point_loss loss_dict.update(point_loss_dict) # Tracking loss - not cleaned yet, dirty code is at the bottom of this file if "track" in predictions: raise NotImplementedError("Track loss is not cleaned up yet") loss_dict["objective"] = total_loss return loss_dict def compute_camera_loss( pred_dict, # predictions dict, contains pose encodings batch_data, # ground truth and mask batch dict loss_type="l1", # "l1" or "l2" loss gamma=0.6, # temporal decay weight for multi-stage training pose_encoding_type="absT_quaR_FoV", weight_trans=1.0, # weight for translation loss weight_rot=1.0, # weight for rotation loss weight_focal=0.5, # weight for focal length loss **kwargs ): # List of predicted pose encodings per stage pred_pose_encodings = pred_dict['pose_enc_list'] # Binary mask for valid points per frame (B, N, H, W) point_masks = batch_data['point_masks'] # Only consider frames with enough valid points (>100) valid_frame_mask = point_masks[:, 0].sum(dim=[-1, -2]) > 100 # Number of prediction stages n_stages = len(pred_pose_encodings) # Get ground truth camera extrinsics and intrinsics gt_extrinsics = batch_data['extrinsics'] gt_intrinsics = batch_data['intrinsics'] image_hw = batch_data['images'].shape[-2:] # Encode ground truth pose to match predicted encoding format gt_pose_encoding = extri_intri_to_pose_encoding( gt_extrinsics, gt_intrinsics, image_hw, pose_encoding_type=pose_encoding_type ) # Initialize loss accumulators for translation, rotation, focal length total_loss_T = total_loss_R = total_loss_FL = 0 # Compute loss for each prediction stage with temporal weighting for stage_idx in range(n_stages): # Later stages get higher weight (gamma^0 = 1.0 for final stage) stage_weight = gamma ** (n_stages - stage_idx - 1) pred_pose_stage = pred_pose_encodings[stage_idx] if valid_frame_mask.sum() == 0: # If no valid frames, set losses to zero to avoid gradient issues loss_T_stage = (pred_pose_stage * 0).mean() loss_R_stage = (pred_pose_stage * 0).mean() loss_FL_stage = (pred_pose_stage * 0).mean() else: # Only consider valid frames for loss computation loss_T_stage, loss_R_stage, loss_FL_stage = camera_loss_single( pred_pose_stage[valid_frame_mask].clone(), gt_pose_encoding[valid_frame_mask].clone(), loss_type=loss_type ) # Accumulate weighted losses across stages total_loss_T += loss_T_stage * stage_weight total_loss_R += loss_R_stage * stage_weight total_loss_FL += loss_FL_stage * stage_weight # Average over all stages avg_loss_T = total_loss_T / n_stages avg_loss_R = total_loss_R / n_stages avg_loss_FL = total_loss_FL / n_stages # Compute total weighted camera loss total_camera_loss = ( avg_loss_T * weight_trans + avg_loss_R * weight_rot + avg_loss_FL * weight_focal ) # Return loss dictionary with individual components return { "loss_camera": total_camera_loss, "loss_T": avg_loss_T, "loss_R": avg_loss_R, "loss_FL": avg_loss_FL } def camera_loss_single(pred_pose_enc, gt_pose_enc, loss_type="l1"): """ Computes translation, rotation, and focal loss for a batch of pose encodings. Args: pred_pose_enc: (N, D) predicted pose encoding gt_pose_enc: (N, D) ground truth pose encoding loss_type: "l1" (abs error) or "l2" (euclidean error) Returns: loss_T: translation loss (mean) loss_R: rotation loss (mean) loss_FL: focal length/intrinsics loss (mean) NOTE: The paper uses smooth l1 loss, but we found l1 loss is more stable than smooth l1 and l2 loss. So here we use l1 loss. """ if loss_type == "l1": # Translation: first 3 dims; Rotation: next 4 (quaternion); Focal/Intrinsics: last dims loss_T = (pred_pose_enc[..., :3] - gt_pose_enc[..., :3]).abs() loss_R = (pred_pose_enc[..., 3:7] - gt_pose_enc[..., 3:7]).abs() loss_FL = (pred_pose_enc[..., 7:] - gt_pose_enc[..., 7:]).abs() elif loss_type == "l2": # L2 norm for each component loss_T = (pred_pose_enc[..., :3] - gt_pose_enc[..., :3]).norm(dim=-1, keepdim=True) loss_R = (pred_pose_enc[..., 3:7] - gt_pose_enc[..., 3:7]).norm(dim=-1) loss_FL = (pred_pose_enc[..., 7:] - gt_pose_enc[..., 7:]).norm(dim=-1) else: raise ValueError(f"Unknown loss type: {loss_type}") # Check/fix numerical issues (nan/inf) for each loss component loss_T = check_and_fix_inf_nan(loss_T, "loss_T") loss_R = check_and_fix_inf_nan(loss_R, "loss_R") loss_FL = check_and_fix_inf_nan(loss_FL, "loss_FL") # Clamp outlier translation loss to prevent instability, then average loss_T = loss_T.clamp(max=100).mean() loss_R = loss_R.mean() loss_FL = loss_FL.mean() return loss_T, loss_R, loss_FL def compute_point_loss(predictions, batch, gamma=1.0, alpha=0.2, gradient_loss_fn = None, valid_range=-1, **kwargs): """ Compute point loss. Args: predictions: Dict containing 'world_points' and 'world_points_conf' batch: Dict containing ground truth 'world_points' and 'point_masks' gamma: Weight for confidence loss alpha: Weight for confidence regularization gradient_loss_fn: Type of gradient loss to apply valid_range: Quantile range for outlier filtering """ pred_points = predictions['world_points'] pred_points_conf = predictions['world_points_conf'] gt_points = batch['world_points'] gt_points_mask = batch['point_masks'] gt_points = check_and_fix_inf_nan(gt_points, "gt_points") if gt_points_mask.sum() < 100: # If there are less than 100 valid points, skip this batch dummy_loss = (0.0 * pred_points).mean() loss_dict = {f"loss_conf_point": dummy_loss, f"loss_reg_point": dummy_loss, f"loss_grad_point": dummy_loss,} return loss_dict # Compute confidence-weighted regression loss with optional gradient loss loss_conf, loss_grad, loss_reg = regression_loss(pred_points, gt_points, gt_points_mask, conf=pred_points_conf, gradient_loss_fn=gradient_loss_fn, gamma=gamma, alpha=alpha, valid_range=valid_range) loss_dict = { f"loss_conf_point": loss_conf, f"loss_reg_point": loss_reg, f"loss_grad_point": loss_grad, } return loss_dict def compute_depth_loss(predictions, batch, gamma=1.0, alpha=0.2, gradient_loss_fn = None, valid_range=-1, **kwargs): """ Compute depth loss. Args: predictions: Dict containing 'depth' and 'depth_conf' batch: Dict containing ground truth 'depths' and 'point_masks' gamma: Weight for confidence loss alpha: Weight for confidence regularization gradient_loss_fn: Type of gradient loss to apply valid_range: Quantile range for outlier filtering """ pred_depth = predictions['depth'] pred_depth_conf = predictions['depth_conf'] gt_depth = batch['depths'] gt_depth = check_and_fix_inf_nan(gt_depth, "gt_depth") gt_depth = gt_depth[..., None] # (B, H, W, 1) gt_depth_mask = batch['point_masks'].clone() # 3D points derived from depth map, so we use the same mask if gt_depth_mask.sum() < 100: # If there are less than 100 valid points, skip this batch dummy_loss = (0.0 * pred_depth).mean() loss_dict = {f"loss_conf_depth": dummy_loss, f"loss_reg_depth": dummy_loss, f"loss_grad_depth": dummy_loss,} return loss_dict # NOTE: we put conf inside regression_loss so that we can also apply conf loss to the gradient loss in a multi-scale manner # this is hacky, but very easier to implement loss_conf, loss_grad, loss_reg = regression_loss(pred_depth, gt_depth, gt_depth_mask, conf=pred_depth_conf, gradient_loss_fn=gradient_loss_fn, gamma=gamma, alpha=alpha, valid_range=valid_range) loss_dict = { f"loss_conf_depth": loss_conf, f"loss_reg_depth": loss_reg, f"loss_grad_depth": loss_grad, } return loss_dict def regression_loss(pred, gt, mask, conf=None, gradient_loss_fn=None, gamma=1.0, alpha=0.2, valid_range=-1): """ Core regression loss function with confidence weighting and optional gradient loss. Computes: 1. gamma * ||pred - gt||^2 * conf - alpha * log(conf) 2. Optional gradient loss Args: pred: (B, S, H, W, C) predicted values gt: (B, S, H, W, C) ground truth values mask: (B, S, H, W) valid pixel mask conf: (B, S, H, W) confidence weights (optional) gradient_loss_fn: Type of gradient loss ("normal", "grad", etc.) gamma: Weight for confidence loss alpha: Weight for confidence regularization valid_range: Quantile range for outlier filtering Returns: loss_conf: Confidence-weighted loss loss_grad: Gradient loss (0 if not specified) loss_reg: Regular L2 loss """ bb, ss, hh, ww, nc = pred.shape # Compute L2 distance between predicted and ground truth points loss_reg = torch.norm(gt[mask] - pred[mask], dim=-1) loss_reg = check_and_fix_inf_nan(loss_reg, "loss_reg") # Confidence-weighted loss: gamma * loss * conf - alpha * log(conf) # This encourages the model to be confident on easy examples and less confident on hard ones loss_conf = gamma * loss_reg * conf[mask] - alpha * torch.log(conf[mask]) loss_conf = check_and_fix_inf_nan(loss_conf, "loss_conf") # Initialize gradient loss loss_grad = 0 # Prepare confidence for gradient loss if needed if "conf" in gradient_loss_fn: to_feed_conf = conf.reshape(bb*ss, hh, ww) else: to_feed_conf = None # Compute gradient loss if specified for spatial smoothness if "normal" in gradient_loss_fn: # Surface normal-based gradient loss loss_grad = gradient_loss_multi_scale_wrapper( pred.reshape(bb*ss, hh, ww, nc), gt.reshape(bb*ss, hh, ww, nc), mask.reshape(bb*ss, hh, ww), gradient_loss_fn=normal_loss, scales=3, conf=to_feed_conf, ) elif "grad" in gradient_loss_fn: # Standard gradient-based loss loss_grad = gradient_loss_multi_scale_wrapper( pred.reshape(bb*ss, hh, ww, nc), gt.reshape(bb*ss, hh, ww, nc), mask.reshape(bb*ss, hh, ww), gradient_loss_fn=gradient_loss, conf=to_feed_conf, ) # Process confidence-weighted loss if loss_conf.numel() > 0: # Filter out outliers using quantile-based thresholding if valid_range>0: loss_conf = filter_by_quantile(loss_conf, valid_range) loss_conf = check_and_fix_inf_nan(loss_conf, f"loss_conf_depth") loss_conf = loss_conf.mean() else: loss_conf = (0.0 * pred).mean() # Process regular regression loss if loss_reg.numel() > 0: # Filter out outliers using quantile-based thresholding if valid_range>0: loss_reg = filter_by_quantile(loss_reg, valid_range) loss_reg = check_and_fix_inf_nan(loss_reg, f"loss_reg_depth") loss_reg = loss_reg.mean() else: loss_reg = (0.0 * pred).mean() return loss_conf, loss_grad, loss_reg def gradient_loss_multi_scale_wrapper(prediction, target, mask, scales=4, gradient_loss_fn = None, conf=None): """ Multi-scale gradient loss wrapper. Applies gradient loss at multiple scales by subsampling the input. This helps capture both fine and coarse spatial structures. Args: prediction: (B, H, W, C) predicted values target: (B, H, W, C) ground truth values mask: (B, H, W) valid pixel mask scales: Number of scales to use gradient_loss_fn: Gradient loss function to apply conf: (B, H, W) confidence weights (optional) """ total = 0 for scale in range(scales): step = pow(2, scale) # Subsample by 2^scale total += gradient_loss_fn( prediction[:, ::step, ::step], target[:, ::step, ::step], mask[:, ::step, ::step], conf=conf[:, ::step, ::step] if conf is not None else None ) total = total / scales return total def normal_loss(prediction, target, mask, cos_eps=1e-8, conf=None, gamma=1.0, alpha=0.2): """ Surface normal-based loss for geometric consistency. Computes surface normals from 3D point maps using cross products of neighboring points, then measures the angle between predicted and ground truth normals. Args: prediction: (B, H, W, 3) predicted 3D coordinates/points target: (B, H, W, 3) ground-truth 3D coordinates/points mask: (B, H, W) valid pixel mask cos_eps: Epsilon for numerical stability in cosine computation conf: (B, H, W) confidence weights (optional) gamma: Weight for confidence loss alpha: Weight for confidence regularization """ # Convert point maps to surface normals using cross products pred_normals, pred_valids = point_map_to_normal(prediction, mask, eps=cos_eps) gt_normals, gt_valids = point_map_to_normal(target, mask, eps=cos_eps) # Only consider regions where both predicted and GT normals are valid all_valid = pred_valids & gt_valids # shape: (4, B, H, W) # Early return if not enough valid points divisor = torch.sum(all_valid) if divisor < 10: return 0 # Extract valid normals pred_normals = pred_normals[all_valid].clone() gt_normals = gt_normals[all_valid].clone() # Compute cosine similarity between corresponding normals dot = torch.sum(pred_normals * gt_normals, dim=-1) # Clamp dot product to [-1, 1] for numerical stability dot = torch.clamp(dot, -1 + cos_eps, 1 - cos_eps) # Compute loss as 1 - cos(theta), instead of arccos(dot) for numerical stability loss = 1 - dot # Return mean loss if we have enough valid points if loss.numel() < 10: return 0 else: loss = check_and_fix_inf_nan(loss, "normal_loss") if conf is not None: # Apply confidence weighting conf = conf[None, ...].expand(4, -1, -1, -1) conf = conf[all_valid].clone() loss = gamma * loss * conf - alpha * torch.log(conf) return loss.mean() else: return loss.mean() def gradient_loss(prediction, target, mask, conf=None, gamma=1.0, alpha=0.2): """ Gradient-based loss. Computes the L1 difference between adjacent pixels in x and y directions. Args: prediction: (B, H, W, C) predicted values target: (B, H, W, C) ground truth values mask: (B, H, W) valid pixel mask conf: (B, H, W) confidence weights (optional) gamma: Weight for confidence loss alpha: Weight for confidence regularization """ # Expand mask to match prediction channels mask = mask[..., None].expand(-1, -1, -1, prediction.shape[-1]) M = torch.sum(mask, (1, 2, 3)) # Compute difference between prediction and target diff = prediction - target diff = torch.mul(mask, diff) # Compute gradients in x direction (horizontal) grad_x = torch.abs(diff[:, :, 1:] - diff[:, :, :-1]) mask_x = torch.mul(mask[:, :, 1:], mask[:, :, :-1]) grad_x = torch.mul(mask_x, grad_x) # Compute gradients in y direction (vertical) grad_y = torch.abs(diff[:, 1:, :] - diff[:, :-1, :]) mask_y = torch.mul(mask[:, 1:, :], mask[:, :-1, :]) grad_y = torch.mul(mask_y, grad_y) # Clamp gradients to prevent outliers grad_x = grad_x.clamp(max=100) grad_y = grad_y.clamp(max=100) # Apply confidence weighting if provided if conf is not None: conf = conf[..., None].expand(-1, -1, -1, prediction.shape[-1]) conf_x = conf[:, :, 1:] conf_y = conf[:, 1:, :] grad_x = gamma * grad_x * conf_x - alpha * torch.log(conf_x) grad_y = gamma * grad_y * conf_y - alpha * torch.log(conf_y) # Sum gradients and normalize by number of valid pixels grad_loss = torch.sum(grad_x, (1, 2, 3)) + torch.sum(grad_y, (1, 2, 3)) divisor = torch.sum(M) if divisor == 0: return 0 else: grad_loss = torch.sum(grad_loss) / divisor return grad_loss def point_map_to_normal(point_map, mask, eps=1e-6): """ Convert 3D point map to surface normal vectors using cross products. Computes normals by taking cross products of neighboring point differences. Uses 4 different cross-product directions for robustness. Args: point_map: (B, H, W, 3) 3D points laid out in a 2D grid mask: (B, H, W) valid pixels (bool) eps: Epsilon for numerical stability in normalization Returns: normals: (4, B, H, W, 3) normal vectors for each of the 4 cross-product directions valids: (4, B, H, W) corresponding valid masks """ with torch.cuda.amp.autocast(enabled=False): # Pad inputs to avoid boundary issues padded_mask = F.pad(mask, (1, 1, 1, 1), mode='constant', value=0) pts = F.pad(point_map.permute(0, 3, 1, 2), (1,1,1,1), mode='constant', value=0).permute(0, 2, 3, 1) # Get neighboring points for each pixel center = pts[:, 1:-1, 1:-1, :] # B,H,W,3 up = pts[:, :-2, 1:-1, :] left = pts[:, 1:-1, :-2 , :] down = pts[:, 2:, 1:-1, :] right = pts[:, 1:-1, 2:, :] # Compute direction vectors from center to neighbors up_dir = up - center left_dir = left - center down_dir = down - center right_dir = right - center # Compute four cross products for different normal directions n1 = torch.cross(up_dir, left_dir, dim=-1) # up x left n2 = torch.cross(left_dir, down_dir, dim=-1) # left x down n3 = torch.cross(down_dir, right_dir, dim=-1) # down x right n4 = torch.cross(right_dir,up_dir, dim=-1) # right x up # Validity masks - require both direction pixels to be valid v1 = padded_mask[:, :-2, 1:-1] & padded_mask[:, 1:-1, 1:-1] & padded_mask[:, 1:-1, :-2] v2 = padded_mask[:, 1:-1, :-2 ] & padded_mask[:, 1:-1, 1:-1] & padded_mask[:, 2:, 1:-1] v3 = padded_mask[:, 2:, 1:-1] & padded_mask[:, 1:-1, 1:-1] & padded_mask[:, 1:-1, 2:] v4 = padded_mask[:, 1:-1, 2: ] & padded_mask[:, 1:-1, 1:-1] & padded_mask[:, :-2, 1:-1] # Stack normals and validity masks normals = torch.stack([n1, n2, n3, n4], dim=0) # shape [4, B, H, W, 3] valids = torch.stack([v1, v2, v3, v4], dim=0) # shape [4, B, H, W] # Normalize normal vectors normals = F.normalize(normals, p=2, dim=-1, eps=eps) return normals, valids def filter_by_quantile(loss_tensor, valid_range, min_elements=1000, hard_max=100): """ Filter loss tensor by keeping only values below a certain quantile threshold. This helps remove outliers that could destabilize training. Args: loss_tensor: Tensor containing loss values valid_range: Float between 0 and 1 indicating the quantile threshold min_elements: Minimum number of elements required to apply filtering hard_max: Maximum allowed value for any individual loss Returns: Filtered and clamped loss tensor """ if loss_tensor.numel() <= min_elements: # Too few elements, just return as-is return loss_tensor # Randomly sample if tensor is too large to avoid memory issues if loss_tensor.numel() > 100000000: # Flatten and randomly select 1M elements indices = torch.randperm(loss_tensor.numel(), device=loss_tensor.device)[:1_000_000] loss_tensor = loss_tensor.view(-1)[indices] # First clamp individual values to prevent extreme outliers loss_tensor = loss_tensor.clamp(max=hard_max) # Compute quantile threshold quantile_thresh = torch_quantile(loss_tensor.detach(), valid_range) quantile_thresh = min(quantile_thresh, hard_max) # Apply quantile filtering if enough elements remain quantile_mask = loss_tensor < quantile_thresh if quantile_mask.sum() > min_elements: return loss_tensor[quantile_mask] return loss_tensor def torch_quantile( input, q, dim = None, keepdim: bool = False, *, interpolation: str = "nearest", out: torch.Tensor = None, ) -> torch.Tensor: """Better torch.quantile for one SCALAR quantile. Using torch.kthvalue. Better than torch.quantile because: - No 2**24 input size limit (pytorch/issues/67592), - Much faster, at least on big input sizes. Arguments: input (torch.Tensor): See torch.quantile. q (float): See torch.quantile. Supports only scalar input currently. dim (int | None): See torch.quantile. keepdim (bool): See torch.quantile. Supports only False currently. interpolation: {"nearest", "lower", "higher"} See torch.quantile. out (torch.Tensor | None): See torch.quantile. Supports only None currently. """ # https://github.com/pytorch/pytorch/issues/64947 # Sanitization: q try: q = float(q) assert 0 <= q <= 1 except Exception: raise ValueError(f"Only scalar input 0<=q<=1 is currently supported (got {q})!") # Handle dim=None case if dim_was_none := dim is None: dim = 0 input = input.reshape((-1,) + (1,) * (input.ndim - 1)) # Set interpolation method if interpolation == "nearest": inter = round elif interpolation == "lower": inter = floor elif interpolation == "higher": inter = ceil else: raise ValueError( "Supported interpolations currently are {'nearest', 'lower', 'higher'} " f"(got '{interpolation}')!" ) # Validate out parameter if out is not None: raise ValueError(f"Only None value is currently supported for out (got {out})!") # Compute k-th value k = inter(q * (input.shape[dim] - 1)) + 1 out = torch.kthvalue(input, k, dim, keepdim=True, out=out)[0] # Handle keepdim and dim=None cases if keepdim: return out if dim_was_none: return out.squeeze() else: return out.squeeze(dim) return out ######################################################################################## ######################################################################################## # Dirty code for tracking loss: ######################################################################################## ######################################################################################## ''' def _compute_losses(self, coord_preds, vis_scores, conf_scores, batch): """Compute tracking losses using sequence_loss""" gt_tracks = batch["tracks"] # B, S, N, 2 gt_track_vis_mask = batch["track_vis_mask"] # B, S, N # if self.training and hasattr(self, "train_query_points"): train_query_points = coord_preds[-1].shape[2] gt_tracks = gt_tracks[:, :, :train_query_points] gt_tracks = check_and_fix_inf_nan(gt_tracks, "gt_tracks", hard_max=None) gt_track_vis_mask = gt_track_vis_mask[:, :, :train_query_points] # Create validity mask that filters out tracks not visible in first frame valids = torch.ones_like(gt_track_vis_mask) mask = gt_track_vis_mask[:, 0, :] == True valids = valids * mask.unsqueeze(1) if not valids.any(): print("No valid tracks found in first frame") print("seq_name: ", batch["seq_name"]) print("ids: ", batch["ids"]) print("time: ", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) dummy_coord = coord_preds[0].mean() * 0 # keeps graph & grads dummy_vis = vis_scores.mean() * 0 if conf_scores is not None: dummy_conf = conf_scores.mean() * 0 else: dummy_conf = 0 return dummy_coord, dummy_vis, dummy_conf # three scalar zeros # Compute tracking loss using sequence_loss track_loss = sequence_loss( flow_preds=coord_preds, flow_gt=gt_tracks, vis=gt_track_vis_mask, valids=valids, **self.loss_kwargs ) vis_loss = F.binary_cross_entropy_with_logits(vis_scores[valids], gt_track_vis_mask[valids].float()) vis_loss = check_and_fix_inf_nan(vis_loss, "vis_loss", hard_max=None) # within 3 pixels if conf_scores is not None: gt_conf_mask = (gt_tracks - coord_preds[-1]).norm(dim=-1) < 3 conf_loss = F.binary_cross_entropy_with_logits(conf_scores[valids], gt_conf_mask[valids].float()) conf_loss = check_and_fix_inf_nan(conf_loss, "conf_loss", hard_max=None) else: conf_loss = 0 return track_loss, vis_loss, conf_loss def reduce_masked_mean(x, mask, dim=None, keepdim=False): for a, b in zip(x.size(), mask.size()): assert a == b prod = x * mask if dim is None: numer = torch.sum(prod) denom = torch.sum(mask) else: numer = torch.sum(prod, dim=dim, keepdim=keepdim) denom = torch.sum(mask, dim=dim, keepdim=keepdim) mean = numer / denom.clamp(min=1) mean = torch.where(denom > 0, mean, torch.zeros_like(mean)) return mean def sequence_loss(flow_preds, flow_gt, vis, valids, gamma=0.8, vis_aware=False, huber=False, delta=10, vis_aware_w=0.1, **kwargs): """Loss function defined over sequence of flow predictions""" B, S, N, D = flow_gt.shape assert D == 2 B, S1, N = vis.shape B, S2, N = valids.shape assert S == S1 assert S == S2 n_predictions = len(flow_preds) flow_loss = 0.0 for i in range(n_predictions): i_weight = gamma ** (n_predictions - i - 1) flow_pred = flow_preds[i] i_loss = (flow_pred - flow_gt).abs() # B, S, N, 2 i_loss = check_and_fix_inf_nan(i_loss, f"i_loss_iter_{i}", hard_max=None) i_loss = torch.mean(i_loss, dim=3) # B, S, N # Combine valids and vis for per-frame valid masking. combined_mask = torch.logical_and(valids, vis) num_valid_points = combined_mask.sum() if vis_aware: combined_mask = combined_mask.float() * (1.0 + vis_aware_w) # Add, don't add to the mask itself. flow_loss += i_weight * reduce_masked_mean(i_loss, combined_mask) else: if num_valid_points > 2: i_loss = i_loss[combined_mask] flow_loss += i_weight * i_loss.mean() else: i_loss = check_and_fix_inf_nan(i_loss, f"i_loss_iter_safe_check_{i}", hard_max=None) flow_loss += 0 * i_loss.mean() # Avoid division by zero if n_predictions is 0 (though it shouldn't be). if n_predictions > 0: flow_loss = flow_loss / n_predictions return flow_loss ''' ================================================ FILE: training/train_utils/__init__.py ================================================ ================================================ FILE: training/train_utils/checkpoint.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import logging from typing import ( Any, Dict, List, ) import torch import torch.nn as nn import os from iopath.common.file_io import g_pathmgr from wcmatch import fnmatch # ------------------------------------------------------------ # Glob‑matching flags (behave like the Unix shell) # ------------------------------------------------------------ GLOB_FLAGS = ( fnmatch.CASE # case‑sensitive | fnmatch.DOTMATCH # '*' also matches '.' | fnmatch.EXTMATCH # extended patterns like *(foo|bar) | fnmatch.SPLIT # "pat1|pat2" works out‑of‑the‑box ) class DDPCheckpointSaver: def __init__( self, checkpoint_folder: str, checkpoint_names: List[str], rank: int, epoch: int, ): super().__init__() self.checkpoint_folder = checkpoint_folder self.checkpoint_names = checkpoint_names self.worker_id = rank self.epoch = epoch def save_checkpoint( self, model: nn.Module, **kwargs: Any, ) -> None: checkpoint = dict(**kwargs) checkpoint["model"] = model.state_dict() if self.worker_id == 0: for ckpt_name in self.checkpoint_names: checkpoint_path = os.path.join( self.checkpoint_folder, f"{ckpt_name}.pt" ) logging.info( f"Saving checkpoint at epoch {self.epoch} to {checkpoint_path}" ) robust_torch_save(checkpoint, checkpoint_path) def robust_torch_save(checkpoint: Dict[str, Any], checkpoint_path: str) -> None: """ A more robust version of torch.save that works better with preemptions and corruptions if a job is preempted during save. """ # Move the existing checkpoint to a backup location backup_checkpoint_path = checkpoint_path + ".bak" backup_checkpoint_path_saved = False if g_pathmgr.exists(checkpoint_path): assert not g_pathmgr.exists( backup_checkpoint_path ), f"this should not exist... {backup_checkpoint_path}" g_pathmgr.mv(checkpoint_path, backup_checkpoint_path) backup_checkpoint_path_saved = True # Save the checkpoint with g_pathmgr.open(checkpoint_path, "wb") as f: torch.save(checkpoint, f) # Remove the backup checkpoint if backup_checkpoint_path_saved: g_pathmgr.rm(backup_checkpoint_path) ================================================ FILE: training/train_utils/distributed.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import os import time import torch def get_machine_local_and_dist_rank(): """ Get the distributed and local rank of the current gpu. """ local_rank = int(os.environ.get("LOCAL_RANK", None)) distributed_rank = int(os.environ.get("RANK", None)) assert ( local_rank is not None and distributed_rank is not None ), "Please the set the RANK and LOCAL_RANK environment variables." return local_rank, distributed_rank ================================================ FILE: training/train_utils/freeze.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from wcmatch import fnmatch from functools import wraps from typing import List import torch.nn as nn # ------------------------------------------------------------ # Glob‑matching flags (behave like the Unix shell) # ------------------------------------------------------------ GLOB_FLAGS = ( fnmatch.CASE # case‑sensitive | fnmatch.DOTMATCH # '*' also matches '.' | fnmatch.EXTMATCH # extended patterns like *(foo|bar) | fnmatch.SPLIT # "pat1|pat2" works out‑of‑the‑box ) def freeze_modules(model: nn.Module, patterns: List[str], recursive: bool = True) -> nn.Module: """Freeze (stop training) parts of *model* whose *name* matches *patterns*. Parameters ---------- model : nn.Module The complete model you are working with. patterns : list[str] Glob patterns to match sub‑module names. Example: ``["encoder.*", "cls_head"]`` recursive : bool, default = True • ``True`` → also freeze every child of a matched module. • ``False`` → freeze only the matched module itself. Returns ------- nn.Module The same model object, now with some parts frozen. Example ------- >>> freeze_modules(model, ["encoder.*", "decoder.layer1"], recursive=True) """ matched: set[str] = set() for name, mod in model.named_modules(): # does *name* match ANY user pattern? if any(fnmatch.fnmatch(name, p, flags=GLOB_FLAGS) for p in patterns): matched.add(name) _freeze(mod, recursive) _check_every_pattern_used(matched, patterns) return model # ------------------------------------------------------------ # helpers # ------------------------------------------------------------ def _freeze(mod: nn.Module, recursive: bool) -> None: """Put *mod* in eval mode and lock its parameters.""" if recursive: mod.eval() # affects the whole subtree else: mod.training = False # only this exact module original_train = mod.train @wraps(original_train) def locked_train(mode: bool = True): if recursive: return original_train(False) # ignore user's *mode* out = original_train(mode) # children follow user's choice out.training = False # but this module stays frozen return out mod.train = locked_train # type: ignore[attr-defined] param_iter = ( mod.parameters() # default recurse=True if recursive else mod.parameters(recurse=False) ) for p in param_iter: p.requires_grad = False def _check_every_pattern_used(matched_names: set[str], patterns: List[str]): unused = [p for p in patterns if not any(fnmatch.fnmatch(n, p, flags=GLOB_FLAGS) for n in matched_names)] if unused: raise ValueError(f"These patterns matched nothing: {unused}") ================================================ FILE: training/train_utils/general.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch import torch.nn as nn import os import math import random import numpy as np from typing import Union, Optional import logging from iopath.common.file_io import g_pathmgr import torch.distributed as dist from pathlib import Path from typing import Dict, Iterable, List from collections import defaultdict from dataclasses import fields, is_dataclass from typing import Any, Mapping, Protocol, runtime_checkable def check_and_fix_inf_nan(input_tensor, loss_name="default", hard_max=100): """ Checks if 'input_tensor' contains inf or nan values and clamps extreme values. Args: input_tensor (torch.Tensor): The loss tensor to check and fix. loss_name (str): Name of the loss (for diagnostic prints). hard_max (float, optional): Maximum absolute value allowed. Values outside [-hard_max, hard_max] will be clamped. If None, no clamping is performed. Defaults to 100. """ if input_tensor is None: return input_tensor # Check for inf/nan values has_inf_nan = torch.isnan(input_tensor).any() or torch.isinf(input_tensor).any() if has_inf_nan: logging.warning(f"Tensor {loss_name} contains inf or nan values. Replacing with zeros.") input_tensor = torch.where( torch.isnan(input_tensor) | torch.isinf(input_tensor), torch.zeros_like(input_tensor), input_tensor ) # Apply hard clamping if specified if hard_max is not None: input_tensor = torch.clamp(input_tensor, min=-hard_max, max=hard_max) return input_tensor def get_resume_checkpoint(checkpoint_save_dir): if not g_pathmgr.isdir(checkpoint_save_dir): return None ckpt_file = os.path.join(checkpoint_save_dir, "checkpoint.pt") if not g_pathmgr.isfile(ckpt_file): return None return ckpt_file class DurationMeter: def __init__(self, name, device, fmt=":f"): self.name = name self.device = device self.fmt = fmt self.val = 0 def reset(self): self.val = 0 def update(self, val): self.val = val def add(self, val): self.val += val def __str__(self): return f"{self.name}: {human_readable_time(self.val)}" def human_readable_time(time_seconds): time = int(time_seconds) minutes, seconds = divmod(time, 60) hours, minutes = divmod(minutes, 60) days, hours = divmod(hours, 24) return f"{days:02}d {hours:02}h {minutes:02}m" class ProgressMeter: def __init__(self, num_batches, meters, real_meters, prefix=""): self.batch_fmtstr = self._get_batch_fmtstr(num_batches) self.meters = meters self.real_meters = real_meters self.prefix = prefix def display(self, batch): entries = [self.prefix + self.batch_fmtstr.format(batch)] entries += [str(meter) for meter in self.meters] entries += [ " | ".join( [ f"{os.path.join(name, subname)}: {val:.4f}" for subname, val in meter.compute().items() ] ) for name, meter in self.real_meters.items() ] logging.info(" | ".join(entries)) def _get_batch_fmtstr(self, num_batches): num_digits = len(str(num_batches // 1)) fmt = "{:" + str(num_digits) + "d}" return "[" + fmt + "/" + fmt.format(num_batches) + "]" @runtime_checkable class _CopyableData(Protocol): def to(self, device: torch.device, *args: Any, **kwargs: Any): """Copy data to the specified device""" ... def _is_named_tuple(x) -> bool: return isinstance(x, tuple) and hasattr(x, "_asdict") and hasattr(x, "_fields") def copy_data_to_device(data, device: torch.device, *args: Any, **kwargs: Any): """Function that recursively copies data to a torch.device. Args: data: The data to copy to device device: The device to which the data should be copied args: positional arguments that will be passed to the `to` call kwargs: keyword arguments that will be passed to the `to` call Returns: The data on the correct device """ if _is_named_tuple(data): return type(data)( **copy_data_to_device(data._asdict(), device, *args, **kwargs) ) elif isinstance(data, (list, tuple)): return type(data)(copy_data_to_device(e, device, *args, **kwargs) for e in data) elif isinstance(data, defaultdict): return type(data)( data.default_factory, { k: copy_data_to_device(v, device, *args, **kwargs) for k, v in data.items() }, ) elif isinstance(data, Mapping) and not is_dataclass(data): # handing FrameData-like things return type(data)( { k: copy_data_to_device(v, device, *args, **kwargs) for k, v in data.items() } ) elif is_dataclass(data) and not isinstance(data, type): new_data_class = type(data)( **{ field.name: copy_data_to_device( getattr(data, field.name), device, *args, **kwargs ) for field in fields(data) if field.init } ) for field in fields(data): if not field.init: setattr( new_data_class, field.name, copy_data_to_device( getattr(data, field.name), device, *args, **kwargs ), ) return new_data_class elif isinstance(data, _CopyableData): return data.to(device, *args, **kwargs) return data def safe_makedirs(path: str): if not path: logging.warning("safe_makedirs called with an empty path. No operation performed.") return False try: os.makedirs(path, exist_ok=True) return True except OSError as e: logging.error(f"Failed to create directory '{path}'. Reason: {e}") raise except Exception as e: # Catch any other unexpected errors. logging.error(f"An unexpected error occurred while creating directory '{path}'. Reason: {e}") raise def set_seeds(seed_value, max_epochs, dist_rank): """ Set the python random, numpy and torch seed for each gpu. Also set the CUDA seeds if the CUDA is available. This ensures deterministic nature of the training. """ seed_value = (seed_value + dist_rank) * max_epochs logging.info(f"GPU SEED: {seed_value}") random.seed(seed_value) np.random.seed(seed_value) torch.manual_seed(seed_value) if torch.cuda.is_available(): torch.cuda.manual_seed(seed_value) torch.cuda.manual_seed_all(seed_value) # for multi-GPU def log_env_variables(): env_keys = sorted(list(os.environ.keys())) st = "" for k in env_keys: v = os.environ[k] st += f"{k}={v}\n" logging.info("Logging ENV_VARIABLES") logging.info(st) def is_dist_avail_and_initialized(): if not dist.is_available(): return False if not dist.is_initialized(): return False return True class AverageMeter: """Computes and stores the average and current value. Args: name (str): Name of the metric being tracked device (torch.device, optional): Device for tensor operations. Defaults to None. fmt (str): Format string for displaying values. Defaults to ":f" """ def __init__(self, name: str, device: Optional[torch.device] = None, fmt: str = ":f"): self.name = name self.fmt = fmt self.device = device self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 self._allow_updates = True def update(self, val, n=1): if n <= 0: raise ValueError(f"n must be positive, got {n}") self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count if self.count > 0 else 0.0 def __str__(self) -> str: """String representation showing current and average values.""" fmtstr = "{name}: {val" + self.fmt + "} ({avg" + self.fmt + "})" return fmtstr.format(**self.__dict__) @property def value(self) -> float: """Get the current value.""" return self.val @property def average(self) -> float: """Get the running average.""" return self.avg ################# _UNITS = ('', ' K', ' M', ' B', ' T') # U+202F = thin-space for nicer look def pretty_int(n: int) -> str: """Abbreviate a non-negative integer (0 → 0, 12_345 → '12.3 K').""" assert n >= 0, 'pretty_int() expects a non-negative int' if n < 1_000: return f'{n:,}' exp = int(math.log10(n) // 3) # group of 3 digits exp = min(exp, len(_UNITS) - 1) # cap at trillions value = n / 10 ** (3 * exp) return f'{value:.1f}'.rstrip('0').rstrip('.') + _UNITS[exp] def model_summary(model: torch.nn.Module, *, log_file = None, prefix: str = '') -> None: """ Print / save a compact parameter summary. Args ---- model : The PyTorch nn.Module to inspect. log_file : Optional path – if given, the full `str(model)` and per-parameter lists are written there (three separate *.txt files). prefix : Optional string printed at the beginning of every log line (handy when several models share the same stdout). """ if get_rank(): # only rank-0 prints return # --- counts ------------------------------------------------------------- trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) total = sum(p.numel() for p in model.parameters()) frozen = total - trainable print(prefix + '='*60) print(prefix + f'Model type : {model.__class__.__name__}') print(prefix + f'Total : {pretty_int(total)} parameters') print(prefix + f' trainable: {pretty_int(trainable)}') print(prefix + f' frozen : {pretty_int(frozen)}') print(prefix + '='*60) # --- optional file dump ------------------------------------------------- if log_file is None: return log_file = Path(log_file) log_file.write_text(str(model)) # full architecture # two extra detailed lists def _dump(names: Iterable[str], fname: str): """Write a formatted per-parameter list to *log_file.with_name(fname)*.""" with open(log_file.with_name(fname), 'w') as f: for n in names: p = dict(model.named_parameters())[n] shape = str(tuple(p.shape)) f.write(f'{n:<60s} {shape:<20} {p.numel()}\n') named = dict(model.named_parameters()) _dump([n for n,p in named.items() if p.requires_grad], 'trainable.txt') _dump([n for n,p in named.items() if not p.requires_grad], 'frozen.txt') def get_rank(): if not is_dist_avail_and_initialized(): return 0 return dist.get_rank() ================================================ FILE: training/train_utils/gradient_clip.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch import torch.nn as nn from typing import Union, Optional class GradientClipper: """ Gradient clipping utils that works for both FSDP and DDP with support for different clipping configurations for different parts of the model. """ def __init__(self, configs, *args, **kwargs): """ Args: configs: List of dictionaries, each containing: - module_name: str or list of str, module names to apply clipping to - max_norm: float, maximum norm for gradient clipping - norm_type: int, type of norm (default: 2) """ self.configs = [] self.params_to_clip_by_config = None self.is_initialized = False for config in configs: module_names = config['module_name'] if isinstance(module_names, str): module_names = [module_names] self.configs.append({ 'module_names': module_names, 'max_norm': float(config['max_norm']) if config['max_norm'] is not None else None, 'norm_type': config.get('norm_type', 2) }) def setup_clipping(self, model: nn.Module) -> None: """ Set up gradient clipping by finding all parameters that should be clipped based on module names and validating that all parameters are covered. This should be called once at the beginning of training. Args: model: The model to set up gradient clipping for """ # First, collect all parameters that should be clipped based on module names params_to_clip_by_config = [] all_clipped_params = set() for config in self.configs: current_config_params = [] for name, param in model.named_parameters(): if param.requires_grad: for module_name in config['module_names']: if module_name in name: current_config_params.append(param) all_clipped_params.add(param) break params_to_clip_by_config.append((config, current_config_params)) # Check for remaining parameters remaining_params = [] for name, param in model.named_parameters(): if param.requires_grad and param not in all_clipped_params: remaining_params.append(param) if len(remaining_params) > 0: print(f"Found {len(remaining_params)} parameters that won't be clipped") print(remaining_params) raise ValueError("Some parameters are not configured for gradient clipping") # Store the computed parameters self.params_to_clip_by_config = params_to_clip_by_config self.is_initialized = True def __call__(self, model: nn.Module) -> Optional[torch.Tensor]: """ Perform gradient clipping using the pre-computed parameter groups. Args: model: The model (not used, kept for backward compatibility) Returns: Dictionary of gradient norms for each configuration """ if not self.is_initialized: raise RuntimeError("GradientClipper must be initialized with setup_clipping() before use") grad_norms = {} for config, params_to_clip in self.params_to_clip_by_config: if not params_to_clip or config['max_norm'] is None: continue grad_norm = nn.utils.clip_grad_norm_( params_to_clip, max_norm=config['max_norm'], norm_type=config['norm_type'] ) if grad_norm is None: continue grad_norms[",".join(config['module_names'])] = grad_norm.item() return grad_norms ================================================ FILE: training/train_utils/logging.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import logging import os import copy import sys import atexit import functools from .general import safe_makedirs from iopath.common.file_io import g_pathmgr # cache the opened file object, so that different calls # with the same file name can safely write to the same file. @functools.lru_cache(maxsize=None) def _cached_log_stream(filename): log_buffer_kb = 1 * 1024 # 1KB io = g_pathmgr.open(filename, mode="a", buffering=log_buffer_kb) atexit.register(io.close) return io def setup_logging( name, output_dir=None, rank=0, log_level_primary="INFO", log_level_secondary="ERROR", all_ranks: bool = False, ): """ Setup various logging streams: stdout and file handlers. For file handlers, we only setup for the master gpu. """ global LOGGING_STATE LOGGING_STATE = copy.deepcopy(locals()) # get the filename if we want to log to the file as well log_filename = None if output_dir: safe_makedirs(output_dir) if rank == 0: log_filename = f"{output_dir}/log.txt" elif all_ranks: log_filename = f"{output_dir}/log_{rank}.txt" logger = logging.getLogger(name) logger.setLevel(log_level_primary) # create formatter FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)4d: %(message)s" formatter = logging.Formatter(FORMAT) # clean up any pre-existing handlers for h in logger.handlers: logger.removeHandler(h) logger.root.handlers = [] logging.root.handlers = [] # setup the console handler console_handler = logging.StreamHandler(sys.stdout) console_handler.setFormatter(formatter) if rank == 0: console_handler.setLevel(log_level_primary) else: console_handler.setLevel(log_level_secondary) logger.addHandler(console_handler) # we log to file as well if user wants if log_filename is not None: file_handler = logging.StreamHandler(_cached_log_stream(log_filename)) file_handler.setLevel(log_level_primary) file_handler.setFormatter(formatter) logger.addHandler(file_handler) logging.root = logger ================================================ FILE: training/train_utils/normalization.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch import logging from typing import Optional, Tuple from vggt.utils.geometry import closed_form_inverse_se3 from train_utils.general import check_and_fix_inf_nan def check_valid_tensor(input_tensor: Optional[torch.Tensor], name: str = "tensor") -> None: """ Check if a tensor contains NaN or Inf values and log a warning if found. Args: input_tensor: The tensor to check name: Name of the tensor for logging purposes """ if input_tensor is not None: if torch.isnan(input_tensor).any() or torch.isinf(input_tensor).any(): logging.warning(f"NaN or Inf found in tensor: {name}") def normalize_camera_extrinsics_and_points_batch( extrinsics: torch.Tensor, cam_points: Optional[torch.Tensor] = None, world_points: Optional[torch.Tensor] = None, depths: Optional[torch.Tensor] = None, scale_by_points: bool = True, point_masks: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: """ Normalize camera extrinsics and corresponding 3D points. This function transforms the coordinate system to be centered at the first camera and optionally scales the scene to have unit average distance. Args: extrinsics: Camera extrinsic matrices of shape (B, S, 3, 4) cam_points: 3D points in camera coordinates of shape (B, S, H, W, 3) or (*,3) world_points: 3D points in world coordinates of shape (B, S, H, W, 3) or (*,3) depths: Depth maps of shape (B, S, H, W) scale_by_points: Whether to normalize the scale based on point distances point_masks: Boolean masks for valid points of shape (B, S, H, W) Returns: Tuple containing: - Normalized camera extrinsics of shape (B, S, 3, 4) - Normalized camera points (same shape as input cam_points) - Normalized world points (same shape as input world_points) - Normalized depths (same shape as input depths) """ # Validate inputs check_valid_tensor(extrinsics, "extrinsics") check_valid_tensor(cam_points, "cam_points") check_valid_tensor(world_points, "world_points") check_valid_tensor(depths, "depths") B, S, _, _ = extrinsics.shape device = extrinsics.device assert device == torch.device("cpu") # Convert extrinsics to homogeneous form: (B, N,4,4) extrinsics_homog = torch.cat( [ extrinsics, torch.zeros((B, S, 1, 4), device=device), ], dim=-2, ) extrinsics_homog[:, :, -1, -1] = 1.0 # first_cam_extrinsic_inv, the inverse of the first camera's extrinsic matrix # which can be also viewed as the cam_to_world extrinsic matrix first_cam_extrinsic_inv = closed_form_inverse_se3(extrinsics_homog[:, 0]) # new_extrinsics = torch.matmul(extrinsics_homog, first_cam_extrinsic_inv) new_extrinsics = torch.matmul(extrinsics_homog, first_cam_extrinsic_inv.unsqueeze(1)) # (B,N,4,4) if world_points is not None: # since we are transforming the world points to the first camera's coordinate system # we directly use the cam_from_world extrinsic matrix of the first camera # instead of using the inverse of the first camera's extrinsic matrix R = extrinsics[:, 0, :3, :3] t = extrinsics[:, 0, :3, 3] new_world_points = (world_points @ R.transpose(-1, -2).unsqueeze(1).unsqueeze(2)) + t.unsqueeze(1).unsqueeze(2).unsqueeze(3) else: new_world_points = None if scale_by_points: new_cam_points = cam_points.clone() new_depths = depths.clone() dist = new_world_points.norm(dim=-1) dist_sum = (dist * point_masks).sum(dim=[1,2,3]) valid_count = point_masks.sum(dim=[1,2,3]) avg_scale = (dist_sum / (valid_count + 1e-3)).clamp(min=1e-6, max=1e6) new_world_points = new_world_points / avg_scale.view(-1, 1, 1, 1, 1) new_extrinsics[:, :, :3, 3] = new_extrinsics[:, :, :3, 3] / avg_scale.view(-1, 1, 1) if depths is not None: new_depths = new_depths / avg_scale.view(-1, 1, 1, 1) if cam_points is not None: new_cam_points = new_cam_points / avg_scale.view(-1, 1, 1, 1, 1) else: return new_extrinsics[:, :, :3], cam_points, new_world_points, depths new_extrinsics = new_extrinsics[:, :, :3] # 4x4 -> 3x4 new_extrinsics = check_and_fix_inf_nan(new_extrinsics, "new_extrinsics", hard_max=None) new_cam_points = check_and_fix_inf_nan(new_cam_points, "new_cam_points", hard_max=None) new_world_points = check_and_fix_inf_nan(new_world_points, "new_world_points", hard_max=None) new_depths = check_and_fix_inf_nan(new_depths, "new_depths", hard_max=None) return new_extrinsics, new_cam_points, new_world_points, new_depths ================================================ FILE: training/train_utils/optimizer.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import logging import itertools from typing import Any, Dict, List, Mapping, Iterable, Set, Tuple, Union import hydra import torch import torch.nn as nn from torch import Tensor # ----------------------------------------------------------------------------- # Optimizer wrapper # ----------------------------------------------------------------------------- class OptimizerWrapper: """Wraps a torch.optim.Optimizer and its schedulers (if any).""" def __init__(self, optimizer: torch.optim.Optimizer, schedulers=None) -> None: self.optimizer = optimizer self.schedulers = schedulers self._validate_optimizer_schedulers() self.step_schedulers(0.0) # --------------------------------------------------------------------- # Public API mirroring torch.optim.Optimizer # --------------------------------------------------------------------- def step(self, where: float = 1.0, closure=None): """Update the optimizer & its schedulers.""" self.step_schedulers(where) return self.optimizer.step(closure) def zero_grad(self, *args, **kwargs): return self.optimizer.zero_grad(*args, **kwargs) def _validate_optimizer_schedulers(self): if self.schedulers is None: return for _, sched_map in enumerate(self.schedulers): for option, _ in sched_map.items(): assert option in self.optimizer.defaults, ( f"Optimizer option {option} not found in {self.optimizer}. " f"Valid options are {self.optimizer.defaults.keys()}" ) def step_schedulers(self, where: float) -> None: if self.schedulers is None: return for i, param_group in enumerate(self.optimizer.param_groups): for option, scheduler in self.schedulers[i].items(): param_group[option] = scheduler(where) # ----------------------------------------------------------------------------- # Validation helpers # ----------------------------------------------------------------------------- def validate_param_group_params(param_groups: List[Dict], model: nn.Module): """Ensure param groups are non-overlapping and include all model params.""" for pg in param_groups: assert len(pg["params"]) == len(set(pg["params"])) parameters = [set(pg["params"]) for pg in param_groups] model_parameters = {p for _, p in model.named_parameters()} for p1, p2 in itertools.permutations(parameters, 2): assert p1.isdisjoint(p2), "Parameter groups should be disjoint" assert set.union(*parameters) == model_parameters, ( "Parameter groups must cover ALL model parameters " f"(found {len(set.union(*parameters))} / {len(model_parameters)})" ) # ----------------------------------------------------------------------------- # Glob helpers for pattern matching # ----------------------------------------------------------------------------- from wcmatch import fnmatch GLOB_FLAGS = ( fnmatch.CASE # case-sensitive | fnmatch.DOTMATCH # '*' also matches '.' | fnmatch.EXTMATCH # extended patterns like *(foo|bar) | fnmatch.SPLIT # "pat1|pat2" works out-of-the-box ) def get_full_parameter_name(module_name: str, param_name: str) -> str: return param_name if module_name == "" else f"{module_name}.{param_name}" def get_module_cls_to_param_names(model: nn.Module) -> Dict[type, Set[str]]: """Map each module class to the *immediate* param names it owns.""" mapping: Dict[type, Set[str]] = {} for module_name, module in model.named_modules(): module_cls = type(module) mapping.setdefault(module_cls, set()) for pname, _ in module.named_parameters(recurse=False): mapping[module_cls].add(get_full_parameter_name(module_name, pname)) return mapping def unix_param_pattern_to_parameter_names(filter_param_names: Union[List[str], None], parameter_names: Set[str]) -> Set[str]: if filter_param_names is None: return set() allowed = [] for pat in filter_param_names: matches = set(fnmatch.filter(parameter_names, pat, flags=GLOB_FLAGS)) if not matches: raise AssertionError(f"Pattern {pat} matched no parameters") logging.info(f"Matches for param pattern [{pat}]: {matches}") allowed.append(matches) return set.union(*allowed) def unix_module_cls_pattern_to_parameter_names(filter_module_cls_names: Union[List[str], None], module_cls_to_param_names: Dict[type, Set[str]]) -> Set[str]: if filter_module_cls_names is None: return set() allowed = [] for cls_name in filter_module_cls_names: module_cls = hydra.utils.get_class(cls_name) if module_cls not in module_cls_to_param_names: raise AssertionError(f"Module class {cls_name} not found in model") params = module_cls_to_param_names[module_cls] if not params: raise AssertionError(f"Module class {cls_name} has no parameters") logging.info(f"Matches for module [{cls_name}]: {params}") allowed.append(params) return set.union(*allowed) def _unix_pattern_to_parameter_names(scheduler_cfg, parameter_names: Set[str], module_cls_to_param_names: Dict[type, Set[str]]): if "param_names" not in scheduler_cfg and "module_cls_names" not in scheduler_cfg: return None return unix_param_pattern_to_parameter_names( scheduler_cfg.get("param_names"), parameter_names ).union( unix_module_cls_pattern_to_parameter_names( scheduler_cfg.get("module_cls_names"), module_cls_to_param_names ) ) # ----------------------------------------------------------------------------- # Scheduler helpers # ----------------------------------------------------------------------------- def set_default_parameters(scheduler_cfgs: List[dict], all_parameter_names: Set[str]): """Ensure exactly one scheduler per option acts as the default.""" specified = [cfg["parameter_names"] for cfg in scheduler_cfgs if cfg["parameter_names"]] default_params = ( all_parameter_names if not specified else all_parameter_names - set.union(*specified) ) default_count = 0 for cfg in scheduler_cfgs: if cfg["parameter_names"] is None: cfg["parameter_names"] = default_params default_count += 1 assert default_count <= 1, "At most one default scheduler per option" if default_count == 0: scheduler_cfgs.append({"parameter_names": default_params}) def name_constraints_to_parameters(param_constraints: List[Set[str]], named_parameters: Dict[str, Tensor]) -> List[Tensor]: matching_names = set.intersection(*param_constraints) return [v for k, v in named_parameters.items() if k in matching_names] def map_scheduler_cfgs_to_param_groups(all_scheduler_cfgs: Iterable[List[dict]], named_parameters: Dict[str, Tensor]): """Produce param groups & schedulers that torch.optim can consume.""" schedulers: List[Dict[str, Any]] = [] param_groups: List[Dict[str, List[Tensor]]] = [] for cfgs in itertools.product(*all_scheduler_cfgs): param_constraints = [cfg["parameter_names"] for cfg in cfgs] matching = name_constraints_to_parameters(param_constraints, named_parameters) if not matching: continue # no intersection of params for this combo schedulers.append({cfg["option"]: cfg["scheduler"] for cfg in cfgs if "option" in cfg}) param_groups.append({"params": matching}) return schedulers, param_groups # ----------------------------------------------------------------------------- # Public factory functions # ----------------------------------------------------------------------------- def construct_optimizer(model: nn.Module, optimizer_conf: Any, options_conf: Union[Mapping[str, List], None] = None, param_group_modifiers_conf: Union[List, None] = None, validate_param_groups: bool = True) -> OptimizerWrapper: """Build an OptimizerWrapper from hydra configs. *No* allowlist handling – we always optimize *all* model parameters. """ named_parameters = dict(model.named_parameters()) all_parameter_names = set(named_parameters.keys()) module_cls_to_all_param_names = get_module_cls_to_param_names(model) # ────────────────────────────────────────────────────────────────── # No scheduler case – simple & fast # ────────────────────────────────────────────────────────────────── if not options_conf: optimizer = hydra.utils.instantiate(optimizer_conf, named_parameters.values()) return OptimizerWrapper(optimizer) # ────────────────────────────────────────────────────────────────── # Build option-specific scheduler configs # ────────────────────────────────────────────────────────────────── scheduler_cfgs_per_option = hydra.utils.instantiate(options_conf) all_scheduler_cfgs: List[List[dict]] = [] for option, cfg_list in scheduler_cfgs_per_option.items(): for cfg in cfg_list: cfg.option = option # annotate cfg.parameter_names = _unix_pattern_to_parameter_names( cfg, all_parameter_names, module_cls_to_all_param_names ) set_default_parameters(cfg_list, all_parameter_names) all_scheduler_cfgs.append(cfg_list) # User-provided modifiers (rare) if param_group_modifiers_conf: for modifier in param_group_modifiers_conf: modifier = hydra.utils.instantiate(modifier) all_scheduler_cfgs = modifier(scheduler_cfgs=all_scheduler_cfgs, model=model) # Map scheduler cfg combos to optimizer param groups schedulers, param_groups = map_scheduler_cfgs_to_param_groups( all_scheduler_cfgs, named_parameters ) if validate_param_groups: validate_param_group_params(param_groups, model) optimizer = hydra.utils.instantiate(optimizer_conf, param_groups) return OptimizerWrapper(optimizer, schedulers) def construct_optimizers(model: nn.Module, optim_conf) -> Union[List[OptimizerWrapper], None]: """Convenience wrapper producing a *single* OptimizerWrapper list.""" if optim_conf is None: return None optimizer = construct_optimizer( model, optim_conf.optimizer, optim_conf.options, validate_param_groups=True, ) return [optimizer] ================================================ FILE: training/train_utils/tb_writer.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import atexit import logging import uuid from typing import Any, Dict, Optional, Union import torch from torch.utils.tensorboard import SummaryWriter from .distributed import get_machine_local_and_dist_rank class TensorBoardLogger: """A wrapper around TensorBoard SummaryWriter with distributed training support. This logger only writes from rank 0 in distributed settings to avoid conflicts. Automatically handles cleanup on exit. """ def __init__( self, path: str, *args: Any, filename_suffix: Optional[str] = None, summary_writer_method: Any = SummaryWriter, **kwargs: Any, ) -> None: """Initialize TensorBoard logger. Args: path: Directory path where TensorBoard logs will be stored filename_suffix: Optional suffix for log filename. If None, uses random UUID summary_writer_method: SummaryWriter class or compatible alternative *args, **kwargs: Additional arguments passed to SummaryWriter """ self._writer: Optional[SummaryWriter] = None _, self._rank = get_machine_local_and_dist_rank() self._path: str = path if self._rank == 0: logging.info( f"TensorBoard SummaryWriter instantiated. Files will be stored in: {path}" ) self._writer = summary_writer_method( log_dir=path, *args, filename_suffix=filename_suffix or str(uuid.uuid4()), **kwargs, ) else: logging.debug( f"Not logging on this process because rank {self._rank} != 0" ) atexit.register(self.close) @property def writer(self) -> Optional[SummaryWriter]: """Get the underlying SummaryWriter instance.""" return self._writer @property def path(self) -> str: """Get the log directory path.""" return self._path def flush(self) -> None: """Write pending logs to disk.""" if self._writer: self._writer.flush() def close(self) -> None: """Close writer and flush pending logs to disk. Logs cannot be written after close() is called. """ if self._writer: self._writer.close() self._writer = None def log_dict(self, payload: Dict[str, Any], step: int) -> None: """Log multiple scalar values to TensorBoard. Args: payload: Dictionary mapping tag names to scalar values step: Step value to record """ if not self._writer: return for key, value in payload.items(): self.log(key, value, step) def log(self, name: str, data: Any, step: int) -> None: """Log scalar data to TensorBoard. Args: name: Tag name used to group scalars data: Scalar data to log (float/int/Tensor) step: Step value to record """ if not self._writer: return self._writer.add_scalar(name, data, global_step=step, new_style=True) def log_visuals( self, name: str, data: Union[torch.Tensor, Any], step: int, fps: int = 4 ) -> None: """Log image or video data to TensorBoard. Args: name: Tag name used to group visuals data: Image tensor (3D) or video tensor (5D) step: Step value to record fps: Frames per second for video data Raises: ValueError: If data dimensions are not supported (must be 3D or 5D) """ if not self._writer: return if data.ndim == 3: self._writer.add_image(name, data, global_step=step) elif data.ndim == 5: self._writer.add_video(name, data, global_step=step, fps=fps) else: raise ValueError( f"Unsupported data dimensions: {data.ndim}. " "Expected 3D for images or 5D for videos." ) ================================================ FILE: training/trainer.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import os # --- Environment Variable Setup for Performance and Debugging --- # Helps with memory fragmentation in PyTorch's memory allocator. os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' # Specifies the threading layer for MKL, can prevent hangs in some environments. os.environ["MKL_THREADING_LAYER"] = "GNU" # Provides full Hydra stack traces on error for easier debugging. os.environ["HYDRA_FULL_ERROR"] = "1" # Enables asynchronous error handling for NCCL, which can prevent hangs. os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1" import contextlib import gc import json import logging import math import time from datetime import timedelta from typing import Any, Dict, List, Mapping, Optional, Sequence import torch import torch.distributed as dist import torch.nn as nn import torchvision from hydra.utils import instantiate from iopath.common.file_io import g_pathmgr from train_utils.checkpoint import DDPCheckpointSaver from train_utils.distributed import get_machine_local_and_dist_rank from train_utils.freeze import freeze_modules from train_utils.general import * from train_utils.logging import setup_logging from train_utils.normalization import normalize_camera_extrinsics_and_points_batch from train_utils.optimizer import construct_optimizers class Trainer: """ A generic trainer for DDP training. This should naturally support multi-node training. This class orchestrates the entire training and validation process, including: - Setting up the distributed environment (DDP). - Initializing the model, optimizers, loss functions, and data loaders. - Handling checkpointing for resuming training. - Executing the main training and validation loops. - Logging metrics and visualizations to TensorBoard. """ EPSILON = 1e-8 def __init__( self, *, data: Dict[str, Any], model: Dict[str, Any], logging: Dict[str, Any], checkpoint: Dict[str, Any], max_epochs: int, mode: str = "train", device: str = "cuda", seed_value: int = 123, val_epoch_freq: int = 1, distributed: Dict[str, bool] = None, cuda: Dict[str, bool] = None, limit_train_batches: Optional[int] = None, limit_val_batches: Optional[int] = None, optim: Optional[Dict[str, Any]] = None, loss: Optional[Dict[str, Any]] = None, env_variables: Optional[Dict[str, Any]] = None, accum_steps: int = 1, **kwargs, ): """ Initializes the Trainer. Args: data: Hydra config for datasets and dataloaders. model: Hydra config for the model. logging: Hydra config for logging (TensorBoard, log frequencies). checkpoint: Hydra config for checkpointing. max_epochs: Total number of epochs to train. mode: "train" for training and validation, "val" for validation only. device: "cuda" or "cpu". seed_value: A random seed for reproducibility. val_epoch_freq: Frequency (in epochs) to run validation. distributed: Hydra config for DDP settings. cuda: Hydra config for CUDA-specific settings (e.g., cuDNN). limit_train_batches: Limit the number of training batches per epoch (for debugging). limit_val_batches: Limit the number of validation batches per epoch (for debugging). optim: Hydra config for optimizers and schedulers. loss: Hydra config for the loss function. env_variables: Dictionary of environment variables to set. accum_steps: Number of steps to accumulate gradients before an optimizer step. """ self._setup_env_variables(env_variables) self._setup_timers() # Store Hydra configurations self.data_conf = data self.model_conf = model self.loss_conf = loss self.logging_conf = logging self.checkpoint_conf = checkpoint self.optim_conf = optim # Store hyperparameters self.accum_steps = accum_steps self.max_epochs = max_epochs self.mode = mode self.val_epoch_freq = val_epoch_freq self.limit_train_batches = limit_train_batches self.limit_val_batches = limit_val_batches self.seed_value = seed_value # 'where' tracks training progress from 0.0 to 1.0 for schedulers self.where = 0.0 self._setup_device(device) self._setup_torch_dist_and_backend(cuda, distributed) # Setup logging directory and configure logger safe_makedirs(self.logging_conf.log_dir) setup_logging( __name__, output_dir=self.logging_conf.log_dir, rank=self.rank, log_level_primary=self.logging_conf.log_level_primary, log_level_secondary=self.logging_conf.log_level_secondary, all_ranks=self.logging_conf.all_ranks, ) set_seeds(seed_value, self.max_epochs, self.distributed_rank) assert is_dist_avail_and_initialized(), "Torch distributed needs to be initialized before calling the trainer." # Instantiate components (model, loss, etc.) self._setup_components() self._setup_dataloaders() # Move model to the correct device self.model.to(self.device) self.time_elapsed_meter = DurationMeter("Time Elapsed", self.device, ":.4f") # Construct optimizers (after moving model to device) if self.mode != "val": self.optims = construct_optimizers(self.model, self.optim_conf) # Load checkpoint if available or specified if self.checkpoint_conf.resume_checkpoint_path is not None: self._load_resuming_checkpoint(self.checkpoint_conf.resume_checkpoint_path) else: ckpt_path = get_resume_checkpoint(self.checkpoint_conf.save_dir) if ckpt_path is not None: self._load_resuming_checkpoint(ckpt_path) # Wrap the model with DDP self._setup_ddp_distributed_training(distributed, device) # Barrier to ensure all processes are synchronized before starting dist.barrier() def _setup_timers(self): """Initializes timers for tracking total elapsed time.""" self.start_time = time.time() self.ckpt_time_elapsed = 0 def _setup_env_variables(self, env_variables_conf: Optional[Dict[str, Any]]) -> None: """Sets environment variables from the configuration.""" if env_variables_conf: for variable_name, value in env_variables_conf.items(): os.environ[variable_name] = value logging.info(f"Environment:\n{json.dumps(dict(os.environ), sort_keys=True, indent=2)}") def _setup_torch_dist_and_backend(self, cuda_conf: Dict, distributed_conf: Dict) -> None: """Initializes the distributed process group and configures PyTorch backends.""" if torch.cuda.is_available(): # Configure CUDA backend settings for performance torch.backends.cudnn.deterministic = cuda_conf.cudnn_deterministic torch.backends.cudnn.benchmark = cuda_conf.cudnn_benchmark torch.backends.cuda.matmul.allow_tf32 = cuda_conf.allow_tf32 torch.backends.cudnn.allow_tf32 = cuda_conf.allow_tf32 # Initialize the DDP process group dist.init_process_group( backend=distributed_conf.backend, timeout=timedelta(minutes=distributed_conf.timeout_mins) ) self.rank = dist.get_rank() def _load_resuming_checkpoint(self, ckpt_path: str): """Loads a checkpoint from the given path to resume training.""" logging.info(f"Resuming training from {ckpt_path} (rank {self.rank})") with g_pathmgr.open(ckpt_path, "rb") as f: checkpoint = torch.load(f, map_location="cpu") # Load model state model_state_dict = checkpoint["model"] if "model" in checkpoint else checkpoint missing, unexpected = self.model.load_state_dict( model_state_dict, strict=self.checkpoint_conf.strict ) if self.rank == 0: logging.info(f"Model state loaded. Missing keys: {missing or 'None'}. Unexpected keys: {unexpected or 'None'}.") # Load optimizer state if available and in training mode if "optimizer" in checkpoint: logging.info(f"Loading optimizer state dict (rank {self.rank})") self.optims.optimizer.load_state_dict(checkpoint["optimizer"]) # Load training progress if "epoch" in checkpoint: self.epoch = checkpoint["epoch"] self.steps = checkpoint["steps"] if "steps" in checkpoint else {"train": 0, "val": 0} self.ckpt_time_elapsed = checkpoint.get("time_elapsed", 0) # Load AMP scaler state if available if self.optim_conf.amp.enabled and "scaler" in checkpoint: self.scaler.load_state_dict(checkpoint["scaler"]) def _setup_device(self, device: str): """Sets up the device for training (CPU or CUDA).""" self.local_rank, self.distributed_rank = get_machine_local_and_dist_rank() if device == "cuda": self.device = torch.device("cuda", self.local_rank) torch.cuda.set_device(self.local_rank) elif device == "cpu": self.device = torch.device("cpu") else: raise ValueError(f"Unsupported device: {device}") def _setup_components(self): """Initializes all core training components using Hydra configs.""" logging.info("Setting up components: Model, Loss, Logger, etc.") self.epoch = 0 self.steps = {'train': 0, 'val': 0} # Instantiate components from configs self.tb_writer = instantiate(self.logging_conf.tensorboard_writer, _recursive_=False) self.model = instantiate(self.model_conf, _recursive_=False) self.loss = instantiate(self.loss_conf, _recursive_=False) self.gradient_clipper = instantiate(self.optim_conf.gradient_clip) self.scaler = torch.cuda.amp.GradScaler(enabled=self.optim_conf.amp.enabled) # Freeze specified model parameters if any if getattr(self.optim_conf, "frozen_module_names", None): logging.info( f"[Start] Freezing modules: {self.optim_conf.frozen_module_names} on rank {self.distributed_rank}" ) self.model = freeze_modules( self.model, patterns=self.optim_conf.frozen_module_names, ) logging.info( f"[Done] Freezing modules: {self.optim_conf.frozen_module_names} on rank {self.distributed_rank}" ) # Log model summary on rank 0 if self.rank == 0: model_summary_path = os.path.join(self.logging_conf.log_dir, "model.txt") model_summary(self.model, log_file=model_summary_path) logging.info(f"Model summary saved to {model_summary_path}") logging.info("Successfully initialized training components.") def _setup_dataloaders(self): """Initializes train and validation datasets and dataloaders.""" self.train_dataset = None self.val_dataset = None if self.mode in ["train", "val"]: self.val_dataset = instantiate( self.data_conf.get('val', None), _recursive_=False ) if self.val_dataset is not None: self.val_dataset.seed = self.seed_value if self.mode in ["train"]: self.train_dataset = instantiate(self.data_conf.train, _recursive_=False) self.train_dataset.seed = self.seed_value def _setup_ddp_distributed_training(self, distributed_conf: Dict, device: str): """Wraps the model with DistributedDataParallel (DDP).""" assert isinstance(self.model, torch.nn.Module) ddp_options = dict( find_unused_parameters=distributed_conf.find_unused_parameters, gradient_as_bucket_view=distributed_conf.gradient_as_bucket_view, bucket_cap_mb=distributed_conf.bucket_cap_mb, broadcast_buffers=distributed_conf.broadcast_buffers, ) self.model = nn.parallel.DistributedDataParallel( self.model, device_ids=[self.local_rank] if device == "cuda" else [], **ddp_options, ) def save_checkpoint(self, epoch: int, checkpoint_names: Optional[List[str]] = None): """ Saves a training checkpoint. Args: epoch: The current epoch number. checkpoint_names: A list of names for the checkpoint file (e.g., "checkpoint_latest"). If None, saves "checkpoint" and "checkpoint_{epoch}" on frequency. """ checkpoint_folder = self.checkpoint_conf.save_dir safe_makedirs(checkpoint_folder) if checkpoint_names is None: checkpoint_names = ["checkpoint"] if ( self.checkpoint_conf.save_freq > 0 and int(epoch) % self.checkpoint_conf.save_freq == 0 and (int(epoch) > 0 or self.checkpoint_conf.save_freq == 1) ): checkpoint_names.append(f"checkpoint_{int(epoch)}") checkpoint_content = { "prev_epoch": epoch, "steps": self.steps, "time_elapsed": self.time_elapsed_meter.val, "optimizer": [optim.optimizer.state_dict() for optim in self.optims], } if len(self.optims) == 1: checkpoint_content["optimizer"] = checkpoint_content["optimizer"][0] if self.optim_conf.amp.enabled: checkpoint_content["scaler"] = self.scaler.state_dict() # Save the checkpoint for DDP only saver = DDPCheckpointSaver( checkpoint_folder, checkpoint_names=checkpoint_names, rank=self.distributed_rank, epoch=epoch, ) if isinstance(self.model, torch.nn.parallel.DistributedDataParallel): model = self.model.module saver.save_checkpoint( model=model, ema_models = None, skip_saving_parameters=[], **checkpoint_content, ) def _get_scalar_log_keys(self, phase: str) -> List[str]: """Retrieves keys for scalar values to be logged for a given phase.""" if self.logging_conf.scalar_keys_to_log: return self.logging_conf.scalar_keys_to_log[phase].keys_to_log return [] def run(self): """Main entry point to start the training or validation process.""" assert self.mode in ["train", "val"], f"Invalid mode: {self.mode}" if self.mode == "train": self.run_train() # Optionally run a final validation after all training is done self.run_val() elif self.mode == "val": self.run_val() else: raise ValueError(f"Invalid mode: {self.mode}") def run_train(self): """Runs the main training loop over all epochs.""" while self.epoch < self.max_epochs: set_seeds(self.seed_value + self.epoch * 100, self.max_epochs, self.distributed_rank) dataloader = self.train_dataset.get_loader(epoch=int(self.epoch + self.distributed_rank)) self.train_epoch(dataloader) # Save checkpoint after each training epoch self.save_checkpoint(self.epoch) # Clean up memory del dataloader gc.collect() torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() # Run validation at the specified frequency # Skips validation after the last training epoch, as it can be run separately. if self.epoch % self.val_epoch_freq == 0 and self.epoch < self.max_epochs - 1: self.run_val() self.epoch += 1 self.epoch -= 1 def run_val(self): """Runs a full validation epoch if a validation dataset is available.""" if not self.val_dataset: logging.info("No validation dataset configured. Skipping validation.") return dataloader = self.val_dataset.get_loader(epoch=int(self.epoch + self.distributed_rank)) self.val_epoch(dataloader) del dataloader gc.collect() torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() @torch.no_grad() def val_epoch(self, val_loader): batch_time = AverageMeter("Batch Time", self.device, ":.4f") data_time = AverageMeter("Data Time", self.device, ":.4f") mem = AverageMeter("Mem (GB)", self.device, ":.4f") data_times = [] phase = 'val' loss_names = self._get_scalar_log_keys(phase) loss_names = [f"Loss/{phase}_{name}" for name in loss_names] loss_meters = { name: AverageMeter(name, self.device, ":.4f") for name in loss_names } progress = ProgressMeter( num_batches=len(val_loader), meters=[ batch_time, data_time, mem, self.time_elapsed_meter, *loss_meters.values(), ], real_meters={}, prefix="Val Epoch: [{}]".format(self.epoch), ) self.model.eval() end = time.time() iters_per_epoch = len(val_loader) limit_val_batches = ( iters_per_epoch if self.limit_val_batches is None else self.limit_val_batches ) for data_iter, batch in enumerate(val_loader): if data_iter > limit_val_batches: break # measure data loading time data_time.update(time.time() - end) data_times.append(data_time.val) with torch.cuda.amp.autocast(enabled=False): batch = self._process_batch(batch) batch = copy_data_to_device(batch, self.device, non_blocking=True) amp_type = self.optim_conf.amp.amp_dtype assert amp_type in ["bfloat16", "float16"], f"Invalid Amp type: {amp_type}" if amp_type == "bfloat16": amp_type = torch.bfloat16 else: amp_type = torch.float16 # compute output with torch.no_grad(): with torch.cuda.amp.autocast( enabled=self.optim_conf.amp.enabled, dtype=amp_type, ): val_loss_dict = self._step( batch, self.model, phase, loss_meters ) # measure elapsed time batch_time.update(time.time() - end) end = time.time() self.time_elapsed_meter.update( time.time() - self.start_time + self.ckpt_time_elapsed ) if torch.cuda.is_available(): mem.update(torch.cuda.max_memory_allocated() // 1e9) if data_iter % self.logging_conf.log_freq == 0: progress.display(data_iter) return True def train_epoch(self, train_loader): batch_time = AverageMeter("Batch Time", self.device, ":.4f") data_time = AverageMeter("Data Time", self.device, ":.4f") mem = AverageMeter("Mem (GB)", self.device, ":.4f") data_times = [] phase = 'train' loss_names = self._get_scalar_log_keys(phase) loss_names = [f"Loss/{phase}_{name}" for name in loss_names] loss_meters = { name: AverageMeter(name, self.device, ":.4f") for name in loss_names } for config in self.gradient_clipper.configs: param_names = ",".join(config['module_names']) loss_meters[f"Grad/{param_names}"] = AverageMeter(f"Grad/{param_names}", self.device, ":.4f") progress = ProgressMeter( num_batches=len(train_loader), meters=[ batch_time, data_time, mem, self.time_elapsed_meter, *loss_meters.values(), ], real_meters={}, prefix="Train Epoch: [{}]".format(self.epoch), ) self.model.train() end = time.time() iters_per_epoch = len(train_loader) limit_train_batches = ( iters_per_epoch if self.limit_train_batches is None else self.limit_train_batches ) if self.gradient_clipper is not None: # setup gradient clipping at the beginning of training self.gradient_clipper.setup_clipping(self.model) for data_iter, batch in enumerate(train_loader): if data_iter > limit_train_batches: break # measure data loading time data_time.update(time.time() - end) data_times.append(data_time.val) with torch.cuda.amp.autocast(enabled=False): batch = self._process_batch(batch) batch = copy_data_to_device(batch, self.device, non_blocking=True) accum_steps = self.accum_steps if accum_steps==1: chunked_batches = [batch] else: chunked_batches = chunk_batch_for_accum_steps(batch, accum_steps) self._run_steps_on_batch_chunks( chunked_batches, phase, loss_meters ) # compute gradient and do SGD step assert data_iter <= limit_train_batches # allow for off by one errors exact_epoch = self.epoch + float(data_iter) / limit_train_batches self.where = float(exact_epoch) / self.max_epochs assert self.where <= 1 + self.EPSILON if self.where < 1.0: for optim in self.optims: optim.step_schedulers(self.where) else: logging.warning( f"Skipping scheduler update since the training is at the end, i.e, {self.where} of [0,1]." ) # Log schedulers if self.steps[phase] % self.logging_conf.log_freq == 0: for i, optim in enumerate(self.optims): for j, param_group in enumerate(optim.optimizer.param_groups): for option in optim.schedulers[j]: optim_prefix = ( f"{i}_" if len(self.optims) > 1 else ( "" + f"{j}_" if len(optim.optimizer.param_groups) > 1 else "" ) ) self.tb_writer.log( os.path.join("Optim", f"{optim_prefix}", option), param_group[option], self.steps[phase], ) self.tb_writer.log( os.path.join("Optim", "where"), self.where, self.steps[phase], ) # Clipping gradients and detecting diverging gradients if self.gradient_clipper is not None: for optim in self.optims: self.scaler.unscale_(optim.optimizer) grad_norm_dict = self.gradient_clipper(model=self.model) for key, grad_norm in grad_norm_dict.items(): loss_meters[f"Grad/{key}"].update(grad_norm) # Optimizer step for optim in self.optims: self.scaler.step(optim.optimizer) self.scaler.update() # Measure elapsed time batch_time.update(time.time() - end) end = time.time() self.time_elapsed_meter.update( time.time() - self.start_time + self.ckpt_time_elapsed ) mem.update(torch.cuda.max_memory_allocated() // 1e9) if data_iter % self.logging_conf.log_freq == 0: progress.display(data_iter) return True def _run_steps_on_batch_chunks( self, chunked_batches: List[Any], phase: str, loss_meters: Dict[str, AverageMeter], ): """ Run the forward / backward as many times as there are chunks in the batch, accumulating the gradients on each backward """ for optim in self.optims: optim.zero_grad(set_to_none=True) accum_steps = len(chunked_batches) amp_type = self.optim_conf.amp.amp_dtype assert amp_type in ["bfloat16", "float16"], f"Invalid Amp type: {amp_type}" if amp_type == "bfloat16": amp_type = torch.bfloat16 else: amp_type = torch.float16 for i, chunked_batch in enumerate(chunked_batches): ddp_context = ( self.model.no_sync() if i < accum_steps - 1 else contextlib.nullcontext() ) with ddp_context: with torch.cuda.amp.autocast( enabled=self.optim_conf.amp.enabled, dtype=amp_type, ): loss_dict = self._step( chunked_batch, self.model, phase, loss_meters ) loss = loss_dict["objective"] loss_key = f"Loss/{phase}_loss_objective" batch_size = chunked_batch["images"].shape[0] if not math.isfinite(loss.item()): error_msg = f"Loss is {loss.item()}, attempting to stop training" logging.error(error_msg) return loss /= accum_steps self.scaler.scale(loss).backward() loss_meters[loss_key].update(loss.item(), batch_size) def _apply_batch_repetition(self, batch: Mapping) -> Mapping: """ Applies a data augmentation by concatenating the original batch with a flipped version of itself. """ tensor_keys = [ "images", "depths", "extrinsics", "intrinsics", "cam_points", "world_points", "point_masks", ] string_keys = ["seq_name"] for key in tensor_keys: if key in batch: original_tensor = batch[key] batch[key] = torch.concatenate([original_tensor, torch.flip(original_tensor, dims=[1])], dim=0) for key in string_keys: if key in batch: batch[key] = batch[key] * 2 return batch def _process_batch(self, batch: Mapping): if self.data_conf.train.common_config.repeat_batch: batch = self._apply_batch_repetition(batch) # Normalize camera extrinsics and points. The function returns new tensors. normalized_extrinsics, normalized_cam_points, normalized_world_points, normalized_depths = \ normalize_camera_extrinsics_and_points_batch( extrinsics=batch["extrinsics"], cam_points=batch["cam_points"], world_points=batch["world_points"], depths=batch["depths"], point_masks=batch["point_masks"], ) # Replace the original values in the batch with the normalized ones. batch["extrinsics"] = normalized_extrinsics batch["cam_points"] = normalized_cam_points batch["world_points"] = normalized_world_points batch["depths"] = normalized_depths return batch def _step(self, batch, model: nn.Module, phase: str, loss_meters: dict): """ Performs a single forward pass, computes loss, and logs results. Returns: A dictionary containing the computed losses. """ # Forward pass y_hat = model(images=batch["images"]) # Loss computation loss_dict = self.loss(y_hat, batch) # Combine all data for logging log_data = {**y_hat, **loss_dict, **batch} self._update_and_log_scalars(log_data, phase, self.steps[phase], loss_meters) self._log_tb_visuals(log_data, phase, self.steps[phase]) self.steps[phase] += 1 return loss_dict def _update_and_log_scalars(self, data: Mapping, phase: str, step: int, loss_meters: dict): """Updates average meters and logs scalar values to TensorBoard.""" keys_to_log = self._get_scalar_log_keys(phase) batch_size = data['extrinsics'].shape[0] for key in keys_to_log: if key in data: value = data[key].item() if torch.is_tensor(data[key]) else data[key] loss_meters[f"Loss/{phase}_{key}"].update(value, batch_size) if step % self.logging_conf.log_freq == 0 and self.rank == 0: self.tb_writer.log(f"Values/{phase}/{key}", value, step) def _log_tb_visuals(self, batch: Mapping, phase: str, step: int) -> None: """Logs image or video visualizations to TensorBoard.""" if not ( self.logging_conf.log_visuals and (phase in self.logging_conf.log_visual_frequency) and self.logging_conf.log_visual_frequency[phase] > 0 and (step % self.logging_conf.log_visual_frequency[phase] == 0) and (self.logging_conf.visuals_keys_to_log is not None) ): return if phase in self.logging_conf.visuals_keys_to_log: keys_to_log = self.logging_conf.visuals_keys_to_log[phase][ "keys_to_log" ] assert ( len(keys_to_log) > 0 ), "Need to include some visual keys to log" modality = self.logging_conf.visuals_keys_to_log[phase][ "modality" ] assert modality in [ "image", "video", ], "Currently only support video or image logging" name = f"Visuals/{phase}" visuals_to_log = torchvision.utils.make_grid( [ torchvision.utils.make_grid( batch[key][0], # Ensure batch[key][0] is tensor and has at least 3 dimensions nrow=self.logging_conf.visuals_per_batch_to_log, ) for key in keys_to_log if key in batch and batch[key][0].dim() >= 3 ], nrow=1, ).clamp(-1, 1) visuals_to_log = visuals_to_log.cpu() if visuals_to_log.dtype == torch.bfloat16: visuals_to_log = visuals_to_log.to(torch.float16) visuals_to_log = visuals_to_log.numpy() self.tb_writer.log_visuals( name, visuals_to_log, step, self.logging_conf.video_logging_fps ) def chunk_batch_for_accum_steps(batch: Mapping, accum_steps: int) -> List[Mapping]: """Splits a batch into smaller chunks for gradient accumulation.""" if accum_steps == 1: return [batch] return [get_chunk_from_data(batch, i, accum_steps) for i in range(accum_steps)] def is_sequence_of_primitives(data: Any) -> bool: """Checks if data is a sequence of primitive types (str, int, float, bool).""" return ( isinstance(data, Sequence) and not isinstance(data, str) and len(data) > 0 and isinstance(data[0], (str, int, float, bool)) ) def get_chunk_from_data(data: Any, chunk_id: int, num_chunks: int) -> Any: """ Recursively splits tensors and sequences within a data structure into chunks. Args: data: The data structure to split (e.g., a dictionary of tensors). chunk_id: The index of the chunk to retrieve. num_chunks: The total number of chunks to split the data into. Returns: A chunk of the original data structure. """ if isinstance(data, torch.Tensor) or is_sequence_of_primitives(data): # either a tensor or a list of primitive objects # assert len(data) % num_chunks == 0 start = (len(data) // num_chunks) * chunk_id end = (len(data) // num_chunks) * (chunk_id + 1) return data[start:end] elif isinstance(data, Mapping): return { key: get_chunk_from_data(value, chunk_id, num_chunks) for key, value in data.items() } elif isinstance(data, str): # NOTE: this is a hack to support string keys in the batch return data elif isinstance(data, Sequence): return [get_chunk_from_data(value, chunk_id, num_chunks) for value in data] else: return data ================================================ FILE: vggt/dependency/__init__.py ================================================ from .track_modules.track_refine import refine_track from .track_modules.blocks import BasicEncoder, ShallowEncoder from .track_modules.base_track_predictor import BaseTrackerPredictor ================================================ FILE: vggt/dependency/distortion.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch import numpy as np from typing import Union ArrayLike = Union[np.ndarray, torch.Tensor] def _is_numpy(x: ArrayLike) -> bool: return isinstance(x, np.ndarray) def _is_torch(x: ArrayLike) -> bool: return isinstance(x, torch.Tensor) def _ensure_torch(x: ArrayLike) -> torch.Tensor: """Convert input to torch tensor if it's not already one.""" if _is_numpy(x): return torch.from_numpy(x) elif _is_torch(x): return x else: return torch.tensor(x) def single_undistortion(params, tracks_normalized): """ Apply undistortion to the normalized tracks using the given distortion parameters once. Args: params (torch.Tensor or numpy.ndarray): Distortion parameters of shape BxN. tracks_normalized (torch.Tensor or numpy.ndarray): Normalized tracks tensor of shape [batch_size, num_tracks, 2]. Returns: torch.Tensor: Undistorted normalized tracks tensor. """ params = _ensure_torch(params) tracks_normalized = _ensure_torch(tracks_normalized) u, v = tracks_normalized[..., 0].clone(), tracks_normalized[..., 1].clone() u_undist, v_undist = apply_distortion(params, u, v) return torch.stack([u_undist, v_undist], dim=-1) def iterative_undistortion(params, tracks_normalized, max_iterations=100, max_step_norm=1e-10, rel_step_size=1e-6): """ Iteratively undistort the normalized tracks using the given distortion parameters. Args: params (torch.Tensor or numpy.ndarray): Distortion parameters of shape BxN. tracks_normalized (torch.Tensor or numpy.ndarray): Normalized tracks tensor of shape [batch_size, num_tracks, 2]. max_iterations (int): Maximum number of iterations for the undistortion process. max_step_norm (float): Maximum step norm for convergence. rel_step_size (float): Relative step size for numerical differentiation. Returns: torch.Tensor: Undistorted normalized tracks tensor. """ params = _ensure_torch(params) tracks_normalized = _ensure_torch(tracks_normalized) B, N, _ = tracks_normalized.shape u, v = tracks_normalized[..., 0].clone(), tracks_normalized[..., 1].clone() original_u, original_v = u.clone(), v.clone() eps = torch.finfo(u.dtype).eps for idx in range(max_iterations): u_undist, v_undist = apply_distortion(params, u, v) dx = original_u - u_undist dy = original_v - v_undist step_u = torch.clamp(torch.abs(u) * rel_step_size, min=eps) step_v = torch.clamp(torch.abs(v) * rel_step_size, min=eps) J_00 = (apply_distortion(params, u + step_u, v)[0] - apply_distortion(params, u - step_u, v)[0]) / (2 * step_u) J_01 = (apply_distortion(params, u, v + step_v)[0] - apply_distortion(params, u, v - step_v)[0]) / (2 * step_v) J_10 = (apply_distortion(params, u + step_u, v)[1] - apply_distortion(params, u - step_u, v)[1]) / (2 * step_u) J_11 = (apply_distortion(params, u, v + step_v)[1] - apply_distortion(params, u, v - step_v)[1]) / (2 * step_v) J = torch.stack([torch.stack([J_00 + 1, J_01], dim=-1), torch.stack([J_10, J_11 + 1], dim=-1)], dim=-2) delta = torch.linalg.solve(J, torch.stack([dx, dy], dim=-1)) u += delta[..., 0] v += delta[..., 1] if torch.max((delta**2).sum(dim=-1)) < max_step_norm: break return torch.stack([u, v], dim=-1) def apply_distortion(extra_params, u, v): """ Applies radial or OpenCV distortion to the given 2D points. Args: extra_params (torch.Tensor or numpy.ndarray): Distortion parameters of shape BxN, where N can be 1, 2, or 4. u (torch.Tensor or numpy.ndarray): Normalized x coordinates of shape Bxnum_tracks. v (torch.Tensor or numpy.ndarray): Normalized y coordinates of shape Bxnum_tracks. Returns: points2D (torch.Tensor): Distorted 2D points of shape BxNx2. """ extra_params = _ensure_torch(extra_params) u = _ensure_torch(u) v = _ensure_torch(v) num_params = extra_params.shape[1] if num_params == 1: # Simple radial distortion k = extra_params[:, 0] u2 = u * u v2 = v * v r2 = u2 + v2 radial = k[:, None] * r2 du = u * radial dv = v * radial elif num_params == 2: # RadialCameraModel distortion k1, k2 = extra_params[:, 0], extra_params[:, 1] u2 = u * u v2 = v * v r2 = u2 + v2 radial = k1[:, None] * r2 + k2[:, None] * r2 * r2 du = u * radial dv = v * radial elif num_params == 4: # OpenCVCameraModel distortion k1, k2, p1, p2 = (extra_params[:, 0], extra_params[:, 1], extra_params[:, 2], extra_params[:, 3]) u2 = u * u v2 = v * v uv = u * v r2 = u2 + v2 radial = k1[:, None] * r2 + k2[:, None] * r2 * r2 du = u * radial + 2 * p1[:, None] * uv + p2[:, None] * (r2 + 2 * u2) dv = v * radial + 2 * p2[:, None] * uv + p1[:, None] * (r2 + 2 * v2) else: raise ValueError("Unsupported number of distortion parameters") u = u.clone() + du v = v.clone() + dv return u, v if __name__ == "__main__": import random import pycolmap max_diff = 0 for i in range(1000): # Define distortion parameters (assuming 1 parameter for simplicity) B = random.randint(1, 500) track_num = random.randint(100, 1000) params = torch.rand((B, 1), dtype=torch.float32) # Batch size 1, 4 parameters tracks_normalized = torch.rand((B, track_num, 2), dtype=torch.float32) # Batch size 1, 5 points # Undistort the tracks undistorted_tracks = iterative_undistortion(params, tracks_normalized) for b in range(B): pycolmap_intri = np.array([1, 0, 0, params[b].item()]) pycam = pycolmap.Camera(model="SIMPLE_RADIAL", width=1, height=1, params=pycolmap_intri, camera_id=0) undistorted_tracks_pycolmap = pycam.cam_from_img(tracks_normalized[b].numpy()) diff = (undistorted_tracks[b] - undistorted_tracks_pycolmap).abs().median() max_diff = max(max_diff, diff) print(f"diff: {diff}, max_diff: {max_diff}") import pdb pdb.set_trace() ================================================ FILE: vggt/dependency/np_to_pycolmap.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import numpy as np import pycolmap from .projection import project_3D_points_np def batch_np_matrix_to_pycolmap( points3d, extrinsics, intrinsics, tracks, image_size, masks=None, max_reproj_error=None, max_points3D_val=3000, shared_camera=False, camera_type="SIMPLE_PINHOLE", extra_params=None, min_inlier_per_frame=64, points_rgb=None, ): """ Convert Batched NumPy Arrays to PyCOLMAP Check https://github.com/colmap/pycolmap for more details about its format NOTE that colmap expects images/cameras/points3D to be 1-indexed so there is a +1 offset between colmap index and batch index NOTE: different from VGGSfM, this function: 1. Use np instead of torch 2. Frame index and camera id starts from 1 rather than 0 (to fit the format of PyCOLMAP) """ # points3d: Px3 # extrinsics: Nx3x4 # intrinsics: Nx3x3 # tracks: NxPx2 # masks: NxP # image_size: 2, assume all the frames have been padded to the same size # where N is the number of frames and P is the number of tracks N, P, _ = tracks.shape assert len(extrinsics) == N assert len(intrinsics) == N assert len(points3d) == P assert image_size.shape[0] == 2 reproj_mask = None if max_reproj_error is not None: projected_points_2d, projected_points_cam = project_3D_points_np(points3d, extrinsics, intrinsics) projected_diff = np.linalg.norm(projected_points_2d - tracks, axis=-1) projected_points_2d[projected_points_cam[:, -1] <= 0] = 1e6 reproj_mask = projected_diff < max_reproj_error if masks is not None and reproj_mask is not None: masks = np.logical_and(masks, reproj_mask) elif masks is not None: masks = masks else: masks = reproj_mask assert masks is not None if masks.sum(1).min() < min_inlier_per_frame: print(f"Not enough inliers per frame, skip BA.") return None, None # Reconstruction object, following the format of PyCOLMAP/COLMAP reconstruction = pycolmap.Reconstruction() inlier_num = masks.sum(0) valid_mask = inlier_num >= 2 # a track is invalid if without two inliers valid_idx = np.nonzero(valid_mask)[0] # Only add 3D points that have sufficient 2D points for vidx in valid_idx: # Use RGB colors if provided, otherwise use zeros rgb = points_rgb[vidx] if points_rgb is not None else np.zeros(3) reconstruction.add_point3D(points3d[vidx], pycolmap.Track(), rgb) num_points3D = len(valid_idx) camera = None # frame idx for fidx in range(N): # set camera if camera is None or (not shared_camera): pycolmap_intri = _build_pycolmap_intri(fidx, intrinsics, camera_type, extra_params) camera = pycolmap.Camera( model=camera_type, width=image_size[0], height=image_size[1], params=pycolmap_intri, camera_id=fidx + 1 ) # add camera reconstruction.add_camera(camera) # set image cam_from_world = pycolmap.Rigid3d( pycolmap.Rotation3d(extrinsics[fidx][:3, :3]), extrinsics[fidx][:3, 3] ) # Rot and Trans image = pycolmap.Image( id=fidx + 1, name=f"image_{fidx + 1}", camera_id=camera.camera_id, cam_from_world=cam_from_world ) points2D_list = [] point2D_idx = 0 # NOTE point3D_id start by 1 for point3D_id in range(1, num_points3D + 1): original_track_idx = valid_idx[point3D_id - 1] if (reconstruction.points3D[point3D_id].xyz < max_points3D_val).all(): if masks[fidx][original_track_idx]: # It seems we don't need +0.5 for BA point2D_xy = tracks[fidx][original_track_idx] # Please note when adding the Point2D object # It not only requires the 2D xy location, but also the id to 3D point points2D_list.append(pycolmap.Point2D(point2D_xy, point3D_id)) # add element track = reconstruction.points3D[point3D_id].track track.add_element(fidx + 1, point2D_idx) point2D_idx += 1 assert point2D_idx == len(points2D_list) try: image.points2D = pycolmap.ListPoint2D(points2D_list) image.registered = True except: print(f"frame {fidx + 1} is out of BA") image.registered = False # add image reconstruction.add_image(image) return reconstruction, valid_mask def pycolmap_to_batch_np_matrix(reconstruction, device="cpu", camera_type="SIMPLE_PINHOLE"): """ Convert a PyCOLMAP Reconstruction Object to batched NumPy arrays. Args: reconstruction (pycolmap.Reconstruction): The reconstruction object from PyCOLMAP. device (str): Ignored in NumPy version (kept for API compatibility). camera_type (str): The type of camera model used (default: "SIMPLE_PINHOLE"). Returns: tuple: A tuple containing points3D, extrinsics, intrinsics, and optionally extra_params. """ num_images = len(reconstruction.images) max_points3D_id = max(reconstruction.point3D_ids()) points3D = np.zeros((max_points3D_id, 3)) for point3D_id in reconstruction.points3D: points3D[point3D_id - 1] = reconstruction.points3D[point3D_id].xyz extrinsics = [] intrinsics = [] extra_params = [] if camera_type == "SIMPLE_RADIAL" else None for i in range(num_images): # Extract and append extrinsics pyimg = reconstruction.images[i + 1] pycam = reconstruction.cameras[pyimg.camera_id] matrix = pyimg.cam_from_world.matrix() extrinsics.append(matrix) # Extract and append intrinsics calibration_matrix = pycam.calibration_matrix() intrinsics.append(calibration_matrix) if camera_type == "SIMPLE_RADIAL": extra_params.append(pycam.params[-1]) # Convert lists to NumPy arrays instead of torch tensors extrinsics = np.stack(extrinsics) intrinsics = np.stack(intrinsics) if camera_type == "SIMPLE_RADIAL": extra_params = np.stack(extra_params) extra_params = extra_params[:, None] return points3D, extrinsics, intrinsics, extra_params ######################################################## def batch_np_matrix_to_pycolmap_wo_track( points3d, points_xyf, points_rgb, extrinsics, intrinsics, image_size, shared_camera=False, camera_type="SIMPLE_PINHOLE", ): """ Convert Batched NumPy Arrays to PyCOLMAP Different from batch_np_matrix_to_pycolmap, this function does not use tracks. It saves points3d to colmap reconstruction format only to serve as init for Gaussians or other nvs methods. Do NOT use this for BA. """ # points3d: Px3 # points_xyf: Px3, with x, y coordinates and frame indices # points_rgb: Px3, rgb colors # extrinsics: Nx3x4 # intrinsics: Nx3x3 # image_size: 2, assume all the frames have been padded to the same size # where N is the number of frames and P is the number of tracks N = len(extrinsics) P = len(points3d) # Reconstruction object, following the format of PyCOLMAP/COLMAP reconstruction = pycolmap.Reconstruction() for vidx in range(P): reconstruction.add_point3D(points3d[vidx], pycolmap.Track(), points_rgb[vidx]) camera = None # frame idx for fidx in range(N): # set camera if camera is None or (not shared_camera): pycolmap_intri = _build_pycolmap_intri(fidx, intrinsics, camera_type) camera = pycolmap.Camera( model=camera_type, width=image_size[0], height=image_size[1], params=pycolmap_intri, camera_id=fidx + 1 ) # add camera reconstruction.add_camera(camera) # set image cam_from_world = pycolmap.Rigid3d( pycolmap.Rotation3d(extrinsics[fidx][:3, :3]), extrinsics[fidx][:3, 3] ) # Rot and Trans image = pycolmap.Image( id=fidx + 1, name=f"image_{fidx + 1}", camera_id=camera.camera_id, cam_from_world=cam_from_world ) points2D_list = [] point2D_idx = 0 points_belong_to_fidx = points_xyf[:, 2].astype(np.int32) == fidx points_belong_to_fidx = np.nonzero(points_belong_to_fidx)[0] for point3D_batch_idx in points_belong_to_fidx: point3D_id = point3D_batch_idx + 1 point2D_xyf = points_xyf[point3D_batch_idx] point2D_xy = point2D_xyf[:2] points2D_list.append(pycolmap.Point2D(point2D_xy, point3D_id)) # add element track = reconstruction.points3D[point3D_id].track track.add_element(fidx + 1, point2D_idx) point2D_idx += 1 assert point2D_idx == len(points2D_list) try: image.points2D = pycolmap.ListPoint2D(points2D_list) image.registered = True except: print(f"frame {fidx + 1} does not have any points") image.registered = False # add image reconstruction.add_image(image) return reconstruction def _build_pycolmap_intri(fidx, intrinsics, camera_type, extra_params=None): """ Helper function to get camera parameters based on camera type. Args: fidx: Frame index intrinsics: Camera intrinsic parameters camera_type: Type of camera model extra_params: Additional parameters for certain camera types Returns: pycolmap_intri: NumPy array of camera parameters """ if camera_type == "PINHOLE": pycolmap_intri = np.array( [intrinsics[fidx][0, 0], intrinsics[fidx][1, 1], intrinsics[fidx][0, 2], intrinsics[fidx][1, 2]] ) elif camera_type == "SIMPLE_PINHOLE": focal = (intrinsics[fidx][0, 0] + intrinsics[fidx][1, 1]) / 2 pycolmap_intri = np.array([focal, intrinsics[fidx][0, 2], intrinsics[fidx][1, 2]]) elif camera_type == "SIMPLE_RADIAL": raise NotImplementedError("SIMPLE_RADIAL is not supported yet") focal = (intrinsics[fidx][0, 0] + intrinsics[fidx][1, 1]) / 2 pycolmap_intri = np.array([focal, intrinsics[fidx][0, 2], intrinsics[fidx][1, 2], extra_params[fidx][0]]) else: raise ValueError(f"Camera type {camera_type} is not supported yet") return pycolmap_intri ================================================ FILE: vggt/dependency/projection.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch import numpy as np from .distortion import apply_distortion def img_from_cam_np( intrinsics: np.ndarray, points_cam: np.ndarray, extra_params: np.ndarray | None = None, default: float = 0.0 ) -> np.ndarray: """ Apply intrinsics (and optional radial distortion) to camera-space points. Args ---- intrinsics : (B,3,3) camera matrix K. points_cam : (B,3,N) homogeneous camera coords (x, y, z)ᵀ. extra_params: (B, N) or (B, k) distortion params (k = 1,2,4) or None. default : value used for np.nan replacement. Returns ------- points2D : (B,N,2) pixel coordinates. """ # 1. perspective divide ─────────────────────────────────────── z = points_cam[:, 2:3, :] # (B,1,N) points_cam_norm = points_cam / z # (B,3,N) uv = points_cam_norm[:, :2, :] # (B,2,N) # 2. optional distortion ────────────────────────────────────── if extra_params is not None: uu, vv = apply_distortion(extra_params, uv[:, 0], uv[:, 1]) uv = np.stack([uu, vv], axis=1) # (B,2,N) # 3. homogeneous coords then K multiplication ───────────────── ones = np.ones_like(uv[:, :1, :]) # (B,1,N) points_cam_h = np.concatenate([uv, ones], axis=1) # (B,3,N) # batched mat-mul: K · [u v 1]ᵀ points2D_h = np.einsum("bij,bjk->bik", intrinsics, points_cam_h) # (B,3,N) points2D = np.nan_to_num(points2D_h[:, :2, :], nan=default) # (B,2,N) return points2D.transpose(0, 2, 1) # (B,N,2) def project_3D_points_np( points3D: np.ndarray, extrinsics: np.ndarray, intrinsics: np.ndarray | None = None, extra_params: np.ndarray | None = None, *, default: float = 0.0, only_points_cam: bool = False, ): """ NumPy clone of ``project_3D_points``. Parameters ---------- points3D : (N,3) world-space points. extrinsics : (B,3,4) [R|t] matrix for each of B cameras. intrinsics : (B,3,3) K matrix (optional if you only need cam-space). extra_params : (B,k) or (B,N) distortion parameters (k ∈ {1,2,4}) or None. default : value used to replace NaNs. only_points_cam : if True, skip the projection and return points_cam with points2D as None. Returns ------- (points2D, points_cam) : A tuple where points2D is (B,N,2) pixel coords or None if only_points_cam=True, and points_cam is (B,3,N) camera-space coordinates. """ # ----- 0. prep sizes ----------------------------------------------------- N = points3D.shape[0] # #points B = extrinsics.shape[0] # #cameras # ----- 1. world → homogeneous ------------------------------------------- w_h = np.ones((N, 1), dtype=points3D.dtype) points3D_h = np.concatenate([points3D, w_h], axis=1) # (N,4) # broadcast to every camera (no actual copying with np.broadcast_to) ------ points3D_h_B = np.broadcast_to(points3D_h, (B, N, 4)) # (B,N,4) # ----- 2. apply extrinsics (camera frame) ------------------------------ # X_cam = E · X_hom # einsum: E_(b i j) · X_(b n j) → (b n i) points_cam = np.einsum("bij,bnj->bni", extrinsics, points3D_h_B) # (B,N,3) points_cam = points_cam.transpose(0, 2, 1) # (B,3,N) if only_points_cam: return None, points_cam # ----- 3. intrinsics + distortion --------------------------------------- if intrinsics is None: raise ValueError("`intrinsics` must be provided unless only_points_cam=True") points2D = img_from_cam_np(intrinsics, points_cam, extra_params=extra_params, default=default) return points2D, points_cam def project_3D_points(points3D, extrinsics, intrinsics=None, extra_params=None, default=0, only_points_cam=False): """ Transforms 3D points to 2D using extrinsic and intrinsic parameters. Args: points3D (torch.Tensor): 3D points of shape Px3. extrinsics (torch.Tensor): Extrinsic parameters of shape Bx3x4. intrinsics (torch.Tensor): Intrinsic parameters of shape Bx3x3. extra_params (torch.Tensor): Extra parameters of shape BxN, used for radial distortion. default (float): Default value to replace NaNs. only_points_cam (bool): If True, skip the projection and return points2D as None. Returns: tuple: (points2D, points_cam) where points2D is of shape BxNx2 or None if only_points_cam=True, and points_cam is of shape Bx3xN. """ with torch.cuda.amp.autocast(dtype=torch.double): N = points3D.shape[0] # Number of points B = extrinsics.shape[0] # Batch size, i.e., number of cameras points3D_homogeneous = torch.cat([points3D, torch.ones_like(points3D[..., 0:1])], dim=1) # Nx4 # Reshape for batch processing points3D_homogeneous = points3D_homogeneous.unsqueeze(0).expand(B, -1, -1) # BxNx4 # Step 1: Apply extrinsic parameters # Transform 3D points to camera coordinate system for all cameras points_cam = torch.bmm(extrinsics, points3D_homogeneous.transpose(-1, -2)) if only_points_cam: return None, points_cam # Step 2: Apply intrinsic parameters and (optional) distortion points2D = img_from_cam(intrinsics, points_cam, extra_params, default) return points2D, points_cam def img_from_cam(intrinsics, points_cam, extra_params=None, default=0.0): """ Applies intrinsic parameters and optional distortion to the given 3D points. Args: intrinsics (torch.Tensor): Intrinsic camera parameters of shape Bx3x3. points_cam (torch.Tensor): 3D points in camera coordinates of shape Bx3xN. extra_params (torch.Tensor, optional): Distortion parameters of shape BxN, where N can be 1, 2, or 4. default (float, optional): Default value to replace NaNs in the output. Returns: points2D (torch.Tensor): 2D points in pixel coordinates of shape BxNx2. """ # Normalize by the third coordinate (homogeneous division) points_cam = points_cam / points_cam[:, 2:3, :] # Extract uv uv = points_cam[:, :2, :] # Apply distortion if extra_params are provided if extra_params is not None: uu, vv = apply_distortion(extra_params, uv[:, 0], uv[:, 1]) uv = torch.stack([uu, vv], dim=1) # Prepare points_cam for batch matrix multiplication points_cam_homo = torch.cat((uv, torch.ones_like(uv[:, :1, :])), dim=1) # Bx3xN # Apply intrinsic parameters using batch matrix multiplication points2D_homo = torch.bmm(intrinsics, points_cam_homo) # Bx3xN # Extract x and y coordinates points2D = points2D_homo[:, :2, :] # Bx2xN # Replace NaNs with default value points2D = torch.nan_to_num(points2D, nan=default) return points2D.transpose(1, 2) # BxNx2 if __name__ == "__main__": # Set up example input B, N = 24, 10240 for _ in range(100): points3D = np.random.rand(N, 3).astype(np.float64) extrinsics = np.random.rand(B, 3, 4).astype(np.float64) intrinsics = np.random.rand(B, 3, 3).astype(np.float64) # Convert to torch tensors points3D_torch = torch.tensor(points3D) extrinsics_torch = torch.tensor(extrinsics) intrinsics_torch = torch.tensor(intrinsics) # Run NumPy implementation points2D_np, points_cam_np = project_3D_points_np(points3D, extrinsics, intrinsics) # Run torch implementation points2D_torch, points_cam_torch = project_3D_points(points3D_torch, extrinsics_torch, intrinsics_torch) # Convert torch output to numpy points2D_torch_np = points2D_torch.detach().numpy() points_cam_torch_np = points_cam_torch.detach().numpy() # Compute difference diff = np.abs(points2D_np - points2D_torch_np) print("Difference between NumPy and PyTorch implementations:") print(diff) # Check max error max_diff = np.max(diff) print(f"Maximum difference: {max_diff}") if np.allclose(points2D_np, points2D_torch_np, atol=1e-6): print("Implementations match closely.") else: print("Significant differences detected.") if points_cam_np is not None: points_cam_diff = np.abs(points_cam_np - points_cam_torch_np) print("Difference between NumPy and PyTorch camera-space coordinates:") print(points_cam_diff) # Check max error max_cam_diff = np.max(points_cam_diff) print(f"Maximum camera-space coordinate difference: {max_cam_diff}") if np.allclose(points_cam_np, points_cam_torch_np, atol=1e-6): print("Camera-space coordinates match closely.") else: print("Significant differences detected in camera-space coordinates.") ================================================ FILE: vggt/dependency/track_modules/__init__.py ================================================ ================================================ FILE: vggt/dependency/track_modules/base_track_predictor.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch import torch.nn as nn from einops import rearrange, repeat from .blocks import EfficientUpdateFormer, CorrBlock from .utils import sample_features4d, get_2d_embedding, get_2d_sincos_pos_embed class BaseTrackerPredictor(nn.Module): def __init__( self, stride=4, corr_levels=5, corr_radius=4, latent_dim=128, hidden_size=384, use_spaceatt=True, depth=6, fine=False, ): super(BaseTrackerPredictor, self).__init__() """ The base template to create a track predictor Modified from https://github.com/facebookresearch/co-tracker/ """ self.stride = stride self.latent_dim = latent_dim self.corr_levels = corr_levels self.corr_radius = corr_radius self.hidden_size = hidden_size self.fine = fine self.flows_emb_dim = latent_dim // 2 self.transformer_dim = self.corr_levels * (self.corr_radius * 2 + 1) ** 2 + self.latent_dim * 2 if self.fine: # TODO this is the old dummy code, will remove this when we train next model self.transformer_dim += 4 if self.transformer_dim % 2 == 0 else 5 else: self.transformer_dim += (4 - self.transformer_dim % 4) % 4 space_depth = depth if use_spaceatt else 0 time_depth = depth self.updateformer = EfficientUpdateFormer( space_depth=space_depth, time_depth=time_depth, input_dim=self.transformer_dim, hidden_size=self.hidden_size, output_dim=self.latent_dim + 2, mlp_ratio=4.0, add_space_attn=use_spaceatt, ) self.norm = nn.GroupNorm(1, self.latent_dim) # A linear layer to update track feats at each iteration self.ffeat_updater = nn.Sequential(nn.Linear(self.latent_dim, self.latent_dim), nn.GELU()) if not self.fine: self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1)) def forward(self, query_points, fmaps=None, iters=4, return_feat=False, down_ratio=1): """ query_points: B x N x 2, the number of batches, tracks, and xy fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension. note HH and WW is the size of feature maps instead of original images """ B, N, D = query_points.shape B, S, C, HH, WW = fmaps.shape assert D == 2 # Scale the input query_points because we may downsample the images # by down_ratio or self.stride # e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map # its query_points should be query_points/4 if down_ratio > 1: query_points = query_points / float(down_ratio) query_points = query_points / float(self.stride) # Init with coords as the query points # It means the search will start from the position of query points at the reference frames coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1) # Sample/extract the features of the query points in the query frame query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0]) # init track feats by query feats track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1) # B, S, N, C # back up the init coords coords_backup = coords.clone() # Construct the correlation block fcorr_fn = CorrBlock(fmaps, num_levels=self.corr_levels, radius=self.corr_radius) coord_preds = [] # Iterative Refinement for itr in range(iters): # Detach the gradients from the last iteration # (in my experience, not very important for performance) coords = coords.detach() # Compute the correlation (check the implementation of CorrBlock) fcorr_fn.corr(track_feats) fcorrs = fcorr_fn.sample(coords) # B, S, N, corrdim corrdim = fcorrs.shape[3] fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corrdim) # Movement of current coords relative to query points flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2) flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False) # (In my trials, it is also okay to just add the flows_emb instead of concat) flows_emb = torch.cat([flows_emb, flows], dim=-1) track_feats_ = track_feats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim) # Concatenate them as the input for the transformers transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2) if transformer_input.shape[2] < self.transformer_dim: # pad the features to match the dimension pad_dim = self.transformer_dim - transformer_input.shape[2] pad = torch.zeros_like(flows_emb[..., 0:pad_dim]) transformer_input = torch.cat([transformer_input, pad], dim=2) # 2D positional embed # TODO: this can be much simplified pos_embed = get_2d_sincos_pos_embed(self.transformer_dim, grid_size=(HH, WW)).to(query_points.device) sampled_pos_emb = sample_features4d(pos_embed.expand(B, -1, -1, -1), coords[:, 0]) sampled_pos_emb = rearrange(sampled_pos_emb, "b n c -> (b n) c").unsqueeze(1) x = transformer_input + sampled_pos_emb # B, N, S, C x = rearrange(x, "(b n) s d -> b n s d", b=B) # Compute the delta coordinates and delta track features delta = self.updateformer(x) # BN, S, C delta = rearrange(delta, " b n s d -> (b n) s d", b=B) delta_coords_ = delta[:, :, :2] delta_feats_ = delta[:, :, 2:] track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim) delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim) # Update the track features track_feats_ = self.ffeat_updater(self.norm(delta_feats_)) + track_feats_ track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute(0, 2, 1, 3) # BxSxNxC # B x S x N x 2 coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3) # Force coord0 as query # because we assume the query points should not be changed coords[:, 0] = coords_backup[:, 0] # The predicted tracks are in the original image scale if down_ratio > 1: coord_preds.append(coords * self.stride * down_ratio) else: coord_preds.append(coords * self.stride) # B, S, N if not self.fine: vis_e = self.vis_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N) vis_e = torch.sigmoid(vis_e) else: vis_e = None if return_feat: return coord_preds, vis_e, track_feats, query_track_feat else: return coord_preds, vis_e ================================================ FILE: vggt/dependency/track_modules/blocks.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # Modified from https://github.com/facebookresearch/co-tracker/ import torch import torch.nn as nn import torch.nn.functional as F from functools import partial from typing import Callable import collections from torch import Tensor from itertools import repeat from .utils import bilinear_sampler from .modules import Mlp, AttnBlock, CrossAttnBlock, ResidualBlock class BasicEncoder(nn.Module): def __init__(self, input_dim=3, output_dim=128, stride=4): super(BasicEncoder, self).__init__() self.stride = stride self.norm_fn = "instance" self.in_planes = output_dim // 2 self.norm1 = nn.InstanceNorm2d(self.in_planes) self.norm2 = nn.InstanceNorm2d(output_dim * 2) self.conv1 = nn.Conv2d(input_dim, self.in_planes, kernel_size=7, stride=2, padding=3, padding_mode="zeros") self.relu1 = nn.ReLU(inplace=True) self.layer1 = self._make_layer(output_dim // 2, stride=1) self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2) self.layer3 = self._make_layer(output_dim, stride=2) self.layer4 = self._make_layer(output_dim, stride=2) self.conv2 = nn.Conv2d( output_dim * 3 + output_dim // 4, output_dim * 2, kernel_size=3, padding=1, padding_mode="zeros" ) self.relu2 = nn.ReLU(inplace=True) self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") elif isinstance(m, (nn.InstanceNorm2d)): if m.weight is not None: nn.init.constant_(m.weight, 1) if m.bias is not None: nn.init.constant_(m.bias, 0) def _make_layer(self, dim, stride=1): layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) layers = (layer1, layer2) self.in_planes = dim return nn.Sequential(*layers) def forward(self, x): _, _, H, W = x.shape x = self.conv1(x) x = self.norm1(x) x = self.relu1(x) a = self.layer1(x) b = self.layer2(a) c = self.layer3(b) d = self.layer4(c) a = _bilinear_intepolate(a, self.stride, H, W) b = _bilinear_intepolate(b, self.stride, H, W) c = _bilinear_intepolate(c, self.stride, H, W) d = _bilinear_intepolate(d, self.stride, H, W) x = self.conv2(torch.cat([a, b, c, d], dim=1)) x = self.norm2(x) x = self.relu2(x) x = self.conv3(x) return x class ShallowEncoder(nn.Module): def __init__(self, input_dim=3, output_dim=32, stride=1, norm_fn="instance"): super(ShallowEncoder, self).__init__() self.stride = stride self.norm_fn = norm_fn self.in_planes = output_dim if self.norm_fn == "group": self.norm1 = nn.GroupNorm(num_groups=8, num_channels=self.in_planes) self.norm2 = nn.GroupNorm(num_groups=8, num_channels=output_dim * 2) elif self.norm_fn == "batch": self.norm1 = nn.BatchNorm2d(self.in_planes) self.norm2 = nn.BatchNorm2d(output_dim * 2) elif self.norm_fn == "instance": self.norm1 = nn.InstanceNorm2d(self.in_planes) self.norm2 = nn.InstanceNorm2d(output_dim * 2) elif self.norm_fn == "none": self.norm1 = nn.Sequential() self.conv1 = nn.Conv2d(input_dim, self.in_planes, kernel_size=3, stride=2, padding=1, padding_mode="zeros") self.relu1 = nn.ReLU(inplace=True) self.layer1 = self._make_layer(output_dim, stride=2) self.layer2 = self._make_layer(output_dim, stride=2) self.conv2 = nn.Conv2d(output_dim, output_dim, kernel_size=1) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): if m.weight is not None: nn.init.constant_(m.weight, 1) if m.bias is not None: nn.init.constant_(m.bias, 0) def _make_layer(self, dim, stride=1): self.in_planes = dim layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) return layer1 def forward(self, x): _, _, H, W = x.shape x = self.conv1(x) x = self.norm1(x) x = self.relu1(x) tmp = self.layer1(x) x = x + F.interpolate(tmp, (x.shape[-2:]), mode="bilinear", align_corners=True) tmp = self.layer2(tmp) x = x + F.interpolate(tmp, (x.shape[-2:]), mode="bilinear", align_corners=True) tmp = None x = self.conv2(x) + x x = F.interpolate(x, (H // self.stride, W // self.stride), mode="bilinear", align_corners=True) return x def _bilinear_intepolate(x, stride, H, W): return F.interpolate(x, (H // stride, W // stride), mode="bilinear", align_corners=True) class EfficientUpdateFormer(nn.Module): """ Transformer model that updates track estimates. """ def __init__( self, space_depth=6, time_depth=6, input_dim=320, hidden_size=384, num_heads=8, output_dim=130, mlp_ratio=4.0, add_space_attn=True, num_virtual_tracks=64, ): super().__init__() self.out_channels = 2 self.num_heads = num_heads self.hidden_size = hidden_size self.add_space_attn = add_space_attn self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True) self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True) self.num_virtual_tracks = num_virtual_tracks if self.add_space_attn: self.virual_tracks = nn.Parameter(torch.randn(1, num_virtual_tracks, 1, hidden_size)) else: self.virual_tracks = None self.time_blocks = nn.ModuleList( [ AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=nn.MultiheadAttention) for _ in range(time_depth) ] ) if add_space_attn: self.space_virtual_blocks = nn.ModuleList( [ AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=nn.MultiheadAttention) for _ in range(space_depth) ] ) self.space_point2virtual_blocks = nn.ModuleList( [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)] ) self.space_virtual2point_blocks = nn.ModuleList( [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)] ) assert len(self.time_blocks) >= len(self.space_virtual2point_blocks) self.initialize_weights() def initialize_weights(self): def _basic_init(module): if isinstance(module, nn.Linear): torch.nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) def init_weights_vit_timm(module: nn.Module, name: str = ""): """ViT weight initialization, original timm impl (for reproducibility)""" if isinstance(module, nn.Linear): trunc_normal_(module.weight, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) def forward(self, input_tensor, mask=None): tokens = self.input_transform(input_tensor) init_tokens = tokens B, _, T, _ = tokens.shape if self.add_space_attn: virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1) tokens = torch.cat([tokens, virtual_tokens], dim=1) _, N, _, _ = tokens.shape j = 0 for i in range(len(self.time_blocks)): time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C time_tokens = self.time_blocks[i](time_tokens) tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C if self.add_space_attn and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0): space_tokens = tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1) # B N T C -> (B T) N C point_tokens = space_tokens[:, : N - self.num_virtual_tracks] virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :] virtual_tokens = self.space_virtual2point_blocks[j](virtual_tokens, point_tokens, mask=mask) virtual_tokens = self.space_virtual_blocks[j](virtual_tokens) point_tokens = self.space_point2virtual_blocks[j](point_tokens, virtual_tokens, mask=mask) space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1) tokens = space_tokens.view(B, T, N, -1).permute(0, 2, 1, 3) # (B T) N C -> B N T C j += 1 if self.add_space_attn: tokens = tokens[:, : N - self.num_virtual_tracks] tokens = tokens + init_tokens flow = self.flow_head(tokens) return flow class CorrBlock: def __init__(self, fmaps, num_levels=4, radius=4, multiple_track_feats=False, padding_mode="zeros"): B, S, C, H, W = fmaps.shape self.S, self.C, self.H, self.W = S, C, H, W self.padding_mode = padding_mode self.num_levels = num_levels self.radius = radius self.fmaps_pyramid = [] self.multiple_track_feats = multiple_track_feats self.fmaps_pyramid.append(fmaps) for i in range(self.num_levels - 1): fmaps_ = fmaps.reshape(B * S, C, H, W) fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2) _, _, H, W = fmaps_.shape fmaps = fmaps_.reshape(B, S, C, H, W) self.fmaps_pyramid.append(fmaps) def sample(self, coords): r = self.radius B, S, N, D = coords.shape assert D == 2 H, W = self.H, self.W out_pyramid = [] for i in range(self.num_levels): corrs = self.corrs_pyramid[i] # B, S, N, H, W *_, H, W = corrs.shape dx = torch.linspace(-r, r, 2 * r + 1) dy = torch.linspace(-r, r, 2 * r + 1) delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(coords.device) centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2**i delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) coords_lvl = centroid_lvl + delta_lvl corrs = bilinear_sampler(corrs.reshape(B * S * N, 1, H, W), coords_lvl, padding_mode=self.padding_mode) corrs = corrs.view(B, S, N, -1) out_pyramid.append(corrs) out = torch.cat(out_pyramid, dim=-1).contiguous() # B, S, N, LRR*2 return out def corr(self, targets): B, S, N, C = targets.shape if self.multiple_track_feats: targets_split = targets.split(C // self.num_levels, dim=-1) B, S, N, C = targets_split[0].shape assert C == self.C assert S == self.S fmap1 = targets self.corrs_pyramid = [] for i, fmaps in enumerate(self.fmaps_pyramid): *_, H, W = fmaps.shape fmap2s = fmaps.view(B, S, C, H * W) # B S C H W -> B S C (H W) if self.multiple_track_feats: fmap1 = targets_split[i] corrs = torch.matmul(fmap1, fmap2s) corrs = corrs.view(B, S, N, H, W) # B S N (H W) -> B S N H W corrs = corrs / torch.sqrt(torch.tensor(C).float()) self.corrs_pyramid.append(corrs) ================================================ FILE: vggt/dependency/track_modules/modules.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch import torch.nn as nn import torch.nn.functional as F from functools import partial from typing import Callable import collections from torch import Tensor from itertools import repeat # From PyTorch internals def _ntuple(n): def parse(x): if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): return tuple(x) return tuple(repeat(x, n)) return parse def exists(val): return val is not None def default(val, d): return val if exists(val) else d to_2tuple = _ntuple(2) class ResidualBlock(nn.Module): """ ResidualBlock: construct a block of two conv layers with residual connections """ def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3): super(ResidualBlock, self).__init__() self.conv1 = nn.Conv2d( in_planes, planes, kernel_size=kernel_size, padding=1, stride=stride, padding_mode="zeros" ) self.conv2 = nn.Conv2d(planes, planes, kernel_size=kernel_size, padding=1, padding_mode="zeros") self.relu = nn.ReLU(inplace=True) num_groups = planes // 8 if norm_fn == "group": self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) if not stride == 1: self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) elif norm_fn == "batch": self.norm1 = nn.BatchNorm2d(planes) self.norm2 = nn.BatchNorm2d(planes) if not stride == 1: self.norm3 = nn.BatchNorm2d(planes) elif norm_fn == "instance": self.norm1 = nn.InstanceNorm2d(planes) self.norm2 = nn.InstanceNorm2d(planes) if not stride == 1: self.norm3 = nn.InstanceNorm2d(planes) elif norm_fn == "none": self.norm1 = nn.Sequential() self.norm2 = nn.Sequential() if not stride == 1: self.norm3 = nn.Sequential() else: raise NotImplementedError if stride == 1: self.downsample = None else: self.downsample = nn.Sequential(nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) def forward(self, x): y = x y = self.relu(self.norm1(self.conv1(y))) y = self.relu(self.norm2(self.conv2(y))) if self.downsample is not None: x = self.downsample(x) return self.relu(x + y) class Mlp(nn.Module): """MLP as used in Vision Transformer, MLP-Mixer and related networks""" def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, norm_layer=None, bias=True, drop=0.0, use_conv=False, ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features bias = to_2tuple(bias) drop_probs = to_2tuple(drop) linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) self.act = act_layer() self.drop1 = nn.Dropout(drop_probs[0]) self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) self.drop2 = nn.Dropout(drop_probs[1]) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop1(x) x = self.fc2(x) x = self.drop2(x) return x class AttnBlock(nn.Module): def __init__( self, hidden_size, num_heads, attn_class: Callable[..., nn.Module] = nn.MultiheadAttention, mlp_ratio=4.0, **block_kwargs, ): """ Self attention block """ super().__init__() self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.attn = attn_class(embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs) mlp_hidden_dim = int(hidden_size * mlp_ratio) self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0) def forward(self, x, mask=None): # Prepare the mask for PyTorch's attention (it expects a different format) # attn_mask = mask if mask is not None else None # Normalize before attention x = self.norm1(x) # PyTorch's MultiheadAttention returns attn_output, attn_output_weights # attn_output, _ = self.attn(x, x, x, attn_mask=attn_mask) attn_output, _ = self.attn(x, x, x) # Add & Norm x = x + attn_output x = x + self.mlp(self.norm2(x)) return x class CrossAttnBlock(nn.Module): def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs): """ Cross attention block """ super().__init__() self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.norm_context = nn.LayerNorm(hidden_size) self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.cross_attn = nn.MultiheadAttention( embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs ) mlp_hidden_dim = int(hidden_size * mlp_ratio) self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0) def forward(self, x, context, mask=None): # Normalize inputs x = self.norm1(x) context = self.norm_context(context) # Apply cross attention # Note: nn.MultiheadAttention returns attn_output, attn_output_weights attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask) # Add & Norm x = x + attn_output x = x + self.mlp(self.norm2(x)) return x ================================================ FILE: vggt/dependency/track_modules/track_refine.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from functools import partial from torch import nn, einsum from einops import rearrange, repeat from einops.layers.torch import Rearrange, Reduce from PIL import Image import os from typing import Union, Tuple def refine_track( images, fine_fnet, fine_tracker, coarse_pred, compute_score=False, pradius=15, sradius=2, fine_iters=6, chunk=40960 ): """ Refines the tracking of images using a fine track predictor and a fine feature network. Check https://arxiv.org/abs/2312.04563 for more details. Args: images (torch.Tensor): The images to be tracked. fine_fnet (nn.Module): The fine feature network. fine_tracker (nn.Module): The fine track predictor. coarse_pred (torch.Tensor): The coarse predictions of tracks. compute_score (bool, optional): Whether to compute the score. Defaults to False. pradius (int, optional): The radius of a patch. Defaults to 15. sradius (int, optional): The search radius. Defaults to 2. Returns: torch.Tensor: The refined tracks. torch.Tensor, optional: The score. """ # coarse_pred shape: BxSxNx2, # where B is the batch, S is the video/images length, and N is the number of tracks # now we are going to extract patches with the center at coarse_pred # Please note that the last dimension indicates x and y, and hence has a dim number of 2 B, S, N, _ = coarse_pred.shape _, _, _, H, W = images.shape # Given the raidus of a patch, compute the patch size psize = pradius * 2 + 1 # Note that we assume the first frame is the query frame # so the 2D locations of the first frame are the query points query_points = coarse_pred[:, 0] # Given 2D positions, we can use grid_sample to extract patches # but it takes too much memory. # Instead, we use the floored track xy to sample patches. # For example, if the query point xy is (128.16, 252.78), # and the patch size is (31, 31), # our goal is to extract the content of a rectangle # with left top: (113.16, 237.78) # and right bottom: (143.16, 267.78). # However, we record the floored left top: (113, 237) # and the offset (0.16, 0.78) # Then what we need is just unfolding the images like in CNN, # picking the content at [(113, 237), (143, 267)]. # Such operations are highly optimized at pytorch # (well if you really want to use interpolation, check the function extract_glimpse() below) with torch.no_grad(): content_to_extract = images.reshape(B * S, 3, H, W) C_in = content_to_extract.shape[1] # Please refer to https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html # for the detailed explanation of unfold() # Here it runs sliding windows (psize x psize) to build patches # The shape changes from # (B*S)x C_in x H x W to (B*S)x C_in x H_new x W_new x Psize x Psize # where Psize is the size of patch content_to_extract = content_to_extract.unfold(2, psize, 1).unfold(3, psize, 1) # Floor the coarse predictions to get integers and save the fractional/decimal track_int = coarse_pred.floor().int() track_frac = coarse_pred - track_int # Note the points represent the center of patches # now we get the location of the top left corner of patches # because the ouput of pytorch unfold are indexed by top left corner topleft = track_int - pradius topleft_BSN = topleft.clone() # clamp the values so that we will not go out of indexes # NOTE: (VERY IMPORTANT: This operation ASSUMES H=W). # You need to seperately clamp x and y if H!=W topleft = topleft.clamp(0, H - psize) # Reshape from BxSxNx2 -> (B*S)xNx2 topleft = topleft.reshape(B * S, N, 2) # Prepare batches for indexing, shape: (B*S)xN batch_indices = torch.arange(B * S)[:, None].expand(-1, N).to(content_to_extract.device) # extracted_patches: (B*S) x N x C_in x Psize x Psize extracted_patches = content_to_extract[batch_indices, :, topleft[..., 1], topleft[..., 0]] if chunk < 0: # Extract image patches based on top left corners # Feed patches to fine fent for features patch_feat = fine_fnet(extracted_patches.reshape(B * S * N, C_in, psize, psize)) else: patches = extracted_patches.reshape(B * S * N, C_in, psize, psize) patch_feat_list = [] for p in torch.split(patches, chunk): patch_feat_list += [fine_fnet(p)] patch_feat = torch.cat(patch_feat_list, 0) C_out = patch_feat.shape[1] # Refine the coarse tracks by fine_tracker # reshape back to B x S x N x C_out x Psize x Psize patch_feat = patch_feat.reshape(B, S, N, C_out, psize, psize) patch_feat = rearrange(patch_feat, "b s n c p q -> (b n) s c p q") # Prepare for the query points for fine tracker # They are relative to the patch left top corner, # instead of the image top left corner now # patch_query_points: N x 1 x 2 # only 1 here because for each patch we only have 1 query point patch_query_points = track_frac[:, 0] + pradius patch_query_points = patch_query_points.reshape(B * N, 2).unsqueeze(1) # Feed the PATCH query points and tracks into fine tracker fine_pred_track_lists, _, _, query_point_feat = fine_tracker( query_points=patch_query_points, fmaps=patch_feat, iters=fine_iters, return_feat=True ) # relative the patch top left fine_pred_track = fine_pred_track_lists[-1].clone() # From (relative to the patch top left) to (relative to the image top left) for idx in range(len(fine_pred_track_lists)): fine_level = rearrange(fine_pred_track_lists[idx], "(b n) s u v -> b s n u v", b=B, n=N) fine_level = fine_level.squeeze(-2) fine_level = fine_level + topleft_BSN fine_pred_track_lists[idx] = fine_level # relative to the image top left refined_tracks = fine_pred_track_lists[-1].clone() refined_tracks[:, 0] = query_points score = None if compute_score: score = compute_score_fn(query_point_feat, patch_feat, fine_pred_track, sradius, psize, B, N, S, C_out) return refined_tracks, score def refine_track_v0( images, fine_fnet, fine_tracker, coarse_pred, compute_score=False, pradius=15, sradius=2, fine_iters=6 ): """ COPIED FROM VGGSfM Refines the tracking of images using a fine track predictor and a fine feature network. Check https://arxiv.org/abs/2312.04563 for more details. Args: images (torch.Tensor): The images to be tracked. fine_fnet (nn.Module): The fine feature network. fine_tracker (nn.Module): The fine track predictor. coarse_pred (torch.Tensor): The coarse predictions of tracks. compute_score (bool, optional): Whether to compute the score. Defaults to False. pradius (int, optional): The radius of a patch. Defaults to 15. sradius (int, optional): The search radius. Defaults to 2. Returns: torch.Tensor: The refined tracks. torch.Tensor, optional: The score. """ # coarse_pred shape: BxSxNx2, # where B is the batch, S is the video/images length, and N is the number of tracks # now we are going to extract patches with the center at coarse_pred # Please note that the last dimension indicates x and y, and hence has a dim number of 2 B, S, N, _ = coarse_pred.shape _, _, _, H, W = images.shape # Given the raidus of a patch, compute the patch size psize = pradius * 2 + 1 # Note that we assume the first frame is the query frame # so the 2D locations of the first frame are the query points query_points = coarse_pred[:, 0] # Given 2D positions, we can use grid_sample to extract patches # but it takes too much memory. # Instead, we use the floored track xy to sample patches. # For example, if the query point xy is (128.16, 252.78), # and the patch size is (31, 31), # our goal is to extract the content of a rectangle # with left top: (113.16, 237.78) # and right bottom: (143.16, 267.78). # However, we record the floored left top: (113, 237) # and the offset (0.16, 0.78) # Then what we need is just unfolding the images like in CNN, # picking the content at [(113, 237), (143, 267)]. # Such operations are highly optimized at pytorch # (well if you really want to use interpolation, check the function extract_glimpse() below) with torch.no_grad(): content_to_extract = images.reshape(B * S, 3, H, W) C_in = content_to_extract.shape[1] # Please refer to https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html # for the detailed explanation of unfold() # Here it runs sliding windows (psize x psize) to build patches # The shape changes from # (B*S)x C_in x H x W to (B*S)x C_in x H_new x W_new x Psize x Psize # where Psize is the size of patch content_to_extract = content_to_extract.unfold(2, psize, 1).unfold(3, psize, 1) # Floor the coarse predictions to get integers and save the fractional/decimal track_int = coarse_pred.floor().int() track_frac = coarse_pred - track_int # Note the points represent the center of patches # now we get the location of the top left corner of patches # because the ouput of pytorch unfold are indexed by top left corner topleft = track_int - pradius topleft_BSN = topleft.clone() # clamp the values so that we will not go out of indexes # NOTE: (VERY IMPORTANT: This operation ASSUMES H=W). # You need to seperately clamp x and y if H!=W topleft = topleft.clamp(0, H - psize) # Reshape from BxSxNx2 -> (B*S)xNx2 topleft = topleft.reshape(B * S, N, 2) # Prepare batches for indexing, shape: (B*S)xN batch_indices = torch.arange(B * S)[:, None].expand(-1, N).to(content_to_extract.device) # Extract image patches based on top left corners # extracted_patches: (B*S) x N x C_in x Psize x Psize extracted_patches = content_to_extract[batch_indices, :, topleft[..., 1], topleft[..., 0]] # Feed patches to fine fent for features patch_feat = fine_fnet(extracted_patches.reshape(B * S * N, C_in, psize, psize)) C_out = patch_feat.shape[1] # Refine the coarse tracks by fine_tracker # reshape back to B x S x N x C_out x Psize x Psize patch_feat = patch_feat.reshape(B, S, N, C_out, psize, psize) patch_feat = rearrange(patch_feat, "b s n c p q -> (b n) s c p q") # Prepare for the query points for fine tracker # They are relative to the patch left top corner, # instead of the image top left corner now # patch_query_points: N x 1 x 2 # only 1 here because for each patch we only have 1 query point patch_query_points = track_frac[:, 0] + pradius patch_query_points = patch_query_points.reshape(B * N, 2).unsqueeze(1) # Feed the PATCH query points and tracks into fine tracker fine_pred_track_lists, _, _, query_point_feat = fine_tracker( query_points=patch_query_points, fmaps=patch_feat, iters=fine_iters, return_feat=True ) # relative the patch top left fine_pred_track = fine_pred_track_lists[-1].clone() # From (relative to the patch top left) to (relative to the image top left) for idx in range(len(fine_pred_track_lists)): fine_level = rearrange(fine_pred_track_lists[idx], "(b n) s u v -> b s n u v", b=B, n=N) fine_level = fine_level.squeeze(-2) fine_level = fine_level + topleft_BSN fine_pred_track_lists[idx] = fine_level # relative to the image top left refined_tracks = fine_pred_track_lists[-1].clone() refined_tracks[:, 0] = query_points score = None if compute_score: score = compute_score_fn(query_point_feat, patch_feat, fine_pred_track, sradius, psize, B, N, S, C_out) return refined_tracks, score ################################## NOTE: NOT USED ################################## def compute_score_fn(query_point_feat, patch_feat, fine_pred_track, sradius, psize, B, N, S, C_out): """ Compute the scores, i.e., the standard deviation of the 2D similarity heatmaps, given the query point features and reference frame feature maps """ from kornia.utils.grid import create_meshgrid from kornia.geometry.subpix import dsnt # query_point_feat initial shape: B x N x C_out, # query_point_feat indicates the feat at the coorponsing query points # Therefore we don't have S dimension here query_point_feat = query_point_feat.reshape(B, N, C_out) # reshape and expand to B x (S-1) x N x C_out query_point_feat = query_point_feat.unsqueeze(1).expand(-1, S - 1, -1, -1) # and reshape to (B*(S-1)*N) x C_out query_point_feat = query_point_feat.reshape(B * (S - 1) * N, C_out) # Radius and size for computing the score ssize = sradius * 2 + 1 # Reshape, you know it, so many reshaping operations patch_feat = rearrange(patch_feat, "(b n) s c p q -> b s n c p q", b=B, n=N) # Again, we unfold the patches to smaller patches # so that we can then focus on smaller patches # patch_feat_unfold shape: # B x S x N x C_out x (psize - 2*sradius) x (psize - 2*sradius) x ssize x ssize # well a bit scary, but actually not patch_feat_unfold = patch_feat.unfold(4, ssize, 1).unfold(5, ssize, 1) # Do the same stuffs above, i.e., the same as extracting patches fine_prediction_floor = fine_pred_track.floor().int() fine_level_floor_topleft = fine_prediction_floor - sradius # Clamp to ensure the smaller patch is valid fine_level_floor_topleft = fine_level_floor_topleft.clamp(0, psize - ssize) fine_level_floor_topleft = fine_level_floor_topleft.squeeze(2) # Prepare the batch indices and xy locations batch_indices_score = torch.arange(B)[:, None, None].expand(-1, S, N) # BxSxN batch_indices_score = batch_indices_score.reshape(-1).to(patch_feat_unfold.device) # B*S*N y_indices = fine_level_floor_topleft[..., 0].flatten() # Flatten H indices x_indices = fine_level_floor_topleft[..., 1].flatten() # Flatten W indices reference_frame_feat = patch_feat_unfold.reshape( B * S * N, C_out, psize - sradius * 2, psize - sradius * 2, ssize, ssize ) # Note again, according to pytorch convention # x_indices cooresponds to [..., 1] and y_indices cooresponds to [..., 0] reference_frame_feat = reference_frame_feat[batch_indices_score, :, x_indices, y_indices] reference_frame_feat = reference_frame_feat.reshape(B, S, N, C_out, ssize, ssize) # pick the frames other than the first one, so we have S-1 frames here reference_frame_feat = reference_frame_feat[:, 1:].reshape(B * (S - 1) * N, C_out, ssize * ssize) # Compute similarity sim_matrix = torch.einsum("mc,mcr->mr", query_point_feat, reference_frame_feat) softmax_temp = 1.0 / C_out**0.5 heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1) # 2D heatmaps heatmap = heatmap.reshape(B * (S - 1) * N, ssize, ssize) # * x ssize x ssize coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0] grid_normalized = create_meshgrid(ssize, ssize, normalized_coordinates=True, device=heatmap.device).reshape( 1, -1, 2 ) var = torch.sum(grid_normalized**2 * heatmap.view(-1, ssize * ssize, 1), dim=1) - coords_normalized**2 std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1) # clamp needed for numerical stability score = std.reshape(B, S - 1, N) # set score as 1 for the query frame score = torch.cat([torch.ones_like(score[:, 0:1]), score], dim=1) return score def extract_glimpse( tensor: torch.Tensor, size: Tuple[int, int], offsets, mode="bilinear", padding_mode="zeros", debug=False, orib=None ): B, C, W, H = tensor.shape h, w = size xs = torch.arange(0, w, dtype=tensor.dtype, device=tensor.device) - (w - 1) / 2.0 ys = torch.arange(0, h, dtype=tensor.dtype, device=tensor.device) - (h - 1) / 2.0 vy, vx = torch.meshgrid(ys, xs) grid = torch.stack([vx, vy], dim=-1) # h, w, 2 grid = grid[None] B, N, _ = offsets.shape offsets = offsets.reshape((B * N), 1, 1, 2) offsets_grid = offsets + grid # normalised grid to [-1, 1] offsets_grid = (offsets_grid - offsets_grid.new_tensor([W / 2, H / 2])) / offsets_grid.new_tensor([W / 2, H / 2]) # BxCxHxW -> Bx1xCxHxW tensor = tensor[:, None] # Bx1xCxHxW -> BxNxCxHxW tensor = tensor.expand(-1, N, -1, -1, -1) # BxNxCxHxW -> (B*N)xCxHxW tensor = tensor.reshape((B * N), C, W, H) sampled = torch.nn.functional.grid_sample( tensor, offsets_grid, mode=mode, align_corners=False, padding_mode=padding_mode ) # NOTE: I am not sure it should be h, w or w, h here # but okay for sqaures sampled = sampled.reshape(B, N, C, h, w) return sampled ================================================ FILE: vggt/dependency/track_modules/utils.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # Modified from https://github.com/facebookresearch/PoseDiffusion # and https://github.com/facebookresearch/co-tracker/tree/main import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Tuple, Union from einops import rearrange, repeat def get_2d_sincos_pos_embed(embed_dim: int, grid_size: Union[int, Tuple[int, int]], return_grid=False) -> torch.Tensor: """ This function initializes a grid and generates a 2D positional embedding using sine and cosine functions. It is a wrapper of get_2d_sincos_pos_embed_from_grid. Args: - embed_dim: The embedding dimension. - grid_size: The grid size. Returns: - pos_embed: The generated 2D positional embedding. """ if isinstance(grid_size, tuple): grid_size_h, grid_size_w = grid_size else: grid_size_h = grid_size_w = grid_size grid_h = torch.arange(grid_size_h, dtype=torch.float) grid_w = torch.arange(grid_size_w, dtype=torch.float) grid = torch.meshgrid(grid_w, grid_h, indexing="xy") grid = torch.stack(grid, dim=0) grid = grid.reshape([2, 1, grid_size_h, grid_size_w]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) if return_grid: return (pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2), grid) return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2) def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor) -> torch.Tensor: """ This function generates a 2D positional embedding from a given grid using sine and cosine functions. Args: - embed_dim: The embedding dimension. - grid: The grid to generate the embedding from. Returns: - emb: The generated 2D positional embedding. """ assert embed_dim % 2 == 0 # use half of dimensions to encode grid_h emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D) return emb def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor) -> torch.Tensor: """ This function generates a 1D positional embedding from a given grid using sine and cosine functions. Args: - embed_dim: The embedding dimension. - pos: The position to generate the embedding from. Returns: - emb: The generated 1D positional embedding. """ assert embed_dim % 2 == 0 omega = torch.arange(embed_dim // 2, dtype=torch.double) omega /= embed_dim / 2.0 omega = 1.0 / 10000**omega # (D/2,) pos = pos.reshape(-1) # (M,) out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product emb_sin = torch.sin(out) # (M, D/2) emb_cos = torch.cos(out) # (M, D/2) emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) return emb[None].float() def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor: """ This function generates a 2D positional embedding from given coordinates using sine and cosine functions. Args: - xy: The coordinates to generate the embedding from. - C: The size of the embedding. - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding. Returns: - pe: The generated 2D positional embedding. """ B, N, D = xy.shape assert D == 2 x = xy[:, :, 0:1] y = xy[:, :, 1:2] div_term = (torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)).reshape(1, 1, int(C / 2)) pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) pe_x[:, :, 0::2] = torch.sin(x * div_term) pe_x[:, :, 1::2] = torch.cos(x * div_term) pe_y[:, :, 0::2] = torch.sin(y * div_term) pe_y[:, :, 1::2] = torch.cos(y * div_term) pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3) if cat_coords: pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3) return pe def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"): r"""Sample a tensor using bilinear interpolation `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at coordinates :attr:`coords` using bilinear interpolation. It is the same as `torch.nn.functional.grid_sample()` but with a different coordinate convention. The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where :math:`B` is the batch size, :math:`C` is the number of channels, :math:`H` is the height of the image, and :math:`W` is the width of the image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`. Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`, in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note that in this case the order of the components is slightly different from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`. If `align_corners` is `True`, the coordinate :math:`x` is assumed to be in the range :math:`[0,W-1]`, with 0 corresponding to the center of the left-most image pixel :math:`W-1` to the center of the right-most pixel. If `align_corners` is `False`, the coordinate :math:`x` is assumed to be in the range :math:`[0,W]`, with 0 corresponding to the left edge of the left-most pixel :math:`W` to the right edge of the right-most pixel. Similar conventions apply to the :math:`y` for the range :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range :math:`[0,T-1]` and :math:`[0,T]`. Args: input (Tensor): batch of input images. coords (Tensor): batch of coordinates. align_corners (bool, optional): Coordinate convention. Defaults to `True`. padding_mode (str, optional): Padding mode. Defaults to `"border"`. Returns: Tensor: sampled points. """ sizes = input.shape[2:] assert len(sizes) in [2, 3] if len(sizes) == 3: # t x y -> x y t to match dimensions T H W in grid_sample coords = coords[..., [1, 2, 0]] if align_corners: coords = coords * torch.tensor([2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device) else: coords = coords * torch.tensor([2 / size for size in reversed(sizes)], device=coords.device) coords -= 1 return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode) def sample_features4d(input, coords): r"""Sample spatial features `sample_features4d(input, coords)` samples the spatial features :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`. The field is sampled at coordinates :attr:`coords` using bilinear interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R, 2)`, where each sample has the format :math:`(x_i, y_i)`. This uses the same convention as :func:`bilinear_sampler` with `align_corners=True`. The output tensor has one feature per point, and has shape :math:`(B, R, C)`. Args: input (Tensor): spatial features. coords (Tensor): points. Returns: Tensor: sampled features. """ B, _, _, _ = input.shape # B R 2 -> B R 1 2 coords = coords.unsqueeze(2) # B C R 1 feats = bilinear_sampler(input, coords) return feats.permute(0, 2, 1, 3).view(B, -1, feats.shape[1] * feats.shape[3]) # B C R 1 -> B R C ================================================ FILE: vggt/dependency/track_predict.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch import numpy as np from .vggsfm_utils import * def predict_tracks( images, conf=None, points_3d=None, masks=None, max_query_pts=2048, query_frame_num=5, keypoint_extractor="aliked+sp", max_points_num=163840, fine_tracking=True, complete_non_vis=True, ): """ Predict tracks for the given images and masks. TODO: support non-square images TODO: support masks This function predicts the tracks for the given images and masks using the specified query method and track predictor. It finds query points, and predicts the tracks, visibility, and scores for the query frames. Args: images: Tensor of shape [S, 3, H, W] containing the input images. conf: Tensor of shape [S, 1, H, W] containing the confidence scores. Default is None. points_3d: Tensor containing 3D points. Default is None. masks: Optional tensor of shape [S, 1, H, W] containing masks. Default is None. max_query_pts: Maximum number of query points. Default is 2048. query_frame_num: Number of query frames to use. Default is 5. keypoint_extractor: Method for keypoint extraction. Default is "aliked+sp". max_points_num: Maximum number of points to process at once. Default is 163840. fine_tracking: Whether to use fine tracking. Default is True. complete_non_vis: Whether to augment non-visible frames. Default is True. Returns: pred_tracks: Numpy array containing the predicted tracks. pred_vis_scores: Numpy array containing the visibility scores for the tracks. pred_confs: Numpy array containing the confidence scores for the tracks. pred_points_3d: Numpy array containing the 3D points for the tracks. pred_colors: Numpy array containing the point colors for the tracks. (0, 255) """ device = images.device dtype = images.dtype tracker = build_vggsfm_tracker().to(device, dtype) # Find query frames query_frame_indexes = generate_rank_by_dino(images, query_frame_num=query_frame_num, device=device) # Add the first image to the front if not already present if 0 in query_frame_indexes: query_frame_indexes.remove(0) query_frame_indexes = [0, *query_frame_indexes] # TODO: add the functionality to handle the masks keypoint_extractors = initialize_feature_extractors( max_query_pts, extractor_method=keypoint_extractor, device=device ) pred_tracks = [] pred_vis_scores = [] pred_confs = [] pred_points_3d = [] pred_colors = [] fmaps_for_tracker = tracker.process_images_to_fmaps(images) if fine_tracking: print("For faster inference, consider disabling fine_tracking") for query_index in query_frame_indexes: print(f"Predicting tracks for query frame {query_index}") pred_track, pred_vis, pred_conf, pred_point_3d, pred_color = _forward_on_query( query_index, images, conf, points_3d, fmaps_for_tracker, keypoint_extractors, tracker, max_points_num, fine_tracking, device, ) pred_tracks.append(pred_track) pred_vis_scores.append(pred_vis) pred_confs.append(pred_conf) pred_points_3d.append(pred_point_3d) pred_colors.append(pred_color) if complete_non_vis: pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, pred_colors = _augment_non_visible_frames( pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, pred_colors, images, conf, points_3d, fmaps_for_tracker, keypoint_extractors, tracker, max_points_num, fine_tracking, min_vis=500, non_vis_thresh=0.1, device=device, ) pred_tracks = np.concatenate(pred_tracks, axis=1) pred_vis_scores = np.concatenate(pred_vis_scores, axis=1) pred_confs = np.concatenate(pred_confs, axis=0) if pred_confs else None pred_points_3d = np.concatenate(pred_points_3d, axis=0) if pred_points_3d else None pred_colors = np.concatenate(pred_colors, axis=0) if pred_colors else None # from vggt.utils.visual_track import visualize_tracks_on_images # visualize_tracks_on_images(images[None], torch.from_numpy(pred_tracks[None]), torch.from_numpy(pred_vis_scores[None])>0.2, out_dir="track_visuals") return pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, pred_colors def _forward_on_query( query_index, images, conf, points_3d, fmaps_for_tracker, keypoint_extractors, tracker, max_points_num, fine_tracking, device, ): """ Process a single query frame for track prediction. Args: query_index: Index of the query frame images: Tensor of shape [S, 3, H, W] containing the input images conf: Confidence tensor points_3d: 3D points tensor fmaps_for_tracker: Feature maps for the tracker keypoint_extractors: Initialized feature extractors tracker: VGG-SFM tracker max_points_num: Maximum number of points to process at once fine_tracking: Whether to use fine tracking device: Device to use for computation Returns: pred_track: Predicted tracks pred_vis: Visibility scores for the tracks pred_conf: Confidence scores for the tracks pred_point_3d: 3D points for the tracks pred_color: Point colors for the tracks (0, 255) """ frame_num, _, height, width = images.shape query_image = images[query_index] query_points = extract_keypoints(query_image, keypoint_extractors, round_keypoints=False) query_points = query_points[:, torch.randperm(query_points.shape[1], device=device)] # Extract the color at the keypoint locations query_points_long = query_points.squeeze(0).round().long() pred_color = images[query_index][:, query_points_long[:, 1], query_points_long[:, 0]] pred_color = (pred_color.permute(1, 0).cpu().numpy() * 255).astype(np.uint8) # Query the confidence and points_3d at the keypoint locations if (conf is not None) and (points_3d is not None): assert height == width assert conf.shape[-2] == conf.shape[-1] assert conf.shape[:3] == points_3d.shape[:3] scale = conf.shape[-1] / width query_points_scaled = (query_points.squeeze(0) * scale).round().long() query_points_scaled = query_points_scaled.cpu().numpy() pred_conf = conf[query_index][query_points_scaled[:, 1], query_points_scaled[:, 0]] pred_point_3d = points_3d[query_index][query_points_scaled[:, 1], query_points_scaled[:, 0]] # heuristic to remove low confidence points # should I export this as an input parameter? valid_mask = pred_conf > 1.2 if valid_mask.sum() > 512: query_points = query_points[:, valid_mask] # Make sure shape is compatible pred_conf = pred_conf[valid_mask] pred_point_3d = pred_point_3d[valid_mask] pred_color = pred_color[valid_mask] else: pred_conf = None pred_point_3d = None reorder_index = calculate_index_mappings(query_index, frame_num, device=device) images_feed, fmaps_feed = switch_tensor_order([images, fmaps_for_tracker], reorder_index, dim=0) images_feed = images_feed[None] # add batch dimension fmaps_feed = fmaps_feed[None] # add batch dimension all_points_num = images_feed.shape[1] * query_points.shape[1] # Don't need to be scared, this is just chunking to make GPU happy if all_points_num > max_points_num: num_splits = (all_points_num + max_points_num - 1) // max_points_num query_points = torch.chunk(query_points, num_splits, dim=1) else: query_points = [query_points] pred_track, pred_vis, _ = predict_tracks_in_chunks( tracker, images_feed, query_points, fmaps_feed, fine_tracking=fine_tracking ) pred_track, pred_vis = switch_tensor_order([pred_track, pred_vis], reorder_index, dim=1) pred_track = pred_track.squeeze(0).float().cpu().numpy() pred_vis = pred_vis.squeeze(0).float().cpu().numpy() return pred_track, pred_vis, pred_conf, pred_point_3d, pred_color def _augment_non_visible_frames( pred_tracks: list, # ← running list of np.ndarrays pred_vis_scores: list, # ← running list of np.ndarrays pred_confs: list, # ← running list of np.ndarrays for confidence scores pred_points_3d: list, # ← running list of np.ndarrays for 3D points pred_colors: list, # ← running list of np.ndarrays for colors images: torch.Tensor, conf, points_3d, fmaps_for_tracker, keypoint_extractors, tracker, max_points_num: int, fine_tracking: bool, *, min_vis: int = 500, non_vis_thresh: float = 0.1, device: torch.device = None, ): """ Augment tracking for frames with insufficient visibility. Args: pred_tracks: List of numpy arrays containing predicted tracks. pred_vis_scores: List of numpy arrays containing visibility scores. pred_confs: List of numpy arrays containing confidence scores. pred_points_3d: List of numpy arrays containing 3D points. pred_colors: List of numpy arrays containing point colors. images: Tensor of shape [S, 3, H, W] containing the input images. conf: Tensor of shape [S, 1, H, W] containing confidence scores points_3d: Tensor containing 3D points fmaps_for_tracker: Feature maps for the tracker keypoint_extractors: Initialized feature extractors tracker: VGG-SFM tracker max_points_num: Maximum number of points to process at once fine_tracking: Whether to use fine tracking min_vis: Minimum visibility threshold non_vis_thresh: Non-visibility threshold device: Device to use for computation Returns: Updated pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, and pred_colors lists. """ last_query = -1 final_trial = False cur_extractors = keypoint_extractors # may be replaced on the final trial while True: # Visibility per frame vis_array = np.concatenate(pred_vis_scores, axis=1) # Count frames with sufficient visibility using numpy sufficient_vis_count = (vis_array > non_vis_thresh).sum(axis=-1) non_vis_frames = np.where(sufficient_vis_count < min_vis)[0].tolist() if len(non_vis_frames) == 0: break print("Processing non visible frames:", non_vis_frames) # Decide the frames & extractor for this round if non_vis_frames[0] == last_query: # Same frame failed twice - final "all-in" attempt final_trial = True cur_extractors = initialize_feature_extractors(2048, extractor_method="sp+sift+aliked", device=device) query_frame_list = non_vis_frames # blast them all at once else: query_frame_list = [non_vis_frames[0]] # Process one at a time last_query = non_vis_frames[0] # Run the tracker for every selected frame for query_index in query_frame_list: new_track, new_vis, new_conf, new_point_3d, new_color = _forward_on_query( query_index, images, conf, points_3d, fmaps_for_tracker, cur_extractors, tracker, max_points_num, fine_tracking, device, ) pred_tracks.append(new_track) pred_vis_scores.append(new_vis) pred_confs.append(new_conf) pred_points_3d.append(new_point_3d) pred_colors.append(new_color) if final_trial: break # Stop after final attempt return pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, pred_colors ================================================ FILE: vggt/dependency/vggsfm_tracker.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from functools import partial from torch import nn, einsum from einops import rearrange, repeat from einops.layers.torch import Rearrange, Reduce from hydra.utils import instantiate from omegaconf import OmegaConf from .track_modules.track_refine import refine_track from .track_modules.blocks import BasicEncoder, ShallowEncoder from .track_modules.base_track_predictor import BaseTrackerPredictor class TrackerPredictor(nn.Module): def __init__(self, **extra_args): super(TrackerPredictor, self).__init__() """ Initializes the tracker predictor. Both coarse_predictor and fine_predictor are constructed as a BaseTrackerPredictor, check track_modules/base_track_predictor.py Both coarse_fnet and fine_fnet are constructed as a 2D CNN network check track_modules/blocks.py for BasicEncoder and ShallowEncoder """ # Define coarse predictor configuration coarse_stride = 4 self.coarse_down_ratio = 2 # Create networks directly instead of using instantiate self.coarse_fnet = BasicEncoder(stride=coarse_stride) self.coarse_predictor = BaseTrackerPredictor(stride=coarse_stride) # Create fine predictor with stride = 1 self.fine_fnet = ShallowEncoder(stride=1) self.fine_predictor = BaseTrackerPredictor( stride=1, depth=4, corr_levels=3, corr_radius=3, latent_dim=32, hidden_size=256, fine=True, use_spaceatt=False, ) def forward( self, images, query_points, fmaps=None, coarse_iters=6, inference=True, fine_tracking=True, fine_chunk=40960 ): """ Args: images (torch.Tensor): Images as RGB, in the range of [0, 1], with a shape of B x S x 3 x H x W. query_points (torch.Tensor): 2D xy of query points, relative to top left, with a shape of B x N x 2. fmaps (torch.Tensor, optional): Precomputed feature maps. Defaults to None. coarse_iters (int, optional): Number of iterations for coarse prediction. Defaults to 6. inference (bool, optional): Whether to perform inference. Defaults to True. fine_tracking (bool, optional): Whether to perform fine tracking. Defaults to True. Returns: tuple: A tuple containing fine_pred_track, coarse_pred_track, pred_vis, and pred_score. """ if fmaps is None: batch_num, frame_num, image_dim, height, width = images.shape reshaped_image = images.reshape(batch_num * frame_num, image_dim, height, width) fmaps = self.process_images_to_fmaps(reshaped_image) fmaps = fmaps.reshape(batch_num, frame_num, -1, fmaps.shape[-2], fmaps.shape[-1]) if inference: torch.cuda.empty_cache() # Coarse prediction coarse_pred_track_lists, pred_vis = self.coarse_predictor( query_points=query_points, fmaps=fmaps, iters=coarse_iters, down_ratio=self.coarse_down_ratio ) coarse_pred_track = coarse_pred_track_lists[-1] if inference: torch.cuda.empty_cache() if fine_tracking: # Refine the coarse prediction fine_pred_track, pred_score = refine_track( images, self.fine_fnet, self.fine_predictor, coarse_pred_track, compute_score=False, chunk=fine_chunk ) if inference: torch.cuda.empty_cache() else: fine_pred_track = coarse_pred_track pred_score = torch.ones_like(pred_vis) return fine_pred_track, coarse_pred_track, pred_vis, pred_score def process_images_to_fmaps(self, images): """ This function processes images for inference. Args: images (torch.Tensor): The images to be processed with shape S x 3 x H x W. Returns: torch.Tensor: The processed feature maps. """ if self.coarse_down_ratio > 1: # whether or not scale down the input images to save memory fmaps = self.coarse_fnet( F.interpolate(images, scale_factor=1 / self.coarse_down_ratio, mode="bilinear", align_corners=True) ) else: fmaps = self.coarse_fnet(images) return fmaps ================================================ FILE: vggt/dependency/vggsfm_utils.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import logging import warnings from typing import Dict, List, Optional, Tuple, Union import numpy as np import pycolmap import torch import torch.nn.functional as F from lightglue import ALIKED, SIFT, SuperPoint from .vggsfm_tracker import TrackerPredictor # Suppress verbose logging from dependencies logging.getLogger("dinov2").setLevel(logging.WARNING) warnings.filterwarnings("ignore", message="xFormers is available") warnings.filterwarnings("ignore", message="dinov2") # Constants _RESNET_MEAN = [0.485, 0.456, 0.406] _RESNET_STD = [0.229, 0.224, 0.225] def build_vggsfm_tracker(model_path=None): """ Build and initialize the VGGSfM tracker. Args: model_path: Path to the model weights file. If None, weights are downloaded from HuggingFace. Returns: Initialized tracker model in eval mode. """ tracker = TrackerPredictor() if model_path is None: default_url = "https://huggingface.co/facebook/VGGSfM/resolve/main/vggsfm_v2_tracker.pt" tracker.load_state_dict(torch.hub.load_state_dict_from_url(default_url)) else: tracker.load_state_dict(torch.load(model_path)) tracker.eval() return tracker def generate_rank_by_dino( images, query_frame_num, image_size=336, model_name="dinov2_vitb14_reg", device="cuda", spatial_similarity=False ): """ Generate a ranking of frames using DINO ViT features. Args: images: Tensor of shape (S, 3, H, W) with values in range [0, 1] query_frame_num: Number of frames to select image_size: Size to resize images to before processing model_name: Name of the DINO model to use device: Device to run the model on spatial_similarity: Whether to use spatial token similarity or CLS token similarity Returns: List of frame indices ranked by their representativeness """ # Resize images to the target size images = F.interpolate(images, (image_size, image_size), mode="bilinear", align_corners=False) # Load DINO model dino_v2_model = torch.hub.load("facebookresearch/dinov2", model_name) dino_v2_model.eval() dino_v2_model = dino_v2_model.to(device) # Normalize images using ResNet normalization resnet_mean = torch.tensor(_RESNET_MEAN, device=device).view(1, 3, 1, 1) resnet_std = torch.tensor(_RESNET_STD, device=device).view(1, 3, 1, 1) images_resnet_norm = (images - resnet_mean) / resnet_std with torch.no_grad(): frame_feat = dino_v2_model(images_resnet_norm, is_training=True) # Process features based on similarity type if spatial_similarity: frame_feat = frame_feat["x_norm_patchtokens"] frame_feat_norm = F.normalize(frame_feat, p=2, dim=1) # Compute the similarity matrix frame_feat_norm = frame_feat_norm.permute(1, 0, 2) similarity_matrix = torch.bmm(frame_feat_norm, frame_feat_norm.transpose(-1, -2)) similarity_matrix = similarity_matrix.mean(dim=0) else: frame_feat = frame_feat["x_norm_clstoken"] frame_feat_norm = F.normalize(frame_feat, p=2, dim=1) similarity_matrix = torch.mm(frame_feat_norm, frame_feat_norm.transpose(-1, -2)) distance_matrix = 100 - similarity_matrix.clone() # Ignore self-pairing similarity_matrix.fill_diagonal_(-100) similarity_sum = similarity_matrix.sum(dim=1) # Find the most common frame most_common_frame_index = torch.argmax(similarity_sum).item() # Conduct FPS sampling starting from the most common frame fps_idx = farthest_point_sampling(distance_matrix, query_frame_num, most_common_frame_index) # Clean up all tensors and models to free memory del frame_feat, frame_feat_norm, similarity_matrix, distance_matrix del dino_v2_model torch.cuda.empty_cache() return fps_idx def farthest_point_sampling(distance_matrix, num_samples, most_common_frame_index=0): """ Farthest point sampling algorithm to select diverse frames. Args: distance_matrix: Matrix of distances between frames num_samples: Number of frames to select most_common_frame_index: Index of the first frame to select Returns: List of selected frame indices """ distance_matrix = distance_matrix.clamp(min=0) N = distance_matrix.size(0) # Initialize with the most common frame selected_indices = [most_common_frame_index] check_distances = distance_matrix[selected_indices] while len(selected_indices) < num_samples: # Find the farthest point from the current set of selected points farthest_point = torch.argmax(check_distances) selected_indices.append(farthest_point.item()) check_distances = distance_matrix[farthest_point] # Mark already selected points to avoid selecting them again check_distances[selected_indices] = 0 # Break if all points have been selected if len(selected_indices) == N: break return selected_indices def calculate_index_mappings(query_index, S, device=None): """ Construct an order that switches [query_index] and [0] so that the content of query_index would be placed at [0]. Args: query_index: Index to swap with 0 S: Total number of elements device: Device to place the tensor on Returns: Tensor of indices with the swapped order """ new_order = torch.arange(S) new_order[0] = query_index new_order[query_index] = 0 if device is not None: new_order = new_order.to(device) return new_order def switch_tensor_order(tensors, order, dim=1): """ Reorder tensors along a specific dimension according to the given order. Args: tensors: List of tensors to reorder order: Tensor of indices specifying the new order dim: Dimension along which to reorder Returns: List of reordered tensors """ return [torch.index_select(tensor, dim, order) if tensor is not None else None for tensor in tensors] def initialize_feature_extractors(max_query_num, det_thres=0.005, extractor_method="aliked", device="cuda"): """ Initialize feature extractors that can be reused based on a method string. Args: max_query_num: Maximum number of keypoints to extract det_thres: Detection threshold for keypoint extraction extractor_method: String specifying which extractors to use (e.g., "aliked", "sp+sift", "aliked+sp+sift") device: Device to run extraction on Returns: Dictionary of initialized extractors """ extractors = {} methods = extractor_method.lower().split("+") for method in methods: method = method.strip() if method == "aliked": aliked_extractor = ALIKED(max_num_keypoints=max_query_num, detection_threshold=det_thres) extractors["aliked"] = aliked_extractor.to(device).eval() elif method == "sp": sp_extractor = SuperPoint(max_num_keypoints=max_query_num, detection_threshold=det_thres) extractors["sp"] = sp_extractor.to(device).eval() elif method == "sift": sift_extractor = SIFT(max_num_keypoints=max_query_num) extractors["sift"] = sift_extractor.to(device).eval() else: print(f"Warning: Unknown feature extractor '{method}', ignoring.") if not extractors: print(f"Warning: No valid extractors found in '{extractor_method}'. Using ALIKED by default.") aliked_extractor = ALIKED(max_num_keypoints=max_query_num, detection_threshold=det_thres) extractors["aliked"] = aliked_extractor.to(device).eval() return extractors def extract_keypoints(query_image, extractors, round_keypoints=True): """ Extract keypoints using pre-initialized feature extractors. Args: query_image: Input image tensor (3xHxW, range [0, 1]) extractors: Dictionary of initialized extractors Returns: Tensor of keypoint coordinates (1xNx2) """ query_points = None with torch.no_grad(): for extractor_name, extractor in extractors.items(): query_points_data = extractor.extract(query_image, invalid_mask=None) extractor_points = query_points_data["keypoints"] if round_keypoints: extractor_points = extractor_points.round() if query_points is not None: query_points = torch.cat([query_points, extractor_points], dim=1) else: query_points = extractor_points return query_points def predict_tracks_in_chunks( track_predictor, images_feed, query_points_list, fmaps_feed, fine_tracking, num_splits=None, fine_chunk=40960 ): """ Process a list of query points to avoid memory issues. Args: track_predictor (object): The track predictor object used for predicting tracks. images_feed (torch.Tensor): A tensor of shape (B, T, C, H, W) representing a batch of images. query_points_list (list or tuple): A list/tuple of tensors, each of shape (B, Ni, 2) representing chunks of query points. fmaps_feed (torch.Tensor): A tensor of feature maps for the tracker. fine_tracking (bool): Whether to perform fine tracking. num_splits (int, optional): Ignored when query_points_list is provided. Kept for backward compatibility. Returns: tuple: A tuple containing the concatenated predicted tracks, visibility, and scores. """ # If query_points_list is not a list or tuple but a single tensor, handle it like the old version for backward compatibility if not isinstance(query_points_list, (list, tuple)): query_points = query_points_list if num_splits is None: num_splits = 1 query_points_list = torch.chunk(query_points, num_splits, dim=1) # Ensure query_points_list is a list for iteration (as torch.chunk returns a tuple) if isinstance(query_points_list, tuple): query_points_list = list(query_points_list) fine_pred_track_list = [] pred_vis_list = [] pred_score_list = [] for split_points in query_points_list: # Feed into track predictor for each split fine_pred_track, _, pred_vis, pred_score = track_predictor( images_feed, split_points, fmaps=fmaps_feed, fine_tracking=fine_tracking, fine_chunk=fine_chunk ) fine_pred_track_list.append(fine_pred_track) pred_vis_list.append(pred_vis) pred_score_list.append(pred_score) # Concatenate the results from all splits fine_pred_track = torch.cat(fine_pred_track_list, dim=2) pred_vis = torch.cat(pred_vis_list, dim=2) if pred_score is not None: pred_score = torch.cat(pred_score_list, dim=2) else: pred_score = None return fine_pred_track, pred_vis, pred_score ================================================ FILE: vggt/heads/camera_head.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import math import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from vggt.layers import Mlp from vggt.layers.block import Block from vggt.heads.head_act import activate_pose class CameraHead(nn.Module): """ CameraHead predicts camera parameters from token representations using iterative refinement. It applies a series of transformer blocks (the "trunk") to dedicated camera tokens. """ def __init__( self, dim_in: int = 2048, trunk_depth: int = 4, pose_encoding_type: str = "absT_quaR_FoV", num_heads: int = 16, mlp_ratio: int = 4, init_values: float = 0.01, trans_act: str = "linear", quat_act: str = "linear", fl_act: str = "relu", # Field of view activations: ensures FOV values are positive. ): super().__init__() if pose_encoding_type == "absT_quaR_FoV": self.target_dim = 9 else: raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}") self.trans_act = trans_act self.quat_act = quat_act self.fl_act = fl_act self.trunk_depth = trunk_depth # Build the trunk using a sequence of transformer blocks. self.trunk = nn.Sequential( *[ Block(dim=dim_in, num_heads=num_heads, mlp_ratio=mlp_ratio, init_values=init_values) for _ in range(trunk_depth) ] ) # Normalizations for camera token and trunk output. self.token_norm = nn.LayerNorm(dim_in) self.trunk_norm = nn.LayerNorm(dim_in) # Learnable empty camera pose token. self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim)) self.embed_pose = nn.Linear(self.target_dim, dim_in) # Module for producing modulation parameters: shift, scale, and a gate. self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True)) # Adaptive layer normalization without affine parameters. self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6) self.pose_branch = Mlp(in_features=dim_in, hidden_features=dim_in // 2, out_features=self.target_dim, drop=0) def forward(self, aggregated_tokens_list: list, num_iterations: int = 4) -> list: """ Forward pass to predict camera parameters. Args: aggregated_tokens_list (list): List of token tensors from the network; the last tensor is used for prediction. num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4. Returns: list: A list of predicted camera encodings (post-activation) from each iteration. """ # Use tokens from the last block for camera prediction. tokens = aggregated_tokens_list[-1] # Extract the camera tokens pose_tokens = tokens[:, :, 0] pose_tokens = self.token_norm(pose_tokens) pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations) return pred_pose_enc_list def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list: """ Iteratively refine camera pose predictions. Args: pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, S, C]. num_iterations (int): Number of refinement iterations. Returns: list: List of activated camera encodings from each iteration. """ B, S, C = pose_tokens.shape pred_pose_enc = None pred_pose_enc_list = [] for _ in range(num_iterations): # Use a learned empty pose for the first iteration. if pred_pose_enc is None: module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1)) else: # Detach the previous prediction to avoid backprop through time. pred_pose_enc = pred_pose_enc.detach() module_input = self.embed_pose(pred_pose_enc) # Generate modulation parameters and split them into shift, scale, and gate components. shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1) # Adaptive layer normalization and modulation. pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa) pose_tokens_modulated = pose_tokens_modulated + pose_tokens pose_tokens_modulated = self.trunk(pose_tokens_modulated) # Compute the delta update for the pose encoding. pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated)) if pred_pose_enc is None: pred_pose_enc = pred_pose_enc_delta else: pred_pose_enc = pred_pose_enc + pred_pose_enc_delta # Apply final activation functions for translation, quaternion, and field-of-view. activated_pose = activate_pose( pred_pose_enc, trans_act=self.trans_act, quat_act=self.quat_act, fl_act=self.fl_act ) pred_pose_enc_list.append(activated_pose) return pred_pose_enc_list def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: """ Modulate the input tensor using scaling and shifting parameters. """ # modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19 return x * (1 + scale) + shift ================================================ FILE: vggt/heads/dpt_head.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # Inspired by https://github.com/DepthAnything/Depth-Anything-V2 import os from typing import List, Dict, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from .head_act import activate_head from .utils import create_uv_grid, position_grid_to_embed class DPTHead(nn.Module): """ DPT Head for dense prediction tasks. This implementation follows the architecture described in "Vision Transformers for Dense Prediction" (https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer backbone and produces dense predictions by fusing multi-scale features. Args: dim_in (int): Input dimension (channels). patch_size (int, optional): Patch size. Default is 14. output_dim (int, optional): Number of output channels. Default is 4. activation (str, optional): Activation type. Default is "inv_log". conf_activation (str, optional): Confidence activation type. Default is "expp1". features (int, optional): Feature channels for intermediate representations. Default is 256. out_channels (List[int], optional): Output channels for each intermediate layer. intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT. pos_embed (bool, optional): Whether to use positional embedding. Default is True. feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False. down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1. """ def __init__( self, dim_in: int, patch_size: int = 14, output_dim: int = 4, activation: str = "inv_log", conf_activation: str = "expp1", features: int = 256, out_channels: List[int] = [256, 512, 1024, 1024], intermediate_layer_idx: List[int] = [4, 11, 17, 23], pos_embed: bool = True, feature_only: bool = False, down_ratio: int = 1, ) -> None: super(DPTHead, self).__init__() self.patch_size = patch_size self.activation = activation self.conf_activation = conf_activation self.pos_embed = pos_embed self.feature_only = feature_only self.down_ratio = down_ratio self.intermediate_layer_idx = intermediate_layer_idx self.norm = nn.LayerNorm(dim_in) # Projection layers for each output channel from tokens. self.projects = nn.ModuleList( [nn.Conv2d(in_channels=dim_in, out_channels=oc, kernel_size=1, stride=1, padding=0) for oc in out_channels] ) # Resize layers for upsampling feature maps. self.resize_layers = nn.ModuleList( [ nn.ConvTranspose2d( in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0 ), nn.ConvTranspose2d( in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0 ), nn.Identity(), nn.Conv2d( in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1 ), ] ) self.scratch = _make_scratch(out_channels, features, expand=False) # Attach additional modules to scratch. self.scratch.stem_transpose = None self.scratch.refinenet1 = _make_fusion_block(features) self.scratch.refinenet2 = _make_fusion_block(features) self.scratch.refinenet3 = _make_fusion_block(features) self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False) head_features_1 = features head_features_2 = 32 if feature_only: self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1) else: self.scratch.output_conv1 = nn.Conv2d( head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1 ) conv2_in_channels = head_features_1 // 2 self.scratch.output_conv2 = nn.Sequential( nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True), nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0), ) def forward( self, aggregated_tokens_list: List[torch.Tensor], images: torch.Tensor, patch_start_idx: int, frames_chunk_size: int = 8, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Forward pass through the DPT head, supports processing by chunking frames. Args: aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers. images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1]. patch_start_idx (int): Starting index for patch tokens in the token sequence. Used to separate patch tokens from other tokens (e.g., camera or register tokens). frames_chunk_size (int, optional): Number of frames to process in each chunk. If None or larger than S, all frames are processed at once. Default: 8. Returns: Tensor or Tuple[Tensor, Tensor]: - If feature_only=True: Feature maps with shape [B, S, C, H, W] - Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W] """ B, S, _, H, W = images.shape # If frames_chunk_size is not specified or greater than S, process all frames at once if frames_chunk_size is None or frames_chunk_size >= S: return self._forward_impl(aggregated_tokens_list, images, patch_start_idx) # Otherwise, process frames in chunks to manage memory usage assert frames_chunk_size > 0 # Process frames in batches all_preds = [] all_conf = [] for frames_start_idx in range(0, S, frames_chunk_size): frames_end_idx = min(frames_start_idx + frames_chunk_size, S) # Process batch of frames if self.feature_only: chunk_output = self._forward_impl( aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx ) all_preds.append(chunk_output) else: chunk_preds, chunk_conf = self._forward_impl( aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx ) all_preds.append(chunk_preds) all_conf.append(chunk_conf) # Concatenate results along the sequence dimension if self.feature_only: return torch.cat(all_preds, dim=1) else: return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1) def _forward_impl( self, aggregated_tokens_list: List[torch.Tensor], images: torch.Tensor, patch_start_idx: int, frames_start_idx: int = None, frames_end_idx: int = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Implementation of the forward pass through the DPT head. This method processes a specific chunk of frames from the sequence. Args: aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers. images (Tensor): Input images with shape [B, S, 3, H, W]. patch_start_idx (int): Starting index for patch tokens. frames_start_idx (int, optional): Starting index for frames to process. frames_end_idx (int, optional): Ending index for frames to process. Returns: Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence). """ if frames_start_idx is not None and frames_end_idx is not None: images = images[:, frames_start_idx:frames_end_idx].contiguous() B, S, _, H, W = images.shape patch_h, patch_w = H // self.patch_size, W // self.patch_size out = [] dpt_idx = 0 for layer_idx in self.intermediate_layer_idx: x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:] # Select frames if processing a chunk if frames_start_idx is not None and frames_end_idx is not None: x = x[:, frames_start_idx:frames_end_idx] x = x.reshape(B * S, -1, x.shape[-1]) x = self.norm(x) x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)) x = self.projects[dpt_idx](x) if self.pos_embed: x = self._apply_pos_embed(x, W, H) x = self.resize_layers[dpt_idx](x) out.append(x) dpt_idx += 1 # Fuse features from multiple layers. out = self.scratch_forward(out) # Interpolate fused output to match target image resolution. out = custom_interpolate( out, (int(patch_h * self.patch_size / self.down_ratio), int(patch_w * self.patch_size / self.down_ratio)), mode="bilinear", align_corners=True, ) if self.pos_embed: out = self._apply_pos_embed(out, W, H) if self.feature_only: return out.view(B, S, *out.shape[1:]) out = self.scratch.output_conv2(out) preds, conf = activate_head(out, activation=self.activation, conf_activation=self.conf_activation) preds = preds.view(B, S, *preds.shape[1:]) conf = conf.view(B, S, *conf.shape[1:]) return preds, conf def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor: """ Apply positional embedding to tensor x. """ patch_w = x.shape[-1] patch_h = x.shape[-2] pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device) pos_embed = position_grid_to_embed(pos_embed, x.shape[1]) pos_embed = pos_embed * ratio pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1) return x + pos_embed def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor: """ Forward pass through the fusion blocks. Args: features (List[Tensor]): List of feature maps from different layers. Returns: Tensor: Fused feature map. """ layer_1, layer_2, layer_3, layer_4 = features layer_1_rn = self.scratch.layer1_rn(layer_1) layer_2_rn = self.scratch.layer2_rn(layer_2) layer_3_rn = self.scratch.layer3_rn(layer_3) layer_4_rn = self.scratch.layer4_rn(layer_4) out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) del layer_4_rn, layer_4 out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:]) del layer_3_rn, layer_3 out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:]) del layer_2_rn, layer_2 out = self.scratch.refinenet1(out, layer_1_rn) del layer_1_rn, layer_1 out = self.scratch.output_conv1(out) return out ################################################################################ # Modules ################################################################################ def _make_fusion_block(features: int, size: int = None, has_residual: bool = True, groups: int = 1) -> nn.Module: return FeatureFusionBlock( features, nn.ReLU(inplace=True), deconv=False, bn=False, expand=False, align_corners=True, size=size, has_residual=has_residual, groups=groups, ) def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False) -> nn.Module: scratch = nn.Module() out_shape1 = out_shape out_shape2 = out_shape out_shape3 = out_shape if len(in_shape) >= 4: out_shape4 = out_shape if expand: out_shape1 = out_shape out_shape2 = out_shape * 2 out_shape3 = out_shape * 4 if len(in_shape) >= 4: out_shape4 = out_shape * 8 scratch.layer1_rn = nn.Conv2d( in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups ) scratch.layer2_rn = nn.Conv2d( in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups ) scratch.layer3_rn = nn.Conv2d( in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups ) if len(in_shape) >= 4: scratch.layer4_rn = nn.Conv2d( in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups ) return scratch class ResidualConvUnit(nn.Module): """Residual convolution module.""" def __init__(self, features, activation, bn, groups=1): """Init. Args: features (int): number of features """ super().__init__() self.bn = bn self.groups = groups self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) self.norm1 = None self.norm2 = None self.activation = activation self.skip_add = nn.quantized.FloatFunctional() def forward(self, x): """Forward pass. Args: x (tensor): input Returns: tensor: output """ out = self.activation(x) out = self.conv1(out) if self.norm1 is not None: out = self.norm1(out) out = self.activation(out) out = self.conv2(out) if self.norm2 is not None: out = self.norm2(out) return self.skip_add.add(out, x) class FeatureFusionBlock(nn.Module): """Feature fusion block.""" def __init__( self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None, has_residual=True, groups=1, ): """Init. Args: features (int): number of features """ super(FeatureFusionBlock, self).__init__() self.deconv = deconv self.align_corners = align_corners self.groups = groups self.expand = expand out_features = features if self.expand == True: out_features = features // 2 self.out_conv = nn.Conv2d( features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups ) if has_residual: self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups) self.has_residual = has_residual self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups) self.skip_add = nn.quantized.FloatFunctional() self.size = size def forward(self, *xs, size=None): """Forward pass. Returns: tensor: output """ output = xs[0] if self.has_residual: res = self.resConfUnit1(xs[1]) output = self.skip_add.add(output, res) output = self.resConfUnit2(output) if (size is None) and (self.size is None): modifier = {"scale_factor": 2} elif size is None: modifier = {"size": self.size} else: modifier = {"size": size} output = custom_interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners) output = self.out_conv(output) return output def custom_interpolate( x: torch.Tensor, size: Tuple[int, int] = None, scale_factor: float = None, mode: str = "bilinear", align_corners: bool = True, ) -> torch.Tensor: """ Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate. """ if size is None: size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor)) INT_MAX = 1610612736 input_elements = size[0] * size[1] * x.shape[0] * x.shape[1] if input_elements > INT_MAX: chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0) interpolated_chunks = [ nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks ] x = torch.cat(interpolated_chunks, dim=0) return x.contiguous() else: return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners) ================================================ FILE: vggt/heads/head_act.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch import torch.nn.functional as F def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"): """ Activate pose parameters with specified activation functions. Args: pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length] trans_act: Activation type for translation component quat_act: Activation type for quaternion component fl_act: Activation type for focal length component Returns: Activated pose parameters tensor """ T = pred_pose_enc[..., :3] quat = pred_pose_enc[..., 3:7] fl = pred_pose_enc[..., 7:] # or fov T = base_pose_act(T, trans_act) quat = base_pose_act(quat, quat_act) fl = base_pose_act(fl, fl_act) # or fov pred_pose_enc = torch.cat([T, quat, fl], dim=-1) return pred_pose_enc def base_pose_act(pose_enc, act_type="linear"): """ Apply basic activation function to pose parameters. Args: pose_enc: Tensor containing encoded pose parameters act_type: Activation type ("linear", "inv_log", "exp", "relu") Returns: Activated pose parameters """ if act_type == "linear": return pose_enc elif act_type == "inv_log": return inverse_log_transform(pose_enc) elif act_type == "exp": return torch.exp(pose_enc) elif act_type == "relu": return F.relu(pose_enc) else: raise ValueError(f"Unknown act_type: {act_type}") def activate_head(out, activation="norm_exp", conf_activation="expp1"): """ Process network output to extract 3D points and confidence values. Args: out: Network output tensor (B, C, H, W) activation: Activation type for 3D points conf_activation: Activation type for confidence values Returns: Tuple of (3D points tensor, confidence tensor) """ # Move channels from last dim to the 4th dimension => (B, H, W, C) fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected # Split into xyz (first C-1 channels) and confidence (last channel) xyz = fmap[:, :, :, :-1] conf = fmap[:, :, :, -1] if activation == "norm_exp": d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8) xyz_normed = xyz / d pts3d = xyz_normed * torch.expm1(d) elif activation == "norm": pts3d = xyz / xyz.norm(dim=-1, keepdim=True) elif activation == "exp": pts3d = torch.exp(xyz) elif activation == "relu": pts3d = F.relu(xyz) elif activation == "inv_log": pts3d = inverse_log_transform(xyz) elif activation == "xy_inv_log": xy, z = xyz.split([2, 1], dim=-1) z = inverse_log_transform(z) pts3d = torch.cat([xy * z, z], dim=-1) elif activation == "sigmoid": pts3d = torch.sigmoid(xyz) elif activation == "linear": pts3d = xyz else: raise ValueError(f"Unknown activation: {activation}") if conf_activation == "expp1": conf_out = 1 + conf.exp() elif conf_activation == "expp0": conf_out = conf.exp() elif conf_activation == "sigmoid": conf_out = torch.sigmoid(conf) else: raise ValueError(f"Unknown conf_activation: {conf_activation}") return pts3d, conf_out def inverse_log_transform(y): """ Apply inverse log transform: sign(y) * (exp(|y|) - 1) Args: y: Input tensor Returns: Transformed tensor """ return torch.sign(y) * (torch.expm1(torch.abs(y))) ================================================ FILE: vggt/heads/track_head.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch.nn as nn from .dpt_head import DPTHead from .track_modules.base_track_predictor import BaseTrackerPredictor class TrackHead(nn.Module): """ Track head that uses DPT head to process tokens and BaseTrackerPredictor for tracking. The tracking is performed iteratively, refining predictions over multiple iterations. """ def __init__( self, dim_in, patch_size=14, features=128, iters=4, predict_conf=True, stride=2, corr_levels=7, corr_radius=4, hidden_size=384, ): """ Initialize the TrackHead module. Args: dim_in (int): Input dimension of tokens from the backbone. patch_size (int): Size of image patches used in the vision transformer. features (int): Number of feature channels in the feature extractor output. iters (int): Number of refinement iterations for tracking predictions. predict_conf (bool): Whether to predict confidence scores for tracked points. stride (int): Stride value for the tracker predictor. corr_levels (int): Number of correlation pyramid levels corr_radius (int): Radius for correlation computation, controlling the search area. hidden_size (int): Size of hidden layers in the tracker network. """ super().__init__() self.patch_size = patch_size # Feature extractor based on DPT architecture # Processes tokens into feature maps for tracking self.feature_extractor = DPTHead( dim_in=dim_in, patch_size=patch_size, features=features, feature_only=True, # Only output features, no activation down_ratio=2, # Reduces spatial dimensions by factor of 2 pos_embed=False, ) # Tracker module that predicts point trajectories # Takes feature maps and predicts coordinates and visibility self.tracker = BaseTrackerPredictor( latent_dim=features, # Match the output_dim of feature extractor predict_conf=predict_conf, stride=stride, corr_levels=corr_levels, corr_radius=corr_radius, hidden_size=hidden_size, ) self.iters = iters def forward(self, aggregated_tokens_list, images, patch_start_idx, query_points=None, iters=None): """ Forward pass of the TrackHead. Args: aggregated_tokens_list (list): List of aggregated tokens from the backbone. images (torch.Tensor): Input images of shape (B, S, C, H, W) where: B = batch size, S = sequence length. patch_start_idx (int): Starting index for patch tokens. query_points (torch.Tensor, optional): Initial query points to track. If None, points are initialized by the tracker. iters (int, optional): Number of refinement iterations. If None, uses self.iters. Returns: tuple: - coord_preds (torch.Tensor): Predicted coordinates for tracked points. - vis_scores (torch.Tensor): Visibility scores for tracked points. - conf_scores (torch.Tensor): Confidence scores for tracked points (if predict_conf=True). """ B, S, _, H, W = images.shape # Extract features from tokens # feature_maps has shape (B, S, C, H//2, W//2) due to down_ratio=2 feature_maps = self.feature_extractor(aggregated_tokens_list, images, patch_start_idx) # Use default iterations if not specified if iters is None: iters = self.iters # Perform tracking using the extracted features coord_preds, vis_scores, conf_scores = self.tracker(query_points=query_points, fmaps=feature_maps, iters=iters) return coord_preds, vis_scores, conf_scores ================================================ FILE: vggt/heads/track_modules/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. ================================================ FILE: vggt/heads/track_modules/base_track_predictor.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch import torch.nn as nn from einops import rearrange, repeat from .blocks import EfficientUpdateFormer, CorrBlock from .utils import sample_features4d, get_2d_embedding, get_2d_sincos_pos_embed from .modules import Mlp class BaseTrackerPredictor(nn.Module): def __init__( self, stride=1, corr_levels=5, corr_radius=4, latent_dim=128, hidden_size=384, use_spaceatt=True, depth=6, max_scale=518, predict_conf=True, ): super(BaseTrackerPredictor, self).__init__() """ The base template to create a track predictor Modified from https://github.com/facebookresearch/co-tracker/ and https://github.com/facebookresearch/vggsfm """ self.stride = stride self.latent_dim = latent_dim self.corr_levels = corr_levels self.corr_radius = corr_radius self.hidden_size = hidden_size self.max_scale = max_scale self.predict_conf = predict_conf self.flows_emb_dim = latent_dim // 2 self.corr_mlp = Mlp( in_features=self.corr_levels * (self.corr_radius * 2 + 1) ** 2, hidden_features=self.hidden_size, out_features=self.latent_dim, ) self.transformer_dim = self.latent_dim + self.latent_dim + self.latent_dim + 4 self.query_ref_token = nn.Parameter(torch.randn(1, 2, self.transformer_dim)) space_depth = depth if use_spaceatt else 0 time_depth = depth self.updateformer = EfficientUpdateFormer( space_depth=space_depth, time_depth=time_depth, input_dim=self.transformer_dim, hidden_size=self.hidden_size, output_dim=self.latent_dim + 2, mlp_ratio=4.0, add_space_attn=use_spaceatt, ) self.fmap_norm = nn.LayerNorm(self.latent_dim) self.ffeat_norm = nn.GroupNorm(1, self.latent_dim) # A linear layer to update track feats at each iteration self.ffeat_updater = nn.Sequential(nn.Linear(self.latent_dim, self.latent_dim), nn.GELU()) self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1)) if predict_conf: self.conf_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1)) def forward(self, query_points, fmaps=None, iters=6, return_feat=False, down_ratio=1, apply_sigmoid=True): """ query_points: B x N x 2, the number of batches, tracks, and xy fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension. note HH and WW is the size of feature maps instead of original images """ B, N, D = query_points.shape B, S, C, HH, WW = fmaps.shape assert D == 2, "Input points must be 2D coordinates" # apply a layernorm to fmaps here fmaps = self.fmap_norm(fmaps.permute(0, 1, 3, 4, 2)) fmaps = fmaps.permute(0, 1, 4, 2, 3) # Scale the input query_points because we may downsample the images # by down_ratio or self.stride # e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map # its query_points should be query_points/4 if down_ratio > 1: query_points = query_points / float(down_ratio) query_points = query_points / float(self.stride) # Init with coords as the query points # It means the search will start from the position of query points at the reference frames coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1) # Sample/extract the features of the query points in the query frame query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0]) # init track feats by query feats track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1) # B, S, N, C # back up the init coords coords_backup = coords.clone() fcorr_fn = CorrBlock(fmaps, num_levels=self.corr_levels, radius=self.corr_radius) coord_preds = [] # Iterative Refinement for _ in range(iters): # Detach the gradients from the last iteration # (in my experience, not very important for performance) coords = coords.detach() fcorrs = fcorr_fn.corr_sample(track_feats, coords) corr_dim = fcorrs.shape[3] fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corr_dim) fcorrs_ = self.corr_mlp(fcorrs_) # Movement of current coords relative to query points flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2) flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False) # (In my trials, it is also okay to just add the flows_emb instead of concat) flows_emb = torch.cat([flows_emb, flows / self.max_scale, flows / self.max_scale], dim=-1) track_feats_ = track_feats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim) # Concatenate them as the input for the transformers transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2) # 2D positional embed # TODO: this can be much simplified pos_embed = get_2d_sincos_pos_embed(self.transformer_dim, grid_size=(HH, WW)).to(query_points.device) sampled_pos_emb = sample_features4d(pos_embed.expand(B, -1, -1, -1), coords[:, 0]) sampled_pos_emb = rearrange(sampled_pos_emb, "b n c -> (b n) c").unsqueeze(1) x = transformer_input + sampled_pos_emb # Add the query ref token to the track feats query_ref_token = torch.cat( [self.query_ref_token[:, 0:1], self.query_ref_token[:, 1:2].expand(-1, S - 1, -1)], dim=1 ) x = x + query_ref_token.to(x.device).to(x.dtype) # B, N, S, C x = rearrange(x, "(b n) s d -> b n s d", b=B) # Compute the delta coordinates and delta track features delta, _ = self.updateformer(x) # BN, S, C delta = rearrange(delta, " b n s d -> (b n) s d", b=B) delta_coords_ = delta[:, :, :2] delta_feats_ = delta[:, :, 2:] track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim) delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim) # Update the track features track_feats_ = self.ffeat_updater(self.ffeat_norm(delta_feats_)) + track_feats_ track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute(0, 2, 1, 3) # BxSxNxC # B x S x N x 2 coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3) # Force coord0 as query # because we assume the query points should not be changed coords[:, 0] = coords_backup[:, 0] # The predicted tracks are in the original image scale if down_ratio > 1: coord_preds.append(coords * self.stride * down_ratio) else: coord_preds.append(coords * self.stride) # B, S, N vis_e = self.vis_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N) if apply_sigmoid: vis_e = torch.sigmoid(vis_e) if self.predict_conf: conf_e = self.conf_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N) if apply_sigmoid: conf_e = torch.sigmoid(conf_e) else: conf_e = None if return_feat: return coord_preds, vis_e, track_feats, query_track_feat, conf_e else: return coord_preds, vis_e, conf_e ================================================ FILE: vggt/heads/track_modules/blocks.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # Modified from https://github.com/facebookresearch/co-tracker/ import math import torch import torch.nn as nn import torch.nn.functional as F from .utils import bilinear_sampler from .modules import Mlp, AttnBlock, CrossAttnBlock, ResidualBlock class EfficientUpdateFormer(nn.Module): """ Transformer model that updates track estimates. """ def __init__( self, space_depth=6, time_depth=6, input_dim=320, hidden_size=384, num_heads=8, output_dim=130, mlp_ratio=4.0, add_space_attn=True, num_virtual_tracks=64, ): super().__init__() self.out_channels = 2 self.num_heads = num_heads self.hidden_size = hidden_size self.add_space_attn = add_space_attn # Add input LayerNorm before linear projection self.input_norm = nn.LayerNorm(input_dim) self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True) # Add output LayerNorm before final projection self.output_norm = nn.LayerNorm(hidden_size) self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True) self.num_virtual_tracks = num_virtual_tracks if self.add_space_attn: self.virual_tracks = nn.Parameter(torch.randn(1, num_virtual_tracks, 1, hidden_size)) else: self.virual_tracks = None self.time_blocks = nn.ModuleList( [ AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=nn.MultiheadAttention) for _ in range(time_depth) ] ) if add_space_attn: self.space_virtual_blocks = nn.ModuleList( [ AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=nn.MultiheadAttention) for _ in range(space_depth) ] ) self.space_point2virtual_blocks = nn.ModuleList( [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)] ) self.space_virtual2point_blocks = nn.ModuleList( [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)] ) assert len(self.time_blocks) >= len(self.space_virtual2point_blocks) self.initialize_weights() def initialize_weights(self): def _basic_init(module): if isinstance(module, nn.Linear): torch.nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001) self.apply(_basic_init) def forward(self, input_tensor, mask=None): # Apply input LayerNorm input_tensor = self.input_norm(input_tensor) tokens = self.input_transform(input_tensor) init_tokens = tokens B, _, T, _ = tokens.shape if self.add_space_attn: virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1) tokens = torch.cat([tokens, virtual_tokens], dim=1) _, N, _, _ = tokens.shape j = 0 for i in range(len(self.time_blocks)): time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C time_tokens = self.time_blocks[i](time_tokens) tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C if self.add_space_attn and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0): space_tokens = tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1) # B N T C -> (B T) N C point_tokens = space_tokens[:, : N - self.num_virtual_tracks] virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :] virtual_tokens = self.space_virtual2point_blocks[j](virtual_tokens, point_tokens, mask=mask) virtual_tokens = self.space_virtual_blocks[j](virtual_tokens) point_tokens = self.space_point2virtual_blocks[j](point_tokens, virtual_tokens, mask=mask) space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1) tokens = space_tokens.view(B, T, N, -1).permute(0, 2, 1, 3) # (B T) N C -> B N T C j += 1 if self.add_space_attn: tokens = tokens[:, : N - self.num_virtual_tracks] tokens = tokens + init_tokens # Apply output LayerNorm before final projection tokens = self.output_norm(tokens) flow = self.flow_head(tokens) return flow, None class CorrBlock: def __init__(self, fmaps, num_levels=4, radius=4, multiple_track_feats=False, padding_mode="zeros"): """ Build a pyramid of feature maps from the input. fmaps: Tensor (B, S, C, H, W) num_levels: number of pyramid levels (each downsampled by factor 2) radius: search radius for sampling correlation multiple_track_feats: if True, split the target features per pyramid level padding_mode: passed to grid_sample / bilinear_sampler """ B, S, C, H, W = fmaps.shape self.S, self.C, self.H, self.W = S, C, H, W self.num_levels = num_levels self.radius = radius self.padding_mode = padding_mode self.multiple_track_feats = multiple_track_feats # Build pyramid: each level is half the spatial resolution of the previous self.fmaps_pyramid = [fmaps] # level 0 is full resolution current_fmaps = fmaps for i in range(num_levels - 1): B, S, C, H, W = current_fmaps.shape # Merge batch & sequence dimensions current_fmaps = current_fmaps.reshape(B * S, C, H, W) # Avg pool down by factor 2 current_fmaps = F.avg_pool2d(current_fmaps, kernel_size=2, stride=2) _, _, H_new, W_new = current_fmaps.shape current_fmaps = current_fmaps.reshape(B, S, C, H_new, W_new) self.fmaps_pyramid.append(current_fmaps) # Precompute a delta grid (of shape (2r+1, 2r+1, 2)) for sampling. # This grid is added to the (scaled) coordinate centroids. r = self.radius dx = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype) dy = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype) # delta: for every (dy,dx) displacement (i.e. Δx, Δy) self.delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), dim=-1) # shape: (2r+1, 2r+1, 2) def corr_sample(self, targets, coords): """ Instead of storing the entire correlation pyramid, we compute each level's correlation volume, sample it immediately, then discard it. This saves GPU memory. Args: targets: Tensor (B, S, N, C) — features for the current targets. coords: Tensor (B, S, N, 2) — coordinates at full resolution. Returns: Tensor (B, S, N, L) where L = num_levels * (2*radius+1)**2 (concatenated sampled correlations) """ B, S, N, C = targets.shape # If you have multiple track features, split them per level. if self.multiple_track_feats: targets_split = torch.split(targets, C // self.num_levels, dim=-1) out_pyramid = [] for i, fmaps in enumerate(self.fmaps_pyramid): # Get current spatial resolution H, W for this pyramid level. B, S, C, H, W = fmaps.shape # Reshape feature maps for correlation computation: # fmap2s: (B, S, C, H*W) fmap2s = fmaps.view(B, S, C, H * W) # Choose appropriate target features. fmap1 = targets_split[i] if self.multiple_track_feats else targets # shape: (B, S, N, C) # Compute correlation directly corrs = compute_corr_level(fmap1, fmap2s, C) corrs = corrs.view(B, S, N, H, W) # Prepare sampling grid: # Scale down the coordinates for the current level. centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / (2**i) # Make sure our precomputed delta grid is on the same device/dtype. delta_lvl = self.delta.to(coords.device).to(coords.dtype) # Now the grid for grid_sample is: # coords_lvl = centroid_lvl + delta_lvl (broadcasted over grid) coords_lvl = centroid_lvl + delta_lvl.view(1, 2 * self.radius + 1, 2 * self.radius + 1, 2) # Sample from the correlation volume using bilinear interpolation. # We reshape corrs to (B * S * N, 1, H, W) so grid_sample acts over each target. corrs_sampled = bilinear_sampler( corrs.reshape(B * S * N, 1, H, W), coords_lvl, padding_mode=self.padding_mode ) # The sampled output is (B * S * N, 1, 2r+1, 2r+1). Flatten the last two dims. corrs_sampled = corrs_sampled.view(B, S, N, -1) # Now shape: (B, S, N, (2r+1)^2) out_pyramid.append(corrs_sampled) # Concatenate all levels along the last dimension. out = torch.cat(out_pyramid, dim=-1).contiguous() return out def compute_corr_level(fmap1, fmap2s, C): # fmap1: (B, S, N, C) # fmap2s: (B, S, C, H*W) corrs = torch.matmul(fmap1, fmap2s) # (B, S, N, H*W) corrs = corrs.view(fmap1.shape[0], fmap1.shape[1], fmap1.shape[2], -1) # (B, S, N, H*W) return corrs / math.sqrt(C) ================================================ FILE: vggt/heads/track_modules/modules.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch import torch.nn as nn import torch.nn.functional as F from functools import partial from typing import Callable import collections from torch import Tensor from itertools import repeat # From PyTorch internals def _ntuple(n): def parse(x): if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): return tuple(x) return tuple(repeat(x, n)) return parse def exists(val): return val is not None def default(val, d): return val if exists(val) else d to_2tuple = _ntuple(2) class ResidualBlock(nn.Module): """ ResidualBlock: construct a block of two conv layers with residual connections """ def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3): super(ResidualBlock, self).__init__() self.conv1 = nn.Conv2d( in_planes, planes, kernel_size=kernel_size, padding=1, stride=stride, padding_mode="zeros" ) self.conv2 = nn.Conv2d(planes, planes, kernel_size=kernel_size, padding=1, padding_mode="zeros") self.relu = nn.ReLU(inplace=True) num_groups = planes // 8 if norm_fn == "group": self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) if not stride == 1: self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) elif norm_fn == "batch": self.norm1 = nn.BatchNorm2d(planes) self.norm2 = nn.BatchNorm2d(planes) if not stride == 1: self.norm3 = nn.BatchNorm2d(planes) elif norm_fn == "instance": self.norm1 = nn.InstanceNorm2d(planes) self.norm2 = nn.InstanceNorm2d(planes) if not stride == 1: self.norm3 = nn.InstanceNorm2d(planes) elif norm_fn == "none": self.norm1 = nn.Sequential() self.norm2 = nn.Sequential() if not stride == 1: self.norm3 = nn.Sequential() else: raise NotImplementedError if stride == 1: self.downsample = None else: self.downsample = nn.Sequential(nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) def forward(self, x): y = x y = self.relu(self.norm1(self.conv1(y))) y = self.relu(self.norm2(self.conv2(y))) if self.downsample is not None: x = self.downsample(x) return self.relu(x + y) class Mlp(nn.Module): """MLP as used in Vision Transformer, MLP-Mixer and related networks""" def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, norm_layer=None, bias=True, drop=0.0, use_conv=False, ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features bias = to_2tuple(bias) drop_probs = to_2tuple(drop) linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) self.act = act_layer() self.drop1 = nn.Dropout(drop_probs[0]) self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) self.drop2 = nn.Dropout(drop_probs[1]) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop1(x) x = self.fc2(x) x = self.drop2(x) return x class AttnBlock(nn.Module): def __init__( self, hidden_size, num_heads, attn_class: Callable[..., nn.Module] = nn.MultiheadAttention, mlp_ratio=4.0, **block_kwargs, ): """ Self attention block """ super().__init__() self.norm1 = nn.LayerNorm(hidden_size) self.norm2 = nn.LayerNorm(hidden_size) self.attn = attn_class(embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs) mlp_hidden_dim = int(hidden_size * mlp_ratio) self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0) def forward(self, x, mask=None): # Prepare the mask for PyTorch's attention (it expects a different format) # attn_mask = mask if mask is not None else None # Normalize before attention x = self.norm1(x) # PyTorch's MultiheadAttention returns attn_output, attn_output_weights # attn_output, _ = self.attn(x, x, x, attn_mask=attn_mask) attn_output, _ = self.attn(x, x, x) # Add & Norm x = x + attn_output x = x + self.mlp(self.norm2(x)) return x class CrossAttnBlock(nn.Module): def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs): """ Cross attention block """ super().__init__() self.norm1 = nn.LayerNorm(hidden_size) self.norm_context = nn.LayerNorm(hidden_size) self.norm2 = nn.LayerNorm(hidden_size) self.cross_attn = nn.MultiheadAttention( embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs ) mlp_hidden_dim = int(hidden_size * mlp_ratio) self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0) def forward(self, x, context, mask=None): # Normalize inputs x = self.norm1(x) context = self.norm_context(context) # Apply cross attention # Note: nn.MultiheadAttention returns attn_output, attn_output_weights attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask) # Add & Norm x = x + attn_output x = x + self.mlp(self.norm2(x)) return x ================================================ FILE: vggt/heads/track_modules/utils.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # Modified from https://github.com/facebookresearch/vggsfm # and https://github.com/facebookresearch/co-tracker/tree/main import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Tuple, Union def get_2d_sincos_pos_embed(embed_dim: int, grid_size: Union[int, Tuple[int, int]], return_grid=False) -> torch.Tensor: """ This function initializes a grid and generates a 2D positional embedding using sine and cosine functions. It is a wrapper of get_2d_sincos_pos_embed_from_grid. Args: - embed_dim: The embedding dimension. - grid_size: The grid size. Returns: - pos_embed: The generated 2D positional embedding. """ if isinstance(grid_size, tuple): grid_size_h, grid_size_w = grid_size else: grid_size_h = grid_size_w = grid_size grid_h = torch.arange(grid_size_h, dtype=torch.float) grid_w = torch.arange(grid_size_w, dtype=torch.float) grid = torch.meshgrid(grid_w, grid_h, indexing="xy") grid = torch.stack(grid, dim=0) grid = grid.reshape([2, 1, grid_size_h, grid_size_w]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) if return_grid: return (pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2), grid) return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2) def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor) -> torch.Tensor: """ This function generates a 2D positional embedding from a given grid using sine and cosine functions. Args: - embed_dim: The embedding dimension. - grid: The grid to generate the embedding from. Returns: - emb: The generated 2D positional embedding. """ assert embed_dim % 2 == 0 # use half of dimensions to encode grid_h emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D) return emb def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor) -> torch.Tensor: """ This function generates a 1D positional embedding from a given grid using sine and cosine functions. Args: - embed_dim: The embedding dimension. - pos: The position to generate the embedding from. Returns: - emb: The generated 1D positional embedding. """ assert embed_dim % 2 == 0 omega = torch.arange(embed_dim // 2, dtype=torch.double) omega /= embed_dim / 2.0 omega = 1.0 / 10000**omega # (D/2,) pos = pos.reshape(-1) # (M,) out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product emb_sin = torch.sin(out) # (M, D/2) emb_cos = torch.cos(out) # (M, D/2) emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) return emb[None].float() def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor: """ This function generates a 2D positional embedding from given coordinates using sine and cosine functions. Args: - xy: The coordinates to generate the embedding from. - C: The size of the embedding. - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding. Returns: - pe: The generated 2D positional embedding. """ B, N, D = xy.shape assert D == 2 x = xy[:, :, 0:1] y = xy[:, :, 1:2] div_term = (torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)).reshape(1, 1, int(C / 2)) pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) pe_x[:, :, 0::2] = torch.sin(x * div_term) pe_x[:, :, 1::2] = torch.cos(x * div_term) pe_y[:, :, 0::2] = torch.sin(y * div_term) pe_y[:, :, 1::2] = torch.cos(y * div_term) pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3) if cat_coords: pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3) return pe def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"): r"""Sample a tensor using bilinear interpolation `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at coordinates :attr:`coords` using bilinear interpolation. It is the same as `torch.nn.functional.grid_sample()` but with a different coordinate convention. The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where :math:`B` is the batch size, :math:`C` is the number of channels, :math:`H` is the height of the image, and :math:`W` is the width of the image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`. Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`, in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note that in this case the order of the components is slightly different from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`. If `align_corners` is `True`, the coordinate :math:`x` is assumed to be in the range :math:`[0,W-1]`, with 0 corresponding to the center of the left-most image pixel :math:`W-1` to the center of the right-most pixel. If `align_corners` is `False`, the coordinate :math:`x` is assumed to be in the range :math:`[0,W]`, with 0 corresponding to the left edge of the left-most pixel :math:`W` to the right edge of the right-most pixel. Similar conventions apply to the :math:`y` for the range :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range :math:`[0,T-1]` and :math:`[0,T]`. Args: input (Tensor): batch of input images. coords (Tensor): batch of coordinates. align_corners (bool, optional): Coordinate convention. Defaults to `True`. padding_mode (str, optional): Padding mode. Defaults to `"border"`. Returns: Tensor: sampled points. """ coords = coords.detach().clone() ############################################################ # IMPORTANT: coords = coords.to(input.device).to(input.dtype) ############################################################ sizes = input.shape[2:] assert len(sizes) in [2, 3] if len(sizes) == 3: # t x y -> x y t to match dimensions T H W in grid_sample coords = coords[..., [1, 2, 0]] if align_corners: scale = torch.tensor( [2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device, dtype=coords.dtype ) else: scale = torch.tensor([2 / size for size in reversed(sizes)], device=coords.device, dtype=coords.dtype) coords.mul_(scale) # coords = coords * scale coords.sub_(1) # coords = coords - 1 return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode) def sample_features4d(input, coords): r"""Sample spatial features `sample_features4d(input, coords)` samples the spatial features :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`. The field is sampled at coordinates :attr:`coords` using bilinear interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R, 2)`, where each sample has the format :math:`(x_i, y_i)`. This uses the same convention as :func:`bilinear_sampler` with `align_corners=True`. The output tensor has one feature per point, and has shape :math:`(B, R, C)`. Args: input (Tensor): spatial features. coords (Tensor): points. Returns: Tensor: sampled features. """ B, _, _, _ = input.shape # B R 2 -> B R 1 2 coords = coords.unsqueeze(2) # B C R 1 feats = bilinear_sampler(input, coords) return feats.permute(0, 2, 1, 3).view(B, -1, feats.shape[1] * feats.shape[3]) # B C R 1 -> B R C ================================================ FILE: vggt/heads/utils.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch import torch.nn as nn def position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100) -> torch.Tensor: """ Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC) Args: pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates embed_dim: Output channel dimension for embeddings Returns: Tensor of shape (H, W, embed_dim) with positional embeddings """ H, W, grid_dim = pos_grid.shape assert grid_dim == 2 pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2) # Process x and y coordinates separately emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) # [1, H*W, D/2] emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) # [1, H*W, D/2] # Combine and reshape emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D] return emb.view(H, W, embed_dim) # [H, W, D] def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor: """ This function generates a 1D positional embedding from a given grid using sine and cosine functions. Args: - embed_dim: The embedding dimension. - pos: The position to generate the embedding from. Returns: - emb: The generated 1D positional embedding. """ assert embed_dim % 2 == 0 device = pos.device omega = torch.arange(embed_dim // 2, dtype=torch.float32 if device.type == "mps" else torch.double, device=device) omega /= embed_dim / 2.0 omega = 1.0 / omega_0**omega # (D/2,) pos = pos.reshape(-1) # (M,) out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product emb_sin = torch.sin(out) # (M, D/2) emb_cos = torch.cos(out) # (M, D/2) emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) return emb.float() # Inspired by https://github.com/microsoft/moge def create_uv_grid( width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None ) -> torch.Tensor: """ Create a normalized UV grid of shape (width, height, 2). The grid spans horizontally and vertically according to an aspect ratio, ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right corner is at (x_span, y_span), normalized by the diagonal of the plane. Args: width (int): Number of points horizontally. height (int): Number of points vertically. aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height. dtype (torch.dtype, optional): Data type of the resulting tensor. device (torch.device, optional): Device on which the tensor is created. Returns: torch.Tensor: A (width, height, 2) tensor of UV coordinates. """ # Derive aspect ratio if not explicitly provided if aspect_ratio is None: aspect_ratio = float(width) / float(height) # Compute normalized spans for X and Y diag_factor = (aspect_ratio**2 + 1.0) ** 0.5 span_x = aspect_ratio / diag_factor span_y = 1.0 / diag_factor # Establish the linspace boundaries left_x = -span_x * (width - 1) / width right_x = span_x * (width - 1) / width top_y = -span_y * (height - 1) / height bottom_y = span_y * (height - 1) / height # Generate 1D coordinates x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device) y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device) # Create 2D meshgrid (width x height) and stack into UV uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy") uv_grid = torch.stack((uu, vv), dim=-1) return uv_grid ================================================ FILE: vggt/layers/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from .mlp import Mlp from .patch_embed import PatchEmbed from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused from .block import NestedTensorBlock from .attention import MemEffAttention ================================================ FILE: vggt/layers/attention.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. # References: # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py import logging import os import warnings from torch import Tensor from torch import nn import torch.nn.functional as F XFORMERS_AVAILABLE = False class Attention(nn.Module): def __init__( self, dim: int, num_heads: int = 8, qkv_bias: bool = True, proj_bias: bool = True, attn_drop: float = 0.0, proj_drop: float = 0.0, norm_layer: nn.Module = nn.LayerNorm, qk_norm: bool = False, fused_attn: bool = True, # use F.scaled_dot_product_attention or not rope=None, ) -> None: super().__init__() assert dim % num_heads == 0, "dim should be divisible by num_heads" self.num_heads = num_heads self.head_dim = dim // num_heads self.scale = self.head_dim**-0.5 self.fused_attn = fused_attn self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim, bias=proj_bias) self.proj_drop = nn.Dropout(proj_drop) self.rope = rope def forward(self, x: Tensor, pos=None) -> Tensor: B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) q, k = self.q_norm(q), self.k_norm(k) if self.rope is not None: q = self.rope(q, pos) k = self.rope(k, pos) if self.fused_attn: x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p if self.training else 0.0) else: q = q * self.scale attn = q @ k.transpose(-2, -1) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = attn @ v x = x.transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x class MemEffAttention(Attention): def forward(self, x: Tensor, attn_bias=None, pos=None) -> Tensor: assert pos is None if not XFORMERS_AVAILABLE: if attn_bias is not None: raise AssertionError("xFormers is required for using nested tensors") return super().forward(x) B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) q, k, v = unbind(qkv, 2) x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) x = x.reshape([B, N, C]) x = self.proj(x) x = self.proj_drop(x) return x ================================================ FILE: vggt/layers/block.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. # References: # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py import logging import os from typing import Callable, List, Any, Tuple, Dict import warnings import torch from torch import nn, Tensor from .attention import Attention from .drop_path import DropPath from .layer_scale import LayerScale from .mlp import Mlp XFORMERS_AVAILABLE = False class Block(nn.Module): def __init__( self, dim: int, num_heads: int, mlp_ratio: float = 4.0, qkv_bias: bool = True, proj_bias: bool = True, ffn_bias: bool = True, drop: float = 0.0, attn_drop: float = 0.0, init_values=None, drop_path: float = 0.0, act_layer: Callable[..., nn.Module] = nn.GELU, norm_layer: Callable[..., nn.Module] = nn.LayerNorm, attn_class: Callable[..., nn.Module] = Attention, ffn_layer: Callable[..., nn.Module] = Mlp, qk_norm: bool = False, fused_attn: bool = True, # use F.scaled_dot_product_attention or not rope=None, ) -> None: super().__init__() self.norm1 = norm_layer(dim) self.attn = attn_class( dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, attn_drop=attn_drop, proj_drop=drop, qk_norm=qk_norm, fused_attn=fused_attn, rope=rope, ) self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = ffn_layer( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, bias=ffn_bias ) self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.sample_drop_ratio = drop_path def forward(self, x: Tensor, pos=None) -> Tensor: def attn_residual_func(x: Tensor, pos=None) -> Tensor: return self.ls1(self.attn(self.norm1(x), pos=pos)) def ffn_residual_func(x: Tensor) -> Tensor: return self.ls2(self.mlp(self.norm2(x))) if self.training and self.sample_drop_ratio > 0.1: # the overhead is compensated only for a drop path rate larger than 0.1 x = drop_add_residual_stochastic_depth( x, pos=pos, residual_func=attn_residual_func, sample_drop_ratio=self.sample_drop_ratio ) x = drop_add_residual_stochastic_depth( x, residual_func=ffn_residual_func, sample_drop_ratio=self.sample_drop_ratio ) elif self.training and self.sample_drop_ratio > 0.0: x = x + self.drop_path1(attn_residual_func(x, pos=pos)) x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 else: x = x + attn_residual_func(x, pos=pos) x = x + ffn_residual_func(x) return x def drop_add_residual_stochastic_depth( x: Tensor, residual_func: Callable[[Tensor], Tensor], sample_drop_ratio: float = 0.0, pos=None ) -> Tensor: # 1) extract subset using permutation b, n, d = x.shape sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) brange = (torch.randperm(b, device=x.device))[:sample_subset_size] x_subset = x[brange] # 2) apply residual_func to get residual if pos is not None: # if necessary, apply rope to the subset pos = pos[brange] residual = residual_func(x_subset, pos=pos) else: residual = residual_func(x_subset) x_flat = x.flatten(1) residual = residual.flatten(1) residual_scale_factor = b / sample_subset_size # 3) add the residual x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) return x_plus_residual.view_as(x) def get_branges_scales(x, sample_drop_ratio=0.0): b, n, d = x.shape sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) brange = (torch.randperm(b, device=x.device))[:sample_subset_size] residual_scale_factor = b / sample_subset_size return brange, residual_scale_factor def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): if scaling_vector is None: x_flat = x.flatten(1) residual = residual.flatten(1) x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) else: x_plus_residual = scaled_index_add( x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor ) return x_plus_residual attn_bias_cache: Dict[Tuple, Any] = {} def get_attn_bias_and_cat(x_list, branges=None): """ this will perform the index select, cat the tensors, and provide the attn_bias from cache """ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) if all_shapes not in attn_bias_cache.keys(): seqlens = [] for b, x in zip(batch_sizes, x_list): for _ in range(b): seqlens.append(x.shape[1]) attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) attn_bias._batch_sizes = batch_sizes attn_bias_cache[all_shapes] = attn_bias if branges is not None: cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) else: tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) cat_tensors = torch.cat(tensors_bs1, dim=1) return attn_bias_cache[all_shapes], cat_tensors def drop_add_residual_stochastic_depth_list( x_list: List[Tensor], residual_func: Callable[[Tensor, Any], Tensor], sample_drop_ratio: float = 0.0, scaling_vector=None, ) -> Tensor: # 1) generate random set of indices for dropping samples in the batch branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] branges = [s[0] for s in branges_scales] residual_scale_factors = [s[1] for s in branges_scales] # 2) get attention bias and index+concat the tensors attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) # 3) apply residual_func to get residual, and split the result residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore outputs = [] for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) return outputs class NestedTensorBlock(Block): def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: """ x_list contains a list of tensors to nest together and run """ assert isinstance(self.attn, MemEffAttention) if self.training and self.sample_drop_ratio > 0.0: def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: return self.attn(self.norm1(x), attn_bias=attn_bias) def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: return self.mlp(self.norm2(x)) x_list = drop_add_residual_stochastic_depth_list( x_list, residual_func=attn_residual_func, sample_drop_ratio=self.sample_drop_ratio, scaling_vector=(self.ls1.gamma if isinstance(self.ls1, LayerScale) else None), ) x_list = drop_add_residual_stochastic_depth_list( x_list, residual_func=ffn_residual_func, sample_drop_ratio=self.sample_drop_ratio, scaling_vector=(self.ls2.gamma if isinstance(self.ls1, LayerScale) else None), ) return x_list else: def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: return self.ls2(self.mlp(self.norm2(x))) attn_bias, x = get_attn_bias_and_cat(x_list) x = x + attn_residual_func(x, attn_bias=attn_bias) x = x + ffn_residual_func(x) return attn_bias.split(x) def forward(self, x_or_x_list): if isinstance(x_or_x_list, Tensor): return super().forward(x_or_x_list) elif isinstance(x_or_x_list, list): if not XFORMERS_AVAILABLE: raise AssertionError("xFormers is required for using nested tensors") return self.forward_nested(x_or_x_list) else: raise AssertionError ================================================ FILE: vggt/layers/drop_path.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. # References: # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py from torch import nn def drop_path(x, drop_prob: float = 0.0, training: bool = False): if drop_prob == 0.0 or not training: return x keep_prob = 1 - drop_prob shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets random_tensor = x.new_empty(shape).bernoulli_(keep_prob) if keep_prob > 0.0: random_tensor.div_(keep_prob) output = x * random_tensor return output class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" def __init__(self, drop_prob=None): super(DropPath, self).__init__() self.drop_prob = drop_prob def forward(self, x): return drop_path(x, self.drop_prob, self.training) ================================================ FILE: vggt/layers/layer_scale.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 from typing import Union import torch from torch import Tensor from torch import nn class LayerScale(nn.Module): def __init__(self, dim: int, init_values: Union[float, Tensor] = 1e-5, inplace: bool = False) -> None: super().__init__() self.inplace = inplace self.gamma = nn.Parameter(init_values * torch.ones(dim)) def forward(self, x: Tensor) -> Tensor: return x.mul_(self.gamma) if self.inplace else x * self.gamma ================================================ FILE: vggt/layers/mlp.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. # References: # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py from typing import Callable, Optional from torch import Tensor, nn class Mlp(nn.Module): def __init__( self, in_features: int, hidden_features: Optional[int] = None, out_features: Optional[int] = None, act_layer: Callable[..., nn.Module] = nn.GELU, drop: float = 0.0, bias: bool = True, ) -> None: super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) self.drop = nn.Dropout(drop) def forward(self, x: Tensor) -> Tensor: x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x ================================================ FILE: vggt/layers/patch_embed.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. # References: # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py from typing import Callable, Optional, Tuple, Union from torch import Tensor import torch.nn as nn def make_2tuple(x): if isinstance(x, tuple): assert len(x) == 2 return x assert isinstance(x, int) return (x, x) class PatchEmbed(nn.Module): """ 2D image to patch embedding: (B,C,H,W) -> (B,N,D) Args: img_size: Image size. patch_size: Patch token size. in_chans: Number of input image channels. embed_dim: Number of linear projection output channels. norm_layer: Normalization layer. """ def __init__( self, img_size: Union[int, Tuple[int, int]] = 224, patch_size: Union[int, Tuple[int, int]] = 16, in_chans: int = 3, embed_dim: int = 768, norm_layer: Optional[Callable] = None, flatten_embedding: bool = True, ) -> None: super().__init__() image_HW = make_2tuple(img_size) patch_HW = make_2tuple(patch_size) patch_grid_size = (image_HW[0] // patch_HW[0], image_HW[1] // patch_HW[1]) self.img_size = image_HW self.patch_size = patch_HW self.patches_resolution = patch_grid_size self.num_patches = patch_grid_size[0] * patch_grid_size[1] self.in_chans = in_chans self.embed_dim = embed_dim self.flatten_embedding = flatten_embedding self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() def forward(self, x: Tensor) -> Tensor: _, _, H, W = x.shape patch_H, patch_W = self.patch_size assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" x = self.proj(x) # B C H W H, W = x.size(2), x.size(3) x = x.flatten(2).transpose(1, 2) # B HW C x = self.norm(x) if not self.flatten_embedding: x = x.reshape(-1, H, W, self.embed_dim) # B H W C return x def flops(self) -> float: Ho, Wo = self.patches_resolution flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) if self.norm is not None: flops += Ho * Wo * self.embed_dim return flops ================================================ FILE: vggt/layers/rope.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. # Implementation of 2D Rotary Position Embeddings (RoPE). # This module provides a clean implementation of 2D Rotary Position Embeddings, # which extends the original RoPE concept to handle 2D spatial positions. # Inspired by: # https://github.com/meta-llama/codellama/blob/main/llama/model.py # https://github.com/naver-ai/rope-vit import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from typing import Dict, Tuple class PositionGetter: """Generates and caches 2D spatial positions for patches in a grid. This class efficiently manages the generation of spatial coordinates for patches in a 2D grid, caching results to avoid redundant computations. Attributes: position_cache: Dictionary storing precomputed position tensors for different grid dimensions. """ def __init__(self): """Initializes the position generator with an empty cache.""" self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {} def __call__(self, batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor: """Generates spatial positions for a batch of patches. Args: batch_size: Number of samples in the batch. height: Height of the grid in patches. width: Width of the grid in patches. device: Target device for the position tensor. Returns: Tensor of shape (batch_size, height*width, 2) containing y,x coordinates for each position in the grid, repeated for each batch item. """ if (height, width) not in self.position_cache: y_coords = torch.arange(height, device=device) x_coords = torch.arange(width, device=device) positions = torch.cartesian_prod(y_coords, x_coords) self.position_cache[height, width] = positions cached_positions = self.position_cache[height, width] return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone() class RotaryPositionEmbedding2D(nn.Module): """2D Rotary Position Embedding implementation. This module applies rotary position embeddings to input tokens based on their 2D spatial positions. It handles the position-dependent rotation of features separately for vertical and horizontal dimensions. Args: frequency: Base frequency for the position embeddings. Default: 100.0 scaling_factor: Scaling factor for frequency computation. Default: 1.0 Attributes: base_frequency: Base frequency for computing position embeddings. scaling_factor: Factor to scale the computed frequencies. frequency_cache: Cache for storing precomputed frequency components. """ def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0): """Initializes the 2D RoPE module.""" super().__init__() self.base_frequency = frequency self.scaling_factor = scaling_factor self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {} def _compute_frequency_components( self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype ) -> Tuple[torch.Tensor, torch.Tensor]: """Computes frequency components for rotary embeddings. Args: dim: Feature dimension (must be even). seq_len: Maximum sequence length. device: Target device for computations. dtype: Data type for the computed tensors. Returns: Tuple of (cosine, sine) tensors for frequency components. """ cache_key = (dim, seq_len, device, dtype) if cache_key not in self.frequency_cache: # Compute frequency bands exponents = torch.arange(0, dim, 2, device=device).float() / dim inv_freq = 1.0 / (self.base_frequency**exponents) # Generate position-dependent frequencies positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) angles = torch.einsum("i,j->ij", positions, inv_freq) # Compute and cache frequency components angles = angles.to(dtype) angles = torch.cat((angles, angles), dim=-1) cos_components = angles.cos().to(dtype) sin_components = angles.sin().to(dtype) self.frequency_cache[cache_key] = (cos_components, sin_components) return self.frequency_cache[cache_key] @staticmethod def _rotate_features(x: torch.Tensor) -> torch.Tensor: """Performs feature rotation by splitting and recombining feature dimensions. Args: x: Input tensor to rotate. Returns: Rotated feature tensor. """ feature_dim = x.shape[-1] x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :] return torch.cat((-x2, x1), dim=-1) def _apply_1d_rope( self, tokens: torch.Tensor, positions: torch.Tensor, cos_comp: torch.Tensor, sin_comp: torch.Tensor ) -> torch.Tensor: """Applies 1D rotary position embeddings along one dimension. Args: tokens: Input token features. positions: Position indices. cos_comp: Cosine components for rotation. sin_comp: Sine components for rotation. Returns: Tokens with applied rotary position embeddings. """ # Embed positions with frequency components cos = F.embedding(positions, cos_comp)[:, None, :, :] sin = F.embedding(positions, sin_comp)[:, None, :, :] # Apply rotation return (tokens * cos) + (self._rotate_features(tokens) * sin) def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: """Applies 2D rotary position embeddings to input tokens. Args: tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim). The feature dimension (dim) must be divisible by 4. positions: Position tensor of shape (batch_size, n_tokens, 2) containing the y and x coordinates for each token. Returns: Tensor of same shape as input with applied 2D rotary position embeddings. Raises: AssertionError: If input dimensions are invalid or positions are malformed. """ # Validate inputs assert tokens.size(-1) % 2 == 0, "Feature dimension must be even" assert positions.ndim == 3 and positions.shape[-1] == 2, "Positions must have shape (batch_size, n_tokens, 2)" # Compute feature dimension for each spatial direction feature_dim = tokens.size(-1) // 2 # Get frequency components max_position = int(positions.max()) + 1 cos_comp, sin_comp = self._compute_frequency_components(feature_dim, max_position, tokens.device, tokens.dtype) # Split features for vertical and horizontal processing vertical_features, horizontal_features = tokens.chunk(2, dim=-1) # Apply RoPE separately for each dimension vertical_features = self._apply_1d_rope(vertical_features, positions[..., 0], cos_comp, sin_comp) horizontal_features = self._apply_1d_rope(horizontal_features, positions[..., 1], cos_comp, sin_comp) # Combine processed features return torch.cat((vertical_features, horizontal_features), dim=-1) ================================================ FILE: vggt/layers/swiglu_ffn.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. import os from typing import Callable, Optional import warnings from torch import Tensor, nn import torch.nn.functional as F class SwiGLUFFN(nn.Module): def __init__( self, in_features: int, hidden_features: Optional[int] = None, out_features: Optional[int] = None, act_layer: Callable[..., nn.Module] = None, drop: float = 0.0, bias: bool = True, ) -> None: super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) self.w3 = nn.Linear(hidden_features, out_features, bias=bias) def forward(self, x: Tensor) -> Tensor: x12 = self.w12(x) x1, x2 = x12.chunk(2, dim=-1) hidden = F.silu(x1) * x2 return self.w3(hidden) XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None # try: # if XFORMERS_ENABLED: # from xformers.ops import SwiGLU # XFORMERS_AVAILABLE = True # warnings.warn("xFormers is available (SwiGLU)") # else: # warnings.warn("xFormers is disabled (SwiGLU)") # raise ImportError # except ImportError: SwiGLU = SwiGLUFFN XFORMERS_AVAILABLE = False # warnings.warn("xFormers is not available (SwiGLU)") class SwiGLUFFNFused(SwiGLU): def __init__( self, in_features: int, hidden_features: Optional[int] = None, out_features: Optional[int] = None, act_layer: Callable[..., nn.Module] = None, drop: float = 0.0, bias: bool = True, ) -> None: out_features = out_features or in_features hidden_features = hidden_features or in_features hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 super().__init__(in_features=in_features, hidden_features=hidden_features, out_features=out_features, bias=bias) ================================================ FILE: vggt/layers/vision_transformer.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. # References: # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py from functools import partial import math import logging from typing import Sequence, Tuple, Union, Callable import torch import torch.nn as nn from torch.utils.checkpoint import checkpoint from torch.nn.init import trunc_normal_ from . import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block logger = logging.getLogger("dinov2") def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: if not depth_first and include_root: fn(module=module, name=name) for child_name, child_module in module.named_children(): child_name = ".".join((name, child_name)) if name else child_name named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) if depth_first and include_root: fn(module=module, name=name) return module class BlockChunk(nn.ModuleList): def forward(self, x): for b in self: x = b(x) return x class DinoVisionTransformer(nn.Module): def __init__( self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, ffn_bias=True, proj_bias=True, drop_path_rate=0.0, drop_path_uniform=False, init_values=None, # for layerscale: None or 0 => no layerscale embed_layer=PatchEmbed, act_layer=nn.GELU, block_fn=Block, ffn_layer="mlp", block_chunks=1, num_register_tokens=0, interpolate_antialias=False, interpolate_offset=0.1, qk_norm=False, ): """ Args: img_size (int, tuple): input image size patch_size (int, tuple): patch size in_chans (int): number of input channels embed_dim (int): embedding dimension depth (int): depth of transformer num_heads (int): number of attention heads mlp_ratio (int): ratio of mlp hidden dim to embedding dim qkv_bias (bool): enable bias for qkv if True proj_bias (bool): enable bias for proj in attn if True ffn_bias (bool): enable bias for ffn if True drop_path_rate (float): stochastic depth rate drop_path_uniform (bool): apply uniform drop rate across blocks weight_init (str): weight init scheme init_values (float): layer-scale init values embed_layer (nn.Module): patch embedding layer act_layer (nn.Module): MLP activation layer block_fn (nn.Module): transformer block class ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" block_chunks: (int) split block sequence into block_chunks units for FSDP wrap num_register_tokens: (int) number of extra cls tokens (so-called "registers") interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings """ super().__init__() norm_layer = partial(nn.LayerNorm, eps=1e-6) self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.num_tokens = 1 self.n_blocks = depth self.num_heads = num_heads self.patch_size = patch_size self.num_register_tokens = num_register_tokens self.interpolate_antialias = interpolate_antialias self.interpolate_offset = interpolate_offset self.use_reentrant = False # hardcoded to False self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) assert num_register_tokens >= 0 self.register_tokens = ( nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None ) if drop_path_uniform is True: dpr = [drop_path_rate] * depth else: dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule if ffn_layer == "mlp": logger.info("using MLP layer as FFN") ffn_layer = Mlp elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": logger.info("using SwiGLU layer as FFN") ffn_layer = SwiGLUFFNFused elif ffn_layer == "identity": logger.info("using Identity layer as FFN") def f(*args, **kwargs): return nn.Identity() ffn_layer = f else: raise NotImplementedError blocks_list = [ block_fn( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, proj_bias=proj_bias, ffn_bias=ffn_bias, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer, ffn_layer=ffn_layer, init_values=init_values, qk_norm=qk_norm, ) for i in range(depth) ] if block_chunks > 0: self.chunked_blocks = True chunked_blocks = [] chunksize = depth // block_chunks for i in range(0, depth, chunksize): # this is to keep the block index consistent if we chunk the block list chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) else: self.chunked_blocks = False self.blocks = nn.ModuleList(blocks_list) self.norm = norm_layer(embed_dim) self.head = nn.Identity() self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) self.init_weights() def init_weights(self): trunc_normal_(self.pos_embed, std=0.02) nn.init.normal_(self.cls_token, std=1e-6) if self.register_tokens is not None: nn.init.normal_(self.register_tokens, std=1e-6) named_apply(init_weights_vit_timm, self) def interpolate_pos_encoding(self, x, w, h): previous_dtype = x.dtype npatch = x.shape[1] - 1 N = self.pos_embed.shape[1] - 1 if npatch == N and w == h: return self.pos_embed pos_embed = self.pos_embed.float() class_pos_embed = pos_embed[:, 0] patch_pos_embed = pos_embed[:, 1:] dim = x.shape[-1] w0 = w // self.patch_size h0 = h // self.patch_size M = int(math.sqrt(N)) # Recover the number of patches in each dimension assert N == M * M kwargs = {} if self.interpolate_offset: # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8 # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors sx = float(w0 + self.interpolate_offset) / M sy = float(h0 + self.interpolate_offset) / M kwargs["scale_factor"] = (sx, sy) else: # Simply specify an output size instead of a scale factor kwargs["size"] = (w0, h0) patch_pos_embed = nn.functional.interpolate( patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), mode="bicubic", antialias=self.interpolate_antialias, **kwargs, ) assert (w0, h0) == patch_pos_embed.shape[-2:] patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) def prepare_tokens_with_masks(self, x, masks=None): B, nc, w, h = x.shape x = self.patch_embed(x) if masks is not None: x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) x = x + self.interpolate_pos_encoding(x, w, h) if self.register_tokens is not None: x = torch.cat((x[:, :1], self.register_tokens.expand(x.shape[0], -1, -1), x[:, 1:]), dim=1) return x def forward_features_list(self, x_list, masks_list): x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] for blk in self.blocks: if self.training: x = checkpoint(blk, x, use_reentrant=self.use_reentrant) else: x = blk(x) all_x = x output = [] for x, masks in zip(all_x, masks_list): x_norm = self.norm(x) output.append( { "x_norm_clstoken": x_norm[:, 0], "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], "x_prenorm": x, "masks": masks, } ) return output def forward_features(self, x, masks=None): if isinstance(x, list): return self.forward_features_list(x, masks) x = self.prepare_tokens_with_masks(x, masks) for blk in self.blocks: if self.training: x = checkpoint(blk, x, use_reentrant=self.use_reentrant) else: x = blk(x) x_norm = self.norm(x) return { "x_norm_clstoken": x_norm[:, 0], "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], "x_prenorm": x, "masks": masks, } def _get_intermediate_layers_not_chunked(self, x, n=1): x = self.prepare_tokens_with_masks(x) # If n is an int, take the n last blocks. If it's a list, take them output, total_block_len = [], len(self.blocks) blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n for i, blk in enumerate(self.blocks): x = blk(x) if i in blocks_to_take: output.append(x) assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" return output def _get_intermediate_layers_chunked(self, x, n=1): x = self.prepare_tokens_with_masks(x) output, i, total_block_len = [], 0, len(self.blocks[-1]) # If n is an int, take the n last blocks. If it's a list, take them blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n for block_chunk in self.blocks: for blk in block_chunk[i:]: # Passing the nn.Identity() x = blk(x) if i in blocks_to_take: output.append(x) i += 1 assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" return output def get_intermediate_layers( self, x: torch.Tensor, n: Union[int, Sequence] = 1, # Layers or n last layers to take reshape: bool = False, return_class_token: bool = False, norm=True, ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: if self.chunked_blocks: outputs = self._get_intermediate_layers_chunked(x, n) else: outputs = self._get_intermediate_layers_not_chunked(x, n) if norm: outputs = [self.norm(out) for out in outputs] class_tokens = [out[:, 0] for out in outputs] outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs] if reshape: B, _, w, h = x.shape outputs = [ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() for out in outputs ] if return_class_token: return tuple(zip(outputs, class_tokens)) return tuple(outputs) def forward(self, *args, is_training=True, **kwargs): ret = self.forward_features(*args, **kwargs) if is_training: return ret else: return self.head(ret["x_norm_clstoken"]) def init_weights_vit_timm(module: nn.Module, name: str = ""): """ViT weight initialization, original timm impl (for reproducibility)""" if isinstance(module, nn.Linear): trunc_normal_(module.weight, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) def vit_small(patch_size=16, num_register_tokens=0, **kwargs): model = DinoVisionTransformer( patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, block_fn=partial(Block, attn_class=MemEffAttention), num_register_tokens=num_register_tokens, **kwargs, ) return model def vit_base(patch_size=16, num_register_tokens=0, **kwargs): model = DinoVisionTransformer( patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, block_fn=partial(Block, attn_class=MemEffAttention), num_register_tokens=num_register_tokens, **kwargs, ) return model def vit_large(patch_size=16, num_register_tokens=0, **kwargs): model = DinoVisionTransformer( patch_size=patch_size, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, block_fn=partial(Block, attn_class=MemEffAttention), num_register_tokens=num_register_tokens, **kwargs, ) return model def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs): """ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 """ model = DinoVisionTransformer( patch_size=patch_size, embed_dim=1536, depth=40, num_heads=24, mlp_ratio=4, block_fn=partial(Block, attn_class=MemEffAttention), num_register_tokens=num_register_tokens, **kwargs, ) return model ================================================ FILE: vggt/models/aggregator.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import logging import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.checkpoint import checkpoint from typing import Optional, Tuple, Union, List, Dict, Any from vggt.layers import PatchEmbed from vggt.layers.block import Block from vggt.layers.rope import RotaryPositionEmbedding2D, PositionGetter from vggt.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2 logger = logging.getLogger(__name__) _RESNET_MEAN = [0.485, 0.456, 0.406] _RESNET_STD = [0.229, 0.224, 0.225] class Aggregator(nn.Module): """ The Aggregator applies alternating-attention over input frames, as described in VGGT: Visual Geometry Grounded Transformer. Remember to set model.train() to enable gradient checkpointing to reduce memory usage. Args: img_size (int): Image size in pixels. patch_size (int): Size of each patch for PatchEmbed. embed_dim (int): Dimension of the token embeddings. depth (int): Number of blocks. num_heads (int): Number of attention heads. mlp_ratio (float): Ratio of MLP hidden dim to embedding dim. num_register_tokens (int): Number of register tokens. block_fn (nn.Module): The block type used for attention (Block by default). qkv_bias (bool): Whether to include bias in QKV projections. proj_bias (bool): Whether to include bias in the output projection. ffn_bias (bool): Whether to include bias in MLP layers. patch_embed (str): Type of patch embed. e.g., "conv" or "dinov2_vitl14_reg". aa_order (list[str]): The order of alternating attention, e.g. ["frame", "global"]. aa_block_size (int): How many blocks to group under each attention type before switching. If not necessary, set to 1. qk_norm (bool): Whether to apply QK normalization. rope_freq (int): Base frequency for rotary embedding. -1 to disable. init_values (float): Init scale for layer scale. """ def __init__( self, img_size=518, patch_size=14, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4.0, num_register_tokens=4, block_fn=Block, qkv_bias=True, proj_bias=True, ffn_bias=True, patch_embed="dinov2_vitl14_reg", aa_order=["frame", "global"], aa_block_size=1, qk_norm=True, rope_freq=100, init_values=0.01, ): super().__init__() self.__build_patch_embed__(patch_embed, img_size, patch_size, num_register_tokens, embed_dim=embed_dim) # Initialize rotary position embedding if frequency > 0 self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None self.position_getter = PositionGetter() if self.rope is not None else None self.frame_blocks = nn.ModuleList( [ block_fn( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, proj_bias=proj_bias, ffn_bias=ffn_bias, init_values=init_values, qk_norm=qk_norm, rope=self.rope, ) for _ in range(depth) ] ) self.global_blocks = nn.ModuleList( [ block_fn( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, proj_bias=proj_bias, ffn_bias=ffn_bias, init_values=init_values, qk_norm=qk_norm, rope=self.rope, ) for _ in range(depth) ] ) self.depth = depth self.aa_order = aa_order self.patch_size = patch_size self.aa_block_size = aa_block_size # Validate that depth is divisible by aa_block_size if self.depth % self.aa_block_size != 0: raise ValueError(f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})") self.aa_block_num = self.depth // self.aa_block_size # Note: We have two camera tokens, one for the first frame and one for the rest # The same applies for register tokens self.camera_token = nn.Parameter(torch.randn(1, 2, 1, embed_dim)) self.register_token = nn.Parameter(torch.randn(1, 2, num_register_tokens, embed_dim)) # The patch tokens start after the camera and register tokens self.patch_start_idx = 1 + num_register_tokens # Initialize parameters with small values nn.init.normal_(self.camera_token, std=1e-6) nn.init.normal_(self.register_token, std=1e-6) # Register normalization constants as buffers for name, value in (("_resnet_mean", _RESNET_MEAN), ("_resnet_std", _RESNET_STD)): self.register_buffer(name, torch.FloatTensor(value).view(1, 1, 3, 1, 1), persistent=False) self.use_reentrant = False # hardcoded to False def __build_patch_embed__( self, patch_embed, img_size, patch_size, num_register_tokens, interpolate_antialias=True, interpolate_offset=0.0, block_chunks=0, init_values=1.0, embed_dim=1024, ): """ Build the patch embed layer. If 'conv', we use a simple PatchEmbed conv layer. Otherwise, we use a vision transformer. """ if "conv" in patch_embed: self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=3, embed_dim=embed_dim) else: vit_models = { "dinov2_vitl14_reg": vit_large, "dinov2_vitb14_reg": vit_base, "dinov2_vits14_reg": vit_small, "dinov2_vitg2_reg": vit_giant2, } self.patch_embed = vit_models[patch_embed]( img_size=img_size, patch_size=patch_size, num_register_tokens=num_register_tokens, interpolate_antialias=interpolate_antialias, interpolate_offset=interpolate_offset, block_chunks=block_chunks, init_values=init_values, ) # Disable gradient updates for mask token if hasattr(self.patch_embed, "mask_token"): self.patch_embed.mask_token.requires_grad_(False) def forward(self, images: torch.Tensor) -> Tuple[List[torch.Tensor], int]: """ Args: images (torch.Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1]. B: batch size, S: sequence length, 3: RGB channels, H: height, W: width Returns: (list[torch.Tensor], int): The list of outputs from the attention blocks, and the patch_start_idx indicating where patch tokens begin. """ B, S, C_in, H, W = images.shape if C_in != 3: raise ValueError(f"Expected 3 input channels, got {C_in}") # Normalize images and reshape for patch embed images = (images - self._resnet_mean) / self._resnet_std # Reshape to [B*S, C, H, W] for patch embedding images = images.view(B * S, C_in, H, W) patch_tokens = self.patch_embed(images) if isinstance(patch_tokens, dict): patch_tokens = patch_tokens["x_norm_patchtokens"] _, P, C = patch_tokens.shape # Expand camera and register tokens to match batch size and sequence length camera_token = slice_expand_and_flatten(self.camera_token, B, S) register_token = slice_expand_and_flatten(self.register_token, B, S) # Concatenate special tokens with patch tokens tokens = torch.cat([camera_token, register_token, patch_tokens], dim=1) pos = None if self.rope is not None: pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=images.device) if self.patch_start_idx > 0: # do not use position embedding for special tokens (camera and register tokens) # so set pos to 0 for the special tokens pos = pos + 1 pos_special = torch.zeros(B * S, self.patch_start_idx, 2).to(images.device).to(pos.dtype) pos = torch.cat([pos_special, pos], dim=1) # update P because we added special tokens _, P, C = tokens.shape frame_idx = 0 global_idx = 0 output_list = [] for _ in range(self.aa_block_num): for attn_type in self.aa_order: if attn_type == "frame": tokens, frame_idx, frame_intermediates = self._process_frame_attention( tokens, B, S, P, C, frame_idx, pos=pos ) elif attn_type == "global": tokens, global_idx, global_intermediates = self._process_global_attention( tokens, B, S, P, C, global_idx, pos=pos ) else: raise ValueError(f"Unknown attention type: {attn_type}") for i in range(len(frame_intermediates)): # concat frame and global intermediates, [B x S x P x 2C] concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1) output_list.append(concat_inter) del concat_inter del frame_intermediates del global_intermediates return output_list, self.patch_start_idx def _process_frame_attention(self, tokens, B, S, P, C, frame_idx, pos=None): """ Process frame attention blocks. We keep tokens in shape (B*S, P, C). """ # If needed, reshape tokens or positions: if tokens.shape != (B * S, P, C): tokens = tokens.view(B, S, P, C).view(B * S, P, C) if pos is not None and pos.shape != (B * S, P, 2): pos = pos.view(B, S, P, 2).view(B * S, P, 2) intermediates = [] # by default, self.aa_block_size=1, which processes one block at a time for _ in range(self.aa_block_size): if self.training: tokens = checkpoint(self.frame_blocks[frame_idx], tokens, pos, use_reentrant=self.use_reentrant) else: tokens = self.frame_blocks[frame_idx](tokens, pos=pos) frame_idx += 1 intermediates.append(tokens.view(B, S, P, C)) return tokens, frame_idx, intermediates def _process_global_attention(self, tokens, B, S, P, C, global_idx, pos=None): """ Process global attention blocks. We keep tokens in shape (B, S*P, C). """ if tokens.shape != (B, S * P, C): tokens = tokens.view(B, S, P, C).view(B, S * P, C) if pos is not None and pos.shape != (B, S * P, 2): pos = pos.view(B, S, P, 2).view(B, S * P, 2) intermediates = [] # by default, self.aa_block_size=1, which processes one block at a time for _ in range(self.aa_block_size): if self.training: tokens = checkpoint(self.global_blocks[global_idx], tokens, pos, use_reentrant=self.use_reentrant) else: tokens = self.global_blocks[global_idx](tokens, pos=pos) global_idx += 1 intermediates.append(tokens.view(B, S, P, C)) return tokens, global_idx, intermediates def slice_expand_and_flatten(token_tensor, B, S): """ Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing: 1) Uses the first position (index=0) for the first frame only 2) Uses the second position (index=1) for all remaining frames (S-1 frames) 3) Expands both to match batch size B 4) Concatenates to form (B, S, X, C) where each sequence has 1 first-position token followed by (S-1) second-position tokens 5) Flattens to (B*S, X, C) for processing Returns: torch.Tensor: Processed tokens with shape (B*S, X, C) """ # Slice out the "query" tokens => shape (1, 1, ...) query = token_tensor[:, 0:1, ...].expand(B, 1, *token_tensor.shape[2:]) # Slice out the "other" tokens => shape (1, S-1, ...) others = token_tensor[:, 1:, ...].expand(B, S - 1, *token_tensor.shape[2:]) # Concatenate => shape (B, S, ...) combined = torch.cat([query, others], dim=1) # Finally flatten => shape (B*S, ...) combined = combined.view(B * S, *combined.shape[2:]) return combined ================================================ FILE: vggt/models/vggt.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch import torch.nn as nn from huggingface_hub import PyTorchModelHubMixin # used for model hub from vggt.models.aggregator import Aggregator from vggt.heads.camera_head import CameraHead from vggt.heads.dpt_head import DPTHead from vggt.heads.track_head import TrackHead class VGGT(nn.Module, PyTorchModelHubMixin): def __init__(self, img_size=518, patch_size=14, embed_dim=1024, enable_camera=True, enable_point=True, enable_depth=True, enable_track=True): super().__init__() self.aggregator = Aggregator(img_size=img_size, patch_size=patch_size, embed_dim=embed_dim) self.camera_head = CameraHead(dim_in=2 * embed_dim) if enable_camera else None self.point_head = DPTHead(dim_in=2 * embed_dim, output_dim=4, activation="inv_log", conf_activation="expp1") if enable_point else None self.depth_head = DPTHead(dim_in=2 * embed_dim, output_dim=2, activation="exp", conf_activation="expp1") if enable_depth else None self.track_head = TrackHead(dim_in=2 * embed_dim, patch_size=patch_size) if enable_track else None def forward(self, images: torch.Tensor, query_points: torch.Tensor = None): """ Forward pass of the VGGT model. Args: images (torch.Tensor): Input images with shape [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1]. B: batch size, S: sequence length, 3: RGB channels, H: height, W: width query_points (torch.Tensor, optional): Query points for tracking, in pixel coordinates. Shape: [N, 2] or [B, N, 2], where N is the number of query points. Default: None Returns: dict: A dictionary containing the following predictions: - pose_enc (torch.Tensor): Camera pose encoding with shape [B, S, 9] (from the last iteration) - depth (torch.Tensor): Predicted depth maps with shape [B, S, H, W, 1] - depth_conf (torch.Tensor): Confidence scores for depth predictions with shape [B, S, H, W] - world_points (torch.Tensor): 3D world coordinates for each pixel with shape [B, S, H, W, 3] - world_points_conf (torch.Tensor): Confidence scores for world points with shape [B, S, H, W] - images (torch.Tensor): Original input images, preserved for visualization If query_points is provided, also includes: - track (torch.Tensor): Point tracks with shape [B, S, N, 2] (from the last iteration), in pixel coordinates - vis (torch.Tensor): Visibility scores for tracked points with shape [B, S, N] - conf (torch.Tensor): Confidence scores for tracked points with shape [B, S, N] """ # If without batch dimension, add it if len(images.shape) == 4: images = images.unsqueeze(0) if query_points is not None and len(query_points.shape) == 2: query_points = query_points.unsqueeze(0) aggregated_tokens_list, patch_start_idx = self.aggregator(images) predictions = {} with torch.cuda.amp.autocast(enabled=False): if self.camera_head is not None: pose_enc_list = self.camera_head(aggregated_tokens_list) predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration predictions["pose_enc_list"] = pose_enc_list if self.depth_head is not None: depth, depth_conf = self.depth_head( aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx ) predictions["depth"] = depth predictions["depth_conf"] = depth_conf if self.point_head is not None: pts3d, pts3d_conf = self.point_head( aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx ) predictions["world_points"] = pts3d predictions["world_points_conf"] = pts3d_conf if self.track_head is not None and query_points is not None: track_list, vis, conf = self.track_head( aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx, query_points=query_points ) predictions["track"] = track_list[-1] # track of the last iteration predictions["vis"] = vis predictions["conf"] = conf if not self.training: predictions["images"] = images # store the images for visualization during inference return predictions ================================================ FILE: vggt/utils/geometry.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import os import torch import numpy as np from vggt.dependency.distortion import apply_distortion, iterative_undistortion, single_undistortion def unproject_depth_map_to_point_map( depth_map: np.ndarray, extrinsics_cam: np.ndarray, intrinsics_cam: np.ndarray ) -> np.ndarray: """ Unproject a batch of depth maps to 3D world coordinates. Args: depth_map (np.ndarray): Batch of depth maps of shape (S, H, W, 1) or (S, H, W) extrinsics_cam (np.ndarray): Batch of camera extrinsic matrices of shape (S, 3, 4) intrinsics_cam (np.ndarray): Batch of camera intrinsic matrices of shape (S, 3, 3) Returns: np.ndarray: Batch of 3D world coordinates of shape (S, H, W, 3) """ if isinstance(depth_map, torch.Tensor): depth_map = depth_map.cpu().numpy() if isinstance(extrinsics_cam, torch.Tensor): extrinsics_cam = extrinsics_cam.cpu().numpy() if isinstance(intrinsics_cam, torch.Tensor): intrinsics_cam = intrinsics_cam.cpu().numpy() world_points_list = [] for frame_idx in range(depth_map.shape[0]): cur_world_points, _, _ = depth_to_world_coords_points( depth_map[frame_idx].squeeze(-1), extrinsics_cam[frame_idx], intrinsics_cam[frame_idx] ) world_points_list.append(cur_world_points) world_points_array = np.stack(world_points_list, axis=0) return world_points_array def depth_to_world_coords_points( depth_map: np.ndarray, extrinsic: np.ndarray, intrinsic: np.ndarray, eps=1e-8, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """ Convert a depth map to world coordinates. Args: depth_map (np.ndarray): Depth map of shape (H, W). intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3). extrinsic (np.ndarray): Camera extrinsic matrix of shape (3, 4). OpenCV camera coordinate convention, cam from world. Returns: tuple[np.ndarray, np.ndarray]: World coordinates (H, W, 3) and valid depth mask (H, W). """ if depth_map is None: return None, None, None # Valid depth mask point_mask = depth_map > eps # Convert depth map to camera coordinates cam_coords_points = depth_to_cam_coords_points(depth_map, intrinsic) # Multiply with the inverse of extrinsic matrix to transform to world coordinates # extrinsic_inv is 4x4 (note closed_form_inverse_OpenCV is batched, the output is (N, 4, 4)) cam_to_world_extrinsic = closed_form_inverse_se3(extrinsic[None])[0] R_cam_to_world = cam_to_world_extrinsic[:3, :3] t_cam_to_world = cam_to_world_extrinsic[:3, 3] # Apply the rotation and translation to the camera coordinates world_coords_points = np.dot(cam_coords_points, R_cam_to_world.T) + t_cam_to_world # HxWx3, 3x3 -> HxWx3 # world_coords_points = np.einsum("ij,hwj->hwi", R_cam_to_world, cam_coords_points) + t_cam_to_world return world_coords_points, cam_coords_points, point_mask def depth_to_cam_coords_points(depth_map: np.ndarray, intrinsic: np.ndarray) -> tuple[np.ndarray, np.ndarray]: """ Convert a depth map to camera coordinates. Args: depth_map (np.ndarray): Depth map of shape (H, W). intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3). Returns: tuple[np.ndarray, np.ndarray]: Camera coordinates (H, W, 3) """ H, W = depth_map.shape assert intrinsic.shape == (3, 3), "Intrinsic matrix must be 3x3" assert intrinsic[0, 1] == 0 and intrinsic[1, 0] == 0, "Intrinsic matrix must have zero skew" # Intrinsic parameters fu, fv = intrinsic[0, 0], intrinsic[1, 1] cu, cv = intrinsic[0, 2], intrinsic[1, 2] # Generate grid of pixel coordinates u, v = np.meshgrid(np.arange(W), np.arange(H)) # Unproject to camera coordinates x_cam = (u - cu) * depth_map / fu y_cam = (v - cv) * depth_map / fv z_cam = depth_map # Stack to form camera coordinates cam_coords = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32) return cam_coords def closed_form_inverse_se3(se3, R=None, T=None): """ Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch. If `R` and `T` are provided, they must correspond to the rotation and translation components of `se3`. Otherwise, they will be extracted from `se3`. Args: se3: Nx4x4 or Nx3x4 array or tensor of SE3 matrices. R (optional): Nx3x3 array or tensor of rotation matrices. T (optional): Nx3x1 array or tensor of translation vectors. Returns: Inverted SE3 matrices with the same type and device as `se3`. Shapes: se3: (N, 4, 4) R: (N, 3, 3) T: (N, 3, 1) """ # Check if se3 is a numpy array or a torch tensor is_numpy = isinstance(se3, np.ndarray) # Validate shapes if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4): raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.") # Extract R and T if not provided if R is None: R = se3[:, :3, :3] # (N,3,3) if T is None: T = se3[:, :3, 3:] # (N,3,1) # Transpose R if is_numpy: # Compute the transpose of the rotation for NumPy R_transposed = np.transpose(R, (0, 2, 1)) # -R^T t for NumPy top_right = -np.matmul(R_transposed, T) inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1)) else: R_transposed = R.transpose(1, 2) # (N,3,3) top_right = -torch.bmm(R_transposed, T) # (N,3,1) inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1) inverted_matrix = inverted_matrix.to(R.dtype).to(R.device) inverted_matrix[:, :3, :3] = R_transposed inverted_matrix[:, :3, 3:] = top_right return inverted_matrix # TODO: this code can be further cleaned up def project_world_points_to_camera_points_batch(world_points, cam_extrinsics): """ Transforms 3D points to 2D using extrinsic and intrinsic parameters. Args: world_points (torch.Tensor): 3D points of shape BxSxHxWx3. cam_extrinsics (torch.Tensor): Extrinsic parameters of shape BxSx3x4. Returns: """ # TODO: merge this into project_world_points_to_cam # device = world_points.device # with torch.autocast(device_type=device.type, enabled=False): ones = torch.ones_like(world_points[..., :1]) # shape: (B, S, H, W, 1) world_points_h = torch.cat([world_points, ones], dim=-1) # shape: (B, S, H, W, 4) # extrinsics: (B, S, 3, 4) -> (B, S, 1, 1, 3, 4) extrinsics_exp = cam_extrinsics.unsqueeze(2).unsqueeze(3) # world_points_h: (B, S, H, W, 4) -> (B, S, H, W, 4, 1) world_points_h_exp = world_points_h.unsqueeze(-1) # Now perform the matrix multiplication # (B, S, 1, 1, 3, 4) @ (B, S, H, W, 4, 1) broadcasts to (B, S, H, W, 3, 1) camera_points = torch.matmul(extrinsics_exp, world_points_h_exp).squeeze(-1) return camera_points def project_world_points_to_cam( world_points, cam_extrinsics, cam_intrinsics=None, distortion_params=None, default=0, only_points_cam=False, ): """ Transforms 3D points to 2D using extrinsic and intrinsic parameters. Args: world_points (torch.Tensor): 3D points of shape Px3. cam_extrinsics (torch.Tensor): Extrinsic parameters of shape Bx3x4. cam_intrinsics (torch.Tensor): Intrinsic parameters of shape Bx3x3. distortion_params (torch.Tensor): Extra parameters of shape BxN, which is used for radial distortion. Returns: torch.Tensor: Transformed 2D points of shape BxNx2. """ device = world_points.device # with torch.autocast(device_type=device.type, dtype=torch.double): with torch.autocast(device_type=device.type, enabled=False): N = world_points.shape[0] # Number of points B = cam_extrinsics.shape[0] # Batch size, i.e., number of cameras world_points_homogeneous = torch.cat( [world_points, torch.ones_like(world_points[..., 0:1])], dim=1 ) # Nx4 # Reshape for batch processing world_points_homogeneous = world_points_homogeneous.unsqueeze(0).expand( B, -1, -1 ) # BxNx4 # Step 1: Apply extrinsic parameters # Transform 3D points to camera coordinate system for all cameras cam_points = torch.bmm( cam_extrinsics, world_points_homogeneous.transpose(-1, -2) ) if only_points_cam: return None, cam_points # Step 2: Apply intrinsic parameters and (optional) distortion image_points = img_from_cam(cam_intrinsics, cam_points, distortion_params, default=default) return image_points, cam_points def img_from_cam(cam_intrinsics, cam_points, distortion_params=None, default=0.0): """ Applies intrinsic parameters and optional distortion to the given 3D points. Args: cam_intrinsics (torch.Tensor): Intrinsic camera parameters of shape Bx3x3. cam_points (torch.Tensor): 3D points in camera coordinates of shape Bx3xN. distortion_params (torch.Tensor, optional): Distortion parameters of shape BxN, where N can be 1, 2, or 4. default (float, optional): Default value to replace NaNs in the output. Returns: pixel_coords (torch.Tensor): 2D points in pixel coordinates of shape BxNx2. """ # Normalized device coordinates (NDC) cam_points = cam_points / cam_points[:, 2:3, :] ndc_xy = cam_points[:, :2, :] # Apply distortion if distortion_params are provided if distortion_params is not None: x_distorted, y_distorted = apply_distortion(distortion_params, ndc_xy[:, 0], ndc_xy[:, 1]) distorted_xy = torch.stack([x_distorted, y_distorted], dim=1) else: distorted_xy = ndc_xy # Prepare cam_points for batch matrix multiplication cam_coords_homo = torch.cat( (distorted_xy, torch.ones_like(distorted_xy[:, :1, :])), dim=1 ) # Bx3xN # Apply intrinsic parameters using batch matrix multiplication pixel_coords = torch.bmm(cam_intrinsics, cam_coords_homo) # Bx3xN # Extract x and y coordinates pixel_coords = pixel_coords[:, :2, :] # Bx2xN # Replace NaNs with default value pixel_coords = torch.nan_to_num(pixel_coords, nan=default) return pixel_coords.transpose(1, 2) # BxNx2 def cam_from_img(pred_tracks, intrinsics, extra_params=None): """ Normalize predicted tracks based on camera intrinsics. Args: intrinsics (torch.Tensor): The camera intrinsics tensor of shape [batch_size, 3, 3]. pred_tracks (torch.Tensor): The predicted tracks tensor of shape [batch_size, num_tracks, 2]. extra_params (torch.Tensor, optional): Distortion parameters of shape BxN, where N can be 1, 2, or 4. Returns: torch.Tensor: Normalized tracks tensor. """ # We don't want to do intrinsics_inv = torch.inverse(intrinsics) here # otherwise we can use something like # tracks_normalized_homo = torch.bmm(pred_tracks_homo, intrinsics_inv.transpose(1, 2)) principal_point = intrinsics[:, [0, 1], [2, 2]].unsqueeze(-2) focal_length = intrinsics[:, [0, 1], [0, 1]].unsqueeze(-2) tracks_normalized = (pred_tracks - principal_point) / focal_length if extra_params is not None: # Apply iterative undistortion try: tracks_normalized = iterative_undistortion( extra_params, tracks_normalized ) except: tracks_normalized = single_undistortion( extra_params, tracks_normalized ) return tracks_normalized ================================================ FILE: vggt/utils/helper.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import numpy as np def randomly_limit_trues(mask: np.ndarray, max_trues: int) -> np.ndarray: """ If mask has more than max_trues True values, randomly keep only max_trues of them and set the rest to False. """ # 1D positions of all True entries true_indices = np.flatnonzero(mask) # shape = (N_true,) # if already within budget, return as-is if true_indices.size <= max_trues: return mask # randomly pick which True positions to keep sampled_indices = np.random.choice(true_indices, size=max_trues, replace=False) # shape = (max_trues,) # build new flat mask: True only at sampled positions limited_flat_mask = np.zeros(mask.size, dtype=bool) limited_flat_mask[sampled_indices] = True # restore original shape return limited_flat_mask.reshape(mask.shape) def create_pixel_coordinate_grid(num_frames, height, width): """ Creates a grid of pixel coordinates and frame indices for all frames. Returns: tuple: A tuple containing: - points_xyf (numpy.ndarray): Array of shape (num_frames, height, width, 3) with x, y coordinates and frame indices - y_coords (numpy.ndarray): Array of y coordinates for all frames - x_coords (numpy.ndarray): Array of x coordinates for all frames - f_coords (numpy.ndarray): Array of frame indices for all frames """ # Create coordinate grids for a single frame y_grid, x_grid = np.indices((height, width), dtype=np.float32) x_grid = x_grid[np.newaxis, :, :] y_grid = y_grid[np.newaxis, :, :] # Broadcast to all frames x_coords = np.broadcast_to(x_grid, (num_frames, height, width)) y_coords = np.broadcast_to(y_grid, (num_frames, height, width)) # Create frame indices and broadcast f_idx = np.arange(num_frames, dtype=np.float32)[:, np.newaxis, np.newaxis] f_coords = np.broadcast_to(f_idx, (num_frames, height, width)) # Stack coordinates and frame indices points_xyf = np.stack((x_coords, y_coords, f_coords), axis=-1) return points_xyf ================================================ FILE: vggt/utils/load_fn.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch from PIL import Image from torchvision import transforms as TF import numpy as np def load_and_preprocess_images_square(image_path_list, target_size=1024): """ Load and preprocess images by center padding to square and resizing to target size. Also returns the position information of original pixels after transformation. Args: image_path_list (list): List of paths to image files target_size (int, optional): Target size for both width and height. Defaults to 518. Returns: tuple: ( torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, target_size, target_size), torch.Tensor: Array of shape (N, 5) containing [x1, y1, x2, y2, width, height] for each image ) Raises: ValueError: If the input list is empty """ # Check for empty list if len(image_path_list) == 0: raise ValueError("At least 1 image is required") images = [] original_coords = [] # Renamed from position_info to be more descriptive to_tensor = TF.ToTensor() for image_path in image_path_list: # Open image img = Image.open(image_path) # If there's an alpha channel, blend onto white background if img.mode == "RGBA": background = Image.new("RGBA", img.size, (255, 255, 255, 255)) img = Image.alpha_composite(background, img) # Convert to RGB img = img.convert("RGB") # Get original dimensions width, height = img.size # Make the image square by padding the shorter dimension max_dim = max(width, height) # Calculate padding left = (max_dim - width) // 2 top = (max_dim - height) // 2 # Calculate scale factor for resizing scale = target_size / max_dim # Calculate final coordinates of original image in target space x1 = left * scale y1 = top * scale x2 = (left + width) * scale y2 = (top + height) * scale # Store original image coordinates and scale original_coords.append(np.array([x1, y1, x2, y2, width, height])) # Create a new black square image and paste original square_img = Image.new("RGB", (max_dim, max_dim), (0, 0, 0)) square_img.paste(img, (left, top)) # Resize to target size square_img = square_img.resize((target_size, target_size), Image.Resampling.BICUBIC) # Convert to tensor img_tensor = to_tensor(square_img) images.append(img_tensor) # Stack all images images = torch.stack(images) original_coords = torch.from_numpy(np.array(original_coords)).float() # Add additional dimension if single image to ensure correct shape if len(image_path_list) == 1: if images.dim() == 3: images = images.unsqueeze(0) original_coords = original_coords.unsqueeze(0) return images, original_coords def load_and_preprocess_images(image_path_list, mode="crop"): """ A quick start function to load and preprocess images for model input. This assumes the images should have the same shape for easier batching, but our model can also work well with different shapes. Args: image_path_list (list): List of paths to image files mode (str, optional): Preprocessing mode, either "crop" or "pad". - "crop" (default): Sets width to 518px and center crops height if needed. - "pad": Preserves all pixels by making the largest dimension 518px and padding the smaller dimension to reach a square shape. Returns: torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W) Raises: ValueError: If the input list is empty or if mode is invalid Notes: - Images with different dimensions will be padded with white (value=1.0) - A warning is printed when images have different shapes - When mode="crop": The function ensures width=518px while maintaining aspect ratio and height is center-cropped if larger than 518px - When mode="pad": The function ensures the largest dimension is 518px while maintaining aspect ratio and the smaller dimension is padded to reach a square shape (518x518) - Dimensions are adjusted to be divisible by 14 for compatibility with model requirements """ # Check for empty list if len(image_path_list) == 0: raise ValueError("At least 1 image is required") # Validate mode if mode not in ["crop", "pad"]: raise ValueError("Mode must be either 'crop' or 'pad'") images = [] shapes = set() to_tensor = TF.ToTensor() target_size = 518 # First process all images and collect their shapes for image_path in image_path_list: # Open image img = Image.open(image_path) # If there's an alpha channel, blend onto white background: if img.mode == "RGBA": # Create white background background = Image.new("RGBA", img.size, (255, 255, 255, 255)) # Alpha composite onto the white background img = Image.alpha_composite(background, img) # Now convert to "RGB" (this step assigns white for transparent areas) img = img.convert("RGB") width, height = img.size if mode == "pad": # Make the largest dimension 518px while maintaining aspect ratio if width >= height: new_width = target_size new_height = round(height * (new_width / width) / 14) * 14 # Make divisible by 14 else: new_height = target_size new_width = round(width * (new_height / height) / 14) * 14 # Make divisible by 14 else: # mode == "crop" # Original behavior: set width to 518px new_width = target_size # Calculate height maintaining aspect ratio, divisible by 14 new_height = round(height * (new_width / width) / 14) * 14 # Resize with new dimensions (width, height) img = img.resize((new_width, new_height), Image.Resampling.BICUBIC) img = to_tensor(img) # Convert to tensor (0, 1) # Center crop height if it's larger than 518 (only in crop mode) if mode == "crop" and new_height > target_size: start_y = (new_height - target_size) // 2 img = img[:, start_y : start_y + target_size, :] # For pad mode, pad to make a square of target_size x target_size if mode == "pad": h_padding = target_size - img.shape[1] w_padding = target_size - img.shape[2] if h_padding > 0 or w_padding > 0: pad_top = h_padding // 2 pad_bottom = h_padding - pad_top pad_left = w_padding // 2 pad_right = w_padding - pad_left # Pad with white (value=1.0) img = torch.nn.functional.pad( img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0 ) shapes.add((img.shape[1], img.shape[2])) images.append(img) # Check if we have different shapes # In theory our model can also work well with different shapes if len(shapes) > 1: print(f"Warning: Found images with different shapes: {shapes}") # Find maximum dimensions max_height = max(shape[0] for shape in shapes) max_width = max(shape[1] for shape in shapes) # Pad images if necessary padded_images = [] for img in images: h_padding = max_height - img.shape[1] w_padding = max_width - img.shape[2] if h_padding > 0 or w_padding > 0: pad_top = h_padding // 2 pad_bottom = h_padding - pad_top pad_left = w_padding // 2 pad_right = w_padding - pad_left img = torch.nn.functional.pad( img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0 ) padded_images.append(img) images = padded_images images = torch.stack(images) # concatenate images # Ensure correct shape when single image if len(image_path_list) == 1: # Verify shape is (1, C, H, W) if images.dim() == 3: images = images.unsqueeze(0) return images ================================================ FILE: vggt/utils/pose_enc.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch from .rotation import quat_to_mat, mat_to_quat def extri_intri_to_pose_encoding( extrinsics, intrinsics, image_size_hw=None, pose_encoding_type="absT_quaR_FoV" # e.g., (256, 512) ): """Convert camera extrinsics and intrinsics to a compact pose encoding. This function transforms camera parameters into a unified pose encoding format, which can be used for various downstream tasks like pose prediction or representation. Args: extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4, where B is batch size and S is sequence length. In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world transformation. The format is [R|t] where R is a 3x3 rotation matrix and t is a 3x1 translation vector. intrinsics (torch.Tensor): Camera intrinsic parameters with shape BxSx3x3. Defined in pixels, with format: [[fx, 0, cx], [0, fy, cy], [0, 0, 1]] where fx, fy are focal lengths and (cx, cy) is the principal point image_size_hw (tuple): Tuple of (height, width) of the image in pixels. Required for computing field of view values. For example: (256, 512). pose_encoding_type (str): Type of pose encoding to use. Currently only supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view). Returns: torch.Tensor: Encoded camera pose parameters with shape BxSx9. For "absT_quaR_FoV" type, the 9 dimensions are: - [:3] = absolute translation vector T (3D) - [3:7] = rotation as quaternion quat (4D) - [7:] = field of view (2D) """ # extrinsics: BxSx3x4 # intrinsics: BxSx3x3 if pose_encoding_type == "absT_quaR_FoV": R = extrinsics[:, :, :3, :3] # BxSx3x3 T = extrinsics[:, :, :3, 3] # BxSx3 quat = mat_to_quat(R) # Note the order of h and w here H, W = image_size_hw fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1]) fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0]) pose_encoding = torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float() else: raise NotImplementedError return pose_encoding def pose_encoding_to_extri_intri( pose_encoding, image_size_hw=None, pose_encoding_type="absT_quaR_FoV", build_intrinsics=True # e.g., (256, 512) ): """Convert a pose encoding back to camera extrinsics and intrinsics. This function performs the inverse operation of extri_intri_to_pose_encoding, reconstructing the full camera parameters from the compact encoding. Args: pose_encoding (torch.Tensor): Encoded camera pose parameters with shape BxSx9, where B is batch size and S is sequence length. For "absT_quaR_FoV" type, the 9 dimensions are: - [:3] = absolute translation vector T (3D) - [3:7] = rotation as quaternion quat (4D) - [7:] = field of view (2D) image_size_hw (tuple): Tuple of (height, width) of the image in pixels. Required for reconstructing intrinsics from field of view values. For example: (256, 512). pose_encoding_type (str): Type of pose encoding used. Currently only supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view). build_intrinsics (bool): Whether to reconstruct the intrinsics matrix. If False, only extrinsics are returned and intrinsics will be None. Returns: tuple: (extrinsics, intrinsics) - extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4. In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world transformation. The format is [R|t] where R is a 3x3 rotation matrix and t is a 3x1 translation vector. - intrinsics (torch.Tensor or None): Camera intrinsic parameters with shape BxSx3x3, or None if build_intrinsics is False. Defined in pixels, with format: [[fx, 0, cx], [0, fy, cy], [0, 0, 1]] where fx, fy are focal lengths and (cx, cy) is the principal point, assumed to be at the center of the image (W/2, H/2). """ intrinsics = None if pose_encoding_type == "absT_quaR_FoV": T = pose_encoding[..., :3] quat = pose_encoding[..., 3:7] fov_h = pose_encoding[..., 7] fov_w = pose_encoding[..., 8] R = quat_to_mat(quat) extrinsics = torch.cat([R, T[..., None]], dim=-1) if build_intrinsics: H, W = image_size_hw fy = (H / 2.0) / torch.tan(fov_h / 2.0) fx = (W / 2.0) / torch.tan(fov_w / 2.0) intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device) intrinsics[..., 0, 0] = fx intrinsics[..., 1, 1] = fy intrinsics[..., 0, 2] = W / 2 intrinsics[..., 1, 2] = H / 2 intrinsics[..., 2, 2] = 1.0 # Set the homogeneous coordinate to 1 else: raise NotImplementedError return extrinsics, intrinsics ================================================ FILE: vggt/utils/rotation.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # Modified from PyTorch3D, https://github.com/facebookresearch/pytorch3d import torch import numpy as np import torch.nn.functional as F def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor: """ Quaternion Order: XYZW or say ijkr, scalar-last Convert rotations given as quaternions to rotation matrices. Args: quaternions: quaternions with real part last, as tensor of shape (..., 4). Returns: Rotation matrices as tensor of shape (..., 3, 3). """ i, j, k, r = torch.unbind(quaternions, -1) # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`. two_s = 2.0 / (quaternions * quaternions).sum(-1) o = torch.stack( ( 1 - two_s * (j * j + k * k), two_s * (i * j - k * r), two_s * (i * k + j * r), two_s * (i * j + k * r), 1 - two_s * (i * i + k * k), two_s * (j * k - i * r), two_s * (i * k - j * r), two_s * (j * k + i * r), 1 - two_s * (i * i + j * j), ), -1, ) return o.reshape(quaternions.shape[:-1] + (3, 3)) def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor: """ Convert rotations given as rotation matrices to quaternions. Args: matrix: Rotation matrices as tensor of shape (..., 3, 3). Returns: quaternions with real part last, as tensor of shape (..., 4). Quaternion Order: XYZW or say ijkr, scalar-last """ if matrix.size(-1) != 3 or matrix.size(-2) != 3: raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") batch_dim = matrix.shape[:-2] m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1) q_abs = _sqrt_positive_part( torch.stack( [1.0 + m00 + m11 + m22, 1.0 + m00 - m11 - m22, 1.0 - m00 + m11 - m22, 1.0 - m00 - m11 + m22], dim=-1 ) ) # we produce the desired quaternion multiplied by each of r, i, j, k quat_by_rijk = torch.stack( [ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and # `int`. torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and # `int`. torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and # `int`. torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and # `int`. torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), ], dim=-2, ) # We floor here at 0.1 but the exact level is not important; if q_abs is small, # the candidate won't be picked. flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) # if not for numerical problems, quat_candidates[i] should be same (up to a sign), # forall i; we pick the best-conditioned one (with the largest denominator) out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4,)) # Convert from rijk to ijkr out = out[..., [1, 2, 3, 0]] out = standardize_quaternion(out) return out def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: """ Returns torch.sqrt(torch.max(0, x)) but with a zero subgradient where x is 0. """ ret = torch.zeros_like(x) positive_mask = x > 0 if torch.is_grad_enabled(): ret[positive_mask] = torch.sqrt(x[positive_mask]) else: ret = torch.where(positive_mask, torch.sqrt(x), ret) return ret def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: """ Convert a unit quaternion to a standard form: one in which the real part is non negative. Args: quaternions: Quaternions with real part last, as tensor of shape (..., 4). Returns: Standardized quaternions as tensor of shape (..., 4). """ return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions) ================================================ FILE: vggt/utils/visual_track.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import cv2 import torch import numpy as np import os def color_from_xy(x, y, W, H, cmap_name="hsv"): """ Map (x, y) -> color in (R, G, B). 1) Normalize x,y to [0,1]. 2) Combine them into a single scalar c in [0,1]. 3) Use matplotlib's colormap to convert c -> (R,G,B). You can customize step 2, e.g., c = (x + y)/2, or some function of (x, y). """ import matplotlib.cm import matplotlib.colors x_norm = x / max(W - 1, 1) y_norm = y / max(H - 1, 1) # Simple combination: c = (x_norm + y_norm) / 2.0 cmap = matplotlib.cm.get_cmap(cmap_name) # cmap(c) -> (r,g,b,a) in [0,1] rgba = cmap(c) r, g, b = rgba[0], rgba[1], rgba[2] return (r, g, b) # in [0,1], RGB order def get_track_colors_by_position(tracks_b, vis_mask_b=None, image_width=None, image_height=None, cmap_name="hsv"): """ Given all tracks in one sample (b), compute a (N,3) array of RGB color values in [0,255]. The color is determined by the (x,y) position in the first visible frame for each track. Args: tracks_b: Tensor of shape (S, N, 2). (x,y) for each track in each frame. vis_mask_b: (S, N) boolean mask; if None, assume all are visible. image_width, image_height: used for normalizing (x, y). cmap_name: for matplotlib (e.g., 'hsv', 'rainbow', 'jet'). Returns: track_colors: np.ndarray of shape (N, 3), each row is (R,G,B) in [0,255]. """ S, N, _ = tracks_b.shape track_colors = np.zeros((N, 3), dtype=np.uint8) if vis_mask_b is None: # treat all as visible vis_mask_b = torch.ones(S, N, dtype=torch.bool, device=tracks_b.device) for i in range(N): # Find first visible frame for track i visible_frames = torch.where(vis_mask_b[:, i])[0] if len(visible_frames) == 0: # track is never visible; just assign black or something track_colors[i] = (0, 0, 0) continue first_s = int(visible_frames[0].item()) # use that frame's (x,y) x, y = tracks_b[first_s, i].tolist() # map (x,y) -> (R,G,B) in [0,1] r, g, b = color_from_xy(x, y, W=image_width, H=image_height, cmap_name=cmap_name) # scale to [0,255] r, g, b = int(r * 255), int(g * 255), int(b * 255) track_colors[i] = (r, g, b) return track_colors def visualize_tracks_on_images( images, tracks, track_vis_mask=None, out_dir="track_visuals_concat_by_xy", image_format="CHW", # "CHW" or "HWC" normalize_mode="[0,1]", cmap_name="hsv", # e.g. "hsv", "rainbow", "jet" frames_per_row=4, # New parameter for grid layout save_grid=True, # Flag to control whether to save the grid image ): """ Visualizes frames in a grid layout with specified frames per row. Each track's color is determined by its (x,y) position in the first visible frame (or frame 0 if always visible). Finally convert the BGR result to RGB before saving. Also saves each individual frame as a separate PNG file. Args: images: torch.Tensor (S, 3, H, W) if CHW or (S, H, W, 3) if HWC. tracks: torch.Tensor (S, N, 2), last dim = (x, y). track_vis_mask: torch.Tensor (S, N) or None. out_dir: folder to save visualizations. image_format: "CHW" or "HWC". normalize_mode: "[0,1]", "[-1,1]", or None for direct raw -> 0..255 cmap_name: a matplotlib colormap name for color_from_xy. frames_per_row: number of frames to display in each row of the grid. save_grid: whether to save all frames in one grid image. Returns: None (saves images in out_dir). """ if len(tracks.shape) == 4: tracks = tracks.squeeze(0) images = images.squeeze(0) if track_vis_mask is not None: track_vis_mask = track_vis_mask.squeeze(0) import matplotlib matplotlib.use("Agg") # for non-interactive (optional) os.makedirs(out_dir, exist_ok=True) S = images.shape[0] _, N, _ = tracks.shape # (S, N, 2) # Move to CPU images = images.cpu().clone() tracks = tracks.cpu().clone() if track_vis_mask is not None: track_vis_mask = track_vis_mask.cpu().clone() # Infer H, W from images shape if image_format == "CHW": # e.g. images[s].shape = (3, H, W) H, W = images.shape[2], images.shape[3] else: # e.g. images[s].shape = (H, W, 3) H, W = images.shape[1], images.shape[2] # Pre-compute the color for each track i based on first visible position track_colors_rgb = get_track_colors_by_position( tracks, # shape (S, N, 2) vis_mask_b=track_vis_mask if track_vis_mask is not None else None, image_width=W, image_height=H, cmap_name=cmap_name, ) # We'll accumulate each frame's drawn image in a list frame_images = [] for s in range(S): # shape => either (3, H, W) or (H, W, 3) img = images[s] # Convert to (H, W, 3) if image_format == "CHW": img = img.permute(1, 2, 0) # (H, W, 3) # else "HWC", do nothing img = img.numpy().astype(np.float32) # Scale to [0,255] if needed if normalize_mode == "[0,1]": img = np.clip(img, 0, 1) * 255.0 elif normalize_mode == "[-1,1]": img = (img + 1.0) * 0.5 * 255.0 img = np.clip(img, 0, 255.0) # else no normalization # Convert to uint8 img = img.astype(np.uint8) # For drawing in OpenCV, convert to BGR img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) # Draw each visible track cur_tracks = tracks[s] # shape (N, 2) if track_vis_mask is not None: valid_indices = torch.where(track_vis_mask[s])[0] else: valid_indices = range(N) cur_tracks_np = cur_tracks.numpy() for i in valid_indices: x, y = cur_tracks_np[i] pt = (int(round(x)), int(round(y))) # track_colors_rgb[i] is (R,G,B). For OpenCV circle, we need BGR R, G, B = track_colors_rgb[i] color_bgr = (int(B), int(G), int(R)) cv2.circle(img_bgr, pt, radius=3, color=color_bgr, thickness=-1) # Convert back to RGB for consistent final saving: img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) # Save individual frame frame_path = os.path.join(out_dir, f"frame_{s:04d}.png") # Convert to BGR for OpenCV imwrite frame_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR) cv2.imwrite(frame_path, frame_bgr) frame_images.append(img_rgb) # Only create and save the grid image if save_grid is True if save_grid: # Calculate grid dimensions num_rows = (S + frames_per_row - 1) // frames_per_row # Ceiling division # Create a grid of images grid_img = None for row in range(num_rows): start_idx = row * frames_per_row end_idx = min(start_idx + frames_per_row, S) # Concatenate this row horizontally row_img = np.concatenate(frame_images[start_idx:end_idx], axis=1) # If this row has fewer than frames_per_row images, pad with black if end_idx - start_idx < frames_per_row: padding_width = (frames_per_row - (end_idx - start_idx)) * W padding = np.zeros((H, padding_width, 3), dtype=np.uint8) row_img = np.concatenate([row_img, padding], axis=1) # Add this row to the grid if grid_img is None: grid_img = row_img else: grid_img = np.concatenate([grid_img, row_img], axis=0) out_path = os.path.join(out_dir, "tracks_grid.png") # Convert back to BGR for OpenCV imwrite grid_img_bgr = cv2.cvtColor(grid_img, cv2.COLOR_RGB2BGR) cv2.imwrite(out_path, grid_img_bgr) print(f"[INFO] Saved color-by-XY track visualization grid -> {out_path}") print(f"[INFO] Saved {S} individual frames to {out_dir}/frame_*.png") ================================================ FILE: visual_util.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import trimesh import gradio as gr import numpy as np import matplotlib from scipy.spatial.transform import Rotation import copy import cv2 import os import requests def predictions_to_glb( predictions, conf_thres=50.0, filter_by_frames="all", mask_black_bg=False, mask_white_bg=False, show_cam=True, mask_sky=False, target_dir=None, prediction_mode="Predicted Pointmap", ) -> trimesh.Scene: """ Converts VGGT predictions to a 3D scene represented as a GLB file. Args: predictions (dict): Dictionary containing model predictions with keys: - world_points: 3D point coordinates (S, H, W, 3) - world_points_conf: Confidence scores (S, H, W) - images: Input images (S, H, W, 3) - extrinsic: Camera extrinsic matrices (S, 3, 4) conf_thres (float): Percentage of low-confidence points to filter out (default: 50.0) filter_by_frames (str): Frame filter specification (default: "all") mask_black_bg (bool): Mask out black background pixels (default: False) mask_white_bg (bool): Mask out white background pixels (default: False) show_cam (bool): Include camera visualization (default: True) mask_sky (bool): Apply sky segmentation mask (default: False) target_dir (str): Output directory for intermediate files (default: None) prediction_mode (str): Prediction mode selector (default: "Predicted Pointmap") Returns: trimesh.Scene: Processed 3D scene containing point cloud and cameras Raises: ValueError: If input predictions structure is invalid """ if not isinstance(predictions, dict): raise ValueError("predictions must be a dictionary") if conf_thres is None: conf_thres = 10.0 print("Building GLB scene") selected_frame_idx = None if filter_by_frames != "all" and filter_by_frames != "All": try: # Extract the index part before the colon selected_frame_idx = int(filter_by_frames.split(":")[0]) except (ValueError, IndexError): pass if "Pointmap" in prediction_mode: print("Using Pointmap Branch") if "world_points" in predictions: pred_world_points = predictions["world_points"] # No batch dimension to remove pred_world_points_conf = predictions.get("world_points_conf", np.ones_like(pred_world_points[..., 0])) else: print("Warning: world_points not found in predictions, falling back to depth-based points") pred_world_points = predictions["world_points_from_depth"] pred_world_points_conf = predictions.get("depth_conf", np.ones_like(pred_world_points[..., 0])) else: print("Using Depthmap and Camera Branch") pred_world_points = predictions["world_points_from_depth"] pred_world_points_conf = predictions.get("depth_conf", np.ones_like(pred_world_points[..., 0])) # Get images from predictions images = predictions["images"] # Use extrinsic matrices instead of pred_extrinsic_list camera_matrices = predictions["extrinsic"] if mask_sky: if target_dir is not None: import onnxruntime skyseg_session = None target_dir_images = target_dir + "/images" image_list = sorted(os.listdir(target_dir_images)) sky_mask_list = [] # Get the shape of pred_world_points_conf to match S, H, W = ( pred_world_points_conf.shape if hasattr(pred_world_points_conf, "shape") else (len(images), images.shape[1], images.shape[2]) ) # Download skyseg.onnx if it doesn't exist if not os.path.exists("skyseg.onnx"): print("Downloading skyseg.onnx...") download_file_from_url( "https://huggingface.co/JianyuanWang/skyseg/resolve/main/skyseg.onnx", "skyseg.onnx" ) for i, image_name in enumerate(image_list): image_filepath = os.path.join(target_dir_images, image_name) mask_filepath = os.path.join(target_dir, "sky_masks", image_name) # Check if mask already exists if os.path.exists(mask_filepath): # Load existing mask sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE) else: # Generate new mask if skyseg_session is None: skyseg_session = onnxruntime.InferenceSession("skyseg.onnx") sky_mask = segment_sky(image_filepath, skyseg_session, mask_filepath) # Resize mask to match H×W if needed if sky_mask.shape[0] != H or sky_mask.shape[1] != W: sky_mask = cv2.resize(sky_mask, (W, H)) sky_mask_list.append(sky_mask) # Convert list to numpy array with shape S×H×W sky_mask_array = np.array(sky_mask_list) # Apply sky mask to confidence scores sky_mask_binary = (sky_mask_array > 0.1).astype(np.float32) pred_world_points_conf = pred_world_points_conf * sky_mask_binary if selected_frame_idx is not None: pred_world_points = pred_world_points[selected_frame_idx][None] pred_world_points_conf = pred_world_points_conf[selected_frame_idx][None] images = images[selected_frame_idx][None] camera_matrices = camera_matrices[selected_frame_idx][None] vertices_3d = pred_world_points.reshape(-1, 3) # Handle different image formats - check if images need transposing if images.ndim == 4 and images.shape[1] == 3: # NCHW format colors_rgb = np.transpose(images, (0, 2, 3, 1)) else: # Assume already in NHWC format colors_rgb = images colors_rgb = (colors_rgb.reshape(-1, 3) * 255).astype(np.uint8) conf = pred_world_points_conf.reshape(-1) # Convert percentage threshold to actual confidence value if conf_thres == 0.0: conf_threshold = 0.0 else: conf_threshold = np.percentile(conf, conf_thres) conf_mask = (conf >= conf_threshold) & (conf > 1e-5) if mask_black_bg: black_bg_mask = colors_rgb.sum(axis=1) >= 16 conf_mask = conf_mask & black_bg_mask if mask_white_bg: # Filter out white background pixels (RGB values close to white) # Consider pixels white if all RGB values are above 240 white_bg_mask = ~((colors_rgb[:, 0] > 240) & (colors_rgb[:, 1] > 240) & (colors_rgb[:, 2] > 240)) conf_mask = conf_mask & white_bg_mask vertices_3d = vertices_3d[conf_mask] colors_rgb = colors_rgb[conf_mask] if vertices_3d is None or np.asarray(vertices_3d).size == 0: vertices_3d = np.array([[1, 0, 0]]) colors_rgb = np.array([[255, 255, 255]]) scene_scale = 1 else: # Calculate the 5th and 95th percentiles along each axis lower_percentile = np.percentile(vertices_3d, 5, axis=0) upper_percentile = np.percentile(vertices_3d, 95, axis=0) # Calculate the diagonal length of the percentile bounding box scene_scale = np.linalg.norm(upper_percentile - lower_percentile) colormap = matplotlib.colormaps.get_cmap("gist_rainbow") # Initialize a 3D scene scene_3d = trimesh.Scene() # Add point cloud data to the scene point_cloud_data = trimesh.PointCloud(vertices=vertices_3d, colors=colors_rgb) scene_3d.add_geometry(point_cloud_data) # Prepare 4x4 matrices for camera extrinsics num_cameras = len(camera_matrices) extrinsics_matrices = np.zeros((num_cameras, 4, 4)) extrinsics_matrices[:, :3, :4] = camera_matrices extrinsics_matrices[:, 3, 3] = 1 if show_cam: # Add camera models to the scene for i in range(num_cameras): world_to_camera = extrinsics_matrices[i] camera_to_world = np.linalg.inv(world_to_camera) rgba_color = colormap(i / num_cameras) current_color = tuple(int(255 * x) for x in rgba_color[:3]) integrate_camera_into_scene(scene_3d, camera_to_world, current_color, scene_scale) # Align scene to the observation of the first camera scene_3d = apply_scene_alignment(scene_3d, extrinsics_matrices) print("GLB Scene built") return scene_3d def integrate_camera_into_scene(scene: trimesh.Scene, transform: np.ndarray, face_colors: tuple, scene_scale: float): """ Integrates a fake camera mesh into the 3D scene. Args: scene (trimesh.Scene): The 3D scene to add the camera model. transform (np.ndarray): Transformation matrix for camera positioning. face_colors (tuple): Color of the camera face. scene_scale (float): Scale of the scene. """ cam_width = scene_scale * 0.05 cam_height = scene_scale * 0.1 # Create cone shape for camera rot_45_degree = np.eye(4) rot_45_degree[:3, :3] = Rotation.from_euler("z", 45, degrees=True).as_matrix() rot_45_degree[2, 3] = -cam_height opengl_transform = get_opengl_conversion_matrix() # Combine transformations complete_transform = transform @ opengl_transform @ rot_45_degree camera_cone_shape = trimesh.creation.cone(cam_width, cam_height, sections=4) # Generate mesh for the camera slight_rotation = np.eye(4) slight_rotation[:3, :3] = Rotation.from_euler("z", 2, degrees=True).as_matrix() vertices_combined = np.concatenate( [ camera_cone_shape.vertices, 0.95 * camera_cone_shape.vertices, transform_points(slight_rotation, camera_cone_shape.vertices), ] ) vertices_transformed = transform_points(complete_transform, vertices_combined) mesh_faces = compute_camera_faces(camera_cone_shape) # Add the camera mesh to the scene camera_mesh = trimesh.Trimesh(vertices=vertices_transformed, faces=mesh_faces) camera_mesh.visual.face_colors[:, :3] = face_colors scene.add_geometry(camera_mesh) def apply_scene_alignment(scene_3d: trimesh.Scene, extrinsics_matrices: np.ndarray) -> trimesh.Scene: """ Aligns the 3D scene based on the extrinsics of the first camera. Args: scene_3d (trimesh.Scene): The 3D scene to be aligned. extrinsics_matrices (np.ndarray): Camera extrinsic matrices. Returns: trimesh.Scene: Aligned 3D scene. """ # Set transformations for scene alignment opengl_conversion_matrix = get_opengl_conversion_matrix() # Rotation matrix for alignment (180 degrees around the y-axis) align_rotation = np.eye(4) align_rotation[:3, :3] = Rotation.from_euler("y", 180, degrees=True).as_matrix() # Apply transformation initial_transformation = np.linalg.inv(extrinsics_matrices[0]) @ opengl_conversion_matrix @ align_rotation scene_3d.apply_transform(initial_transformation) return scene_3d def get_opengl_conversion_matrix() -> np.ndarray: """ Constructs and returns the OpenGL conversion matrix. Returns: numpy.ndarray: A 4x4 OpenGL conversion matrix. """ # Create an identity matrix matrix = np.identity(4) # Flip the y and z axes matrix[1, 1] = -1 matrix[2, 2] = -1 return matrix def transform_points(transformation: np.ndarray, points: np.ndarray, dim: int = None) -> np.ndarray: """ Applies a 4x4 transformation to a set of points. Args: transformation (np.ndarray): Transformation matrix. points (np.ndarray): Points to be transformed. dim (int, optional): Dimension for reshaping the result. Returns: np.ndarray: Transformed points. """ points = np.asarray(points) initial_shape = points.shape[:-1] dim = dim or points.shape[-1] # Apply transformation transformation = transformation.swapaxes(-1, -2) # Transpose the transformation matrix points = points @ transformation[..., :-1, :] + transformation[..., -1:, :] # Reshape the result result = points[..., :dim].reshape(*initial_shape, dim) return result def compute_camera_faces(cone_shape: trimesh.Trimesh) -> np.ndarray: """ Computes the faces for the camera mesh. Args: cone_shape (trimesh.Trimesh): The shape of the camera cone. Returns: np.ndarray: Array of faces for the camera mesh. """ # Create pseudo cameras faces_list = [] num_vertices_cone = len(cone_shape.vertices) for face in cone_shape.faces: if 0 in face: continue v1, v2, v3 = face v1_offset, v2_offset, v3_offset = face + num_vertices_cone v1_offset_2, v2_offset_2, v3_offset_2 = face + 2 * num_vertices_cone faces_list.extend( [ (v1, v2, v2_offset), (v1, v1_offset, v3), (v3_offset, v2, v3), (v1, v2, v2_offset_2), (v1, v1_offset_2, v3), (v3_offset_2, v2, v3), ] ) faces_list += [(v3, v2, v1) for v1, v2, v3 in faces_list] return np.array(faces_list) def segment_sky(image_path, onnx_session, mask_filename=None): """ Segments sky from an image using an ONNX model. Thanks for the great model provided by https://github.com/xiongzhu666/Sky-Segmentation-and-Post-processing Args: image_path: Path to input image onnx_session: ONNX runtime session with loaded model mask_filename: Path to save the output mask Returns: np.ndarray: Binary mask where 255 indicates non-sky regions """ assert mask_filename is not None image = cv2.imread(image_path) result_map = run_skyseg(onnx_session, [320, 320], image) # resize the result_map to the original image size result_map_original = cv2.resize(result_map, (image.shape[1], image.shape[0])) # Fix: Invert the mask so that 255 = non-sky, 0 = sky # The model outputs low values for sky, high values for non-sky output_mask = np.zeros_like(result_map_original) output_mask[result_map_original < 32] = 255 # Use threshold of 32 os.makedirs(os.path.dirname(mask_filename), exist_ok=True) cv2.imwrite(mask_filename, output_mask) return output_mask def run_skyseg(onnx_session, input_size, image): """ Runs sky segmentation inference using ONNX model. Args: onnx_session: ONNX runtime session input_size: Target size for model input (width, height) image: Input image in BGR format Returns: np.ndarray: Segmentation mask """ # Pre process:Resize, BGR->RGB, Transpose, PyTorch standardization, float32 cast temp_image = copy.deepcopy(image) resize_image = cv2.resize(temp_image, dsize=(input_size[0], input_size[1])) x = cv2.cvtColor(resize_image, cv2.COLOR_BGR2RGB) x = np.array(x, dtype=np.float32) mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] x = (x / 255 - mean) / std x = x.transpose(2, 0, 1) x = x.reshape(-1, 3, input_size[0], input_size[1]).astype("float32") # Inference input_name = onnx_session.get_inputs()[0].name output_name = onnx_session.get_outputs()[0].name onnx_result = onnx_session.run([output_name], {input_name: x}) # Post process onnx_result = np.array(onnx_result).squeeze() min_value = np.min(onnx_result) max_value = np.max(onnx_result) onnx_result = (onnx_result - min_value) / (max_value - min_value) onnx_result *= 255 onnx_result = onnx_result.astype("uint8") return onnx_result def download_file_from_url(url, filename): """Downloads a file from a Hugging Face model repo, handling redirects.""" try: # Get the redirect URL response = requests.get(url, allow_redirects=False) response.raise_for_status() # Raise HTTPError for bad requests (4xx or 5xx) if response.status_code == 302: # Expecting a redirect redirect_url = response.headers["Location"] response = requests.get(redirect_url, stream=True) response.raise_for_status() else: print(f"Unexpected status code: {response.status_code}") return with open(filename, "wb") as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) print(f"Downloaded {filename} successfully.") except requests.exceptions.RequestException as e: print(f"Error downloading file: {e}")