Full Code of siddk/voltron-robotics for AI

main 1b299bf5cfa0 cached
60 files
431.8 KB
104.9k tokens
305 symbols
1 requests
Download .txt
Showing preview only (453K chars total). Download the full file or copy to clipboard to get everything.
Repository: siddk/voltron-robotics
Branch: main
Commit: 1b299bf5cfa0
Files: 60
Total size: 431.8 KB

Directory structure:
gitextract_4mb9jrxp/

├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── Makefile
├── README.md
├── docs/
│   └── ROADMAP.md
├── examples/
│   ├── pretrain/
│   │   ├── README.md
│   │   ├── preprocess.py
│   │   └── pretrain.py
│   ├── usage.py
│   ├── verification/
│   │   └── verify.py
│   └── xla-reference/
│       ├── README.md
│       ├── xpreprocess.py
│       └── xpretrain.py
├── pyproject.toml
├── setup.py
└── voltron/
    ├── __init__.py
    ├── conf/
    │   ├── __init__.py
    │   ├── accelerators.py
    │   ├── datasets.py
    │   ├── models.py
    │   └── tracking.py
    ├── datasets/
    │   ├── __init__.py
    │   ├── datasets.py
    │   └── v1/
    │       ├── __init__.py
    │       └── stream_datasets.py
    ├── models/
    │   ├── __init__.py
    │   ├── core/
    │   │   ├── __init__.py
    │   │   ├── vcond.py
    │   │   ├── vdual.py
    │   │   └── vgen.py
    │   ├── instantiate.py
    │   ├── materialize.py
    │   ├── reproductions/
    │   │   ├── __init__.py
    │   │   ├── vmvp.py
    │   │   ├── vr3m.py
    │   │   └── vrn3m.py
    │   └── util/
    │       ├── __init__.py
    │       ├── extraction.py
    │       ├── optimization.py
    │       └── transformer.py
    ├── overwatch/
    │   ├── __init__.py
    │   └── overwatch.py
    ├── preprocessing/
    │   ├── __init__.py
    │   ├── core.py
    │   ├── process.py
    │   ├── transforms.py
    │   └── v1/
    │       ├── __init__.py
    │       ├── process.py
    │       ├── transforms.py
    │       └── utils.py
    └── util/
        ├── __init__.py
        ├── checkpointing.py
        ├── metrics.py
        ├── utilities.py
        └── v1/
            ├── __init__.py
            ├── checkpointing.py
            ├── distributed.py
            ├── random.py
            └── xla_logger.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# Ruff
.ruff_cache/

# IDE caches
.idea/
.vscode/

# Mac OS
.DS_Store

# Cache
data/
cache/

# Scratch
scratch/


================================================
FILE: .pre-commit-config.yaml
================================================
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
exclude: ".git"

repos:
  - repo: https://github.com/charliermarsh/ruff-pre-commit
    rev: v0.0.252
    hooks:
      - id: ruff
        args: [ --fix, --exit-non-zero-on-fix ]

  - repo: https://github.com/psf/black
    rev: 23.1.0
    hooks:
      - id: black

  - repo: https://github.com/pre-commit/pre-commit-hooks
    rev: v4.4.0
    hooks:
      - id: check-added-large-files
        args: ["--maxkb=40000"]
      - id: check-ast
      - id: check-case-conflict
      - id: check-merge-conflict
      - id: check-toml
      - id: check-yaml
      - id: end-of-file-fixer
      - id: trailing-whitespace


================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2021-present, Siddharth Karamcheti and other contributors.

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: Makefile
================================================
.PHONY: help check autoformat
.DEFAULT: help

# Generates a useful overview/help message for various make features - add to this as necessary!
help:
	@echo "make check"
	@echo "    Run code style and linting (black, ruff) *without* changing files!"
	@echo "make autoformat"
	@echo "    Run code styling (black, ruff) and update in place - committing with pre-commit also does this."

check:
	black --check .
	ruff check --show-source .

autoformat:
	black .
	ruff check --fix --show-fixes .


================================================
FILE: README.md
================================================
<div align="center">
    <img src="https://raw.githubusercontent.com/siddk/voltron-robotics/main/docs/assets/voltron-banner.png" alt="Voltron Logo"/>
</div>

<div align="center">

[![arXiv](https://img.shields.io/badge/arXiv-2302.12766-df2a2a.svg?style=for-the-badge)](https://arxiv.org/abs/2302.12766)
[![PyTorch](https://img.shields.io/badge/PyTorch-2.0.0-EE4C2C.svg?style=for-the-badge&logo=pytorch)](https://pytorch.org/get-started/locally/)
[![Code Style: Black](https://img.shields.io/badge/Code%20Style-Black-000000?style=for-the-badge)](https://github.com/psf/black)
[![Ruff](https://img.shields.io/badge/%E2%9A%A1%EF%B8%8F-Ruff-orange?style=for-the-badge)](https://github.com/charliermarsh/ruff)
![License](https://img.shields.io/github/license/siddk/lila?color=blueviolet&style=for-the-badge)

</div>

---

# Language-Driven Representation Learning for Robotics

Package repository for Voltron: Language-Driven Representation Learning for Robotics. Provides code for loading
pretrained Voltron, R3M, and MVP representations for adaptation to downstream tasks, as well as code for pretraining
such representations on arbitrary datasets.

---

## Quickstart

This repository is built with PyTorch; while specified as a dependency for the package, we highly recommend that
you install the desired version (e.g., with accelerator support) for your given hardware and environment
manager (e.g., `conda`).

PyTorch installation instructions [can be found here](https://pytorch.org/get-started/locally/). This repository
should work with PyTorch >= 1.12. Releases before 1.1.0 have been thoroughly tested with PyTorch 1.12.0,
Torchvision 0.13.0, and Torchaudio 0.12.0. **Note**: Releases 1.1.0 and after *assume PyTorch 2.0*!

Once PyTorch has been properly installed, you can install this package via PyPI, and you're off!

```bash
pip install voltron-robotics
```

You can also install this package locally via an editable installation in case you want to run examples/extend the
current functionality:

```bash
git clone https://github.com/siddk/voltron-robotics
cd voltron-robotics
pip install -e .
```

## Usage

Voltron Robotics (package: `voltron`) is structured to provide easy access to pretrained Voltron models (and
reproductions), to facilitate use for various downstream tasks. Using a pretrained Voltron model is easy:

```python
from torchvision.io import read_image
from voltron import instantiate_extractor, load

# Load a frozen Voltron (V-Cond) model & configure a vector extractor
vcond, preprocess = load("v-cond", device="cuda", freeze=True)
vector_extractor = instantiate_extractor(vcond)()

# Obtain & Preprocess an image =>> can be from a dataset, or camera on a robot, etc.
#   => Feel free to add any language if you have it (Voltron models work either way!)
img = preprocess(read_image("examples/img/peel-carrot-initial.png"))[None, ...].to("cuda")
lang = ["peeling a carrot"]

# Extract both multimodal AND vision-only embeddings!
multimodal_embeddings = vcond(img, lang, mode="multimodal")
visual_embeddings = vcond(img, mode="visual")

# Use the `vector_extractor` to output dense vector representations for downstream applications!
#   => Pass this representation to model of your choice (object detector, control policy, etc.)
representation = vector_extractor(multimodal_embeddings)
```

Voltron representations can be used for a variety of different applications; in the
[`voltron-evaluation`](https://github.com/siddk/voltron-evaluation) repository, you can find code for adapting Voltron
representations to various downstream tasks (segmentation, object detection, control, etc.); all the applications from
our paper.

---

## API

![Voltron Framework](https://raw.githubusercontent.com/siddk/voltron-robotics/main/docs/assets/voltron-framework.png)

The package `voltron` provides the following functionality for using and adapting existing representations:

#### `voltron.available_models()`

Returns the name of available Voltron models; right now, the following models (all models trained in the paper) are
available:

- `v-cond` – V-Cond (ViT-Small) trained on Sth-Sth; single-frame w/ language-conditioning.
- `v-dual` – V-Dual (ViT-Small) trained on Sth-Sth; dual-frame w/ language-conditioning.
- `v-gen` – V-Gen (ViT-Small) trained on Sth-Sth; dual-frame w/ language conditioning AND generation.
- `r-mvp` – R-MVP (ViT-Small); reproduction of [MVP](https://github.com/ir413/mvp) trained on Sth-Sth.
- `r-r3m-vit` – R-R3M (ViT-Small); reproduction of [R3M](https://github.com/facebookresearch/r3m) trained on Sth-Sth.
- `r-r3m-rn50` – R-R3M (ResNet-50); reproduction of [R3M](https://github.com/facebookresearch/r3m) trained on Sth-Sth.
- `v-cond-base` – V-Cond (ViT-Base) trained on Sth-Sth; larger (86M parameter) variant of V-Cond.

#### `voltron.load(name: str, device: str, freeze: bool, cache: str = cache/)`

Returns the model and the Torchvision Transform needed by the model, where `name` is one of the strings returned
by `voltron.available_models()`; this in general follows the same API as
[OpenAI's CLIP](https://github.com/openai/CLIP).

---

Voltron models (`v-{cond, dual, gen, ...}`) returned by `voltron.load()` support the following:

#### `model(img: Tensor, lang: Optional[List[str]], mode: str = "multimodal")`

Returns a sequence of embeddings corresponding to the output of the multimodal encoder; note that `lang` can be None,
which is totally fine for Voltron models! However, if you have any language (even a coarse task description), it'll
probably be helpful!

The parameter `mode` in `["multimodal", "visual"]` controls whether the output will contain the fused image patch and
language embeddings, or only the image patch embeddings.

**Note:** For the API for the non-Voltron models (e.g., R-MVP, R-R3M), take a look at
[`examples/verify.py`](examples/verify.py); this file shows how representations from *every* model can be extracted.

### Adaptation

See [`examples/usage.py`](examples/usage.py) and the [`voltron-evaluation`](https://github.com/siddk/voltron-evaluation)
repository for more examples on the various ways to adapt/use Voltron representations.

---

## Contributing

Before committing to the repository, make sure to set up your dev environment!
Here are the basic development environment setup guidelines:

+ Fork/clone the repository, performing an editable installation. Make sure to install with the development dependencies
  (e.g., `pip install -e ".[dev]"`); this will install `black`, `ruff`, and `pre-commit`.

+ Install `pre-commit` hooks (`pre-commit install`).

+ Branch for the specific feature/issue, issuing PR against the upstream repository for review.

Additional Contribution Notes:
- This project has migrated to the recommended
  [`pyproject.toml` based configuration for setuptools](https://setuptools.pypa.io/en/latest/userguide/quickstart.html).
  However, as some tools haven't yet adopted [PEP 660](https://peps.python.org/pep-0660/), we provide a
  [`setup.py` file](https://setuptools.pypa.io/en/latest/userguide/pyproject_config.html).

- This package follows the [`flat-layout` structure](https://setuptools.pypa.io/en/latest/userguide/package_discovery.html#flat-layout)
  described in `setuptools`.

- Make sure to add any new dependencies to the `project.toml` file!

---

## Repository Structure

High-level overview of repository/project file-tree:

+ `docs/` - Package documentation & assets - including project roadmap.
+ `voltron` - Package source code; has all core utilities for model specification, loading, feature extraction,
              preprocessing, etc.
+ `examples/` - Standalone examples scripts for demonstrating various functionality (e.g., extracting different types
                of representations, adapting representations in various contexts, pretraining, amongst others).
+ `.pre-commit-config.yaml` - Pre-commit configuration file (sane defaults + `black` + `ruff`).
+ `LICENSE` - Code is made available under the MIT License.
+ `Makefile` - Top-level Makefile (by default, supports linting - checking & auto-fix); extend as needed.
+ `pyproject.toml` - Following PEP 621, this file has all project configuration details (including dependencies), as
                     well as tool configurations (for `black` and `ruff`).
+ `README.md` - You are here!

---

## Citation

Please cite [our paper](https://arxiv.org/abs/2302.12766) if using any of the Voltron models, evaluation suite, or other parts of our framework in your work.

```bibtex
@inproceedings{karamcheti2023voltron,
  title={Language-Driven Representation Learning for Robotics},
  author={Siddharth Karamcheti and Suraj Nair and Annie S. Chen and Thomas Kollar and Chelsea Finn and Dorsa Sadigh and Percy Liang},
  booktitle={Robotics: Science and Systems (RSS)},
  year={2023}
}
```


================================================
FILE: docs/ROADMAP.md
================================================
# Project Roadmap

We document the future of this project (new features to be added, issues to address) here. For the most part, any
new features/bugfixes are documented as [Github Issues](https://github.com/siddk/voltron-robotics/issues).

## Timeline

[X] - **February 26th, 2023**: Initial Voltron-Robotics release with support for loading/adapting all pretrained models,
                               with comprehensive verification scripts & a small adaptation example.

[X] - **April 4, 2023**:  [#1](https://github.com/siddk/voltron-robotics/issues/1) - Add `xpretrain.py` reference script,
                          mostly for completeness. Refactor/rewrite the preprocessing and pretraining pipeline to reflect
                          the Qualcomm Sth-Sth data format, as well as PyTorch DDP vs. the patched PyTorch XLA!

[X] - **April 11, 2023**: [#2](https://github.com/siddk/voltron-robotics/issues/2) - Add support and a more general API
                          for pretraining on other datasets.

[ ] - **Future**:         [#5](https://github.com/siddk/voltron-robotics/issues/5) - Add better documentation and examples
                          around using the MAP extractor (especially for adaptation tasks).


================================================
FILE: examples/pretrain/README.md
================================================
# Pretraining Voltron Models

We provide scripts for pretraining Voltron models on various datasets. Below, we provide the full pipeline from
downloading the raw Something-Something-v2 Dataset from Qualcomm, running preprocessing, then running Distributed
Data Parallel (DDP) pretraining on 1+ GPUs via `torchrun`. Adding support for new datasets should follow this same
general flow.

---

## Dataset Preprocessing

We provide end-to-end instructions for downloading, preprocessing, and serializing various pretraining datasets (and
combinations thereof). Where possible, we provide links to batch/dataset index files.

**Note:** We make a key assumption that you have enough local disk space (e.g., on your server, attached NFS volume) to
store all *raw* and *preprocessed* data; this can range from 100s of GBs to 10s of TBs! We did not have access to such
storage in the original work, necessitating the *streaming* dataloaders defined in
`voltron/datasets/v1/stream_datasets.py`. Given your resources, you might consider adopting a similar approach; feel
free to post an issue with any questions!

We currently support pretraining on the following datasets:

- [Something-Something-v2](https://developer.qualcomm.com/software/ai-datasets/something-something)

Instructions for downloading/preprocessing each dataset can be found below!

---

### Something-Something-v2

Dataset Download: [Qualcomm AI Datasets](https://developer.qualcomm.com/software/ai-datasets/something-something)

#### Obtaining the Raw Dataset

Follow the instructions [at the above link](https://developer.qualcomm.com/software/ai-datasets/something-something) to
download the dataset. Qualcomm requires that you register for a
[Qualcomm OneID Account](https://myaccount.qualcomm.com/signup?target=https%3A%2F%2Fdeveloper.qualcomm.com)
to get access to the data. Approval might take some time.

After registering for an account, make sure to download all of the following files to a directory of your choosing
(we create a directory `data/raw/something-something-v2/downloaded/`). *You will need to manually download all 22 of
the following files from the Qualcomm site*:

1. Datasheet / Instructions (PDF – optional, but useful): `20bn-something-something_download_instructions_-_091622.pdf`
2. Labels (includes language annotations): `20bn-something-something_download-package-labels.zip`
3. Chunked Videos (should be 20 `.zip` archives):
   + `20bn-something-something-v2-00.zip`
   + ...
   + `20bn-something-something-v2-19.zip`

To extract all the given files (we extract to `data/raw/something-something-v2/`) - *execute the following from inside
the `downloaded/` subdirectory)*:

```bash
# Labels (annotations/language) --> creates `data/raw/something-something-v2/labels`
unzip 20bn-something-something-download-package-labels.zip -d ../

# Videos (following instructions in `20-bn-something-something_download_instructions_-_091622.pdf`)
unzip "20bn-something-something-v2-*.zip" -d ../videos
cd ../videos
cat 20bn-something-something-?? | tar -xvzf -
find . -maxdepth 1 -type f -delete
cd 20bn-something-something-v2/
find . -mindepth 1 -maxdepth 1 -exec mv -t .. -- {} +
cd ..
rm -r 20bn-something-something-v2
ls | wc   # Should have 220847 `.webm` files!
```

#### Dataset Information & Statistics

Something-Something-v2 consists of 220,847 `.webm` clips (168,913 in the `train` split) each with a height of exactly
240px, and variable width. The frames are encoded at a fixed 12 FPS.

There are an average of 45 frames per clip (approx ~7 KB per jpeg); ~7.6M frames total (~56 GB).

#### Video/Image Transformations --> from Video Clip to "frame" --> "tensor"

```python
import av
from PIL import Image, ImageOps

# Resolutions for "preprocessing" (serialize to disk) and "training"
PREPROCESS_RESOLUTION, TRAIN_RESOLUTION = 240, 224

# Define Preprocessing Transformation
def preprocess_transform(frames: List[Image.Image]) -> List[Image.Image]:
    # Assert width >= height and height >= PREPROCESS_RESOLUTION
    orig_w, orig_h = frames[0].size
    assert orig_w >= orig_h >= PREPROCESS_RESOLUTION

    # Compute scale factor --> just a function of height and PREPROCESS_RESOLUTION
    scale_factor = PREPROCESS_RESOLUTION / orig_h

    # Full Transformation --> scale (preserve aspect ratio, then get square)
    for idx in range(len(frames)):
        frames[idx] = ImageOps.scale(frames[idx], factor=scale_factor)
        left = (frames[idx].size[0] - PREPROCESS_RESOLUTION) // 2
        frames[idx] = frames[idx].crop((left, 0, left + PREPROCESS_RESOLUTION, PREPROCESS_RESOLUTION))

    return frames

def train_transform(img) -> torch.Tensor:
    # Assumes square, just resizes to TRAIN_RESOLUTION via `torchvision.transforms`
    ...

def extract_frames(webm_file: str) -> None:
    container = av.open(webm_file)
    assert int(container.streams.video[0].average_rate) == 12, "FPS for `sth-sth-v2` should be 12!"

    # Extract --> then serialize via `Image.save("frame_{idx}.jpg")`
    frames = preprocess_transform([f.to_image() for f in container.decode(video=0)])
    ...
```


#### Citation

If you are pretraining on this dataset, make sure to cite the original research; Something-Something-v2 is the product
of two papers:

```bibtex
@inproceedings{goyal2017sthsthv1,
  author = {Raghav Goyal and Samira Ebrahimi Kahou and Vincent Michalski and Joanna Materzynska and Susanne Westphal and Heuna Kim and Valentin Haenel and Ingo Fründ and Peter N. Yianilos and Moritz Mueller-Freitag and Florian Hoppe and Christian Thurau and Ingo Bax and Roland Memisevic},
  booktitle = {International Conference on Computer Vision (ICCV)},
  title = {The ``Something Something'' Video Database for Learning and Evaluating Visual Common Sense},
  year = {2017},
}
@article{mahidisoltani2018sthsthv2,
  author={Farzaneh Mahdisoltani and Guillaume Berger and Waseem Gharbieh and David J. Fleet and Roland Memisevic},
  journal = {arXiv preprint arXiv:1804.09235},
  title={On the Effectiveness of Task Granularity for Transfer Learning},
  year={2018}
}
```

---

## PyTorch Native Pretraining Pipeline

To pretrain a Voltron model (e.g., `v-cond`) on the processed data, make sure to read `examples/pretrain/preprocess.py`.
A sample launch command to run with the Something-Something-v2 dataset on a single node with 8 GPUs is as follows:

```bash
torchrun --standalone --nnodes 1 --nproc-per-node 8 examples/pretrain/pretrain.py
```

Make sure to check the following configuration files and either update them manually (adding your own dataclass,
overriding [DEFAULTS](https://github.com/siddk/voltron-robotics/blob/main/examples/pretrain/pretrain.py#L38)), or by
using Hydra semantics to override them at the command line (e.g., `... pretrain.py dataset.path="<PATH>" ...`):

- [Accelerator Config](../../voltron/conf/accelerators.py): Depending on hardware, might need to tune `num_workers`
- [Dataset Config](../../voltron/conf/datasets.py): Make sure to override `path` and `artifact_path`
- [Tracking Config](../../voltron/conf/tracking.py): Disable Weights & Biases / change default entity/name


================================================
FILE: examples/pretrain/preprocess.py
================================================
"""
preprocess.py

Centralized script for preprocessing various video/vision-language datasets for GPU pretraining, using a multi-stage,
multiprocessing approach.

Run as a standalone script, *prior* to calling `pretrain.py` =>> mostly because we want to preprocess the data once, as
a fixed cost.
"""
import logging
from dataclasses import dataclass, field
from typing import Any, Dict, List

import hydra
from hydra.core.config_store import ConfigStore
from omegaconf import MISSING

from voltron.conf import DatasetConfig
from voltron.overwatch import OverwatchRich
from voltron.preprocessing import extract_frames, preprocess_language, unify_batches
from voltron.util import set_global_seed

# Grab Logger
overwatch = logging.getLogger(__file__)


# Set Defaults (Hydra w/ Structured Configs)
DEFAULTS = ["_self_", {"dataset": "sth-sth-v2"}, {"override hydra/job_logging": "overwatch_rich"}]


@dataclass
class PreprocessingConfig:
    # fmt: off
    defaults: List[Any] = field(default_factory=lambda: DEFAULTS)
    hydra: Dict[str, Any] = field(
        default_factory=lambda: {"run": {"dir": "./runs/preprocessing/${now:%m-%d}/dataset-${dataset.name}"}}
    )

    # Command Line Arguments
    seed: int = 21                                  # Random Seed (for reproducibility)
    dry_run: bool = False                           # Dry Run --> Get a sense of preprocessing/serialization footprint

    # Composable / Structured Arguments
    dataset: DatasetConfig = MISSING                # Dataset(s) for pretraining/preprocessing
    # fmt: on


# Hydra Setup :: Retrieve ConfigStore (Singleton) & Register Components
cs = ConfigStore.instance()
cs.store(group="hydra/job_logging", name="overwatch_rich", node=OverwatchRich)
cs.store(name="config", node=PreprocessingConfig)


@hydra.main(config_path=None, config_name="config")
def preprocess(cfg: PreprocessingConfig) -> None:
    overwatch.info("Preprocessing :: Running Phases for Frame Extraction, Language Compilation, and Batching...")

    # Set Randomness
    set_global_seed(cfg.seed)

    # Phase 1 :: Serialize Frames from Video Clips --> get `registry` (index files) for train and validation
    train_registry, val_registry, train_dir, val_dir = extract_frames(
        cfg.dataset.name,
        path=cfg.dataset.path,
        artifact_path=cfg.dataset.artifact_path,
        preprocess_resolution=cfg.dataset.preprocess_resolution,
        n_val_videos=cfg.dataset.n_val_videos,
        dry_run=cfg.dry_run,
    )

    # Phase 2 :: Normalize & Tokenize Language --> create `index.pt` and `index.json` files
    index_dir = preprocess_language(
        cfg.dataset.name,
        train_registry,
        val_registry,
        artifact_path=cfg.dataset.artifact_path,
        max_lang_len=cfg.dataset.max_lang_len,
        language_model=cfg.dataset.language_model,
        hf_cache=cfg.dataset.hf_cache,
    )

    # Phase 3 :: Assemble "Data-Locked" Batch Sets for Various Models (e.g., for single-frame/dual-frame/quintet)
    unify_batches(
        cfg.dataset.name,
        train_registry,
        val_registry,
        train_dir,
        val_dir,
        index_dir,
        batch_formats=cfg.dataset.batch_formats,
        max_epochs=cfg.dataset.max_epochs,
        initial_final_alpha=cfg.dataset.initial_final_alpha,
    )

    overwatch.info("Preprocessing Complete!")


if __name__ == "__main__":
    preprocess()


================================================
FILE: examples/pretrain/pretrain.py
================================================
"""
pretrain.py

Core pretraining script for Native PyTorch (Single/Multi-) GPU pretraining on the Something-Something-v2 dataset; this
is basically just a 1-1 reproduction of the XLA pretraining script (`examples/xla-reference/xpretrain.py`) with just
a bit of cleanup, the default PyTorch DDP semantics (`torchrun`), using PyTorch 2.0.

Other notable differences from `xpretrain.py`:
    - Loads data from the local filesystem instead of streaming from a GCP bucket (can be added back easily!)
    - No TPU/XLA specific dependencies --> just PyTorch 2.0!

Run with:
    - [Single Node Multi-GPU ($K)]: `torchrun --standalone --nnodes 1 --nproc-per-node $K examples/pretrain/pretrain.py`
"""
import logging
import os
import re
import time
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional

import hydra
import torch
import torch.distributed as dist
from hydra.core.config_store import ConfigStore
from omegaconf import MISSING, OmegaConf
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm

from voltron.conf import AcceleratorConfig, DatasetConfig, ModelConfig, TrackingConfig
from voltron.datasets import get_datasets
from voltron.models import get_model_optimizer
from voltron.overwatch import OverwatchRich
from voltron.util import CheckpointSaver, Metrics, ResumeableDistributedSampler, do_resume, set_global_seed

# Set Defaults (Hydra w/ Structured Configs)
DEFAULTS = [
    "_self_",
    {"model": "v-cond"},
    {"dataset": "sth-sth-v2"},
    {"accelerator": "torchrun"},
    {"tracking": "voltron-tracking"},
    {"override hydra/job_logging": "overwatch_rich"},
]


@dataclass
class PretrainConfig:
    # fmt: off
    defaults: List[Any] = field(default_factory=lambda: DEFAULTS)
    hydra: Dict[str, Any] = field(default_factory=lambda: {
        "run": {"dir": "runs/train/${model.identifier}+dataset-${dataset.name}"}
    })

    # Command Line Arguments
    run_id: Optional[str] = None                                        # Run ID for Logging
    seed: int = 21                                                      # Random Seed (for reproducibility)

    # Resume / Debug Behavior
    resume: bool = True                                                 # Whether to resume an existing run...
    wandb_resume_id: Optional[str] = None                               # W&B Run ID for `resume` behavior...

    # Composable / Structured Arguments
    model: ModelConfig = MISSING                                        # Model architecture for pretraining
    dataset: DatasetConfig = MISSING                                    # List of datasets for pretraining
    accelerator: AcceleratorConfig = MISSING                            # Accelerator (should always keep `torchrun`)
    tracking: TrackingConfig = MISSING                                  # Run/experiment tracking configuration
    # fmt: on


# Hydra Setup :: Retrieve ConfigStore (Singleton) & Register Components
cs = ConfigStore.instance()
cs.store(group="hydra/job_logging", name="overwatch_rich", node=OverwatchRich)
cs.store(name="config", node=PretrainConfig)


@hydra.main(config_path=None, config_name="config")
def pretrain(cfg: PretrainConfig) -> None:
    # Initialize Distributed Process Group --> assumes NCCL + Environment Variable Initialization (via `torchrun`)
    dist.init_process_group(backend="nccl", init_method="env://")
    device_id = dist.get_rank() % torch.cuda.device_count()
    is_rank_zero, rank, world_size = dist.get_rank() == 0, dist.get_rank(), dist.get_world_size()

    # Create Unique Run Name -- `resume = True` we assume the same "run_id"
    if cfg.run_id is None:
        cfg.run_id = run_dir = f"{cfg.model.identifier}+{cfg.dataset.name}-ddp-x{cfg.seed}"
    else:
        run_dir = cfg.run_id

    # Setup Logging (Rank 0 Only!) and Directory Handling
    overwatch = logging.getLogger(__file__)
    overwatch.setLevel(logging.INFO if is_rank_zero else logging.ERROR)
    overwatch.info("Voltron Training :: Assembling the Legendary Defender...")
    if is_rank_zero:
        os.makedirs(run_dir, exist_ok=True)

    # Let's Get Started!
    overwatch.info(
        '\t=>> "If you get too worried about what could go wrong, you might miss a chance to do something great."'
    )

    # Set Randomness & Get Dataloader `worker_init_fn` to ensure proper randomness in augmentations (if any)
    worker_init_fn = set_global_seed(cfg.seed, get_worker_init_fn=True)

    # Initialize Model & Optimizer --> Wrap in DDP / Device Handling
    #   > Note :: For (Standard) DDP Training --> initializing Optimizer before DDP == initializing after!
    overwatch.info("Initializing Model, Optimizer, and Learning Rate Scheduler")
    model, optimizer, update_lr = get_model_optimizer(cfg.model, cfg.dataset)
    model = DDP(model.to(device_id), device_ids=[device_id], output_device=device_id)

    # Handle Resume / Checkpoint Loading
    resume_checkpoint, resume_epoch, resume_step = do_resume(cfg.resume, run_dir=run_dir)
    if resume_checkpoint is not None:
        # IMPORTANT --> Load weights by mapping specifically to `cuda:<device_id>`!
        resume_state = torch.load(resume_checkpoint, map_location=f"cuda:{device_id}")
        model.load_state_dict(resume_state["model_state_dict"])
        optimizer.load_state_dict(resume_state["optimizer_state_dict"])

        dist.barrier()

    # Create Checkpoint Saver and Save Initial Checkpoint
    saver = CheckpointSaver(cfg.tracking.checkpoint_strategy, run_dir, is_rank_zero=is_rank_zero)
    if resume_checkpoint is None and resume_epoch == 0:
        overwatch.info("  | Saving 0th Epoch Checkpoint (Model Initialization)")
        saver.save(
            epoch=0, is_local_step=False, model=model, optimizer=optimizer, duration=0, train_loss=None, val_loss=None
        )
        dist.barrier()

    # Get Datasets --> Barrier after I/O Intensive Operation
    overwatch.info(f"Retrieving Dataset `{cfg.dataset.name}` prepared for Model `{cfg.model.arch}`")
    train_dataset, val_dataset = get_datasets(
        0,
        cfg.dataset.name,
        cfg.model.arch,
        cfg.dataset.artifact_path,
        cfg.model.data_modality,
        cfg.dataset.resolution,
        cfg.dataset.normalization,
        cfg.model.get("lang_dropout", None),
        cfg.model.get("gen_ratio", None),
    )
    dist.barrier()

    # Create Metrics =>> Handles on-the-fly computation, logging to JSONL and Weights & Biases
    metrics = Metrics(
        active_loggers=cfg.tracking.active_loggers,
        run_id=cfg.run_id,
        hparams=OmegaConf.to_container(cfg),
        model_arch=cfg.model.arch,
        is_rank_zero=is_rank_zero,
        tracking_cfg=cfg.tracking,
        tags=cfg.tracking.tags,
        resume=cfg.resume,
        resume_id=cfg.wandb_resume_id,
    )
    dist.barrier()

    # Configure Gradient Accumulation --> function of `effective_bsz`, `native_bsz`, and `WORLD_SIZE`
    assert cfg.model.effective_bsz % cfg.model.native_bsz == 0, "Device `native_bsz` must evenly divide `effective_bsz`"
    accumulate_grad_batches = cfg.model.effective_bsz // cfg.model.native_bsz // world_size
    overwatch.info(f"Running `{cfg.model.identifier}` Model Pretraining with Parameters =>")
    overwatch.info(f"  | Effective Batch Size = `{cfg.model.effective_bsz}`")
    overwatch.info(f"  | Per-Device Batch Size = `{cfg.model.native_bsz}`")
    overwatch.info(f"  | Distributed World Size = `{world_size}`")
    overwatch.info(f"  | Accumulation Steps = `{accumulate_grad_batches}`")

    # Start Train Loop --> Iterate through Epochs (Evaluation at end of Epoch)
    overwatch.info("Starting Training Loop")
    for epoch in range(resume_epoch, cfg.dataset.max_epochs):
        overwatch.info(f"  | [Epoch {epoch:03d}] Building Distributed Sampler & DataLoaders")
        train_dataset.set_epoch(epoch)
        dist.barrier()

        # [Custom] ResumeableDistributedSampler operates over *examples* --> start_step (full batches) * effective_bsz
        train_sampler = ResumeableDistributedSampler(
            seen_examples=resume_step * cfg.model.effective_bsz,
            resume_epoch=resume_epoch,
            dataset=train_dataset,
            num_replicas=world_size,
            rank=rank,
            shuffle=True,
            seed=cfg.seed,
        )
        train_sampler.set_epoch(epoch)
        val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank, shuffle=False, drop_last=True)

        # Create Epoch DataLoaders
        train_dl = DataLoader(
            train_dataset,
            batch_size=cfg.model.native_bsz,
            sampler=train_sampler,
            shuffle=False,
            num_workers=cfg.accelerator.num_workers,
            drop_last=True,
            pin_memory=True,
            prefetch_factor=4,
            worker_init_fn=worker_init_fn,
        )
        val_dl = DataLoader(
            val_dataset, batch_size=cfg.model.native_bsz, sampler=val_sampler, shuffle=False, num_workers=4
        )

        # Book-Keeping =>> Set LR when `resume = True` (or starting from scratch)
        if epoch == resume_epoch or epoch == 0:
            metrics.resume_time = (
                int(re.search("-t=(.+?).pt", str(resume_checkpoint)).group(1)) if resume_checkpoint is not None else 0
            )
            metrics.commit(
                global_step=resume_step + ((len(train_dataset) // cfg.model.effective_bsz) * resume_epoch),
                lr=update_lr(resume_epoch, resume_step / (len(train_dataset) // cfg.model.effective_bsz)),
                update_step_time=True,
            )

        # === Train Epoch ===
        model.train()
        status = metrics.get_status(epoch)
        overwatch.info(f"  | [Epoch {epoch:03d}] Running Train Loop")
        with tqdm(
            total=len(train_dl) // accumulate_grad_batches, desc=status, leave=False, disable=not is_rank_zero
        ) as progress:
            for train_idx, batch in enumerate(train_dl):
                # Model-Specific Handling
                if cfg.model.arch == "v-mvp":
                    img = batch
                    loss, _, _ = model(img.to(device_id, non_blocking=True))
                    metrics.commit(reconstruction_loss=loss)

                elif cfg.model.arch in {"v-r3m", "v-rn3m"}:
                    imgs, lang, lang_mask = batch
                    loss, tcn_loss, reward_loss, l1_loss, l2_loss, tcn_acc, rew_acc = model(
                        imgs.to(device_id, non_blocking=True),
                        lang.to(device_id, non_blocking=True),
                        lang_mask.to(device_id, non_blocking=True),
                    )
                    metrics.commit(
                        tcn_loss=tcn_loss,
                        reward_loss=reward_loss,
                        l1_loss=l1_loss,
                        l2_loss=l2_loss,
                        tcn_accuracy=tcn_acc,
                        reward_accuracy=rew_acc,
                    )

                elif cfg.model.arch == "v-cond":
                    img, lang, lang_mask = batch
                    loss, _, _ = model(
                        img.to(device_id, non_blocking=True),
                        lang.to(device_id, non_blocking=True),
                        lang_mask.to(device_id, non_blocking=True),
                    )
                    metrics.commit(reconstruction_loss=loss)

                elif cfg.model.arch == "v-dual":
                    imgs, lang, lang_mask = batch
                    loss, [zero_loss, k_loss] = model(
                        imgs.to(device_id, non_blocking=True),
                        lang.to(device_id, non_blocking=True),
                        lang_mask.to(device_id, non_blocking=True),
                    )
                    metrics.commit(
                        reconstruction_loss=loss,
                        zero_reconstruction_loss=zero_loss,
                        k_reconstruction_loss=k_loss,
                    )

                elif cfg.model.arch == "v-gen":
                    imgs, lang_con, lang_con_mask, lang_gen, lang_gen_mask, lang_gen_weight = batch
                    loss, reconstruction_loss, lm_loss, [zero_loss, k_loss] = model(
                        imgs.to(device_id, non_blocking=True),
                        lang_con.to(device_id, non_blocking=True),
                        lang_con_mask.to(device_id, non_blocking=True),
                        lang_gen.to(device_id, non_blocking=True),
                        lang_gen_mask.to(device_id, non_blocking=True),
                        lang_gen_weight,
                    )
                    metrics.commit(
                        reconstruction_loss=reconstruction_loss,
                        zero_reconstruction_loss=zero_loss,
                        k_reconstruction_loss=k_loss,
                        lm_loss=lm_loss,
                        lm_ppl=torch.exp(lm_loss),
                    )

                else:
                    raise ValueError(f"Forward() for Model `{cfg.model.arch}` is not implemented!")

                # Commit Loss (Prior to Normalization)
                metrics.commit(loss=loss)

                # Normalize Loss to account for Gradient Accumulation --> Backward!
                normalized_loss = loss / accumulate_grad_batches
                normalized_loss.backward()

                # Step =>> Check if done w/ Gradient Accumulation
                if (train_idx + 1) % accumulate_grad_batches == 0:
                    metrics.commit(update_step_time=True)

                    # Push Metrics every `log_frequency` steps...
                    if metrics.global_step % cfg.tracking.log_frequency == 0:
                        status = metrics.push(epoch)

                    # Optimizer Step --> Increment Global Step, Learning Rate, and Checkpoint (if specified)
                    optimizer.step()
                    optimizer.zero_grad()
                    lr = update_lr(
                        epoch,
                        (resume_step + ((train_idx + 1) // accumulate_grad_batches))
                        / (len(train_dataset) // cfg.model.effective_bsz),
                    )
                    metrics.commit(global_step=metrics.global_step + 1, lr=lr)
                    saver.save(
                        epoch,
                        is_local_step=True,
                        model=model,
                        optimizer=optimizer,
                        duration=int(time.time() - metrics.start_time) + metrics.resume_time,
                        local_step=resume_step + ((train_idx + 1) // accumulate_grad_batches),
                    )

                    # Update Progress Bar
                    progress.update()
                    progress.set_description(status)

        # === After Train Epoch --> Clear Gradients and reset `resume_step` ===
        optimizer.zero_grad()
        resume_step = 0

        # === Validation ===
        overwatch.info(f"  | [Epoch {epoch:03d}] Running Validation Loop")
        model.eval()

        # Accumulate `validation_losses` in order to `all_reduce` later!
        val_losses = []
        with torch.no_grad():
            for batch in tqdm(val_dl, disable=not is_rank_zero, leave=False):
                # Model-Specific Handling
                if cfg.model.arch == "v-mvp":
                    img = batch
                    val_loss, _, _ = model(img.to(device_id, non_blocking=True))

                elif cfg.model.arch in {"v-r3m", "v-rn3m"}:
                    imgs, lang, lang_mask = batch
                    val_loss, _, _, _, _, _, _ = model(
                        imgs.to(device_id, non_blocking=True),
                        lang.to(device_id, non_blocking=True),
                        lang_mask.to(device_id, non_blocking=True),
                    )

                elif cfg.model.arch == "v-cond":
                    img, lang, lang_mask = batch
                    val_loss, _, _ = model(
                        img.to(device_id, non_blocking=True),
                        lang.to(device_id, non_blocking=True),
                        lang_mask.to(device_id, non_blocking=True),
                    )

                elif cfg.model.arch == "v-dual":
                    imgs, lang, lang_mask = batch
                    val_loss, _ = model(
                        imgs.to(device_id, non_blocking=True),
                        lang.to(device_id, non_blocking=True),
                        lang_mask.to(device_id, non_blocking=True),
                    )

                elif cfg.model.arch == "v-gen":
                    imgs, lang_con, lang_con_mask, lang_gen, lang_gen_mask, lang_gen_weight = batch
                    val_loss, _, _, _ = model(
                        imgs.to(device_id, non_blocking=True),
                        lang_con.to(device_id, non_blocking=True),
                        lang_con_mask.to(device_id, non_blocking=True),
                        lang_gen.to(device_id, non_blocking=True),
                        lang_gen_mask.to(device_id, non_blocking=True),
                        lang_gen_weight,
                    )

                else:
                    raise ValueError(f"Forward() for Model `{cfg.model.arch}` is not implemented!")

                # Add to Validation Losses
                val_losses.append(val_loss)

        # All Reduce --> Push Epoch Metrics --> Checkpoint!
        validation_loss = torch.stack(val_losses).mean()
        dist.all_reduce(validation_loss)
        avg_val_loss = validation_loss / world_size
        if is_rank_zero:
            epoch_status, train_loss, training_duration = metrics.push_epoch(epoch, avg_val_loss)
            saver.save(
                epoch=epoch + 1,
                is_local_step=False,
                model=model,
                optimizer=optimizer,
                duration=training_duration,
                train_loss=train_loss.item(),
                val_loss=avg_val_loss.item(),
            )

        # === End of Epoch ===
        dist.barrier()

    # Finalize
    metrics.finalize()

    # And... we're done!
    overwatch.info("...and that's all, folks!")
    dist.barrier()


if __name__ == "__main__":
    # General Defaults --> should use Tensor Cores (kinda) if you have them!
    torch.set_float32_matmul_precision("high")
    torch.multiprocessing.set_start_method("spawn", force=True)

    pretrain()


================================================
FILE: examples/usage.py
================================================
"""
usage.py

Example script demonstrating how to load a Voltron model (`V-Cond`) and instantiate a Multiheaded Attention Pooling
extractor head for downstream tasks.

This is the basic formula/protocol for using Voltron for arbitrary downstream applications.

Run with (from root of repository): `python examples/usage.py`
"""
import torch
from torchvision.io import read_image

from voltron import instantiate_extractor, load


def usage() -> None:
    print("[*] Demonstrating Voltron Usage for Various Adaptation Applications")

    # Get `torch.device` for loading model (note -- we'll load weights directly onto device!)
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Load Voltron model --> specify `freeze`, `device` and get model (nn.Module) and preprocessor
    vcond, preprocess = load("v-cond", device=device, freeze=True)

    # Obtain and preprocess an image =>> can be from a dataset, from a camera on a robot, etc.
    img = preprocess(read_image("examples/img/peel-carrot-initial.png"))[None, ...].to(device)
    lang = ["peeling a carrot"]

    # Get various representations...
    with torch.no_grad():
        multimodal_features = vcond(img, lang, mode="multimodal")  # Fused vision & language features
        visual_features = vcond(img, mode="visual")  # Vision-only features (no language)

    # Can instantiate various extractors for downstream applications
    vector_extractor = instantiate_extractor(vcond, n_latents=1, device=device)()
    seq_extractor = instantiate_extractor(vcond, n_latents=64, device=device)()

    # Assertions...
    assert list(vector_extractor(multimodal_features).shape) == [1, vcond.embed_dim], "Should return a dense vector!"
    assert list(seq_extractor(visual_features).shape) == [1, 64, vcond.embed_dim], "Should return a sequence!"


if __name__ == "__main__":
    usage()


================================================
FILE: examples/verification/verify.py
================================================
"""
verify.py

Example script demonstrating how to load all Voltron models (and reproduced models), take input image(s), and get the
various (e.g., multimodal, image-only) representations.

Also serves to verify that representation loading is working as advertised.

Run with (from root of repository): `python examples/verification/verify.py`
"""
import torch
from torchvision.io import read_image

from voltron import load

# Available Models
MODELS = ["v-cond", "v-dual", "v-gen", "r-mvp", "r-r3m-vit", "r-r3m-rn50"]

# Sample Inputs
IMG_A, IMG_B = "examples/verification/img/peel-carrot-initial.png", "examples/verification/img/peel-carrot-final.png"
LANGUAGE = "peeling a carrot"


def verify() -> None:
    print("[*] Running `verify` =>> Verifying Model Representations!")

    # Read both images (we'll use the second image for the dual-frame models)
    image_a, image_b = read_image(IMG_A), read_image(IMG_B)

    # Get `torch.device` for loading model (note -- we'll load weights directly onto device!)
    device = "cuda" if torch.cuda.is_available() else "cpu"

    for model_id in MODELS:
        print(f"\t=> Loading Model ID `{model_id}` and Verifying Representation Shapes!")
        model, preprocess = load(model_id, device=device, freeze=True)

        # Preprocess image, run feature extraction --> assert on shapes!
        if model_id in {"v-cond", "v-cond-base"}:
            for modality, expected in [("multimodal", 196 + 20), ("visual", 196)]:
                representation = model(preprocess(image_a)[None, ...].to(device), [LANGUAGE], mode=modality)
                assert representation.squeeze(dim=0).shape[0] == expected, "Shape not expected!"

        elif model_id in {"v-dual", "v-gen"}:
            for modality, expected in [("multimodal", 196 + 20), ("visual", 196)]:
                dual_img = torch.stack([preprocess(image_a), preprocess(image_b)])[None, ...].to(device)
                representation = model(dual_img, [LANGUAGE], mode=modality)
                assert representation.squeeze(dim=0).shape[0] == expected, "Shape not expected!"

        elif model_id == "r-mvp":
            for mode, expected in [("patch", 196), ("cls", 1)]:
                representation = model(preprocess(image_a)[None, ...].to(device), mode=mode)
                assert representation.squeeze(dim=0).shape[0] == expected, "Shape not expected!"

        elif model_id in {"r-r3m-vit", "r-r3m-rn50"}:
            representation = model(preprocess(image_a)[None, ...].to(device))
            assert representation.squeeze(dim=0).shape[0] == 1, "Shape not expected!"

        else:
            raise ValueError(f"Model {model_id} not supported!")

    # We're good!
    print("[*] All representations & shapes verified! Yay!")


if __name__ == "__main__":
    verify()


================================================
FILE: examples/xla-reference/README.md
================================================
# XLA Reference

*Note :: This code was written for the experimental PyTorch XLA build in PyTorch 1.12; no guarantees it works with later
versions!*

We trained the original Voltron models (and data-locked reproductions of R3M and MVP) on TPU v3-8 nodes generously
provided by the [TPU Research Cloud (TRC)](https://sites.research.google/trc/about/) program. At the time we started
the project, PyTorch XLA still had some bumps, which was further complicated by the switch from
[TPU Nodes to TPU VMs](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu-arch).

To get things to work, we had to add some non-intuitive code to facilitate PyTorch + TPUs (vs. a standard distributed
data parallel training pipeline). As a result, `xpretrain.py` is here mostly for documentation purposes, with a fully
refactored version `pretrain.py` forthcoming.

We also include the original cloud preprocessing script `xpreprocess.py` for completeness (this is more general).


================================================
FILE: examples/xla-reference/xpreprocess.py
================================================
"""
xpreprocess.py

Centralized script for preprocessing Sth-Sth-v2 for TPU/GCP pretraining, using a multi-stage, multiprocessing strategy.

Run as a standalone script, *prior* to calling `xpretrain.py` =>> mostly because we want to preprocess the data
once, as a fixed cost.
"""
import logging
from dataclasses import dataclass, field
from typing import Any, Dict, List

import hydra
from hydra.core.config_store import ConfigStore
from omegaconf import MISSING

from voltron.conf import DatasetConfig
from voltron.overwatch import OverwatchRich
from voltron.preprocessing.v1 import index, jsonify_language, preprocess_language, preprocess_videos, unify_batches
from voltron.util.v1.random import set_global_seed

# Grab Logger
overwatch = logging.getLogger(__file__)


# Set Defaults (Hydra w/ Structured Configs)
DEFAULTS = ["_self_", {"dataset": "sth-sth-v2"}, {"override hydra/job_logging": "overwatch_rich"}]


@dataclass
class PreprocessingConfig:
    # fmt: off
    defaults: List[Any] = field(default_factory=lambda: DEFAULTS)
    hydra: Dict[str, Any] = field(
        default_factory=lambda: {"run": {"dir": "./runs/preprocessing/${now:%m-%d}/dataset-${dataset.name}"}}
    )

    # Command Line Arguments
    seed: int = 21                                  # Random Seed (for reproducibility)
    dry_run: bool = False                           # Dry Run --> Get a sense of preprocessing/serialization footprint

    # Composable / Structured Arguments
    dataset: DatasetConfig = MISSING                # Dataset(s) for pretraining/preprocessing
    # fmt: on


# Hydra Setup :: Retrieve ConfigStore (Singleton) & Register Components
cs = ConfigStore.instance()
cs.store(group="hydra/job_logging", name="overwatch_rich", node=OverwatchRich)
cs.store(name="config", node=PreprocessingConfig)


@hydra.main(config_path=None, config_name="config")
def xpreprocess(cfg: PreprocessingConfig) -> None:
    overwatch.info("Preprocessing :: Running Phases for Frame Extraction, Language Compilation, and Batching...")

    # Set Randomness
    set_global_seed(cfg.seed)

    # Phase 1 :: Serialize Frames from Video Clips --> Get `registry` for train and val (index structure)
    train_registry, val_registry, train_dir, val_dir = preprocess_videos(
        cfg.dataset.name,
        path=cfg.dataset.path,
        artifact_path=cfg.dataset.artifact_path,
        resolution=cfg.dataset.resolution,
        n_val_videos=cfg.dataset.n_val_videos,
        dry_run=cfg.dry_run,
    )

    # Phase 2 :: Normalize & Tokenize Language  --> Create `index.pt` & `index.json` files
    preprocess_language(
        cfg.dataset.name,
        train_registry,
        val_registry,
        max_lang_len=cfg.dataset.max_lang_len,
        language_model=cfg.dataset.language_model,
        hf_cache=cfg.dataset.hf_cache,
    )
    jsonify_language(train_registry, val_registry)
    index_dir = index(train_registry, val_registry, cfg.dataset.name, artifact_path=cfg.dataset.artifact_path)

    # Phase 3 :: Assemble & Unify Batch "Sets" across the Varied Dataset Formats (for each Model =>> "data-locked")
    unify_batches(
        cfg.dataset.artifact_path,
        cfg.dataset.name,
        train_registry,
        val_registry,
        train_dir,
        val_dir,
        index_dir,
        cfg.dataset.batch_formats,
        max_epochs=cfg.dataset.max_epochs,
        initial_final_alpha=cfg.dataset.initial_final_alpha,
    )


if __name__ == "__main__":
    xpreprocess()


================================================
FILE: examples/xla-reference/xpretrain.py
================================================
"""
xpretrain.py

(The `x` prefix indicates this is a script geared for XLA/TPU backends *only*)!

Reference script for PyTorch XLA (TPU-based) pretraining on the Something-Something-v2 dataset; this is
mostly for completeness =>> the hope is that the regular `pretrain.py` script is more general and maintained.

Focuses on multi-TPU (XLA) training --> but also supports single-core TPU training, as the default distributed mp.spawn
behavior just collapses into a single thread! Loads and preprocesses dataset, instantiates a model, and runs training.

Run with `python examples/xla-reference/xpretrain.py` (will use the configuration specified by `DEFAULTS` below).
"""
import os
import re
import time
from collections import deque
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional

import hydra
import jsonlines
import numpy as np
import torch
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as parallel
import wandb
from hydra.core.config_store import ConfigStore
from omegaconf import MISSING, OmegaConf
from torch.utils.data import DataLoader
from tqdm import tqdm

from voltron.conf import AcceleratorConfig, DatasetConfig, ModelConfig, TrackingConfig
from voltron.datasets.v1.stream_datasets import get_epoch_datasets
from voltron.models import VMVP, VR3M, VRN3M, VCond, VDual, VGen
from voltron.overwatch import OverwatchRich
from voltron.util.v1.checkpointing import XLACheckpointSaver
from voltron.util.v1.distributed import ResumeableDistributedSampler
from voltron.util.v1.random import set_global_seed
from voltron.util.v1.xla_logger import (
    log_epoch_end_update,
    log_vcond_train_update,
    log_vdual_train_update,
    log_vgen_train_update,
    log_vmvp_train_update,
    log_vr3m_train_update,
    log_vrn3m_train_update,
)

# Set Defaults (Hydra w/ Structured Configs)
DEFAULTS = [
    "_self_",
    {"model": "v-cond"},
    {"dataset": "sth-sth-v2"},
    {"accelerator": "tpu-v3-8"},
    {"tracking": "voltron-tracking"},
    {"override hydra/job_logging": "overwatch_rich"},
]


@dataclass
class PretrainConfig:
    # fmt: off
    defaults: List[Any] = field(default_factory=lambda: DEFAULTS)
    hydra: Dict[str, Any] = field(default_factory=lambda: {
        "run": {"dir": "./runs/train/${model.identifier}+dataset-${dataset.name}"}
    })

    # Command Line Arguments
    run_id: Optional[str] = None                                        # Run ID for Logging
    seed: int = 21                                                      # Random Seed (for reproducibility)

    # Resume / Debug Behavior
    resume: bool = True                                                 # Whether to resume an existing run...
    resume_epoch: Optional[int] = None                                  # Epoch to resume (if auto-resuming)...
    checkpoint_path: Optional[str] = None                               # Path to the specific checkpoint to load!
    wandb_resume_id: Optional[str] = None                               # W&B Run ID for `resume` behavior...

    # Composable / Structured Arguments
    model: ModelConfig = MISSING                                        # Model architecture for pretraining
    dataset: DatasetConfig = MISSING                                    # List of datasets for pretraining
    accelerator: AcceleratorConfig = MISSING                            # Accelerator configuration
    tracking: TrackingConfig = MISSING                                  # Run/experiment tracking configuration
    # fmt: on


# Hydra Setup :: Retrieve ConfigStore (Singleton) & Register Components
cs = ConfigStore.instance()
cs.store(group="hydra/job_logging", name="overwatch_rich", node=OverwatchRich)  # Annoying - configure logger for Hydra
cs.store(name="config", node=PretrainConfig)


# ruff: noqa: C901
def xpretrain(cfg: PretrainConfig) -> None:
    # Identify if `is_rank_zero` --> We only log from the rank zero process!
    is_rank_zero = xm.is_master_ordinal(local=False)
    xm.master_print("Voltron Training :: Assembling the Legendary Defender...")

    # Create Unique Run Name -- if `resume = True` we assume same "run_id"
    run_id = cfg.run_id
    if run_id is None:
        run_id = run_dir = f"{cfg.model.identifier}+{cfg.dataset.name}-x{cfg.seed}"
        cfg.run_id = run_id
    else:
        cfg.run_id = run_dir = run_id

    if is_rank_zero:
        os.makedirs(run_dir, exist_ok=True)

    xm.master_print(
        '\t=>> "If you get too worried about what could go wrong, you might miss a chance to do something great."'
    )

    # Set Randomness, get DataLoader worker initialization function (to ensure any random augmentations!)
    worker_init_fn = set_global_seed(cfg.seed)

    # Model Initialization Logic
    xm.master_print("Initializing Model and Placing on Different Devices...")
    if cfg.model.arch == "v-mvp":
        xm.master_print(f"Initializing MVP variant `{cfg.model.identifier}`")
        model = VMVP(
            resolution=cfg.dataset.resolution,
            patch_size=cfg.model.patch_size,
            encoder_depth=cfg.model.encoder_depth,
            encoder_embed_dim=cfg.model.encoder_embed_dim,
            encoder_n_heads=cfg.model.encoder_n_heads,
            decoder_depth=cfg.model.decoder_depth,
            decoder_embed_dim=cfg.model.decoder_embed_dim,
            decoder_n_heads=cfg.model.decoder_n_heads,
            optimizer=cfg.model.optimizer,
            schedule=cfg.model.schedule,
            base_lr=cfg.model.base_lr,
            min_lr=cfg.model.min_lr,
            effective_bsz=cfg.model.effective_bsz,
            betas=cfg.model.betas,
            weight_decay=cfg.model.weight_decay,
            warmup_epochs=cfg.dataset.warmup_epochs,
            max_epochs=cfg.dataset.max_epochs,
            mlp_ratio=cfg.model.mlp_ratio,
            norm_pixel_loss=cfg.model.norm_pixel_loss,
        )

    elif cfg.model.arch == "v-r3m":
        xm.master_print(f"Initializing R3M (ViT) Variant `{cfg.model.identifier}`")
        model = VR3M(
            resolution=cfg.dataset.resolution,
            patch_size=cfg.model.patch_size,
            depth=cfg.model.depth,
            embed_dim=cfg.model.embed_dim,
            n_heads=cfg.model.n_heads,
            language_model=cfg.model.language_model,
            hf_cache=cfg.model.hf_cache,
            language_dim=cfg.model.language_dim,
            reward_dim=cfg.model.reward_dim,
            n_negatives=cfg.model.n_negatives,
            lang_reward_weight=cfg.model.lang_reward_weight,
            tcn_weight=cfg.model.tcn_weight,
            l1_weight=cfg.model.l1_weight,
            l2_weight=cfg.model.l2_weight,
            optimizer=cfg.model.optimizer,
            schedule=cfg.model.schedule,
            lr=cfg.model.lr,
            min_lr=cfg.model.min_lr,
            warmup_epochs=cfg.dataset.warmup_epochs,
            max_epochs=cfg.dataset.max_epochs,
            mlp_ratio=cfg.model.mlp_ratio,
        )

    elif cfg.model.arch == "v-rn3m":
        xm.master_print(f"Intializing R3M (ResNet) Variant `{cfg.model.identifier}`")
        model = VRN3M(
            resolution=cfg.dataset.resolution,
            fc_dim=cfg.model.fc_dim,
            language_model=cfg.model.language_model,
            hf_cache=cfg.model.hf_cache,
            language_dim=cfg.model.language_dim,
            reward_dim=cfg.model.reward_dim,
            n_negatives=cfg.model.n_negatives,
            lang_reward_weight=cfg.model.lang_reward_weight,
            tcn_weight=cfg.model.tcn_weight,
            l1_weight=cfg.model.l1_weight,
            l2_weight=cfg.model.l2_weight,
            optimizer=cfg.model.optimizer,
            lr=cfg.model.lr,
        )

    elif cfg.model.arch == "v-cond":
        xm.master_print(f"Initializing Voltron V-Cond variant `{cfg.model.identifier}`")
        model = VCond(
            resolution=cfg.dataset.resolution,
            patch_size=cfg.model.patch_size,
            encoder_depth=cfg.model.encoder_depth,
            encoder_embed_dim=cfg.model.encoder_embed_dim,
            encoder_n_heads=cfg.model.encoder_n_heads,
            decoder_depth=cfg.model.decoder_depth,
            decoder_embed_dim=cfg.model.decoder_embed_dim,
            decoder_n_heads=cfg.model.decoder_n_heads,
            language_model=cfg.model.language_model,
            hf_cache=cfg.model.hf_cache,
            language_dim=cfg.model.language_dim,
            optimizer=cfg.model.optimizer,
            schedule=cfg.model.schedule,
            base_lr=cfg.model.base_lr,
            min_lr=cfg.model.min_lr,
            effective_bsz=cfg.model.effective_bsz,
            betas=cfg.model.betas,
            weight_decay=cfg.model.weight_decay,
            warmup_epochs=cfg.dataset.warmup_epochs,
            max_epochs=cfg.dataset.max_epochs,
            mlp_ratio=cfg.model.mlp_ratio,
            norm_pixel_loss=cfg.model.norm_pixel_loss,
        )

    elif cfg.model.arch == "v-dual":
        xm.master_print(f"Initializing Voltron V-Dual variant `{cfg.model.identifier}`")
        model = VDual(
            resolution=cfg.dataset.resolution,
            patch_size=cfg.model.patch_size,
            encoder_depth=cfg.model.encoder_depth,
            encoder_embed_dim=cfg.model.encoder_embed_dim,
            encoder_n_heads=cfg.model.encoder_n_heads,
            decoder_depth=cfg.model.decoder_depth,
            decoder_embed_dim=cfg.model.decoder_embed_dim,
            decoder_n_heads=cfg.model.decoder_n_heads,
            language_model=cfg.model.language_model,
            hf_cache=cfg.model.hf_cache,
            language_dim=cfg.model.language_dim,
            optimizer=cfg.model.optimizer,
            schedule=cfg.model.schedule,
            base_lr=cfg.model.base_lr,
            min_lr=cfg.model.min_lr,
            effective_bsz=cfg.model.effective_bsz,
            betas=cfg.model.betas,
            weight_decay=cfg.model.weight_decay,
            warmup_epochs=cfg.dataset.warmup_epochs,
            max_epochs=cfg.dataset.max_epochs,
            mlp_ratio=cfg.model.mlp_ratio,
            norm_pixel_loss=cfg.model.norm_pixel_loss,
        )

    elif cfg.model.arch == "v-gen":
        xm.master_print(f"Initializing Voltron V-Gen variant `{cfg.model.identifier}`")
        model = VGen(
            resolution=cfg.dataset.resolution,
            patch_size=cfg.model.patch_size,
            encoder_depth=cfg.model.encoder_depth,
            encoder_embed_dim=cfg.model.encoder_embed_dim,
            encoder_n_heads=cfg.model.encoder_n_heads,
            decoder_depth=cfg.model.decoder_depth,
            decoder_embed_dim=cfg.model.decoder_embed_dim,
            decoder_n_heads=cfg.model.decoder_n_heads,
            language_model=cfg.model.language_model,
            hf_cache=cfg.model.hf_cache,
            language_dim=cfg.model.language_dim,
            max_lang_len=cfg.dataset.max_lang_len,
            vocab_size=cfg.model.vocab_size,
            mae_weight=cfg.model.mae_weight,
            lm_weight=cfg.model.lm_weight,
            optimizer=cfg.model.optimizer,
            schedule=cfg.model.schedule,
            base_lr=cfg.model.base_lr,
            min_lr=cfg.model.min_lr,
            effective_bsz=cfg.model.effective_bsz,
            betas=cfg.model.betas,
            weight_decay=cfg.model.weight_decay,
            warmup_epochs=cfg.dataset.warmup_epochs,
            max_epochs=cfg.dataset.max_epochs,
            mlp_ratio=cfg.model.mlp_ratio,
            norm_pixel_loss=cfg.model.norm_pixel_loss,
        )

    else:
        raise NotImplementedError(f"Model Architecture `{cfg.model.arch}` is not supported!")

    # We use gradient accumulation to honor the effective batch size specified...
    assert cfg.model.effective_bsz % cfg.model.device_bsz == 0, "Device bsz must evenly divide effective bsz!"
    accumulate_grad_batches = cfg.model.effective_bsz // cfg.model.device_bsz // xm.xrt_world_size()
    xm.master_print(
        f"Running `{cfg.model.identifier}` model w/ Effective Batch Size of `{cfg.model.effective_bsz}`, "
        f"Per-Device Batch Size of `{cfg.model.device_bsz}`, "
        f"Distributed World Size of `{xm.xrt_world_size()}` and `{accumulate_grad_batches}` Accumulation Steps"
    )

    # If Resuming =>> Load Model from Checkpoint
    start_checkpoint, start_epoch, start_step = None, 0, 0
    if cfg.resume:
        # **IMPORTANT**: We're making a few assumptions on resuming that should eventually become explicit checks:
        #   - `accumulate_grad_batches` is exactly the same when resuming; this means:
        #       + `cfg.model.effective_bsz`, `cfg.model.device_bsz`, & `cfg.accelerator.num_accelerators` are the same!
        #   - The Weights & Biases directory `run_dir/wandb` only contains a *single run*
        #   - The `param_groups` in `optimizer.state_dict()` are exactly the same across resumes!
        #       + This means that (and generally should be true for resuming altogether) the architecture is the same!
        #   - The `cfg.seed` should be the same (again, should generally be true...)
        if cfg.checkpoint_path is None:
            xm.master_print("Resuming :: Attempting to Automatically Load Checkpoint -- Searching!")
            checkpoint_path = Path(run_dir) / "checkpoints"
            if checkpoint_path.exists() and any(checkpoint_path.iterdir()):
                # Parse out the latest "complete" epoch checkpoint, as well as any "local step" checkpoints...
                checkpoints = list(checkpoint_path.iterdir())
                complete_checkpoint, complete_epoch = max(
                    [
                        (c, int(re.search("epoch=(.+?)-train", c.name).group(1)))
                        for c in checkpoints
                        if "local-epoch=" not in str(c)
                    ],
                    key=lambda x: x[1],
                )

                # Case 1 :: We have "local step" checkpoints --> will always override any "full epoch" checkpoints...
                local = [
                    (
                        c,
                        int(re.search("local-epoch=(.+?)-step", c.name).group(1)),
                        int(re.search("step=(.+?)[.-]", c.name).group(1)),
                    )
                    for c in checkpoints
                    if "local-epoch=" in str(c)
                ]
                if len(local) > 0:
                    # Parse out (epoch, "highest" step) + assert no great "full epoch" checkpoint exists!
                    start_checkpoint, start_epoch, start_step = max(local, key=lambda x: x[1:])
                    assert start_epoch == complete_epoch, "Epoch mismatch in `resume` from local_step!"

                # Case 2 :: Otherwise, we're just going to start with the last "complete" epoch...
                else:
                    start_checkpoint, start_epoch = complete_checkpoint, complete_epoch

            else:
                xm.master_print("No Checkpoints Found -- Starting Run from Scratch!")

        else:
            xm.master_print(f"Resuming :: Loading from Checkpoint `{cfg.checkpoint_path}`...")
            start_checkpoint = cfg.checkpoint_path

        # Actually Load the Checkpoint State!
        if start_checkpoint is not None:
            xm.master_print(f"Resuming :: Loading Model & Optimizer State Dictionaries from `{start_checkpoint}`")
            checkpoint = torch.load(str(start_checkpoint))
            model_state_dict, optimizer_state_dict = checkpoint
            model.load_state_dict(model_state_dict)

    # Logging / W&B Handling
    if is_rank_zero:
        xm.master_print("Initializing Weights & Biases + JSONL + Checkpoint Saver on Rank Zero ONLY...")
        tags = None
        if cfg.tracking.tags is None:
            tags = [cfg.model.identifier, cfg.dataset.name, "pretraining"]

        # W&B Initialize & Log all Hyperparameters (Only on ordinal 0)
        wandb_resume_id = None
        if cfg.resume and cfg.wandb_resume_id is None:
            xm.master_print("Resuming :: Attempting to Automatically Load W&B Resume ID -- Searching!")
            wandb_path = Path("wandb")
            if wandb_path.exists() and any((wandb_path / "latest-run").iterdir()):
                # Parse out the unique resume_id from the `.wandb` file...
                wandb_fns = [f.name for f in (wandb_path / "latest-run").iterdir() if str(f).endswith(".wandb")]
                assert len(wandb_fns) == 1, f"There should only be 1 `.wandb` file... found {len(wandb_fns)}!"

                # Regex match on `run-{id}.wandb`...
                wandb_resume_id = re.search("run-(.+?).wandb", wandb_fns[0]).group(1)

            # Otherwise, assert that we're starting from scratch!
            else:
                assert start_checkpoint is None, "Trying to restart a run from checkpoint without a valid W&B ID!"

        elif cfg.resume:
            xm.master_print(f"Resuming :: Using Specified W&B Resume ID = `{cfg.wandb_resume_id}`")
            wandb_resume_id = cfg.wandb_resume_id

        # Initialize Weights & Biases
        xm.master_print(f"W&B Resume is {cfg.resume} w/ W&B Resume ID = {wandb_resume_id}!")
        wandb.init(
            project=cfg.tracking.project,
            entity=cfg.tracking.entity,
            config=cfg,
            name=run_id,
            dir=f"{os.getcwd()}" if cfg.tracking.directory is None else cfg.tracking.directory,
            tags=tags,
            notes=cfg.tracking.notes,
            resume="allow" if start_checkpoint is not None else False,
            id=wandb_resume_id,
            # Weird line because PT-TPU VMs don't come with a working install of Tensorflow...
            settings=wandb.Settings(_disable_stats=True),
        )

        # Initialize JSONL Logger (append only mode) --> last "global step" will always take precedence.
        with jsonlines.open(f"{run_id}.jsonl", mode="a", sort_keys=True) as js_logger:
            js_logger.write(
                {
                    "run_id": run_id,
                    "start_time": datetime.now().strftime("%m-%d-%H:%M"),
                    "hparams": OmegaConf.to_container(cfg),
                }
            )

    # Rank Zero Node will take time to spin up the loggers & checkpointer... might as well rendezvous?
    xm.rendezvous("Logging...")

    # === Here Be Dragons ===
    # Time to handle device placement -- Note - example code doesn't specify device idx - why not?
    #   > https://github.com/pytorch/xla/blob/3c0d68da07702995a592ea70f27868cd76fa0755/test/test_train_mp_mnist.py#L114
    #   > Results in printing [xla:0] and [xla:1] a bunch... no [xla:2-7]? This feels bad...?
    #
    #   |=> Debugging Try: `xm.xla_device(n=xm.get_ordinal()) ---> hangs completely?
    #   +=> *ANSWER*: https://github.com/pytorch/xla/issues/2345#issuecomment-657114819
    #       >> "Make no assumptions and don't try to build them manually..."
    device = xm.xla_device()
    model = model.train().to(device)
    optimizer, update_lr = model.configure_optimizer()
    global_step, train_losses, lrs, start_time, resume_time = 0, deque(maxlen=128), [], time.time(), 0

    # If resuming (valid `start_checkpoint`) -- patch the optimizer state dictionary, and load!
    if start_checkpoint is not None:
        patched_optimizer_state_dict = {
            "state": optimizer_state_dict,
            "param_groups": optimizer.state_dict()["param_groups"],
        }
        optimizer.load_state_dict(patched_optimizer_state_dict)

    # Create step timing...
    step_times, step_start_time = deque(maxlen=128), time.time()

    # Create Model/Architecture-Specific Trackers...
    if cfg.model.arch == "v-mvp":
        reconstruction_losses = deque(maxlen=128)

    elif cfg.model.arch in {"v-r3m", "v-rn3m"}:
        tcn_losses, reward_losses, l1_losses, l2_losses = [deque(maxlen=128) for _ in range(4)]
        tcn_accuracies, reward_accuracies = [deque(maxlen=128) for _ in range(2)]

    elif cfg.model.arch == "v-cond":
        reconstruction_losses = deque(maxlen=128)

    elif cfg.model.arch == "v-dual":
        reconstruction_losses = deque(maxlen=128)
        zero_reconstruction, k_reconstruction = deque(maxlen=128), deque(maxlen=128)

    elif cfg.model.arch == "v-gen":
        reconstruction_losses, lm_losses, lm_ppl = deque(maxlen=128), deque(maxlen=128), deque(maxlen=128)
        zero_reconstruction, k_reconstruction = deque(maxlen=128), deque(maxlen=128)

    else:
        raise NotImplementedError(f"Trackers for Model `{cfg.model.arch}` not implemented!")

    # 0th Checkpoint - Pull out optimizer state explicitly (`groups` are not serializable & can easily be replicated)
    saver = XLACheckpointSaver(cfg.tracking.checkpoint_strategy, run_dir, cfg.accelerator.accelerator)
    if start_checkpoint is None and start_epoch == 0:
        xm.master_print("Saving 0th Epoch Checkpoint...")
        saver.save(
            epoch=0, is_local_step=False, model=model, optimizer=optimizer, duration=0, train_loss=None, val_loss=None
        )

    # Run on all processes --> retrieve "0th epoch" dataset!
    #   =>> Important, ensures data is locked across models, for the given epoch!
    xm.master_print(f"Retrieving Dataset `{cfg.dataset.name}` prepared for `{cfg.model.arch}`!")
    train_dataset, val_dataset = get_epoch_datasets(
        0,
        cfg.dataset.name,
        cfg.dataset.normalization,
        cfg.model.arch,
        cfg.dataset.stream,
        cfg.dataset.artifact_path,
        cfg.dataset.stream_prefix,
        cfg.model.data_modality,
        cfg.model.get("lang_dropout", None),
        cfg.model.get("gen_ratio", None),
    )

    # Loading Datasets might take time... rendezvous to be safe
    xm.rendezvous("Retrieved Datasets...")

    # Iterate through Epochs, Evaluating at the end of each Training Epoch!
    #   >> Everything in this loop should happen across all workers, except for the logging (ordinal 0)!
    xm.master_print("Starting Training Loop...")
    for epoch in range(start_epoch, cfg.dataset.max_epochs):
        xm.master_print(f"\t[Epoch {epoch}] Building Distributed Sampler & DataLoaders...")
        train_dataset.set_epoch(epoch)

        # ResumeableDistributedSampler operates at over *examples* --> start_step (full_batches) * effective_bsz
        seen_examples = start_step * cfg.model.effective_bsz
        train_sampler = ResumeableDistributedSampler(
            seen_examples,
            start_epoch,
            train_dataset,
            num_replicas=xm.xrt_world_size(),
            rank=xm.get_ordinal(),
            shuffle=True,
            seed=cfg.seed,
        )

        # Set epoch appropriately for the `train_sampler` --> necessary to trigger "normal" logic!
        train_sampler.set_epoch(epoch)

        train_dataloader = DataLoader(
            train_dataset,
            batch_size=cfg.model.device_bsz,
            sampler=train_sampler,
            shuffle=False if train_sampler else True,
            num_workers=cfg.accelerator.num_workers,
            drop_last=True,
            worker_init_fn=worker_init_fn,
            prefetch_factor=4,
        )

        # NOTE :: We're not sharding the Validation set --> *everybody* will run forward passes on the *same* data!
        #   > We will have to reduce_mesh() later... unclear why, but the torch_xla folks seem keen on it; might lead to
        #   > weird rendezvous/hang issues if Validation is big enough...
        val_dataloader = DataLoader(
            val_dataset,
            batch_size=cfg.model.device_bsz,
            shuffle=False,
            num_workers=4,
            drop_last=True,
            worker_init_fn=worker_init_fn,
        )

        # Initializing the Dataloaders might take time depending on process...
        xm.rendezvous("Initialized Dataloaders...")

        # Leverage the *special* <XLA ParallelLoader> API that handles synchronizing TPU cores across batches!
        #   > NOTE: This is super important!
        xm.master_print("\tSetting up Parallel MpDeviceLoaders...")
        train_device_loader = parallel.MpDeviceLoader(train_dataloader, device)
        val_device_loader = parallel.MpDeviceLoader(val_dataloader, device)

        # Book-keeping & LR setting when `resuming` --> only do this on start_epoch!
        if epoch == start_epoch:
            if start_checkpoint is not None:
                global_step = start_step + ((len(train_dataset) // cfg.model.effective_bsz) * start_epoch)
                resume_time = int(re.search("-t=(.+?).pt", str(start_checkpoint)).group(1))
                lrs.append(update_lr(start_epoch, start_step / (len(train_dataset) // cfg.model.effective_bsz)))
            else:
                lrs.append(update_lr(start_epoch, 0))

        # Iterate...
        step_start_time = time.time()
        with tqdm(total=len(train_device_loader) // accumulate_grad_batches, disable=not is_rank_zero) as progress:
            for train_idx, batch in enumerate(train_device_loader):
                if cfg.model.arch == "v-mvp":
                    # Run a forward pass through the MAE... other return vals are reconstructions (pixel norm) & mask
                    loss, _, _ = model(batch)
                    reconstruction_losses.append(loss)

                elif cfg.model.arch in {"v-r3m", "v-rn3m"}:
                    imgs, lang, lang_mask = batch
                    loss, tcn_loss, reward_loss, l1_loss, l2_loss, tcn_acc, rew_acc = model(imgs, lang, lang_mask)

                    # Add to trackers
                    tcn_losses.append(tcn_loss)
                    reward_losses.append(reward_loss)
                    l1_losses.append(l1_loss)
                    l2_losses.append(l2_loss)
                    tcn_accuracies.append(tcn_acc)
                    reward_accuracies.append(rew_acc)

                elif cfg.model.arch == "v-cond":
                    img, lang, lang_mask = batch
                    loss, _, _ = model(img, lang, lang_mask)
                    reconstruction_losses.append(loss)

                elif cfg.model.arch == "v-dual":
                    imgs, lang, lang_mask = batch
                    loss, [zero_loss, k_loss] = model(imgs, lang, lang_mask)

                    # Add to trackers
                    reconstruction_losses.append(loss)
                    zero_reconstruction.append(zero_loss)
                    k_reconstruction.append(k_loss)

                elif cfg.model.arch == "v-gen":
                    imgs, lang_con, lang_con_mask, lang_gen, lang_gen_mask, lang_gen_weight = batch
                    loss, reconstruction_loss, lm_loss, [zero_loss, k_loss] = model(
                        imgs, lang_con, lang_con_mask, lang_gen, lang_gen_mask, lang_gen_weight
                    )

                    # Add to trackers
                    reconstruction_losses.append(reconstruction_loss)
                    lm_losses.append(lm_loss)
                    lm_ppl.append(torch.exp(lm_loss))
                    zero_reconstruction.append(zero_loss)
                    k_reconstruction.append(k_loss)

                else:
                    raise NotImplementedError(f"Forward Pass Logic for Model `{cfg.model.arch}` not implemented!")

                # Write Loss to Loggers (prior to accumulation normalization)
                train_losses.append(loss)

                # Normalize loss to account for accumulation
                loss = loss / accumulate_grad_batches
                loss.backward()

                # Gradient Accumulation =>> Note: skip any errant batches at the end...
                if (train_idx + 1) % accumulate_grad_batches == 0:
                    xm.optimizer_step(optimizer)  # Note call to xm.optimizer_step() -- has implicit mark_step!
                    optimizer.zero_grad()

                    # Add to `step_times`
                    step_times.append(time.time() - step_start_time)

                    # Logging --> Because there is no guarantee processes will be in sync, we need a `closure`
                    #   > Ref: https://pytorch.org/xla/release/1.11/index.html#torch_xla.core.xla_model.add_step_closure
                    if is_rank_zero and global_step % cfg.tracking.log_frequency == 0:
                        if cfg.model.arch == "v-mvp":
                            xm.add_step_closure(
                                log_vmvp_train_update,
                                args=(
                                    epoch,
                                    global_step,
                                    run_id,
                                    train_losses,
                                    lrs[-1],
                                    reconstruction_losses,
                                    step_times,
                                ),
                            )

                        elif cfg.model.arch == "v-r3m":
                            xm.add_step_closure(
                                log_vr3m_train_update,
                                args=(
                                    epoch,
                                    global_step,
                                    run_id,
                                    train_losses,
                                    lrs[-1],
                                    tcn_losses,
                                    reward_losses,
                                    l1_losses,
                                    l2_losses,
                                    tcn_accuracies,
                                    reward_accuracies,
                                    step_times,
                                ),
                            )

                        elif cfg.model.arch == "v-rn3m":
                            xm.add_step_closure(
                                log_vrn3m_train_update,
                                args=(
                                    epoch,
                                    global_step,
                                    run_id,
                                    train_losses,
                                    lrs[-1],
                                    tcn_losses,
                                    reward_losses,
                                    l1_losses,
                                    l2_losses,
                                    tcn_accuracies,
                                    reward_accuracies,
                                    step_times,
                                ),
                            )

                        elif cfg.model.arch == "v-cond":
                            xm.add_step_closure(
                                log_vcond_train_update,
                                args=(
                                    epoch,
                                    global_step,
                                    run_id,
                                    train_losses,
                                    lrs[-1],
                                    reconstruction_losses,
                                    step_times,
                                ),
                            )

                        elif cfg.model.arch == "v-dual":
                            xm.add_step_closure(
                                log_vdual_train_update,
                                args=(
                                    epoch,
                                    global_step,
                                    run_id,
                                    train_losses,
                                    lrs[-1],
                                    reconstruction_losses,
                                    zero_reconstruction,
                                    k_reconstruction,
                                    step_times,
                                ),
                            )

                        elif cfg.model.arch == "v-gen":
                            xm.add_step_closure(
                                log_vgen_train_update,
                                args=(
                                    epoch,
                                    global_step,
                                    run_id,
                                    train_losses,
                                    lrs[-1],
                                    reconstruction_losses,
                                    lm_losses,
                                    lm_ppl,
                                    zero_reconstruction,
                                    k_reconstruction,
                                    step_times,
                                ),
                            )

                        else:
                            raise NotImplementedError(f"Log Update for Model `{cfg.model.arch}` not implemented!")

                    # Increment Global Step _after_ logging!
                    global_step += 1

                    # Save checkpoint subject to *local_step = (train_idx + 1) // accumulate_grad_batches*
                    saver.save(
                        epoch=epoch,
                        is_local_step=True,
                        model=model,
                        optimizer=optimizer,
                        duration=int(time.time() - start_time) + resume_time,
                        local_step=start_step + ((train_idx + 1) // accumulate_grad_batches),
                    )

                    # Update LR every `accumulation_steps` iterations...
                    lrs.append(
                        update_lr(
                            epoch,
                            (start_step + ((train_idx + 1) // accumulate_grad_batches))
                            / (len(train_dataset) // cfg.model.effective_bsz),
                        )
                    )

                    # Reset `step_start_time`
                    step_start_time = time.time()

                    # Update `progress` each time we take a gradient step!
                    progress.update()

                # After each forward pass, mark a step, to compile XLA graph for a single forward pass!
                #   =>> Note :: this is important, with gradient accumulation, the graph can get massive otherwise!
                xm.mark_step()

            else:
                # Clear gradients and reset start step (regardless) at end of the loop
                optimizer.zero_grad()
                start_step = 0

        # Redundant, but Synchronous Validation Epoch...
        xm.master_print("Validating...")
        val_losses = []
        with torch.no_grad():
            for batch in tqdm(val_device_loader, disable=not is_rank_zero):
                if cfg.model.arch == "v-mvp":
                    loss, _, _ = model(batch)
                elif cfg.model.arch in {"v-r3m", "v-rn3m"}:
                    imgs, lang, lang_mask = batch
                    loss, _, _, _, _, _, _ = model(imgs, lang, lang_mask)
                elif cfg.model.arch == "v-cond":
                    img, lang, lang_mask = batch
                    loss, _, _ = model(img, lang, lang_mask)
                elif cfg.model.arch == "v-dual":
                    imgs, lang, lang_mask = batch
                    loss, _ = model(imgs, lang, lang_mask)
                elif cfg.model.arch == "v-gen":
                    imgs, lang_con, lang_con_mask, lang_gen, lang_gen_mask, lang_gen_weight = batch
                    loss, _, _, _ = model(imgs, lang_con, lang_con_mask, lang_gen, lang_gen_mask, lang_gen_weight)
                else:
                    raise NotImplementedError(f"Forward Pass Logic for Model `{cfg.model.arch} not implemented!")

                # Just append to val_losses...
                val_losses.append(loss)

            # Compute Val Loss & *mesh reduce* --> Why? :: the XLA people said so!
            val_loss = torch.stack(val_losses).mean().item()
            val_loss = xm.mesh_reduce("val_loss", val_loss, np.mean)  # All replicas should just return the same thing?

            # Logging --> add another `closure` for end-of-epoch cleanup --> compute `duration` as well...
            duration = int(time.time() - start_time) + resume_time
            if is_rank_zero:
                xm.add_step_closure(
                    log_epoch_end_update,
                    args=(
                        cfg.model.arch,
                        epoch,
                        global_step,
                        run_id,
                        duration,
                        train_losses,
                        val_loss,
                        lrs[-1],
                        step_times,
                    ),
                )

        # Save Checkpoint (at end of Epoch)
        saver.save(
            epoch=epoch + 1,
            is_local_step=False,
            model=model,
            optimizer=optimizer,
            duration=duration,
            train_loss=train_losses[-1].item(),
            val_loss=val_loss,
        )

    # Dump TPU Debugging Metrics...
    if is_rank_zero:
        with open("tpu-debug-metrics.log", "w") as f:
            f.write(met.metrics_report())

    # Exiting w/ Multiprocessing is a Nightmare... try to join?
    xm.master_print("...and that's all, folks!")
    xm.rendezvous("Cheers!")

    # Sleep for like 3 minutes... get W&B to finish syncing logs
    wandb.finish()
    time.sleep(150)


def mp_fn(_: int, cfg: PretrainConfig) -> None:
    torch.set_default_tensor_type("torch.FloatTensor")

    # Let's Start Pretraining!
    xpretrain(cfg)


@hydra.main(config_path=None, config_name="config")
def main(cfg: PretrainConfig) -> None:
    import torch_xla.distributed.xla_multiprocessing as xmp

    # Call XMP Spawn w/ the Config as the sole argument...
    xmp.spawn(mp_fn, args=(cfg,), nprocs=cfg.accelerator.num_accelerators, start_method="spawn")


if __name__ == "__main__":
    main()


================================================
FILE: pyproject.toml
================================================
[build-system]
requires = ["setuptools"]
build-backend = "setuptools.build_meta"

[project]
name = "voltron-robotics"
authors = [
    {name = "Siddharth Karamcheti", email="skaramcheti@cs.stanford.edu"}
]
description = "Voltron: Language-Driven Representation Learning for Robotics."
version = "1.1.0"
readme = "README.md"
requires-python = ">=3.8"
keywords = ["robotics", "representation learning", "natural language processing", "machine learning"]
license = {file = "LICENSE"}
classifiers = [
    "Development Status :: 3 - Alpha",
    "Intended Audience :: Developers",
    "Intended Audience :: Education",
    "Intended Audience :: Science/Research",
    "License :: OSI Approved :: MIT License",
    "Operating System :: OS Independent",
    "Programming Language :: Python :: 3",
    "Programming Language :: Python :: 3.8",
    "Programming Language :: Python :: 3.9",
    "Programming Language :: Python :: 3.10",
    "Programming Language :: Python :: 3 :: Only",
    "Topic :: Scientific/Engineering :: Artificial Intelligence",
]
dependencies = [
    "av",
    "einops",
    "gdown",
    "google-cloud-storage",
    "h5py",
    "hurry.filesize",
    "hydra-core==1.1.1",    # Lock Hydra =>> future versions break!
    "jsonlines",
    "omegaconf==2.1.2",     # Lock OmegaConf =>> future versions break!
    "opencv-python",
    "pandas",
    "rich",
    "torch>=2.0.0",         # Native PyTorch Code (Release 2.0.0) uses PyTorch 2.0!
    "torchvision>=0.15.0",
    "transformers",
    "wandb",
]

[project.optional-dependencies]
dev = [
    "black",
    "ipython",
    "pre-commit",
    "ruff",
]

[project.urls]
homepage = "https://github.com/siddk/voltron-robotics"
repository = "https://github.com/siddk/voltron-robotics"
documentation = "https://github.com/siddk/voltron-robotics"

[tool.black]
line-length = 121
target-version = ["py38", "py39", "py310"]
preview = true

[tool.ruff]
line-length = 121
target-version = "py38"
select = ["A", "B", "C90", "E", "F", "I", "RUF", "W"]

[tool.ruff.per-file-ignores]
"__init__.py" = ["E402", "F401"]

[tool.setuptools.packages.find]
where = ["."]
exclude = ["cache"]


================================================
FILE: setup.py
================================================
"""
setup.py

PEP 621 switches most of Packaging to `pyproject.toml` -- yet keep a "dummy" setup.py for external code that has not
yet upgraded.
"""
from setuptools import setup

setup()


================================================
FILE: voltron/__init__.py
================================================
from .models.materialize import available_models, load
from .models.util import instantiate_extractor


================================================
FILE: voltron/conf/__init__.py
================================================
from .accelerators import AcceleratorConfig
from .datasets import DatasetConfig
from .models import ModelConfig
from .tracking import TrackingConfig


================================================
FILE: voltron/conf/accelerators.py
================================================
"""
accelerator.py

Base Hydra Structured Configs for defining various accelerator schemes. Uses a simple single inheritance structure.
"""
import os
from dataclasses import dataclass

from hydra.core.config_store import ConfigStore
from omegaconf import MISSING

# === Vanilla Accelerators (Deprecated; mostly for XLA code) ===


@dataclass
class AcceleratorConfig:
    accelerator: str = MISSING
    num_accelerators: int = MISSING
    num_workers: int = MISSING


@dataclass
class TPUv2OneConfig(AcceleratorConfig):
    accelerator = "tpu"
    num_accelerators = 1
    num_workers = 4


@dataclass
class TPUv2EightConfig(AcceleratorConfig):
    accelerator = "tpu"
    num_accelerators = 8
    num_workers = 4


@dataclass
class TPUv3OneConfig(AcceleratorConfig):
    accelerator = "tpu"
    num_accelerators = 1
    num_workers = 8


@dataclass
class TPUv3EightConfig(AcceleratorConfig):
    accelerator = "tpu"
    num_accelerators = 8
    num_workers = 8


# === GPU Default Config --> just set `num_workers`; `torchrun` takes care of the rest! ===
#   > Note :: Defaults to 1 GPU if WORLD_SIZE not set (e.g., not running with `torchrun`)


@dataclass
class TorchRunDefaultConfig(AcceleratorConfig):
    accelerator = "gpu"
    num_accelerators = int(os.environ["WORLD_SIZE"] if "WORLD_SIZE" in os.environ else 1)
    num_workers = 8


# Create a configuration group `accelerator` and populate with the above...
cs = ConfigStore.instance()
cs.store(group="accelerator", name="tpu-v2-1", node=TPUv2OneConfig)
cs.store(group="accelerator", name="tpu-v2-8", node=TPUv2EightConfig)
cs.store(group="accelerator", name="tpu-v3-1", node=TPUv3OneConfig)
cs.store(group="accelerator", name="tpu-v3-8", node=TPUv3EightConfig)

cs.store(group="accelerator", name="torchrun", node=TorchRunDefaultConfig)


================================================
FILE: voltron/conf/datasets.py
================================================
"""
datasets.py

Base Hydra Structured Config for defining various pretraining datasets and appropriate configurations. Uses a simple,
single inheritance structure.
"""
from dataclasses import dataclass
from typing import Any, Tuple

from hydra.core.config_store import ConfigStore
from hydra.utils import to_absolute_path
from omegaconf import MISSING


@dataclass
class DatasetConfig:
    name: str = MISSING
    path: str = MISSING
    artifact_path: str = MISSING

    # Streaming Parameters (assumes fully preprocessed dataset lives at `stream_prefix/...`)
    #   =>> Deprecated as of `v2`
    stream: bool = True
    stream_prefix: str = "data/processed"

    # Dataset-Specific Parameters
    resolution: int = 224
    normalization: Tuple[Any, Any] = MISSING

    # For preprocessing --> maximum size of saved frames (assumed square)
    preprocess_resolution: int = MISSING

    # Validation Parameters
    n_val_videos: int = MISSING

    # Language Modeling Parameters
    language_model: str = "distilbert-base-uncased"
    hf_cache: str = to_absolute_path("data/hf-cache")

    # Maximum Length for truncating language inputs... should be computed after the fact (set to -1 to compute!)
    max_lang_len: int = MISSING

    # Dataset sets the number of pretraining epochs (general rule :: warmup should be ~5% of full)
    warmup_epochs: int = MISSING
    max_epochs: int = MISSING

    # Plausible Formats --> These are instantiations each "batch" could take, with a small DSL
    #   > Note: Assumes final element of the list is the "most expressive" --> used to back-off
    batch_formats: Any = (
        ("state", ("state_i",)),
        ("state+language", ("state_i", "language")),
        ("state+ok", ("state_initial", "state_i", "language")),
        ("quintet+language", ("state_initial", "state_i", "state_j", "state_k", "state_final", "language")),
    )

    # Preprocessing :: Frame-Sampling Parameters
    initial_final_alpha: float = 0.2


@dataclass
class SthSthv2Config(DatasetConfig):
    # fmt: off
    name: str = "sth-sth-v2"
    path: str = to_absolute_path("data/raw/sth-sth-v2")
    artifact_path: str = to_absolute_path("data/processed/sth-sth-v2")

    # Dataset Specific arguments
    normalization: Tuple[Any, Any] = (                              # Mean & Standard Deviation (default :: ImageNet)
        (0.485, 0.456, 0.406),
        (0.229, 0.224, 0.225),
    )

    # Sth-Sth-v2 Videos have a fixed height of 240; we'll crop to square at this resolution!
    preprocess_resolution: int = 240

    # Validation Parameters
    n_val_videos: int = 1000                                        # Number of Validation Clips (fast evaluation!)

    # Epochs for Dataset
    warmup_epochs: int = 20
    max_epochs: int = 400

    # Language Modeling Parameters
    max_lang_len: int = 20
    # fmt: on


# Create a configuration group `dataset` and populate with the above...
#   =>> Note :: this is meant to be extendable --> add arbitrary datasets & mixtures!
cs = ConfigStore.instance()
cs.store(group="dataset", name="sth-sth-v2", node=SthSthv2Config)


================================================
FILE: voltron/conf/models.py
================================================
"""
models.py

Base Hydra Structured Config for defining various pretraining model architectures and appropriate configurations. Uses a
simple single inheritance structure.
"""
from dataclasses import dataclass
from typing import Tuple

from hydra.core.config_store import ConfigStore
from hydra.utils import to_absolute_path
from omegaconf import MISSING


@dataclass
class ModelConfig:
    arch: str = MISSING
    identifier: str = MISSING

    # Dataset Modality
    data_modality: str = MISSING

    # Default Vision Transformer Configuration
    patch_size: int = 16
    mlp_ratio: float = 4.0

    # Effective batch size --> total number of examples before gradient update
    effective_bsz: int = MISSING

    # Number of examples one can safely fit on an accelerator w/ this model!
    device_bsz: int = MISSING  # For backwards compatibility, only use device_bsz for XLA/TPU pretraining...
    native_bsz: int = MISSING  # For backwards compatibility, define a separate `native_bsz`...


# @Data-Locked Reproductions --- Encompasses MVP (MAE) + R3M


# MVP (Base Masked Autoencoder)
@dataclass
class MVPConfig(ModelConfig):
    arch: str = "v-mvp"
    identifier: str = MISSING

    # Dataset Modality
    data_modality: str = "state"

    # Base MAE Parameters
    mask_ratio: float = 0.75

    # Architecture Parameters
    encoder_depth: int = MISSING
    encoder_embed_dim: int = MISSING
    encoder_n_heads: int = MISSING

    decoder_depth: int = MISSING
    decoder_embed_dim: int = MISSING
    decoder_n_heads: int = MISSING

    # MAE Loss/Objective Configuration
    norm_pixel_loss: bool = True
    effective_bsz: int = 1024
    device_bsz: int = MISSING
    native_bsz: int = MISSING

    # Optimization Parameters
    optimizer: str = "adamw"
    schedule: str = "linear-warmup+cosine-decay"
    base_lr: float = 1.5e-4
    min_lr: float = 0.0
    betas: Tuple[float, float] = (0.9, 0.95)
    weight_decay: float = 0.05


@dataclass
class MVPSmallConfig(MVPConfig):
    identifier = "r-mvp"

    # Architecture Parameters -- should match ViT Small Architecture to the letter!
    #   Note: Small is defined in TIMM:
    #       > https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L683
    encoder_depth = 12
    encoder_embed_dim = 384
    encoder_n_heads = 6

    decoder_depth = 6
    decoder_embed_dim = 192
    decoder_n_heads = 6

    # Number of examples one can safely fit on an accelerator w/ this model!
    #   > TPU-v3: max of 128 per device.
    device_bsz = 128
    native_bsz = 128


# R3M Models --> Just different visual encoders, roughly following the above!
@dataclass
class R3MConfig(ModelConfig):
    arch: str = "v-r3m"
    identifier: str = MISSING

    # Dataset Modality
    data_modality: str = "quintet+language"

    # ViT Architecture Parameters
    depth: int = MISSING
    embed_dim: int = MISSING
    n_heads: int = MISSING

    # Effective Batch Size
    effective_bsz: int = 1024

    # Language Model Parameters
    language_model: str = "distilbert-base-uncased"
    hf_cache: str = to_absolute_path("data/hf-cache")
    language_dim: int = 768
    vocab_size: int = 30522
    reward_dim: int = 1024

    # Loss/Objective Configuration
    lang_reward_weight: float = 1.0
    tcn_weight: float = 1.0
    l1_weight: float = 1e-5
    l2_weight: float = 1e-5
    n_negatives: int = 3
    device_bsz: int = MISSING
    native_bsz: int = MISSING

    # Optimization Parameters
    optimizer: str = "adam"
    schedule: str = "linear-warmup+cosine-decay"
    lr: float = 1e-4
    min_lr: float = 0.0


@dataclass
class R3MSmallConfig(R3MConfig):
    identifier = "r-r3m-vit"

    # Architecture Parameters -- should match ViT Small Architecture to the letter!
    #   Note: Small is defined in TIMM:
    #       > https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L683
    depth = 12
    embed_dim = 384
    n_heads = 6

    # Device Batch Size
    device_bsz = 32
    native_bsz = 128


# R3M -- ResNet50 Encoder (instead of ViT)
@dataclass
class ResNet3MConfig(ModelConfig):
    arch: str = "v-rn3m"
    identifier: str = MISSING

    # Dataset Modality
    data_modality: str = "quintet+language"

    # Effective Batch Size
    effective_bsz: int = 1024

    # Architecture Parameters
    fc_dim: int = MISSING

    # Language Model Parameters
    language_model: str = "distilbert-base-uncased"
    hf_cache: str = to_absolute_path("data/hf-cache")
    language_dim: int = 768
    vocab_size: int = 30522
    reward_dim: int = 1024

    # Loss/Objective Configuration
    lang_reward_weight: float = 1.0
    tcn_weight: float = 1.0
    l1_weight: float = 1e-5
    l2_weight: float = 1e-5
    n_negatives: int = 3
    device_bsz: int = MISSING
    native_bsz: int = MISSING

    # Optimization Parameters
    optimizer: str = "adam"
    lr: float = 1e-4


class RN3M50Config(ResNet3MConfig):
    identifier = "r-r3m-rn50"

    # Architecture Parameters
    fc_dim = 2048

    # Device Batch Size
    device_bsz = 32
    native_bsz = 128


# @Voltron Models -- VCond, VDual, VGen


# VCond -- Single Frame + Language Conditioning
@dataclass
class VCondConfig(ModelConfig):
    arch: str = "v-cond"
    identifier: str = MISSING

    # Dataset Modality
    data_modality: str = "state+language"

    # Base MAE Parameters
    mask_ratio: float = 0.75

    # Base Language Parameters --> full sentence dropout only...
    language_model: str = "distilbert-base-uncased"
    hf_cache: str = to_absolute_path("data/hf-cache")
    language_dim: int = 768
    vocab_size: int = 30522
    lang_dropout: float = MISSING

    # Architecture Parameters
    encoder_depth: int = MISSING
    encoder_embed_dim: int = MISSING
    encoder_n_heads: int = MISSING

    decoder_depth: int = MISSING
    decoder_embed_dim: int = MISSING
    decoder_n_heads: int = MISSING

    use_cls_token: bool = True

    # MAE Loss/Objective Configuration
    norm_pixel_loss: bool = True
    effective_bsz: int = 1024
    device_bsz: int = MISSING
    native_bsz: int = MISSING

    # Optimization Parameters
    optimizer: str = "adamw"
    schedule: str = "linear-warmup+cosine-decay"
    base_lr: float = 1.5e-4
    min_lr: float = 0.0
    betas: Tuple[float, float] = (0.9, 0.95)
    weight_decay: float = 0.05


@dataclass
class VCondSmallConfig(VCondConfig):
    identifier = "v-cond"

    # No language dropout...
    lang_dropout = 0.0

    # Architecture Parameters -- should match ViT Small Architecture to the letter!
    #   Note: Small is defined in TIMM:
    #       > https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L683
    encoder_depth = 12
    encoder_embed_dim = 384
    encoder_n_heads = 6

    decoder_depth = 6
    decoder_embed_dim = 192
    decoder_n_heads = 6

    # Number of examples one can safely fit on an accelerator w/ this model!
    #   > TPU-v3: max of 128 per device
    #   > GPU w/ 32G of RAM: max of 128 per device!
    device_bsz = 128
    native_bsz = 128


@dataclass
class VCondBaseConfig(VCondConfig):
    identifier = "v-cond-base"

    # No language dropout...
    lang_dropout = 0.0

    # Architecture Parameters -- should match ViT Base Architecture to the letter!
    #   Note: Base is defined in TIMM & Original MAE Repository:
    #       > https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L723
    #       > https://github.com/facebookresearch/mae/blob/main/models_mae.py#L223
    encoder_depth = 12
    encoder_embed_dim = 768
    encoder_n_heads = 12

    decoder_depth = 8
    decoder_embed_dim = 512
    decoder_n_heads = 16

    # Number of examples one can safely fit on an accelerator w/ this model!
    #   > TPU-v3: max of 128 per device!
    #   > GPU w/ 32G of RAM: max of 128 per device!
    device_bsz = 128
    native_bsz = 128


# VDual - Dual Frame (0th Frame + Kth frame) + Language Conditioning
@dataclass
class VDualConfig(ModelConfig):
    arch: str = "v-dual"
    identifier: str = MISSING

    # Dataset Modality
    data_modality: str = "state+ok"

    # Base MAE Parameters
    mae_weight: float = 1.0
    mask_ratio: float = 0.75

    # Base Language Parameters --> full sentence dropout only...
    language_model: str = "distilbert-base-uncased"
    hf_cache: str = to_absolute_path("data/hf-cache")
    language_dim: int = 768
    vocab_size: int = 30522
    lang_dropout: float = MISSING

    # Architecture Parameters
    encoder_depth: int = MISSING
    encoder_embed_dim: int = MISSING
    encoder_n_heads: int = MISSING

    decoder_depth: int = MISSING
    decoder_embed_dim: int = MISSING
    decoder_n_heads: int = MISSING

    use_cls_token: bool = True

    # MAE Loss/Objective Configuration -- Cut effective batch size since we see 12-25x contexts per batch example!
    norm_pixel_loss: bool = True
    effective_bsz: int = 1024
    device_bsz: int = MISSING
    native_bsz: int = MISSING

    # Optimization Parameters
    optimizer: str = "adamw"
    schedule: str = "linear-warmup+cosine-decay"
    base_lr: float = 1.5e-4
    min_lr: float = 0.0
    betas: Tuple[float, float] = (0.9, 0.95)
    weight_decay: float = 0.05


@dataclass
class VDualSmallConfig(VDualConfig):
    identifier = "v-dual"

    # No language dropout...
    lang_dropout = 0.0

    # Architecture Parameters -- should match ViT Small Architecture to the letter!
    #   Note: Small is defined in TIMM:
    #       > https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L683
    encoder_depth = 12
    encoder_embed_dim = 384
    encoder_n_heads = 6

    decoder_depth = 6
    decoder_embed_dim = 192
    decoder_n_heads = 6

    # Number of examples one can safely fit on an accelerator w/ this model!
    #   > TPU-v3: max of 128 per device!
    #   > GPU w/ 32G of RAM: max of 128 per device!
    device_bsz = 128
    native_bsz = 128


@dataclass
class VDualBaseConfig(VDualConfig):
    identifier = "v-dual-base"

    # No language dropout...
    lang_dropout = 0.0

    # Architecture Parameters -- should match ViT Base Architecture to the letter!
    #   Note: Base is defined in TIMM & Original MAE Repository:
    #       > https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L723
    #       > https://github.com/facebookresearch/mae/blob/main/models_mae.py#L223
    encoder_depth = 12
    encoder_embed_dim = 768
    encoder_n_heads = 12

    decoder_depth = 8
    decoder_embed_dim = 512
    decoder_n_heads = 16

    # Number of examples one can safely fit on an accelerator w/ this model!
    #   > TPU-v3: max of 128 per device!
    #   > GPU w/ 32G of RAM: max of 64 per device!
    device_bsz = 128
    native_bsz = 64


# VGen - Dual Frame with Language Conditioning AND Language Generation
@dataclass
class VGenConfig(ModelConfig):
    arch: str = "v-gen"
    identifier: str = MISSING

    # Dataset Modality
    data_modality: str = "state+ok"

    # Base MAE & LM Parameters --> LM Weight is set such that mae & lang loss are ~same order of magnitude
    mae_weight: float = 1.0
    lm_weight: float = 0.5
    mask_ratio: float = 0.75
    gen_ratio: float = MISSING

    # Base Language Parameters
    language_model: str = "distilbert-base-uncased"
    hf_cache: str = to_absolute_path("data/hf-cache")
    language_dim: int = 768
    vocab_size: int = 30522

    # Architecture Parameters
    encoder_depth: int = MISSING
    encoder_embed_dim: int = MISSING
    encoder_n_heads: int = MISSING

    decoder_depth: int = MISSING
    decoder_embed_dim: int = MISSING
    decoder_n_heads: int = MISSING

    use_cls_token: bool = True

    # MAE Loss/Objective Configuration -- Cut effective batch size since we see 12-25x contexts per batch example!
    norm_pixel_loss: bool = True
    effective_bsz: int = 1024
    device_bsz: int = MISSING
    native_bsz: int = MISSING

    # Optimization Parameters
    optimizer: str = "adamw"
    schedule: str = "linear-warmup+cosine-decay"
    base_lr: float = 1.5e-4
    min_lr: float = 0.0
    betas: Tuple[float, float] = (0.9, 0.95)
    weight_decay: float = 0.05


@dataclass
class VGen50SmallConfig(VGenConfig):
    identifier = "v-gen"

    # LM Parameters --> control % of examples that are for "language generation" (no conditioning)
    gen_ratio = 0.50

    # Architecture Parameters -- should match ViT Small Architecture to the letter!
    #   Note: Small is defined in TIMM:
    #       > https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L683
    encoder_depth = 12
    encoder_embed_dim = 384
    encoder_n_heads = 6

    decoder_depth = 6
    decoder_embed_dim = 192
    decoder_n_heads = 6

    # Number of examples one can safely fit on an accelerator w/ this model!
    #   > TPU-v3: max of 64 per device!
    #   > GPU w/ 32G of RAM: max of 64 per device!
    device_bsz = 64
    native_bsz = 64


@dataclass
class VGen50BaseConfig(VGenConfig):
    identifier = "v-gen-base"

    # LM Parameters --> control % of examples that are for "language generation" (no conditioning)
    gen_ratio = 0.50

    # Architecture Parameters -- should match ViT Base Architecture to the letter!
    #   Note: Base is defined in TIMM & Original MAE Repository:
    #       > https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L723
    #       > https://github.com/facebookresearch/mae/blob/main/models_mae.py#L223
    encoder_depth = 12
    encoder_embed_dim = 768
    encoder_n_heads = 12

    decoder_depth = 8
    decoder_embed_dim = 512
    decoder_n_heads = 16

    # Number of examples one can safely fit on an accelerator w/ this model!
    #   > TPU-v3: max of 32 per device!
    #   > GPU w/ 32G of RAM: max of 32 per device!
    device_bsz = 32
    native_bsz = 32


# Create a configuration group `model` and populate with the above...
cs = ConfigStore.instance()

# === @Data-Locked Reproductions ===

# Image-Only MAE/MVP Architectures
cs.store(group="model", name="r-mvp", node=MVPSmallConfig)

# R3M Architectures - ViT & ResNet50
cs.store(group="model", name="r-r3m-vit", node=R3MSmallConfig)
cs.store(group="model", name="r-r3m-rn50", node=RN3M50Config)

# === @Voltron ===

# VCond Architectures
cs.store(group="model", name="v-cond", node=VCondSmallConfig)
cs.store(group="model", name="v-cond-base", node=VCondBaseConfig)

# VDual
cs.store(group="model", name="v-dual", node=VDualSmallConfig)
cs.store(group="model", name="v-dual-base", node=VDualBaseConfig)

# VGen
cs.store(group="model", name="v-gen", node=VGen50SmallConfig)
cs.store(group="model", name="v-gen-base", node=VGen50BaseConfig)


================================================
FILE: voltron/conf/tracking.py
================================================
"""
tracking.py

Base Hydra Structured Config for defining various run & experiment tracking configurations, e.g., via Weights & Biases.
Uses a simple single inheritance structure.
"""
from dataclasses import dataclass, field
from typing import List, Optional, Tuple

from hydra.core.config_store import ConfigStore
from omegaconf import MISSING


@dataclass
class TrackingConfig:
    # Active Loggers --> List of Loggers
    active_loggers: List[str] = field(default_factory=lambda: ["jsonl", "wandb"])

    # Generic Logging Frequency --> Matters more for XLA/TPUs... set this to be as large as you can stomach!
    log_frequency: int = 100

    # Checkpointing Strategy --> Save each epoch, keep most recent `idx[0]` checkpoints & *every* `idx[1]` checkpoints
    #   Additionally, save (locally) a checkpoint every `idx[2]` steps for the current epoch (-1).
    checkpoint_strategy: Tuple[int, int, int] = (1, 1, 1500)

    # Weights & Biases Setup
    project: str = "voltron-pretraining"
    entity: str = "voltron-robotics"

    # Notes & Tags are at the discretion of the user... see below
    notes: str = MISSING
    tags: Optional[List[str]] = None

    # Directory to save W&B Metadata & Logs in General -- if None, defaults to `logs/` in the Hydra CWD
    directory: Optional[str] = None


@dataclass
class VoltronTrackingConfig(TrackingConfig):
    # Note: I really like using notes to keep track of things, so will crash unless specified with run.
    #   > For `tags` I like to populate based on other args in the script, so letting it remain None
    notes: str = MISSING


# Create a configuration group `trackers` and populate with the above...
cs = ConfigStore.instance()
cs.store(group="tracking", name="voltron-tracking", node=VoltronTrackingConfig)


================================================
FILE: voltron/datasets/__init__.py
================================================
from .datasets import get_datasets


================================================
FILE: voltron/datasets/datasets.py
================================================
"""
datasets.py

Core Pytorch Dataset implementations for the various "data flavors" used by the different representation learning models
(Voltron and data-locked reproductions). Crucially, ach dataset loads from the corresponding serialized batch files that
define the exact data (and order of iterating through the data) to see during each epoch.

Notably, these serialized files control exactly what data is seen by *all* methods *across epochs*; using these files is
critical to reproducibility & comparisons.

The file contains logic for a "standard" Dataset; all files (batch index files, image/video/language files) are stored
on local disk, assuming storage conducive to fast random reads. For a "streaming" dataset (loading data directly from
GCP Buckets/Amazon S3), see `v1/stream_datasets.py`.
"""
from pathlib import Path
from typing import Any, Optional, Tuple

import h5py
import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision.io import read_image
from torchvision.transforms import Compose

from voltron.preprocessing.transforms import get_online_transform


class PretrainDataset(Dataset):
    def __init__(self) -> None:
        super().__init__()
        self.epoch, self.h5, self.vid, self.states = 0, None, None, None
        self.index_path, self.language_path, self.language = None, None, None

    def hydrate(self, path: Path) -> None:
        # Create Open HDF5 Handle
        self.h5 = h5py.File(path, "r")
        self.vid, self.states = self.h5["vid"].asstr(), self.h5["states"].asstr()

        # Load Language Index
        if self.language_path is not None:
            self.language = torch.load(self.index_path / self.language_path)

    def set_epoch(self, epoch: int) -> None:
        self.epoch = epoch

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, ...]:
        raise NotImplementedError("PretrainDataset is an abstract class; should never be initialized directly!")

    def __len__(self) -> int:
        raise NotImplementedError("PretrainDataset is an abstract class; should never be initialized directly!")


class StateDataset(PretrainDataset):
    def __init__(self, epoch: int, index_path: Path, img_transform: Compose, is_val: bool = False) -> None:
        super().__init__()
        self.index_path, self.is_val, self.val_loaded = index_path, is_val, False
        self.epoch, self.img_transform, self.hdf5_path, self.n_examples = epoch, img_transform, None, None

        # === Retrieve Epoch Batches --> only call before/between epochs (as we create new DataLoaders) ===
        self.set_epoch(epoch)

    def set_epoch(self, epoch: int) -> None:
        # Load Validation Batches
        if self.is_val and not self.val_loaded:
            self.hdf5_path = self.index_path / "state" / "validation-batches.hdf5"
            with h5py.File(self.hdf5_path, "r") as h5:
                self.n_examples = len(h5["states"])

            # Set `val_loaded`
            self.val_loaded = True

        # Load Train Batches
        elif not self.is_val:
            self.hdf5_path = self.index_path / "state" / f"train-epoch={epoch}-batches.hdf5"
            with h5py.File(self.hdf5_path, "r") as h5:
                self.n_examples = len(h5["states"])

    def __getitem__(self, idx: int) -> torch.Tensor:
        """Return processed image frame as a Tensor."""
        if self.h5 is None:
            self.hydrate(self.hdf5_path)

        return self.img_transform(read_image(str(self.index_path.parent / self.states[idx][0])))

    def __len__(self) -> int:
        return self.n_examples


class StateLanguageDataset(PretrainDataset):
    def __init__(
        self,
        epoch: int,
        index_path: Path,
        img_transform: Compose,
        lang_dropout: Optional[float] = None,
        is_val: bool = False,
    ) -> None:
        super().__init__()
        self.index_path, self.is_val, self.val_loaded = index_path, is_val, False
        self.epoch, self.img_transform, self.hdf5_path, self.n_examples = epoch, img_transform, None, None
        self.lang_dropout, self.dropout_idxs = 0.0 if (lang_dropout is None) else lang_dropout, set()

        # Set Language Path
        self.language_path = "val-language-index.pt" if self.is_val else "train-language-index.pt"

        # === Retrieve Epoch Batches --> only call before/between epochs (as we create new DataLoaders) ===
        self.set_epoch(epoch)

    def set_epoch(self, epoch: int) -> None:
        # Load Validation Batches
        if self.is_val and not self.val_loaded:
            self.hdf5_path = self.index_path / "state+language" / "validation-batches.hdf5"
            with h5py.File(self.hdf5_path, "r") as h5:
                self.n_examples = len(h5["states"])

            # Set `val_loaded`
            self.val_loaded = True

        # Load Train Batches
        elif not self.is_val:
            self.hdf5_path = self.index_path / "state+language" / f"train-epoch={epoch}-batches.hdf5"
            with h5py.File(self.hdf5_path, "r") as h5:
                self.n_examples = len(h5["states"])

        # Assemble Dropout Indices
        n_drop = int(self.lang_dropout * self.n_examples)
        self.dropout_idxs = set(np.random.choice(self.n_examples, n_drop, replace=False))

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Return processed image frame and language, decomposed as the input_ids and attention_mask."""
        if self.h5 is None:
            self.hydrate(self.hdf5_path)

        # Get Vid ID --> parse out language, transform frame!
        vid = self.vid[idx]
        lang, lang_mask = self.language[vid]["input_ids"], self.language[vid]["attention_mask"]

        # Dropout Language (Naive Zeroing leads to NaN --> just want the "CLS" token)
        if idx in self.dropout_idxs:
            # Initial language token is *always* <CLS> = `101` --> last token always <SEP> = `102`
            lang[1:] *= 0
            lang_mask[1:] *= 0

        # Retrieve Frame & Return
        img = self.img_transform(read_image(str(self.index_path.parent / self.states[idx][0])))
        return img, lang, lang_mask

    def __len__(self) -> int:
        return self.n_examples


class StateOKDataset(PretrainDataset):
    def __init__(
        self,
        epoch: int,
        index_path: Path,
        img_transform: Compose,
        lang_dropout: Optional[float] = None,
        is_val: bool = False,
    ) -> None:
        super().__init__()
        self.index_path, self.is_val, self.val_loaded = index_path, is_val, False
        self.epoch, self.img_transform, self.hdf5_path, self.n_examples = epoch, img_transform, None, None
        self.lang_dropout, self.dropout_idxs = 0.0 if (lang_dropout is None) else lang_dropout, set()

        # Set Language Path
        self.language_path = "val-language-index.pt" if self.is_val else "train-language-index.pt"

        # === Retrieve Epoch Batches --> only call before/between epochs (as we create new DataLoaders) ===
        self.set_epoch(epoch)

    def set_epoch(self, epoch: int) -> None:
        # Load Validation Batches
        if self.is_val and not self.val_loaded:
            self.hdf5_path = self.index_path / "state+ok" / "validation-batches.hdf5"
            with h5py.File(self.hdf5_path, "r") as h5:
                self.n_examples = len(h5["states"])

            # Set `val_loaded`
            self.val_loaded = True

        # Load Train Batches
        elif not self.is_val:
            self.hdf5_path = self.index_path / "state+ok" / f"train-epoch={epoch}-batches.hdf5"
            with h5py.File(self.hdf5_path, "r") as h5:
                self.n_examples = len(h5["states"])

        # Assemble Dropout Indices
        n_drop = int(self.lang_dropout * self.n_examples)
        self.dropout_idxs = set(np.random.choice(self.n_examples, n_drop, replace=False))

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Return processed dual frames and language, decomposed as the input_ids and attention_mask."""
        if self.h5 is None:
            self.hydrate(self.hdf5_path)

        # Get Vid ID --> parse out language, transform frames!
        vid = self.vid[idx]
        lang, lang_mask = self.language[vid]["input_ids"], self.language[vid]["attention_mask"]

        # Dropout Language (Naive Zeroing leads to NaN --> just want the "CLS" token)
        if idx in self.dropout_idxs:
            # Initial language token is *always* <CLS> = `101` --> last token always <SEP> = `102`
            lang[1:] *= 0
            lang_mask[1:] *= 0

        # Retrieve Frames & Return
        imgs = self.states[idx]
        imgs = torch.stack([self.img_transform(read_image(str(self.index_path.parent / fn))) for fn in imgs])
        return imgs, lang, lang_mask

    def __len__(self) -> int:
        return self.n_examples


class GenStateOKDataset(PretrainDataset):
    def __init__(
        self,
        epoch: int,
        index_path: Path,
        img_transform: Compose,
        gen_ratio: float,
        is_val: bool = False,
    ) -> None:
        super().__init__()
        self.index_path, self.is_val, self.val_loaded = index_path, is_val, False
        self.epoch, self.img_transform, self.hdf5_path, self.n_examples = epoch, img_transform, None, None
        self.gen_ratio, self.gen_idxs = gen_ratio, set()

        # Set Language Path
        self.language_path = "val-language-index.pt" if self.is_val else "train-language-index.pt"

        # === Retrieve Epoch Batches --> only call before/between epochs (as we create new DataLoaders) ===
        self.set_epoch(epoch)

    def set_epoch(self, epoch: int) -> None:
        # Load Validation Batches
        if self.is_val and not self.val_loaded:
            self.hdf5_path = self.index_path / "state+ok" / "validation-batches.hdf5"
            with h5py.File(self.hdf5_path, "r") as h5:
                self.n_examples = len(h5["states"])

            # Set `val_loaded`
            self.val_loaded = True

        # Load Train Batches
        elif not self.is_val:
            self.hdf5_path = self.index_path / "state+ok" / f"train-epoch={epoch}-batches.hdf5"
            with h5py.File(self.hdf5_path, "r") as h5:
                self.n_examples = len(h5["states"])

        # Assemble Generation Indices
        n_gen = int(self.gen_ratio * self.n_examples)
        self.gen_idxs = set(np.random.choice(self.n_examples, n_gen, replace=False))

    def __getitem__(
        self, idx: int
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, float]:
        """Return dual frames, conditioning language, language to generate, decomposed as input_ids/attention mask."""
        if self.h5 is None:
            self.hydrate(self.hdf5_path)

        # Get Vid ID --> parse out language to condition on / generate
        vid = self.vid[idx]
        lang_con, lang_con_mask = self.language[vid]["input_ids"], self.language[vid]["attention_mask"]
        lang_gen, lang_gen_mask, lang_gen_weight = lang_con.clone(), lang_con_mask.clone(), None

        # Generate / Condition Check --> (Naive Zeroing leads to NaN --> just want the "CLS" token)
        if idx in self.gen_idxs:
            # When Generating --> just condition on the <CLS> token and generate the rest!
            lang_con[1:] *= 0
            lang_con_mask[1:] *= 0
            lang_gen_weight = 1

        else:
            # When Conditioning -> just generate the <CLS> token (so things don't crash) but set weight to 0
            lang_gen[1:] *= 0
            lang_gen_mask[1:] *= 0
            lang_gen_weight = 0

        # Retrieve Frames and Return
        imgs = self.states[idx]
        imgs = torch.stack([self.img_transform(read_image(str(self.index_path.parent / fn))) for fn in imgs])

        return imgs, lang_con, lang_con_mask, lang_gen, lang_gen_mask, lang_gen_weight

    def __len__(self) -> int:
        return self.n_examples


class QuintetDataset(PretrainDataset):
    def __init__(self, epoch: int, index_path: Path, img_transform: Compose, is_val: bool = False) -> None:
        super().__init__()
        self.index_path, self.is_val, self.val_loaded = index_path, is_val, False
        self.epoch, self.img_transform, self.hdf5_path, self.n_examples = epoch, img_transform, None, None

        # Set Language Path
        self.language_path = "val-language-index.pt" if self.is_val else "train-language-index.pt"

        # === Retrieve Epoch Batches --> only call before/between epochs (as we create new DataLoaders) ===
        self.set_epoch(epoch)

    def set_epoch(self, epoch: int) -> None:
        # Load Validation Batches
        if self.is_val and not self.val_loaded:
            self.hdf5_path = self.index_path / "quintet+language" / "validation-batches.hdf5"
            with h5py.File(self.hdf5_path, "r") as h5:
                self.n_examples = len(h5["states"])

            # Set `val_loaded`
            self.val_loaded = True

        # Load Train Batches
        elif not self.is_val:
            self.hdf5_path = self.index_path / "quintet+language" / f"train-epoch={epoch}-batches.hdf5"
            with h5py.File(self.hdf5_path, "r") as h5:
                self.n_examples = len(h5["states"])

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Return all five processed frames and language, decomposed as the input_ids/attention_mask."""
        if self.h5 is None:
            self.hydrate(self.hdf5_path)

        # Get Vid ID --> parse out language, transform frames!
        vid = self.vid[idx]
        lang, lang_mask = self.language[vid]["input_ids"], self.language[vid]["attention_mask"]

        # Retrieve Frames & Return
        imgs = self.states[idx]
        imgs = torch.stack([self.img_transform(read_image(str(self.index_path.parent / fn))) for fn in imgs])
        return imgs, lang, lang_mask

    def __len__(self) -> int:
        return self.n_examples


def get_datasets(
    epoch: int,
    dataset_name: str,
    model_arch: str,
    artifact_path: str,
    data_modality: str,
    resolution: int,
    normalization: Tuple[Any, Any],
    lang_dropout: Optional[float] = None,
    gen_ratio: Optional[float] = None,
) -> Tuple[PretrainDataset, PretrainDataset]:
    index = Path(artifact_path) / dataset_name / "index"
    img_transform = get_online_transform(dataset_name, model_arch, resolution, normalization)

    # Switch on `data_modality` --> differs based on `model_arch` (e.g., MVP --> img, V-Cond --> img, language)
    if data_modality == "state":
        train_ds = StateDataset(epoch, index, img_transform)
        val_ds = StateDataset(epoch, index, img_transform, is_val=True)

    elif data_modality == "state+language":
        train_ds = StateLanguageDataset(epoch, index, img_transform, lang_dropout)
        val_ds = StateLanguageDataset(epoch, index, img_transform, lang_dropout, is_val=True)

    elif data_modality == "state+ok":
        # V-Dual --> don't return language modeling elements (causal attention mask, suffix language, etc.)
        if gen_ratio is None:
            train_ds = StateOKDataset(epoch, index, img_transform, lang_dropout)
            val_ds = StateOKDataset(epoch, index, img_transform, lang_dropout, is_val=True)

        # V-Gen --> add language modeling elements!
        else:
            train_ds = GenStateOKDataset(epoch, index, img_transform, gen_ratio)
            val_ds = GenStateOKDataset(epoch, index, img_transform, gen_ratio, is_val=True)

    elif data_modality == "quintet+language":
        train_ds = QuintetDataset(epoch, index, img_transform)
        val_ds = QuintetDataset(epoch, index, img_transform, is_val=True)

    else:
        raise ValueError(f"Data Modality `{data_modality}` is not supported!")

    return train_ds, val_ds


================================================
FILE: voltron/datasets/v1/__init__.py
================================================


================================================
FILE: voltron/datasets/v1/stream_datasets.py
================================================
"""
stream_datasets.py

Core PyTorch Datasets for the various "flavors" of data used by the various models under study. Crucially, each dataset
loads from the corresponding "batch" serialized files, that define the exact data to use.

Notably, these serialized files control exactly what data is seen by *all* methods **across epochs.** Using them is
fairly critical to reproducibility & fair comparison.

This specific file contains logic for a "streaming" Dataset; data is fetched (within the dataloader, by each
worker) via an open connection over the network to a GCS bucket, materializing data as raw BytesIO objects fed to
PIL.Image constructors.
"""
import json
import os
from io import BytesIO
from pathlib import Path
from typing import Any, Optional, Tuple

import numpy as np
import torch
from google.api_core.exceptions import NotFound
from google.auth.exceptions import TransportError
from google.cloud import storage
from google.resumable_media._helpers import _LOGGER
from PIL import Image, UnidentifiedImageError
from torch.utils.data import Dataset, get_worker_info
from torchvision.io import read_image
from torchvision.transforms import Compose
from torchvision.transforms.functional import pil_to_tensor

from voltron.preprocessing.v1.transforms import get_online_transform
from voltron.util.distributed import get_rank

# NOTE --> IF STREAMING JPEGS, WE NEED TO USE PILLOW TO READ FILES (w/o extracting locally...)
#   =>> Instead of `read_image(file)` assume we have "fname" and open fileobj (as BytesIO) -- remember to `seek(0)`
#
# > from PIL import Image
# > from torchvision.transforms.functional import pil_to_tensor
# > tensor = pil_to_tensor(Image.open(fileobj)
#       |--> This returns a `torch.uint8` Tensor of shape [3, 224, 224] --> *verified* equivalent to `read_image`

# Create Global GCS Client...
#   =>> Multiprocessing.spawn() does not inherit from base shell --> need to set service account key...
#   =>> TODO :: Figure out how to fetch num_accelerators & num_workers programatically...
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "/mnt/home/auth/gcp-auth.json"
N_CORES, BUCKETS = 8, [storage.Client().bucket("voltron-ANONYMIZED") for _ in range(8 * 8)]

# Suppress Google Cloud Loggers
_LOGGER.propagate = False
storage.blob._logger.propagate = False


class PretrainDataset(Dataset):
    def set_epoch(self, epoch: int) -> None:
        self.epoch = epoch


class StateDataset(PretrainDataset):
    def __init__(
        self,
        epoch: int,
        index_path: Path,
        img_transform: Compose,
        stream: bool = False,
        prefix: Optional[Path] = None,
        is_val: bool = False,
        do_retry: bool = True,
        n_retries: int = 3,
    ) -> None:
        super().__init__()
        self.index_path, self.stream, self.is_val, self.val_loaded = index_path, stream, is_val, False
        self.epoch, self.transform, self.elements, self.prefix = epoch, img_transform, None, prefix
        self.r = N_CORES * get_rank()
        self.do_retry, self.n_retries = do_retry, n_retries

        # === Retrieve Epoch Batches (only call `set_epoch` inside __init__() ===
        self.set_epoch(self.epoch)

    def set_epoch(self, epoch: int) -> None:
        # Not Streaming --> Read from local disk...
        if not self.stream:
            if self.is_val and not self.val_loaded:
                with open(self.index_path / "state" / "validation-batches.json", "r") as f:
                    self.elements = json.load(f)

                # Set `val_loaded` and move on...
                self.val_loaded = True

            elif not self.is_val:
                with open(self.index_path / "state" / f"train-epoch={epoch}-batches.json", "r") as f:
                    self.elements = json.load(f)

        # Streaming --> Beam directly from Bucket
        else:
            if self.is_val and not self.val_loaded:
                blob = BUCKETS[self.r].blob(str(self.prefix / "index" / "state" / "validation-batches.json"))
                self.elements = json.loads(blob.download_as_string())

                # `elements[i]["state"] currently maps to disk path... remove all but `parent/child.jpg`
                for element in self.elements:
                    element["state"] = "/".join(element["state"].split("/")[-2:])

                # Set `val_loaded` and move on...
                self.val_loaded = True

            elif not self.is_val:
                blob = BUCKETS[self.r].blob(str(self.prefix / "index" / "state" / f"train-epoch={epoch}-batches.json"))
                self.elements = json.loads(blob.download_as_string())

                # `elements[i]["state"]` currently maps to disk path... remove all but `parent/child.jpg`
                for element in self.elements:
                    element["state"] = "/".join(element["state"].split("/")[-2:])

    def __getitem__(self, index: int) -> torch.Tensor:
        """Return single frame as torch Tensor."""
        if not self.stream:
            return self.transform(read_image(self.elements[index]["state"]))
        else:
            # Multiplex w/ num_worker idx...
            worker_info = get_worker_info()
            r = (self.r + worker_info.id) if worker_info is not None else self.r

            # Streaming + Retry Logic (in case of a bad connection -- retry same file!)
            frame_path = self.elements[index]["state"]
            for _i in range(self.n_retries):
                try:
                    # Stream JPEG contents into BytesIO (seek back to 0), then into PIL Image.open()
                    if self.is_val:
                        blob, fobj = BUCKETS[r].blob(str(self.prefix / "val" / frame_path)), BytesIO()
                    else:
                        blob, fobj = BUCKETS[r].blob(str(self.prefix / "train" / frame_path)), BytesIO()

                    # Error Handling --> we've already run several dry-run/verification trials, but about ~0.004% of
                    #                    the time, we'll hit some sort of TCP/Transport error; this might even go up
                    #                    with multiple runs happening at the same time.
                    #
                    #                    To address this, we're adopting the simplest possible "retry" strategy that
                    #                    immediately tries to re-download the same file (and crashes if not possible).
                    #                    This ensures reproducibility, but puts some extra effort onto the user...

                    # File download...
                    blob.download_to_file(fobj)
                    fobj.seek(0)

                    # Image loading...
                    img_tensor = pil_to_tensor(Image.open(fobj))

                    # Return transformed image...
                    return self.transform(img_tensor)

                except (NotFound, TransportError, UnidentifiedImageError, OSError) as e:
                    # At the minimum --> print the broken file (obnoxiously!)
                    print(f"=>> BROKEN FILE :: {frame_path}")
                    if not self.do_retry:
                        raise e
                    else:
                        continue

            # If we've exhausted our retries --> raise an informative ValueError
            raise ValueError(f"Failed to fix state `{self.elements[index]['state']}` w/ {self.n_retries} retries...")

    def __len__(self) -> int:
        return len(self.elements)


class StateLanguageDataset(PretrainDataset):
    def __init__(
        self,
        epoch: int,
        index_path: Path,
        img_transform: Compose,
        lang_dropout: Optional[float] = None,
        stream: bool = False,
        prefix: Optional[Path] = None,
        is_val: bool = False,
        do_retry: bool = True,
        n_retries: int = 3,
    ) -> None:
        super().__init__()
        self.index_path, self.stream, self.is_val, self.val_loaded = index_path, stream, is_val, False
        self.epoch, self.transform, self.elements, self.prefix = epoch, img_transform, None, prefix
        self.lang_dropout = 0.0 if (lang_dropout is None or lang_dropout == 0) else lang_dropout
        self.dropout_indices = set()
        self.r = N_CORES * get_rank()
        self.do_retry, self.n_retries = do_retry, n_retries

        # Load Language Index & Retrieve Epoch 0 Batches
        language_path = "val-language-index.json" if self.is_val else "train-language-index.json"
        if not self.stream:
            with open(self.index_path / language_path, "r") as f:
                self.language_index = json.load(f)
        else:
            blob = BUCKETS[self.r].blob(str(self.prefix / "index" / language_path))
            self.language_index = json.loads(blob.download_as_string())

        # === Retrieve Epoch Batches (only call `set_epoch` inside __init__() ===
        self.set_epoch(self.epoch)

    def set_epoch(self, epoch: int) -> None:
        # Not Streaming --> Read from local disk...
        if not self.stream:
            if self.is_val and not self.val_loaded:
                with open(self.index_path / "state+language" / "validation-batches.json", "r") as f:
                    self.elements = json.load(f)

                # Assemble the set of dropout indices for the given epoch...
                n_drop = int(self.lang_dropout * len(self.elements))
                self.dropout_indices = set(np.random.choice(len(self.elements), n_drop, replace=False))

                # Set `val_loaded` and move on...
                self.val_loaded = True

            elif not self.is_val:
                with open(self.index_path / "state+language" / f"train-epoch={epoch}-batches.json", "r") as f:
                    self.elements = json.load(f)

                # Assemble the set of dropout indices for the given epoch...
                n_drop = int(self.lang_dropout * len(self.elements))
                self.dropout_indices = set(np.random.choice(len(self.elements), n_drop, replace=False))

        # Streaming --> Beam directly from Bucket
        else:
            if self.is_val and not self.val_loaded:
                blob = BUCKETS[self.r].blob(str(self.prefix / "index" / "state+language" / "validation-batches.json"))
                self.elements = json.loads(blob.download_as_string())

                # `elements[i]["state"]` currently maps to disk path... remove all but `parent/child.jpg`
                for element in self.elements:
                    element["state"] = "/".join(element["state"].split("/")[-2:])

                # Assemble the set of dropout indices for the given epoch...
                n_drop = int(self.lang_dropout * len(self.elements))
                self.dropout_indices = set(np.random.choice(len(self.elements), n_drop, replace=False))

                # Set `val_loaded` and move on...
                self.val_loaded = True

            elif not self.is_val:
                blob = BUCKETS[self.r].blob(
                    str(self.prefix / "index" / "state+language" / f"train-epoch={epoch}-batches.json")
                )
                self.elements = json.loads(blob.download_as_string())

                # `elements[i]["state"] currently maps to disk path... remove all but `parent/child.jpg`
                for element in self.elements:
                    element["state"] = "/".join(element["state"].split("/")[-2:])

                # Assemble the set of dropout indices for the given epoch...
                n_drop = int(self.lang_dropout * len(self.elements))
                self.dropout_indices = set(np.random.choice(len(self.elements), n_drop, replace=False))

    def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Return the frame and language, decomposed as the input_ids, and attention_mask."""
        vid = self.elements[index]["vid"]
        lang = torch.tensor(self.language_index[vid]["input_ids"], dtype=torch.int64)
        lang_mask = torch.tensor(self.language_index[vid]["attention_mask"], dtype=torch.int64)

        # Dropout Language Check --> (Naive Zeroing leads to NaN --> just want the "CLS" token)
        if index in self.dropout_indices:
            # Initial language token is *always* <CLS> = `101` --> last token always <SEP> = `102`
            lang[1:] *= 0
            lang_mask[1:] *= 0

        # Retrieve Single Frame
        if not self.stream:
            img = self.transform(read_image(self.elements[index]["state"]))
            return img, lang, lang_mask
        else:
            # Multiplex w/ num_worker idx...
            worker_info = get_worker_info()
            r = (self.r + worker_info.id) if worker_info is not None else self.r

            # Streaming + Retry Logic (in case of a bad connection -- retry same file!)
            frame_path = self.elements[index]["state"]
            for _i in range(self.n_retries):
                try:
                    # Stream JPEG contents into BytesIO (seek back to 0), then into PIL Image.open()
                    if self.is_val:
                        blob, fobj = BUCKETS[r].blob(str(self.prefix / "val" / frame_path)), BytesIO()
                    else:
                        blob, fobj = BUCKETS[r].blob(str(self.prefix / "train" / frame_path)), BytesIO()

                    # Error Handling --> we've already run several dry-run/verification trials, but about ~0.004% of
                    #                    the time, we'll hit some sort of TCP/Transport error; this might even go up
                    #                    with multiple runs happening at the same time.
                    #
                    #                    To address this, we're adopting the simplest possible "retry" strategy that
                    #                    immediately tries to re-download the same file (and crashes if not possible).
                    #                    This ensures reproducibility, but puts some extra effort onto the user...

                    # File download...
                    blob.download_to_file(fobj)
                    fobj.seek(0)

                    # Image loading...
                    img_tensor = pil_to_tensor(Image.open(fobj))

                    # Assemble transformed image and return...
                    img = self.transform(img_tensor)
                    return img, lang, lang_mask

                except (NotFound, TransportError, UnidentifiedImageError, OSError) as e:
                    # At the minimum --> print the broken file (obnoxiously!)
                    print(f"=>> BROKEN FILE :: {frame_path}")
                    if not self.do_retry:
                        raise e
                    else:
                        continue

            # If we've exhausted our retries --> raise an informative ValueError
            raise ValueError(f"Failed to fix state `{self.elements[index]['state']}` w/ {self.n_retries} retries...")

    def __len__(self) -> int:
        return len(self.elements)


class StateOKDataset(PretrainDataset):
    def __init__(
        self,
        epoch: int,
        index_path: Path,
        img_transform: Compose,
        lang_dropout: Optional[float] = None,
        stream: bool = False,
        prefix: Optional[Path] = None,
        no_lang: bool = False,
        is_val: bool = False,
        do_retry: bool = True,
        n_retries: int = 3,
    ) -> None:
        super().__init__()
        self.index_path, self.stream, self.is_val, self.val_loaded = index_path, stream, is_val, False
        self.epoch, self.transform, self.elements, self.prefix = epoch, img_transform, None, prefix
        self.no_lang, self.lang_dropout = no_lang, 0.0 if (lang_dropout is None or lang_dropout == 0) else lang_dropout
        self.dropout_indices = set()
        self.r = N_CORES * get_rank()
        self.do_retry, self.n_retries = do_retry, n_retries

        # Load Language Index & Retrieve Epoch 0 Batches
        if not self.no_lang:
            language_path = "val-language-index.json" if self.is_val else "train-language-index.json"
            if not self.stream:
                with open(self.index_path / language_path, "r") as f:
                    self.language_index = json.load(f)
            else:
                blob = BUCKETS[self.r].blob(str(self.prefix / "index" / language_path))
                self.language_index = json.loads(blob.download_as_string())

        # === Retrieve Epoch Batches (only call `set_epoch` inside __init__() ===
        self.set_epoch(self.epoch)

    def set_epoch(self, epoch: int) -> None:
        # Not Streaming --> Read from local disk...
        if not self.stream:
            if self.is_val and not self.val_loaded:
                with open(self.index_path / "state+ok" / "validation-batches.json", "r") as f:
                    self.elements = json.load(f)

                # Assemble the set of dropout indices for the given epoch...
                n_drop = int(self.lang_dropout * len(self.elements))
                self.dropout_indices = set(np.random.choice(len(self.elements), n_drop, replace=False))

                # Set `val_loaded` and move on...
                self.val_loaded = True

            elif not self.is_val:
                with open(self.index_path / "state + ok" / f"train-epoch={epoch}-batches.json", "r") as f:
                    self.elements = json.load(f)

                # Assemble the set of dropout indices for the given epoch...
                n_drop = int(self.lang_dropout * len(self.elements))
                self.dropout_indices = set(np.random.choice(len(self.elements), n_drop, replace=False))

        # Streaming --> Beam directly from Bucket
        else:
            if self.is_val and not self.val_loaded:
                blob = BUCKETS[self.r].blob(str(self.prefix / "index" / "state+ok" / "validation-batches.json"))
                self.elements = json.loads(blob.download_as_string())

                # `elements[i]["states"]` currently maps to disk path.. remove all but `parent/child.jpg`
                for element in self.elements:
                    element["states"] = ["/".join(x.split("/")[-2:]) for x in element["states"]]

                # Assemble the set of dropout indices for the given epoch...
                n_drop = int(self.lang_dropout * len(self.elements))
                self.dropout_indices = set(np.random.choice(len(self.elements), n_drop, replace=False))

                # Set `val_loaded` and move on...
                self.val_loaded = True

            elif not self.is_val:
                blob = BUCKETS[self.r].blob(
                    str(self.prefix / "index" / "state+ok" / f"train-epoch={epoch}-batches.json")
                )
                self.elements = json.loads(blob.download_as_string())

                # `elements[i]["states"]` currently maps to disk path.. remove all but `parent/child.jpg`
                for element in self.elements:
                    element["states"] = ["/".join(x.split("/")[-2:]) for x in element["states"]]

                # Assemble the set of dropout indices for the given epoch...
                n_drop = int(self.lang_dropout * len(self.elements))
                self.dropout_indices = set(np.random.choice(len(self.elements), n_drop, replace=False))

    # ruff: noqa: C901
    def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Return both states/frames and language, decomposed as the input_ids and attention_mask."""
        vid = self.elements[index]["vid"]

        # Fetch language if desired...
        if not self.no_lang:
            lang = torch.tensor(self.language_index[vid]["input_ids"], dtype=torch.int64)
            lang_mask = torch.tensor(self.language_index[vid]["attention_mask"], dtype=torch.int64)

            # Dropout Language Check --> (Naive Zeroing leads to NaN --> just want the "CLS" token)
            if index in self.dropout_indices:
                # Initial language token is *always* <CLS> = `101` --> last token always <SEP> = `102`
                lang[1:] *= 0
                lang_mask[1:] *= 0

        # Retrieve Frames
        if not self.stream:
            imgs = self.elements[index]["states"]
            imgs = torch.stack([self.transform(read_image(s)) for s in imgs])

            # Return --> based on `self.no_lang`
            if not self.no_lang:
                return imgs, lang, lang_mask
            else:
                return imgs

        else:
            # Multiplex w/ num_worker idx...
            worker_info = get_worker_info()
            r = (self.r + worker_info.id) if worker_info is not None else self.r

            # Streaming + Retry Logic (in case of a bad connection -- retry same files!)
            frame_paths, current_frame = list(self.elements[index]["states"]), None
            for _i in range(self.n_retries):
                try:
                    # Stream JPEG contents into BytesIO (seek back to 0), then PIL Image.open()
                    imgs = []
                    for _current_idx, current_frame in enumerate(frame_paths):
                        if self.is_val:
                            blob, fobj = BUCKETS[r].blob(str(self.prefix / "val" / current_frame)), BytesIO()
                        else:
                            blob, fobj = BUCKETS[r].blob(str(self.prefix / "train" / current_frame)), BytesIO()

                        # Error Handling --> we've already run several dry-run/verification trials, but about ~0.004% of
                        #                    the time, we'll hit some sort of TCP/Transport error; this might even go up
                        #                    with multiple runs happening at the same time.
                        #
                        #                    To address this, we're adopting the simplest possible "retry" strategy that
                        #                    immediately tries to re-download the same file (crashes if not possible).
                        #                    This ensures reproducibility, but puts some extra effort onto the user...

                        # File download...
                        blob.download_to_file(fobj)
                        fobj.seek(0)

                        # Image loading...
                        img_tensor = pil_to_tensor(Image.open(fobj))
                        imgs.append(self.transform(img_tensor))

                    # Stack...
                    assert len(imgs) == 2, "Something went awry with try/except in StateOK Dataset..."
                    imgs = torch.stack(imgs)

                    # Return --> based on `self.no_lang`
                    if not self.no_lang:
                        return imgs, lang, lang_mask
                    else:
                        return imgs

                except (NotFound, TransportError, UnidentifiedImageError, OSError) as e:
                    # At the minimum --> print the broken file (obnoxiously!)
                    print(f"=>> BROKEN FILE :: {current_frame}")
                    if not self.do_retry:
                        raise e
                    else:
                        continue

            # If we've exhausted our retries --> raise an informative ValueError
            raise ValueError(f"Failed to fix states `{self.elements[index]['states']}` w/ {self.n_retries} retries...")

    def __len__(self) -> int:
        return len(self.elements)


class GenStateOKDataset(PretrainDataset):
    def __init__(
        self,
        epoch: int,
        index_path: Path,
        img_transform: Compose,
        gen_ratio: float,
        stream: bool = False,
        prefix: Optional[Path] = None,
        is_val: bool = False,
        do_retry: bool = True,
        n_retries: int = 3,
    ) -> None:
        super().__init__()
        self.index_path, self.stream, self.is_val, self.val_loaded = index_path, stream, is_val, False
        self.epoch, self.transform, self.elements, self.prefix = epoch, img_transform, None, prefix
        self.gen_ratio, self.gen_indices = gen_ratio, set()
        self.r = N_CORES * get_rank()
        self.do_retry, self.n_retries = do_retry, n_retries

        # Load Language Index & Retrieve Epoch 0 Batches
        language_path = "val-language-index.json" if self.is_val else "train-language-index.json"
        if not self.stream:
            with open(self.index_path / language_path, "r") as f:
                self.language_index = json.load(f)
        else:
            blob = BUCKETS[self.r].blob(str(self.prefix / "index" / language_path))
            self.language_index = json.loads(blob.download_as_string())

        # === Retrieve Epoch Batches (only call `set_epoch` inside __init__() ===
        self.set_epoch(self.epoch)

    def set_epoch(self, epoch: int) -> None:
        # Not Streaming --> Read from local disk...
        if not self.stream:
            if self.is_val and not self.val_loaded:
                with open(self.index_path / "state+ok" / "validation-batches.json", "r") as f:
                    self.elements = json.load(f)

                # Assemble the set of dropout indices for the given epoch...
                n_gen = int(self.gen_ratio * len(self.elements))
                self.gen_indices = set(np.random.choice(len(self.elements), n_gen, replace=False))

                # Set `val_loaded` and move on...
                self.val_loaded = True

            elif not self.is_val:
                with open(self.index_path / "state+ok" / f"train-epoch={epoch}-batches.json", "r") as f:
                    self.elements = json.load(f)

                # Assemble the set of dropout indices for the given epoch...
                n_gen = int(self.gen_ratio * len(self.elements))
                self.gen_indices = set(np.random.choice(len(self.elements), n_gen, replace=False))

        # Streaming --> Beam directly from Bucket
        else:
            if self.is_val and not self.val_loaded:
                blob = BUCKETS[self.r].blob(str(self.prefix / "index" / "state+ok" / "validation-batches.json"))
                self.elements = json.loads(blob.download_as_string())

                # `elements[i]["states"]` currently maps to disk path.. remove all but `parent/child.jpg`
                for element in self.elements:
                    element["states"] = ["/".join(x.split("/")[-2:]) for x in element["states"]]

                # Assemble the set of dropout indices for the given epoch...
                n_gen = int(self.gen_ratio * len(self.elements))
                self.gen_indices = set(np.random.choice(len(self.elements), n_gen, replace=False))

                # Set `val_loaded` and move on...
                self.val_loaded = True

            elif not self.is_val:
                blob = BUCKETS[self.r].blob(
                    str(self.prefix / "index" / "state+ok" / f"train-epoch={epoch}-batches.json")
                )
                self.elements = json.loads(blob.download_as_string())

                # `elements[i]["states"]` currently maps to disk path.. remove all but `parent/child.jpg`
                for element in self.elements:
                    element["states"] = ["/".join(x.split("/")[-2:]) for x in element["states"]]

                # Assemble the set of dropout indices for the given epoch...
                n_gen = int(self.gen_ratio * len(self.elements))
                self.gen_indices = set(np.random.choice(len(self.elements), n_gen, replace=False))

    def __getitem__(
        self, index: int
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, float]:
        """Return both states/frames, language to condition on, language to generate, decomposed as input_ids/mask."""
        vid = self.elements[index]["vid"]

        # Fetch language to condition on / generate...
        lang_condition = torch.tensor(self.language_index[vid]["input_ids"], dtype=torch.int64)
        lang_condition_mask = torch.tensor(self.language_index[vid]["attention_mask"], dtype=torch.int64)
        lang_gen = torch.tensor(self.language_index[vid]["input_ids"], dtype=torch.int64)
        lang_gen_mask = torch.tensor(self.language_index[vid]["attention_mask"], dtype=torch.int64)

        # Generate/Condition Check --> (Naive Zeroing leads to NaN --> just want the "CLS" token)
        if index in self.gen_indices:
            # If generating, just condition on the <CLS> token (always the initial...), but generate everything!
            lang_condition[1:] *= 0
            lang_condition_mask[1:] *= 0
            lang_gen_weight = 1

        else:
            # If conditioning, generate the <CLS> token (dummy so things don't crash), but set weight to 0
            lang_gen[1:] *= 0
            lang_gen_mask[1:] *= 0
            lang_gen_weight = 0

        # Retrieve Frames
        if not self.stream:
            imgs = self.elements[index]["states"]
            imgs = torch.stack([self.transform(read_image(s)) for s in imgs])

            # Return...
            return imgs, lang_condition, lang_condition_mask, lang_gen, lang_gen_mask, lang_gen_weight

        else:
            # Multiplex w/ num_worker idx...
            worker_info = get_worker_info()
            r = (self.r + worker_info.id) if worker_info is not None else self.r

            # Streaming + Retry Logic (in case of a bad connection -- retry same files!)
            frame_paths, current_frame = list(self.elements[index]["states"]), None
            for _i in range(self.n_retries):
                try:
                    # Stream JPEG contents into BytesIO (seek back to 0), then PIL Image.open()
                    imgs = []
                    for _current_idx, current_frame in enumerate(frame_paths):
                        if self.is_val:
                            blob, fobj = BUCKETS[r].blob(str(self.prefix / "val" / current_frame)), BytesIO()
                        else:
                            blob, fobj = BUCKETS[r].blob(str(self.prefix / "train" / current_frame)), BytesIO()

                        # Error Handling --> we've already run several dry-run/verification trials, but about ~0.004% of
                        #                    the time, we'll hit some sort of TCP/Transport error; this might even go up
                        #                    with multiple runs happening at the same time.
                        #
                        #                    To address this, we're adopting the simplest possible "retry" strategy that
                        #                    immediately tries to re-download the same file (crashes if not possible).
                        #                    This ensures reproducibility, but puts some extra effort onto the user...

                        # File download...
                        blob.download_to_file(fobj)
                        fobj.seek(0)

                        # Image loading...
                        img_tensor = pil_to_tensor(Image.open(fobj))
                        imgs.append(self.transform(img_tensor))

                    # Stack...
                    assert len(imgs) == 2, "Something went awry with try/except in GenStateOK Dataset..."
                    imgs = torch.stack(imgs)

                    # Return...
                    return imgs, lang_condition, lang_condition_mask, lang_gen, lang_gen_mask, lang_gen_weight

                except (NotFound, TransportError, UnidentifiedImageError, OSError) as e:
                    # At the minimum --> print the broken file (obnoxiously!)
                    print(f"=>> BROKEN FILE :: {current_frame}")
                    if not self.do_retry:
                        raise e
                    else:
                        continue

            # If we've exhausted our retries --> raise an informative ValueError
            raise ValueError(f"Failed to fix states `{self.elements[index]['states']}` w/ {self.n_retries} retries...")

    def __len__(self) -> int:
        return len(self.elements)


class QuintetDataset(PretrainDataset):
    def __init__(
        self,
        epoch: int,
        index_path: Path,
        img_transform: Compose,
        lang_dropout: Optional[float] = None,
        stream: bool = False,
        prefix: Optional[Path] = None,
        is_val: bool = False,
    ) -> None:
        super().__init__()
        self.index_path, self.stream, self.is_val, self.val_loaded = index_path, stream, is_val, False
        self.epoch, self.transform, self.elements, self.prefix = epoch, img_transform, None, prefix
        self.lang_dropout = 0.0 if (lang_dropout is None or lang_dropout == 0) else lang_dropout
        self.dropout_indices = set()
        self.r = N_CORES * get_rank()

        # Load Language Index & Retrieve Epoch 0 Batches
        language_path = "val-language-index.json" if self.is_val else "train-language-index.json"
        if not self.stream:
            with open(self.index_path / language_path, "r") as f:
                self.language_index = json.load(f)
        else:
            blob = BUCKETS[self.r].blob(str(self.prefix / "index" / language_path))
            self.language_index = json.loads(blob.download_as_string())

        # === Retrieve Epoch Batches (only call `set_epoch` inside __init__() ===
        self.set_epoch(self.epoch)

    def set_epoch(self, epoch: int) -> None:
        # Not Streaming --> Read from local disk...
        if not self.stream:
            if self.is_val and not self.val_loaded:
                with open(self.index_path / "quintet+language" / "validation-batches.json", "r") as f:
                    self.elements = json.load(f)

                # Assemble the set of dropout indices for the given epoch...
                n_drop = int(self.lang_dropout * len(self.elements))
                self.dropout_indices = set(np.random.choice(len(self.elements), n_drop, replace=False))

                # Set `val_loaded` and move on...
                self.val_loaded = True

            elif not self.is_val:
                with open(self.index_path / "quintet+language" / f"train-epoch={epoch}-batches.json", "r") as f:
                    self.elements = json.load(f)

                # Assemble the set of dropout indices for the given epoch...
                n_drop = int(self.lang_dropout * len(self.elements))
                self.dropout_indices = set(np.random.choice(len(self.elements), n_drop, replace=False))

        # Streaming --> Beam directly from Bucket
        else:
            if self.is_val and not self.val_loaded:
                blob = BUCKETS[self.r].blob(str(self.prefix / "index" / "quintet+language" / "validation-batches.json"))
                self.elements = json.loads(blob.download_as_string())

                # `elements[i]["states"]` currently maps to disk path.. remove all but `parent/child.jpg`
                for element in self.elements:
                    element["states"] = ["/".join(x.split("/")[-2:]) for x in element["states"]]

                # Assemble the set of dropout indices for the given epoch...
                n_drop = int(self.lang_dropout * len(self.elements))
                self.dropout_indices = set(np.random.choice(len(self.elements), n_drop, replace=False))

                # Set `val_loaded` and move on...
                self.val_loaded = True

            elif not self.is_val:
                blob = BUCKETS[self.r].blob(
                    str(self.prefix / "index" / "quintet+language" / f"train-epoch={epoch}-batches.json")
                )
                self.elements = json.loads(blob.download_as_string())

                # `elements[i]["states"]` currently maps to disk path.. remove all but `parent/child.jpg`
                for element in self.elements:
                    element["states"] = ["/".join(x.split("/")[-2:]) for x in element["states"]]

                # Assemble the set of dropout indices for the given epoch...
                n_drop = int(self.lang_dropout * len(self.elements))
                self.dropout_indices = set(np.random.choice(len(self.elements), n_drop, replace=False))

    def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Return all 5 states/frames, and language, decomposed as the input_ids and attention_mask."""
        vid = self.elements[index]["vid"]
        lang = torch.tensor(self.language_index[vid]["input_ids"], dtype=torch.int64)
        lang_mask = torch.tensor(self.language_index[vid]["attention_mask"], dtype=torch.int64)

        # Dropout Language Check --> (Naive Zeroing leads to NaN --> just want the "PAD" token)
        if index in self.dropout_indices:
            # Initial language token is *always* <CLS> = `101` --> last token always <SEP> = `102`
            lang[1:] *= 0
            lang_mask[1:] *= 0

        # Retrieve Frames
        if not self.stream:
            imgs = self.elements[index]["states"]
            imgs = torch.stack([self.transform(read_image(s)) for s in imgs])
        else:
            # Multiplex w/ num_worker idx...
            worker_info, imgs = get_worker_info(), []
            r = (self.r + worker_info.id) if worker_info is not None else self.r

            # Stream JPEG contents into BytesIO (seek back to 0), then PIL Image.open()
            for state in self.elements[index]["states"]:
                if self.is_val:
                    blob, fobj = BUCKETS[r].blob(str(self.prefix / "val" / state)), BytesIO()
                else:
                    blob, fobj = BUCKETS[r].blob(str(self.prefix / "train" / state)), BytesIO()

                # Download into FileObj & Rewind...
                blob.download_to_file(fobj)
                fobj.seek(0)

                # Add to imgs...
                imgs.append(self.transform(pil_to_tensor(Image.open(fobj))))

            # Stack...
            imgs = torch.stack(imgs)

        return imgs, lang, lang_mask

    def __len__(self) -> int:
        return len(self.elements)


def get_epoch_datasets(
    epoch: int,
    name: str,
    normalization: Tuple[Any, Any],
    model_arch: str,
    stream: bool,
    artifact_path: str,
    stream_prefix: str,
    data_modality: str,
    lang_dropout: Optional[float] = None,
    gen_ratio: Optional[float] = None,
) -> Tuple[PretrainDataset, PretrainDataset]:
    """Retrieve the custom Dataset classes for the train & val set for the given dataset & data modality."""
    index, img_transform = Path(artifact_path) / name / "index", get_online_transform(name, model_arch, normalization)
    prefix = Path(stream_prefix) / name if stream else stream_prefix

    # Switch on `data_modality`
    if data_modality == "state":
        train_ds = StateDataset(epoch, index, img_transform, stream, prefix)
        val_ds = StateDataset(epoch, index, img_transform, stream, prefix, is_val=True)

    elif data_modality == "state+language":
        train_ds = StateLanguageDataset(epoch, index, img_transform, lang_dropout, stream, prefix)
        val_ds = StateLanguageDataset(epoch, index, img_transform, lang_dropout, stream, prefix, is_val=True)

    elif data_modality == "state+ok":
        if gen_ratio is None:
            nl = model_arch == "v-dual"
            train_ds = StateOKDataset(epoch, index, img_transform, lang_dropout, stream, prefix, no_lang=nl)
            val_ds = StateOKDataset(epoch, index, img_transform, lang_dropout, stream, prefix, no_lang=nl, is_val=True)
        else:
            # Special Generative Language Dataset...
            train_ds = GenStateOKDataset(epoch, index, img_transform, gen_ratio, stream, prefix)
            val_ds = GenStateOKDataset(epoch, index, img_transform, gen_ratio, stream, prefix, is_val=True)

    elif data_modality == "quintet+language":
        train_ds = QuintetDataset(epoch, index, img_transform, lang_dropout, stream, prefix)
        val_ds = QuintetDataset(epoch, index, img_transform, lang_dropout, stream, prefix, is_val=True)

    else:
        raise NotImplementedError(f"Support for data modality `{data_modality}` not yet implemented!")

    return train_ds, val_ds


================================================
FILE: voltron/models/__init__.py
================================================
from .instantiate import VMVP, VR3M, VRN3M, VCond, VDual, VGen, get_model_optimizer


================================================
FILE: voltron/models/core/__init__.py
================================================


================================================
FILE: voltron/models/core/vcond.py
================================================
"""
vcond.py

PyTorch Module defining the Voltron `V-Cond` variant (single-frame with language-conditioning). In general, follows the
MAE recipe, with the architectural modifications described in the paper:
    - RMSNorm, for stability/performance ("Do Transformer Modifications Transfer...")
    - SwishGLU activations in the Attention Block Feed-Forward MLP (gated linear units) as used in PaLM
    - LayerScale with a default value of 0.1 (from Mistral/CaIT)

References:
    - https://github.com/facebookresearch/mae
    - https://github.com/lucidrains/x-transformers
"""
from typing import Callable, List, Optional, Tuple, Union

import torch
import torch.nn as nn
import transformers
from einops import rearrange, repeat

from voltron.models.util.optimization import get_lr_update
from voltron.models.util.transformer import Block, PatchEmbed, RMSNorm, get_2D_position_embeddings

# Suppress Transformers Logging
transformers.logging.set_verbosity_error()


class VCond(nn.Module):
    def __init__(
        self,
        resolution: int,
        patch_size: int,
        encoder_depth: int,
        encoder_embed_dim: int,
        encoder_n_heads: int,
        decoder_depth: int,
        decoder_embed_dim: int,
        decoder_n_heads: int,
        language_model: str,
        hf_cache: str,
        language_dim: int,
        optimizer: str,
        schedule: str,
        base_lr: float,
        min_lr: float,
        effective_bsz: float,
        betas: Tuple[float, float],
        weight_decay: float,
        warmup_epochs: int,
        max_epochs: int,
        mask_ratio: float = 0.75,
        mlp_ratio: float = 4.0,
        in_channels: int = 3,
        norm_pixel_loss: bool = True,
        use_cls_token: bool = False,
    ) -> None:
        """
        Initialize a VCond model with the requisite architecture parameters.

        :param resolution: Base image resolution -- usually 224 (ImageNet size).
        :param patch_size: Height/Width of each patch in pixels -- usually 16.
        :param encoder_depth: Number of Transformer blocks in the encoder -- should be greater than decoder.
        :param encoder_embed_dim: Core embedding/hidden dimension for encoder vision transformer backbone.
        :param encoder_n_heads: Number of heads for encoder multi-headed self-attention.
        :param decoder_depth: Number of Transformer blocks in the decoder -- should be relatively shallow.
        :param decoder_embed_dim: Core embedding/hidden dimension for encoder vision transformer backbone.
        :param decoder_n_heads: Number of heads for encoder multi-headed self-attention.
        :param language_model: Language model to freeze for encoding narrations/utterances.
        :param hf_cache: Cache directory to store pretrained models, for safe distributed training.
        :param language_dim: Dimensionality of the language embedding coming out of the pretrained LM.
        :param optimizer: String denoting which optimizer to use (for MAEs, usually `adamw`)
        :param schedule: Learning rate schedule to use; for Transformers a linear warmup + decay is recommended!
        :param base_lr: Base learning rate, to be scaled via a linear scaling rule (from scaling laws).
        :param min_lr: Minimum learning rate to decay to over the course of learning (usually 0.0)
        :param effective_bsz: Global batch size for update, dictates the scaling of the base_lr.
        :param betas: Adam optimizer betas (only applicable for `adam` and `adamw`. Prevents early loss spiking.
        :param weight_decay: Weight decay for global weight regularization (only applied to non-bias, non-LN layers).
        :param warmup_epochs: Number of epochs to warmup learning rate for linear warmup schedule.
        :param max_epochs: Total number of training epochs to be run.
        :param mask_ratio: Ratio for number of patches to mask out for MAE -- should be fairly high!
        :param mlp_ratio: Ratio for embedding size to Position-wise Feed-Forward MLP (gets shrunk back down).
        :param in_channels: Default number of channels in the base image -- almost always 3.
        :param norm_pixel_loss: Normalize decoder pixel targets for reconstruction (better perf, not interpretable).
        :param use_cls_token: Add <CLS> token for continued pretraining (NOTE: not used in MAE pretraining/finetuning!)
        """
        super().__init__()
        self.resolution, self.patch_size, self.mask_ratio = resolution, patch_size, mask_ratio
        self.in_channels, self.norm_pixel_loss, self.mlp_ratio = in_channels, norm_pixel_loss, mlp_ratio
        self.optimizer, self.schedule, self.betas, self.weight_decay = optimizer, schedule, betas, weight_decay
        self.lr, self.base_lr, self.min_lr, self.effective_bsz = None, base_lr, min_lr, effective_bsz
        self.warmup_epochs, self.max_epochs = warmup_epochs, max_epochs
        self.use_cls_token = use_cls_token
        self.language_dim = language_dim

        # Encoder/Decoder Parameters
        self.encoder_depth, self.decoder_depth = encoder_depth, decoder_depth
        self.encoder_embed_dim, self.encoder_n_heads = encoder_embed_dim, encoder_n_heads
        self.decoder_embed_dim, self.decoder_n_heads = decoder_embed_dim, decoder_n_heads

        # General Parameters (for downstream adaptation)
        self.embed_dim, self.n_heads = self.encoder_embed_dim, self.encoder_n_heads

        # (Optional) <CLS> Token Handling
        if self.use_cls_token:
            self.cls_token = nn.Parameter(torch.zeros(1, 1, self.encoder_embed_dim))

        # MAE Encoder Parameters
        self.patch2embed = PatchEmbed(
            self.resolution, self.patch_size, self.encoder_embed_dim, in_channels=self.in_channels
        )
        self.encoder_pe = nn.Parameter(
            torch.zeros(1, self.patch2embed.num_patches + (1 if self.use_cls_token else 0), self.encoder_embed_dim),
            requires_grad=False,
        )
        self.encoder_blocks = nn.ModuleList(
            [
                Block(
                    self.encoder_embed_dim,
                    self.encoder_n_heads,
                    self.mlp_ratio,
                    do_rms_norm=True,
                    do_swish_glu=True,
                    do_layer_scale=True,
                )
                for _ in range(self.encoder_depth)
            ]
        )
        self.encoder_norm = RMSNorm(self.encoder_embed_dim)

        # Projection from Language Embedding to Decoder
        self.lang2encoder = nn.Linear(self.language_dim, self.encoder_embed_dim)

        # Projection from Encoder to Decoder
        self.encoder2decoder = nn.Linear(self.encoder_embed_dim, self.decoder_embed_dim)

        # MAE Decoder Parameters -- Remember the CLS Token (if specified)!
        self.mask_token = nn.Parameter(torch.zeros(1, 1, self.decoder_embed_dim))
        self.decoder_pe = nn.Parameter(
            torch.zeros(1, self.patch2embed.num_patches + (1 if self.use_cls_token else 0), self.decoder_embed_dim),
            requires_grad=False,
        )
        self.decoder_blocks = nn.ModuleList(
            [
                Block(
                    self.decoder_embed_dim,
                    self.decoder_n_heads,
                    self.mlp_ratio,
                    do_rms_norm=True,
                    do_swish_glu=True,
                    do_layer_scale=True,
                )
                for _ in range(self.decoder_depth)
            ]
        )
        self.decoder_norm = RMSNorm(self.decoder_embed_dim)
        self.decoder_prediction = nn.Linear(self.decoder_embed_dim, (patch_size**2) * in_channels, bias=True)

        # VCond -- Add "Image" and "Language" Modifier Tokens...
        self.img_token = nn.Parameter(torch.zeros(1, 1, self.encoder_embed_dim))
        self.lang_token = nn.Parameter(torch.zeros(1, 1, self.encoder_embed_dim))

        # Initialize all Weights
        self.initialize_weights()

        # @AFTER INITIALIZATION -- Create Language Model & Language Reward MLP --> LM has requires_grad = False
        #   > For BERT models, our "embedding" is just going to be the last hidden state
        #   > Assumes inputs to forward pass are pre-tokenized!
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(language_model, cache_dir=hf_cache)
        self.lm = transformers.AutoModel.from_pretrained(language_model, cache_dir=hf_cache)
        self.lm.eval()

        # Shape Assertion -- make sure self.language_dim actually is the same as the LM dimension!
        assert self.lm.config.dim == self.language_dim, "Language model embedding dimension != self.language_dim!"

        # Freeze the LM
        for _, param in self.lm.named_parameters():
            param.requires_grad = False

    def initialize_weights(self) -> None:
        # Position Encoding -- Fixed 2D Sine-Cosine Embeddings
        enc_pe = get_2D_position_embeddings(
            self.encoder_embed_dim, int(self.patch2embed.num_patches**0.5), cls_token=self.use_cls_token
        )
        self.encoder_pe.data.copy_(torch.from_numpy(enc_pe).float().unsqueeze(0))
        dec_pe = get_2D_position_embeddings(
            self.decoder_embed_dim, int(self.patch2embed.num_patches**0.5), cls_token=self.use_cls_token
        )
        self.decoder_pe.data.copy_(torch.from_numpy(dec_pe).float().unsqueeze(0))

        # Initialize PatchEmbedding as a Linear...
        nn.init.xavier_uniform_(self.patch2embed.proj.weight.data.view([self.patch2embed.proj.weight.data.shape[0], -1]))

        # Initialize Mask Token, Img Token, Lang Token w/ Truncated Normal
        nn.init.normal_(self.mask_token, std=0.02)
        nn.init.normal_(self.img_token, std=0.02)
        nn.init.normal_(self.lang_token, std=0.02)
        if self.use_cls_token:
            nn.init.normal_(self.cls_token, std=0.02)

        # Default Transformer initializer on everything else...
        self.apply(self.transformer_initializer)

    @staticmethod
    def transformer_initializer(m: nn.Module) -> None:
        # Use `xavier_uniform` following Jax ViT
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0.0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.weight, 1.0)
            nn.init.constant_(m.bias, 0.0)

    def encode_language(self, lang: torch.Tensor, lang_mask: torch.Tensor) -> torch.Tensor:
        """Encode language by feeding the *pre-tokenized text* through the frozen language model."""
        self.lm.eval()
        with torch.no_grad():
            transformer_embeddings = self.lm(lang, attention_mask=lang_mask).last_hidden_state
        return transformer_embeddings

    def mask(
        self, patches: torch.Tensor, mask_ratio: Optional[float] = None
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Perform per-sample random masking by shuffling :: uses argsort random noise to identify masked patches."""
        bsz, n_patches, embed_dim = patches.shape
        if mask_ratio is not None:
            n_keep = int(n_patches * (1 - mask_ratio))
        else:
            n_keep = int(n_patches * (1 - self.mask_ratio))

        # Sample noise of n_patches size, argsort to get shuffled IDs, argsort again to get "unshuffle"
        #   > For clarity -- argsort is an invertible transformation (if argsort `restore`, recovers `shuffle`)
        shuffle_idxs = torch.argsort(torch.rand(bsz, n_patches, device=patches.device), dim=1)
        restore_idxs = torch.argsort(shuffle_idxs, dim=1)

        # Get "keep" (visible) patches
        visible_patches = torch.gather(patches, dim=1, index=shuffle_idxs[:, :n_keep, None].repeat(1, 1, embed_dim))

        # Generate the binary mask --> IMPORTANT :: `0` is keep, `1` is remove (following MAE convention)
        mask = torch.ones(bsz, n_patches, device=patches.device)
        mask[:, :n_keep] = 0
        mask = torch.gather(mask, dim=1, index=restore_idxs)

        return visible_patches, mask, restore_idxs

    def get_representations(
        self, img: torch.Tensor, language: Optional[Union[List[str], Tuple[str]]] = None, mode: str = "multimodal"
    ) -> torch.Tensor:
        """
        Given either a singleton (img, language) pair or a batch of images and language, extract representations
        subject to the specified mode in < multimodal | visual >.

        :param img: Processed batch of images :: [bsz, 3, 224, 224]
        :param language: Input language as `List[str] | Tuple[str] | None`
        :param mode: Type of representations to extract -- `multimodal` (both vision+text), `visual` (visual only)

        :return: Extracted representations given (img, language) input as sequence.
        """
        assert img.ndim == 4 and (
            language is None or isinstance(language, list) or isinstance(language, tuple)
        ), "Invalid input to `get_representations()`"
        assert mode in {"multimodal", "visual"}, f"Extraction mode `{mode}` not supported!"

        # Tokenize Language --> note max length is 20!
        if language is None:
            lang, lang_mask = [torch.zeros(img.shape[0], 20, dtype=int, device=self.lm.device) for _ in range(2)]
            lang[:, 0], lang_mask[:, 0] = self.tokenizer.cls_token_id, 1
        else:
            tokens = self.tokenizer(language, return_tensors="pt", max_length=20, padding="max_length", truncation=True)
            lang, lang_mask = tokens["input_ids"].to(self.lm.device), tokens["attention_mask"].to(self.lm.device)

            # Tile Language & Language Mask if mismatch with # images!
            if not len(lang) == len(img):
                lang = repeat(lang, "b seq -> (bsz b) seq", bsz=img.size(0))
                lang_mask = repeat(lang_mask, "b seq -> (bsz b) seq", bsz=img.size(0))

        # Extract desired representations...
        representations = self.encode(img, lang, lang_mask)
        return representations if mode == "multimodal" else representations[:, : -lang_mask.shape[-1]]

    def encode(self, img: torch.Tensor, lang: torch.Tensor, lang_mask: torch.Tensor) -> torch.Tensor:
        """Default representation extraction function, given a batch of images and tokenized language."""
        lang_embeddings = self.encode_language(lang, lang_mask)
        projected_language = self.lang2encoder(lang_embeddings)

        # Patchify
        patches = self.patch2embed(img)

        # (Optional) <CLS> Token Handling
        if self.use_cls_token:
            cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)
            patches = torch.cat([cls_tokens, patches], dim=1)

        # Position Encoding
        patches_pe = patches + self.encoder_pe

        # Add "modality" embeddings to patches & language
        img_embeddings, lang_embeddings = patches_pe + self.img_token, projected_language + self.lang_token

        # Create "dummy" visible mask, concatenate image patches & language, feed to Transformer
        patches_mask = torch.ones_like(img_embeddings[..., -1], dtype=lang_mask.dtype)
        multimodal_embeddings = torch.cat([img_embeddings, lang_embeddings], dim=1)  # Merge on sequence length...
        multimodal_mask = torch.cat([patches_mask, lang_mask], dim=1)  # Merge on sequence length...

        # Apply Transformer Blocks...
        for block in self.encoder_blocks:
            multimodal_embeddings = block(multimodal_embeddings, multimodal_mask)
        multimodal_embeddings = self.encoder_norm(multimodal_embeddings)

        # Return the full sequence of multimodal embeddings...
        return multimodal_embeddings

    def forward_encoder(
        self, img: torch.Tensor, lang: torch.Tensor, lang_mask: torch.Tensor, mask_ratio: Optional[float] = None
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        lang_embeddings = self.encode_language(lang, lang_mask)
        projected_lang = self.lang2encoder(lang_embeddings)

        # Patchify + Position Embedding (without <CLS> Token!)
        patches = self.patch2embed(img)
        patches_pe = patches + (self.encoder_pe if not self.use_cls_token else self.encoder_pe[:, 1:, :])

        # Create mask (and go ahead and mask out patches at the same time)
        visible_patches, mask, restore_idxs = self.mask(patches_pe, mask_ratio)

        # (Optional) <CLS> Token Handling
        if self.use_cls_token:
            cls_token_pe = self.cls_token + self.encoder_pe[:, :1, :]
            cls_tokens = cls_token_pe.expand(img.shape[0], -1, -1)
            visible_patches = torch.cat([cls_tokens, visible_patches], dim=1)

        # Add "modality" embeddings to patches & language
        visible_patches, projected_lang = visible_patches + self.img_token, projected_lang + self.lang_token

        # Create "dummy" visible mask, concatenate image patches & language, feed to Transformer
        visible_mask = torch.ones_like(visible_patches[..., -1], dtype=lang_mask.dtype)
        multimodal_embedding = torch.cat([visible_patches, projected_lang], dim=1)  # Merge on sequence length...
        multimodal_mask = torch.cat([visible_mask, lang_mask], dim=1)  # Merge on sequence length...

        # Apply Transformer Blocks...
        for block in self.encoder_blocks:
            multimodal_embedding = block(multimodal_embedding, multimodal_mask)
        multimodal_embedding = self.encoder_norm(multimodal_embedding)

        # Split multimodal embedding, remove language and return only the visible patches (+ optional <CLS> token)!
        visible_patches = multimodal_embedding[:, : -lang_mask.shape[-1], ...]
        return visible_patches, mask, restore_idxs

    def forward_decoder(self, visible_patches: torch.Tensor, restore_idxs: torch.Tensor) -> torch.Tensor:
        # Project patches into decoder embedding dimension
        projected_patches = self.encoder2decoder(visible_patches)

        # Add Mask Tokens to Sequence & Unshuffle
        mask_tokens = self.mask_token.repeat(
            projected_patches.shape[0],
            restore_idxs.shape[1] - visible_patches.shape[1] + (1 if self.use_cls_token else 0),
            1,
        )

        # (Optional) <CLS> Token Handling
        if self.use_cls_token:
            # Remove & add back CLS Token as part of the "unshuffling"
            concatenated_patches = torch.cat([projected_patches[:, 1:, :], mask_tokens], dim=1)  # Skip CLS Token
            no_cls_unshuffled_patches = torch.gather(
                concatenated_patches, dim=1, index=restore_idxs[..., None].repeat(1, 1, self.decoder_embed_dim)
            )
            unshuffled_patches = torch.cat([projected_patches[:, :1, :], no_cls_unshuffled_patches], dim=1)
        else:
            concatenated_patches = torch.cat([projected_patches, mask_tokens], dim=1)
            unshuffled_patches = torch.gather(
                concatenated_patches, dim=1, index=restore_idxs[..., None].repeat(1, 1, self.decoder_embed_dim)
            )

        # Add Position Embeddings
        decoder_patches = unshuffled_patches + self.decoder_pe

        # Apply Transformer Blocks...
        for block in self.decoder_blocks:
            decoder_patches = block(decoder_patches)
        decoder_patches = self.decoder_norm(decoder_patches)

        # Run final projection & return --> note <CLS> token handling!
        decoder_prediction = self.decoder_prediction(decoder_patches)
        return decoder_prediction if not self.use_cls_token else decoder_prediction[:, 1:, :]

    def patchify(self, imgs: torch.Tensor) -> torch.Tensor:
        """Convert a batch of images to their patched equivalents, by naive reshaping"""
        return rearrange(
            imgs,
            "bsz c (height patch_h) (width patch_w) -> bsz (height width) (patch_h patch_w c)",
            patch_h=self.patch_size,
            patch_w=self.patch_size,
        )

    def compute_loss(self, imgs: torch.Tensor, reconstructions: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        assert self.norm_pixel_loss, "`norm_pixel_loss` should always be true... false only for visualizations!"
        targets = self.patchify(imgs)

        # Normalize targets...
        mu, var = targets.mean(dim=-1, keepdim=True), targets.var(dim=-1, unbiased=True, keepdim=True)
        targets = (targets - mu) / ((var + 1e-6) ** 0.5)

        # Compute mean loss per patch first...
        mse = (reconstructions - targets) ** 2
        avg_loss_per_patch = mse.mean(dim=-1)

        # Compute mean loss only on *removed* patches and return
        return (avg_loss_per_patch * mask).sum() / mask.sum()

    def forward(
        self, img: torch.Tensor, lang: torch.Tensor, lang_mask: torch.Tensor, mask_ratio: Optional[float] = None
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        visible_patches, mask, restore_idxs = self.forward_encoder(img, lang, lang_mask, mask_ratio)
        reconstructions = self.forward_decoder(visible_patches, restore_idxs)
        loss = self.compute_loss(img, reconstructions, mask)
        return loss, reconstructions, mask

    def configure_optimizer(self) -> Tuple[torch.optim.Optimizer, Callable[[int, float], float]]:
        # Short-Circuit on Valid Optimizers
        if self.optimizer not in ["adamw"]:
            raise NotImplementedError(f"Optimizer `{self.optimizer}` not supported - try [`adamw`] instead!")

        # Create Parameter Groups --> bias terms, normalization layer parameters shouldn't be decayed...
        #   > This is a compact rewrite of `param_groups_weight_decay()` from TIMM because I don't want the dependency
        decay, no_decay = [], []
        for name, param in self.named_parameters():
            if not param.requires_grad:
                continue

            # Check on any parameters with fewer than 2 dimensions or with "bias" in the name...
            if param.ndim <= 1 or name.endswith(".bias"):
                no_decay.append(param)
            else:
                decay.append(param)

        # Build Parameter Groups
        groups = [{"params": decay, "weight_decay": self.weight_decay}, {"params": no_decay, "weight_decay": 0.0}]

        # Compute LR -- MAE uses the `linear scaling rule` :: lr = base_lr * (effective_bsz / 256)
        #   > https://github.com/facebookresearch/mae/blob/main/PRETRAIN.md
        self.lr = self.base_lr * (self.effective_bsz / 256)

        # Create Optimizer & LR Scheduler
        optimizer = torch.optim.AdamW(groups, lr=self.lr, betas=self.betas)
        update_lr = get_lr_update(optimizer, self.schedule, self.lr, self.min_lr, self.warmup_epochs, self.max_epochs)
        return optimizer, update_lr


================================================
FILE: voltron/models/core/vdual.py
================================================
"""
vdual.py

PyTorch Module defining the Voltron `V-Dual` variant (dual-frame with language-conditioning). In general, follows the
MAE recipe, with the same modifications described in `vcond.py`.

When masking visual patches, the same patches are elided for both the 0th frame and the Kth frame to avoid cheating!

References:
    - https://github.com/huggingface/m4/blob/main/m4/modeling/pretraining/video/videomae.py
    - https://github.com/lucidrains/vit-pytorch
    - https://github.com/MCG-NJU/VideoMAE
"""
from typing import Callable, List, Optional, Tuple, Union

import torch
import torch.nn as nn
import transformers
from einops import rearrange, repeat

from voltron.models.util.optimization import get_lr_update
from voltron.models.util.transformer import Block, PatchEmbed, RMSNorm, get_2D_position_embeddings

# Suppress Transformers Logging
transformers.logging.set_verbosity_error()


class VDual(nn.Module):
    def __init__(
        self,
        resolution: int,
        patch_size: int,
        encoder_depth: int,
        encoder_embed_dim: int,
        encoder_n_heads: int,
        decoder_depth: int,
        decoder_embed_dim: int,
        decoder_n_heads: int,
        language_model: str,
        hf_cache: str,
        language_dim: int,
        optimizer: str,
        schedule: str,
        base_lr: float,
        min_lr: float,
        effective_bsz: int,
        betas: Tuple[float, float],
        weight_decay: float,
        warmup_epochs: int,
        max_epochs: int,
        mask_ratio: float = 0.75,
        mlp_ratio: float = 4.0,
        in_channels: int = 3,
        norm_pixel_loss: bool = True,
        use_cls_token: bool = False,
    ) -> None:
        """
        Initialize a VDual model with the requisite architecture parameters.

        :param resolution: Base image resolution -- usually 224 (ImageNet size).
        :param patch_size: Height/Width of each patch in pixels -- usually 16.
        :param encoder_depth: Number of Transformer blocks in the encoder -- should be greater than decoder.
        :param encoder_embed_dim: Core embedding/hidden dimension for encoder vision transformer backbone.
        :param encoder_n_heads: Number of heads for encoder multi-headed self-attention.
        :param decoder_depth: Number of Transformer blocks in the decoder -- should be relatively shallow.
        :param decoder_embed_dim: Core embedding/hidden dimension for encoder vision transformer backbone.
        :param decoder_n_heads: Number of heads for encoder multi-headed self-attention.
        :param language_model: Language model to freeze for encoding narrations/utterances.
        :param hf_cache: Cache directory to store pretrained models, for safe distributed training.
        :param language_dim: Dimensionality of the language embedding coming out of the pretrained LM.
        :par
Download .txt
gitextract_4mb9jrxp/

├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── Makefile
├── README.md
├── docs/
│   └── ROADMAP.md
├── examples/
│   ├── pretrain/
│   │   ├── README.md
│   │   ├── preprocess.py
│   │   └── pretrain.py
│   ├── usage.py
│   ├── verification/
│   │   └── verify.py
│   └── xla-reference/
│       ├── README.md
│       ├── xpreprocess.py
│       └── xpretrain.py
├── pyproject.toml
├── setup.py
└── voltron/
    ├── __init__.py
    ├── conf/
    │   ├── __init__.py
    │   ├── accelerators.py
    │   ├── datasets.py
    │   ├── models.py
    │   └── tracking.py
    ├── datasets/
    │   ├── __init__.py
    │   ├── datasets.py
    │   └── v1/
    │       ├── __init__.py
    │       └── stream_datasets.py
    ├── models/
    │   ├── __init__.py
    │   ├── core/
    │   │   ├── __init__.py
    │   │   ├── vcond.py
    │   │   ├── vdual.py
    │   │   └── vgen.py
    │   ├── instantiate.py
    │   ├── materialize.py
    │   ├── reproductions/
    │   │   ├── __init__.py
    │   │   ├── vmvp.py
    │   │   ├── vr3m.py
    │   │   └── vrn3m.py
    │   └── util/
    │       ├── __init__.py
    │       ├── extraction.py
    │       ├── optimization.py
    │       └── transformer.py
    ├── overwatch/
    │   ├── __init__.py
    │   └── overwatch.py
    ├── preprocessing/
    │   ├── __init__.py
    │   ├── core.py
    │   ├── process.py
    │   ├── transforms.py
    │   └── v1/
    │       ├── __init__.py
    │       ├── process.py
    │       ├── transforms.py
    │       └── utils.py
    └── util/
        ├── __init__.py
        ├── checkpointing.py
        ├── metrics.py
        ├── utilities.py
        └── v1/
            ├── __init__.py
            ├── checkpointing.py
            ├── distributed.py
            ├── random.py
            └── xla_logger.py
Download .txt
SYMBOL INDEX (305 symbols across 37 files)

FILE: examples/pretrain/preprocess.py
  class PreprocessingConfig (line 32) | class PreprocessingConfig:
  function preprocess (line 55) | def preprocess(cfg: PreprocessingConfig) -> None:

FILE: examples/pretrain/pretrain.py
  class PretrainConfig (line 49) | class PretrainConfig:
  function pretrain (line 79) | def pretrain(cfg: PretrainConfig) -> None:

FILE: examples/usage.py
  function usage (line 17) | def usage() -> None:

FILE: examples/verification/verify.py
  function verify (line 24) | def verify() -> None:

FILE: examples/xla-reference/xpreprocess.py
  class PreprocessingConfig (line 31) | class PreprocessingConfig:
  function xpreprocess (line 54) | def xpreprocess(cfg: PreprocessingConfig) -> None:

FILE: examples/xla-reference/xpretrain.py
  class PretrainConfig (line 65) | class PretrainConfig:
  function xpretrain (line 97) | def xpretrain(cfg: PretrainConfig) -> None:
  function mp_fn (line 819) | def mp_fn(_: int, cfg: PretrainConfig) -> None:
  function main (line 827) | def main(cfg: PretrainConfig) -> None:

FILE: voltron/conf/accelerators.py
  class AcceleratorConfig (line 16) | class AcceleratorConfig:
  class TPUv2OneConfig (line 23) | class TPUv2OneConfig(AcceleratorConfig):
  class TPUv2EightConfig (line 30) | class TPUv2EightConfig(AcceleratorConfig):
  class TPUv3OneConfig (line 37) | class TPUv3OneConfig(AcceleratorConfig):
  class TPUv3EightConfig (line 44) | class TPUv3EightConfig(AcceleratorConfig):
  class TorchRunDefaultConfig (line 55) | class TorchRunDefaultConfig(AcceleratorConfig):

FILE: voltron/conf/datasets.py
  class DatasetConfig (line 16) | class DatasetConfig:
  class SthSthv2Config (line 61) | class SthSthv2Config(DatasetConfig):

FILE: voltron/conf/models.py
  class ModelConfig (line 16) | class ModelConfig:
  class MVPConfig (line 40) | class MVPConfig(ModelConfig):
  class MVPSmallConfig (line 75) | class MVPSmallConfig(MVPConfig):
  class R3MConfig (line 97) | class R3MConfig(ModelConfig):
  class R3MSmallConfig (line 136) | class R3MSmallConfig(R3MConfig):
  class ResNet3MConfig (line 153) | class ResNet3MConfig(ModelConfig):
  class RN3M50Config (line 187) | class RN3M50Config(ResNet3MConfig):
  class VCondConfig (line 203) | class VCondConfig(ModelConfig):
  class VCondSmallConfig (line 247) | class VCondSmallConfig(VCondConfig):
  class VCondBaseConfig (line 272) | class VCondBaseConfig(VCondConfig):
  class VDualConfig (line 299) | class VDualConfig(ModelConfig):
  class VDualSmallConfig (line 344) | class VDualSmallConfig(VDualConfig):
  class VDualBaseConfig (line 369) | class VDualBaseConfig(VDualConfig):
  class VGenConfig (line 396) | class VGenConfig(ModelConfig):
  class VGen50SmallConfig (line 442) | class VGen50SmallConfig(VGenConfig):
  class VGen50BaseConfig (line 467) | class VGen50BaseConfig(VGenConfig):

FILE: voltron/conf/tracking.py
  class TrackingConfig (line 15) | class TrackingConfig:
  class VoltronTrackingConfig (line 39) | class VoltronTrackingConfig(TrackingConfig):

FILE: voltron/datasets/datasets.py
  class PretrainDataset (line 28) | class PretrainDataset(Dataset):
    method __init__ (line 29) | def __init__(self) -> None:
    method hydrate (line 34) | def hydrate(self, path: Path) -> None:
    method set_epoch (line 43) | def set_epoch(self, epoch: int) -> None:
    method __getitem__ (line 46) | def __getitem__(self, idx: int) -> Tuple[torch.Tensor, ...]:
    method __len__ (line 49) | def __len__(self) -> int:
  class StateDataset (line 53) | class StateDataset(PretrainDataset):
    method __init__ (line 54) | def __init__(self, epoch: int, index_path: Path, img_transform: Compos...
    method set_epoch (line 62) | def set_epoch(self, epoch: int) -> None:
    method __getitem__ (line 78) | def __getitem__(self, idx: int) -> torch.Tensor:
    method __len__ (line 85) | def __len__(self) -> int:
  class StateLanguageDataset (line 89) | class StateLanguageDataset(PretrainDataset):
    method __init__ (line 90) | def __init__(
    method set_epoch (line 109) | def set_epoch(self, epoch: int) -> None:
    method __getitem__ (line 129) | def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, t...
    method __len__ (line 148) | def __len__(self) -> int:
  class StateOKDataset (line 152) | class StateOKDataset(PretrainDataset):
    method __init__ (line 153) | def __init__(
    method set_epoch (line 172) | def set_epoch(self, epoch: int) -> None:
    method __getitem__ (line 192) | def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, t...
    method __len__ (line 212) | def __len__(self) -> int:
  class GenStateOKDataset (line 216) | class GenStateOKDataset(PretrainDataset):
    method __init__ (line 217) | def __init__(
    method set_epoch (line 236) | def set_epoch(self, epoch: int) -> None:
    method __getitem__ (line 256) | def __getitem__(
    method __len__ (line 287) | def __len__(self) -> int:
  class QuintetDataset (line 291) | class QuintetDataset(PretrainDataset):
    method __init__ (line 292) | def __init__(self, epoch: int, index_path: Path, img_transform: Compos...
    method set_epoch (line 303) | def set_epoch(self, epoch: int) -> None:
    method __getitem__ (line 319) | def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, t...
    method __len__ (line 333) | def __len__(self) -> int:
  function get_datasets (line 337) | def get_datasets(

FILE: voltron/datasets/v1/stream_datasets.py
  class PretrainDataset (line 54) | class PretrainDataset(Dataset):
    method set_epoch (line 55) | def set_epoch(self, epoch: int) -> None:
  class StateDataset (line 59) | class StateDataset(PretrainDataset):
    method __init__ (line 60) | def __init__(
    method set_epoch (line 80) | def set_epoch(self, epoch: int) -> None:
    method __getitem__ (line 115) | def __getitem__(self, index: int) -> torch.Tensor:
    method __len__ (line 163) | def __len__(self) -> int:
  class StateLanguageDataset (line 167) | class StateLanguageDataset(PretrainDataset):
    method __init__ (line 168) | def __init__(
    method set_epoch (line 200) | def set_epoch(self, epoch: int) -> None:
    method __getitem__ (line 253) | def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor,...
    method __len__ (line 314) | def __len__(self) -> int:
  class StateOKDataset (line 318) | class StateOKDataset(PretrainDataset):
    method __init__ (line 319) | def __init__(
    method set_epoch (line 353) | def set_epoch(self, epoch: int) -> None:
    method __getitem__ (line 407) | def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor,...
    method __len__ (line 487) | def __len__(self) -> int:
  class GenStateOKDataset (line 491) | class GenStateOKDataset(PretrainDataset):
    method __init__ (line 492) | def __init__(
    method set_epoch (line 523) | def set_epoch(self, epoch: int) -> None:
    method __getitem__ (line 576) | def __getitem__(
    method __len__ (line 660) | def __len__(self) -> int:
  class QuintetDataset (line 664) | class QuintetDataset(PretrainDataset):
    method __init__ (line 665) | def __init__(
    method set_epoch (line 694) | def set_epoch(self, epoch: int) -> None:
    method __getitem__ (line 747) | def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor,...
    method __len__ (line 787) | def __len__(self) -> int:
  function get_epoch_datasets (line 791) | def get_epoch_datasets(

FILE: voltron/models/core/vcond.py
  class VCond (line 28) | class VCond(nn.Module):
    method __init__ (line 29) | def __init__(
    method initialize_weights (line 179) | def initialize_weights(self) -> None:
    method transformer_initializer (line 204) | def transformer_initializer(m: nn.Module) -> None:
    method encode_language (line 214) | def encode_language(self, lang: torch.Tensor, lang_mask: torch.Tensor)...
    method mask (line 221) | def mask(
    method get_representations (line 246) | def get_representations(
    method encode (line 281) | def encode(self, img: torch.Tensor, lang: torch.Tensor, lang_mask: tor...
    method forward_encoder (line 313) | def forward_encoder(
    method forward_decoder (line 349) | def forward_decoder(self, visible_patches: torch.Tensor, restore_idxs:...
    method patchify (line 386) | def patchify(self, imgs: torch.Tensor) -> torch.Tensor:
    method compute_loss (line 395) | def compute_loss(self, imgs: torch.Tensor, reconstructions: torch.Tens...
    method forward (line 410) | def forward(
    method configure_optimizer (line 418) | def configure_optimizer(self) -> Tuple[torch.optim.Optimizer, Callable...

FILE: voltron/models/core/vdual.py
  class VDual (line 28) | class VDual(nn.Module):
    method __init__ (line 29) | def __init__(
    method initialize_weights (line 185) | def initialize_weights(self) -> None:
    method transformer_initializer (line 210) | def transformer_initializer(m: nn.Module) -> None:
    method encode_language (line 220) | def encode_language(self, lang: torch.Tensor, lang_mask: torch.Tensor)...
    method mask (line 227) | def mask(
    method get_representations (line 255) | def get_representations(
    method encode (line 292) | def encode(self, imgs: torch.Tensor, lang: torch.Tensor, lang_mask: to...
    method forward_encoder (line 333) | def forward_encoder(
    method forward_decoder (line 372) | def forward_decoder(self, visible_patches: torch.Tensor, restore_idxs:...
    method patchify (line 424) | def patchify(self, imgs: torch.Tensor) -> torch.Tensor:
    method compute_loss (line 433) | def compute_loss(
    method forward (line 454) | def forward(
    method configure_optimizer (line 477) | def configure_optimizer(self) -> Tuple[torch.optim.Optimizer, Callable...

FILE: voltron/models/core/vgen.py
  class VGen (line 32) | class VGen(nn.Module):
    method __init__ (line 33) | def __init__(
    method initialize_weights (line 210) | def initialize_weights(self) -> None:
    method transformer_initializer (line 235) | def transformer_initializer(m: nn.Module) -> None:
    method embed_language (line 245) | def embed_language(self, lang: torch.Tensor) -> torch.Tensor:
    method encode_language (line 253) | def encode_language(self, lang: torch.Tensor, lang_mask: torch.Tensor)...
    method mask (line 260) | def mask(
    method get_representations (line 288) | def get_representations(
    method encode (line 325) | def encode(self, imgs: torch.Tensor, lang: torch.Tensor, lang_mask: to...
    method score (line 369) | def score(self, imgs: torch.Tensor, langs: torch.Tensor, lang_masks: t...
    method forward_encoder (line 479) | def forward_encoder(
    method forward_decoder (line 524) | def forward_decoder(
    method patchify (line 601) | def patchify(self, imgs: torch.Tensor) -> torch.Tensor:
    method compute_loss (line 610) | def compute_loss(
    method forward (line 677) | def forward(
    method configure_optimizer (line 718) | def configure_optimizer(self) -> Tuple[torch.optim.Optimizer, Callable...

FILE: voltron/models/instantiate.py
  function get_model_optimizer (line 22) | def get_model_optimizer(

FILE: voltron/models/materialize.py
  function available_models (line 66) | def available_models() -> List[str]:
  function load (line 70) | def load(

FILE: voltron/models/reproductions/vmvp.py
  class VMVP (line 21) | class VMVP(nn.Module):
    method __init__ (line 22) | def __init__(
    method initialize_weights (line 113) | def initialize_weights(self) -> None:
    method transformer_initializer (line 131) | def transformer_initializer(m: nn.Module) -> None:
    method mask (line 141) | def mask(
    method get_representations (line 166) | def get_representations(self, img: torch.Tensor, mode: str = "patch") ...
    method encode (line 184) | def encode(self, img: torch.Tensor) -> torch.Tensor:
    method forward_encoder (line 199) | def forward_encoder(
    method forward_decoder (line 221) | def forward_decoder(self, visible_patches: torch.Tensor, restore_idxs:...
    method patchify (line 250) | def patchify(self, imgs: torch.Tensor) -> torch.Tensor:
    method compute_loss (line 259) | def compute_loss(self, imgs: torch.Tensor, reconstructions: torch.Tens...
    method forward (line 274) | def forward(
    method configure_optimizer (line 283) | def configure_optimizer(self) -> Tuple[torch.optim.Optimizer, Callable...

FILE: voltron/models/reproductions/vr3m.py
  class VR3M (line 24) | class VR3M(nn.Module):
    method __init__ (line 25) | def __init__(
    method initialize_weights (line 134) | def initialize_weights(self) -> None:
    method transformer_initializer (line 146) | def transformer_initializer(m: nn.Module) -> None:
    method get_representations (line 156) | def get_representations(self, img: torch.Tensor) -> torch.Tensor:
    method encode_images (line 174) | def encode_images(self, imgs: torch.Tensor) -> torch.Tensor:
    method encode_language (line 187) | def encode_language(self, lang: torch.Tensor, lang_mask: torch.Tensor)...
    method get_reward (line 194) | def get_reward(self, initial: torch.Tensor, later: torch.Tensor, lang:...
    method forward (line 197) | def forward(self, imgs: torch.Tensor, lang: torch.Tensor, lang_mask: t...
    method time_similarity (line 236) | def time_similarity(state_x: torch.Tensor, state_y: torch.Tensor, use_...
    method get_time_contrastive_loss (line 241) | def get_time_contrastive_loss(
    method get_reward_loss (line 281) | def get_reward_loss(
    method configure_optimizer (line 341) | def configure_optimizer(self) -> Tuple[torch.optim.Optimizer, Callable...

FILE: voltron/models/reproductions/vrn3m.py
  class VRN3M (line 24) | class VRN3M(nn.Module):
    method __init__ (line 25) | def __init__(
    method get_representations (line 102) | def get_representations(self, img: torch.Tensor) -> torch.Tensor:
    method encode_images (line 113) | def encode_images(self, imgs: torch.Tensor) -> torch.Tensor:
    method encode_language (line 117) | def encode_language(self, lang: torch.Tensor, lang_mask: torch.Tensor)...
    method get_reward (line 124) | def get_reward(self, initial: torch.Tensor, later: torch.Tensor, lang:...
    method extract_features (line 127) | def extract_features(self, img: torch.Tensor) -> torch.Tensor:
    method forward (line 131) | def forward(self, imgs: torch.Tensor, lang: torch.Tensor, lang_mask: t...
    method time_similarity (line 170) | def time_similarity(state_x: torch.Tensor, state_y: torch.Tensor, use_...
    method get_time_contrastive_loss (line 175) | def get_time_contrastive_loss(
    method get_reward_loss (line 215) | def get_reward_loss(
    method configure_optimizer (line 275) | def configure_optimizer(self) -> Tuple[torch.optim.Optimizer, Callable...

FILE: voltron/models/util/extraction.py
  class MAPAttention (line 22) | class MAPAttention(nn.Module):
    method __init__ (line 23) | def __init__(self, embed_dim: int, n_heads: int) -> None:
    method forward (line 33) | def forward(self, seed: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
  class MAPBlock (line 51) | class MAPBlock(nn.Module):
    method __init__ (line 52) | def __init__(
    method forward (line 88) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  function instantiate_extractor (line 96) | def instantiate_extractor(backbone: nn.Module, n_latents: int = 1) -> Ca...

FILE: voltron/models/util/optimization.py
  function get_lr_update (line 19) | def get_lr_update(

FILE: voltron/models/util/transformer.py
  function get_1D_sine_cosine (line 22) | def get_1D_sine_cosine(dim: int, pos: np.ndarray) -> np.ndarray:
  function get_1D_position_embeddings (line 31) | def get_1D_position_embeddings(embed_dim: int, length: int) -> np.ndarray:
  function get_2D_position_embeddings (line 37) | def get_2D_position_embeddings(embed_dim: int, grid_size: int, cls_token...
  class PatchEmbed (line 57) | class PatchEmbed(nn.Module):
    method __init__ (line 58) | def __init__(
    method forward (line 73) | def forward(self, patches: torch.Tensor) -> torch.Tensor:
  class LayerScale (line 84) | class LayerScale(nn.Module):
    method __init__ (line 85) | def __init__(self, dim: int, init_values: float = 0.1) -> None:  # CaI...
    method forward (line 89) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class RMSNorm (line 94) | class RMSNorm(nn.Module):
    method __init__ (line 95) | def __init__(self, dim: int, eps: float = 1e-8) -> None:
    method forward (line 100) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class SwishGLU (line 106) | class SwishGLU(nn.Module):
    method __init__ (line 107) | def __init__(self, in_dim: int, out_dim: int) -> None:
    method forward (line 111) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class Attention (line 119) | class Attention(nn.Module):
    method __init__ (line 120) | def __init__(self, embed_dim: int, n_heads: int, dropout: float = 0.0)...
    method forward (line 131) | def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None...
  class Block (line 158) | class Block(nn.Module):
    method __init__ (line 159) | def __init__(
    method forward (line 204) | def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None...

FILE: voltron/overwatch/overwatch.py
  class OverwatchRich (line 16) | class OverwatchRich:
  class OverwatchStandard (line 47) | class OverwatchStandard:

FILE: voltron/preprocessing/core.py
  function get_path (line 34) | def get_path(save_dir: Path, v: str, i: int, relpath: bool = False) -> str:
  function do_dry_run (line 41) | def do_dry_run(
  function process_clip (line 110) | def process_clip(
  function serialize_epoch (line 150) | def serialize_epoch(

FILE: voltron/preprocessing/process.py
  function extract_frames (line 32) | def extract_frames(
  function preprocess_language (line 107) | def preprocess_language(
  function unify_batches (line 214) | def unify_batches(

FILE: voltron/preprocessing/transforms.py
  function identity (line 20) | def identity(x: torch.Tensor) -> torch.Tensor:
  function scaled_center_crop (line 24) | def scaled_center_crop(target_resolution: int, frames: List[Image.Image]...
  function get_preprocess_transform (line 40) | def get_preprocess_transform(
  function get_online_transform (line 50) | def get_online_transform(

FILE: voltron/preprocessing/v1/process.py
  function preprocess_videos (line 35) | def preprocess_videos(
  function preprocess_language (line 118) | def preprocess_language(
  function jsonify_language (line 186) | def jsonify_language(train_registry: Path, val_registry: Path) -> None:
  function index (line 224) | def index(train_registry: Path, val_registry: Path, name: str, artifact_...
  function unify_batches (line 244) | def unify_batches(

FILE: voltron/preprocessing/v1/transforms.py
  class ComposeMix (line 15) | class ComposeMix:
    method __init__ (line 16) | def __init__(self, transforms):
    method __call__ (line 19) | def __call__(self, imgs):
  class RandomCropVideo (line 31) | class RandomCropVideo:
    method __init__ (line 32) | def __init__(self, size):
    method __call__ (line 35) | def __call__(self, imgs):
  class Scale (line 44) | class Scale:
    method __init__ (line 45) | def __init__(self, size):
    method __call__ (line 48) | def __call__(self, img):
  function identity (line 52) | def identity(x):
  function get_pre_transform (line 57) | def get_pre_transform(dataset: str, resolution: int, scale_factor: float...
  function get_online_transform (line 79) | def get_online_transform(dataset: str, model_arch: str, normalization: T...

FILE: voltron/preprocessing/v1/utils.py
  function get_path (line 31) | def get_path(save_dir: Path, v: str, i: int) -> str:
  function do_dry_run (line 35) | def do_dry_run(
  function process_video (line 100) | def process_video(
  function precompute_epoch (line 138) | def precompute_epoch(

FILE: voltron/util/checkpointing.py
  class FixedDeck (line 26) | class FixedDeck(deque):
    method __init__ (line 27) | def __init__(self, maxlen: int) -> None:
    method append (line 30) | def append(self, x: Any) -> Any:
  class CheckpointSaver (line 40) | class CheckpointSaver:
    method __init__ (line 41) | def __init__(self, strategy: Tuple[int, int, int], run_dir: str, is_ra...
    method save (line 66) | def save(
  function do_resume (line 114) | def do_resume(resume: bool, run_dir: str) -> Tuple[Optional[Path], int, ...

FILE: voltron/util/metrics.py
  class Logger (line 25) | class Logger(ABC):
    method __init__ (line 26) | def __init__(self, run_id: str, hparams: Dict[str, Any], is_rank_zero:...
    method write_hyperparameters (line 30) | def write_hyperparameters(self) -> None:
    method write (line 34) | def write(self, global_step: int, metrics: Dict[str, Union[int, float]...
    method finalize (line 37) | def finalize(self) -> None:
  class JSONLinesLogger (line 41) | class JSONLinesLogger(Logger):
    method write_hyperparameters (line 42) | def write_hyperparameters(self) -> None:
    method write (line 56) | def write(self, global_step: int, metrics: Dict[str, Union[int, float]...
  class WeightsBiasesLogger (line 65) | class WeightsBiasesLogger(Logger):
    method __init__ (line 66) | def __init__(
    method initialize (line 97) | def initialize(self) -> None:
    method write_hyperparameters (line 115) | def write_hyperparameters(self) -> None:
    method write (line 122) | def write(self, global_step: int, metrics: Dict[str, Union[int, float]...
    method finalize (line 129) | def finalize(self) -> None:
  class Metrics (line 137) | class Metrics:
    method __init__ (line 138) | def __init__(
    method itemize (line 220) | def itemize(self) -> Dict[str, torch.Tensor]:
    method log (line 228) | def log(self, global_step: int, metrics: Dict[str, Union[int, float]])...
    method finalize (line 232) | def finalize(self) -> None:
    method get_status (line 236) | def get_status(self, epoch: int, loss: Optional[torch.Tensor] = None) ...
    method commit (line 244) | def commit(
    method push (line 275) | def push(self, epoch: int) -> str:
    method push_epoch (line 369) | def push_epoch(self, epoch: int, val_loss: torch.Tensor) -> Tuple[str,...

FILE: voltron/util/utilities.py
  function worker_init_function (line 43) | def worker_init_function(worker_id: int) -> None:
  function set_global_seed (line 77) | def set_global_seed(seed: int, get_worker_init_fn: bool = False) -> Opti...
  class ResumeableDistributedSampler (line 93) | class ResumeableDistributedSampler(DistributedSampler):
    method __init__ (line 94) | def __init__(
    method __iter__ (line 111) | def __iter__(self) -> Iterator[T_co]:
    method __len__ (line 120) | def __len__(self) -> int:
    method set_epoch (line 127) | def set_epoch(self, epoch: int) -> None:

FILE: voltron/util/v1/checkpointing.py
  class FixedDeck (line 20) | class FixedDeck(deque):
    method __init__ (line 21) | def __init__(self, maxlen: int) -> None:
    method append (line 24) | def append(self, x: Any) -> Any:
  class XLACheckpointSaver (line 34) | class XLACheckpointSaver:
    method __init__ (line 35) | def __init__(self, strategy: Tuple[int, int, int], run_dir: str) -> None:
    method save (line 61) | def save(

FILE: voltron/util/v1/distributed.py
  class ResumeableDistributedSampler (line 27) | class ResumeableDistributedSampler(DistributedSampler):
    method __init__ (line 28) | def __init__(
    method __iter__ (line 45) | def __iter__(self) -> Iterator[T_co]:
    method __len__ (line 54) | def __len__(self) -> int:
    method set_epoch (line 61) | def set_epoch(self, epoch: int) -> None:
  function xla_available (line 69) | def xla_available() -> bool:
  function get_rank (line 76) | def get_rank() -> int:

FILE: voltron/util/v1/random.py
  function set_global_seed (line 24) | def set_global_seed(seed: int) -> Callable[[int], None]:
  function worker_init_function (line 37) | def worker_init_function(worker_id: int) -> None:

FILE: voltron/util/v1/xla_logger.py
  function log_epoch_end_update (line 17) | def log_epoch_end_update(
  function log_vmvp_train_update (line 67) | def log_vmvp_train_update(
  function log_vr3m_train_update (line 99) | def log_vr3m_train_update(
  function log_vrn3m_train_update (line 145) | def log_vrn3m_train_update(
  function log_vcond_train_update (line 192) | def log_vcond_train_update(
  function log_vdual_train_update (line 224) | def log_vdual_train_update(
  function log_vgen_train_update (line 262) | def log_vgen_train_update(
Condensed preview — 60 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (460K chars).
[
  {
    "path": ".gitignore",
    "chars": 1912,
    "preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
  },
  {
    "path": ".pre-commit-config.yaml",
    "chars": 715,
    "preview": "# See https://pre-commit.com for more information\n# See https://pre-commit.com/hooks.html for more hooks\nexclude: \".git\""
  },
  {
    "path": "LICENSE",
    "chars": 1110,
    "preview": "MIT License\n\nCopyright (c) 2021-present, Siddharth Karamcheti and other contributors.\n\nPermission is hereby granted, fre"
  },
  {
    "path": "Makefile",
    "chars": 491,
    "preview": ".PHONY: help check autoformat\n.DEFAULT: help\n\n# Generates a useful overview/help message for various make features - add"
  },
  {
    "path": "README.md",
    "chars": 8810,
    "preview": "<div align=\"center\">\n    <img src=\"https://raw.githubusercontent.com/siddk/voltron-robotics/main/docs/assets/voltron-ban"
  },
  {
    "path": "docs/ROADMAP.md",
    "chars": 1231,
    "preview": "# Project Roadmap\n\nWe document the future of this project (new features to be added, issues to address) here. For the mo"
  },
  {
    "path": "examples/pretrain/README.md",
    "chars": 7098,
    "preview": "# Pretraining Voltron Models\n\nWe provide scripts for pretraining Voltron models on various datasets. Below, we provide t"
  },
  {
    "path": "examples/pretrain/preprocess.py",
    "chars": 3397,
    "preview": "\"\"\"\npreprocess.py\n\nCentralized script for preprocessing various video/vision-language datasets for GPU pretraining, usin"
  },
  {
    "path": "examples/pretrain/pretrain.py",
    "chars": 18529,
    "preview": "\"\"\"\npretrain.py\n\nCore pretraining script for Native PyTorch (Single/Multi-) GPU pretraining on the Something-Something-v"
  },
  {
    "path": "examples/usage.py",
    "chars": 1856,
    "preview": "\"\"\"\nusage.py\n\nExample script demonstrating how to load a Voltron model (`V-Cond`) and instantiate a Multiheaded Attentio"
  },
  {
    "path": "examples/verification/verify.py",
    "chars": 2795,
    "preview": "\"\"\"\nverify.py\n\nExample script demonstrating how to load all Voltron models (and reproduced models), take input image(s),"
  },
  {
    "path": "examples/xla-reference/README.md",
    "chars": 971,
    "preview": "# XLA Reference\n\n*Note :: This code was written for the experimental PyTorch XLA build in PyTorch 1.12; no guarantees it"
  },
  {
    "path": "examples/xla-reference/xpreprocess.py",
    "chars": 3472,
    "preview": "\"\"\"\nxpreprocess.py\n\nCentralized script for preprocessing Sth-Sth-v2 for TPU/GCP pretraining, using a multi-stage, multip"
  },
  {
    "path": "examples/xla-reference/xpretrain.py",
    "chars": 38007,
    "preview": "\"\"\"\nxpretrain.py\n\n(The `x` prefix indicates this is a script geared for XLA/TPU backends *only*)!\n\nReference script for "
  },
  {
    "path": "pyproject.toml",
    "chars": 2127,
    "preview": "[build-system]\nrequires = [\"setuptools\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"voltron-robotics\"\nau"
  },
  {
    "path": "setup.py",
    "chars": 187,
    "preview": "\"\"\"\nsetup.py\n\nPEP 621 switches most of Packaging to `pyproject.toml` -- yet keep a \"dummy\" setup.py for external code th"
  },
  {
    "path": "voltron/__init__.py",
    "chars": 102,
    "preview": "from .models.materialize import available_models, load\nfrom .models.util import instantiate_extractor\n"
  },
  {
    "path": "voltron/conf/__init__.py",
    "chars": 149,
    "preview": "from .accelerators import AcceleratorConfig\nfrom .datasets import DatasetConfig\nfrom .models import ModelConfig\nfrom .tr"
  },
  {
    "path": "voltron/conf/accelerators.py",
    "chars": 1798,
    "preview": "\"\"\"\naccelerator.py\n\nBase Hydra Structured Configs for defining various accelerator schemes. Uses a simple single inherit"
  },
  {
    "path": "voltron/conf/datasets.py",
    "chars": 3095,
    "preview": "\"\"\"\ndatasets.py\n\nBase Hydra Structured Config for defining various pretraining datasets and appropriate configurations. "
  },
  {
    "path": "voltron/conf/models.py",
    "chars": 14747,
    "preview": "\"\"\"\nmodels.py\n\nBase Hydra Structured Config for defining various pretraining model architectures and appropriate configu"
  },
  {
    "path": "voltron/conf/tracking.py",
    "chars": 1772,
    "preview": "\"\"\"\ntracking.py\n\nBase Hydra Structured Config for defining various run & experiment tracking configurations, e.g., via W"
  },
  {
    "path": "voltron/datasets/__init__.py",
    "chars": 35,
    "preview": "from .datasets import get_datasets\n"
  },
  {
    "path": "voltron/datasets/datasets.py",
    "chars": 15867,
    "preview": "\"\"\"\ndatasets.py\n\nCore Pytorch Dataset implementations for the various \"data flavors\" used by the different representatio"
  },
  {
    "path": "voltron/datasets/v1/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "voltron/datasets/v1/stream_datasets.py",
    "chars": 40177,
    "preview": "\"\"\"\nstream_datasets.py\n\nCore PyTorch Datasets for the various \"flavors\" of data used by the various models under study. "
  },
  {
    "path": "voltron/models/__init__.py",
    "chars": 84,
    "preview": "from .instantiate import VMVP, VR3M, VRN3M, VCond, VDual, VGen, get_model_optimizer\n"
  },
  {
    "path": "voltron/models/core/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "voltron/models/core/vcond.py",
    "chars": 22688,
    "preview": "\"\"\"\nvcond.py\n\nPyTorch Module defining the Voltron `V-Cond` variant (single-frame with language-conditioning). In general"
  },
  {
    "path": "voltron/models/core/vdual.py",
    "chars": 27404,
    "preview": "\"\"\"\nvdual.py\n\nPyTorch Module defining the Voltron `V-Dual` variant (dual-frame with language-conditioning). In general, "
  },
  {
    "path": "voltron/models/core/vgen.py",
    "chars": 40463,
    "preview": "\"\"\"\nvgen.py\n\nPyTorch Module defining the Voltron `V-Gen` variant (dual-frame with language-conditioning AND language-gen"
  },
  {
    "path": "voltron/models/instantiate.py",
    "chars": 7301,
    "preview": "\"\"\"\ninstantiate.py\n\nSimple wrapping script for instantiating a core Voltron/reproduction model and configuring the torch"
  },
  {
    "path": "voltron/models/materialize.py",
    "chars": 5380,
    "preview": "\"\"\"\nmaterialize.py\n\nCore functionality for using pretrained models; defines the package-level `load` functionality for d"
  },
  {
    "path": "voltron/models/reproductions/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "voltron/models/reproductions/vmvp.py",
    "chars": 15788,
    "preview": "\"\"\"\nvmvp.py\n\nPyTorch Module defining a basic MAE a la Masked Visual Pretraining for Motor Control (MVP), with the requis"
  },
  {
    "path": "voltron/models/reproductions/vr3m.py",
    "chars": 17436,
    "preview": "\"\"\"\nvr3m.py\n\nPyTorch Module defining an R3M model (with a ViT encoder), with the remainder as described in Nair et. al. "
  },
  {
    "path": "voltron/models/reproductions/vrn3m.py",
    "chars": 13930,
    "preview": "\"\"\"\nvrn3m.py\n\nPyTorch Module defining an R3M model (with a ResNet 50 encoder), exactly as described in Nair et. al. 2021"
  },
  {
    "path": "voltron/models/util/__init__.py",
    "chars": 46,
    "preview": "from .extraction import instantiate_extractor\n"
  },
  {
    "path": "voltron/models/util/extraction.py",
    "chars": 4095,
    "preview": "\"\"\"\nextraction.py\n\nGeneral Extraction module definitions & associated utilities.\n\nReferences:\n    - Set Transformers (MA"
  },
  {
    "path": "voltron/models/util/optimization.py",
    "chars": 1990,
    "preview": "\"\"\"\noptimization.py\n\nGeneral utilities for optimization, e.g., schedulers such as Linear Warmup w/ Cosine Decay for Tran"
  },
  {
    "path": "voltron/models/util/transformer.py",
    "chars": 8787,
    "preview": "\"\"\"\ntransformer.py\n\nGeneral Transformer modules & utilities.\n\nReferences:\n    - https://github.com/facebookresearch/mae\n"
  },
  {
    "path": "voltron/overwatch/__init__.py",
    "chars": 37,
    "preview": "from .overwatch import OverwatchRich\n"
  },
  {
    "path": "voltron/overwatch/overwatch.py",
    "chars": 2207,
    "preview": "\"\"\"\noverwatch.py\n\nUtility class for creating a centralized/standardized logger (to pass to Hydra), with a sane default f"
  },
  {
    "path": "voltron/preprocessing/__init__.py",
    "chars": 72,
    "preview": "from .process import extract_frames, preprocess_language, unify_batches\n"
  },
  {
    "path": "voltron/preprocessing/core.py",
    "chars": 10599,
    "preview": "\"\"\"\nutils.py\n\nPreprocessing utilities, including dry-run and single-video (single-example) processing. This file effecti"
  },
  {
    "path": "voltron/preprocessing/process.py",
    "chars": 12749,
    "preview": "\"\"\"\nprocess.py\n\nUtility functions for preprocessing large-scale video/vision-language datasets in multiple passes, using"
  },
  {
    "path": "voltron/preprocessing/transforms.py",
    "chars": 2944,
    "preview": "\"\"\"\ntransforms.py\n\nDefault video/image transforms for Voltron preprocessing and training. Provides utilities for definin"
  },
  {
    "path": "voltron/preprocessing/v1/__init__.py",
    "chars": 100,
    "preview": "from .process import index, jsonify_language, preprocess_language, preprocess_videos, unify_batches\n"
  },
  {
    "path": "voltron/preprocessing/v1/process.py",
    "chars": 15156,
    "preview": "\"\"\"\nprocess.py\n\nUtility functions for serializing datasets in multiple passes, using multiprocessing for efficient paral"
  },
  {
    "path": "voltron/preprocessing/v1/transforms.py",
    "chars": 2990,
    "preview": "\"\"\"\ntransforms.py\n\nDefault image/video transformations for various datasets.\n\"\"\"\nfrom typing import Any, Tuple\n\nimport c"
  },
  {
    "path": "voltron/preprocessing/v1/utils.py",
    "chars": 9079,
    "preview": "\"\"\"\nutils.py\n\nPreprocessing utilities, including functions for dry-runs and processing a single video (helpers for multi"
  },
  {
    "path": "voltron/util/__init__.py",
    "chars": 152,
    "preview": "from .checkpointing import CheckpointSaver, do_resume\nfrom .metrics import Metrics\nfrom .utilities import ResumeableDist"
  },
  {
    "path": "voltron/util/checkpointing.py",
    "chars": 7070,
    "preview": "\"\"\"\ncheckpointing.py\n\nCore utility class for handling model/optimizer serialization & checkpointing -- including resume "
  },
  {
    "path": "voltron/util/metrics.py",
    "chars": 15408,
    "preview": "\"\"\"\nmetrics.py\n\nUtility classes defining Metrics containers with model-specific logging to various endpoints (JSONL loca"
  },
  {
    "path": "voltron/util/utilities.py",
    "chars": 5698,
    "preview": "\"\"\"\nutilities.py\n\nGeneral utilities for randomness, distributed training, and miscellaneous checks in PyTorch.\n\n=== Rand"
  },
  {
    "path": "voltron/util/v1/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "voltron/util/v1/checkpointing.py",
    "chars": 4561,
    "preview": "\"\"\"\ncheckpointing.py\n\nXLA-specific utility class for handling model/optimizer serialization & checkpointing.\n\nSupport th"
  },
  {
    "path": "voltron/util/v1/distributed.py",
    "chars": 3405,
    "preview": "\"\"\"\ndistributed.py\n\nKey distributed utilities; notably provides a standard API for getting relevant data from either CPU"
  },
  {
    "path": "voltron/util/v1/random.py",
    "chars": 3012,
    "preview": "\"\"\"\nrandom.py\n\nUtilities for dealing with randomness for PyTorch, across devices (CPU, GPU, TPU).\n\nLoosely inspired by f"
  },
  {
    "path": "voltron/util/v1/xla_logger.py",
    "chars": 11056,
    "preview": "\"\"\"\nxla_logger.py\n\nUtility class defining various XLA logging methods (called within marked closures), for logging metri"
  }
]

About this extraction

This page contains the full source code of the siddk/voltron-robotics GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 60 files (431.8 KB), approximately 104.9k tokens, and a symbol index with 305 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!