Repository: siddk/voltron-robotics Branch: main Commit: 1b299bf5cfa0 Files: 60 Total size: 431.8 KB Directory structure: gitextract_4mb9jrxp/ ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── Makefile ├── README.md ├── docs/ │ └── ROADMAP.md ├── examples/ │ ├── pretrain/ │ │ ├── README.md │ │ ├── preprocess.py │ │ └── pretrain.py │ ├── usage.py │ ├── verification/ │ │ └── verify.py │ └── xla-reference/ │ ├── README.md │ ├── xpreprocess.py │ └── xpretrain.py ├── pyproject.toml ├── setup.py └── voltron/ ├── __init__.py ├── conf/ │ ├── __init__.py │ ├── accelerators.py │ ├── datasets.py │ ├── models.py │ └── tracking.py ├── datasets/ │ ├── __init__.py │ ├── datasets.py │ └── v1/ │ ├── __init__.py │ └── stream_datasets.py ├── models/ │ ├── __init__.py │ ├── core/ │ │ ├── __init__.py │ │ ├── vcond.py │ │ ├── vdual.py │ │ └── vgen.py │ ├── instantiate.py │ ├── materialize.py │ ├── reproductions/ │ │ ├── __init__.py │ │ ├── vmvp.py │ │ ├── vr3m.py │ │ └── vrn3m.py │ └── util/ │ ├── __init__.py │ ├── extraction.py │ ├── optimization.py │ └── transformer.py ├── overwatch/ │ ├── __init__.py │ └── overwatch.py ├── preprocessing/ │ ├── __init__.py │ ├── core.py │ ├── process.py │ ├── transforms.py │ └── v1/ │ ├── __init__.py │ ├── process.py │ ├── transforms.py │ └── utils.py └── util/ ├── __init__.py ├── checkpointing.py ├── metrics.py ├── utilities.py └── v1/ ├── __init__.py ├── checkpointing.py ├── distributed.py ├── random.py └── xla_logger.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ pip-wheel-metadata/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # PEP 582; used by e.g. github.com/David-OConnor/pyflow __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # Ruff .ruff_cache/ # IDE caches .idea/ .vscode/ # Mac OS .DS_Store # Cache data/ cache/ # Scratch scratch/ ================================================ FILE: .pre-commit-config.yaml ================================================ # See https://pre-commit.com for more information # See https://pre-commit.com/hooks.html for more hooks exclude: ".git" repos: - repo: https://github.com/charliermarsh/ruff-pre-commit rev: v0.0.252 hooks: - id: ruff args: [ --fix, --exit-non-zero-on-fix ] - repo: https://github.com/psf/black rev: 23.1.0 hooks: - id: black - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.4.0 hooks: - id: check-added-large-files args: ["--maxkb=40000"] - id: check-ast - id: check-case-conflict - id: check-merge-conflict - id: check-toml - id: check-yaml - id: end-of-file-fixer - id: trailing-whitespace ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2021-present, Siddharth Karamcheti and other contributors. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: Makefile ================================================ .PHONY: help check autoformat .DEFAULT: help # Generates a useful overview/help message for various make features - add to this as necessary! help: @echo "make check" @echo " Run code style and linting (black, ruff) *without* changing files!" @echo "make autoformat" @echo " Run code styling (black, ruff) and update in place - committing with pre-commit also does this." check: black --check . ruff check --show-source . autoformat: black . ruff check --fix --show-fixes . ================================================ FILE: README.md ================================================
Voltron Logo
[![arXiv](https://img.shields.io/badge/arXiv-2302.12766-df2a2a.svg?style=for-the-badge)](https://arxiv.org/abs/2302.12766) [![PyTorch](https://img.shields.io/badge/PyTorch-2.0.0-EE4C2C.svg?style=for-the-badge&logo=pytorch)](https://pytorch.org/get-started/locally/) [![Code Style: Black](https://img.shields.io/badge/Code%20Style-Black-000000?style=for-the-badge)](https://github.com/psf/black) [![Ruff](https://img.shields.io/badge/%E2%9A%A1%EF%B8%8F-Ruff-orange?style=for-the-badge)](https://github.com/charliermarsh/ruff) ![License](https://img.shields.io/github/license/siddk/lila?color=blueviolet&style=for-the-badge)
--- # Language-Driven Representation Learning for Robotics Package repository for Voltron: Language-Driven Representation Learning for Robotics. Provides code for loading pretrained Voltron, R3M, and MVP representations for adaptation to downstream tasks, as well as code for pretraining such representations on arbitrary datasets. --- ## Quickstart This repository is built with PyTorch; while specified as a dependency for the package, we highly recommend that you install the desired version (e.g., with accelerator support) for your given hardware and environment manager (e.g., `conda`). PyTorch installation instructions [can be found here](https://pytorch.org/get-started/locally/). This repository should work with PyTorch >= 1.12. Releases before 1.1.0 have been thoroughly tested with PyTorch 1.12.0, Torchvision 0.13.0, and Torchaudio 0.12.0. **Note**: Releases 1.1.0 and after *assume PyTorch 2.0*! Once PyTorch has been properly installed, you can install this package via PyPI, and you're off! ```bash pip install voltron-robotics ``` You can also install this package locally via an editable installation in case you want to run examples/extend the current functionality: ```bash git clone https://github.com/siddk/voltron-robotics cd voltron-robotics pip install -e . ``` ## Usage Voltron Robotics (package: `voltron`) is structured to provide easy access to pretrained Voltron models (and reproductions), to facilitate use for various downstream tasks. Using a pretrained Voltron model is easy: ```python from torchvision.io import read_image from voltron import instantiate_extractor, load # Load a frozen Voltron (V-Cond) model & configure a vector extractor vcond, preprocess = load("v-cond", device="cuda", freeze=True) vector_extractor = instantiate_extractor(vcond)() # Obtain & Preprocess an image =>> can be from a dataset, or camera on a robot, etc. # => Feel free to add any language if you have it (Voltron models work either way!) img = preprocess(read_image("examples/img/peel-carrot-initial.png"))[None, ...].to("cuda") lang = ["peeling a carrot"] # Extract both multimodal AND vision-only embeddings! multimodal_embeddings = vcond(img, lang, mode="multimodal") visual_embeddings = vcond(img, mode="visual") # Use the `vector_extractor` to output dense vector representations for downstream applications! # => Pass this representation to model of your choice (object detector, control policy, etc.) representation = vector_extractor(multimodal_embeddings) ``` Voltron representations can be used for a variety of different applications; in the [`voltron-evaluation`](https://github.com/siddk/voltron-evaluation) repository, you can find code for adapting Voltron representations to various downstream tasks (segmentation, object detection, control, etc.); all the applications from our paper. --- ## API ![Voltron Framework](https://raw.githubusercontent.com/siddk/voltron-robotics/main/docs/assets/voltron-framework.png) The package `voltron` provides the following functionality for using and adapting existing representations: #### `voltron.available_models()` Returns the name of available Voltron models; right now, the following models (all models trained in the paper) are available: - `v-cond` – V-Cond (ViT-Small) trained on Sth-Sth; single-frame w/ language-conditioning. - `v-dual` – V-Dual (ViT-Small) trained on Sth-Sth; dual-frame w/ language-conditioning. - `v-gen` – V-Gen (ViT-Small) trained on Sth-Sth; dual-frame w/ language conditioning AND generation. - `r-mvp` – R-MVP (ViT-Small); reproduction of [MVP](https://github.com/ir413/mvp) trained on Sth-Sth. - `r-r3m-vit` – R-R3M (ViT-Small); reproduction of [R3M](https://github.com/facebookresearch/r3m) trained on Sth-Sth. - `r-r3m-rn50` – R-R3M (ResNet-50); reproduction of [R3M](https://github.com/facebookresearch/r3m) trained on Sth-Sth. - `v-cond-base` – V-Cond (ViT-Base) trained on Sth-Sth; larger (86M parameter) variant of V-Cond. #### `voltron.load(name: str, device: str, freeze: bool, cache: str = cache/)` Returns the model and the Torchvision Transform needed by the model, where `name` is one of the strings returned by `voltron.available_models()`; this in general follows the same API as [OpenAI's CLIP](https://github.com/openai/CLIP). --- Voltron models (`v-{cond, dual, gen, ...}`) returned by `voltron.load()` support the following: #### `model(img: Tensor, lang: Optional[List[str]], mode: str = "multimodal")` Returns a sequence of embeddings corresponding to the output of the multimodal encoder; note that `lang` can be None, which is totally fine for Voltron models! However, if you have any language (even a coarse task description), it'll probably be helpful! The parameter `mode` in `["multimodal", "visual"]` controls whether the output will contain the fused image patch and language embeddings, or only the image patch embeddings. **Note:** For the API for the non-Voltron models (e.g., R-MVP, R-R3M), take a look at [`examples/verify.py`](examples/verify.py); this file shows how representations from *every* model can be extracted. ### Adaptation See [`examples/usage.py`](examples/usage.py) and the [`voltron-evaluation`](https://github.com/siddk/voltron-evaluation) repository for more examples on the various ways to adapt/use Voltron representations. --- ## Contributing Before committing to the repository, make sure to set up your dev environment! Here are the basic development environment setup guidelines: + Fork/clone the repository, performing an editable installation. Make sure to install with the development dependencies (e.g., `pip install -e ".[dev]"`); this will install `black`, `ruff`, and `pre-commit`. + Install `pre-commit` hooks (`pre-commit install`). + Branch for the specific feature/issue, issuing PR against the upstream repository for review. Additional Contribution Notes: - This project has migrated to the recommended [`pyproject.toml` based configuration for setuptools](https://setuptools.pypa.io/en/latest/userguide/quickstart.html). However, as some tools haven't yet adopted [PEP 660](https://peps.python.org/pep-0660/), we provide a [`setup.py` file](https://setuptools.pypa.io/en/latest/userguide/pyproject_config.html). - This package follows the [`flat-layout` structure](https://setuptools.pypa.io/en/latest/userguide/package_discovery.html#flat-layout) described in `setuptools`. - Make sure to add any new dependencies to the `project.toml` file! --- ## Repository Structure High-level overview of repository/project file-tree: + `docs/` - Package documentation & assets - including project roadmap. + `voltron` - Package source code; has all core utilities for model specification, loading, feature extraction, preprocessing, etc. + `examples/` - Standalone examples scripts for demonstrating various functionality (e.g., extracting different types of representations, adapting representations in various contexts, pretraining, amongst others). + `.pre-commit-config.yaml` - Pre-commit configuration file (sane defaults + `black` + `ruff`). + `LICENSE` - Code is made available under the MIT License. + `Makefile` - Top-level Makefile (by default, supports linting - checking & auto-fix); extend as needed. + `pyproject.toml` - Following PEP 621, this file has all project configuration details (including dependencies), as well as tool configurations (for `black` and `ruff`). + `README.md` - You are here! --- ## Citation Please cite [our paper](https://arxiv.org/abs/2302.12766) if using any of the Voltron models, evaluation suite, or other parts of our framework in your work. ```bibtex @inproceedings{karamcheti2023voltron, title={Language-Driven Representation Learning for Robotics}, author={Siddharth Karamcheti and Suraj Nair and Annie S. Chen and Thomas Kollar and Chelsea Finn and Dorsa Sadigh and Percy Liang}, booktitle={Robotics: Science and Systems (RSS)}, year={2023} } ``` ================================================ FILE: docs/ROADMAP.md ================================================ # Project Roadmap We document the future of this project (new features to be added, issues to address) here. For the most part, any new features/bugfixes are documented as [Github Issues](https://github.com/siddk/voltron-robotics/issues). ## Timeline [X] - **February 26th, 2023**: Initial Voltron-Robotics release with support for loading/adapting all pretrained models, with comprehensive verification scripts & a small adaptation example. [X] - **April 4, 2023**: [#1](https://github.com/siddk/voltron-robotics/issues/1) - Add `xpretrain.py` reference script, mostly for completeness. Refactor/rewrite the preprocessing and pretraining pipeline to reflect the Qualcomm Sth-Sth data format, as well as PyTorch DDP vs. the patched PyTorch XLA! [X] - **April 11, 2023**: [#2](https://github.com/siddk/voltron-robotics/issues/2) - Add support and a more general API for pretraining on other datasets. [ ] - **Future**: [#5](https://github.com/siddk/voltron-robotics/issues/5) - Add better documentation and examples around using the MAP extractor (especially for adaptation tasks). ================================================ FILE: examples/pretrain/README.md ================================================ # Pretraining Voltron Models We provide scripts for pretraining Voltron models on various datasets. Below, we provide the full pipeline from downloading the raw Something-Something-v2 Dataset from Qualcomm, running preprocessing, then running Distributed Data Parallel (DDP) pretraining on 1+ GPUs via `torchrun`. Adding support for new datasets should follow this same general flow. --- ## Dataset Preprocessing We provide end-to-end instructions for downloading, preprocessing, and serializing various pretraining datasets (and combinations thereof). Where possible, we provide links to batch/dataset index files. **Note:** We make a key assumption that you have enough local disk space (e.g., on your server, attached NFS volume) to store all *raw* and *preprocessed* data; this can range from 100s of GBs to 10s of TBs! We did not have access to such storage in the original work, necessitating the *streaming* dataloaders defined in `voltron/datasets/v1/stream_datasets.py`. Given your resources, you might consider adopting a similar approach; feel free to post an issue with any questions! We currently support pretraining on the following datasets: - [Something-Something-v2](https://developer.qualcomm.com/software/ai-datasets/something-something) Instructions for downloading/preprocessing each dataset can be found below! --- ### Something-Something-v2 Dataset Download: [Qualcomm AI Datasets](https://developer.qualcomm.com/software/ai-datasets/something-something) #### Obtaining the Raw Dataset Follow the instructions [at the above link](https://developer.qualcomm.com/software/ai-datasets/something-something) to download the dataset. Qualcomm requires that you register for a [Qualcomm OneID Account](https://myaccount.qualcomm.com/signup?target=https%3A%2F%2Fdeveloper.qualcomm.com) to get access to the data. Approval might take some time. After registering for an account, make sure to download all of the following files to a directory of your choosing (we create a directory `data/raw/something-something-v2/downloaded/`). *You will need to manually download all 22 of the following files from the Qualcomm site*: 1. Datasheet / Instructions (PDF – optional, but useful): `20bn-something-something_download_instructions_-_091622.pdf` 2. Labels (includes language annotations): `20bn-something-something_download-package-labels.zip` 3. Chunked Videos (should be 20 `.zip` archives): + `20bn-something-something-v2-00.zip` + ... + `20bn-something-something-v2-19.zip` To extract all the given files (we extract to `data/raw/something-something-v2/`) - *execute the following from inside the `downloaded/` subdirectory)*: ```bash # Labels (annotations/language) --> creates `data/raw/something-something-v2/labels` unzip 20bn-something-something-download-package-labels.zip -d ../ # Videos (following instructions in `20-bn-something-something_download_instructions_-_091622.pdf`) unzip "20bn-something-something-v2-*.zip" -d ../videos cd ../videos cat 20bn-something-something-?? | tar -xvzf - find . -maxdepth 1 -type f -delete cd 20bn-something-something-v2/ find . -mindepth 1 -maxdepth 1 -exec mv -t .. -- {} + cd .. rm -r 20bn-something-something-v2 ls | wc # Should have 220847 `.webm` files! ``` #### Dataset Information & Statistics Something-Something-v2 consists of 220,847 `.webm` clips (168,913 in the `train` split) each with a height of exactly 240px, and variable width. The frames are encoded at a fixed 12 FPS. There are an average of 45 frames per clip (approx ~7 KB per jpeg); ~7.6M frames total (~56 GB). #### Video/Image Transformations --> from Video Clip to "frame" --> "tensor" ```python import av from PIL import Image, ImageOps # Resolutions for "preprocessing" (serialize to disk) and "training" PREPROCESS_RESOLUTION, TRAIN_RESOLUTION = 240, 224 # Define Preprocessing Transformation def preprocess_transform(frames: List[Image.Image]) -> List[Image.Image]: # Assert width >= height and height >= PREPROCESS_RESOLUTION orig_w, orig_h = frames[0].size assert orig_w >= orig_h >= PREPROCESS_RESOLUTION # Compute scale factor --> just a function of height and PREPROCESS_RESOLUTION scale_factor = PREPROCESS_RESOLUTION / orig_h # Full Transformation --> scale (preserve aspect ratio, then get square) for idx in range(len(frames)): frames[idx] = ImageOps.scale(frames[idx], factor=scale_factor) left = (frames[idx].size[0] - PREPROCESS_RESOLUTION) // 2 frames[idx] = frames[idx].crop((left, 0, left + PREPROCESS_RESOLUTION, PREPROCESS_RESOLUTION)) return frames def train_transform(img) -> torch.Tensor: # Assumes square, just resizes to TRAIN_RESOLUTION via `torchvision.transforms` ... def extract_frames(webm_file: str) -> None: container = av.open(webm_file) assert int(container.streams.video[0].average_rate) == 12, "FPS for `sth-sth-v2` should be 12!" # Extract --> then serialize via `Image.save("frame_{idx}.jpg")` frames = preprocess_transform([f.to_image() for f in container.decode(video=0)]) ... ``` #### Citation If you are pretraining on this dataset, make sure to cite the original research; Something-Something-v2 is the product of two papers: ```bibtex @inproceedings{goyal2017sthsthv1, author = {Raghav Goyal and Samira Ebrahimi Kahou and Vincent Michalski and Joanna Materzynska and Susanne Westphal and Heuna Kim and Valentin Haenel and Ingo Fründ and Peter N. Yianilos and Moritz Mueller-Freitag and Florian Hoppe and Christian Thurau and Ingo Bax and Roland Memisevic}, booktitle = {International Conference on Computer Vision (ICCV)}, title = {The ``Something Something'' Video Database for Learning and Evaluating Visual Common Sense}, year = {2017}, } @article{mahidisoltani2018sthsthv2, author={Farzaneh Mahdisoltani and Guillaume Berger and Waseem Gharbieh and David J. Fleet and Roland Memisevic}, journal = {arXiv preprint arXiv:1804.09235}, title={On the Effectiveness of Task Granularity for Transfer Learning}, year={2018} } ``` --- ## PyTorch Native Pretraining Pipeline To pretrain a Voltron model (e.g., `v-cond`) on the processed data, make sure to read `examples/pretrain/preprocess.py`. A sample launch command to run with the Something-Something-v2 dataset on a single node with 8 GPUs is as follows: ```bash torchrun --standalone --nnodes 1 --nproc-per-node 8 examples/pretrain/pretrain.py ``` Make sure to check the following configuration files and either update them manually (adding your own dataclass, overriding [DEFAULTS](https://github.com/siddk/voltron-robotics/blob/main/examples/pretrain/pretrain.py#L38)), or by using Hydra semantics to override them at the command line (e.g., `... pretrain.py dataset.path="" ...`): - [Accelerator Config](../../voltron/conf/accelerators.py): Depending on hardware, might need to tune `num_workers` - [Dataset Config](../../voltron/conf/datasets.py): Make sure to override `path` and `artifact_path` - [Tracking Config](../../voltron/conf/tracking.py): Disable Weights & Biases / change default entity/name ================================================ FILE: examples/pretrain/preprocess.py ================================================ """ preprocess.py Centralized script for preprocessing various video/vision-language datasets for GPU pretraining, using a multi-stage, multiprocessing approach. Run as a standalone script, *prior* to calling `pretrain.py` =>> mostly because we want to preprocess the data once, as a fixed cost. """ import logging from dataclasses import dataclass, field from typing import Any, Dict, List import hydra from hydra.core.config_store import ConfigStore from omegaconf import MISSING from voltron.conf import DatasetConfig from voltron.overwatch import OverwatchRich from voltron.preprocessing import extract_frames, preprocess_language, unify_batches from voltron.util import set_global_seed # Grab Logger overwatch = logging.getLogger(__file__) # Set Defaults (Hydra w/ Structured Configs) DEFAULTS = ["_self_", {"dataset": "sth-sth-v2"}, {"override hydra/job_logging": "overwatch_rich"}] @dataclass class PreprocessingConfig: # fmt: off defaults: List[Any] = field(default_factory=lambda: DEFAULTS) hydra: Dict[str, Any] = field( default_factory=lambda: {"run": {"dir": "./runs/preprocessing/${now:%m-%d}/dataset-${dataset.name}"}} ) # Command Line Arguments seed: int = 21 # Random Seed (for reproducibility) dry_run: bool = False # Dry Run --> Get a sense of preprocessing/serialization footprint # Composable / Structured Arguments dataset: DatasetConfig = MISSING # Dataset(s) for pretraining/preprocessing # fmt: on # Hydra Setup :: Retrieve ConfigStore (Singleton) & Register Components cs = ConfigStore.instance() cs.store(group="hydra/job_logging", name="overwatch_rich", node=OverwatchRich) cs.store(name="config", node=PreprocessingConfig) @hydra.main(config_path=None, config_name="config") def preprocess(cfg: PreprocessingConfig) -> None: overwatch.info("Preprocessing :: Running Phases for Frame Extraction, Language Compilation, and Batching...") # Set Randomness set_global_seed(cfg.seed) # Phase 1 :: Serialize Frames from Video Clips --> get `registry` (index files) for train and validation train_registry, val_registry, train_dir, val_dir = extract_frames( cfg.dataset.name, path=cfg.dataset.path, artifact_path=cfg.dataset.artifact_path, preprocess_resolution=cfg.dataset.preprocess_resolution, n_val_videos=cfg.dataset.n_val_videos, dry_run=cfg.dry_run, ) # Phase 2 :: Normalize & Tokenize Language --> create `index.pt` and `index.json` files index_dir = preprocess_language( cfg.dataset.name, train_registry, val_registry, artifact_path=cfg.dataset.artifact_path, max_lang_len=cfg.dataset.max_lang_len, language_model=cfg.dataset.language_model, hf_cache=cfg.dataset.hf_cache, ) # Phase 3 :: Assemble "Data-Locked" Batch Sets for Various Models (e.g., for single-frame/dual-frame/quintet) unify_batches( cfg.dataset.name, train_registry, val_registry, train_dir, val_dir, index_dir, batch_formats=cfg.dataset.batch_formats, max_epochs=cfg.dataset.max_epochs, initial_final_alpha=cfg.dataset.initial_final_alpha, ) overwatch.info("Preprocessing Complete!") if __name__ == "__main__": preprocess() ================================================ FILE: examples/pretrain/pretrain.py ================================================ """ pretrain.py Core pretraining script for Native PyTorch (Single/Multi-) GPU pretraining on the Something-Something-v2 dataset; this is basically just a 1-1 reproduction of the XLA pretraining script (`examples/xla-reference/xpretrain.py`) with just a bit of cleanup, the default PyTorch DDP semantics (`torchrun`), using PyTorch 2.0. Other notable differences from `xpretrain.py`: - Loads data from the local filesystem instead of streaming from a GCP bucket (can be added back easily!) - No TPU/XLA specific dependencies --> just PyTorch 2.0! Run with: - [Single Node Multi-GPU ($K)]: `torchrun --standalone --nnodes 1 --nproc-per-node $K examples/pretrain/pretrain.py` """ import logging import os import re import time from dataclasses import dataclass, field from typing import Any, Dict, List, Optional import hydra import torch import torch.distributed as dist from hydra.core.config_store import ConfigStore from omegaconf import MISSING, OmegaConf from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader, DistributedSampler from tqdm import tqdm from voltron.conf import AcceleratorConfig, DatasetConfig, ModelConfig, TrackingConfig from voltron.datasets import get_datasets from voltron.models import get_model_optimizer from voltron.overwatch import OverwatchRich from voltron.util import CheckpointSaver, Metrics, ResumeableDistributedSampler, do_resume, set_global_seed # Set Defaults (Hydra w/ Structured Configs) DEFAULTS = [ "_self_", {"model": "v-cond"}, {"dataset": "sth-sth-v2"}, {"accelerator": "torchrun"}, {"tracking": "voltron-tracking"}, {"override hydra/job_logging": "overwatch_rich"}, ] @dataclass class PretrainConfig: # fmt: off defaults: List[Any] = field(default_factory=lambda: DEFAULTS) hydra: Dict[str, Any] = field(default_factory=lambda: { "run": {"dir": "runs/train/${model.identifier}+dataset-${dataset.name}"} }) # Command Line Arguments run_id: Optional[str] = None # Run ID for Logging seed: int = 21 # Random Seed (for reproducibility) # Resume / Debug Behavior resume: bool = True # Whether to resume an existing run... wandb_resume_id: Optional[str] = None # W&B Run ID for `resume` behavior... # Composable / Structured Arguments model: ModelConfig = MISSING # Model architecture for pretraining dataset: DatasetConfig = MISSING # List of datasets for pretraining accelerator: AcceleratorConfig = MISSING # Accelerator (should always keep `torchrun`) tracking: TrackingConfig = MISSING # Run/experiment tracking configuration # fmt: on # Hydra Setup :: Retrieve ConfigStore (Singleton) & Register Components cs = ConfigStore.instance() cs.store(group="hydra/job_logging", name="overwatch_rich", node=OverwatchRich) cs.store(name="config", node=PretrainConfig) @hydra.main(config_path=None, config_name="config") def pretrain(cfg: PretrainConfig) -> None: # Initialize Distributed Process Group --> assumes NCCL + Environment Variable Initialization (via `torchrun`) dist.init_process_group(backend="nccl", init_method="env://") device_id = dist.get_rank() % torch.cuda.device_count() is_rank_zero, rank, world_size = dist.get_rank() == 0, dist.get_rank(), dist.get_world_size() # Create Unique Run Name -- `resume = True` we assume the same "run_id" if cfg.run_id is None: cfg.run_id = run_dir = f"{cfg.model.identifier}+{cfg.dataset.name}-ddp-x{cfg.seed}" else: run_dir = cfg.run_id # Setup Logging (Rank 0 Only!) and Directory Handling overwatch = logging.getLogger(__file__) overwatch.setLevel(logging.INFO if is_rank_zero else logging.ERROR) overwatch.info("Voltron Training :: Assembling the Legendary Defender...") if is_rank_zero: os.makedirs(run_dir, exist_ok=True) # Let's Get Started! overwatch.info( '\t=>> "If you get too worried about what could go wrong, you might miss a chance to do something great."' ) # Set Randomness & Get Dataloader `worker_init_fn` to ensure proper randomness in augmentations (if any) worker_init_fn = set_global_seed(cfg.seed, get_worker_init_fn=True) # Initialize Model & Optimizer --> Wrap in DDP / Device Handling # > Note :: For (Standard) DDP Training --> initializing Optimizer before DDP == initializing after! overwatch.info("Initializing Model, Optimizer, and Learning Rate Scheduler") model, optimizer, update_lr = get_model_optimizer(cfg.model, cfg.dataset) model = DDP(model.to(device_id), device_ids=[device_id], output_device=device_id) # Handle Resume / Checkpoint Loading resume_checkpoint, resume_epoch, resume_step = do_resume(cfg.resume, run_dir=run_dir) if resume_checkpoint is not None: # IMPORTANT --> Load weights by mapping specifically to `cuda:`! resume_state = torch.load(resume_checkpoint, map_location=f"cuda:{device_id}") model.load_state_dict(resume_state["model_state_dict"]) optimizer.load_state_dict(resume_state["optimizer_state_dict"]) dist.barrier() # Create Checkpoint Saver and Save Initial Checkpoint saver = CheckpointSaver(cfg.tracking.checkpoint_strategy, run_dir, is_rank_zero=is_rank_zero) if resume_checkpoint is None and resume_epoch == 0: overwatch.info(" | Saving 0th Epoch Checkpoint (Model Initialization)") saver.save( epoch=0, is_local_step=False, model=model, optimizer=optimizer, duration=0, train_loss=None, val_loss=None ) dist.barrier() # Get Datasets --> Barrier after I/O Intensive Operation overwatch.info(f"Retrieving Dataset `{cfg.dataset.name}` prepared for Model `{cfg.model.arch}`") train_dataset, val_dataset = get_datasets( 0, cfg.dataset.name, cfg.model.arch, cfg.dataset.artifact_path, cfg.model.data_modality, cfg.dataset.resolution, cfg.dataset.normalization, cfg.model.get("lang_dropout", None), cfg.model.get("gen_ratio", None), ) dist.barrier() # Create Metrics =>> Handles on-the-fly computation, logging to JSONL and Weights & Biases metrics = Metrics( active_loggers=cfg.tracking.active_loggers, run_id=cfg.run_id, hparams=OmegaConf.to_container(cfg), model_arch=cfg.model.arch, is_rank_zero=is_rank_zero, tracking_cfg=cfg.tracking, tags=cfg.tracking.tags, resume=cfg.resume, resume_id=cfg.wandb_resume_id, ) dist.barrier() # Configure Gradient Accumulation --> function of `effective_bsz`, `native_bsz`, and `WORLD_SIZE` assert cfg.model.effective_bsz % cfg.model.native_bsz == 0, "Device `native_bsz` must evenly divide `effective_bsz`" accumulate_grad_batches = cfg.model.effective_bsz // cfg.model.native_bsz // world_size overwatch.info(f"Running `{cfg.model.identifier}` Model Pretraining with Parameters =>") overwatch.info(f" | Effective Batch Size = `{cfg.model.effective_bsz}`") overwatch.info(f" | Per-Device Batch Size = `{cfg.model.native_bsz}`") overwatch.info(f" | Distributed World Size = `{world_size}`") overwatch.info(f" | Accumulation Steps = `{accumulate_grad_batches}`") # Start Train Loop --> Iterate through Epochs (Evaluation at end of Epoch) overwatch.info("Starting Training Loop") for epoch in range(resume_epoch, cfg.dataset.max_epochs): overwatch.info(f" | [Epoch {epoch:03d}] Building Distributed Sampler & DataLoaders") train_dataset.set_epoch(epoch) dist.barrier() # [Custom] ResumeableDistributedSampler operates over *examples* --> start_step (full batches) * effective_bsz train_sampler = ResumeableDistributedSampler( seen_examples=resume_step * cfg.model.effective_bsz, resume_epoch=resume_epoch, dataset=train_dataset, num_replicas=world_size, rank=rank, shuffle=True, seed=cfg.seed, ) train_sampler.set_epoch(epoch) val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank, shuffle=False, drop_last=True) # Create Epoch DataLoaders train_dl = DataLoader( train_dataset, batch_size=cfg.model.native_bsz, sampler=train_sampler, shuffle=False, num_workers=cfg.accelerator.num_workers, drop_last=True, pin_memory=True, prefetch_factor=4, worker_init_fn=worker_init_fn, ) val_dl = DataLoader( val_dataset, batch_size=cfg.model.native_bsz, sampler=val_sampler, shuffle=False, num_workers=4 ) # Book-Keeping =>> Set LR when `resume = True` (or starting from scratch) if epoch == resume_epoch or epoch == 0: metrics.resume_time = ( int(re.search("-t=(.+?).pt", str(resume_checkpoint)).group(1)) if resume_checkpoint is not None else 0 ) metrics.commit( global_step=resume_step + ((len(train_dataset) // cfg.model.effective_bsz) * resume_epoch), lr=update_lr(resume_epoch, resume_step / (len(train_dataset) // cfg.model.effective_bsz)), update_step_time=True, ) # === Train Epoch === model.train() status = metrics.get_status(epoch) overwatch.info(f" | [Epoch {epoch:03d}] Running Train Loop") with tqdm( total=len(train_dl) // accumulate_grad_batches, desc=status, leave=False, disable=not is_rank_zero ) as progress: for train_idx, batch in enumerate(train_dl): # Model-Specific Handling if cfg.model.arch == "v-mvp": img = batch loss, _, _ = model(img.to(device_id, non_blocking=True)) metrics.commit(reconstruction_loss=loss) elif cfg.model.arch in {"v-r3m", "v-rn3m"}: imgs, lang, lang_mask = batch loss, tcn_loss, reward_loss, l1_loss, l2_loss, tcn_acc, rew_acc = model( imgs.to(device_id, non_blocking=True), lang.to(device_id, non_blocking=True), lang_mask.to(device_id, non_blocking=True), ) metrics.commit( tcn_loss=tcn_loss, reward_loss=reward_loss, l1_loss=l1_loss, l2_loss=l2_loss, tcn_accuracy=tcn_acc, reward_accuracy=rew_acc, ) elif cfg.model.arch == "v-cond": img, lang, lang_mask = batch loss, _, _ = model( img.to(device_id, non_blocking=True), lang.to(device_id, non_blocking=True), lang_mask.to(device_id, non_blocking=True), ) metrics.commit(reconstruction_loss=loss) elif cfg.model.arch == "v-dual": imgs, lang, lang_mask = batch loss, [zero_loss, k_loss] = model( imgs.to(device_id, non_blocking=True), lang.to(device_id, non_blocking=True), lang_mask.to(device_id, non_blocking=True), ) metrics.commit( reconstruction_loss=loss, zero_reconstruction_loss=zero_loss, k_reconstruction_loss=k_loss, ) elif cfg.model.arch == "v-gen": imgs, lang_con, lang_con_mask, lang_gen, lang_gen_mask, lang_gen_weight = batch loss, reconstruction_loss, lm_loss, [zero_loss, k_loss] = model( imgs.to(device_id, non_blocking=True), lang_con.to(device_id, non_blocking=True), lang_con_mask.to(device_id, non_blocking=True), lang_gen.to(device_id, non_blocking=True), lang_gen_mask.to(device_id, non_blocking=True), lang_gen_weight, ) metrics.commit( reconstruction_loss=reconstruction_loss, zero_reconstruction_loss=zero_loss, k_reconstruction_loss=k_loss, lm_loss=lm_loss, lm_ppl=torch.exp(lm_loss), ) else: raise ValueError(f"Forward() for Model `{cfg.model.arch}` is not implemented!") # Commit Loss (Prior to Normalization) metrics.commit(loss=loss) # Normalize Loss to account for Gradient Accumulation --> Backward! normalized_loss = loss / accumulate_grad_batches normalized_loss.backward() # Step =>> Check if done w/ Gradient Accumulation if (train_idx + 1) % accumulate_grad_batches == 0: metrics.commit(update_step_time=True) # Push Metrics every `log_frequency` steps... if metrics.global_step % cfg.tracking.log_frequency == 0: status = metrics.push(epoch) # Optimizer Step --> Increment Global Step, Learning Rate, and Checkpoint (if specified) optimizer.step() optimizer.zero_grad() lr = update_lr( epoch, (resume_step + ((train_idx + 1) // accumulate_grad_batches)) / (len(train_dataset) // cfg.model.effective_bsz), ) metrics.commit(global_step=metrics.global_step + 1, lr=lr) saver.save( epoch, is_local_step=True, model=model, optimizer=optimizer, duration=int(time.time() - metrics.start_time) + metrics.resume_time, local_step=resume_step + ((train_idx + 1) // accumulate_grad_batches), ) # Update Progress Bar progress.update() progress.set_description(status) # === After Train Epoch --> Clear Gradients and reset `resume_step` === optimizer.zero_grad() resume_step = 0 # === Validation === overwatch.info(f" | [Epoch {epoch:03d}] Running Validation Loop") model.eval() # Accumulate `validation_losses` in order to `all_reduce` later! val_losses = [] with torch.no_grad(): for batch in tqdm(val_dl, disable=not is_rank_zero, leave=False): # Model-Specific Handling if cfg.model.arch == "v-mvp": img = batch val_loss, _, _ = model(img.to(device_id, non_blocking=True)) elif cfg.model.arch in {"v-r3m", "v-rn3m"}: imgs, lang, lang_mask = batch val_loss, _, _, _, _, _, _ = model( imgs.to(device_id, non_blocking=True), lang.to(device_id, non_blocking=True), lang_mask.to(device_id, non_blocking=True), ) elif cfg.model.arch == "v-cond": img, lang, lang_mask = batch val_loss, _, _ = model( img.to(device_id, non_blocking=True), lang.to(device_id, non_blocking=True), lang_mask.to(device_id, non_blocking=True), ) elif cfg.model.arch == "v-dual": imgs, lang, lang_mask = batch val_loss, _ = model( imgs.to(device_id, non_blocking=True), lang.to(device_id, non_blocking=True), lang_mask.to(device_id, non_blocking=True), ) elif cfg.model.arch == "v-gen": imgs, lang_con, lang_con_mask, lang_gen, lang_gen_mask, lang_gen_weight = batch val_loss, _, _, _ = model( imgs.to(device_id, non_blocking=True), lang_con.to(device_id, non_blocking=True), lang_con_mask.to(device_id, non_blocking=True), lang_gen.to(device_id, non_blocking=True), lang_gen_mask.to(device_id, non_blocking=True), lang_gen_weight, ) else: raise ValueError(f"Forward() for Model `{cfg.model.arch}` is not implemented!") # Add to Validation Losses val_losses.append(val_loss) # All Reduce --> Push Epoch Metrics --> Checkpoint! validation_loss = torch.stack(val_losses).mean() dist.all_reduce(validation_loss) avg_val_loss = validation_loss / world_size if is_rank_zero: epoch_status, train_loss, training_duration = metrics.push_epoch(epoch, avg_val_loss) saver.save( epoch=epoch + 1, is_local_step=False, model=model, optimizer=optimizer, duration=training_duration, train_loss=train_loss.item(), val_loss=avg_val_loss.item(), ) # === End of Epoch === dist.barrier() # Finalize metrics.finalize() # And... we're done! overwatch.info("...and that's all, folks!") dist.barrier() if __name__ == "__main__": # General Defaults --> should use Tensor Cores (kinda) if you have them! torch.set_float32_matmul_precision("high") torch.multiprocessing.set_start_method("spawn", force=True) pretrain() ================================================ FILE: examples/usage.py ================================================ """ usage.py Example script demonstrating how to load a Voltron model (`V-Cond`) and instantiate a Multiheaded Attention Pooling extractor head for downstream tasks. This is the basic formula/protocol for using Voltron for arbitrary downstream applications. Run with (from root of repository): `python examples/usage.py` """ import torch from torchvision.io import read_image from voltron import instantiate_extractor, load def usage() -> None: print("[*] Demonstrating Voltron Usage for Various Adaptation Applications") # Get `torch.device` for loading model (note -- we'll load weights directly onto device!) device = "cuda" if torch.cuda.is_available() else "cpu" # Load Voltron model --> specify `freeze`, `device` and get model (nn.Module) and preprocessor vcond, preprocess = load("v-cond", device=device, freeze=True) # Obtain and preprocess an image =>> can be from a dataset, from a camera on a robot, etc. img = preprocess(read_image("examples/img/peel-carrot-initial.png"))[None, ...].to(device) lang = ["peeling a carrot"] # Get various representations... with torch.no_grad(): multimodal_features = vcond(img, lang, mode="multimodal") # Fused vision & language features visual_features = vcond(img, mode="visual") # Vision-only features (no language) # Can instantiate various extractors for downstream applications vector_extractor = instantiate_extractor(vcond, n_latents=1, device=device)() seq_extractor = instantiate_extractor(vcond, n_latents=64, device=device)() # Assertions... assert list(vector_extractor(multimodal_features).shape) == [1, vcond.embed_dim], "Should return a dense vector!" assert list(seq_extractor(visual_features).shape) == [1, 64, vcond.embed_dim], "Should return a sequence!" if __name__ == "__main__": usage() ================================================ FILE: examples/verification/verify.py ================================================ """ verify.py Example script demonstrating how to load all Voltron models (and reproduced models), take input image(s), and get the various (e.g., multimodal, image-only) representations. Also serves to verify that representation loading is working as advertised. Run with (from root of repository): `python examples/verification/verify.py` """ import torch from torchvision.io import read_image from voltron import load # Available Models MODELS = ["v-cond", "v-dual", "v-gen", "r-mvp", "r-r3m-vit", "r-r3m-rn50"] # Sample Inputs IMG_A, IMG_B = "examples/verification/img/peel-carrot-initial.png", "examples/verification/img/peel-carrot-final.png" LANGUAGE = "peeling a carrot" def verify() -> None: print("[*] Running `verify` =>> Verifying Model Representations!") # Read both images (we'll use the second image for the dual-frame models) image_a, image_b = read_image(IMG_A), read_image(IMG_B) # Get `torch.device` for loading model (note -- we'll load weights directly onto device!) device = "cuda" if torch.cuda.is_available() else "cpu" for model_id in MODELS: print(f"\t=> Loading Model ID `{model_id}` and Verifying Representation Shapes!") model, preprocess = load(model_id, device=device, freeze=True) # Preprocess image, run feature extraction --> assert on shapes! if model_id in {"v-cond", "v-cond-base"}: for modality, expected in [("multimodal", 196 + 20), ("visual", 196)]: representation = model(preprocess(image_a)[None, ...].to(device), [LANGUAGE], mode=modality) assert representation.squeeze(dim=0).shape[0] == expected, "Shape not expected!" elif model_id in {"v-dual", "v-gen"}: for modality, expected in [("multimodal", 196 + 20), ("visual", 196)]: dual_img = torch.stack([preprocess(image_a), preprocess(image_b)])[None, ...].to(device) representation = model(dual_img, [LANGUAGE], mode=modality) assert representation.squeeze(dim=0).shape[0] == expected, "Shape not expected!" elif model_id == "r-mvp": for mode, expected in [("patch", 196), ("cls", 1)]: representation = model(preprocess(image_a)[None, ...].to(device), mode=mode) assert representation.squeeze(dim=0).shape[0] == expected, "Shape not expected!" elif model_id in {"r-r3m-vit", "r-r3m-rn50"}: representation = model(preprocess(image_a)[None, ...].to(device)) assert representation.squeeze(dim=0).shape[0] == 1, "Shape not expected!" else: raise ValueError(f"Model {model_id} not supported!") # We're good! print("[*] All representations & shapes verified! Yay!") if __name__ == "__main__": verify() ================================================ FILE: examples/xla-reference/README.md ================================================ # XLA Reference *Note :: This code was written for the experimental PyTorch XLA build in PyTorch 1.12; no guarantees it works with later versions!* We trained the original Voltron models (and data-locked reproductions of R3M and MVP) on TPU v3-8 nodes generously provided by the [TPU Research Cloud (TRC)](https://sites.research.google/trc/about/) program. At the time we started the project, PyTorch XLA still had some bumps, which was further complicated by the switch from [TPU Nodes to TPU VMs](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu-arch). To get things to work, we had to add some non-intuitive code to facilitate PyTorch + TPUs (vs. a standard distributed data parallel training pipeline). As a result, `xpretrain.py` is here mostly for documentation purposes, with a fully refactored version `pretrain.py` forthcoming. We also include the original cloud preprocessing script `xpreprocess.py` for completeness (this is more general). ================================================ FILE: examples/xla-reference/xpreprocess.py ================================================ """ xpreprocess.py Centralized script for preprocessing Sth-Sth-v2 for TPU/GCP pretraining, using a multi-stage, multiprocessing strategy. Run as a standalone script, *prior* to calling `xpretrain.py` =>> mostly because we want to preprocess the data once, as a fixed cost. """ import logging from dataclasses import dataclass, field from typing import Any, Dict, List import hydra from hydra.core.config_store import ConfigStore from omegaconf import MISSING from voltron.conf import DatasetConfig from voltron.overwatch import OverwatchRich from voltron.preprocessing.v1 import index, jsonify_language, preprocess_language, preprocess_videos, unify_batches from voltron.util.v1.random import set_global_seed # Grab Logger overwatch = logging.getLogger(__file__) # Set Defaults (Hydra w/ Structured Configs) DEFAULTS = ["_self_", {"dataset": "sth-sth-v2"}, {"override hydra/job_logging": "overwatch_rich"}] @dataclass class PreprocessingConfig: # fmt: off defaults: List[Any] = field(default_factory=lambda: DEFAULTS) hydra: Dict[str, Any] = field( default_factory=lambda: {"run": {"dir": "./runs/preprocessing/${now:%m-%d}/dataset-${dataset.name}"}} ) # Command Line Arguments seed: int = 21 # Random Seed (for reproducibility) dry_run: bool = False # Dry Run --> Get a sense of preprocessing/serialization footprint # Composable / Structured Arguments dataset: DatasetConfig = MISSING # Dataset(s) for pretraining/preprocessing # fmt: on # Hydra Setup :: Retrieve ConfigStore (Singleton) & Register Components cs = ConfigStore.instance() cs.store(group="hydra/job_logging", name="overwatch_rich", node=OverwatchRich) cs.store(name="config", node=PreprocessingConfig) @hydra.main(config_path=None, config_name="config") def xpreprocess(cfg: PreprocessingConfig) -> None: overwatch.info("Preprocessing :: Running Phases for Frame Extraction, Language Compilation, and Batching...") # Set Randomness set_global_seed(cfg.seed) # Phase 1 :: Serialize Frames from Video Clips --> Get `registry` for train and val (index structure) train_registry, val_registry, train_dir, val_dir = preprocess_videos( cfg.dataset.name, path=cfg.dataset.path, artifact_path=cfg.dataset.artifact_path, resolution=cfg.dataset.resolution, n_val_videos=cfg.dataset.n_val_videos, dry_run=cfg.dry_run, ) # Phase 2 :: Normalize & Tokenize Language --> Create `index.pt` & `index.json` files preprocess_language( cfg.dataset.name, train_registry, val_registry, max_lang_len=cfg.dataset.max_lang_len, language_model=cfg.dataset.language_model, hf_cache=cfg.dataset.hf_cache, ) jsonify_language(train_registry, val_registry) index_dir = index(train_registry, val_registry, cfg.dataset.name, artifact_path=cfg.dataset.artifact_path) # Phase 3 :: Assemble & Unify Batch "Sets" across the Varied Dataset Formats (for each Model =>> "data-locked") unify_batches( cfg.dataset.artifact_path, cfg.dataset.name, train_registry, val_registry, train_dir, val_dir, index_dir, cfg.dataset.batch_formats, max_epochs=cfg.dataset.max_epochs, initial_final_alpha=cfg.dataset.initial_final_alpha, ) if __name__ == "__main__": xpreprocess() ================================================ FILE: examples/xla-reference/xpretrain.py ================================================ """ xpretrain.py (The `x` prefix indicates this is a script geared for XLA/TPU backends *only*)! Reference script for PyTorch XLA (TPU-based) pretraining on the Something-Something-v2 dataset; this is mostly for completeness =>> the hope is that the regular `pretrain.py` script is more general and maintained. Focuses on multi-TPU (XLA) training --> but also supports single-core TPU training, as the default distributed mp.spawn behavior just collapses into a single thread! Loads and preprocesses dataset, instantiates a model, and runs training. Run with `python examples/xla-reference/xpretrain.py` (will use the configuration specified by `DEFAULTS` below). """ import os import re import time from collections import deque from dataclasses import dataclass, field from datetime import datetime from pathlib import Path from typing import Any, Dict, List, Optional import hydra import jsonlines import numpy as np import torch import torch_xla.core.xla_model as xm import torch_xla.debug.metrics as met import torch_xla.distributed.parallel_loader as parallel import wandb from hydra.core.config_store import ConfigStore from omegaconf import MISSING, OmegaConf from torch.utils.data import DataLoader from tqdm import tqdm from voltron.conf import AcceleratorConfig, DatasetConfig, ModelConfig, TrackingConfig from voltron.datasets.v1.stream_datasets import get_epoch_datasets from voltron.models import VMVP, VR3M, VRN3M, VCond, VDual, VGen from voltron.overwatch import OverwatchRich from voltron.util.v1.checkpointing import XLACheckpointSaver from voltron.util.v1.distributed import ResumeableDistributedSampler from voltron.util.v1.random import set_global_seed from voltron.util.v1.xla_logger import ( log_epoch_end_update, log_vcond_train_update, log_vdual_train_update, log_vgen_train_update, log_vmvp_train_update, log_vr3m_train_update, log_vrn3m_train_update, ) # Set Defaults (Hydra w/ Structured Configs) DEFAULTS = [ "_self_", {"model": "v-cond"}, {"dataset": "sth-sth-v2"}, {"accelerator": "tpu-v3-8"}, {"tracking": "voltron-tracking"}, {"override hydra/job_logging": "overwatch_rich"}, ] @dataclass class PretrainConfig: # fmt: off defaults: List[Any] = field(default_factory=lambda: DEFAULTS) hydra: Dict[str, Any] = field(default_factory=lambda: { "run": {"dir": "./runs/train/${model.identifier}+dataset-${dataset.name}"} }) # Command Line Arguments run_id: Optional[str] = None # Run ID for Logging seed: int = 21 # Random Seed (for reproducibility) # Resume / Debug Behavior resume: bool = True # Whether to resume an existing run... resume_epoch: Optional[int] = None # Epoch to resume (if auto-resuming)... checkpoint_path: Optional[str] = None # Path to the specific checkpoint to load! wandb_resume_id: Optional[str] = None # W&B Run ID for `resume` behavior... # Composable / Structured Arguments model: ModelConfig = MISSING # Model architecture for pretraining dataset: DatasetConfig = MISSING # List of datasets for pretraining accelerator: AcceleratorConfig = MISSING # Accelerator configuration tracking: TrackingConfig = MISSING # Run/experiment tracking configuration # fmt: on # Hydra Setup :: Retrieve ConfigStore (Singleton) & Register Components cs = ConfigStore.instance() cs.store(group="hydra/job_logging", name="overwatch_rich", node=OverwatchRich) # Annoying - configure logger for Hydra cs.store(name="config", node=PretrainConfig) # ruff: noqa: C901 def xpretrain(cfg: PretrainConfig) -> None: # Identify if `is_rank_zero` --> We only log from the rank zero process! is_rank_zero = xm.is_master_ordinal(local=False) xm.master_print("Voltron Training :: Assembling the Legendary Defender...") # Create Unique Run Name -- if `resume = True` we assume same "run_id" run_id = cfg.run_id if run_id is None: run_id = run_dir = f"{cfg.model.identifier}+{cfg.dataset.name}-x{cfg.seed}" cfg.run_id = run_id else: cfg.run_id = run_dir = run_id if is_rank_zero: os.makedirs(run_dir, exist_ok=True) xm.master_print( '\t=>> "If you get too worried about what could go wrong, you might miss a chance to do something great."' ) # Set Randomness, get DataLoader worker initialization function (to ensure any random augmentations!) worker_init_fn = set_global_seed(cfg.seed) # Model Initialization Logic xm.master_print("Initializing Model and Placing on Different Devices...") if cfg.model.arch == "v-mvp": xm.master_print(f"Initializing MVP variant `{cfg.model.identifier}`") model = VMVP( resolution=cfg.dataset.resolution, patch_size=cfg.model.patch_size, encoder_depth=cfg.model.encoder_depth, encoder_embed_dim=cfg.model.encoder_embed_dim, encoder_n_heads=cfg.model.encoder_n_heads, decoder_depth=cfg.model.decoder_depth, decoder_embed_dim=cfg.model.decoder_embed_dim, decoder_n_heads=cfg.model.decoder_n_heads, optimizer=cfg.model.optimizer, schedule=cfg.model.schedule, base_lr=cfg.model.base_lr, min_lr=cfg.model.min_lr, effective_bsz=cfg.model.effective_bsz, betas=cfg.model.betas, weight_decay=cfg.model.weight_decay, warmup_epochs=cfg.dataset.warmup_epochs, max_epochs=cfg.dataset.max_epochs, mlp_ratio=cfg.model.mlp_ratio, norm_pixel_loss=cfg.model.norm_pixel_loss, ) elif cfg.model.arch == "v-r3m": xm.master_print(f"Initializing R3M (ViT) Variant `{cfg.model.identifier}`") model = VR3M( resolution=cfg.dataset.resolution, patch_size=cfg.model.patch_size, depth=cfg.model.depth, embed_dim=cfg.model.embed_dim, n_heads=cfg.model.n_heads, language_model=cfg.model.language_model, hf_cache=cfg.model.hf_cache, language_dim=cfg.model.language_dim, reward_dim=cfg.model.reward_dim, n_negatives=cfg.model.n_negatives, lang_reward_weight=cfg.model.lang_reward_weight, tcn_weight=cfg.model.tcn_weight, l1_weight=cfg.model.l1_weight, l2_weight=cfg.model.l2_weight, optimizer=cfg.model.optimizer, schedule=cfg.model.schedule, lr=cfg.model.lr, min_lr=cfg.model.min_lr, warmup_epochs=cfg.dataset.warmup_epochs, max_epochs=cfg.dataset.max_epochs, mlp_ratio=cfg.model.mlp_ratio, ) elif cfg.model.arch == "v-rn3m": xm.master_print(f"Intializing R3M (ResNet) Variant `{cfg.model.identifier}`") model = VRN3M( resolution=cfg.dataset.resolution, fc_dim=cfg.model.fc_dim, language_model=cfg.model.language_model, hf_cache=cfg.model.hf_cache, language_dim=cfg.model.language_dim, reward_dim=cfg.model.reward_dim, n_negatives=cfg.model.n_negatives, lang_reward_weight=cfg.model.lang_reward_weight, tcn_weight=cfg.model.tcn_weight, l1_weight=cfg.model.l1_weight, l2_weight=cfg.model.l2_weight, optimizer=cfg.model.optimizer, lr=cfg.model.lr, ) elif cfg.model.arch == "v-cond": xm.master_print(f"Initializing Voltron V-Cond variant `{cfg.model.identifier}`") model = VCond( resolution=cfg.dataset.resolution, patch_size=cfg.model.patch_size, encoder_depth=cfg.model.encoder_depth, encoder_embed_dim=cfg.model.encoder_embed_dim, encoder_n_heads=cfg.model.encoder_n_heads, decoder_depth=cfg.model.decoder_depth, decoder_embed_dim=cfg.model.decoder_embed_dim, decoder_n_heads=cfg.model.decoder_n_heads, language_model=cfg.model.language_model, hf_cache=cfg.model.hf_cache, language_dim=cfg.model.language_dim, optimizer=cfg.model.optimizer, schedule=cfg.model.schedule, base_lr=cfg.model.base_lr, min_lr=cfg.model.min_lr, effective_bsz=cfg.model.effective_bsz, betas=cfg.model.betas, weight_decay=cfg.model.weight_decay, warmup_epochs=cfg.dataset.warmup_epochs, max_epochs=cfg.dataset.max_epochs, mlp_ratio=cfg.model.mlp_ratio, norm_pixel_loss=cfg.model.norm_pixel_loss, ) elif cfg.model.arch == "v-dual": xm.master_print(f"Initializing Voltron V-Dual variant `{cfg.model.identifier}`") model = VDual( resolution=cfg.dataset.resolution, patch_size=cfg.model.patch_size, encoder_depth=cfg.model.encoder_depth, encoder_embed_dim=cfg.model.encoder_embed_dim, encoder_n_heads=cfg.model.encoder_n_heads, decoder_depth=cfg.model.decoder_depth, decoder_embed_dim=cfg.model.decoder_embed_dim, decoder_n_heads=cfg.model.decoder_n_heads, language_model=cfg.model.language_model, hf_cache=cfg.model.hf_cache, language_dim=cfg.model.language_dim, optimizer=cfg.model.optimizer, schedule=cfg.model.schedule, base_lr=cfg.model.base_lr, min_lr=cfg.model.min_lr, effective_bsz=cfg.model.effective_bsz, betas=cfg.model.betas, weight_decay=cfg.model.weight_decay, warmup_epochs=cfg.dataset.warmup_epochs, max_epochs=cfg.dataset.max_epochs, mlp_ratio=cfg.model.mlp_ratio, norm_pixel_loss=cfg.model.norm_pixel_loss, ) elif cfg.model.arch == "v-gen": xm.master_print(f"Initializing Voltron V-Gen variant `{cfg.model.identifier}`") model = VGen( resolution=cfg.dataset.resolution, patch_size=cfg.model.patch_size, encoder_depth=cfg.model.encoder_depth, encoder_embed_dim=cfg.model.encoder_embed_dim, encoder_n_heads=cfg.model.encoder_n_heads, decoder_depth=cfg.model.decoder_depth, decoder_embed_dim=cfg.model.decoder_embed_dim, decoder_n_heads=cfg.model.decoder_n_heads, language_model=cfg.model.language_model, hf_cache=cfg.model.hf_cache, language_dim=cfg.model.language_dim, max_lang_len=cfg.dataset.max_lang_len, vocab_size=cfg.model.vocab_size, mae_weight=cfg.model.mae_weight, lm_weight=cfg.model.lm_weight, optimizer=cfg.model.optimizer, schedule=cfg.model.schedule, base_lr=cfg.model.base_lr, min_lr=cfg.model.min_lr, effective_bsz=cfg.model.effective_bsz, betas=cfg.model.betas, weight_decay=cfg.model.weight_decay, warmup_epochs=cfg.dataset.warmup_epochs, max_epochs=cfg.dataset.max_epochs, mlp_ratio=cfg.model.mlp_ratio, norm_pixel_loss=cfg.model.norm_pixel_loss, ) else: raise NotImplementedError(f"Model Architecture `{cfg.model.arch}` is not supported!") # We use gradient accumulation to honor the effective batch size specified... assert cfg.model.effective_bsz % cfg.model.device_bsz == 0, "Device bsz must evenly divide effective bsz!" accumulate_grad_batches = cfg.model.effective_bsz // cfg.model.device_bsz // xm.xrt_world_size() xm.master_print( f"Running `{cfg.model.identifier}` model w/ Effective Batch Size of `{cfg.model.effective_bsz}`, " f"Per-Device Batch Size of `{cfg.model.device_bsz}`, " f"Distributed World Size of `{xm.xrt_world_size()}` and `{accumulate_grad_batches}` Accumulation Steps" ) # If Resuming =>> Load Model from Checkpoint start_checkpoint, start_epoch, start_step = None, 0, 0 if cfg.resume: # **IMPORTANT**: We're making a few assumptions on resuming that should eventually become explicit checks: # - `accumulate_grad_batches` is exactly the same when resuming; this means: # + `cfg.model.effective_bsz`, `cfg.model.device_bsz`, & `cfg.accelerator.num_accelerators` are the same! # - The Weights & Biases directory `run_dir/wandb` only contains a *single run* # - The `param_groups` in `optimizer.state_dict()` are exactly the same across resumes! # + This means that (and generally should be true for resuming altogether) the architecture is the same! # - The `cfg.seed` should be the same (again, should generally be true...) if cfg.checkpoint_path is None: xm.master_print("Resuming :: Attempting to Automatically Load Checkpoint -- Searching!") checkpoint_path = Path(run_dir) / "checkpoints" if checkpoint_path.exists() and any(checkpoint_path.iterdir()): # Parse out the latest "complete" epoch checkpoint, as well as any "local step" checkpoints... checkpoints = list(checkpoint_path.iterdir()) complete_checkpoint, complete_epoch = max( [ (c, int(re.search("epoch=(.+?)-train", c.name).group(1))) for c in checkpoints if "local-epoch=" not in str(c) ], key=lambda x: x[1], ) # Case 1 :: We have "local step" checkpoints --> will always override any "full epoch" checkpoints... local = [ ( c, int(re.search("local-epoch=(.+?)-step", c.name).group(1)), int(re.search("step=(.+?)[.-]", c.name).group(1)), ) for c in checkpoints if "local-epoch=" in str(c) ] if len(local) > 0: # Parse out (epoch, "highest" step) + assert no great "full epoch" checkpoint exists! start_checkpoint, start_epoch, start_step = max(local, key=lambda x: x[1:]) assert start_epoch == complete_epoch, "Epoch mismatch in `resume` from local_step!" # Case 2 :: Otherwise, we're just going to start with the last "complete" epoch... else: start_checkpoint, start_epoch = complete_checkpoint, complete_epoch else: xm.master_print("No Checkpoints Found -- Starting Run from Scratch!") else: xm.master_print(f"Resuming :: Loading from Checkpoint `{cfg.checkpoint_path}`...") start_checkpoint = cfg.checkpoint_path # Actually Load the Checkpoint State! if start_checkpoint is not None: xm.master_print(f"Resuming :: Loading Model & Optimizer State Dictionaries from `{start_checkpoint}`") checkpoint = torch.load(str(start_checkpoint)) model_state_dict, optimizer_state_dict = checkpoint model.load_state_dict(model_state_dict) # Logging / W&B Handling if is_rank_zero: xm.master_print("Initializing Weights & Biases + JSONL + Checkpoint Saver on Rank Zero ONLY...") tags = None if cfg.tracking.tags is None: tags = [cfg.model.identifier, cfg.dataset.name, "pretraining"] # W&B Initialize & Log all Hyperparameters (Only on ordinal 0) wandb_resume_id = None if cfg.resume and cfg.wandb_resume_id is None: xm.master_print("Resuming :: Attempting to Automatically Load W&B Resume ID -- Searching!") wandb_path = Path("wandb") if wandb_path.exists() and any((wandb_path / "latest-run").iterdir()): # Parse out the unique resume_id from the `.wandb` file... wandb_fns = [f.name for f in (wandb_path / "latest-run").iterdir() if str(f).endswith(".wandb")] assert len(wandb_fns) == 1, f"There should only be 1 `.wandb` file... found {len(wandb_fns)}!" # Regex match on `run-{id}.wandb`... wandb_resume_id = re.search("run-(.+?).wandb", wandb_fns[0]).group(1) # Otherwise, assert that we're starting from scratch! else: assert start_checkpoint is None, "Trying to restart a run from checkpoint without a valid W&B ID!" elif cfg.resume: xm.master_print(f"Resuming :: Using Specified W&B Resume ID = `{cfg.wandb_resume_id}`") wandb_resume_id = cfg.wandb_resume_id # Initialize Weights & Biases xm.master_print(f"W&B Resume is {cfg.resume} w/ W&B Resume ID = {wandb_resume_id}!") wandb.init( project=cfg.tracking.project, entity=cfg.tracking.entity, config=cfg, name=run_id, dir=f"{os.getcwd()}" if cfg.tracking.directory is None else cfg.tracking.directory, tags=tags, notes=cfg.tracking.notes, resume="allow" if start_checkpoint is not None else False, id=wandb_resume_id, # Weird line because PT-TPU VMs don't come with a working install of Tensorflow... settings=wandb.Settings(_disable_stats=True), ) # Initialize JSONL Logger (append only mode) --> last "global step" will always take precedence. with jsonlines.open(f"{run_id}.jsonl", mode="a", sort_keys=True) as js_logger: js_logger.write( { "run_id": run_id, "start_time": datetime.now().strftime("%m-%d-%H:%M"), "hparams": OmegaConf.to_container(cfg), } ) # Rank Zero Node will take time to spin up the loggers & checkpointer... might as well rendezvous? xm.rendezvous("Logging...") # === Here Be Dragons === # Time to handle device placement -- Note - example code doesn't specify device idx - why not? # > https://github.com/pytorch/xla/blob/3c0d68da07702995a592ea70f27868cd76fa0755/test/test_train_mp_mnist.py#L114 # > Results in printing [xla:0] and [xla:1] a bunch... no [xla:2-7]? This feels bad...? # # |=> Debugging Try: `xm.xla_device(n=xm.get_ordinal()) ---> hangs completely? # +=> *ANSWER*: https://github.com/pytorch/xla/issues/2345#issuecomment-657114819 # >> "Make no assumptions and don't try to build them manually..." device = xm.xla_device() model = model.train().to(device) optimizer, update_lr = model.configure_optimizer() global_step, train_losses, lrs, start_time, resume_time = 0, deque(maxlen=128), [], time.time(), 0 # If resuming (valid `start_checkpoint`) -- patch the optimizer state dictionary, and load! if start_checkpoint is not None: patched_optimizer_state_dict = { "state": optimizer_state_dict, "param_groups": optimizer.state_dict()["param_groups"], } optimizer.load_state_dict(patched_optimizer_state_dict) # Create step timing... step_times, step_start_time = deque(maxlen=128), time.time() # Create Model/Architecture-Specific Trackers... if cfg.model.arch == "v-mvp": reconstruction_losses = deque(maxlen=128) elif cfg.model.arch in {"v-r3m", "v-rn3m"}: tcn_losses, reward_losses, l1_losses, l2_losses = [deque(maxlen=128) for _ in range(4)] tcn_accuracies, reward_accuracies = [deque(maxlen=128) for _ in range(2)] elif cfg.model.arch == "v-cond": reconstruction_losses = deque(maxlen=128) elif cfg.model.arch == "v-dual": reconstruction_losses = deque(maxlen=128) zero_reconstruction, k_reconstruction = deque(maxlen=128), deque(maxlen=128) elif cfg.model.arch == "v-gen": reconstruction_losses, lm_losses, lm_ppl = deque(maxlen=128), deque(maxlen=128), deque(maxlen=128) zero_reconstruction, k_reconstruction = deque(maxlen=128), deque(maxlen=128) else: raise NotImplementedError(f"Trackers for Model `{cfg.model.arch}` not implemented!") # 0th Checkpoint - Pull out optimizer state explicitly (`groups` are not serializable & can easily be replicated) saver = XLACheckpointSaver(cfg.tracking.checkpoint_strategy, run_dir, cfg.accelerator.accelerator) if start_checkpoint is None and start_epoch == 0: xm.master_print("Saving 0th Epoch Checkpoint...") saver.save( epoch=0, is_local_step=False, model=model, optimizer=optimizer, duration=0, train_loss=None, val_loss=None ) # Run on all processes --> retrieve "0th epoch" dataset! # =>> Important, ensures data is locked across models, for the given epoch! xm.master_print(f"Retrieving Dataset `{cfg.dataset.name}` prepared for `{cfg.model.arch}`!") train_dataset, val_dataset = get_epoch_datasets( 0, cfg.dataset.name, cfg.dataset.normalization, cfg.model.arch, cfg.dataset.stream, cfg.dataset.artifact_path, cfg.dataset.stream_prefix, cfg.model.data_modality, cfg.model.get("lang_dropout", None), cfg.model.get("gen_ratio", None), ) # Loading Datasets might take time... rendezvous to be safe xm.rendezvous("Retrieved Datasets...") # Iterate through Epochs, Evaluating at the end of each Training Epoch! # >> Everything in this loop should happen across all workers, except for the logging (ordinal 0)! xm.master_print("Starting Training Loop...") for epoch in range(start_epoch, cfg.dataset.max_epochs): xm.master_print(f"\t[Epoch {epoch}] Building Distributed Sampler & DataLoaders...") train_dataset.set_epoch(epoch) # ResumeableDistributedSampler operates at over *examples* --> start_step (full_batches) * effective_bsz seen_examples = start_step * cfg.model.effective_bsz train_sampler = ResumeableDistributedSampler( seen_examples, start_epoch, train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True, seed=cfg.seed, ) # Set epoch appropriately for the `train_sampler` --> necessary to trigger "normal" logic! train_sampler.set_epoch(epoch) train_dataloader = DataLoader( train_dataset, batch_size=cfg.model.device_bsz, sampler=train_sampler, shuffle=False if train_sampler else True, num_workers=cfg.accelerator.num_workers, drop_last=True, worker_init_fn=worker_init_fn, prefetch_factor=4, ) # NOTE :: We're not sharding the Validation set --> *everybody* will run forward passes on the *same* data! # > We will have to reduce_mesh() later... unclear why, but the torch_xla folks seem keen on it; might lead to # > weird rendezvous/hang issues if Validation is big enough... val_dataloader = DataLoader( val_dataset, batch_size=cfg.model.device_bsz, shuffle=False, num_workers=4, drop_last=True, worker_init_fn=worker_init_fn, ) # Initializing the Dataloaders might take time depending on process... xm.rendezvous("Initialized Dataloaders...") # Leverage the *special* API that handles synchronizing TPU cores across batches! # > NOTE: This is super important! xm.master_print("\tSetting up Parallel MpDeviceLoaders...") train_device_loader = parallel.MpDeviceLoader(train_dataloader, device) val_device_loader = parallel.MpDeviceLoader(val_dataloader, device) # Book-keeping & LR setting when `resuming` --> only do this on start_epoch! if epoch == start_epoch: if start_checkpoint is not None: global_step = start_step + ((len(train_dataset) // cfg.model.effective_bsz) * start_epoch) resume_time = int(re.search("-t=(.+?).pt", str(start_checkpoint)).group(1)) lrs.append(update_lr(start_epoch, start_step / (len(train_dataset) // cfg.model.effective_bsz))) else: lrs.append(update_lr(start_epoch, 0)) # Iterate... step_start_time = time.time() with tqdm(total=len(train_device_loader) // accumulate_grad_batches, disable=not is_rank_zero) as progress: for train_idx, batch in enumerate(train_device_loader): if cfg.model.arch == "v-mvp": # Run a forward pass through the MAE... other return vals are reconstructions (pixel norm) & mask loss, _, _ = model(batch) reconstruction_losses.append(loss) elif cfg.model.arch in {"v-r3m", "v-rn3m"}: imgs, lang, lang_mask = batch loss, tcn_loss, reward_loss, l1_loss, l2_loss, tcn_acc, rew_acc = model(imgs, lang, lang_mask) # Add to trackers tcn_losses.append(tcn_loss) reward_losses.append(reward_loss) l1_losses.append(l1_loss) l2_losses.append(l2_loss) tcn_accuracies.append(tcn_acc) reward_accuracies.append(rew_acc) elif cfg.model.arch == "v-cond": img, lang, lang_mask = batch loss, _, _ = model(img, lang, lang_mask) reconstruction_losses.append(loss) elif cfg.model.arch == "v-dual": imgs, lang, lang_mask = batch loss, [zero_loss, k_loss] = model(imgs, lang, lang_mask) # Add to trackers reconstruction_losses.append(loss) zero_reconstruction.append(zero_loss) k_reconstruction.append(k_loss) elif cfg.model.arch == "v-gen": imgs, lang_con, lang_con_mask, lang_gen, lang_gen_mask, lang_gen_weight = batch loss, reconstruction_loss, lm_loss, [zero_loss, k_loss] = model( imgs, lang_con, lang_con_mask, lang_gen, lang_gen_mask, lang_gen_weight ) # Add to trackers reconstruction_losses.append(reconstruction_loss) lm_losses.append(lm_loss) lm_ppl.append(torch.exp(lm_loss)) zero_reconstruction.append(zero_loss) k_reconstruction.append(k_loss) else: raise NotImplementedError(f"Forward Pass Logic for Model `{cfg.model.arch}` not implemented!") # Write Loss to Loggers (prior to accumulation normalization) train_losses.append(loss) # Normalize loss to account for accumulation loss = loss / accumulate_grad_batches loss.backward() # Gradient Accumulation =>> Note: skip any errant batches at the end... if (train_idx + 1) % accumulate_grad_batches == 0: xm.optimizer_step(optimizer) # Note call to xm.optimizer_step() -- has implicit mark_step! optimizer.zero_grad() # Add to `step_times` step_times.append(time.time() - step_start_time) # Logging --> Because there is no guarantee processes will be in sync, we need a `closure` # > Ref: https://pytorch.org/xla/release/1.11/index.html#torch_xla.core.xla_model.add_step_closure if is_rank_zero and global_step % cfg.tracking.log_frequency == 0: if cfg.model.arch == "v-mvp": xm.add_step_closure( log_vmvp_train_update, args=( epoch, global_step, run_id, train_losses, lrs[-1], reconstruction_losses, step_times, ), ) elif cfg.model.arch == "v-r3m": xm.add_step_closure( log_vr3m_train_update, args=( epoch, global_step, run_id, train_losses, lrs[-1], tcn_losses, reward_losses, l1_losses, l2_losses, tcn_accuracies, reward_accuracies, step_times, ), ) elif cfg.model.arch == "v-rn3m": xm.add_step_closure( log_vrn3m_train_update, args=( epoch, global_step, run_id, train_losses, lrs[-1], tcn_losses, reward_losses, l1_losses, l2_losses, tcn_accuracies, reward_accuracies, step_times, ), ) elif cfg.model.arch == "v-cond": xm.add_step_closure( log_vcond_train_update, args=( epoch, global_step, run_id, train_losses, lrs[-1], reconstruction_losses, step_times, ), ) elif cfg.model.arch == "v-dual": xm.add_step_closure( log_vdual_train_update, args=( epoch, global_step, run_id, train_losses, lrs[-1], reconstruction_losses, zero_reconstruction, k_reconstruction, step_times, ), ) elif cfg.model.arch == "v-gen": xm.add_step_closure( log_vgen_train_update, args=( epoch, global_step, run_id, train_losses, lrs[-1], reconstruction_losses, lm_losses, lm_ppl, zero_reconstruction, k_reconstruction, step_times, ), ) else: raise NotImplementedError(f"Log Update for Model `{cfg.model.arch}` not implemented!") # Increment Global Step _after_ logging! global_step += 1 # Save checkpoint subject to *local_step = (train_idx + 1) // accumulate_grad_batches* saver.save( epoch=epoch, is_local_step=True, model=model, optimizer=optimizer, duration=int(time.time() - start_time) + resume_time, local_step=start_step + ((train_idx + 1) // accumulate_grad_batches), ) # Update LR every `accumulation_steps` iterations... lrs.append( update_lr( epoch, (start_step + ((train_idx + 1) // accumulate_grad_batches)) / (len(train_dataset) // cfg.model.effective_bsz), ) ) # Reset `step_start_time` step_start_time = time.time() # Update `progress` each time we take a gradient step! progress.update() # After each forward pass, mark a step, to compile XLA graph for a single forward pass! # =>> Note :: this is important, with gradient accumulation, the graph can get massive otherwise! xm.mark_step() else: # Clear gradients and reset start step (regardless) at end of the loop optimizer.zero_grad() start_step = 0 # Redundant, but Synchronous Validation Epoch... xm.master_print("Validating...") val_losses = [] with torch.no_grad(): for batch in tqdm(val_device_loader, disable=not is_rank_zero): if cfg.model.arch == "v-mvp": loss, _, _ = model(batch) elif cfg.model.arch in {"v-r3m", "v-rn3m"}: imgs, lang, lang_mask = batch loss, _, _, _, _, _, _ = model(imgs, lang, lang_mask) elif cfg.model.arch == "v-cond": img, lang, lang_mask = batch loss, _, _ = model(img, lang, lang_mask) elif cfg.model.arch == "v-dual": imgs, lang, lang_mask = batch loss, _ = model(imgs, lang, lang_mask) elif cfg.model.arch == "v-gen": imgs, lang_con, lang_con_mask, lang_gen, lang_gen_mask, lang_gen_weight = batch loss, _, _, _ = model(imgs, lang_con, lang_con_mask, lang_gen, lang_gen_mask, lang_gen_weight) else: raise NotImplementedError(f"Forward Pass Logic for Model `{cfg.model.arch} not implemented!") # Just append to val_losses... val_losses.append(loss) # Compute Val Loss & *mesh reduce* --> Why? :: the XLA people said so! val_loss = torch.stack(val_losses).mean().item() val_loss = xm.mesh_reduce("val_loss", val_loss, np.mean) # All replicas should just return the same thing? # Logging --> add another `closure` for end-of-epoch cleanup --> compute `duration` as well... duration = int(time.time() - start_time) + resume_time if is_rank_zero: xm.add_step_closure( log_epoch_end_update, args=( cfg.model.arch, epoch, global_step, run_id, duration, train_losses, val_loss, lrs[-1], step_times, ), ) # Save Checkpoint (at end of Epoch) saver.save( epoch=epoch + 1, is_local_step=False, model=model, optimizer=optimizer, duration=duration, train_loss=train_losses[-1].item(), val_loss=val_loss, ) # Dump TPU Debugging Metrics... if is_rank_zero: with open("tpu-debug-metrics.log", "w") as f: f.write(met.metrics_report()) # Exiting w/ Multiprocessing is a Nightmare... try to join? xm.master_print("...and that's all, folks!") xm.rendezvous("Cheers!") # Sleep for like 3 minutes... get W&B to finish syncing logs wandb.finish() time.sleep(150) def mp_fn(_: int, cfg: PretrainConfig) -> None: torch.set_default_tensor_type("torch.FloatTensor") # Let's Start Pretraining! xpretrain(cfg) @hydra.main(config_path=None, config_name="config") def main(cfg: PretrainConfig) -> None: import torch_xla.distributed.xla_multiprocessing as xmp # Call XMP Spawn w/ the Config as the sole argument... xmp.spawn(mp_fn, args=(cfg,), nprocs=cfg.accelerator.num_accelerators, start_method="spawn") if __name__ == "__main__": main() ================================================ FILE: pyproject.toml ================================================ [build-system] requires = ["setuptools"] build-backend = "setuptools.build_meta" [project] name = "voltron-robotics" authors = [ {name = "Siddharth Karamcheti", email="skaramcheti@cs.stanford.edu"} ] description = "Voltron: Language-Driven Representation Learning for Robotics." version = "1.1.0" readme = "README.md" requires-python = ">=3.8" keywords = ["robotics", "representation learning", "natural language processing", "machine learning"] license = {file = "LICENSE"} classifiers = [ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", "Intended Audience :: Education", "Intended Audience :: Science/Research", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3 :: Only", "Topic :: Scientific/Engineering :: Artificial Intelligence", ] dependencies = [ "av", "einops", "gdown", "google-cloud-storage", "h5py", "hurry.filesize", "hydra-core==1.1.1", # Lock Hydra =>> future versions break! "jsonlines", "omegaconf==2.1.2", # Lock OmegaConf =>> future versions break! "opencv-python", "pandas", "rich", "torch>=2.0.0", # Native PyTorch Code (Release 2.0.0) uses PyTorch 2.0! "torchvision>=0.15.0", "transformers", "wandb", ] [project.optional-dependencies] dev = [ "black", "ipython", "pre-commit", "ruff", ] [project.urls] homepage = "https://github.com/siddk/voltron-robotics" repository = "https://github.com/siddk/voltron-robotics" documentation = "https://github.com/siddk/voltron-robotics" [tool.black] line-length = 121 target-version = ["py38", "py39", "py310"] preview = true [tool.ruff] line-length = 121 target-version = "py38" select = ["A", "B", "C90", "E", "F", "I", "RUF", "W"] [tool.ruff.per-file-ignores] "__init__.py" = ["E402", "F401"] [tool.setuptools.packages.find] where = ["."] exclude = ["cache"] ================================================ FILE: setup.py ================================================ """ setup.py PEP 621 switches most of Packaging to `pyproject.toml` -- yet keep a "dummy" setup.py for external code that has not yet upgraded. """ from setuptools import setup setup() ================================================ FILE: voltron/__init__.py ================================================ from .models.materialize import available_models, load from .models.util import instantiate_extractor ================================================ FILE: voltron/conf/__init__.py ================================================ from .accelerators import AcceleratorConfig from .datasets import DatasetConfig from .models import ModelConfig from .tracking import TrackingConfig ================================================ FILE: voltron/conf/accelerators.py ================================================ """ accelerator.py Base Hydra Structured Configs for defining various accelerator schemes. Uses a simple single inheritance structure. """ import os from dataclasses import dataclass from hydra.core.config_store import ConfigStore from omegaconf import MISSING # === Vanilla Accelerators (Deprecated; mostly for XLA code) === @dataclass class AcceleratorConfig: accelerator: str = MISSING num_accelerators: int = MISSING num_workers: int = MISSING @dataclass class TPUv2OneConfig(AcceleratorConfig): accelerator = "tpu" num_accelerators = 1 num_workers = 4 @dataclass class TPUv2EightConfig(AcceleratorConfig): accelerator = "tpu" num_accelerators = 8 num_workers = 4 @dataclass class TPUv3OneConfig(AcceleratorConfig): accelerator = "tpu" num_accelerators = 1 num_workers = 8 @dataclass class TPUv3EightConfig(AcceleratorConfig): accelerator = "tpu" num_accelerators = 8 num_workers = 8 # === GPU Default Config --> just set `num_workers`; `torchrun` takes care of the rest! === # > Note :: Defaults to 1 GPU if WORLD_SIZE not set (e.g., not running with `torchrun`) @dataclass class TorchRunDefaultConfig(AcceleratorConfig): accelerator = "gpu" num_accelerators = int(os.environ["WORLD_SIZE"] if "WORLD_SIZE" in os.environ else 1) num_workers = 8 # Create a configuration group `accelerator` and populate with the above... cs = ConfigStore.instance() cs.store(group="accelerator", name="tpu-v2-1", node=TPUv2OneConfig) cs.store(group="accelerator", name="tpu-v2-8", node=TPUv2EightConfig) cs.store(group="accelerator", name="tpu-v3-1", node=TPUv3OneConfig) cs.store(group="accelerator", name="tpu-v3-8", node=TPUv3EightConfig) cs.store(group="accelerator", name="torchrun", node=TorchRunDefaultConfig) ================================================ FILE: voltron/conf/datasets.py ================================================ """ datasets.py Base Hydra Structured Config for defining various pretraining datasets and appropriate configurations. Uses a simple, single inheritance structure. """ from dataclasses import dataclass from typing import Any, Tuple from hydra.core.config_store import ConfigStore from hydra.utils import to_absolute_path from omegaconf import MISSING @dataclass class DatasetConfig: name: str = MISSING path: str = MISSING artifact_path: str = MISSING # Streaming Parameters (assumes fully preprocessed dataset lives at `stream_prefix/...`) # =>> Deprecated as of `v2` stream: bool = True stream_prefix: str = "data/processed" # Dataset-Specific Parameters resolution: int = 224 normalization: Tuple[Any, Any] = MISSING # For preprocessing --> maximum size of saved frames (assumed square) preprocess_resolution: int = MISSING # Validation Parameters n_val_videos: int = MISSING # Language Modeling Parameters language_model: str = "distilbert-base-uncased" hf_cache: str = to_absolute_path("data/hf-cache") # Maximum Length for truncating language inputs... should be computed after the fact (set to -1 to compute!) max_lang_len: int = MISSING # Dataset sets the number of pretraining epochs (general rule :: warmup should be ~5% of full) warmup_epochs: int = MISSING max_epochs: int = MISSING # Plausible Formats --> These are instantiations each "batch" could take, with a small DSL # > Note: Assumes final element of the list is the "most expressive" --> used to back-off batch_formats: Any = ( ("state", ("state_i",)), ("state+language", ("state_i", "language")), ("state+ok", ("state_initial", "state_i", "language")), ("quintet+language", ("state_initial", "state_i", "state_j", "state_k", "state_final", "language")), ) # Preprocessing :: Frame-Sampling Parameters initial_final_alpha: float = 0.2 @dataclass class SthSthv2Config(DatasetConfig): # fmt: off name: str = "sth-sth-v2" path: str = to_absolute_path("data/raw/sth-sth-v2") artifact_path: str = to_absolute_path("data/processed/sth-sth-v2") # Dataset Specific arguments normalization: Tuple[Any, Any] = ( # Mean & Standard Deviation (default :: ImageNet) (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), ) # Sth-Sth-v2 Videos have a fixed height of 240; we'll crop to square at this resolution! preprocess_resolution: int = 240 # Validation Parameters n_val_videos: int = 1000 # Number of Validation Clips (fast evaluation!) # Epochs for Dataset warmup_epochs: int = 20 max_epochs: int = 400 # Language Modeling Parameters max_lang_len: int = 20 # fmt: on # Create a configuration group `dataset` and populate with the above... # =>> Note :: this is meant to be extendable --> add arbitrary datasets & mixtures! cs = ConfigStore.instance() cs.store(group="dataset", name="sth-sth-v2", node=SthSthv2Config) ================================================ FILE: voltron/conf/models.py ================================================ """ models.py Base Hydra Structured Config for defining various pretraining model architectures and appropriate configurations. Uses a simple single inheritance structure. """ from dataclasses import dataclass from typing import Tuple from hydra.core.config_store import ConfigStore from hydra.utils import to_absolute_path from omegaconf import MISSING @dataclass class ModelConfig: arch: str = MISSING identifier: str = MISSING # Dataset Modality data_modality: str = MISSING # Default Vision Transformer Configuration patch_size: int = 16 mlp_ratio: float = 4.0 # Effective batch size --> total number of examples before gradient update effective_bsz: int = MISSING # Number of examples one can safely fit on an accelerator w/ this model! device_bsz: int = MISSING # For backwards compatibility, only use device_bsz for XLA/TPU pretraining... native_bsz: int = MISSING # For backwards compatibility, define a separate `native_bsz`... # @Data-Locked Reproductions --- Encompasses MVP (MAE) + R3M # MVP (Base Masked Autoencoder) @dataclass class MVPConfig(ModelConfig): arch: str = "v-mvp" identifier: str = MISSING # Dataset Modality data_modality: str = "state" # Base MAE Parameters mask_ratio: float = 0.75 # Architecture Parameters encoder_depth: int = MISSING encoder_embed_dim: int = MISSING encoder_n_heads: int = MISSING decoder_depth: int = MISSING decoder_embed_dim: int = MISSING decoder_n_heads: int = MISSING # MAE Loss/Objective Configuration norm_pixel_loss: bool = True effective_bsz: int = 1024 device_bsz: int = MISSING native_bsz: int = MISSING # Optimization Parameters optimizer: str = "adamw" schedule: str = "linear-warmup+cosine-decay" base_lr: float = 1.5e-4 min_lr: float = 0.0 betas: Tuple[float, float] = (0.9, 0.95) weight_decay: float = 0.05 @dataclass class MVPSmallConfig(MVPConfig): identifier = "r-mvp" # Architecture Parameters -- should match ViT Small Architecture to the letter! # Note: Small is defined in TIMM: # > https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L683 encoder_depth = 12 encoder_embed_dim = 384 encoder_n_heads = 6 decoder_depth = 6 decoder_embed_dim = 192 decoder_n_heads = 6 # Number of examples one can safely fit on an accelerator w/ this model! # > TPU-v3: max of 128 per device. device_bsz = 128 native_bsz = 128 # R3M Models --> Just different visual encoders, roughly following the above! @dataclass class R3MConfig(ModelConfig): arch: str = "v-r3m" identifier: str = MISSING # Dataset Modality data_modality: str = "quintet+language" # ViT Architecture Parameters depth: int = MISSING embed_dim: int = MISSING n_heads: int = MISSING # Effective Batch Size effective_bsz: int = 1024 # Language Model Parameters language_model: str = "distilbert-base-uncased" hf_cache: str = to_absolute_path("data/hf-cache") language_dim: int = 768 vocab_size: int = 30522 reward_dim: int = 1024 # Loss/Objective Configuration lang_reward_weight: float = 1.0 tcn_weight: float = 1.0 l1_weight: float = 1e-5 l2_weight: float = 1e-5 n_negatives: int = 3 device_bsz: int = MISSING native_bsz: int = MISSING # Optimization Parameters optimizer: str = "adam" schedule: str = "linear-warmup+cosine-decay" lr: float = 1e-4 min_lr: float = 0.0 @dataclass class R3MSmallConfig(R3MConfig): identifier = "r-r3m-vit" # Architecture Parameters -- should match ViT Small Architecture to the letter! # Note: Small is defined in TIMM: # > https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L683 depth = 12 embed_dim = 384 n_heads = 6 # Device Batch Size device_bsz = 32 native_bsz = 128 # R3M -- ResNet50 Encoder (instead of ViT) @dataclass class ResNet3MConfig(ModelConfig): arch: str = "v-rn3m" identifier: str = MISSING # Dataset Modality data_modality: str = "quintet+language" # Effective Batch Size effective_bsz: int = 1024 # Architecture Parameters fc_dim: int = MISSING # Language Model Parameters language_model: str = "distilbert-base-uncased" hf_cache: str = to_absolute_path("data/hf-cache") language_dim: int = 768 vocab_size: int = 30522 reward_dim: int = 1024 # Loss/Objective Configuration lang_reward_weight: float = 1.0 tcn_weight: float = 1.0 l1_weight: float = 1e-5 l2_weight: float = 1e-5 n_negatives: int = 3 device_bsz: int = MISSING native_bsz: int = MISSING # Optimization Parameters optimizer: str = "adam" lr: float = 1e-4 class RN3M50Config(ResNet3MConfig): identifier = "r-r3m-rn50" # Architecture Parameters fc_dim = 2048 # Device Batch Size device_bsz = 32 native_bsz = 128 # @Voltron Models -- VCond, VDual, VGen # VCond -- Single Frame + Language Conditioning @dataclass class VCondConfig(ModelConfig): arch: str = "v-cond" identifier: str = MISSING # Dataset Modality data_modality: str = "state+language" # Base MAE Parameters mask_ratio: float = 0.75 # Base Language Parameters --> full sentence dropout only... language_model: str = "distilbert-base-uncased" hf_cache: str = to_absolute_path("data/hf-cache") language_dim: int = 768 vocab_size: int = 30522 lang_dropout: float = MISSING # Architecture Parameters encoder_depth: int = MISSING encoder_embed_dim: int = MISSING encoder_n_heads: int = MISSING decoder_depth: int = MISSING decoder_embed_dim: int = MISSING decoder_n_heads: int = MISSING use_cls_token: bool = True # MAE Loss/Objective Configuration norm_pixel_loss: bool = True effective_bsz: int = 1024 device_bsz: int = MISSING native_bsz: int = MISSING # Optimization Parameters optimizer: str = "adamw" schedule: str = "linear-warmup+cosine-decay" base_lr: float = 1.5e-4 min_lr: float = 0.0 betas: Tuple[float, float] = (0.9, 0.95) weight_decay: float = 0.05 @dataclass class VCondSmallConfig(VCondConfig): identifier = "v-cond" # No language dropout... lang_dropout = 0.0 # Architecture Parameters -- should match ViT Small Architecture to the letter! # Note: Small is defined in TIMM: # > https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L683 encoder_depth = 12 encoder_embed_dim = 384 encoder_n_heads = 6 decoder_depth = 6 decoder_embed_dim = 192 decoder_n_heads = 6 # Number of examples one can safely fit on an accelerator w/ this model! # > TPU-v3: max of 128 per device # > GPU w/ 32G of RAM: max of 128 per device! device_bsz = 128 native_bsz = 128 @dataclass class VCondBaseConfig(VCondConfig): identifier = "v-cond-base" # No language dropout... lang_dropout = 0.0 # Architecture Parameters -- should match ViT Base Architecture to the letter! # Note: Base is defined in TIMM & Original MAE Repository: # > https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L723 # > https://github.com/facebookresearch/mae/blob/main/models_mae.py#L223 encoder_depth = 12 encoder_embed_dim = 768 encoder_n_heads = 12 decoder_depth = 8 decoder_embed_dim = 512 decoder_n_heads = 16 # Number of examples one can safely fit on an accelerator w/ this model! # > TPU-v3: max of 128 per device! # > GPU w/ 32G of RAM: max of 128 per device! device_bsz = 128 native_bsz = 128 # VDual - Dual Frame (0th Frame + Kth frame) + Language Conditioning @dataclass class VDualConfig(ModelConfig): arch: str = "v-dual" identifier: str = MISSING # Dataset Modality data_modality: str = "state+ok" # Base MAE Parameters mae_weight: float = 1.0 mask_ratio: float = 0.75 # Base Language Parameters --> full sentence dropout only... language_model: str = "distilbert-base-uncased" hf_cache: str = to_absolute_path("data/hf-cache") language_dim: int = 768 vocab_size: int = 30522 lang_dropout: float = MISSING # Architecture Parameters encoder_depth: int = MISSING encoder_embed_dim: int = MISSING encoder_n_heads: int = MISSING decoder_depth: int = MISSING decoder_embed_dim: int = MISSING decoder_n_heads: int = MISSING use_cls_token: bool = True # MAE Loss/Objective Configuration -- Cut effective batch size since we see 12-25x contexts per batch example! norm_pixel_loss: bool = True effective_bsz: int = 1024 device_bsz: int = MISSING native_bsz: int = MISSING # Optimization Parameters optimizer: str = "adamw" schedule: str = "linear-warmup+cosine-decay" base_lr: float = 1.5e-4 min_lr: float = 0.0 betas: Tuple[float, float] = (0.9, 0.95) weight_decay: float = 0.05 @dataclass class VDualSmallConfig(VDualConfig): identifier = "v-dual" # No language dropout... lang_dropout = 0.0 # Architecture Parameters -- should match ViT Small Architecture to the letter! # Note: Small is defined in TIMM: # > https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L683 encoder_depth = 12 encoder_embed_dim = 384 encoder_n_heads = 6 decoder_depth = 6 decoder_embed_dim = 192 decoder_n_heads = 6 # Number of examples one can safely fit on an accelerator w/ this model! # > TPU-v3: max of 128 per device! # > GPU w/ 32G of RAM: max of 128 per device! device_bsz = 128 native_bsz = 128 @dataclass class VDualBaseConfig(VDualConfig): identifier = "v-dual-base" # No language dropout... lang_dropout = 0.0 # Architecture Parameters -- should match ViT Base Architecture to the letter! # Note: Base is defined in TIMM & Original MAE Repository: # > https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L723 # > https://github.com/facebookresearch/mae/blob/main/models_mae.py#L223 encoder_depth = 12 encoder_embed_dim = 768 encoder_n_heads = 12 decoder_depth = 8 decoder_embed_dim = 512 decoder_n_heads = 16 # Number of examples one can safely fit on an accelerator w/ this model! # > TPU-v3: max of 128 per device! # > GPU w/ 32G of RAM: max of 64 per device! device_bsz = 128 native_bsz = 64 # VGen - Dual Frame with Language Conditioning AND Language Generation @dataclass class VGenConfig(ModelConfig): arch: str = "v-gen" identifier: str = MISSING # Dataset Modality data_modality: str = "state+ok" # Base MAE & LM Parameters --> LM Weight is set such that mae & lang loss are ~same order of magnitude mae_weight: float = 1.0 lm_weight: float = 0.5 mask_ratio: float = 0.75 gen_ratio: float = MISSING # Base Language Parameters language_model: str = "distilbert-base-uncased" hf_cache: str = to_absolute_path("data/hf-cache") language_dim: int = 768 vocab_size: int = 30522 # Architecture Parameters encoder_depth: int = MISSING encoder_embed_dim: int = MISSING encoder_n_heads: int = MISSING decoder_depth: int = MISSING decoder_embed_dim: int = MISSING decoder_n_heads: int = MISSING use_cls_token: bool = True # MAE Loss/Objective Configuration -- Cut effective batch size since we see 12-25x contexts per batch example! norm_pixel_loss: bool = True effective_bsz: int = 1024 device_bsz: int = MISSING native_bsz: int = MISSING # Optimization Parameters optimizer: str = "adamw" schedule: str = "linear-warmup+cosine-decay" base_lr: float = 1.5e-4 min_lr: float = 0.0 betas: Tuple[float, float] = (0.9, 0.95) weight_decay: float = 0.05 @dataclass class VGen50SmallConfig(VGenConfig): identifier = "v-gen" # LM Parameters --> control % of examples that are for "language generation" (no conditioning) gen_ratio = 0.50 # Architecture Parameters -- should match ViT Small Architecture to the letter! # Note: Small is defined in TIMM: # > https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L683 encoder_depth = 12 encoder_embed_dim = 384 encoder_n_heads = 6 decoder_depth = 6 decoder_embed_dim = 192 decoder_n_heads = 6 # Number of examples one can safely fit on an accelerator w/ this model! # > TPU-v3: max of 64 per device! # > GPU w/ 32G of RAM: max of 64 per device! device_bsz = 64 native_bsz = 64 @dataclass class VGen50BaseConfig(VGenConfig): identifier = "v-gen-base" # LM Parameters --> control % of examples that are for "language generation" (no conditioning) gen_ratio = 0.50 # Architecture Parameters -- should match ViT Base Architecture to the letter! # Note: Base is defined in TIMM & Original MAE Repository: # > https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L723 # > https://github.com/facebookresearch/mae/blob/main/models_mae.py#L223 encoder_depth = 12 encoder_embed_dim = 768 encoder_n_heads = 12 decoder_depth = 8 decoder_embed_dim = 512 decoder_n_heads = 16 # Number of examples one can safely fit on an accelerator w/ this model! # > TPU-v3: max of 32 per device! # > GPU w/ 32G of RAM: max of 32 per device! device_bsz = 32 native_bsz = 32 # Create a configuration group `model` and populate with the above... cs = ConfigStore.instance() # === @Data-Locked Reproductions === # Image-Only MAE/MVP Architectures cs.store(group="model", name="r-mvp", node=MVPSmallConfig) # R3M Architectures - ViT & ResNet50 cs.store(group="model", name="r-r3m-vit", node=R3MSmallConfig) cs.store(group="model", name="r-r3m-rn50", node=RN3M50Config) # === @Voltron === # VCond Architectures cs.store(group="model", name="v-cond", node=VCondSmallConfig) cs.store(group="model", name="v-cond-base", node=VCondBaseConfig) # VDual cs.store(group="model", name="v-dual", node=VDualSmallConfig) cs.store(group="model", name="v-dual-base", node=VDualBaseConfig) # VGen cs.store(group="model", name="v-gen", node=VGen50SmallConfig) cs.store(group="model", name="v-gen-base", node=VGen50BaseConfig) ================================================ FILE: voltron/conf/tracking.py ================================================ """ tracking.py Base Hydra Structured Config for defining various run & experiment tracking configurations, e.g., via Weights & Biases. Uses a simple single inheritance structure. """ from dataclasses import dataclass, field from typing import List, Optional, Tuple from hydra.core.config_store import ConfigStore from omegaconf import MISSING @dataclass class TrackingConfig: # Active Loggers --> List of Loggers active_loggers: List[str] = field(default_factory=lambda: ["jsonl", "wandb"]) # Generic Logging Frequency --> Matters more for XLA/TPUs... set this to be as large as you can stomach! log_frequency: int = 100 # Checkpointing Strategy --> Save each epoch, keep most recent `idx[0]` checkpoints & *every* `idx[1]` checkpoints # Additionally, save (locally) a checkpoint every `idx[2]` steps for the current epoch (-1). checkpoint_strategy: Tuple[int, int, int] = (1, 1, 1500) # Weights & Biases Setup project: str = "voltron-pretraining" entity: str = "voltron-robotics" # Notes & Tags are at the discretion of the user... see below notes: str = MISSING tags: Optional[List[str]] = None # Directory to save W&B Metadata & Logs in General -- if None, defaults to `logs/` in the Hydra CWD directory: Optional[str] = None @dataclass class VoltronTrackingConfig(TrackingConfig): # Note: I really like using notes to keep track of things, so will crash unless specified with run. # > For `tags` I like to populate based on other args in the script, so letting it remain None notes: str = MISSING # Create a configuration group `trackers` and populate with the above... cs = ConfigStore.instance() cs.store(group="tracking", name="voltron-tracking", node=VoltronTrackingConfig) ================================================ FILE: voltron/datasets/__init__.py ================================================ from .datasets import get_datasets ================================================ FILE: voltron/datasets/datasets.py ================================================ """ datasets.py Core Pytorch Dataset implementations for the various "data flavors" used by the different representation learning models (Voltron and data-locked reproductions). Crucially, ach dataset loads from the corresponding serialized batch files that define the exact data (and order of iterating through the data) to see during each epoch. Notably, these serialized files control exactly what data is seen by *all* methods *across epochs*; using these files is critical to reproducibility & comparisons. The file contains logic for a "standard" Dataset; all files (batch index files, image/video/language files) are stored on local disk, assuming storage conducive to fast random reads. For a "streaming" dataset (loading data directly from GCP Buckets/Amazon S3), see `v1/stream_datasets.py`. """ from pathlib import Path from typing import Any, Optional, Tuple import h5py import numpy as np import torch from torch.utils.data import Dataset from torchvision.io import read_image from torchvision.transforms import Compose from voltron.preprocessing.transforms import get_online_transform class PretrainDataset(Dataset): def __init__(self) -> None: super().__init__() self.epoch, self.h5, self.vid, self.states = 0, None, None, None self.index_path, self.language_path, self.language = None, None, None def hydrate(self, path: Path) -> None: # Create Open HDF5 Handle self.h5 = h5py.File(path, "r") self.vid, self.states = self.h5["vid"].asstr(), self.h5["states"].asstr() # Load Language Index if self.language_path is not None: self.language = torch.load(self.index_path / self.language_path) def set_epoch(self, epoch: int) -> None: self.epoch = epoch def __getitem__(self, idx: int) -> Tuple[torch.Tensor, ...]: raise NotImplementedError("PretrainDataset is an abstract class; should never be initialized directly!") def __len__(self) -> int: raise NotImplementedError("PretrainDataset is an abstract class; should never be initialized directly!") class StateDataset(PretrainDataset): def __init__(self, epoch: int, index_path: Path, img_transform: Compose, is_val: bool = False) -> None: super().__init__() self.index_path, self.is_val, self.val_loaded = index_path, is_val, False self.epoch, self.img_transform, self.hdf5_path, self.n_examples = epoch, img_transform, None, None # === Retrieve Epoch Batches --> only call before/between epochs (as we create new DataLoaders) === self.set_epoch(epoch) def set_epoch(self, epoch: int) -> None: # Load Validation Batches if self.is_val and not self.val_loaded: self.hdf5_path = self.index_path / "state" / "validation-batches.hdf5" with h5py.File(self.hdf5_path, "r") as h5: self.n_examples = len(h5["states"]) # Set `val_loaded` self.val_loaded = True # Load Train Batches elif not self.is_val: self.hdf5_path = self.index_path / "state" / f"train-epoch={epoch}-batches.hdf5" with h5py.File(self.hdf5_path, "r") as h5: self.n_examples = len(h5["states"]) def __getitem__(self, idx: int) -> torch.Tensor: """Return processed image frame as a Tensor.""" if self.h5 is None: self.hydrate(self.hdf5_path) return self.img_transform(read_image(str(self.index_path.parent / self.states[idx][0]))) def __len__(self) -> int: return self.n_examples class StateLanguageDataset(PretrainDataset): def __init__( self, epoch: int, index_path: Path, img_transform: Compose, lang_dropout: Optional[float] = None, is_val: bool = False, ) -> None: super().__init__() self.index_path, self.is_val, self.val_loaded = index_path, is_val, False self.epoch, self.img_transform, self.hdf5_path, self.n_examples = epoch, img_transform, None, None self.lang_dropout, self.dropout_idxs = 0.0 if (lang_dropout is None) else lang_dropout, set() # Set Language Path self.language_path = "val-language-index.pt" if self.is_val else "train-language-index.pt" # === Retrieve Epoch Batches --> only call before/between epochs (as we create new DataLoaders) === self.set_epoch(epoch) def set_epoch(self, epoch: int) -> None: # Load Validation Batches if self.is_val and not self.val_loaded: self.hdf5_path = self.index_path / "state+language" / "validation-batches.hdf5" with h5py.File(self.hdf5_path, "r") as h5: self.n_examples = len(h5["states"]) # Set `val_loaded` self.val_loaded = True # Load Train Batches elif not self.is_val: self.hdf5_path = self.index_path / "state+language" / f"train-epoch={epoch}-batches.hdf5" with h5py.File(self.hdf5_path, "r") as h5: self.n_examples = len(h5["states"]) # Assemble Dropout Indices n_drop = int(self.lang_dropout * self.n_examples) self.dropout_idxs = set(np.random.choice(self.n_examples, n_drop, replace=False)) def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Return processed image frame and language, decomposed as the input_ids and attention_mask.""" if self.h5 is None: self.hydrate(self.hdf5_path) # Get Vid ID --> parse out language, transform frame! vid = self.vid[idx] lang, lang_mask = self.language[vid]["input_ids"], self.language[vid]["attention_mask"] # Dropout Language (Naive Zeroing leads to NaN --> just want the "CLS" token) if idx in self.dropout_idxs: # Initial language token is *always* = `101` --> last token always = `102` lang[1:] *= 0 lang_mask[1:] *= 0 # Retrieve Frame & Return img = self.img_transform(read_image(str(self.index_path.parent / self.states[idx][0]))) return img, lang, lang_mask def __len__(self) -> int: return self.n_examples class StateOKDataset(PretrainDataset): def __init__( self, epoch: int, index_path: Path, img_transform: Compose, lang_dropout: Optional[float] = None, is_val: bool = False, ) -> None: super().__init__() self.index_path, self.is_val, self.val_loaded = index_path, is_val, False self.epoch, self.img_transform, self.hdf5_path, self.n_examples = epoch, img_transform, None, None self.lang_dropout, self.dropout_idxs = 0.0 if (lang_dropout is None) else lang_dropout, set() # Set Language Path self.language_path = "val-language-index.pt" if self.is_val else "train-language-index.pt" # === Retrieve Epoch Batches --> only call before/between epochs (as we create new DataLoaders) === self.set_epoch(epoch) def set_epoch(self, epoch: int) -> None: # Load Validation Batches if self.is_val and not self.val_loaded: self.hdf5_path = self.index_path / "state+ok" / "validation-batches.hdf5" with h5py.File(self.hdf5_path, "r") as h5: self.n_examples = len(h5["states"]) # Set `val_loaded` self.val_loaded = True # Load Train Batches elif not self.is_val: self.hdf5_path = self.index_path / "state+ok" / f"train-epoch={epoch}-batches.hdf5" with h5py.File(self.hdf5_path, "r") as h5: self.n_examples = len(h5["states"]) # Assemble Dropout Indices n_drop = int(self.lang_dropout * self.n_examples) self.dropout_idxs = set(np.random.choice(self.n_examples, n_drop, replace=False)) def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Return processed dual frames and language, decomposed as the input_ids and attention_mask.""" if self.h5 is None: self.hydrate(self.hdf5_path) # Get Vid ID --> parse out language, transform frames! vid = self.vid[idx] lang, lang_mask = self.language[vid]["input_ids"], self.language[vid]["attention_mask"] # Dropout Language (Naive Zeroing leads to NaN --> just want the "CLS" token) if idx in self.dropout_idxs: # Initial language token is *always* = `101` --> last token always = `102` lang[1:] *= 0 lang_mask[1:] *= 0 # Retrieve Frames & Return imgs = self.states[idx] imgs = torch.stack([self.img_transform(read_image(str(self.index_path.parent / fn))) for fn in imgs]) return imgs, lang, lang_mask def __len__(self) -> int: return self.n_examples class GenStateOKDataset(PretrainDataset): def __init__( self, epoch: int, index_path: Path, img_transform: Compose, gen_ratio: float, is_val: bool = False, ) -> None: super().__init__() self.index_path, self.is_val, self.val_loaded = index_path, is_val, False self.epoch, self.img_transform, self.hdf5_path, self.n_examples = epoch, img_transform, None, None self.gen_ratio, self.gen_idxs = gen_ratio, set() # Set Language Path self.language_path = "val-language-index.pt" if self.is_val else "train-language-index.pt" # === Retrieve Epoch Batches --> only call before/between epochs (as we create new DataLoaders) === self.set_epoch(epoch) def set_epoch(self, epoch: int) -> None: # Load Validation Batches if self.is_val and not self.val_loaded: self.hdf5_path = self.index_path / "state+ok" / "validation-batches.hdf5" with h5py.File(self.hdf5_path, "r") as h5: self.n_examples = len(h5["states"]) # Set `val_loaded` self.val_loaded = True # Load Train Batches elif not self.is_val: self.hdf5_path = self.index_path / "state+ok" / f"train-epoch={epoch}-batches.hdf5" with h5py.File(self.hdf5_path, "r") as h5: self.n_examples = len(h5["states"]) # Assemble Generation Indices n_gen = int(self.gen_ratio * self.n_examples) self.gen_idxs = set(np.random.choice(self.n_examples, n_gen, replace=False)) def __getitem__( self, idx: int ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, float]: """Return dual frames, conditioning language, language to generate, decomposed as input_ids/attention mask.""" if self.h5 is None: self.hydrate(self.hdf5_path) # Get Vid ID --> parse out language to condition on / generate vid = self.vid[idx] lang_con, lang_con_mask = self.language[vid]["input_ids"], self.language[vid]["attention_mask"] lang_gen, lang_gen_mask, lang_gen_weight = lang_con.clone(), lang_con_mask.clone(), None # Generate / Condition Check --> (Naive Zeroing leads to NaN --> just want the "CLS" token) if idx in self.gen_idxs: # When Generating --> just condition on the token and generate the rest! lang_con[1:] *= 0 lang_con_mask[1:] *= 0 lang_gen_weight = 1 else: # When Conditioning -> just generate the token (so things don't crash) but set weight to 0 lang_gen[1:] *= 0 lang_gen_mask[1:] *= 0 lang_gen_weight = 0 # Retrieve Frames and Return imgs = self.states[idx] imgs = torch.stack([self.img_transform(read_image(str(self.index_path.parent / fn))) for fn in imgs]) return imgs, lang_con, lang_con_mask, lang_gen, lang_gen_mask, lang_gen_weight def __len__(self) -> int: return self.n_examples class QuintetDataset(PretrainDataset): def __init__(self, epoch: int, index_path: Path, img_transform: Compose, is_val: bool = False) -> None: super().__init__() self.index_path, self.is_val, self.val_loaded = index_path, is_val, False self.epoch, self.img_transform, self.hdf5_path, self.n_examples = epoch, img_transform, None, None # Set Language Path self.language_path = "val-language-index.pt" if self.is_val else "train-language-index.pt" # === Retrieve Epoch Batches --> only call before/between epochs (as we create new DataLoaders) === self.set_epoch(epoch) def set_epoch(self, epoch: int) -> None: # Load Validation Batches if self.is_val and not self.val_loaded: self.hdf5_path = self.index_path / "quintet+language" / "validation-batches.hdf5" with h5py.File(self.hdf5_path, "r") as h5: self.n_examples = len(h5["states"]) # Set `val_loaded` self.val_loaded = True # Load Train Batches elif not self.is_val: self.hdf5_path = self.index_path / "quintet+language" / f"train-epoch={epoch}-batches.hdf5" with h5py.File(self.hdf5_path, "r") as h5: self.n_examples = len(h5["states"]) def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Return all five processed frames and language, decomposed as the input_ids/attention_mask.""" if self.h5 is None: self.hydrate(self.hdf5_path) # Get Vid ID --> parse out language, transform frames! vid = self.vid[idx] lang, lang_mask = self.language[vid]["input_ids"], self.language[vid]["attention_mask"] # Retrieve Frames & Return imgs = self.states[idx] imgs = torch.stack([self.img_transform(read_image(str(self.index_path.parent / fn))) for fn in imgs]) return imgs, lang, lang_mask def __len__(self) -> int: return self.n_examples def get_datasets( epoch: int, dataset_name: str, model_arch: str, artifact_path: str, data_modality: str, resolution: int, normalization: Tuple[Any, Any], lang_dropout: Optional[float] = None, gen_ratio: Optional[float] = None, ) -> Tuple[PretrainDataset, PretrainDataset]: index = Path(artifact_path) / dataset_name / "index" img_transform = get_online_transform(dataset_name, model_arch, resolution, normalization) # Switch on `data_modality` --> differs based on `model_arch` (e.g., MVP --> img, V-Cond --> img, language) if data_modality == "state": train_ds = StateDataset(epoch, index, img_transform) val_ds = StateDataset(epoch, index, img_transform, is_val=True) elif data_modality == "state+language": train_ds = StateLanguageDataset(epoch, index, img_transform, lang_dropout) val_ds = StateLanguageDataset(epoch, index, img_transform, lang_dropout, is_val=True) elif data_modality == "state+ok": # V-Dual --> don't return language modeling elements (causal attention mask, suffix language, etc.) if gen_ratio is None: train_ds = StateOKDataset(epoch, index, img_transform, lang_dropout) val_ds = StateOKDataset(epoch, index, img_transform, lang_dropout, is_val=True) # V-Gen --> add language modeling elements! else: train_ds = GenStateOKDataset(epoch, index, img_transform, gen_ratio) val_ds = GenStateOKDataset(epoch, index, img_transform, gen_ratio, is_val=True) elif data_modality == "quintet+language": train_ds = QuintetDataset(epoch, index, img_transform) val_ds = QuintetDataset(epoch, index, img_transform, is_val=True) else: raise ValueError(f"Data Modality `{data_modality}` is not supported!") return train_ds, val_ds ================================================ FILE: voltron/datasets/v1/__init__.py ================================================ ================================================ FILE: voltron/datasets/v1/stream_datasets.py ================================================ """ stream_datasets.py Core PyTorch Datasets for the various "flavors" of data used by the various models under study. Crucially, each dataset loads from the corresponding "batch" serialized files, that define the exact data to use. Notably, these serialized files control exactly what data is seen by *all* methods **across epochs.** Using them is fairly critical to reproducibility & fair comparison. This specific file contains logic for a "streaming" Dataset; data is fetched (within the dataloader, by each worker) via an open connection over the network to a GCS bucket, materializing data as raw BytesIO objects fed to PIL.Image constructors. """ import json import os from io import BytesIO from pathlib import Path from typing import Any, Optional, Tuple import numpy as np import torch from google.api_core.exceptions import NotFound from google.auth.exceptions import TransportError from google.cloud import storage from google.resumable_media._helpers import _LOGGER from PIL import Image, UnidentifiedImageError from torch.utils.data import Dataset, get_worker_info from torchvision.io import read_image from torchvision.transforms import Compose from torchvision.transforms.functional import pil_to_tensor from voltron.preprocessing.v1.transforms import get_online_transform from voltron.util.distributed import get_rank # NOTE --> IF STREAMING JPEGS, WE NEED TO USE PILLOW TO READ FILES (w/o extracting locally...) # =>> Instead of `read_image(file)` assume we have "fname" and open fileobj (as BytesIO) -- remember to `seek(0)` # # > from PIL import Image # > from torchvision.transforms.functional import pil_to_tensor # > tensor = pil_to_tensor(Image.open(fileobj) # |--> This returns a `torch.uint8` Tensor of shape [3, 224, 224] --> *verified* equivalent to `read_image` # Create Global GCS Client... # =>> Multiprocessing.spawn() does not inherit from base shell --> need to set service account key... # =>> TODO :: Figure out how to fetch num_accelerators & num_workers programatically... os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "/mnt/home/auth/gcp-auth.json" N_CORES, BUCKETS = 8, [storage.Client().bucket("voltron-ANONYMIZED") for _ in range(8 * 8)] # Suppress Google Cloud Loggers _LOGGER.propagate = False storage.blob._logger.propagate = False class PretrainDataset(Dataset): def set_epoch(self, epoch: int) -> None: self.epoch = epoch class StateDataset(PretrainDataset): def __init__( self, epoch: int, index_path: Path, img_transform: Compose, stream: bool = False, prefix: Optional[Path] = None, is_val: bool = False, do_retry: bool = True, n_retries: int = 3, ) -> None: super().__init__() self.index_path, self.stream, self.is_val, self.val_loaded = index_path, stream, is_val, False self.epoch, self.transform, self.elements, self.prefix = epoch, img_transform, None, prefix self.r = N_CORES * get_rank() self.do_retry, self.n_retries = do_retry, n_retries # === Retrieve Epoch Batches (only call `set_epoch` inside __init__() === self.set_epoch(self.epoch) def set_epoch(self, epoch: int) -> None: # Not Streaming --> Read from local disk... if not self.stream: if self.is_val and not self.val_loaded: with open(self.index_path / "state" / "validation-batches.json", "r") as f: self.elements = json.load(f) # Set `val_loaded` and move on... self.val_loaded = True elif not self.is_val: with open(self.index_path / "state" / f"train-epoch={epoch}-batches.json", "r") as f: self.elements = json.load(f) # Streaming --> Beam directly from Bucket else: if self.is_val and not self.val_loaded: blob = BUCKETS[self.r].blob(str(self.prefix / "index" / "state" / "validation-batches.json")) self.elements = json.loads(blob.download_as_string()) # `elements[i]["state"] currently maps to disk path... remove all but `parent/child.jpg` for element in self.elements: element["state"] = "/".join(element["state"].split("/")[-2:]) # Set `val_loaded` and move on... self.val_loaded = True elif not self.is_val: blob = BUCKETS[self.r].blob(str(self.prefix / "index" / "state" / f"train-epoch={epoch}-batches.json")) self.elements = json.loads(blob.download_as_string()) # `elements[i]["state"]` currently maps to disk path... remove all but `parent/child.jpg` for element in self.elements: element["state"] = "/".join(element["state"].split("/")[-2:]) def __getitem__(self, index: int) -> torch.Tensor: """Return single frame as torch Tensor.""" if not self.stream: return self.transform(read_image(self.elements[index]["state"])) else: # Multiplex w/ num_worker idx... worker_info = get_worker_info() r = (self.r + worker_info.id) if worker_info is not None else self.r # Streaming + Retry Logic (in case of a bad connection -- retry same file!) frame_path = self.elements[index]["state"] for _i in range(self.n_retries): try: # Stream JPEG contents into BytesIO (seek back to 0), then into PIL Image.open() if self.is_val: blob, fobj = BUCKETS[r].blob(str(self.prefix / "val" / frame_path)), BytesIO() else: blob, fobj = BUCKETS[r].blob(str(self.prefix / "train" / frame_path)), BytesIO() # Error Handling --> we've already run several dry-run/verification trials, but about ~0.004% of # the time, we'll hit some sort of TCP/Transport error; this might even go up # with multiple runs happening at the same time. # # To address this, we're adopting the simplest possible "retry" strategy that # immediately tries to re-download the same file (and crashes if not possible). # This ensures reproducibility, but puts some extra effort onto the user... # File download... blob.download_to_file(fobj) fobj.seek(0) # Image loading... img_tensor = pil_to_tensor(Image.open(fobj)) # Return transformed image... return self.transform(img_tensor) except (NotFound, TransportError, UnidentifiedImageError, OSError) as e: # At the minimum --> print the broken file (obnoxiously!) print(f"=>> BROKEN FILE :: {frame_path}") if not self.do_retry: raise e else: continue # If we've exhausted our retries --> raise an informative ValueError raise ValueError(f"Failed to fix state `{self.elements[index]['state']}` w/ {self.n_retries} retries...") def __len__(self) -> int: return len(self.elements) class StateLanguageDataset(PretrainDataset): def __init__( self, epoch: int, index_path: Path, img_transform: Compose, lang_dropout: Optional[float] = None, stream: bool = False, prefix: Optional[Path] = None, is_val: bool = False, do_retry: bool = True, n_retries: int = 3, ) -> None: super().__init__() self.index_path, self.stream, self.is_val, self.val_loaded = index_path, stream, is_val, False self.epoch, self.transform, self.elements, self.prefix = epoch, img_transform, None, prefix self.lang_dropout = 0.0 if (lang_dropout is None or lang_dropout == 0) else lang_dropout self.dropout_indices = set() self.r = N_CORES * get_rank() self.do_retry, self.n_retries = do_retry, n_retries # Load Language Index & Retrieve Epoch 0 Batches language_path = "val-language-index.json" if self.is_val else "train-language-index.json" if not self.stream: with open(self.index_path / language_path, "r") as f: self.language_index = json.load(f) else: blob = BUCKETS[self.r].blob(str(self.prefix / "index" / language_path)) self.language_index = json.loads(blob.download_as_string()) # === Retrieve Epoch Batches (only call `set_epoch` inside __init__() === self.set_epoch(self.epoch) def set_epoch(self, epoch: int) -> None: # Not Streaming --> Read from local disk... if not self.stream: if self.is_val and not self.val_loaded: with open(self.index_path / "state+language" / "validation-batches.json", "r") as f: self.elements = json.load(f) # Assemble the set of dropout indices for the given epoch... n_drop = int(self.lang_dropout * len(self.elements)) self.dropout_indices = set(np.random.choice(len(self.elements), n_drop, replace=False)) # Set `val_loaded` and move on... self.val_loaded = True elif not self.is_val: with open(self.index_path / "state+language" / f"train-epoch={epoch}-batches.json", "r") as f: self.elements = json.load(f) # Assemble the set of dropout indices for the given epoch... n_drop = int(self.lang_dropout * len(self.elements)) self.dropout_indices = set(np.random.choice(len(self.elements), n_drop, replace=False)) # Streaming --> Beam directly from Bucket else: if self.is_val and not self.val_loaded: blob = BUCKETS[self.r].blob(str(self.prefix / "index" / "state+language" / "validation-batches.json")) self.elements = json.loads(blob.download_as_string()) # `elements[i]["state"]` currently maps to disk path... remove all but `parent/child.jpg` for element in self.elements: element["state"] = "/".join(element["state"].split("/")[-2:]) # Assemble the set of dropout indices for the given epoch... n_drop = int(self.lang_dropout * len(self.elements)) self.dropout_indices = set(np.random.choice(len(self.elements), n_drop, replace=False)) # Set `val_loaded` and move on... self.val_loaded = True elif not self.is_val: blob = BUCKETS[self.r].blob( str(self.prefix / "index" / "state+language" / f"train-epoch={epoch}-batches.json") ) self.elements = json.loads(blob.download_as_string()) # `elements[i]["state"] currently maps to disk path... remove all but `parent/child.jpg` for element in self.elements: element["state"] = "/".join(element["state"].split("/")[-2:]) # Assemble the set of dropout indices for the given epoch... n_drop = int(self.lang_dropout * len(self.elements)) self.dropout_indices = set(np.random.choice(len(self.elements), n_drop, replace=False)) def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Return the frame and language, decomposed as the input_ids, and attention_mask.""" vid = self.elements[index]["vid"] lang = torch.tensor(self.language_index[vid]["input_ids"], dtype=torch.int64) lang_mask = torch.tensor(self.language_index[vid]["attention_mask"], dtype=torch.int64) # Dropout Language Check --> (Naive Zeroing leads to NaN --> just want the "CLS" token) if index in self.dropout_indices: # Initial language token is *always* = `101` --> last token always = `102` lang[1:] *= 0 lang_mask[1:] *= 0 # Retrieve Single Frame if not self.stream: img = self.transform(read_image(self.elements[index]["state"])) return img, lang, lang_mask else: # Multiplex w/ num_worker idx... worker_info = get_worker_info() r = (self.r + worker_info.id) if worker_info is not None else self.r # Streaming + Retry Logic (in case of a bad connection -- retry same file!) frame_path = self.elements[index]["state"] for _i in range(self.n_retries): try: # Stream JPEG contents into BytesIO (seek back to 0), then into PIL Image.open() if self.is_val: blob, fobj = BUCKETS[r].blob(str(self.prefix / "val" / frame_path)), BytesIO() else: blob, fobj = BUCKETS[r].blob(str(self.prefix / "train" / frame_path)), BytesIO() # Error Handling --> we've already run several dry-run/verification trials, but about ~0.004% of # the time, we'll hit some sort of TCP/Transport error; this might even go up # with multiple runs happening at the same time. # # To address this, we're adopting the simplest possible "retry" strategy that # immediately tries to re-download the same file (and crashes if not possible). # This ensures reproducibility, but puts some extra effort onto the user... # File download... blob.download_to_file(fobj) fobj.seek(0) # Image loading... img_tensor = pil_to_tensor(Image.open(fobj)) # Assemble transformed image and return... img = self.transform(img_tensor) return img, lang, lang_mask except (NotFound, TransportError, UnidentifiedImageError, OSError) as e: # At the minimum --> print the broken file (obnoxiously!) print(f"=>> BROKEN FILE :: {frame_path}") if not self.do_retry: raise e else: continue # If we've exhausted our retries --> raise an informative ValueError raise ValueError(f"Failed to fix state `{self.elements[index]['state']}` w/ {self.n_retries} retries...") def __len__(self) -> int: return len(self.elements) class StateOKDataset(PretrainDataset): def __init__( self, epoch: int, index_path: Path, img_transform: Compose, lang_dropout: Optional[float] = None, stream: bool = False, prefix: Optional[Path] = None, no_lang: bool = False, is_val: bool = False, do_retry: bool = True, n_retries: int = 3, ) -> None: super().__init__() self.index_path, self.stream, self.is_val, self.val_loaded = index_path, stream, is_val, False self.epoch, self.transform, self.elements, self.prefix = epoch, img_transform, None, prefix self.no_lang, self.lang_dropout = no_lang, 0.0 if (lang_dropout is None or lang_dropout == 0) else lang_dropout self.dropout_indices = set() self.r = N_CORES * get_rank() self.do_retry, self.n_retries = do_retry, n_retries # Load Language Index & Retrieve Epoch 0 Batches if not self.no_lang: language_path = "val-language-index.json" if self.is_val else "train-language-index.json" if not self.stream: with open(self.index_path / language_path, "r") as f: self.language_index = json.load(f) else: blob = BUCKETS[self.r].blob(str(self.prefix / "index" / language_path)) self.language_index = json.loads(blob.download_as_string()) # === Retrieve Epoch Batches (only call `set_epoch` inside __init__() === self.set_epoch(self.epoch) def set_epoch(self, epoch: int) -> None: # Not Streaming --> Read from local disk... if not self.stream: if self.is_val and not self.val_loaded: with open(self.index_path / "state+ok" / "validation-batches.json", "r") as f: self.elements = json.load(f) # Assemble the set of dropout indices for the given epoch... n_drop = int(self.lang_dropout * len(self.elements)) self.dropout_indices = set(np.random.choice(len(self.elements), n_drop, replace=False)) # Set `val_loaded` and move on... self.val_loaded = True elif not self.is_val: with open(self.index_path / "state + ok" / f"train-epoch={epoch}-batches.json", "r") as f: self.elements = json.load(f) # Assemble the set of dropout indices for the given epoch... n_drop = int(self.lang_dropout * len(self.elements)) self.dropout_indices = set(np.random.choice(len(self.elements), n_drop, replace=False)) # Streaming --> Beam directly from Bucket else: if self.is_val and not self.val_loaded: blob = BUCKETS[self.r].blob(str(self.prefix / "index" / "state+ok" / "validation-batches.json")) self.elements = json.loads(blob.download_as_string()) # `elements[i]["states"]` currently maps to disk path.. remove all but `parent/child.jpg` for element in self.elements: element["states"] = ["/".join(x.split("/")[-2:]) for x in element["states"]] # Assemble the set of dropout indices for the given epoch... n_drop = int(self.lang_dropout * len(self.elements)) self.dropout_indices = set(np.random.choice(len(self.elements), n_drop, replace=False)) # Set `val_loaded` and move on... self.val_loaded = True elif not self.is_val: blob = BUCKETS[self.r].blob( str(self.prefix / "index" / "state+ok" / f"train-epoch={epoch}-batches.json") ) self.elements = json.loads(blob.download_as_string()) # `elements[i]["states"]` currently maps to disk path.. remove all but `parent/child.jpg` for element in self.elements: element["states"] = ["/".join(x.split("/")[-2:]) for x in element["states"]] # Assemble the set of dropout indices for the given epoch... n_drop = int(self.lang_dropout * len(self.elements)) self.dropout_indices = set(np.random.choice(len(self.elements), n_drop, replace=False)) # ruff: noqa: C901 def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Return both states/frames and language, decomposed as the input_ids and attention_mask.""" vid = self.elements[index]["vid"] # Fetch language if desired... if not self.no_lang: lang = torch.tensor(self.language_index[vid]["input_ids"], dtype=torch.int64) lang_mask = torch.tensor(self.language_index[vid]["attention_mask"], dtype=torch.int64) # Dropout Language Check --> (Naive Zeroing leads to NaN --> just want the "CLS" token) if index in self.dropout_indices: # Initial language token is *always* = `101` --> last token always = `102` lang[1:] *= 0 lang_mask[1:] *= 0 # Retrieve Frames if not self.stream: imgs = self.elements[index]["states"] imgs = torch.stack([self.transform(read_image(s)) for s in imgs]) # Return --> based on `self.no_lang` if not self.no_lang: return imgs, lang, lang_mask else: return imgs else: # Multiplex w/ num_worker idx... worker_info = get_worker_info() r = (self.r + worker_info.id) if worker_info is not None else self.r # Streaming + Retry Logic (in case of a bad connection -- retry same files!) frame_paths, current_frame = list(self.elements[index]["states"]), None for _i in range(self.n_retries): try: # Stream JPEG contents into BytesIO (seek back to 0), then PIL Image.open() imgs = [] for _current_idx, current_frame in enumerate(frame_paths): if self.is_val: blob, fobj = BUCKETS[r].blob(str(self.prefix / "val" / current_frame)), BytesIO() else: blob, fobj = BUCKETS[r].blob(str(self.prefix / "train" / current_frame)), BytesIO() # Error Handling --> we've already run several dry-run/verification trials, but about ~0.004% of # the time, we'll hit some sort of TCP/Transport error; this might even go up # with multiple runs happening at the same time. # # To address this, we're adopting the simplest possible "retry" strategy that # immediately tries to re-download the same file (crashes if not possible). # This ensures reproducibility, but puts some extra effort onto the user... # File download... blob.download_to_file(fobj) fobj.seek(0) # Image loading... img_tensor = pil_to_tensor(Image.open(fobj)) imgs.append(self.transform(img_tensor)) # Stack... assert len(imgs) == 2, "Something went awry with try/except in StateOK Dataset..." imgs = torch.stack(imgs) # Return --> based on `self.no_lang` if not self.no_lang: return imgs, lang, lang_mask else: return imgs except (NotFound, TransportError, UnidentifiedImageError, OSError) as e: # At the minimum --> print the broken file (obnoxiously!) print(f"=>> BROKEN FILE :: {current_frame}") if not self.do_retry: raise e else: continue # If we've exhausted our retries --> raise an informative ValueError raise ValueError(f"Failed to fix states `{self.elements[index]['states']}` w/ {self.n_retries} retries...") def __len__(self) -> int: return len(self.elements) class GenStateOKDataset(PretrainDataset): def __init__( self, epoch: int, index_path: Path, img_transform: Compose, gen_ratio: float, stream: bool = False, prefix: Optional[Path] = None, is_val: bool = False, do_retry: bool = True, n_retries: int = 3, ) -> None: super().__init__() self.index_path, self.stream, self.is_val, self.val_loaded = index_path, stream, is_val, False self.epoch, self.transform, self.elements, self.prefix = epoch, img_transform, None, prefix self.gen_ratio, self.gen_indices = gen_ratio, set() self.r = N_CORES * get_rank() self.do_retry, self.n_retries = do_retry, n_retries # Load Language Index & Retrieve Epoch 0 Batches language_path = "val-language-index.json" if self.is_val else "train-language-index.json" if not self.stream: with open(self.index_path / language_path, "r") as f: self.language_index = json.load(f) else: blob = BUCKETS[self.r].blob(str(self.prefix / "index" / language_path)) self.language_index = json.loads(blob.download_as_string()) # === Retrieve Epoch Batches (only call `set_epoch` inside __init__() === self.set_epoch(self.epoch) def set_epoch(self, epoch: int) -> None: # Not Streaming --> Read from local disk... if not self.stream: if self.is_val and not self.val_loaded: with open(self.index_path / "state+ok" / "validation-batches.json", "r") as f: self.elements = json.load(f) # Assemble the set of dropout indices for the given epoch... n_gen = int(self.gen_ratio * len(self.elements)) self.gen_indices = set(np.random.choice(len(self.elements), n_gen, replace=False)) # Set `val_loaded` and move on... self.val_loaded = True elif not self.is_val: with open(self.index_path / "state+ok" / f"train-epoch={epoch}-batches.json", "r") as f: self.elements = json.load(f) # Assemble the set of dropout indices for the given epoch... n_gen = int(self.gen_ratio * len(self.elements)) self.gen_indices = set(np.random.choice(len(self.elements), n_gen, replace=False)) # Streaming --> Beam directly from Bucket else: if self.is_val and not self.val_loaded: blob = BUCKETS[self.r].blob(str(self.prefix / "index" / "state+ok" / "validation-batches.json")) self.elements = json.loads(blob.download_as_string()) # `elements[i]["states"]` currently maps to disk path.. remove all but `parent/child.jpg` for element in self.elements: element["states"] = ["/".join(x.split("/")[-2:]) for x in element["states"]] # Assemble the set of dropout indices for the given epoch... n_gen = int(self.gen_ratio * len(self.elements)) self.gen_indices = set(np.random.choice(len(self.elements), n_gen, replace=False)) # Set `val_loaded` and move on... self.val_loaded = True elif not self.is_val: blob = BUCKETS[self.r].blob( str(self.prefix / "index" / "state+ok" / f"train-epoch={epoch}-batches.json") ) self.elements = json.loads(blob.download_as_string()) # `elements[i]["states"]` currently maps to disk path.. remove all but `parent/child.jpg` for element in self.elements: element["states"] = ["/".join(x.split("/")[-2:]) for x in element["states"]] # Assemble the set of dropout indices for the given epoch... n_gen = int(self.gen_ratio * len(self.elements)) self.gen_indices = set(np.random.choice(len(self.elements), n_gen, replace=False)) def __getitem__( self, index: int ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, float]: """Return both states/frames, language to condition on, language to generate, decomposed as input_ids/mask.""" vid = self.elements[index]["vid"] # Fetch language to condition on / generate... lang_condition = torch.tensor(self.language_index[vid]["input_ids"], dtype=torch.int64) lang_condition_mask = torch.tensor(self.language_index[vid]["attention_mask"], dtype=torch.int64) lang_gen = torch.tensor(self.language_index[vid]["input_ids"], dtype=torch.int64) lang_gen_mask = torch.tensor(self.language_index[vid]["attention_mask"], dtype=torch.int64) # Generate/Condition Check --> (Naive Zeroing leads to NaN --> just want the "CLS" token) if index in self.gen_indices: # If generating, just condition on the token (always the initial...), but generate everything! lang_condition[1:] *= 0 lang_condition_mask[1:] *= 0 lang_gen_weight = 1 else: # If conditioning, generate the token (dummy so things don't crash), but set weight to 0 lang_gen[1:] *= 0 lang_gen_mask[1:] *= 0 lang_gen_weight = 0 # Retrieve Frames if not self.stream: imgs = self.elements[index]["states"] imgs = torch.stack([self.transform(read_image(s)) for s in imgs]) # Return... return imgs, lang_condition, lang_condition_mask, lang_gen, lang_gen_mask, lang_gen_weight else: # Multiplex w/ num_worker idx... worker_info = get_worker_info() r = (self.r + worker_info.id) if worker_info is not None else self.r # Streaming + Retry Logic (in case of a bad connection -- retry same files!) frame_paths, current_frame = list(self.elements[index]["states"]), None for _i in range(self.n_retries): try: # Stream JPEG contents into BytesIO (seek back to 0), then PIL Image.open() imgs = [] for _current_idx, current_frame in enumerate(frame_paths): if self.is_val: blob, fobj = BUCKETS[r].blob(str(self.prefix / "val" / current_frame)), BytesIO() else: blob, fobj = BUCKETS[r].blob(str(self.prefix / "train" / current_frame)), BytesIO() # Error Handling --> we've already run several dry-run/verification trials, but about ~0.004% of # the time, we'll hit some sort of TCP/Transport error; this might even go up # with multiple runs happening at the same time. # # To address this, we're adopting the simplest possible "retry" strategy that # immediately tries to re-download the same file (crashes if not possible). # This ensures reproducibility, but puts some extra effort onto the user... # File download... blob.download_to_file(fobj) fobj.seek(0) # Image loading... img_tensor = pil_to_tensor(Image.open(fobj)) imgs.append(self.transform(img_tensor)) # Stack... assert len(imgs) == 2, "Something went awry with try/except in GenStateOK Dataset..." imgs = torch.stack(imgs) # Return... return imgs, lang_condition, lang_condition_mask, lang_gen, lang_gen_mask, lang_gen_weight except (NotFound, TransportError, UnidentifiedImageError, OSError) as e: # At the minimum --> print the broken file (obnoxiously!) print(f"=>> BROKEN FILE :: {current_frame}") if not self.do_retry: raise e else: continue # If we've exhausted our retries --> raise an informative ValueError raise ValueError(f"Failed to fix states `{self.elements[index]['states']}` w/ {self.n_retries} retries...") def __len__(self) -> int: return len(self.elements) class QuintetDataset(PretrainDataset): def __init__( self, epoch: int, index_path: Path, img_transform: Compose, lang_dropout: Optional[float] = None, stream: bool = False, prefix: Optional[Path] = None, is_val: bool = False, ) -> None: super().__init__() self.index_path, self.stream, self.is_val, self.val_loaded = index_path, stream, is_val, False self.epoch, self.transform, self.elements, self.prefix = epoch, img_transform, None, prefix self.lang_dropout = 0.0 if (lang_dropout is None or lang_dropout == 0) else lang_dropout self.dropout_indices = set() self.r = N_CORES * get_rank() # Load Language Index & Retrieve Epoch 0 Batches language_path = "val-language-index.json" if self.is_val else "train-language-index.json" if not self.stream: with open(self.index_path / language_path, "r") as f: self.language_index = json.load(f) else: blob = BUCKETS[self.r].blob(str(self.prefix / "index" / language_path)) self.language_index = json.loads(blob.download_as_string()) # === Retrieve Epoch Batches (only call `set_epoch` inside __init__() === self.set_epoch(self.epoch) def set_epoch(self, epoch: int) -> None: # Not Streaming --> Read from local disk... if not self.stream: if self.is_val and not self.val_loaded: with open(self.index_path / "quintet+language" / "validation-batches.json", "r") as f: self.elements = json.load(f) # Assemble the set of dropout indices for the given epoch... n_drop = int(self.lang_dropout * len(self.elements)) self.dropout_indices = set(np.random.choice(len(self.elements), n_drop, replace=False)) # Set `val_loaded` and move on... self.val_loaded = True elif not self.is_val: with open(self.index_path / "quintet+language" / f"train-epoch={epoch}-batches.json", "r") as f: self.elements = json.load(f) # Assemble the set of dropout indices for the given epoch... n_drop = int(self.lang_dropout * len(self.elements)) self.dropout_indices = set(np.random.choice(len(self.elements), n_drop, replace=False)) # Streaming --> Beam directly from Bucket else: if self.is_val and not self.val_loaded: blob = BUCKETS[self.r].blob(str(self.prefix / "index" / "quintet+language" / "validation-batches.json")) self.elements = json.loads(blob.download_as_string()) # `elements[i]["states"]` currently maps to disk path.. remove all but `parent/child.jpg` for element in self.elements: element["states"] = ["/".join(x.split("/")[-2:]) for x in element["states"]] # Assemble the set of dropout indices for the given epoch... n_drop = int(self.lang_dropout * len(self.elements)) self.dropout_indices = set(np.random.choice(len(self.elements), n_drop, replace=False)) # Set `val_loaded` and move on... self.val_loaded = True elif not self.is_val: blob = BUCKETS[self.r].blob( str(self.prefix / "index" / "quintet+language" / f"train-epoch={epoch}-batches.json") ) self.elements = json.loads(blob.download_as_string()) # `elements[i]["states"]` currently maps to disk path.. remove all but `parent/child.jpg` for element in self.elements: element["states"] = ["/".join(x.split("/")[-2:]) for x in element["states"]] # Assemble the set of dropout indices for the given epoch... n_drop = int(self.lang_dropout * len(self.elements)) self.dropout_indices = set(np.random.choice(len(self.elements), n_drop, replace=False)) def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Return all 5 states/frames, and language, decomposed as the input_ids and attention_mask.""" vid = self.elements[index]["vid"] lang = torch.tensor(self.language_index[vid]["input_ids"], dtype=torch.int64) lang_mask = torch.tensor(self.language_index[vid]["attention_mask"], dtype=torch.int64) # Dropout Language Check --> (Naive Zeroing leads to NaN --> just want the "PAD" token) if index in self.dropout_indices: # Initial language token is *always* = `101` --> last token always = `102` lang[1:] *= 0 lang_mask[1:] *= 0 # Retrieve Frames if not self.stream: imgs = self.elements[index]["states"] imgs = torch.stack([self.transform(read_image(s)) for s in imgs]) else: # Multiplex w/ num_worker idx... worker_info, imgs = get_worker_info(), [] r = (self.r + worker_info.id) if worker_info is not None else self.r # Stream JPEG contents into BytesIO (seek back to 0), then PIL Image.open() for state in self.elements[index]["states"]: if self.is_val: blob, fobj = BUCKETS[r].blob(str(self.prefix / "val" / state)), BytesIO() else: blob, fobj = BUCKETS[r].blob(str(self.prefix / "train" / state)), BytesIO() # Download into FileObj & Rewind... blob.download_to_file(fobj) fobj.seek(0) # Add to imgs... imgs.append(self.transform(pil_to_tensor(Image.open(fobj)))) # Stack... imgs = torch.stack(imgs) return imgs, lang, lang_mask def __len__(self) -> int: return len(self.elements) def get_epoch_datasets( epoch: int, name: str, normalization: Tuple[Any, Any], model_arch: str, stream: bool, artifact_path: str, stream_prefix: str, data_modality: str, lang_dropout: Optional[float] = None, gen_ratio: Optional[float] = None, ) -> Tuple[PretrainDataset, PretrainDataset]: """Retrieve the custom Dataset classes for the train & val set for the given dataset & data modality.""" index, img_transform = Path(artifact_path) / name / "index", get_online_transform(name, model_arch, normalization) prefix = Path(stream_prefix) / name if stream else stream_prefix # Switch on `data_modality` if data_modality == "state": train_ds = StateDataset(epoch, index, img_transform, stream, prefix) val_ds = StateDataset(epoch, index, img_transform, stream, prefix, is_val=True) elif data_modality == "state+language": train_ds = StateLanguageDataset(epoch, index, img_transform, lang_dropout, stream, prefix) val_ds = StateLanguageDataset(epoch, index, img_transform, lang_dropout, stream, prefix, is_val=True) elif data_modality == "state+ok": if gen_ratio is None: nl = model_arch == "v-dual" train_ds = StateOKDataset(epoch, index, img_transform, lang_dropout, stream, prefix, no_lang=nl) val_ds = StateOKDataset(epoch, index, img_transform, lang_dropout, stream, prefix, no_lang=nl, is_val=True) else: # Special Generative Language Dataset... train_ds = GenStateOKDataset(epoch, index, img_transform, gen_ratio, stream, prefix) val_ds = GenStateOKDataset(epoch, index, img_transform, gen_ratio, stream, prefix, is_val=True) elif data_modality == "quintet+language": train_ds = QuintetDataset(epoch, index, img_transform, lang_dropout, stream, prefix) val_ds = QuintetDataset(epoch, index, img_transform, lang_dropout, stream, prefix, is_val=True) else: raise NotImplementedError(f"Support for data modality `{data_modality}` not yet implemented!") return train_ds, val_ds ================================================ FILE: voltron/models/__init__.py ================================================ from .instantiate import VMVP, VR3M, VRN3M, VCond, VDual, VGen, get_model_optimizer ================================================ FILE: voltron/models/core/__init__.py ================================================ ================================================ FILE: voltron/models/core/vcond.py ================================================ """ vcond.py PyTorch Module defining the Voltron `V-Cond` variant (single-frame with language-conditioning). In general, follows the MAE recipe, with the architectural modifications described in the paper: - RMSNorm, for stability/performance ("Do Transformer Modifications Transfer...") - SwishGLU activations in the Attention Block Feed-Forward MLP (gated linear units) as used in PaLM - LayerScale with a default value of 0.1 (from Mistral/CaIT) References: - https://github.com/facebookresearch/mae - https://github.com/lucidrains/x-transformers """ from typing import Callable, List, Optional, Tuple, Union import torch import torch.nn as nn import transformers from einops import rearrange, repeat from voltron.models.util.optimization import get_lr_update from voltron.models.util.transformer import Block, PatchEmbed, RMSNorm, get_2D_position_embeddings # Suppress Transformers Logging transformers.logging.set_verbosity_error() class VCond(nn.Module): def __init__( self, resolution: int, patch_size: int, encoder_depth: int, encoder_embed_dim: int, encoder_n_heads: int, decoder_depth: int, decoder_embed_dim: int, decoder_n_heads: int, language_model: str, hf_cache: str, language_dim: int, optimizer: str, schedule: str, base_lr: float, min_lr: float, effective_bsz: float, betas: Tuple[float, float], weight_decay: float, warmup_epochs: int, max_epochs: int, mask_ratio: float = 0.75, mlp_ratio: float = 4.0, in_channels: int = 3, norm_pixel_loss: bool = True, use_cls_token: bool = False, ) -> None: """ Initialize a VCond model with the requisite architecture parameters. :param resolution: Base image resolution -- usually 224 (ImageNet size). :param patch_size: Height/Width of each patch in pixels -- usually 16. :param encoder_depth: Number of Transformer blocks in the encoder -- should be greater than decoder. :param encoder_embed_dim: Core embedding/hidden dimension for encoder vision transformer backbone. :param encoder_n_heads: Number of heads for encoder multi-headed self-attention. :param decoder_depth: Number of Transformer blocks in the decoder -- should be relatively shallow. :param decoder_embed_dim: Core embedding/hidden dimension for encoder vision transformer backbone. :param decoder_n_heads: Number of heads for encoder multi-headed self-attention. :param language_model: Language model to freeze for encoding narrations/utterances. :param hf_cache: Cache directory to store pretrained models, for safe distributed training. :param language_dim: Dimensionality of the language embedding coming out of the pretrained LM. :param optimizer: String denoting which optimizer to use (for MAEs, usually `adamw`) :param schedule: Learning rate schedule to use; for Transformers a linear warmup + decay is recommended! :param base_lr: Base learning rate, to be scaled via a linear scaling rule (from scaling laws). :param min_lr: Minimum learning rate to decay to over the course of learning (usually 0.0) :param effective_bsz: Global batch size for update, dictates the scaling of the base_lr. :param betas: Adam optimizer betas (only applicable for `adam` and `adamw`. Prevents early loss spiking. :param weight_decay: Weight decay for global weight regularization (only applied to non-bias, non-LN layers). :param warmup_epochs: Number of epochs to warmup learning rate for linear warmup schedule. :param max_epochs: Total number of training epochs to be run. :param mask_ratio: Ratio for number of patches to mask out for MAE -- should be fairly high! :param mlp_ratio: Ratio for embedding size to Position-wise Feed-Forward MLP (gets shrunk back down). :param in_channels: Default number of channels in the base image -- almost always 3. :param norm_pixel_loss: Normalize decoder pixel targets for reconstruction (better perf, not interpretable). :param use_cls_token: Add token for continued pretraining (NOTE: not used in MAE pretraining/finetuning!) """ super().__init__() self.resolution, self.patch_size, self.mask_ratio = resolution, patch_size, mask_ratio self.in_channels, self.norm_pixel_loss, self.mlp_ratio = in_channels, norm_pixel_loss, mlp_ratio self.optimizer, self.schedule, self.betas, self.weight_decay = optimizer, schedule, betas, weight_decay self.lr, self.base_lr, self.min_lr, self.effective_bsz = None, base_lr, min_lr, effective_bsz self.warmup_epochs, self.max_epochs = warmup_epochs, max_epochs self.use_cls_token = use_cls_token self.language_dim = language_dim # Encoder/Decoder Parameters self.encoder_depth, self.decoder_depth = encoder_depth, decoder_depth self.encoder_embed_dim, self.encoder_n_heads = encoder_embed_dim, encoder_n_heads self.decoder_embed_dim, self.decoder_n_heads = decoder_embed_dim, decoder_n_heads # General Parameters (for downstream adaptation) self.embed_dim, self.n_heads = self.encoder_embed_dim, self.encoder_n_heads # (Optional) Token Handling if self.use_cls_token: self.cls_token = nn.Parameter(torch.zeros(1, 1, self.encoder_embed_dim)) # MAE Encoder Parameters self.patch2embed = PatchEmbed( self.resolution, self.patch_size, self.encoder_embed_dim, in_channels=self.in_channels ) self.encoder_pe = nn.Parameter( torch.zeros(1, self.patch2embed.num_patches + (1 if self.use_cls_token else 0), self.encoder_embed_dim), requires_grad=False, ) self.encoder_blocks = nn.ModuleList( [ Block( self.encoder_embed_dim, self.encoder_n_heads, self.mlp_ratio, do_rms_norm=True, do_swish_glu=True, do_layer_scale=True, ) for _ in range(self.encoder_depth) ] ) self.encoder_norm = RMSNorm(self.encoder_embed_dim) # Projection from Language Embedding to Decoder self.lang2encoder = nn.Linear(self.language_dim, self.encoder_embed_dim) # Projection from Encoder to Decoder self.encoder2decoder = nn.Linear(self.encoder_embed_dim, self.decoder_embed_dim) # MAE Decoder Parameters -- Remember the CLS Token (if specified)! self.mask_token = nn.Parameter(torch.zeros(1, 1, self.decoder_embed_dim)) self.decoder_pe = nn.Parameter( torch.zeros(1, self.patch2embed.num_patches + (1 if self.use_cls_token else 0), self.decoder_embed_dim), requires_grad=False, ) self.decoder_blocks = nn.ModuleList( [ Block( self.decoder_embed_dim, self.decoder_n_heads, self.mlp_ratio, do_rms_norm=True, do_swish_glu=True, do_layer_scale=True, ) for _ in range(self.decoder_depth) ] ) self.decoder_norm = RMSNorm(self.decoder_embed_dim) self.decoder_prediction = nn.Linear(self.decoder_embed_dim, (patch_size**2) * in_channels, bias=True) # VCond -- Add "Image" and "Language" Modifier Tokens... self.img_token = nn.Parameter(torch.zeros(1, 1, self.encoder_embed_dim)) self.lang_token = nn.Parameter(torch.zeros(1, 1, self.encoder_embed_dim)) # Initialize all Weights self.initialize_weights() # @AFTER INITIALIZATION -- Create Language Model & Language Reward MLP --> LM has requires_grad = False # > For BERT models, our "embedding" is just going to be the last hidden state # > Assumes inputs to forward pass are pre-tokenized! self.tokenizer = transformers.AutoTokenizer.from_pretrained(language_model, cache_dir=hf_cache) self.lm = transformers.AutoModel.from_pretrained(language_model, cache_dir=hf_cache) self.lm.eval() # Shape Assertion -- make sure self.language_dim actually is the same as the LM dimension! assert self.lm.config.dim == self.language_dim, "Language model embedding dimension != self.language_dim!" # Freeze the LM for _, param in self.lm.named_parameters(): param.requires_grad = False def initialize_weights(self) -> None: # Position Encoding -- Fixed 2D Sine-Cosine Embeddings enc_pe = get_2D_position_embeddings( self.encoder_embed_dim, int(self.patch2embed.num_patches**0.5), cls_token=self.use_cls_token ) self.encoder_pe.data.copy_(torch.from_numpy(enc_pe).float().unsqueeze(0)) dec_pe = get_2D_position_embeddings( self.decoder_embed_dim, int(self.patch2embed.num_patches**0.5), cls_token=self.use_cls_token ) self.decoder_pe.data.copy_(torch.from_numpy(dec_pe).float().unsqueeze(0)) # Initialize PatchEmbedding as a Linear... nn.init.xavier_uniform_(self.patch2embed.proj.weight.data.view([self.patch2embed.proj.weight.data.shape[0], -1])) # Initialize Mask Token, Img Token, Lang Token w/ Truncated Normal nn.init.normal_(self.mask_token, std=0.02) nn.init.normal_(self.img_token, std=0.02) nn.init.normal_(self.lang_token, std=0.02) if self.use_cls_token: nn.init.normal_(self.cls_token, std=0.02) # Default Transformer initializer on everything else... self.apply(self.transformer_initializer) @staticmethod def transformer_initializer(m: nn.Module) -> None: # Use `xavier_uniform` following Jax ViT if isinstance(m, nn.Linear): torch.nn.init.xavier_uniform_(m.weight) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0.0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.weight, 1.0) nn.init.constant_(m.bias, 0.0) def encode_language(self, lang: torch.Tensor, lang_mask: torch.Tensor) -> torch.Tensor: """Encode language by feeding the *pre-tokenized text* through the frozen language model.""" self.lm.eval() with torch.no_grad(): transformer_embeddings = self.lm(lang, attention_mask=lang_mask).last_hidden_state return transformer_embeddings def mask( self, patches: torch.Tensor, mask_ratio: Optional[float] = None ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Perform per-sample random masking by shuffling :: uses argsort random noise to identify masked patches.""" bsz, n_patches, embed_dim = patches.shape if mask_ratio is not None: n_keep = int(n_patches * (1 - mask_ratio)) else: n_keep = int(n_patches * (1 - self.mask_ratio)) # Sample noise of n_patches size, argsort to get shuffled IDs, argsort again to get "unshuffle" # > For clarity -- argsort is an invertible transformation (if argsort `restore`, recovers `shuffle`) shuffle_idxs = torch.argsort(torch.rand(bsz, n_patches, device=patches.device), dim=1) restore_idxs = torch.argsort(shuffle_idxs, dim=1) # Get "keep" (visible) patches visible_patches = torch.gather(patches, dim=1, index=shuffle_idxs[:, :n_keep, None].repeat(1, 1, embed_dim)) # Generate the binary mask --> IMPORTANT :: `0` is keep, `1` is remove (following MAE convention) mask = torch.ones(bsz, n_patches, device=patches.device) mask[:, :n_keep] = 0 mask = torch.gather(mask, dim=1, index=restore_idxs) return visible_patches, mask, restore_idxs def get_representations( self, img: torch.Tensor, language: Optional[Union[List[str], Tuple[str]]] = None, mode: str = "multimodal" ) -> torch.Tensor: """ Given either a singleton (img, language) pair or a batch of images and language, extract representations subject to the specified mode in < multimodal | visual >. :param img: Processed batch of images :: [bsz, 3, 224, 224] :param language: Input language as `List[str] | Tuple[str] | None` :param mode: Type of representations to extract -- `multimodal` (both vision+text), `visual` (visual only) :return: Extracted representations given (img, language) input as sequence. """ assert img.ndim == 4 and ( language is None or isinstance(language, list) or isinstance(language, tuple) ), "Invalid input to `get_representations()`" assert mode in {"multimodal", "visual"}, f"Extraction mode `{mode}` not supported!" # Tokenize Language --> note max length is 20! if language is None: lang, lang_mask = [torch.zeros(img.shape[0], 20, dtype=int, device=self.lm.device) for _ in range(2)] lang[:, 0], lang_mask[:, 0] = self.tokenizer.cls_token_id, 1 else: tokens = self.tokenizer(language, return_tensors="pt", max_length=20, padding="max_length", truncation=True) lang, lang_mask = tokens["input_ids"].to(self.lm.device), tokens["attention_mask"].to(self.lm.device) # Tile Language & Language Mask if mismatch with # images! if not len(lang) == len(img): lang = repeat(lang, "b seq -> (bsz b) seq", bsz=img.size(0)) lang_mask = repeat(lang_mask, "b seq -> (bsz b) seq", bsz=img.size(0)) # Extract desired representations... representations = self.encode(img, lang, lang_mask) return representations if mode == "multimodal" else representations[:, : -lang_mask.shape[-1]] def encode(self, img: torch.Tensor, lang: torch.Tensor, lang_mask: torch.Tensor) -> torch.Tensor: """Default representation extraction function, given a batch of images and tokenized language.""" lang_embeddings = self.encode_language(lang, lang_mask) projected_language = self.lang2encoder(lang_embeddings) # Patchify patches = self.patch2embed(img) # (Optional) Token Handling if self.use_cls_token: cls_tokens = self.cls_token.expand(img.shape[0], -1, -1) patches = torch.cat([cls_tokens, patches], dim=1) # Position Encoding patches_pe = patches + self.encoder_pe # Add "modality" embeddings to patches & language img_embeddings, lang_embeddings = patches_pe + self.img_token, projected_language + self.lang_token # Create "dummy" visible mask, concatenate image patches & language, feed to Transformer patches_mask = torch.ones_like(img_embeddings[..., -1], dtype=lang_mask.dtype) multimodal_embeddings = torch.cat([img_embeddings, lang_embeddings], dim=1) # Merge on sequence length... multimodal_mask = torch.cat([patches_mask, lang_mask], dim=1) # Merge on sequence length... # Apply Transformer Blocks... for block in self.encoder_blocks: multimodal_embeddings = block(multimodal_embeddings, multimodal_mask) multimodal_embeddings = self.encoder_norm(multimodal_embeddings) # Return the full sequence of multimodal embeddings... return multimodal_embeddings def forward_encoder( self, img: torch.Tensor, lang: torch.Tensor, lang_mask: torch.Tensor, mask_ratio: Optional[float] = None ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: lang_embeddings = self.encode_language(lang, lang_mask) projected_lang = self.lang2encoder(lang_embeddings) # Patchify + Position Embedding (without Token!) patches = self.patch2embed(img) patches_pe = patches + (self.encoder_pe if not self.use_cls_token else self.encoder_pe[:, 1:, :]) # Create mask (and go ahead and mask out patches at the same time) visible_patches, mask, restore_idxs = self.mask(patches_pe, mask_ratio) # (Optional) Token Handling if self.use_cls_token: cls_token_pe = self.cls_token + self.encoder_pe[:, :1, :] cls_tokens = cls_token_pe.expand(img.shape[0], -1, -1) visible_patches = torch.cat([cls_tokens, visible_patches], dim=1) # Add "modality" embeddings to patches & language visible_patches, projected_lang = visible_patches + self.img_token, projected_lang + self.lang_token # Create "dummy" visible mask, concatenate image patches & language, feed to Transformer visible_mask = torch.ones_like(visible_patches[..., -1], dtype=lang_mask.dtype) multimodal_embedding = torch.cat([visible_patches, projected_lang], dim=1) # Merge on sequence length... multimodal_mask = torch.cat([visible_mask, lang_mask], dim=1) # Merge on sequence length... # Apply Transformer Blocks... for block in self.encoder_blocks: multimodal_embedding = block(multimodal_embedding, multimodal_mask) multimodal_embedding = self.encoder_norm(multimodal_embedding) # Split multimodal embedding, remove language and return only the visible patches (+ optional token)! visible_patches = multimodal_embedding[:, : -lang_mask.shape[-1], ...] return visible_patches, mask, restore_idxs def forward_decoder(self, visible_patches: torch.Tensor, restore_idxs: torch.Tensor) -> torch.Tensor: # Project patches into decoder embedding dimension projected_patches = self.encoder2decoder(visible_patches) # Add Mask Tokens to Sequence & Unshuffle mask_tokens = self.mask_token.repeat( projected_patches.shape[0], restore_idxs.shape[1] - visible_patches.shape[1] + (1 if self.use_cls_token else 0), 1, ) # (Optional) Token Handling if self.use_cls_token: # Remove & add back CLS Token as part of the "unshuffling" concatenated_patches = torch.cat([projected_patches[:, 1:, :], mask_tokens], dim=1) # Skip CLS Token no_cls_unshuffled_patches = torch.gather( concatenated_patches, dim=1, index=restore_idxs[..., None].repeat(1, 1, self.decoder_embed_dim) ) unshuffled_patches = torch.cat([projected_patches[:, :1, :], no_cls_unshuffled_patches], dim=1) else: concatenated_patches = torch.cat([projected_patches, mask_tokens], dim=1) unshuffled_patches = torch.gather( concatenated_patches, dim=1, index=restore_idxs[..., None].repeat(1, 1, self.decoder_embed_dim) ) # Add Position Embeddings decoder_patches = unshuffled_patches + self.decoder_pe # Apply Transformer Blocks... for block in self.decoder_blocks: decoder_patches = block(decoder_patches) decoder_patches = self.decoder_norm(decoder_patches) # Run final projection & return --> note token handling! decoder_prediction = self.decoder_prediction(decoder_patches) return decoder_prediction if not self.use_cls_token else decoder_prediction[:, 1:, :] def patchify(self, imgs: torch.Tensor) -> torch.Tensor: """Convert a batch of images to their patched equivalents, by naive reshaping""" return rearrange( imgs, "bsz c (height patch_h) (width patch_w) -> bsz (height width) (patch_h patch_w c)", patch_h=self.patch_size, patch_w=self.patch_size, ) def compute_loss(self, imgs: torch.Tensor, reconstructions: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: assert self.norm_pixel_loss, "`norm_pixel_loss` should always be true... false only for visualizations!" targets = self.patchify(imgs) # Normalize targets... mu, var = targets.mean(dim=-1, keepdim=True), targets.var(dim=-1, unbiased=True, keepdim=True) targets = (targets - mu) / ((var + 1e-6) ** 0.5) # Compute mean loss per patch first... mse = (reconstructions - targets) ** 2 avg_loss_per_patch = mse.mean(dim=-1) # Compute mean loss only on *removed* patches and return return (avg_loss_per_patch * mask).sum() / mask.sum() def forward( self, img: torch.Tensor, lang: torch.Tensor, lang_mask: torch.Tensor, mask_ratio: Optional[float] = None ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: visible_patches, mask, restore_idxs = self.forward_encoder(img, lang, lang_mask, mask_ratio) reconstructions = self.forward_decoder(visible_patches, restore_idxs) loss = self.compute_loss(img, reconstructions, mask) return loss, reconstructions, mask def configure_optimizer(self) -> Tuple[torch.optim.Optimizer, Callable[[int, float], float]]: # Short-Circuit on Valid Optimizers if self.optimizer not in ["adamw"]: raise NotImplementedError(f"Optimizer `{self.optimizer}` not supported - try [`adamw`] instead!") # Create Parameter Groups --> bias terms, normalization layer parameters shouldn't be decayed... # > This is a compact rewrite of `param_groups_weight_decay()` from TIMM because I don't want the dependency decay, no_decay = [], [] for name, param in self.named_parameters(): if not param.requires_grad: continue # Check on any parameters with fewer than 2 dimensions or with "bias" in the name... if param.ndim <= 1 or name.endswith(".bias"): no_decay.append(param) else: decay.append(param) # Build Parameter Groups groups = [{"params": decay, "weight_decay": self.weight_decay}, {"params": no_decay, "weight_decay": 0.0}] # Compute LR -- MAE uses the `linear scaling rule` :: lr = base_lr * (effective_bsz / 256) # > https://github.com/facebookresearch/mae/blob/main/PRETRAIN.md self.lr = self.base_lr * (self.effective_bsz / 256) # Create Optimizer & LR Scheduler optimizer = torch.optim.AdamW(groups, lr=self.lr, betas=self.betas) update_lr = get_lr_update(optimizer, self.schedule, self.lr, self.min_lr, self.warmup_epochs, self.max_epochs) return optimizer, update_lr ================================================ FILE: voltron/models/core/vdual.py ================================================ """ vdual.py PyTorch Module defining the Voltron `V-Dual` variant (dual-frame with language-conditioning). In general, follows the MAE recipe, with the same modifications described in `vcond.py`. When masking visual patches, the same patches are elided for both the 0th frame and the Kth frame to avoid cheating! References: - https://github.com/huggingface/m4/blob/main/m4/modeling/pretraining/video/videomae.py - https://github.com/lucidrains/vit-pytorch - https://github.com/MCG-NJU/VideoMAE """ from typing import Callable, List, Optional, Tuple, Union import torch import torch.nn as nn import transformers from einops import rearrange, repeat from voltron.models.util.optimization import get_lr_update from voltron.models.util.transformer import Block, PatchEmbed, RMSNorm, get_2D_position_embeddings # Suppress Transformers Logging transformers.logging.set_verbosity_error() class VDual(nn.Module): def __init__( self, resolution: int, patch_size: int, encoder_depth: int, encoder_embed_dim: int, encoder_n_heads: int, decoder_depth: int, decoder_embed_dim: int, decoder_n_heads: int, language_model: str, hf_cache: str, language_dim: int, optimizer: str, schedule: str, base_lr: float, min_lr: float, effective_bsz: int, betas: Tuple[float, float], weight_decay: float, warmup_epochs: int, max_epochs: int, mask_ratio: float = 0.75, mlp_ratio: float = 4.0, in_channels: int = 3, norm_pixel_loss: bool = True, use_cls_token: bool = False, ) -> None: """ Initialize a VDual model with the requisite architecture parameters. :param resolution: Base image resolution -- usually 224 (ImageNet size). :param patch_size: Height/Width of each patch in pixels -- usually 16. :param encoder_depth: Number of Transformer blocks in the encoder -- should be greater than decoder. :param encoder_embed_dim: Core embedding/hidden dimension for encoder vision transformer backbone. :param encoder_n_heads: Number of heads for encoder multi-headed self-attention. :param decoder_depth: Number of Transformer blocks in the decoder -- should be relatively shallow. :param decoder_embed_dim: Core embedding/hidden dimension for encoder vision transformer backbone. :param decoder_n_heads: Number of heads for encoder multi-headed self-attention. :param language_model: Language model to freeze for encoding narrations/utterances. :param hf_cache: Cache directory to store pretrained models, for safe distributed training. :param language_dim: Dimensionality of the language embedding coming out of the pretrained LM. :param optimizer: String denoting which optimizer to use (for MAEs, usually `adamw`) :param schedule: Learning rate schedule to use; for Transformers a linear warmup + decay is recommended! :param base_lr: Base learning rate, to be scaled via a linear scaling rule (from scaling laws). :param min_lr: Minimum learning rate to decay to over the course of learning (usually 0.0) :param effective_bsz: Global batch size for update, dictates the scaling of the base_lr. :param betas: Adam optimizer betas (only applicable for `adam` and `adamw`. Prevents early loss spiking. :param weight_decay: Weight decay for global weight regularization (only applied to non-bias, non-LN layers). :param warmup_epochs: Number of epochs to warmup learning rate for linear warmup schedule. :param max_epochs: Total number of training epochs to be run. :param mask_ratio: Ratio for number of patches to mask out for MAE -- should be fairly high! :param mlp_ratio: Ratio for embedding size to Position-wise FeedForward MLP (gets shrunk back down). :param in_channels: Default number of channels in the base image -- almost always 3. :param norm_pixel_loss: Normalize decoder pixel targets for reconstruction (better perf, not interpretable). :param use_cls_token: Add token for continued pretraining (NOTE: not used in MAE pretraining/finetuning!) """ super().__init__() self.resolution, self.patch_size, self.mask_ratio = resolution, patch_size, mask_ratio self.in_channels, self.norm_pixel_loss, self.mlp_ratio = in_channels, norm_pixel_loss, mlp_ratio self.optimizer, self.schedule, self.betas, self.weight_decay = optimizer, schedule, betas, weight_decay self.lr, self.base_lr, self.min_lr, self.effective_bsz = None, base_lr, min_lr, effective_bsz self.warmup_epochs, self.max_epochs = warmup_epochs, max_epochs self.use_cls_token = use_cls_token self.language_dim = language_dim # Encoder/Decoder Parameters self.encoder_depth, self.decoder_depth = encoder_depth, decoder_depth self.encoder_embed_dim, self.encoder_n_heads = encoder_embed_dim, encoder_n_heads self.decoder_embed_dim, self.decoder_n_heads = decoder_embed_dim, decoder_n_heads # General Parameters (for downstream adaptation) self.embed_dim, self.n_heads = self.encoder_embed_dim, self.encoder_n_heads # (Optional) Token Handling if self.use_cls_token: self.cls_token = nn.Parameter(torch.zeros(1, 1, self.encoder_embed_dim)) # MAE Encoder Parameters self.patch2embed = PatchEmbed( self.resolution, self.patch_size, self.encoder_embed_dim, in_channels=self.in_channels ) self.encoder_pe = nn.Parameter( torch.zeros(1, self.patch2embed.num_patches + (1 if self.use_cls_token else 0), self.encoder_embed_dim), requires_grad=False, ) self.encoder_blocks = nn.ModuleList( [ Block( self.encoder_embed_dim, self.encoder_n_heads, self.mlp_ratio, do_rms_norm=True, do_swish_glu=True, do_layer_scale=True, ) for _ in range(self.encoder_depth) ] ) self.encoder_norm = RMSNorm(self.encoder_embed_dim) # Projection from Language Embedding to Decoder self.lang2encoder = nn.Linear(self.language_dim, self.encoder_embed_dim) # Projection from Encoder to Decoder self.encoder2decoder = nn.Linear(self.encoder_embed_dim, self.decoder_embed_dim) # MAE Decoder Parameters -- Remember the CLS Token (if specified)! self.mask_token = nn.Parameter(torch.zeros(1, 1, 1, self.decoder_embed_dim)) self.decoder_pe = nn.Parameter( torch.zeros(1, self.patch2embed.num_patches + (1 if self.use_cls_token else 0), self.decoder_embed_dim), requires_grad=False, ) self.decoder_blocks = nn.ModuleList( [ Block( self.decoder_embed_dim, self.decoder_n_heads, self.mlp_ratio, do_rms_norm=True, do_swish_glu=True, do_layer_scale=True, ) for _ in range(self.decoder_depth) ] ) self.decoder_norm = RMSNorm(self.decoder_embed_dim) self.decoder_prediction = nn.Linear(self.decoder_embed_dim, (patch_size**2) * in_channels, bias=True) # VDual -- Add "Image" and "Language" Modifier Tokens... self.img_token = nn.Parameter(torch.zeros(1, 1, 1, self.encoder_embed_dim)) self.lang_token = nn.Parameter(torch.zeros(1, 1, self.encoder_embed_dim)) # VDual -- Learnable "ctx" position embeddings --> initialize via `randn` following original ViT & @lucidrains # =>> Ref: https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py#L99 # =>> Note that n_context = 2 (0th frame + Kth frame) self.ctx_enc_pe = nn.Parameter(torch.randn(1, 2, 1, self.encoder_embed_dim)) self.ctx_dec_pe = nn.Parameter(torch.randn(1, 2, 1, self.decoder_embed_dim)) # Initialize all Weights self.initialize_weights() # @AFTER INITIALIZATION -- Create Language Model & Language Reward MLP --> LM has requires_grad = False # > For BERT models, our "embedding" is just going to be the last hidden state # > Assumes inputs to forward pass are pre-tokenized! self.tokenizer = transformers.AutoTokenizer.from_pretrained(language_model, cache_dir=hf_cache) self.lm = transformers.AutoModel.from_pretrained(language_model, cache_dir=hf_cache) self.lm.eval() # Shape Assertion -- make sure self.language_dim actually is the same as the LM dimension! assert self.lm.config.dim == self.language_dim, "Language model embedding dimension != self.language_dim!" # Freeze the LM for _, param in self.lm.named_parameters(): param.requires_grad = False def initialize_weights(self) -> None: # Position Encoding -- Fixed 2D Sine-Cosine Embeddings enc_pe = get_2D_position_embeddings( self.encoder_embed_dim, int(self.patch2embed.num_patches**0.5), cls_token=self.use_cls_token ) self.encoder_pe.data.copy_(torch.from_numpy(enc_pe).float().unsqueeze(0)) dec_pe = get_2D_position_embeddings( self.decoder_embed_dim, int(self.patch2embed.num_patches**0.5), cls_token=self.use_cls_token ) self.decoder_pe.data.copy_(torch.from_numpy(dec_pe).float().unsqueeze(0)) # Initialize PatchEmbedding as a Linear... nn.init.xavier_uniform_(self.patch2embed.proj.weight.data.view([self.patch2embed.proj.weight.data.shape[0], -1])) # Initialize Mask Token, Img Token, Lang Token w/ Truncated Normal nn.init.normal_(self.mask_token, std=0.02) nn.init.normal_(self.img_token, std=0.02) nn.init.normal_(self.lang_token, std=0.02) if self.use_cls_token: nn.init.normal_(self.cls_token, std=0.02) # Everything else... self.apply(self.transformer_initializer) @staticmethod def transformer_initializer(m: nn.Module) -> None: if isinstance(m, nn.Linear): # Use xavier_uniform following Jax ViT torch.nn.init.xavier_uniform_(m.weight) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0.0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.weight, 1.0) nn.init.constant_(m.bias, 0.0) def encode_language(self, lang: torch.Tensor, lang_mask: torch.Tensor) -> torch.Tensor: """Encode language by feeding the *pre-tokenized text* through the frozen language model.""" self.lm.eval() with torch.no_grad(): transformer_embeddings = self.lm(lang, attention_mask=lang_mask).last_hidden_state return transformer_embeddings def mask( self, ctx_patches: torch.Tensor, mask_ratio: Optional[float] = None ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Perform per-context random masking by shuffling :: uses argsort random noise to identify masked patches.""" bsz, ctx_len, n_patches, embed_dim = ctx_patches.shape if mask_ratio is not None: n_keep = int(n_patches * (1 - mask_ratio)) else: n_keep = int(n_patches * (1 - self.mask_ratio)) # Sample noise of n_patches size, argsort to get shuffled IDs, argsort again to get "unshuffle" # > For clarity -- argsort is an invertible transformation (if argsort `restore`, recovers `shuffle`) # > Note that shuffle_idxs is defined solely as a function of *n_patches* and **not** context! Same mask! shuffle_idxs = torch.argsort(torch.rand(bsz, n_patches, device=ctx_patches.device), dim=1) restore_idxs = torch.argsort(shuffle_idxs, dim=1) # Get "keep" (visible) patches --> make sure to get _same_ patches *across* context length! visible_patches = torch.gather( ctx_patches, dim=2, index=shuffle_idxs[:, None, :n_keep, None].repeat(1, ctx_len, 1, embed_dim) ) # Generate the binary mask --> IMPORTANT :: `0` is keep, `1` is remove (following MAE convention) mask = torch.ones(bsz, n_patches, device=ctx_patches.device) mask[:, :n_keep] = 0 mask = torch.gather(mask, dim=1, index=restore_idxs) return visible_patches, mask, restore_idxs def get_representations( self, imgs: torch.Tensor, language: Optional[Union[List[str], Tuple[str]]] = None, mode: str = "multimodal" ) -> torch.Tensor: """ Given either a singleton (dual-imgs, language) pair or batch of dual-imgs and language, extract representations subject to the specified mode in < multimodal | visual >. :param imgs: Processed batch of images :: [bsz, 2, 3, 224, 224] :param language: Input language as `List[str] | Tuple[str] | None` :param mode: Type of representations to extract -- `multimodal` (both vision+text), `visual` (visual only) :return: Extracted representations given (imgs, language) input as sequence. """ assert ( imgs.ndim == 5 and imgs.shape[1] == 2 and (language is None or isinstance(language, list) or isinstance(language, tuple)) ), "Invalid input to `get_representations()`" assert mode in {"multimodal", "visual"}, f"Extraction mode `{mode}` not supported!" # Tokenize Language --> note max length is 20! if language is None: lang, lang_mask = [torch.zeros(imgs.shape[0], 20, dtype=int, device=self.lm.device) for _ in range(2)] lang[:, 0], lang_mask[:, 0] = self.tokenizer.cls_token_id, 1 else: tokens = self.tokenizer(language, return_tensors="pt", max_length=20, padding="max_length", truncation=True) lang, lang_mask = tokens["input_ids"].to(self.lm.device), tokens["attention_mask"].to(self.lm.device) # Tile Language & Language Mask if mismatch with # images! if not len(lang) == len(imgs): lang = repeat(lang, "b seq -> (bsz b) seq", bsz=imgs.size(0)) lang_mask = repeat(lang_mask, "b seq -> (bsz b) seq", bsz=imgs.size(0)) # Extract desired representations... representations = self.encode(imgs, lang, lang_mask) return representations if mode == "multimodal" else representations[:, : -lang_mask.shape[-1]] def encode(self, imgs: torch.Tensor, lang: torch.Tensor, lang_mask: torch.Tensor) -> torch.Tensor: """Default representation extraction function, given a batch of dual-images and tokenized language.""" lang_embeddings = self.encode_language(lang, lang_mask) projected_language = self.lang2encoder(lang_embeddings) # Patchify, broadcast position embedding ctx_len (0 + K) dimension, unfold, add `ctx_enc_pe` embeddings! patches = self.patch2embed(rearrange(imgs, "bsz ctx channels res1 res2 -> (bsz ctx) channels res1 res2")) patches_pe = patches + (self.encoder_pe[:, 1:, :] if self.use_cls_token else self.encoder_pe) ctx_patches = rearrange(patches_pe, "(bsz ctx) seq embed -> bsz ctx seq embed", ctx=2) ctx_patches_pe = ctx_patches + self.ctx_enc_pe[:, :2, ...] # Add "modality" embeddings to patches & language & flatten out context patches... img_ctx_embeddings, lang_embeddings = ctx_patches_pe + self.img_token, projected_language + self.lang_token img_embeddings = rearrange(img_ctx_embeddings, "bsz ctx seq embed -> bsz (ctx seq) embed") # (Optional) Token Handling if self.use_cls_token: cls_token_pe = self.cls_token + self.encoder_pe[:, :1, :] + self.img_token[:, 0, :, :] cls_tokens = cls_token_pe.expand(imgs.shape[0], -1, -1) img_embeddings = torch.cat([cls_tokens, img_embeddings], dim=1) # Create "dummy" visible mask, concatenate image patches & language, feed to Transformer patches_mask = torch.ones_like(img_embeddings[..., -1], dtype=lang_mask.dtype) multimodal_embeddings = torch.cat([img_embeddings, lang_embeddings], dim=1) # Merge on sequence length... multimodal_mask = torch.cat([patches_mask, lang_mask], dim=1) # Merge on sequence length... # Apply Transformer Blocks... for block in self.encoder_blocks: multimodal_embeddings = block(multimodal_embeddings, multimodal_mask) multimodal_embeddings = self.encoder_norm(multimodal_embeddings) # Return the full sequence of multimodal embeddings (but ignore 0th frame) => the `~` denote what to remove! # => [CLS] + ~[n_patches x 0th frame]~ + [n_patches x Kth frame] + [max_lang_len language] # => ~[n_patches x 0th frame]~ + [n_patches x Kth frame] + [max_lang_len language] if self.use_cls_token: return torch.cat( [multimodal_embeddings[:, :1, :], multimodal_embeddings[:, 1 + self.patch2embed.num_patches :, :]], dim=1 ) else: return multimodal_embeddings[:, self.patch2embed.num_patches :] def forward_encoder( self, img_ctx: torch.Tensor, lang: torch.Tensor, lang_mask: torch.Tensor, mask_ratio: Optional[float] = None ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: lang_embeddings = self.encode_language(lang, lang_mask) projected_lang = self.lang2encoder(lang_embeddings) # Patchify, broadcast position embedding across ctx_len (0 + K) dimension, unfold, add `ctx_enc_pe` embeddings! patches = self.patch2embed(rearrange(img_ctx, "bsz ctx channels res1 res2 -> (bsz ctx) channels res1 res2")) patches_pe = patches + (self.encoder_pe if not self.use_cls_token else self.encoder_pe[:, 1:, :]) ctx_patches = rearrange(patches_pe, "(bsz ctx) seq embed -> bsz ctx seq embed", ctx=2) ctx_patches_pe = ctx_patches + self.ctx_enc_pe[:, :2, ...] # Create mask (and go ahead and mask out patches at the same time) visible_ctx_patches, mask, restore_idxs = self.mask(ctx_patches_pe, mask_ratio) # Add "modality" embeddings to patches & language & flatten out context patches... visible_ctx_patches, lang = visible_ctx_patches + self.img_token, projected_lang + self.lang_token visible_patches = rearrange(visible_ctx_patches, "bsz ctx seq embed -> bsz (ctx seq) embed") # (Optional) Token Handling if self.use_cls_token: cls_token_pe = self.cls_token + self.encoder_pe[:, :1, :] + self.img_token[:, 0, :, :] cls_tokens = cls_token_pe.expand(img_ctx.shape[0], -1, -1) visible_patches = torch.cat([cls_tokens, visible_patches], dim=1) # Create "dummy" visible mask, concatenate image patches & language, feed to Transformer... visible_mask = torch.ones_like(visible_patches[..., -1], dtype=lang_mask.dtype) multimodal_embedding = torch.cat([visible_patches, lang], dim=1) # Merge on sequence length... multimodal_mask = torch.cat([visible_mask, lang_mask], dim=1) # Merge on sequence length... # Apply Transformer Blocks... for block in self.encoder_blocks: multimodal_embedding = block(multimodal_embedding, multimodal_mask) multimodal_embedding = self.encoder_norm(multimodal_embedding) # Split multimodal embedding, remove language, return the visible ctx (0th + Kth frame) patches (+ )! visible_patches = multimodal_embedding[:, : -lang_mask.shape[-1], ...] return visible_patches, mask, restore_idxs def forward_decoder(self, visible_patches: torch.Tensor, restore_idxs: torch.Tensor) -> torch.Tensor: # Project patches into decoder embedding dimension (visible_ctx_patches :: [bsz, (CLS) + 2 * seq, enc_embed]) projected_patches = self.encoder2decoder(visible_patches) visible_per_frame = (projected_patches.shape[1] - (1 if self.use_cls_token else 0)) // 2 # Add Mask Tokens to Sequence and Unshuffle mask_tokens = self.mask_token.repeat(projected_patches.shape[0], 2, restore_idxs.shape[1] - visible_per_frame, 1) # (Optional) Token Handling if self.use_cls_token: # Remove CLS Token as part of "unshuffling" projected_ctx_patches = rearrange( projected_patches[:, 1:, :], "bsz (ctx seq) embed -> bsz ctx seq embed", ctx=2 ) no_cls_concatenated_ctx_patches = torch.cat([projected_ctx_patches, mask_tokens], dim=2) unshuffled_ctx_patches = torch.gather( no_cls_concatenated_ctx_patches, dim=2, index=restore_idxs[:, None, ..., None].repeat(1, 2, 1, self.decoder_embed_dim), ) else: projected_ctx_patches = rearrange(projected_patches, "bsz (ctx seq) embed -> bsz ctx seq embed", ctx=2) concatenated_ctx_patches = torch.cat([projected_ctx_patches, mask_tokens], dim=2) unshuffled_ctx_patches = torch.gather( concatenated_ctx_patches, dim=2, index=restore_idxs[:, None, ..., None].repeat(1, 2, 1, self.decoder_embed_dim), ) # Add position embeddings, `ctx_dec_pe` embeddings, and flatten patches for Transformer... decoder_ctx_patches_pe = unshuffled_ctx_patches + ( self.decoder_pe[None, ...] if not self.use_cls_token else self.decoder_pe[None, :, 1:, :] ) decoder_ctx_patches = decoder_ctx_patches_pe + self.ctx_dec_pe[:, :2, ...] decoder_patches = rearrange(decoder_ctx_patches, "bsz ctx seq embed -> bsz (ctx seq) embed") # (Optional) Token Handling if self.use_cls_token: # Add back Token from `projected_patches[:, :1, :]` cls_embedding = projected_patches[:, :1, :] + self.decoder_pe[:, :1, :] decoder_patches = torch.cat([cls_embedding, decoder_patches], dim=1) # Apply Transformer Blocks... for block in self.decoder_blocks: decoder_patches = block(decoder_patches) decoder_patches = self.decoder_norm(decoder_patches) # Run final projection & return "unflattened" patches --> note token handling! decoder_prediction = self.decoder_prediction(decoder_patches) reconstructions = decoder_prediction if not self.use_cls_token else decoder_prediction[:, 1:, :] return rearrange(reconstructions, "bsz (ctx seq) embed -> bsz ctx seq embed", ctx=2) def patchify(self, imgs: torch.Tensor) -> torch.Tensor: """Convert a batch of (0th + Kth frame) images to their patched equivalents by naive reshaping.""" return rearrange( imgs, "bsz ctx c (height patch_h) (width patch_w) -> bsz ctx (height width) (patch_h patch_w c)", patch_h=self.patch_size, patch_w=self.patch_size, ) def compute_loss( self, imgs: torch.Tensor, ctx_reconstructions: torch.Tensor, mask: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: assert self.norm_pixel_loss, "`norm_pixel_loss` should always be true... false only for visualizations!" targets = self.patchify(imgs) # Normalize targets... mu, var = targets.mean(dim=-1, keepdim=True), targets.var(dim=-1, unbiased=True, keepdim=True) targets = (targets - mu) / ((var + 1e-6) ** 0.5) # Split targets into 0 and K --> do the same for ctx_reconstructions zero_target, k_target = targets[:, 0, ...], targets[:, 1, ...] zero_reconstruction, k_reconstruction = ctx_reconstructions[:, 0, ...], ctx_reconstructions[:, 1, ...] # Compute mean losses per patch first... zero_mse, k_mse = (zero_reconstruction - zero_target) ** 2, (k_reconstruction - k_target) ** 2 zero_avg_loss_per_patch, k_avg_loss_per_patch = zero_mse.mean(dim=-1), k_mse.mean(dim=-1) # Compute mean loss only on *removed* patches and return... return (zero_avg_loss_per_patch * mask).sum() / mask.sum(), (k_avg_loss_per_patch * mask).sum() / mask.sum() def forward( self, imgs: torch.Tensor, lang: torch.Tensor, lang_mask: torch.Tensor, mask_ratio: Optional[float] = None ) -> Tuple[torch.Tensor, List[torch.Tensor]]: """ Run a forward pass through the model, computing the language-conditioned MAE reconstruction loss on the 0th + Kth frame temporal context, given language prefix. :param imgs: A [bsz, 2, in_channels, resolution, resolution] tensor of (0th frame, Kth frame) sequences. :param lang: A [bsz, seq_len] tensor of language context to condition on. :param lang_mask: A [bsz, seq_len] binary mask tensor to indicate padding locations in the lang tensor. :param mask_ratio: Optional masking ratio to use instead of the default. :return Tuple of losses and intermediates, as follows: > (combined loss, [reconstruction loss per frame in {0, K}]) """ visible_ctx_patches, mask, restore_idxs = self.forward_encoder(imgs, lang, lang_mask, mask_ratio) ctx_reconstructions = self.forward_decoder(visible_ctx_patches, restore_idxs) zero_loss, k_loss = self.compute_loss(imgs, ctx_reconstructions, mask) # Return average reconstruction loss, individual losses... loss = (zero_loss + k_loss) / 2 return loss, [zero_loss, k_loss] def configure_optimizer(self) -> Tuple[torch.optim.Optimizer, Callable[[int, float], float]]: # Short-Circuit on Valid Optimizers if self.optimizer not in ["adamw"]: raise NotImplementedError(f"Optimizer `{self.optimizer}` not supported - try [`adamw`] instead!") # Create Parameter Groups --> Bias terms, Normalization layer parameters shouldn't be decayed... # > This is a compact rewrite of `param_groups_weight_decay()` from TIMM because I don't want the dependency decay, no_decay = [], [] for name, param in self.named_parameters(): if not param.requires_grad: continue # Check on any parameters with fewer than 2 dimensions or with "bias" in the name... if param.ndim <= 1 or name.endswith(".bias"): no_decay.append(param) else: decay.append(param) # Build Parameter Groups groups = [{"params": decay, "weight_decay": self.weight_decay}, {"params": no_decay, "weight_decay": 0.0}] # Compute LR -- MAE uses the `linear scaling rule` :: lr = base_lr * (effective_bsz / 256) # > https://github.com/facebookresearch/mae/blob/main/PRETRAIN.md self.lr = self.base_lr * (self.effective_bsz / 256) # Create Optimizer & LR Scheduler optimizer = torch.optim.AdamW(groups, lr=self.lr, betas=self.betas) update_lr = get_lr_update(optimizer, self.schedule, self.lr, self.min_lr, self.warmup_epochs, self.max_epochs) return optimizer, update_lr ================================================ FILE: voltron/models/core/vgen.py ================================================ """ vgen.py PyTorch Module defining the Voltron `V-Gen` variant (dual-frame with language-conditioning AND language-generation). This model adds the ability to *both* condition on language context or (XOR) generate language given masked frame context (with a hyperparameter (`alpha` in the paper) controlling the `gen_ratio` -- the ratio of examples for which to generate language). The objective this model seeks to optimize is the sum of the MAE reconstruction error (when conditioning on language) and the log-likelihood of predicting the next token given prior tokens and the entire learned image representation. Follows same dual-frame encoding structure as VDual, and architectural modifications from VCond. References: - https://github.com/lucidrains/x-transformers """ from typing import Callable, List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F import transformers from einops import rearrange, repeat from voltron.models.util.optimization import get_lr_update from voltron.models.util.transformer import Block, PatchEmbed, RMSNorm, get_2D_position_embeddings # Suppress Transformers Logging transformers.logging.set_verbosity_error() class VGen(nn.Module): def __init__( self, resolution: int, patch_size: int, encoder_depth: int, encoder_embed_dim: int, encoder_n_heads: int, decoder_depth: int, decoder_embed_dim: int, decoder_n_heads: int, language_model: str, hf_cache: str, language_dim: int, max_lang_len: int, vocab_size: int, mae_weight: float, lm_weight: float, optimizer: str, schedule: str, base_lr: float, min_lr: float, effective_bsz: float, betas: Tuple[float, float], weight_decay: float, warmup_epochs: int, max_epochs: int, mask_ratio: float = 0.75, mlp_ratio: float = 4.0, in_channels: int = 3, norm_pixel_loss: bool = True, use_cls_token: bool = False, eps: float = 1e-8, ) -> None: """ Initialize a VGen model with the requisite architecture parameters. :param resolution: Base image resolution -- usually 224 (ImageNet size). :param patch_size: Height/Width of each patch in pixels -- usually 16. :param encoder_depth: Number of Transformer blocks in the encoder -- should be greater than decoder. :param encoder_embed_dim: Core embedding/hidden dimension for encoder vision transformer backbone. :param encoder_n_heads: Number of heads for encoder multi-headed self-attention. :param decoder_depth: Number of Transformer blocks in the decoder -- should be relatively shallow. :param decoder_embed_dim: Core embedding/hidden dimension for encoder vision transformer backbone. :param decoder_n_heads: Number of heads for encoder multi-headed self-attention. :param language_model: Language model to freeze for encoding narrations/utterances. :param hf_cache: Cache directory to store pretrained models, for safe distributed training. :param language_dim: Dimensionality of the language embedding coming out of the pretrained LM. :param max_lang_len: Maximum length of input sequence (in tokens). :param vocab_size: Vocabulary size for final cross-entropy loss over token prediction. :param mae_weight: Weighting term for the MAE loss -- usually 1.0 (borrowed from M3AE paper as *rough* guide) :param lm_weight: Weighting term for the LM loss -- usually 0.5 (borrowed from the M3AE paper as *rough* guide) :param optimizer: String denoting which optimizer to use (for MAEs, usually `adamw`) :param schedule: Learning rate schedule to use; for Transformers a linear warmup + decay is recommended! :param base_lr: Base learning rate, to be scaled via a linear scaling rule (from scaling laws). :param min_lr: Minimum learning rate to decay to over the course of learning (usually 0.0) :param effective_bsz: Global batch size for update, dictates the scaling of the base_lr. :param betas: Adam optimizer betas (only applicable for `adam` and `adamw`. Prevents early loss spiking. :param weight_decay: Weight decay for global weight regularization (only applied to non-bias, non-LN layers). :param warmup_epochs: Number of epochs to warmup learning rate for linear warmup schedule. :param max_epochs: Total number of training epochs to be run. :param mask_ratio: Ratio for number of patches AND tokens to mask out for M3AE -- should be fairly high! :param mlp_ratio: Ratio for embedding size to Position-wise FeedForward MLP (gets shrunk back down). :param in_channels: Default number of channels in the base image -- almost always 3. :param norm_pixel_loss: Normalize decoder pixel targets for reconstruction (better perf, not interpretable). :param use_cls_token: Add token for continued pretraining (NOTE: not used in MAE pretraining/finetuning!) :param eps: Epsilon for preventing divide by zero. """ super().__init__() self.resolution, self.patch_size, self.mask_ratio, self.eps = resolution, patch_size, mask_ratio, eps self.in_channels, self.norm_pixel_loss, self.mlp_ratio = in_channels, norm_pixel_loss, mlp_ratio self.optimizer, self.schedule, self.betas, self.weight_decay = optimizer, schedule, betas, weight_decay self.lr, self.base_lr, self.min_lr, self.effective_bsz = None, base_lr, min_lr, effective_bsz self.mae_weight, self.lm_weight, self.language_dim = mae_weight, lm_weight, language_dim self.max_lang_len, self.vocab_size = max_lang_len, vocab_size self.use_cls_token = use_cls_token self.warmup_epochs, self.max_epochs = warmup_epochs, max_epochs # Encoder/Decoder Parameters self.encoder_depth, self.decoder_depth = encoder_depth, decoder_depth self.encoder_embed_dim, self.encoder_n_heads = encoder_embed_dim, encoder_n_heads self.decoder_embed_dim, self.decoder_n_heads = decoder_embed_dim, decoder_n_heads # General Parameters (for downstream adaptation) self.embed_dim, self.n_heads = self.encoder_embed_dim, self.encoder_n_heads # (Optional) Token Handling if self.use_cls_token: self.cls_token = nn.Parameter(torch.zeros(1, 1, self.encoder_embed_dim)) # MAE Encoder Parameters self.patch2embed = PatchEmbed( self.resolution, self.patch_size, self.encoder_embed_dim, in_channels=self.in_channels ) self.encoder_pe = nn.Parameter( torch.zeros(1, self.patch2embed.num_patches + (1 if self.use_cls_token else 0), self.encoder_embed_dim), requires_grad=False, ) self.encoder_blocks = nn.ModuleList( [ Block( self.encoder_embed_dim, self.encoder_n_heads, self.mlp_ratio, do_rms_norm=True, do_swish_glu=True, do_layer_scale=True, ) for _ in range(self.encoder_depth) ] ) self.encoder_norm = RMSNorm(self.encoder_embed_dim) # Projection from Language Embedding to Decoder self.lang2encoder = nn.Linear(self.language_dim, self.encoder_embed_dim) self.lang2decoder = nn.Linear(self.language_dim, self.decoder_embed_dim) # Projection from Encoder to Decoder self.encoder2decoder = nn.Linear(self.encoder_embed_dim, self.decoder_embed_dim) # MAE Decoder Parameters -- Remember the CLS Token (if specified)! self.mask_token = nn.Parameter(torch.zeros(1, 1, 1, self.decoder_embed_dim)) self.decoder_pe = nn.Parameter( torch.zeros(1, self.patch2embed.num_patches + (1 if self.use_cls_token else 0), self.decoder_embed_dim), requires_grad=False, ) self.decoder_blocks = nn.ModuleList( [ Block( self.decoder_embed_dim, self.decoder_n_heads, self.mlp_ratio, do_rms_norm=True, do_swish_glu=True, do_layer_scale=True, ) for _ in range(self.decoder_depth) ] ) self.decoder_norm = RMSNorm(self.decoder_embed_dim) self.decoder_patch_prediction = nn.Linear(self.decoder_embed_dim, (patch_size**2) * in_channels, bias=True) self.decoder_lang_prediction = nn.Linear(self.decoder_embed_dim, self.vocab_size, bias=True) # VGen -- Add "Image" and "Language" Modifier Tokens for Encoder & Decoder... self.img_enc_token = nn.Parameter(torch.zeros(1, 1, 1, self.encoder_embed_dim)) self.lang_enc_token = nn.Parameter(torch.zeros(1, 1, self.encoder_embed_dim)) # VGen -- Learnable "ctx" position embeddings --> initialize via `randn` following original ViT & @lucidrains # =>> Ref: https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py#L99 # =>> Note that n_context = 2 (0th frame + Kth frame) self.ctx_enc_pe = nn.Parameter(torch.randn(1, 2, 1, self.encoder_embed_dim)) self.ctx_dec_pe = nn.Parameter(torch.randn(1, 2, 1, self.decoder_embed_dim)) # Register Prefix Mask --> Lower Triangular ==> set prefix to 1 n_patches, total_seq = 2 * self.patch2embed.num_patches, (2 * self.patch2embed.num_patches) + self.max_lang_len prefix_mask = torch.tril(torch.ones((total_seq, total_seq), dtype=torch.uint8)) prefix_mask[:n_patches, :n_patches] = 1 # Register this once... we'll multiply by padding masks prior to feeding to Transformer self.register_buffer("prefix_mask", prefix_mask.view(1, 1, total_seq, total_seq)) # Initialize all Weights self.initialize_weights() # @AFTER INITIALIZATION -- Create Language Model & Language Reward MLP --> LM has requires_grad = False # > For BERT models, our "embedding" is just going to be the last hidden state # > Assumes inputs to forward pass are pre-tokenized! self.tokenizer = transformers.AutoTokenizer.from_pretrained(language_model, cache_dir=hf_cache) self.lm = transformers.AutoModel.from_pretrained(language_model, cache_dir=hf_cache) self.lm.eval() # Shape Assertion -- make sure self.language_dim actually is the same as the LM dimension! assert self.lm.config.dim == self.language_dim, "Language model embedding dimension != self.language_dim!" # Freeze the LM for _, param in self.lm.named_parameters(): param.requires_grad = False def initialize_weights(self) -> None: # Position Encoding -- Fixed 2D Sine-Cosine Embeddings enc_pe = get_2D_position_embeddings( self.encoder_embed_dim, int(self.patch2embed.num_patches**0.5), cls_token=self.use_cls_token ) self.encoder_pe.data.copy_(torch.from_numpy(enc_pe).float().unsqueeze(0)) dec_pe = get_2D_position_embeddings( self.decoder_embed_dim, int(self.patch2embed.num_patches**0.5), cls_token=self.use_cls_token ) self.decoder_pe.data.copy_(torch.from_numpy(dec_pe).float().unsqueeze(0)) # Initialize PatchEmbedding as a Linear... nn.init.xavier_uniform_(self.patch2embed.proj.weight.data.view([self.patch2embed.proj.weight.data.shape[0], -1])) # Initialize Mask Token, Img Token, Lang Token w/ Truncated Normal nn.init.normal_(self.mask_token, std=0.02) nn.init.normal_(self.img_enc_token, std=0.02) nn.init.normal_(self.lang_enc_token, std=0.02) if self.use_cls_token: nn.init.normal_(self.cls_token, std=0.02) # Everything else... self.apply(self.transformer_initializer) @staticmethod def transformer_initializer(m: nn.Module) -> None: if isinstance(m, nn.Linear): # Use xavier_uniform following Jax ViT torch.nn.init.xavier_uniform_(m.weight) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0.0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.weight, 1.0) nn.init.constant_(m.bias, 0.0) def embed_language(self, lang: torch.Tensor) -> torch.Tensor: """Only feed language through the pretrained *embedding* matrix (no bidirectional cheating).""" self.lm.eval() with torch.no_grad(): # Note :: These have position_encodings included... no need for separate `decoder_lang_pe` token_embeddings = self.lm.embeddings(lang) return token_embeddings def encode_language(self, lang: torch.Tensor, lang_mask: torch.Tensor) -> torch.Tensor: """Encode language by feeding the *pre-tokenized text* through the frozen language model.""" self.lm.eval() with torch.no_grad(): transformer_embeddings = self.lm(lang, attention_mask=lang_mask).last_hidden_state return transformer_embeddings def mask( self, ctx_patches: torch.Tensor, mask_ratio: Optional[float] = None ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Perform per-context random masking by shuffling :: uses argsort random noise to identify masked patches.""" bsz, ctx_len, n_patches, embed_dim = ctx_patches.shape if mask_ratio is not None: n_keep = int(n_patches * (1 - mask_ratio)) else: n_keep = int(n_patches * (1 - self.mask_ratio)) # Sample noise of n_patches size, argsort to get shuffled IDs, argsort again to get "unshuffle" # > For clarity -- argsort is an invertible transformation (if argsort `restore`, recovers `shuffle`) # > Note that shuffle_idxs is defined solely as a function of *n_patches* and **not** context! Same mask! shuffle_idxs = torch.argsort(torch.rand(bsz, n_patches, device=ctx_patches.device), dim=1) restore_idxs = torch.argsort(shuffle_idxs, dim=1) # Get "keep" (visible) patches --> make sure to get _same_ patches *across* context length! visible_patches = torch.gather( ctx_patches, dim=2, index=shuffle_idxs[:, None, :n_keep, None].repeat(1, ctx_len, 1, embed_dim) ) # Generate the binary mask --> IMPORTANT :: `0` is keep, `1` is remove (following FAIR MAE convention) mask = torch.ones(bsz, n_patches, device=ctx_patches.device) mask[:, :n_keep] = 0 mask = torch.gather(mask, dim=1, index=restore_idxs) return visible_patches, mask, restore_idxs def get_representations( self, imgs: torch.Tensor, language: Optional[Union[List[str], Tuple[str]]] = None, mode: str = "multimodal" ) -> torch.Tensor: """ Given either a singleton (dual-imgs, language) pair or batch of dual-imgs and language, extract representations subject to the specified mode in < multimodal | visual >. :param imgs: Processed batch of images :: [bsz, 2, 3, 224, 224] :param language: Input language as `List[str] | Tuple[str] | None` :param mode: Type of representations to extract -- `multimodal` (both vision+text), `visual` (visual only) :return: Extracted representations given (imgs, language) input as sequence. """ assert ( imgs.ndim == 5 and imgs.shape[1] == 2 and (language is None or isinstance(language, list) or isinstance(language, tuple)) ), "Invalid input to `get_representations()`" assert mode in {"multimodal", "visual"}, f"Extraction mode `{mode}` not supported!" # Tokenize Language --> note max length is 20! if language is None: lang, lang_mask = [torch.zeros(imgs.shape[0], 20, dtype=int, device=self.lm.device) for _ in range(2)] lang[:, 0], lang_mask[:, 0] = self.tokenizer.cls_token_id, 1 else: tokens = self.tokenizer(language, return_tensors="pt", max_length=20, padding="max_length", truncation=True) lang, lang_mask = tokens["input_ids"].to(self.lm.device), tokens["attention_mask"].to(self.lm.device) # Tile Language & Language Mask if mismatch with # images! if not len(lang) == len(imgs): lang = repeat(lang, "b seq -> (bsz b) seq", bsz=imgs.size(0)) lang_mask = repeat(lang_mask, "b seq -> (bsz b) seq", bsz=imgs.size(0)) # Extract desired representations... representations = self.encode(imgs, lang, lang_mask) return representations if mode == "multimodal" else representations[:, : -lang_mask.shape[-1]] def encode(self, imgs: torch.Tensor, lang: torch.Tensor, lang_mask: torch.Tensor) -> torch.Tensor: """Default representation extraction function, given a batch of dual-images and tokenized language.""" lang_embeddings = self.encode_language(lang, lang_mask) projected_language = self.lang2encoder(lang_embeddings) # Patchify, broadcast position embedding across ctx_len (0 + K) dimension, unfold, add `ctx_enc_pe` embeddings! patches = self.patch2embed(rearrange(imgs, "bsz ctx channels res1 res2 -> (bsz ctx) channels res1 res2")) patches_pe = patches + (self.encoder_pe[:, 1:, :] if self.use_cls_token else self.encoder_pe) ctx_patches = rearrange(patches_pe, "(bsz ctx) seq embed -> bsz ctx seq embed", ctx=2) ctx_patches_pe = ctx_patches + self.ctx_enc_pe[:, :2, ...] # Add "modality" embeddings to patches & language & flatten out context patches... img_ctx_embeddings, lang_embeddings = ( ctx_patches_pe + self.img_enc_token, projected_language + self.lang_enc_token, ) img_embeddings = rearrange(img_ctx_embeddings, "bsz ctx seq embed -> bsz (ctx seq) embed") # (Optional) Token Handling if self.use_cls_token: cls_token_pe = self.cls_token + self.encoder_pe[:, :1, :] + self.img_enc_token[:, 0, :, :] cls_tokens = cls_token_pe.expand(imgs.shape[0], -1, -1) img_embeddings = torch.cat([cls_tokens, img_embeddings], dim=1) # Create "dummy" visible mask, concatenate image patches & language, feed to Transformer patches_mask = torch.ones_like(img_embeddings[..., -1], dtype=lang_mask.dtype) multimodal_embeddings = torch.cat([img_embeddings, lang_embeddings], dim=1) # Merge on sequence length... multimodal_mask = torch.cat([patches_mask, lang_mask], dim=1) # Merge on sequence length... # Apply Transformer Blocks... for block in self.encoder_blocks: multimodal_embeddings = block(multimodal_embeddings, multimodal_mask) multimodal_embeddings = self.encoder_norm(multimodal_embeddings) # Return the full sequence of multimodal embeddings (but ignore 0th frame) => the `~` denote what to remove! # => [CLS] + ~[n_patches x 0th frame]~ + [n_patches x Kth frame] + [max_lang_len language] # => ~[n_patches x 0th frame]~ + [n_patches x Kth frame] + [max_lang_len language] if self.use_cls_token: return torch.cat( [multimodal_embeddings[:, :1, :], multimodal_embeddings[:, 1 + self.patch2embed.num_patches :, :]], dim=1 ) else: return multimodal_embeddings[:, self.patch2embed.num_patches :] def score(self, imgs: torch.Tensor, langs: torch.Tensor, lang_masks: torch.Tensor) -> torch.Tensor: """ Given an example 0-K pair and a set of k language instructions, output scores under the generative language model for each instruction. :param imgs: 0-K pairs --> [1, 2, 3, 224, 224] :param langs: Tokenized language input --> [1, k, seq] :param lang_masks: Language padding masks --> [1, k, seq] :return: [1, k] Tensor of LM probabilities given imgs. """ # Blank out the "encoder" language --> just [ = 101, 0 ...] blank_lang = torch.zeros(1, self.max_lang_len, dtype=torch.int64, device=imgs.device) blank_lang_mask = torch.zeros(1, self.max_lang_len, dtype=torch.int64, device=imgs.device) blank_lang[0][0], blank_lang_mask[0][0] = 101, 1 # === Encoder Forward === lang_embeddings = self.encode_language(blank_lang, blank_lang_mask) projected_language = self.lang2encoder(lang_embeddings) # Patchify, broadcast position embedding across ctx_len (0 + K) dimension, unfold, add `ctx_enc_pe` embeddings! patches = self.patch2embed(rearrange(imgs, "bsz ctx channels res1 res2 -> (bsz ctx) channels res1 res2")) patches_pe = patches + (self.encoder_pe[:, 1:, :] if self.use_cls_token else self.encoder_pe) ctx_patches = rearrange(patches_pe, "(bsz ctx) seq embed -> bsz ctx seq embed", ctx=2) ctx_patches_pe = ctx_patches + self.ctx_enc_pe[:, :2, ...] # Add "modality" embeddings to patches & language & flatten out context patches... img_ctx_embeddings, lang_embeddings = ( ctx_patches_pe + self.img_enc_token, projected_language + self.lang_enc_token, ) img_embeddings = rearrange(img_ctx_embeddings, "bsz ctx seq embed -> bsz (ctx seq) embed") # (Optional) Token Handling if self.use_cls_token: cls_token_pe = self.cls_token + self.encoder_pe[:, :1, :] + self.img_enc_token[:, 0, :, :] cls_tokens = cls_token_pe.expand(imgs.shape[0], -1, -1) img_embeddings = torch.cat([cls_tokens, img_embeddings], dim=1) # Create "dummy" visible mask, concatenate image patches & language, feed to Transformer patches_mask = torch.ones_like(img_embeddings[..., -1], dtype=blank_lang_mask.dtype) multimodal_embeddings = torch.cat([img_embeddings, lang_embeddings], dim=1) # Merge on sequence length... multimodal_mask = torch.cat([patches_mask, blank_lang_mask], dim=1) # Merge on sequence length... # Apply Transformer Blocks... for block in self.encoder_blocks: multimodal_embeddings = block(multimodal_embeddings, multimodal_mask) multimodal_embeddings = self.encoder_norm(multimodal_embeddings) # Split multimodal embedding, remove language, and return only the (CLS +) 0th + Kth frame patches enc_patches = multimodal_embeddings[:, : -blank_lang_mask.shape[-1], ...] # === Encoder =>> Decoder Hand-Off === enc_patches = repeat(enc_patches, "b cseq embed -> (bsz b) cseq embed", bsz=langs.size(0)) lang_gen_embeddings = self.embed_language(langs) # === Decoder Forward === projected_patches = self.encoder2decoder(enc_patches) projected_lang_gen = self.lang2decoder(lang_gen_embeddings) # (Optional) Token Handling if self.use_cls_token: projected_ctx_patches = rearrange( projected_patches[:, 1:, :], "bsz (ctx seq) embed -> bsz ctx seq embed", ctx=2 ) else: projected_ctx_patches = rearrange(projected_patches, "bsz (ctx seq) embed -> bsz ctx seq embed", ctx=2) # Add position embeddings, `ctx_dec_pe` embeddings, and flatten patches for Transformer... decoder_ctx_patches_pe = projected_ctx_patches + ( self.decoder_pe[None, ...] if not self.use_cls_token else self.decoder_pe[None, :, 1:, :] ) decoder_ctx_patches = decoder_ctx_patches_pe + self.ctx_dec_pe[:, :2, ...] decoder_patches = rearrange(decoder_ctx_patches, "bsz ctx seq embed -> bsz (ctx seq) embed") # (Optional) Token Handling if self.use_cls_token: # Add back Token from `projected_patches[:, :1, :]` cls_embedding = projected_patches[:, :1, :] + self.decoder_pe[:, :1, :] decoder_patches = torch.cat([cls_embedding, decoder_patches], dim=1) # Add language -> create "mask" by multiply padding by self.prefix_mask decoder_patches_mask = torch.ones_like(decoder_patches[..., -1], dtype=lang_masks.dtype) multimodal_embedding = torch.cat([decoder_patches, projected_lang_gen], dim=1) # Merge on sequence length... multimodal_mask = torch.cat([decoder_patches_mask, lang_masks], dim=1) # Merge on sequence length... # Compute prefix_padded_mask prefix_padded_mask = rearrange(multimodal_mask, "bsz seq -> bsz 1 seq 1") * self.prefix_mask # Apply Transformer Blocks... for block in self.decoder_blocks: multimodal_embedding = block(multimodal_embedding, prefix_padded_mask) multimodal_embedding = self.decoder_norm(multimodal_embedding) # Split multimodal embedding into *just* the language + project! lang = multimodal_embedding[:, -lang_masks.shape[-1] :, ...] generations = self.decoder_lang_prediction(lang) # Compute cross-entropy loss (multiply by -1 for "final scoring") --> log-likelihood! bsz, seq = langs.shape lang_logits = rearrange(generations[:, :-1, ...], "bsz seq vocab -> (bsz seq) vocab") lang_targets = rearrange(langs[:, 1:], "bsz seq -> (bsz seq)") lang_loss_mask = lang_masks[:, :-1] # Defined where valid... ce_loss = F.cross_entropy(lang_logits, lang_targets, reduction="none") per_token_loss = rearrange(ce_loss, "(bsz seq) -> bsz seq", bsz=bsz, seq=seq - 1) # -1 because causal mask... # Compute loss only on *non-padded* and *non-ignored* tokens... lang_example_loss = (per_token_loss * lang_loss_mask).sum(dim=-1) / lang_loss_mask.sum(dim=-1) return -1 * lang_example_loss.detach() def forward_encoder( self, img_ctx: torch.Tensor, lang_con: torch.Tensor, lang_con_mask: torch.Tensor, mask_ratio: Optional[float] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: lang_embeddings = self.encode_language(lang_con, lang_con_mask) projected_lang = self.lang2encoder(lang_embeddings) # Reshape image context to apply masking *identically* # Patchify, broadcast position embedding across ctx_len (0 + K) dimension, unfold, add `ctx_enc_pe` embeddings! patches = self.patch2embed(rearrange(img_ctx, "bsz ctx channels res1 res2 -> (bsz ctx) channels res1 res2")) patches_pe = patches + (self.encoder_pe if not self.use_cls_token else self.encoder_pe[:, 1:, :]) ctx_patches = rearrange(patches_pe, "(bsz ctx) seq embed -> bsz ctx seq embed", ctx=2) ctx_patches_pe = ctx_patches + self.ctx_enc_pe[:, :2, ...] # Create mask (and go ahead and mask out patches at the same time) visible_ctx_patches, mask, restore_idxs = self.mask(ctx_patches_pe, mask_ratio) # Add "modality" embeddings to patches & language & flatten out context patches... visible_ctx_patches, lang = visible_ctx_patches + self.img_enc_token, projected_lang + self.lang_enc_token visible_patches = rearrange(visible_ctx_patches, "bsz ctx seq embed -> bsz (ctx seq) embed") # (Optional) Token Handling if self.use_cls_token: cls_token_pe = self.cls_token + self.encoder_pe[:, :1, :] + self.img_enc_token[:, 0, :, :] cls_tokens = cls_token_pe.expand(img_ctx.shape[0], -1, -1) visible_patches = torch.cat([cls_tokens, visible_patches], dim=1) # Create "dummy" visible mask, concatenate image patches & language, feed to Transformer... visible_mask = torch.ones_like(visible_patches[..., -1], dtype=lang_con_mask.dtype) multimodal_embedding = torch.cat([visible_patches, lang], dim=1) # Merge on sequence length... multimodal_mask = torch.cat([visible_mask, lang_con_mask], dim=1) # Merge on sequence length... # Apply Transformer Blocks... for block in self.encoder_blocks: multimodal_embedding = block(multimodal_embedding, multimodal_mask) multimodal_embedding = self.encoder_norm(multimodal_embedding) # Split multimodal embedding, remove language, return the visible ctx (0th + Kth frame) patches (+ )! visible_patches = multimodal_embedding[:, : -lang_con_mask.shape[-1], ...] return visible_patches, mask, restore_idxs def forward_decoder( self, visible_patches: torch.Tensor, restore_idxs: torch.Tensor, lang_gen: torch.Tensor, lang_gen_mask: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: # Project patches & lang_gen into decoder dimension (visible_patches :: [bsz, (CLS) + 2 * seq, enc_embed]) projected_patches = self.encoder2decoder(visible_patches) projected_lang_gen = self.lang2decoder(lang_gen) visible_per_frame = (projected_patches.shape[1] - (1 if self.use_cls_token else 0)) // 2 # Add Mask Tokens to Sequence and Unshuffle mask_tokens = self.mask_token.repeat(projected_patches.shape[0], 2, restore_idxs.shape[1] - visible_per_frame, 1) # (Optional) Token Handling if self.use_cls_token: # Remove CLS Token as part of "unshuffling" projected_ctx_patches = rearrange( projected_patches[:, 1:, :], "bsz (ctx seq) embed -> bsz ctx seq embed", ctx=2 ) no_cls_concatenated_ctx_patches = torch.cat([projected_ctx_patches, mask_tokens], dim=2) unshuffled_ctx_patches = torch.gather( no_cls_concatenated_ctx_patches, dim=2, index=restore_idxs[:, None, ..., None].repeat(1, 2, 1, self.decoder_embed_dim), ) else: projected_ctx_patches = rearrange(projected_patches, "bsz (ctx seq) embed -> bsz ctx seq embed", ctx=2) concatenated_ctx_patches = torch.cat([projected_ctx_patches, mask_tokens], dim=2) unshuffled_ctx_patches = torch.gather( concatenated_ctx_patches, dim=2, index=restore_idxs[:, None, ..., None].repeat(1, 2, 1, self.decoder_embed_dim), ) # Add position embeddings, `ctx_dec_pe` embeddings, and flatten patches for Transformer... decoder_ctx_patches_pe = unshuffled_ctx_patches + ( self.decoder_pe[None, ...] if not self.use_cls_token else self.decoder_pe[None, :, 1:, :] ) decoder_ctx_patches = decoder_ctx_patches_pe + self.ctx_dec_pe[:, :2, ...] decoder_patches = rearrange(decoder_ctx_patches, "bsz ctx seq embed -> bsz (ctx seq) embed") # (Optional) Token Handling if self.use_cls_token: # Add back Token from `projected_patches[:, :1, :]` cls_embedding = projected_patches[:, :1, :] + self.decoder_pe[:, :1, :] decoder_patches = torch.cat([cls_embedding, decoder_patches], dim=1) # Add language -> create "mask" by multiply padding by self.prefix_mask decoder_patches_mask = torch.ones_like(decoder_patches[..., -1], dtype=lang_gen_mask.dtype) multimodal_embedding = torch.cat([decoder_patches, projected_lang_gen], dim=1) # Merge on sequence length... multimodal_mask = torch.cat([decoder_patches_mask, lang_gen_mask], dim=1) # Merge on sequence length... # Compute prefix_padded_mask prefix_padded_mask = rearrange(multimodal_mask, "bsz seq -> bsz 1 seq 1") * self.prefix_mask # Apply Transformer Blocks... for block in self.decoder_blocks: multimodal_embedding = block(multimodal_embedding, prefix_padded_mask) multimodal_embedding = self.decoder_norm(multimodal_embedding) # Split multimodal embedding into patches and language... patches_ctx = multimodal_embedding[:, : -lang_gen_mask.shape[-1], ...] patches = rearrange( patches_ctx if not self.use_cls_token else patches_ctx[:, 1:, :], "bsz (ctx seq) embed -> bsz ctx seq embed", ctx=2, ) lang = multimodal_embedding[:, -lang_gen_mask.shape[-1] :, ...] # Project each up to the output space... reconstructions = self.decoder_patch_prediction(patches) generations = self.decoder_lang_prediction(lang) return reconstructions, generations def patchify(self, imgs: torch.Tensor) -> torch.Tensor: """Convert a batch of (0th + Kth frame) images to their patched equivalents by naive reshaping.""" return rearrange( imgs, "bsz ctx c (height patch_h) (width patch_w) -> bsz ctx (height width) (patch_h patch_w c)", patch_h=self.patch_size, patch_w=self.patch_size, ) def compute_loss( self, imgs: torch.Tensor, ctx_reconstructions: torch.Tensor, mask: torch.Tensor, lang: torch.Tensor, generated_language: torch.Tensor, lang_gen_mask: torch.Tensor, lang_gen_weight: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: # Reconstruction Loss... assert self.norm_pixel_loss, "`norm_pixel_loss` should always be true... false only for visualizations!" targets = self.patchify(imgs) # Normalize targets... mu, var = targets.mean(dim=-1, keepdim=True), targets.var(dim=-1, unbiased=True, keepdim=True) targets = (targets - mu) / ((var + 1e-6) ** 0.5) # Split targets into 0 and K --> do the same for ctx_reconstructions zero_target, k_target = targets[:, 0, ...], targets[:, 1, ...] zero_reconstruction, k_reconstruction = ctx_reconstructions[:, 0, ...], ctx_reconstructions[:, 1, ...] # Compute mean losses per patch first... zero_mse, k_mse = (zero_reconstruction - zero_target) ** 2, (k_reconstruction - k_target) ** 2 zero_avg_loss_per_patch, k_avg_loss_per_patch = zero_mse.mean(dim=-1), k_mse.mean(dim=-1) # Compute reconstruction losses... zero_loss = (zero_avg_loss_per_patch * mask).sum() / mask.sum() k_loss = (k_avg_loss_per_patch * mask).sum() / mask.sum() reconstruction_loss = (zero_loss + k_loss) / 2 # Language Loss... bsz, seq = lang.shape lang_logits = rearrange(generated_language[:, :-1, ...], "bsz seq vocab -> (bsz seq) vocab") lang_targets = rearrange(lang[:, 1:], "bsz seq -> (bsz seq)") lang_loss_mask = lang_gen_mask[:, :-1] # Defined where valid... ce_loss = F.cross_entropy(lang_logits, lang_targets, reduction="none") per_token_loss = rearrange(ce_loss, "(bsz seq) -> bsz seq", bsz=bsz, seq=seq - 1) # -1 because causal mask... # Compute loss only on *non-padded* and *non-ignored* tokens... lang_example_loss = (per_token_loss * lang_loss_mask).sum(dim=-1) / lang_loss_mask.sum(dim=-1) lang_loss = (lang_example_loss * lang_gen_weight).sum() / (self.eps + lang_gen_weight.sum()) # Divide by 0... # TODO (Remove) -- NaN Check... if reconstruction_loss.isnan().any() or lang_loss.isnan().any(): # fmt: off print( f"Found Nan -- " f"ctx_reconstructions: {ctx_reconstructions.isnan().any()} -- " f"generated_language: {generated_language.isnan().any()} -- " f"zero_avg_loss_per_patch: {zero_avg_loss_per_patch.isnan().any()} -- " f"k_avg_loss_per_patch: {k_avg_loss_per_patch.isnan().any()} -- " f"zero_loss: {zero_loss.isnan().any()} -- " f"k_loss: {k_loss.isnan().any()} -- " f"reconstruction_loss: {reconstruction_loss.isnan().any()} -- " f"ce_loss: {ce_loss.isnan().any()} -- " f"per_token_loss: {per_token_loss.isnan().any()} -- " f"lang_example_loss: {lang_example_loss.isnan().any()} -- " f"lang_loss: {lang_loss.isnan().any()}" ) exit(1) # fmt: on # Compute weighted loss... loss = self.mae_weight * reconstruction_loss + self.lm_weight * lang_loss return loss, reconstruction_loss, lang_loss, zero_loss, k_loss def forward( self, imgs: torch.Tensor, lang_con: torch.Tensor, lang_con_mask: torch.Tensor, lang_gen: torch.Tensor, lang_gen_mask: torch.Tensor, lang_gen_weight: torch.Tensor, mask_ratio: Optional[float] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[torch.Tensor]]: """ Run a forward pass through the model, computing the MAE reconstruction loss (language-conditioned if applicable) on the 0th + Kth frame temporal context, as well as the generated language given masked context (if applicable). :param imgs: A [bsz, 2, in_channels, resolution, resolution] tensor of (0th frame, Kth frame) sequences. :param lang_con: A [bsz, seq_len] tensor of language context to condition on. :param lang_con_mask: A [bsz, seq_len] binary mask tensor to indicate padding/null in `lang_condition`. :param lang_gen: A [bsz, seq_len] tensor of language to generate. :param lang_gen_mask: A [bsz, seq_len] binary mask tensor to indicate padding/null in `lang_gen`. :param lang_gen_weight: A [bsz] tensor of per-example weights to indicate when to 0 `lm` loss. :param mask_ratio: Optional masking ratio to use instead of the default. :return: Tuple of losses and intermediates, as follows: > (combined loss, reconstruction loss, lm loss, [reconstruction loss per frame in {0, K}]) """ visible_ctx_patches, mask, restore_idxs = self.forward_encoder(imgs, lang_con, lang_con_mask, mask_ratio) # Get token embeddings -- *NOT CONTEXTUAL* -- for the lang_gen tokens... lang_gen_embeddings = self.embed_language(lang_gen) # Run patches, and lang_gen through decoder --> note that we need a causal mask on language generation... ctx_reconstructions, generated_language = self.forward_decoder( visible_ctx_patches, restore_idxs, lang_gen_embeddings, lang_gen_mask ) # Compute loss for reconstructed patches & generated language loss, reconstruction_loss, lang_loss, zero_loss, k_loss = self.compute_loss( imgs, ctx_reconstructions, mask, lang_gen, generated_language, lang_gen_mask, lang_gen_weight ) return loss, reconstruction_loss, lang_loss, [zero_loss, k_loss] def configure_optimizer(self) -> Tuple[torch.optim.Optimizer, Callable[[int, float], float]]: # Short-Circuit on Valid Optimizers if self.optimizer not in ["adamw"]: raise NotImplementedError(f"Optimizer `{self.optimizer}` not supported - try [`adamw`] instead!") # Create Parameter Groups --> Bias terms, Normalization layer parameters shouldn't be decayed... # > This is a compact rewrite of `param_groups_weight_decay()` from TIMM because I don't want the dependency decay, no_decay = [], [] for name, param in self.named_parameters(): if not param.requires_grad: continue # Check on any parameters with fewer than 2 dimensions or with "bias" in the name... if param.ndim <= 1 or name.endswith(".bias"): no_decay.append(param) else: decay.append(param) # Build Parameter Groups groups = [{"params": decay, "weight_decay": self.weight_decay}, {"params": no_decay, "weight_decay": 0.0}] # Compute LR -- MAE uses the `linear scaling rule` :: lr = base_lr * (effective_bsz / 256) # > https://github.com/facebookresearch/mae/blob/main/PRETRAIN.md self.lr = self.base_lr * (self.effective_bsz / 256) # Create Optimizer & LR Scheduler optimizer = torch.optim.AdamW(groups, lr=self.lr, betas=self.betas) update_lr = get_lr_update(optimizer, self.schedule, self.lr, self.min_lr, self.warmup_epochs, self.max_epochs) return optimizer, update_lr ================================================ FILE: voltron/models/instantiate.py ================================================ """ instantiate.py Simple wrapping script for instantiating a core Voltron/reproduction model and configuring the torch.Optimizer for DDP pretraining. Meant to be modular and extensible! """ from typing import Callable, Tuple import torch.nn as nn from torch.optim import Optimizer from voltron.conf import DatasetConfig, ModelConfig from .core.vcond import VCond from .core.vdual import VDual from .core.vgen import VGen from .reproductions.vmvp import VMVP from .reproductions.vr3m import VR3M from .reproductions.vrn3m import VRN3M def get_model_optimizer( model_cfg: ModelConfig, dataset_cfg: DatasetConfig ) -> Tuple[nn.Module, Optimizer, Callable[[int, float], float]]: """Switch on `model_cfg.arch` --> instantiate the correct nn.Module and Optimizer (on CPU/default device).""" # Data-Locked Reproductions if model_cfg.arch == "v-mvp": model = VMVP( resolution=dataset_cfg.resolution, patch_size=model_cfg.patch_size, encoder_depth=model_cfg.encoder_depth, encoder_embed_dim=model_cfg.encoder_embed_dim, encoder_n_heads=model_cfg.encoder_n_heads, decoder_depth=model_cfg.decoder_depth, decoder_embed_dim=model_cfg.decoder_embed_dim, decoder_n_heads=model_cfg.decoder_n_heads, optimizer=model_cfg.optimizer, schedule=model_cfg.schedule, base_lr=model_cfg.base_lr, min_lr=model_cfg.min_lr, effective_bsz=model_cfg.effective_bsz, betas=model_cfg.betas, weight_decay=model_cfg.weight_decay, warmup_epochs=dataset_cfg.warmup_epochs, max_epochs=dataset_cfg.max_epochs, mlp_ratio=model_cfg.mlp_ratio, norm_pixel_loss=model_cfg.norm_pixel_loss, ) elif model_cfg.arch == "v-r3m": model = VR3M( resolution=dataset_cfg.resolution, patch_size=model_cfg.patch_size, depth=model_cfg.depth, embed_dim=model_cfg.embed_dim, n_heads=model_cfg.n_heads, language_model=model_cfg.language_model, hf_cache=model_cfg.hf_cache, language_dim=model_cfg.language_dim, reward_dim=model_cfg.reward_dim, n_negatives=model_cfg.n_negatives, lang_reward_weight=model_cfg.lang_reward_weight, tcn_weight=model_cfg.tcn_weight, l1_weight=model_cfg.l1_weight, l2_weight=model_cfg.l2_weight, optimizer=model_cfg.optimizer, schedule=model_cfg.schedule, lr=model_cfg.lr, min_lr=model_cfg.min_lr, warmup_epochs=dataset_cfg.warmup_epochs, max_epochs=dataset_cfg.max_epochs, mlp_ratio=model_cfg.mlp_ratio, ) elif model_cfg.arch == "v-rn3m": model = VRN3M( resolution=dataset_cfg.resolution, fc_dim=model_cfg.fc_dim, language_model=model_cfg.language_model, hf_cache=model_cfg.hf_cache, language_dim=model_cfg.language_dim, reward_dim=model_cfg.reward_dim, n_negatives=model_cfg.n_negatives, lang_reward_weight=model_cfg.lang_reward_weight, tcn_weight=model_cfg.tcn_weight, l1_weight=model_cfg.l1_weight, l2_weight=model_cfg.l2_weight, optimizer=model_cfg.optimizer, lr=model_cfg.lr, ) # Voltron Models elif model_cfg.arch == "v-cond": model = VCond( resolution=dataset_cfg.resolution, patch_size=model_cfg.patch_size, encoder_depth=model_cfg.encoder_depth, encoder_embed_dim=model_cfg.encoder_embed_dim, encoder_n_heads=model_cfg.encoder_n_heads, decoder_depth=model_cfg.decoder_depth, decoder_embed_dim=model_cfg.decoder_embed_dim, decoder_n_heads=model_cfg.decoder_n_heads, language_model=model_cfg.language_model, hf_cache=model_cfg.hf_cache, language_dim=model_cfg.language_dim, optimizer=model_cfg.optimizer, schedule=model_cfg.schedule, base_lr=model_cfg.base_lr, min_lr=model_cfg.min_lr, effective_bsz=model_cfg.effective_bsz, betas=model_cfg.betas, weight_decay=model_cfg.weight_decay, warmup_epochs=dataset_cfg.warmup_epochs, max_epochs=dataset_cfg.max_epochs, mlp_ratio=model_cfg.mlp_ratio, norm_pixel_loss=model_cfg.norm_pixel_loss, ) elif model_cfg.arch == "v-dual": model = VDual( resolution=dataset_cfg.resolution, patch_size=model_cfg.patch_size, encoder_depth=model_cfg.encoder_depth, encoder_embed_dim=model_cfg.encoder_embed_dim, encoder_n_heads=model_cfg.encoder_n_heads, decoder_depth=model_cfg.decoder_depth, decoder_embed_dim=model_cfg.decoder_embed_dim, decoder_n_heads=model_cfg.decoder_n_heads, language_model=model_cfg.language_model, hf_cache=model_cfg.hf_cache, language_dim=model_cfg.language_dim, optimizer=model_cfg.optimizer, schedule=model_cfg.schedule, base_lr=model_cfg.base_lr, min_lr=model_cfg.min_lr, effective_bsz=model_cfg.effective_bsz, betas=model_cfg.betas, weight_decay=model_cfg.weight_decay, warmup_epochs=dataset_cfg.warmup_epochs, max_epochs=dataset_cfg.max_epochs, mlp_ratio=model_cfg.mlp_ratio, norm_pixel_loss=model_cfg.norm_pixel_loss, ) elif model_cfg.arch == "v-gen": model = VGen( resolution=dataset_cfg.resolution, patch_size=model_cfg.patch_size, encoder_depth=model_cfg.encoder_depth, encoder_embed_dim=model_cfg.encoder_embed_dim, encoder_n_heads=model_cfg.encoder_n_heads, decoder_depth=model_cfg.decoder_depth, decoder_embed_dim=model_cfg.decoder_embed_dim, decoder_n_heads=model_cfg.decoder_n_heads, language_model=model_cfg.language_model, hf_cache=model_cfg.hf_cache, language_dim=model_cfg.language_dim, max_lang_len=dataset_cfg.max_lang_len, vocab_size=model_cfg.vocab_size, mae_weight=model_cfg.mae_weight, lm_weight=model_cfg.lm_weight, optimizer=model_cfg.optimizer, schedule=model_cfg.schedule, base_lr=model_cfg.base_lr, min_lr=model_cfg.min_lr, effective_bsz=model_cfg.effective_bsz, betas=model_cfg.betas, weight_decay=model_cfg.weight_decay, warmup_epochs=dataset_cfg.warmup_epochs, max_epochs=dataset_cfg.max_epochs, mlp_ratio=model_cfg.mlp_ratio, norm_pixel_loss=model_cfg.norm_pixel_loss, ) else: raise ValueError(f"Model Architecture `{model_cfg.arch}` is not implemented!") # Configure Optimizer --> on same device (CPU) optimizer, update_lr = model.configure_optimizer() return model, optimizer, update_lr ================================================ FILE: voltron/models/materialize.py ================================================ """ materialize.py Core functionality for using pretrained models; defines the package-level `load` functionality for downloading and instantiating pretrained Voltron (and baseline) models. """ import json import os from pathlib import Path from typing import Callable, List, Tuple import gdown import torch import torch.nn as nn import torchvision.transforms as T from voltron.models import VMVP, VR3M, VRN3M, VCond, VDual, VGen # === Define Useful Variables for Loading Models === DEFAULT_CACHE = "cache/" NORMALIZATION = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) # Pretrained Model Registry :: "model id" -> {"config" -> gdown ID, "checkpoint" -> gdown ID, "cls" -> Model Class} MODEL_REGISTRY = { # === Voltron ViT-Small (Sth-Sth) Models === "v-cond": { "config": "1O4oqRIblfS6PdFlZzUcYIX-Rqe6LbvnD", "checkpoint": "12g5QckQSMKqrfr4lFY3UPdy7oLw4APpG", "cls": VCond, }, "v-dual": { "config": "1zgKiK81SF9-0lg0XbMZwNhUh1Q7YdZZU", "checkpoint": "1CCRqrwcvF8xhIbJJmwnCbcWfWTJCK40T", "cls": VDual, }, "v-gen": { "config": "18-mUBDsr-2_-KrGoL2E2YzjcUO8JOwUF", "checkpoint": "1TzSQpKVKBWKCSvYJf22c45hrKczTQz24", "cls": VGen, }, # === Voltron ViT-Base Model === "v-cond-base": { "config": "1CLe7CaIzTEcGCijIgw_S-uqMXHfBFSLI", "checkpoint": "1PwczOijL0hfYD8DI4xLOPLf1xL_7Kg9S", "cls": VCond, }, # === Data-Locked Reproductions === "r-mvp": { "config": "1KKNWag6aS1xkUiUjaJ1Khm9D6F3ROhCR", "checkpoint": "1-ExshZ6EC8guElOv_s-e8gOJ0R1QEAfj", "cls": VMVP, }, "r-r3m-vit": { "config": "1JGk32BLXwI79uDLAGcpbw0PiupBknf-7", "checkpoint": "1Yby5oB4oPc33IDQqYxwYjQV3-56hjCTW", "cls": VR3M, }, "r-r3m-rn50": { "config": "1OS3mB4QRm-MFzHoD9chtzSmVhOA-eL_n", "checkpoint": "1t1gkQYr6JbRSkG3fGqy_9laFg_54IIJL", "cls": VRN3M, }, } def available_models() -> List[str]: return list(MODEL_REGISTRY.keys()) def load( model_id: str, device: torch.device = "cpu", freeze: bool = True, cache: str = DEFAULT_CACHE ) -> Tuple[nn.Module, Callable[[torch.Tensor], torch.Tensor]]: """ Download & cache specified model configuration & checkpoint, then load & return module & image processor. Note :: We *override* the default `forward()` method of each of the respective model classes with the `extract_features` method --> by default passing "NULL" language for any language-conditioned models. This can be overridden either by passing in language (as a `str) or by invoking the corresponding methods. """ assert model_id in MODEL_REGISTRY, f"Model ID `{model_id}` not valid, try one of {list(MODEL_REGISTRY.keys())}" # Download Config & Checkpoint (if not in cache) model_cache = Path(cache) / model_id config_path, checkpoint_path = model_cache / f"{model_id}-config.json", model_cache / f"{model_id}.pt" os.makedirs(model_cache, exist_ok=True) if not checkpoint_path.exists() or not config_path.exists(): gdown.download(id=MODEL_REGISTRY[model_id]["config"], output=str(config_path), quiet=False) gdown.download(id=MODEL_REGISTRY[model_id]["checkpoint"], output=str(checkpoint_path), quiet=False) # Load Configuration --> patch `hf_cache` key if present (don't download to random locations on filesystem) with open(config_path, "r") as f: model_kwargs = json.load(f) if "hf_cache" in model_kwargs: model_kwargs["hf_cache"] = str(Path(cache) / "hf-cache") # By default, the model's `__call__` method defaults to `forward` --> for downstream applications, override! # > Switch `__call__` to `get_representations` MODEL_REGISTRY[model_id]["cls"].__call__ = MODEL_REGISTRY[model_id]["cls"].get_representations # Materialize Model (load weights from checkpoint; note that unused element `_` are the optimizer states...) model = MODEL_REGISTRY[model_id]["cls"](**model_kwargs) state_dict, _ = torch.load(checkpoint_path, map_location=device) model.load_state_dict(state_dict, strict=True) model.to(device) model.eval() # Freeze model parameters if specified (default: True) if freeze: for _, param in model.named_parameters(): param.requires_grad = False # Build Visual Preprocessing Transform (assumes image is read into a torch.Tensor, but can be adapted) if model_id in {"v-cond", "v-dual", "v-gen", "v-cond-base", "r-mvp"}: # All models except R3M are by default normalized subject to default IN1K normalization... preprocess = T.Compose( [ T.Resize(model_kwargs["resolution"]), T.CenterCrop(model_kwargs["resolution"]), T.ConvertImageDtype(torch.float), T.Normalize(mean=NORMALIZATION[0], std=NORMALIZATION[1]), ] ) else: # R3M models (following original work) expect unnormalized images with values in range [0 - 255) preprocess = T.Compose( [ T.Resize(model_kwargs["resolution"]), T.CenterCrop(model_kwargs["resolution"]), T.ConvertImageDtype(torch.float), T.Lambda(lambda x: x * 255.0), ] ) return model, preprocess ================================================ FILE: voltron/models/reproductions/__init__.py ================================================ ================================================ FILE: voltron/models/reproductions/vmvp.py ================================================ """ vmvp.py PyTorch Module defining a basic MAE a la Masked Visual Pretraining for Motor Control (MVP), with the requisite hyperparameters - as defined in the original ImageMAE paper, and as used by both MVP papers. References: - https://github.com/facebookresearch/mae - https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py """ from typing import Callable, Optional, Tuple import torch import torch.nn as nn from einops import rearrange from voltron.models.util.optimization import get_lr_update from voltron.models.util.transformer import Block, PatchEmbed, get_2D_position_embeddings class VMVP(nn.Module): def __init__( self, resolution: int, patch_size: int, encoder_depth: int, encoder_embed_dim: int, encoder_n_heads: int, decoder_depth: int, decoder_embed_dim: int, decoder_n_heads: int, optimizer: str, schedule: str, base_lr: float, min_lr: float, effective_bsz: float, betas: Tuple[float, float], weight_decay: float, warmup_epochs: int, max_epochs: int, mask_ratio: float = 0.75, mlp_ratio: float = 4.0, in_channels: int = 3, norm_pixel_loss: bool = True, ): """ Initialize an VMVP (MAE) model with the requisite architecture parameters. :param resolution: Base image resolution -- usually 224 (ImageNet size). :param patch_size: Height/Width of each patch in pixels -- usually 16. :param encoder_depth: Number of Transformer blocks in the encoder -- should be greater than decoder. :param encoder_embed_dim: Core embedding/hidden dimension for encoder vision transformer backbone. :param encoder_n_heads: Number of heads for encoder multi-headed self-attention. :param decoder_depth: Number of Transformer blocks in the decoder -- should be relatively shallow. :param decoder_embed_dim: Core embedding/hidden dimension for encoder vision transformer backbone. :param decoder_n_heads: Number of heads for encoder multi-headed self-attention. :param optimizer: String denoting which optimizer to use (for MAEs, usually `adamw`) :param schedule: Learning rate schedule to use; for Transformers a linear warmup + decay is recommended! :param base_lr: Base learning rate, to be scaled via a linear scaling rule (from scaling laws). :param min_lr: Minimum learning rate to decay to over the course of learning (usually 0.0) :param effective_bsz: Global batch size for update, dictates the scaling of the base_lr. :param betas: Adam optimizer betas (only applicable for `adam` and `adamw`. Prevents early loss spiking. :param weight_decay: Weight decay for global weight regularization (only applied to non-bias, non-LN layers). :param warmup_epochs: Number of epochs to warmup learning rate for linear warmup schedule. :param max_epochs: Total number of training epochs to be run. :param mask_ratio: Ratio for number of patches to mask out for MAE -- should be fairly high! :param mlp_ratio: Ratio for embedding size to Position-wise FeedForward MLP (gets shrunk back down). :param in_channels: Default number of channels in the base image -- almost always 3. :param norm_pixel_loss: Normalize decoder pixel targets for reconstruction (better perf, not interpretable). """ super().__init__() self.resolution, self.patch_size, self.mask_ratio = resolution, patch_size, mask_ratio self.in_channels, self.norm_pixel_loss, self.mlp_ratio = in_channels, norm_pixel_loss, mlp_ratio self.optimizer, self.schedule, self.betas, self.weight_decay = optimizer, schedule, betas, weight_decay self.lr, self.base_lr, self.min_lr, self.effective_bsz = None, base_lr, min_lr, effective_bsz self.warmup_epochs, self.max_epochs = warmup_epochs, max_epochs # Encoder/Decoder Parameters self.encoder_depth, self.decoder_depth = encoder_depth, decoder_depth self.encoder_embed_dim, self.encoder_n_heads = encoder_embed_dim, encoder_n_heads self.decoder_embed_dim, self.decoder_n_heads = decoder_embed_dim, decoder_n_heads # MAE Encoder Parameters --> MVP uses a CLS Token for feature extraction! self.cls_token = nn.Parameter(torch.zeros(1, 1, self.encoder_embed_dim)) self.patch2embed = PatchEmbed( self.resolution, self.patch_size, self.encoder_embed_dim, in_channels=self.in_channels ) self.encoder_pe = nn.Parameter( torch.zeros(1, self.patch2embed.num_patches + 1, self.encoder_embed_dim), requires_grad=False ) self.encoder_blocks = nn.ModuleList( [Block(self.encoder_embed_dim, self.encoder_n_heads, self.mlp_ratio) for _ in range(self.encoder_depth)] ) self.encoder_norm = nn.LayerNorm(self.encoder_embed_dim, eps=1e-6) # Projection from Encoder to Decoder self.encoder2decoder = nn.Linear(self.encoder_embed_dim, self.decoder_embed_dim) # MAE Decoder Parameters -- Remember the CLS Token! self.mask_token = nn.Parameter(torch.zeros(1, 1, self.decoder_embed_dim)) self.decoder_pe = nn.Parameter( torch.zeros(1, self.patch2embed.num_patches + 1, self.decoder_embed_dim), requires_grad=False ) self.decoder_blocks = nn.ModuleList( [Block(self.decoder_embed_dim, self.decoder_n_heads, self.mlp_ratio) for _ in range(self.decoder_depth)] ) self.decoder_norm = nn.LayerNorm(self.decoder_embed_dim, eps=1e-6) self.decoder_prediction = nn.Linear(self.decoder_embed_dim, (patch_size**2) * in_channels, bias=True) # Initialize all Weights self.initialize_weights() def initialize_weights(self) -> None: # Position Encoding -- Fixed 2D Sine-Cosine Embeddings enc_pe = get_2D_position_embeddings(self.encoder_embed_dim, int(self.patch2embed.num_patches**0.5), True) self.encoder_pe.data.copy_(torch.from_numpy(enc_pe).float().unsqueeze(0)) dec_pe = get_2D_position_embeddings(self.decoder_embed_dim, int(self.patch2embed.num_patches**0.5), True) self.decoder_pe.data.copy_(torch.from_numpy(dec_pe).float().unsqueeze(0)) # Initialize PatchEmbedding as a Linear... nn.init.xavier_uniform_(self.patch2embed.proj.weight.data.view([self.patch2embed.proj.weight.data.shape[0], -1])) # Initialize CLS Token & Mask Token w/ Truncated Normal nn.init.normal_(self.cls_token, std=0.02) nn.init.normal_(self.mask_token, std=0.02) # Everything else... self.apply(self.transformer_initializer) @staticmethod def transformer_initializer(m: nn.Module) -> None: if isinstance(m, nn.Linear): # Use xavier_uniform following Jax ViT torch.nn.init.xavier_uniform_(m.weight) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0.0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.weight, 1.0) nn.init.constant_(m.bias, 0.0) def mask( self, patches: torch.Tensor, mask_ratio: Optional[float] = None ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Perform per-sample random masking by shuffling :: uses argsort random noise to identify masked patches""" bsz, n_patches, embed_dim = patches.shape if mask_ratio is not None: n_keep = int(n_patches * (1 - mask_ratio)) else: n_keep = int(n_patches * (1 - self.mask_ratio)) # Sample some noise of n_patches size, argsort to get shuffled IDs (keep small), argsort again to "unshuffle" # > For clarity -- argsort is an invertible transformation (if argsort `restore`, recovers `shuffle`) shuffle_idxs = torch.argsort(torch.rand(bsz, n_patches, device=patches.device), dim=1) restore_idxs = torch.argsort(shuffle_idxs, dim=1) # Get "keep" (visible) patches visible_patches = torch.gather(patches, dim=1, index=shuffle_idxs[:, :n_keep, None].repeat(1, 1, embed_dim)) # Generate the binary mask --> IMPORTANT :: `0` is keep, `1` is remove (following FAIR MAE convention) mask = torch.ones(bsz, n_patches, device=patches.device) mask[:, :n_keep] = 0 mask = torch.gather(mask, dim=1, index=restore_idxs) return visible_patches, mask, restore_idxs def get_representations(self, img: torch.Tensor, mode: str = "patch") -> torch.Tensor: """ Given a single image, extract representations subject to the specified mode in < patch | cls >, where "cls" denotes extracting the token embedding; for our experiments, we find that running multiheaded attention pooling on top of the "patch" embeddings is *always* better! :param img: Processed batch of images :: [bsz, 3, 224, 224] :param mode: Type of representation to extract -- `patch` (sequence of patch embeddings) or `cls` () :return: Extracted representations given img input. """ assert img.ndim == 4, "Invalid input to `get_representations()`" assert mode in {"patch", "cls"}, f"Extraction mode `{mode}` not supported!" # Extract desired representations representations = self.encode(img) return representations[:, 1:] if mode == "patch" else representations[:, :1] def encode(self, img: torch.Tensor) -> torch.Tensor: """Run a single image through the MAE and extract patch embeddings.""" # Note: All of this code is taken near-verbatim from the MVP repository... # > Ref: https://github.com/ir413/mvp/blob/master/mvp/backbones/vit.py#L30 patches = self.patch2embed(img) cls_tokens = self.cls_token.expand(img.shape[0], -1, -1) cls_patches = torch.cat([cls_tokens, patches]) + self.encoder_pe # Apply Transformer Blocks... for block in self.encoder_blocks: cls_patches = block(cls_patches) cls_patches = self.encoder_norm(cls_patches) return cls_patches def forward_encoder( self, imgs: torch.Tensor, mask_ratio: Optional[float] = None ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # Patchify + Position Embedding (without the CLS Token) patches = self.patch2embed(imgs) patches_pe = patches + self.encoder_pe[:, 1:, :] # Create mask (and go ahead and mask out patches at the same time) visible_patches, mask, restore_idxs = self.mask(patches_pe, mask_ratio) # Add the CLS Token cls_token = self.cls_token + self.encoder_pe[:, :1, :] cls_tokens = cls_token.expand(imgs.shape[0], -1, -1) cls_visible_patches = torch.cat([cls_tokens, visible_patches], dim=1) # Apply Transformer Blocks... for block in self.encoder_blocks: cls_visible_patches = block(cls_visible_patches) cls_visible_patches = self.encoder_norm(cls_visible_patches) return cls_visible_patches, mask, restore_idxs def forward_decoder(self, visible_patches: torch.Tensor, restore_idxs: torch.Tensor) -> torch.Tensor: # Project patches into decoder embedding dimension projected_patches = self.encoder2decoder(visible_patches) # Add Mask Tokens to Sequence mask_tokens = self.mask_token.repeat( projected_patches.shape[0], restore_idxs.shape[1] - visible_patches.shape[1] + 1, 1 ) # Remove & add back CLS Token as part of the "unshuffling" concatenated_patches = torch.cat([projected_patches[:, 1:, :], mask_tokens], dim=1) # Skip CLS Token unshuffled_patches = torch.gather( concatenated_patches, dim=1, index=restore_idxs[..., None].repeat(1, 1, self.decoder_embed_dim) ) cls_unshuffled_patches = torch.cat([projected_patches[:, :1, :], unshuffled_patches], dim=1) # Add CLS Token # Add Position Embeddings cls_decoder_patches = cls_unshuffled_patches + self.decoder_pe # Apply Transformer Blocks... for block in self.decoder_blocks: cls_decoder_patches = block(cls_decoder_patches) cls_decoder_patches = self.decoder_norm(cls_decoder_patches) # Run final projection, remove the CLS token, and return cls_decoder_prediction = self.decoder_prediction(cls_decoder_patches) decoder_prediction = cls_decoder_prediction[:, 1:, :] return decoder_prediction def patchify(self, imgs: torch.Tensor) -> torch.Tensor: """Convert a batch of images to their patched equivalents, by naive reshaping""" return rearrange( imgs, "bsz c (height patch_h) (width patch_w) -> bsz (height width) (patch_h patch_w c)", patch_h=self.patch_size, patch_w=self.patch_size, ) def compute_loss(self, imgs: torch.Tensor, reconstructions: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: assert self.norm_pixel_loss, "`norm_pixel_loss` should always be true... false only for visualizations!" targets = self.patchify(imgs) # Normalize targets... mu, var = targets.mean(dim=-1, keepdim=True), targets.var(dim=-1, unbiased=True, keepdim=True) targets = (targets - mu) / ((var + 1e-6) ** 0.5) # Compute mean loss per patch first... mse = (reconstructions - targets) ** 2 avg_loss_per_patch = mse.mean(dim=-1) # Compute mean loss only on *removed* patches and return return (avg_loss_per_patch * mask).sum() / mask.sum() def forward( self, imgs: torch.Tensor, mask_ratio: Optional[float] = None ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: visible_patches, mask, restore_idxs = self.forward_encoder(imgs, mask_ratio) reconstructions = self.forward_decoder(visible_patches, restore_idxs) loss = self.compute_loss(imgs, reconstructions, mask) return loss, reconstructions, mask def configure_optimizer(self) -> Tuple[torch.optim.Optimizer, Callable[[int, float], float]]: # Short-Circuit on Valid Optimizers if self.optimizer not in ["adamw"]: raise NotImplementedError(f"Optimizer `{self.optimizer}` not supported - try [`adamw`] instead!") # Create Parameter Groups --> Bias terms, Normalization layer parameters shouldn't be decayed... # > This is a compact rewrite of `param_groups_weight_decay()` from TIMM because I don't want the dependency decay, no_decay = [], [] for name, param in self.named_parameters(): if not param.requires_grad: continue # Check on any parameters with fewer than 2 dimensions or with "bias" in the name... if param.ndim <= 1 or name.endswith(".bias"): no_decay.append(param) else: decay.append(param) # Build Parameter Groups groups = [{"params": decay, "weight_decay": self.weight_decay}, {"params": no_decay, "weight_decay": 0.0}] # Compute LR -- MAE uses the `linear scaling rule` :: lr = base_lr * (effective_bsz / 256) # > https://github.com/facebookresearch/mae/blob/main/PRETRAIN.md self.lr = self.base_lr * (self.effective_bsz / 256) # Create Optimizer & LR Scheduler optimizer = torch.optim.AdamW(groups, lr=self.lr, betas=self.betas) update_lr = get_lr_update(optimizer, self.schedule, self.lr, self.min_lr, self.warmup_epochs, self.max_epochs) return optimizer, update_lr ================================================ FILE: voltron/models/reproductions/vr3m.py ================================================ """ vr3m.py PyTorch Module defining an R3M model (with a ViT encoder), with the remainder as described in Nair et. al. 2021, with all the requisite hyperparameters. Reference: - https://github.com/facebookresearch/r3m """ from typing import Callable, Tuple import torch import torch.nn as nn import transformers from einops import rearrange from voltron.models.util.optimization import get_lr_update from voltron.models.util.transformer import Block, PatchEmbed, get_2D_position_embeddings # Suppress Transformers Logging transformers.logging.set_verbosity_error() class VR3M(nn.Module): def __init__( self, resolution: int, patch_size: int, depth: int, embed_dim: int, n_heads: int, language_model: str, hf_cache: str, language_dim: int, reward_dim: int, n_negatives: int, lang_reward_weight: float, tcn_weight: float, l1_weight: float, l2_weight: float, optimizer: str, schedule: str, lr: float, min_lr: float, warmup_epochs: int, max_epochs: int, mlp_ratio: float = 4.0, in_channels: int = 3, eps: float = 1e-8, ): """ Initialize a ViT R3M model with the requisite architecture parameters. :param resolution: Base image resolution -- usually 224 (ImageNet size). :param patch_size: Height/Width of each patch in pixels -- usually 16. :param depth: Number of Transformer blocks in the ViT image encoder. :param embed_dim: Core embedding/hidden dimension for the vision transformer backbone. :param n_heads: Number of heads for multi-headed self-attention. :param language_model: Language model to freeze for encoding narrations/utterances. :param hf_cache: Cache directory to store pretrained models, for safe distributed training. :param language_dim: Dimensionality of the language embedding coming out of the pretrained LM. :param reward_dim: Hidden layer dimensionality for the language-reward MLP. :param n_negatives: Number of cross-batch negatives to sample for contrastive learning. :param lang_reward_weight: Weight applied to the contrastive "language alignment" loss term. :param tcn_weight: Weight applied to the time contrastive loss term. :param l1_weight: Weight applied to the L1 regularization loss term. :param l2_weight: Weight applied to the L2 regularization loss term. :param optimizer: String denoting which optimizer to use (for R3M, usually `adam`) :param schedule: Learning rate schedule to use; for Transformers a linear warmup + decay is recommended! :param lr: Peak learning rate to use over the course of training -- warms up to this value, then decays. :param min_lr: Minimum learning rate to decay to over the course of learning (usually 0.0) :param warmup_epochs: Number of epochs to warmup learning rate for linear warmup schedule. :param max_epochs: Total number of training epochs to be run. :param mlp_ratio: Ratio for embedding size to Position-wise FeedForward MLP (gets shrunk back down; default 4). :param in_channels: Default number of channels in the base image -- almost always 3. :param eps: Epsilon for preventing divide by zero in the InfoNCE loss terms. """ super().__init__() self.resolution, self.patch_size, self.n_negatives, self.eps = resolution, patch_size, n_negatives, eps self.optimizer, self.schedule, self.lr, self.min_lr = optimizer, schedule, lr, min_lr self.warmup_epochs, self.max_epochs = warmup_epochs, max_epochs self.mlp_ratio, self.in_channels = mlp_ratio, in_channels self.language_dim, self.reward_dim = language_dim, reward_dim # Weights for each loss term self.lang_reward_weight, self.tcn_weight = lang_reward_weight, tcn_weight self.l1_weight, self.l2_weight = l1_weight, l2_weight # ViT Backbone self.depth, self.embed_dim, self.n_heads = depth, embed_dim, n_heads # Create ViT --> some differences from the original architectures in "An Image is Worth 16x16 Words": # > Namely, subset of authors show that (https://arxiv.org/abs/2205.01580): # 1) 2D sinusoidal embeddings are better than learned embeddings... # 2) Average pooling is better/simpler than learning a [CLS] token embedding... # Note: Output "embedding" is just mean pooled final transformer layer --> after extra layer norm! # > Ref: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L446 self.patch2embed = PatchEmbed(self.resolution, self.patch_size, self.embed_dim, in_channels=self.in_channels) self.pe = nn.Parameter(torch.zeros(1, self.patch2embed.num_patches, self.embed_dim), requires_grad=False) self.blocks = nn.ModuleList([Block(self.embed_dim, self.n_heads, self.mlp_ratio) for _ in range(self.depth)]) self.norm = nn.LayerNorm(self.embed_dim, eps=1e-6) # Create Language Reward Model self.language_reward = nn.Sequential( nn.Linear(self.embed_dim + self.embed_dim + self.language_dim, self.reward_dim), nn.ReLU(), nn.Linear(self.reward_dim, self.reward_dim), nn.ReLU(), nn.Linear(self.reward_dim, self.reward_dim), nn.ReLU(), nn.Linear(self.reward_dim, self.reward_dim), nn.ReLU(), nn.Linear(self.reward_dim, 1), nn.Sigmoid(), ) # Initialize all Weights self.initialize_weights() # @AFTER INITIALIZATION -- Create Language Model & Language Reward MLP --> LM has requires_grad = False # > For BERT models, our "embedding" is just going to be the last hidden state # > Assumes inputs to forward pass are pre-tokenized! self.tokenizer = transformers.AutoTokenizer.from_pretrained(language_model, cache_dir=hf_cache) self.lm = transformers.AutoModel.from_pretrained(language_model, cache_dir=hf_cache) self.lm.eval() # Shape Assertion -- make sure self.language_dim actually is the same as the LM dimension! assert self.lm.config.dim == self.language_dim, "Language model embedding dimension != self.language_dim!" # Freeze the LM for _name, param in self.lm.named_parameters(): param.requires_grad = False def initialize_weights(self) -> None: # Position Encoding -- Fixed 2D Sine-Cosine Embeddings pe = get_2D_position_embeddings(self.embed_dim, int(self.patch2embed.num_patches**0.5)) self.pe.data.copy_(torch.from_numpy(pe).float().unsqueeze(0)) # Initialize PatchEmbedding as a Linear... nn.init.xavier_uniform_(self.patch2embed.proj.weight.data.view([self.patch2embed.proj.weight.data.shape[0], -1])) # Everything else... self.apply(self.transformer_initializer) @staticmethod def transformer_initializer(m: nn.Module) -> None: if isinstance(m, nn.Linear): # Use xavier_uniform following Jax ViT torch.nn.init.xavier_uniform_(m.weight) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0.0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.weight, 1.0) nn.init.constant_(m.bias, 0.0) def get_representations(self, img: torch.Tensor) -> torch.Tensor: """ Given a single image, extract R3M "default" (mean-pooled) dense representation. :param img: Processed batch of images :: [bsz, 3, 224, 224] :return: Extracted R3M dense representation given img input. """ assert img.ndim == 4, "Invalid input to `get_representations()`" patches = self.patch2embed(img) img_embedding = patches + self.pe # Apply Transformer Blocks... for block in self.blocks: img_embedding = block(img_embedding) img_embedding = self.norm(img_embedding) return img_embedding.mean(1, keepdim=True) def encode_images(self, imgs: torch.Tensor) -> torch.Tensor: """Feed images through ViT, get single embedding via global mean pooling.""" patches = self.patch2embed(imgs) img_embedding = patches + self.pe # Apply Transformer Blocks... for block in self.blocks: img_embedding = block(img_embedding) img_embedding = self.norm(img_embedding) # Return mean pooled embeddings return img_embedding.mean(dim=1) def encode_language(self, lang: torch.Tensor, lang_mask: torch.Tensor) -> torch.Tensor: """Encode language by feeding the *pre-tokenized text* through the frozen language model.""" self.lm.eval() with torch.no_grad(): transformer_embeddings = self.lm(lang, attention_mask=lang_mask).last_hidden_state return transformer_embeddings.mean(dim=1) def get_reward(self, initial: torch.Tensor, later: torch.Tensor, lang: torch.Tensor) -> torch.Tensor: return self.language_reward(torch.cat([initial, later, lang], dim=-1)).squeeze() def forward(self, imgs: torch.Tensor, lang: torch.Tensor, lang_mask: torch.Tensor) -> Tuple[torch.tensor, ...]: """ Run a forward pass through the model, computing the *full* R3M loss -- the TCN contrastive loss, the Language Alignment loss, and both sparsity losses, as well as the full loss (which will get optimized)! :param imgs: A [bsz, 5, in_channels, resolution, resolution] tensor of (start, i, j, k, end) sequences. :param lang: Tokenized language of dimensionality [bsz, seq_len] to be fed to the language model. :param lang_mask: Attention mask computed by the tokenizer, as a result of padding to the max_seq_len. :return: Tuple of losses, as follows: > (combined_loss, tcn_loss, reward_loss, l1_loss, l2_loss, tcn_acc, reward_acc) """ # Encode each image separately... feed to transformer... then reshape all_images = rearrange(imgs, "bsz n_states c res1 res2 -> (bsz n_states) c res1 res2", n_states=5) all_embeddings = self.encode_images(all_images) initial, state_i, state_j, state_k, final = rearrange( all_embeddings, "(bsz n_states) embed -> n_states bsz embed", n_states=5 ) # Compute Regularization Losses l1_loss = torch.linalg.norm(all_embeddings, ord=1, dim=-1).mean() l2_loss = torch.linalg.norm(all_embeddings, ord=2, dim=-1).mean() # Compute TCN Loss tcn_loss, tcn_acc = self.get_time_contrastive_loss(state_i, state_j, state_k) # Compute Language Alignment/Predictive Loss lang_reward_loss, rew_acc = self.get_reward_loss(lang, lang_mask, initial, state_i, state_j, state_k, final) # Compute full weighted loss & return... loss = ( (self.l1_weight * l1_loss) + (self.l2_weight * l2_loss) + (self.tcn_weight * tcn_loss) + (self.lang_reward_weight * lang_reward_loss) ) return loss, tcn_loss, lang_reward_loss, l1_loss, l2_loss, tcn_acc, rew_acc @staticmethod def time_similarity(state_x: torch.Tensor, state_y: torch.Tensor, use_l2: bool = True) -> torch.Tensor: """Computes similarity between embeddings via -L2 distance.""" assert use_l2, "Non-L2 time-similarity functions not yet implemented!" return -torch.linalg.norm(state_x - state_y, dim=-1) def get_time_contrastive_loss( self, state_i: torch.Tensor, state_j: torch.Tensor, state_k: torch.Tensor ) -> Tuple[torch.Tensor, ...]: """Evaluates the Time-Contrastive Loss, computed via InfoNCE.""" # *Punchline* - we want `sim(i, j)` to be higher than `sim(i, k)` for some k > j (goes both ways) # `Reward(s*_0, s*_ As our positive examples --> we sample (s_i, s_j) and (s_j, s_k). # > Our negatives --> other pairs from the triplet, cross-batch negatives! sim_i_j_exp = torch.exp(self.time_similarity(state_i, state_j)) sim_j_k_exp = torch.exp(self.time_similarity(state_j, state_k)) # Add a "hard" negative! neg_i_k_exp = torch.exp(self.time_similarity(state_i, state_k)) # Obtain *cross-batch* negatives bsz, neg_i, neg_j = state_i.shape[0], [], [] for _ in range(self.n_negatives): neg_idx = torch.randperm(bsz) state_i_shuf = state_i[neg_idx] neg_idx = torch.randperm(bsz) state_j_shuf = state_j[neg_idx] neg_i.append(self.time_similarity(state_i, state_i_shuf)) neg_j.append(self.time_similarity(state_j, state_j_shuf)) neg_i_exp, neg_j_exp = torch.exp(torch.stack(neg_i, -1)), torch.exp(torch.stack(neg_j, -1)) # Compute InfoNCE denominator_i = sim_i_j_exp + neg_i_k_exp + neg_i_exp.sum(-1) denominator_j = sim_j_k_exp + neg_i_k_exp + neg_j_exp.sum(-1) nce_i = -torch.log(self.eps + (sim_i_j_exp / (self.eps + denominator_i))) nce_j = -torch.log(self.eps + (sim_j_k_exp / (self.eps + denominator_j))) nce = (nce_i + nce_j) / 2 # Compute "accuracy" i_j_acc = (1.0 * (sim_i_j_exp > neg_i_k_exp)).mean() j_k_acc = (1.0 * (sim_j_k_exp > neg_i_k_exp)).mean() acc = (i_j_acc + j_k_acc) / 2 return nce.mean(), acc def get_reward_loss( self, lang: torch.Tensor, lang_mask: torch.Tensor, initial: torch.Tensor, state_i: torch.Tensor, state_j: torch.Tensor, state_k: torch.Tensor, final: torch.Tensor, ) -> Tuple[torch.Tensor, ...]: """Evaluates the Language-Alignment Reward Loss, computed via InfoNCE.""" lang_embed = self.encode_language(lang, lang_mask) # *Punchline* - we want `Reward(s_0, s_t, l)` to be higher than `Reward(s_0, s_ As our positive examples --> we sample s_j, s_k, and s_final (excluding s_i) pos_final_exp = torch.exp(self.get_reward(initial, final, lang_embed)) pos_j_exp = torch.exp(self.get_reward(initial, state_j, lang_embed)) pos_k_exp = torch.exp(self.get_reward(initial, state_k, lang_embed)) # Add the within-context negatives <--> these are the most informative examples! # > We use initial, initial as a negative for the first one, just to get reward model to "capture progress" negs_final = [self.get_reward(initial, initial, lang_embed)] negs_j = [self.get_reward(initial, state_i, lang_embed)] negs_k = [self.get_reward(initial, state_j, lang_embed)] # Cross Batch Negatives -- same as positives (indexing), but from a different batch! # > @SK :: Unclear how well this will unroll on TPUs... bsz = initial.shape[0] for _ in range(self.n_negatives): # We get three random indices to further minimize correlation... from the R3M codebase! neg_idx = torch.randperm(bsz) negs_final.append(self.get_reward(initial[neg_idx], final[neg_idx], lang_embed)) neg_idx = torch.randperm(bsz) negs_j.append(self.get_reward(initial[neg_idx], state_j[neg_idx], lang_embed)) neg_idx = torch.randperm(bsz) negs_k.append(self.get_reward(initial[neg_idx], state_k[neg_idx], lang_embed)) # Flatten & exponentiate; get ready for the InfoNCE negs_final, negs_j, negs_k = torch.stack(negs_final, -1), torch.stack(negs_j, -1), torch.stack(negs_k, -1) negs_final_exp, negs_j_exp, negs_k_exp = torch.exp(negs_final), torch.exp(negs_j), torch.exp(negs_k) # Compute InfoNCE denominator_final = pos_final_exp + negs_final_exp.sum(-1) denominator_j = pos_j_exp + negs_j_exp.sum(-1) denominator_k = pos_k_exp + negs_k_exp.sum(-1) nce_final = -torch.log(self.eps + (pos_final_exp / (self.eps + denominator_final))) nce_j = -torch.log(self.eps + (pos_j_exp / (self.eps + denominator_j))) nce_k = -torch.log(self.eps + (pos_k_exp / (self.eps + denominator_k))) # Compute "accuracy" acc_final = (1.0 * (negs_final_exp.max(dim=-1)[0] < pos_final_exp)).mean() acc_j = (1.0 * (negs_j_exp.max(dim=-1)[0] < pos_j_exp)).mean() acc_k = (1.0 * (negs_k_exp.max(dim=-1)[0] < pos_k_exp)).mean() acc = (acc_final + acc_j + acc_k) / 3 nce = (nce_final + nce_j + nce_k) / 3 return nce.mean(), acc def configure_optimizer(self) -> Tuple[torch.optim.Optimizer, Callable[[int, float], float]]: # Short-Circuit on Valid Optimizers if self.optimizer not in ["adam"]: raise NotImplementedError(f"Optimizer `{self.optimizer}` not supported - try [`adam`] instead!") # Create Optimizer & LR Scheduler optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) update_lr = get_lr_update(optimizer, self.schedule, self.lr, self.min_lr, self.warmup_epochs, self.max_epochs) return optimizer, update_lr ================================================ FILE: voltron/models/reproductions/vrn3m.py ================================================ """ vrn3m.py PyTorch Module defining an R3M model (with a ResNet 50 encoder), exactly as described in Nair et. al. 2021, with all the requisite hyperparameters. Reference: - https://github.com/facebookresearch/r3m """ from typing import Callable, Tuple import torch import torch.nn as nn import transformers from einops import rearrange from torchvision.models import resnet50 from voltron.models.util.optimization import get_lr_update # Suppress Transformers Logging transformers.logging.set_verbosity_error() class VRN3M(nn.Module): def __init__( self, resolution: int, fc_dim: int, language_model: str, hf_cache: str, language_dim: int, reward_dim: int, n_negatives: int, lang_reward_weight: float, tcn_weight: float, l1_weight: float, l2_weight: float, optimizer: str, lr: float, eps: float = 1e-8, ): """ Initialize an ResNet-50 R3M model with the required architecture parameters. :param resolution: Base image resolution -- usually 224 (ImageNet size). :param fc_dim: Dimensionality of the pooled embedding coming out of the ResNet (for RN50, fc_dim = 2048) :param language_model: Language model to freeze for encoding narrations/utterances. :param hf_cache: Cache directory to store pretrained models, for safe distributed training. :param language_dim: Dimensionality of the language embedding coming out of the pretrained LM. :param reward_dim: Hidden layer dimensionality for the language-reward MLP. :param n_negatives: Number of cross-batch negatives to sample for contrastive learning. :param lang_reward_weight: Weight applied to the contrastive "language alignment" loss term. :param tcn_weight: Weight applied to the time contrastive loss term. :param l1_weight: Weight applied to the L1 regularization loss term. :param l2_weight: Weight applied to the L2 regularization loss term. :param optimizer: String denoting which optimizer to use (for R3M, usually `adam`). :param lr: Learning rate (fixed for ResNet R3M models) for training. :param eps: Epsilon for preventing divide by zero in the InfoNCE loss terms. """ super().__init__() self.resolution, self.fc_dim, self.n_negatives, self.eps = resolution, fc_dim, n_negatives, eps self.language_dim, self.reward_dim, self.optimizer, self.lr = language_dim, reward_dim, optimizer, lr self.embed_dim = self.fc_dim # Weights for each loss term self.lang_reward_weight, self.tcn_weight = lang_reward_weight, tcn_weight self.l1_weight, self.l2_weight = l1_weight, l2_weight # Create ResNet50 --> set `rn.fc` to the Identity() to extract final features of dim = `fc_dim` self.resnet = resnet50(weights=None) self.resnet.fc = nn.Identity() self.resnet.train() # Create Language Reward Model self.language_reward = nn.Sequential( nn.Linear(self.fc_dim + self.fc_dim + self.language_dim, self.reward_dim), nn.ReLU(), nn.Linear(self.reward_dim, self.reward_dim), nn.ReLU(), nn.Linear(self.reward_dim, self.reward_dim), nn.ReLU(), nn.Linear(self.reward_dim, self.reward_dim), nn.ReLU(), nn.Linear(self.reward_dim, 1), nn.Sigmoid(), ) # Create Language Model & Language Reward MLP --> LM has requires_grad = False # > For BERT models, our "embedding" is just going to be the last hidden state # > Assumes inputs to forward pass are pre-tokenized! self.tokenizer = transformers.AutoTokenizer.from_pretrained(language_model, cache_dir=hf_cache) self.lm = transformers.AutoModel.from_pretrained(language_model, cache_dir=hf_cache) self.lm.eval() # Shape Assertion -- make sure self.language_dim actually is the same as the LM dimension! assert self.lm.config.dim == self.language_dim, "Language model embedding dimension != self.language_dim!" # Freeze the LM for _name, param in self.lm.named_parameters(): param.requires_grad = False def get_representations(self, img: torch.Tensor) -> torch.Tensor: """ Given a single image, extract R3M "default" (ResNet pooled) dense representation. :param img: Processed batch of images :: [bsz, 3, 224, 224] :return: Extracted R3M dense representation given img input. """ assert img.ndim == 4, "Invalid input to `get_representations()`" representation = self.resnet(img) return representation.unsqueeze(1) def encode_images(self, imgs: torch.Tensor) -> torch.Tensor: """Feed images through ResNet-50 to get single embedding after global average pooling.""" return self.resnet(imgs) def encode_language(self, lang: torch.Tensor, lang_mask: torch.Tensor) -> torch.Tensor: """Encode language by feeding the *pre-tokenized text* through the frozen language model.""" self.lm.eval() with torch.no_grad(): transformer_embeddings = self.lm(lang, attention_mask=lang_mask).last_hidden_state return transformer_embeddings.mean(dim=1) def get_reward(self, initial: torch.Tensor, later: torch.Tensor, lang: torch.Tensor) -> torch.Tensor: return self.language_reward(torch.cat([initial, later, lang], dim=-1)).squeeze() def extract_features(self, img: torch.Tensor) -> torch.Tensor: """Run a single image of shape [1, 3, 224, 224] through the ResNet and extract the feature.""" return self.encode_images(img).detach() def forward(self, imgs: torch.Tensor, lang: torch.Tensor, lang_mask: torch.Tensor) -> Tuple[torch.Tensor, ...]: """ Run a forward pass through the model, computing the *full* R3M loss -- the TCN contrastive loss, the Language Alignment loss, and both sparsity losses, as well as the full loss (which will get optimized)! :param imgs: A [bsz, 5, in_channels, resolution, resolution] tensor of (start, i, j, k, end) sequences. :param lang: Tokenized language of dimensionality [bsz, seq_len] to be fed to the language model. :param lang_mask: Attention mask computed by the tokenizer, as a result of padding to the max_seq_len. :return: Tuple of losses, as follows: > (combined_loss, tcn_loss, reward_loss, l1_loss, l2_loss, tcn_acc, reward_acc) """ # Encode each image separately... feed to transformer... then reshape all_images = rearrange(imgs, "bsz n_states c res1 res2 -> (bsz n_states) c res1 res2", n_states=5) all_embeddings = self.encode_images(all_images) initial, state_i, state_j, state_k, final = rearrange( all_embeddings, "(bsz n_states) embed -> n_states bsz embed", n_states=5 ) # Compute Regularization Losses l1_loss = torch.linalg.norm(all_embeddings, ord=1, dim=-1).mean() l2_loss = torch.linalg.norm(all_embeddings, ord=2, dim=-1).mean() # Compute TCN Loss tcn_loss, tcn_acc = self.get_time_contrastive_loss(state_i, state_j, state_k) # Compute Language Alignment/Predictive Loss lang_reward_loss, rew_acc = self.get_reward_loss(lang, lang_mask, initial, state_i, state_j, state_k, final) # Compute full weighted loss & return... loss = ( (self.l1_weight * l1_loss) + (self.l2_weight * l2_loss) + (self.tcn_weight * tcn_loss) + (self.lang_reward_weight * lang_reward_loss) ) return loss, tcn_loss, lang_reward_loss, l1_loss, l2_loss, tcn_acc, rew_acc @staticmethod def time_similarity(state_x: torch.Tensor, state_y: torch.Tensor, use_l2: bool = True) -> torch.Tensor: """Computes similarity between embeddings via -L2 distance.""" assert use_l2, "Non-L2 time-similarity functions not yet implemented!" return -torch.linalg.norm(state_x - state_y, dim=-1) def get_time_contrastive_loss( self, state_i: torch.Tensor, state_j: torch.Tensor, state_k: torch.Tensor ) -> Tuple[torch.Tensor, ...]: """Evaluates the Time-Contrastive Loss, computed via InfoNCE.""" # *Punchline* - we want `sim(i, j)` to be higher than `sim(i, k)` for some k > j (goes both ways) # `Reward(s*_0, s*_ As our positive examples --> we sample (s_i, s_j) and (s_j, s_k). # > Our negatives --> other pairs from the triplet, cross-batch negatives! sim_i_j_exp = torch.exp(self.time_similarity(state_i, state_j)) sim_j_k_exp = torch.exp(self.time_similarity(state_j, state_k)) # Add a "hard" negative! neg_i_k_exp = torch.exp(self.time_similarity(state_i, state_k)) # Obtain *cross-batch* negatives bsz, neg_i, neg_j = state_i.shape[0], [], [] for _ in range(self.n_negatives): neg_idx = torch.randperm(bsz) state_i_shuf = state_i[neg_idx] neg_idx = torch.randperm(bsz) state_j_shuf = state_j[neg_idx] neg_i.append(self.time_similarity(state_i, state_i_shuf)) neg_j.append(self.time_similarity(state_j, state_j_shuf)) neg_i_exp, neg_j_exp = torch.exp(torch.stack(neg_i, -1)), torch.exp(torch.stack(neg_j, -1)) # Compute InfoNCE denominator_i = sim_i_j_exp + neg_i_k_exp + neg_i_exp.sum(-1) denominator_j = sim_j_k_exp + neg_i_k_exp + neg_j_exp.sum(-1) nce_i = -torch.log(self.eps + (sim_i_j_exp / (self.eps + denominator_i))) nce_j = -torch.log(self.eps + (sim_j_k_exp / (self.eps + denominator_j))) nce = (nce_i + nce_j) / 2 # Compute "accuracy" i_j_acc = (1.0 * (sim_i_j_exp > neg_i_k_exp)).mean() j_k_acc = (1.0 * (sim_j_k_exp > neg_i_k_exp)).mean() acc = (i_j_acc + j_k_acc) / 2 return nce.mean(), acc def get_reward_loss( self, lang: torch.Tensor, lang_mask: torch.Tensor, initial: torch.Tensor, state_i: torch.Tensor, state_j: torch.Tensor, state_k: torch.Tensor, final: torch.Tensor, ) -> Tuple[torch.Tensor, ...]: """Evaluates the Language-Alignment Reward Loss, computed via InfoNCE.""" lang_embed = self.encode_language(lang, lang_mask) # *Punchline* - we want `Reward(s_0, s_t, l)` to be higher than `Reward(s_0, s_ As our positive examples --> we sample s_j, s_k, and s_final (excluding s_i) pos_final_exp = torch.exp(self.get_reward(initial, final, lang_embed)) pos_j_exp = torch.exp(self.get_reward(initial, state_j, lang_embed)) pos_k_exp = torch.exp(self.get_reward(initial, state_k, lang_embed)) # Add the within-context negatives <--> these are the most informative examples! # > We use initial, initial as a negative for the first one, just to get reward model to "capture progress" negs_final = [self.get_reward(initial, initial, lang_embed)] negs_j = [self.get_reward(initial, state_i, lang_embed)] negs_k = [self.get_reward(initial, state_j, lang_embed)] # Cross Batch Negatives -- same as positives (indexing), but from a different batch! # > @SK :: Unclear how well this will unroll on TPUs... bsz = initial.shape[0] for _ in range(self.n_negatives): # We get three random indices to further minimize correlation... from the R3M codebase! neg_idx = torch.randperm(bsz) negs_final.append(self.get_reward(initial[neg_idx], final[neg_idx], lang_embed)) neg_idx = torch.randperm(bsz) negs_j.append(self.get_reward(initial[neg_idx], state_j[neg_idx], lang_embed)) neg_idx = torch.randperm(bsz) negs_k.append(self.get_reward(initial[neg_idx], state_k[neg_idx], lang_embed)) # Flatten & exponentiate; get ready for the InfoNCE negs_final, negs_j, negs_k = torch.stack(negs_final, -1), torch.stack(negs_j, -1), torch.stack(negs_k, -1) negs_final_exp, negs_j_exp, negs_k_exp = torch.exp(negs_final), torch.exp(negs_j), torch.exp(negs_k) # Compute InfoNCE denominator_final = pos_final_exp + negs_final_exp.sum(-1) denominator_j = pos_j_exp + negs_j_exp.sum(-1) denominator_k = pos_k_exp + negs_k_exp.sum(-1) nce_final = -torch.log(self.eps + (pos_final_exp / (self.eps + denominator_final))) nce_j = -torch.log(self.eps + (pos_j_exp / (self.eps + denominator_j))) nce_k = -torch.log(self.eps + (pos_k_exp / (self.eps + denominator_k))) # Compute "accuracy" acc_final = (1.0 * (negs_final_exp.max(dim=-1)[0] < pos_final_exp)).mean() acc_j = (1.0 * (negs_j_exp.max(dim=-1)[0] < pos_j_exp)).mean() acc_k = (1.0 * (negs_k_exp.max(dim=-1)[0] < pos_k_exp)).mean() acc = (acc_final + acc_j + acc_k) / 3 nce = (nce_final + nce_j + nce_k) / 3 return nce.mean(), acc def configure_optimizer(self) -> Tuple[torch.optim.Optimizer, Callable[[int, float], float]]: # Short-Circuit on Valid Optimizers if self.optimizer not in ["adam"]: raise NotImplementedError(f"Optimizer `{self.optimizer}` not supported - try [`adam`] instead!") # Create Optimizer and (No-Op) LR Scheduler optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) update_lr = get_lr_update( optimizer, schedule="none", lr=self.lr, min_lr=self.lr, warmup_epochs=-1, max_epochs=-1 ) return optimizer, update_lr ================================================ FILE: voltron/models/util/__init__.py ================================================ from .extraction import instantiate_extractor ================================================ FILE: voltron/models/util/extraction.py ================================================ """ extraction.py General Extraction module definitions & associated utilities. References: - Set Transformers (MAP): https://arxiv.org/abs/1810.00825.pdf """ from typing import Callable import torch import torch.nn as nn from einops import repeat from voltron.models.util.transformer import RMSNorm, SwishGLU # === Multiheaded Attention Pooling === # As defined in Set Transformers () -- basically the above, additionally taking in # a set of $k$ learned "seed vectors" that are used to "pool" information. class MAPAttention(nn.Module): def __init__(self, embed_dim: int, n_heads: int) -> None: """Multi-Input Multi-Headed Attention Operation""" super().__init__() assert embed_dim % n_heads == 0, "`embed_dim` must be divisible by `n_heads`!" self.n_heads, self.scale = n_heads, (embed_dim // n_heads) ** -0.5 # Projections (no bias) --> separate for Q (seed vector), and KV ("pool" inputs) self.q, self.kv = nn.Linear(embed_dim, embed_dim, bias=False), nn.Linear(embed_dim, 2 * embed_dim, bias=False) self.proj = nn.Linear(embed_dim, embed_dim) def forward(self, seed: torch.Tensor, x: torch.Tensor) -> torch.Tensor: (B_s, K, C_s), (B_x, N, C_x) = seed.shape, x.shape assert C_s == C_x, "Seed vectors and pool inputs must have the same embedding dimensionality!" # Project Seed Vectors to `queries` q = self.q(seed).reshape(B_s, K, self.n_heads, C_s // self.n_heads).permute(0, 2, 1, 3) kv = self.kv(x).reshape(B_x, N, 2, self.n_heads, C_x // self.n_heads).permute(2, 0, 3, 1, 4) k, v = kv.unbind(0) # Attention --> compute weighted sum over values! scores = q @ (k.transpose(-2, -1) * self.scale) attn = scores.softmax(dim=-1) vals = (attn @ v).transpose(1, 2).reshape(B_s, K, C_s) # Project back to `embed_dim` return self.proj(vals) class MAPBlock(nn.Module): def __init__( self, n_latents: int, embed_dim: int, n_heads: int, mlp_ratio: float = 4.0, do_rms_norm: bool = True, do_swish_glu: bool = True, ) -> None: """Multiheaded Attention Pooling Block -- note that for MAP, we adopt earlier post-norm conventions.""" super().__init__() self.n_latents, self.embed_dim, self.n_heads = n_latents, embed_dim, 2 * n_heads # Projection Operator self.projection = nn.Linear(embed_dim, self.embed_dim) # Initialize Latents self.latents = nn.Parameter(torch.zeros(self.n_latents, self.embed_dim)) nn.init.normal_(self.latents, std=0.02) # Custom MAP Attention (seed, encoder outputs) -> seed self.attn_norm = RMSNorm(self.embed_dim) if do_rms_norm else nn.LayerNorm(self.embed_dim, eps=1e-6) self.attn = MAPAttention(self.embed_dim, n_heads=self.n_heads) # Position-wise Feed-Forward Components self.mlp_norm = RMSNorm(self.embed_dim) if do_rms_norm else nn.LayerNorm(self.embed_dim, eps=1e-6) self.mlp = nn.Sequential( # Handle SwishGLU vs. GELU MLP... ( SwishGLU(self.embed_dim, int(mlp_ratio * self.embed_dim)) if do_swish_glu else nn.Sequential(nn.Linear(self.embed_dim, int(mlp_ratio * self.embed_dim)), nn.GELU()) ), nn.Linear(int(mlp_ratio * self.embed_dim), self.embed_dim), ) def forward(self, x: torch.Tensor) -> torch.Tensor: latents = repeat(self.latents, "n_latents d -> bsz n_latents d", bsz=x.shape[0]) latents = self.attn_norm(latents + self.attn(latents, self.projection(x))) latents = self.mlp_norm(latents + self.mlp(latents)) return latents.squeeze(dim=1) # MAP Extractor Instantiation --> factory for creating extractors with the given parameters. def instantiate_extractor(backbone: nn.Module, n_latents: int = 1) -> Callable[[], nn.Module]: def initialize() -> nn.Module: return MAPBlock(n_latents, backbone.embed_dim, backbone.n_heads) return initialize ================================================ FILE: voltron/models/util/optimization.py ================================================ """ optimization.py General utilities for optimization, e.g., schedulers such as Linear Warmup w/ Cosine Decay for Transformer training. Notably *does not* use the base PyTorch LR Scheduler, since we call it continuously, across epochs, across steps; PyTorch has no built-in way of separating the two without coupling to the DataLoader, so may as well make this explicit in the parent loop. References - MAE: https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/lr_sched.py - ⚡️-Bolts: https://github.com/PyTorchLightning/lightning-bolts/blob/master/pl_bolts/optimizers/lr_scheduler.py """ import math from typing import Callable from torch.optim.optimizer import Optimizer def get_lr_update( opt: Optimizer, schedule: str, lr: float, min_lr: float, warmup_epochs: int, max_epochs: int ) -> Callable[[int, float], float]: if schedule == "linear-warmup+cosine-decay": def lr_update(epoch: int, fractional_progress: float) -> float: """Run the warmup check for linear increase, else cosine decay.""" if (epoch + fractional_progress) < warmup_epochs: new_lr = lr * (epoch + fractional_progress) / max(1.0, warmup_epochs) else: # Cosine Decay --> as defined in the SGDR Paper... progress = ((epoch + fractional_progress) - warmup_epochs) / max(1.0, max_epochs - warmup_epochs) new_lr = min_lr + (lr - min_lr) * (0.5 * (1 + math.cos(math.pi * progress))) # Apply... for group in opt.param_groups: if "lr_scale" in group: group["lr"] = new_lr * group["lr_scale"] else: group["lr"] = new_lr return new_lr elif schedule == "none": def lr_update(_: int, __: float) -> float: return lr else: raise NotImplementedError(f"Schedule `{schedule}` not implemented!") return lr_update ================================================ FILE: voltron/models/util/transformer.py ================================================ """ transformer.py General Transformer modules & utilities. References: - https://github.com/facebookresearch/mae - https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py """ from typing import Optional import numpy as np import torch import torch.nn as nn from einops import rearrange # === Position Encoding Utilities === # Helper/Utility Function -- computes simple 1D sinusoidal position embeddings for both 1D/2D use cases. # > We'll be combining two 1D sin-cos (traditional) position encodings for height/width of an image (grid features). def get_1D_sine_cosine(dim: int, pos: np.ndarray) -> np.ndarray: omega = np.arange(dim // 2, dtype=np.float32) / (dim / 2.0) omega = 1.0 / (10000**omega) out = np.einsum("m,d->md", pos.reshape(-1), omega) # [flatten(pos) x omega] -- outer product! emb_sin, emb_cos = np.sin(out), np.cos(out) return np.concatenate([emb_sin, emb_cos], axis=1) # [flatten(pos) x D] # 1D Sine-Cosine Position Embedding -- standard from "Attention is all you need!" def get_1D_position_embeddings(embed_dim: int, length: int) -> np.ndarray: return get_1D_sine_cosine(embed_dim, np.arange(length)) # 2D Sine-Cosine Position Embedding (from MAE repository) # > https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 def get_2D_position_embeddings(embed_dim: int, grid_size: int, cls_token: bool = False) -> np.ndarray: # Create 2D Position embeddings by taking cross product of height and width and splicing 1D embeddings... grid_h, grid_w = np.arange(grid_size, dtype=np.float32), np.arange(grid_size, dtype=np.float32) grid = np.stack(np.meshgrid(grid_w, grid_h), axis=0).reshape(2, 1, grid_size, grid_size) # w goes first? # Use half of dimensions to encode grid_h, other half to encode grid_w emb_h, emb_w = get_1D_sine_cosine(embed_dim // 2, grid[0]), get_1D_sine_cosine(embed_dim // 2, grid[1]) pos_embed = np.concatenate([emb_h, emb_w], axis=1) # CLS token handling (only for R-MVP) if cls_token: pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) return pos_embed # === Vision Transformer Building Blocks === # Patch Embedding Module class PatchEmbed(nn.Module): def __init__( self, resolution: int, patch_size: int, embed_dim: int, in_channels: int = 3, flatten: bool = True, ): super().__init__() self.resolution, self.patch_size = (resolution, resolution), (patch_size, patch_size) self.grid_size = (self.resolution[0] // self.patch_size[0], self.resolution[1] // self.patch_size[1]) self.num_patches = self.grid_size[0] * self.grid_size[1] self.flatten = flatten self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=self.patch_size, stride=self.patch_size) def forward(self, patches: torch.Tensor) -> torch.Tensor: patch_embeddings = self.proj(patches) if self.flatten: return rearrange(patch_embeddings, "bsz embed patch_h patch_w -> bsz (patch_h patch_w) embed") return patch_embeddings # === Stability Utilities === # LayerScale -- Trainable scaling for residual blocks -- Mistral/CaIT class LayerScale(nn.Module): def __init__(self, dim: int, init_values: float = 0.1) -> None: # CaIT :: 0.1 -> lay 12, 1e-5 -> lay 24, 1e-6... super().__init__() self.gamma = nn.Parameter(init_values * torch.ones(dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: return x * self.gamma # RMSNorm -- Better, simpler alternative to LayerNorm class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-8) -> None: super().__init__() self.scale, self.eps = dim**-0.5, eps self.g = nn.Parameter(torch.ones(dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: norm = torch.norm(x, dim=-1, keepdim=True) * self.scale return x / norm.clamp(min=self.eps) * self.g # SwishGLU -- A Gated Linear Unit (GLU) with the Swish activation; always better than GELU MLP! class SwishGLU(nn.Module): def __init__(self, in_dim: int, out_dim: int) -> None: super().__init__() self.act, self.project = nn.SiLU(), nn.Linear(in_dim, 2 * out_dim) def forward(self, x: torch.Tensor) -> torch.Tensor: projected, gate = self.project(x).tensor_split(2, dim=-1) return projected * self.act(gate) # === Fundamental Transformer Building Blocks === class Attention(nn.Module): def __init__(self, embed_dim: int, n_heads: int, dropout: float = 0.0) -> None: """Multi-Headed Self-Attention Operation""" super().__init__() assert embed_dim % n_heads == 0, "`embed_dim` must be divisible by `n_heads`!" self.n_heads, self.scale = n_heads, (embed_dim // n_heads) ** -0.5 self.attn_softmax = None # Projections self.qkv, self.proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True), nn.Linear(embed_dim, embed_dim) self.dropout = nn.Dropout(dropout) def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: B, N, C = x.shape # Project to Q-K-V qkv = self.qkv(x).reshape(B, N, 3, self.n_heads, C // self.n_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) # Self-attention -- with masking! scores = q @ (k.transpose(-2, -1) * self.scale) if mask is not None: if mask.ndim == 2: mask = rearrange(mask, "bsz seq -> bsz 1 seq 1") elif mask.ndim != 4: raise NotImplementedError("Attention got `mask` of shape not in {2, 4}!") # Mask out by filling indices with negative infinity... scores = scores.masked_fill(mask == 0, torch.finfo(scores.dtype).min) # Compute weighted sum over values self.attn_softmax = scores.softmax(dim=-1) vals = (self.attn_softmax @ v).transpose(1, 2).reshape(B, N, C) # Project back to `embed_dim` -- with optional dropout vals = self.dropout(self.proj(vals)) return vals class Block(nn.Module): def __init__( self, embed_dim: int, n_heads: int, mlp_ratio: float = 4.0, dropout: float = 0.0, do_rms_norm: bool = False, do_swish_glu: bool = False, do_layer_scale: bool = False, ) -> None: """ Transformer Block Implementation (modality-agnostic). :param embed_dim: Core embedding/hidden dimension for vision transformer backbone. :param n_heads: Number of heads for multi-headed self-attention. :param mlp_ratio: Ratio for embedding size to position-wise feed-forward MLP (gets shrunk back down). :param dropout: [Optional] dropout for projection layer and MLPs -- for MAEs, always 0.0! :param do_rms_norm: Boolean whether or not to use RMSNorm in lieu of LayerNorm within block. :param do_swish_glu: Use the Swish-variant of the Gated Linear Unit for the feed-forward layers. :param do_layer_scale: Boolean whether or not to use LayerScale from Mistral/CaIT w/ initialization of 0.1. """ super().__init__() self.embed_dim, self.n_heads, self.do_layer_scale = embed_dim, n_heads, do_layer_scale # Attention Components self.pre_norm_attn = RMSNorm(self.embed_dim) if do_rms_norm else nn.LayerNorm(self.embed_dim, eps=1e-6) self.attn = Attention(self.embed_dim, n_heads=n_heads, dropout=dropout) if do_layer_scale: self.layer_scale_attn = LayerScale(self.embed_dim) # Position-wise Feed-Forward Components self.pre_norm_mlp = RMSNorm(self.embed_dim) if do_rms_norm else nn.LayerNorm(self.embed_dim, eps=1e-6) self.mlp = nn.Sequential( # Handle SwishGLU vs. GELU MLP... ( SwishGLU(embed_dim, int(mlp_ratio * embed_dim)) if do_swish_glu else nn.Sequential(nn.Linear(embed_dim, int(mlp_ratio * embed_dim)), nn.GELU()) ), nn.Dropout(dropout), nn.Linear(int(mlp_ratio * embed_dim), embed_dim), ) if self.do_layer_scale: self.layer_scale_mlp = LayerScale(self.embed_dim) def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: if self.do_layer_scale: x = x + self.layer_scale_attn(self.attn(self.pre_norm_attn(x), mask)) x = x + self.layer_scale_mlp(self.mlp(self.pre_norm_mlp(x))) else: x = x + self.attn(self.pre_norm_attn(x), mask) x = x + self.mlp(self.pre_norm_mlp(x)) return x ================================================ FILE: voltron/overwatch/__init__.py ================================================ from .overwatch import OverwatchRich ================================================ FILE: voltron/overwatch/overwatch.py ================================================ """ overwatch.py Utility class for creating a centralized/standardized logger (to pass to Hydra), with a sane default format. """ from dataclasses import dataclass, field from typing import Any, Dict # Overwatch Default Format String FORMATTER, DATEFMT = "[*] %(asctime)s - %(name)s >> %(levelname)s :: %(message)s", "%m/%d [%H:%M:%S]" RICH_FORMATTER = "| >> %(message)s" # Rich Overwatch Variant --> Good for debugging, and tracing! @dataclass class OverwatchRich: version: int = 1 formatters: Dict[str, Any] = field( default_factory=lambda: { "simple-console": {"format": RICH_FORMATTER, "datefmt": DATEFMT}, "simple-file": {"format": FORMATTER, "datefmt": DATEFMT}, } ) handlers: Dict[str, Any] = field( default_factory=lambda: { "console": { "class": "rich.logging.RichHandler", "formatter": "simple-console", "rich_tracebacks": True, "show_level": True, "show_path": True, "show_time": True, }, "file": { "class": "logging.FileHandler", "formatter": "simple-file", "filename": "${hydra.job.name}.log", }, } ) root: Dict[str, Any] = field(default_factory=lambda: {"level": "INFO", "handlers": ["console", "file"]}) disable_existing_loggers: bool = True # Standard Overwatch Variant --> Performant, no bells & whistles @dataclass class OverwatchStandard: version: int = 1 formatters: Dict[str, Any] = field(default_factory=lambda: {"simple": {"format": FORMATTER, "datefmt": DATEFMT}}) handlers: Dict[str, Any] = field( default_factory=lambda: { "console": {"class": "logging.StreamHandler", "formatter": "simple", "stream": "ext://sys.stdout"}, "file": { "class": "logging.FileHandler", "formatter": "simple", "filename": "${hydra.job.name}.log", }, } ) root: Dict[str, Any] = field(default_factory=lambda: {"level": "INFO", "handlers": ["console", "file"]}) disable_existing_loggers: bool = True ================================================ FILE: voltron/preprocessing/__init__.py ================================================ from .process import extract_frames, preprocess_language, unify_batches ================================================ FILE: voltron/preprocessing/core.py ================================================ """ utils.py Preprocessing utilities, including dry-run and single-video (single-example) processing. This file effectively defines the "atomic" logic (take one video --> extract all frames, etc.), while the `process.py` functions invoke each unit in a multiprocessing pool. """ import glob import json import logging import os import time from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Set, Tuple import av import h5py import numpy as np import pandas as pd from hurry.filesize import alternative, size from PIL import Image from rich.progress import track from tqdm import tqdm # Grab Logger overwatch = logging.getLogger(__file__) logging.getLogger("libav").setLevel(logging.ERROR) # === General Utilities === # Videos are saved as `train_dir/{vid}/{vid}_idx={i}.jpg || if `relpath` then *relative path* `{split}/{vid}/... def get_path(save_dir: Path, v: str, i: int, relpath: bool = False) -> str: return str((save_dir if not relpath else Path(save_dir.name)) / v / f"{v}_idx={i}.jpg") # === Dry-Run Functionality === def do_dry_run( name: str, path: str, train_ids: List[str], val_ids: List[str], preprocess_transform: Callable[[List[Image.Image]], List[Image.Image]], n_train_videos: int = 1000, n_val_videos: int = 100, n_samples: int = 1000, ) -> None: """Iterates through a small subset of the total dataset, logs n_frames & average image size for estimation.""" overwatch.info(f"Performing Dry-Run with {n_train_videos} Train Videos and {n_val_videos} Validation Videos") dry_run_metrics = { "n_frames": [], "jpg_sizes": [], "n_samples": n_samples, "time_per_example": [], "blank": str(Path(path) / "blank.jpg"), } # Switch on dataset (`name`) if name == "sth-sth-v2": for k, n_iter, vids in [("train", n_train_videos, train_ids), ("val", n_val_videos, val_ids)]: for idx in track(range(n_iter), description=f"Reading {k.capitalize()} Videos =>> ", transient=True): container = av.open(str(Path(path) / "videos" / f"{vids[idx]}.webm")) assert int(container.streams.video[0].average_rate) == 12, "FPS for `sth-sth-v2` should be 12!" try: imgs = [f.to_image() for f in container.decode(video=0)] except (RuntimeError, ZeroDivisionError) as e: overwatch.error(f"{type(e).__name__}: WebM reader cannot open `{vids[idx]}.webm` - continuing...") continue container.close() # Apply `preprocess_transform` imgs = preprocess_transform(imgs) # Dry-Run Handling --> write a dummy JPEG to collect size statistics, dump, and move on... dry_run_metrics["n_frames"].append(len(imgs)) while dry_run_metrics["n_samples"] > 0 and len(imgs) > 0: img = imgs.pop(0) img.save(str(dry_run_metrics["blank"])) dry_run_metrics["jpg_sizes"].append(os.path.getsize(dry_run_metrics["blank"])) dry_run_metrics["n_samples"] -= 1 # Compute nice totals for "dry-run" estimate... total_clips = len(train_ids) + len(val_ids) else: raise ValueError(f"Dry Run for Dataset `{name}` not implemented!") # Compute aggregate statistics and gently exit... avg_size, avg_frames = np.mean(dry_run_metrics["jpg_sizes"]), int(np.mean(dry_run_metrics["n_frames"])) overwatch.info("Dry-Run Statistics =>>") overwatch.info(f"\t> A video has on average `{avg_frames}` frames at {size(avg_size, system=alternative)}") overwatch.info(f"\t> So - 1 video ~ {size(avg_frames * avg_size, system=alternative)}") overwatch.info( f"\t> With the full dataset of {total_clips} Train + Val videos ~" f" {size(total_clips * avg_frames * avg_size, system=alternative)}" ) overwatch.info("Dry-Run complete, do what you will... exiting ✌️") # Remove dummy file... os.remove(dry_run_metrics["blank"]) exit(0) # === Atomic "Processing" Steps === def process_clip( name: str, path: Path, save: Path, preprocess_transform: Callable[[List[Image.Image]], List[Image.Image]], item: Tuple[str, str], ) -> Tuple[Optional[str], Optional[Dict[str, Any]]]: """Processes a single video clip and extracts/serializes all frames (as jpeg), returning the registry contents.""" if name == "sth-sth-v2": vid, lang = item container, registration = av.open(str(Path(path) / "videos" / f"{vid}.webm")), {"language": lang, "n_frames": 0} assert int(container.streams.video[0].average_rate) == 12, "FPS for `sth-sth-v2` should be 12!" try: imgs = [f.to_image() for f in container.decode(video=0)] except (RuntimeError, ZeroDivisionError) as e: overwatch.error(f"{type(e).__name__}: WebM reader cannot open `{vid}.webm` - continuing...") return None, None container.close() # Book-Keeping os.makedirs(save / vid, exist_ok=True) registration["n_frames"] = len(imgs) # Short Circuit --> Writes are Expensive! if len(glob.glob1(save / vid, "*.jpg")) == len(imgs): return vid, registration # Apply `preprocess_transform` --> write individual frames, register, and move on! imgs = preprocess_transform(imgs) for idx in range(len(imgs)): imgs[idx].save(get_path(save, vid, idx)) # Return title & registration return vid, registration else: raise ValueError(f"Clip Processing for Dataset `{name}` is not implemented!") # ruff: noqa: C901 def serialize_epoch( index_dir: Path, registry: Dict[str, Any], vid_dir: Path, batch_formats: Tuple[Tuple[str, Tuple[str, ...]], ...], do_initial: bool, do_final: bool, initial_final_alpha: float, n_int: int, epoch: int, is_validation: bool = False, ) -> Tuple[int, int, Optional[Set[str]]]: index_file = "validation-batches.json" if is_validation else f"train-epoch={epoch}-batches.json" index_hdf5 = "validation-batches.hdf5" if is_validation else f"train-epoch={epoch}-batches.hdf5" # Short-Circuit if all([(index_dir / key / index_file).exists() for key, _ in batch_formats]): return -1, -1, None # Random seed is inherited from parent process... we want new randomness w/ each process np.random.seed((os.getpid() * int(time.time())) % 123456789) # Create Tracking Variables unique_states, batches = set(), {b: [] for b, _ in batch_formats} # Iterate through Registry --> Note we're using `tqdm` instead of `track` here because of `position` feature! for vid in tqdm(registry.keys(), desc=f"Epoch {epoch}", total=len(registry), position=epoch): # The initial/final states are sampled from the first [0, \alpha) and final 1-\alpha, 1] percent of the video n_frames = registry[vid]["n_frames"] initial_idx, final_idx = 0, n_frames - 1 if do_initial: initial_idx = np.random.randint(0, np.around(n_frames * initial_final_alpha)) if do_final: final_idx = np.random.randint(np.around(n_frames * (1 - initial_final_alpha)), n_frames) # Assertion --> initial_idx < final_idx - len(state_elements) assert initial_idx < final_idx - n_int, "Initial & Final are too close... no way to sample!" # Assume remaining elements are just random "interior" states --> sort to get ordering! sampled_idxs = np.random.choice(np.arange(initial_idx + 1, final_idx), size=n_int, replace=False) sampled_idxs = sorted(list(sampled_idxs)) # Compile full-set "batch" retrieved_states = [get_path(vid_dir, vid, x, relpath=True) for x in [initial_idx, *sampled_idxs] + [final_idx]] # Add batch to index for specific batch_format key... batches[batch_formats[-1][0]].append({"vid": vid, "states": retrieved_states, "n_frames": n_frames}) unique_states.update(retrieved_states) # Add all other batch formats to indices... for key, elements in batch_formats[:-1]: n_states = len([x for x in elements if "state_" in x]) assert (n_states <= 2) or ( n_states == len(retrieved_states) ), f"Strange value of n_states={n_states} > 2 and not equal to total possible of {len(retrieved_states)}" # States are all independent -- each of the retrieved states is its own example... if n_states == 1: for idx in range(len(retrieved_states)): batches[key].append({"vid": vid, "state": retrieved_states[idx], "n_frames": n_frames}) # OK-Context is the only "valid" context for n_states == 2 elif n_states == 2: assert elements == ["state_initial", "state_i", "language"], "n_states = 2 but not 0K context?" # Append 0th state to each of the remaining sampled contexts (usually 2 or 4)... each pair is an example for idx in range(1, len(retrieved_states)): batches[key].append( {"vid": vid, "states": [retrieved_states[0], retrieved_states[idx]], "n_frames": n_frames} ) # We're treating the entire sequence of retrieved states as a single example (for TCN/R3M/Temporal Models) else: batches[key].append({"vid": vid, "states": retrieved_states, "n_frames": n_frames}) # Write JSON Index directly to disk... for key in batches: with open(index_dir / key / index_file, "w") as f: json.dump(batches[key], f) # Write HDF5 Index directly to disk... for key, elements in batch_formats[:-1]: n_states = len([x for x in elements if "state_" in x]) # Create HDF5 File df = pd.DataFrame(batches[key]) h5 = h5py.File(index_dir / key / index_hdf5, "w") for k in ["vid", "n_frames"]: h5.create_dataset(k, data=df[k].values) # Handle "state(s)" --> (image path strings) --> add leading dimension (`n_states`) if n_states == 1: dfs = df["state"].apply(pd.Series) h5.create_dataset("states", data=dfs.values) else: dfs = df["states"].apply(pd.Series) h5.create_dataset("states", data=dfs.values) # Close HDF5 File h5.close() return epoch, len(batches["state"]), unique_states ================================================ FILE: voltron/preprocessing/process.py ================================================ """ process.py Utility functions for preprocessing large-scale video/vision-language datasets in multiple passes, using multiprocessing for parallelization. Exposes a three-phase sequence for preprocessing --> batching data: - Phase I (`extract_frames`): Read in raw (video clip, language) pairs, extract and serialize *all frames* to disk. This script tries to be smart where it can, using multiprocessing.Pool in Phase I to speed up extraction; however, for larger datasets YMMV. You might consider extracting the relevant logic, and using tools like SLURM Job Arrays, AWS Lambda Functions, or GCP Cloud Run to "burst preprocess" data. """ import json import logging import multiprocessing as mp import os import shutil from functools import partial from pathlib import Path from typing import Tuple import torch from rich.progress import track from transformers import AutoTokenizer from voltron.preprocessing.core import do_dry_run, process_clip, serialize_epoch from voltron.preprocessing.transforms import get_preprocess_transform # Grab Logger overwatch = logging.getLogger(__file__) def extract_frames( name: str, path: str, artifact_path: str, preprocess_resolution: int, n_val_videos: int, dry_run: bool = False, ) -> Tuple[Path, Path, Path, Path]: """Phase I: Extract and serialize *all frames* from video clips; uses multiprocessing to parallelize.""" overwatch.info(f"Phase 1 Preprocessing :: Extracting Frames for Dataset `{name}`") # Overview of Return Values: # `t_registry` and `v_registry` =>> store mappings of "video id" -> {metadata} # `t_dir` and `v_dir` =>> store "processed data" (extracted frames) t_dir, v_dir = Path(artifact_path) / name / "train", Path(artifact_path) / name / "val" t_registry, v_registry = t_dir / "registry.json", v_dir / "registry.json" # Short-Circuit if t_registry.exists() and v_registry.exists(): return t_registry, v_registry, t_dir, v_dir # Setup / Book-Keeping os.makedirs(t_dir, exist_ok=True) os.makedirs(v_dir, exist_ok=True) # Retrieve "pre-serialization" frame transform --> we scale down video frames (*while preserving aspect ratios*) # and center crop each frame to `(preprocess_resolution, preprocess_resolution)`; saves on disk space (by a lot!) preprocess_transform = get_preprocess_transform(name, preprocess_resolution=preprocess_resolution) # Switch on dataset (`name`) if name == "sth-sth-v2": with open(Path(path) / "labels/train.json", "r") as f: annotations = json.load(f) train_ids, train_lang = [x["id"] for x in annotations], [x["label"] for x in annotations] with open(Path(path) / "labels/validation.json", "r") as f: annotations = json.load(f)[:n_val_videos] val_ids, val_lang = [x["id"] for x in annotations], [x["label"] for x in annotations] else: raise ValueError(f"Language/Metadata Extraction Pipeline for Dataset `{name}` not implemented!") # Run Dry-Run (if specified) --> single-threaded for debugging if dry_run: do_dry_run(name, path, train_ids, val_ids, preprocess_transform) # Otherwise =>> Iterate through all videos, dump all frames subject to the following structure: # |-> .../processed/something-something-v2/ # |-> / # |-> /frames<0..k>.jpg # # We'll build a single metadata file with a mapping : ("language", n_frames) # > To speed up serialization, we'll use a multiprocessing.Pool and max out CPU workers with mp.Pool(mp.cpu_count()) as pool: for k, save, vids, langs in [("train", t_dir, train_ids, train_lang), ("val", v_dir, val_ids, val_lang)]: overwatch.info(f"\tWriting `{k}` videos to disk...") # Spawn! process_fn, registration = partial(process_clip, name, Path(path), save, preprocess_transform), {} for key, value in track( pool.imap_unordered(process_fn, zip(vids, langs)), total=len(vids), transient=True, ): if key is not None: registration[key] = value # Write Registration to Disk with open(t_registry if k == "train" else v_registry, "w") as f: json.dump(registration, f) # Return Paths to Registry & Extract Directories... return t_registry, v_registry, t_dir, v_dir def preprocess_language( name: str, train_registry: Path, val_registry: Path, artifact_path: str, max_lang_len: int, language_model: str, hf_cache: str, ) -> Path: """Phase II: Iterate through Language Captions/Narrations and Normalize/Tokenize (truncate/pad to max length).""" overwatch.info(f"Phase 2 Preprocessing :: Normalizing & Tokenizing Language for Dataset `{name}`") t_index, v_index = train_registry.parent / "index.pt", val_registry.parent / "index.pt" t_json, v_json = train_registry.parent / "index.json", val_registry.parent / "index.json" index_dir = Path(artifact_path) / name / "index" os.makedirs(index_dir, exist_ok=True) # Short-Circuit if (index_dir / "train-language-index.json").exists() and (index_dir / "val-language-index.json").exists(): return index_dir # Grab Language --> retain metadata for building index structures! with open(train_registry, "r") as f: train_metadata = json.load(f) train = [(vid, train_metadata[vid]["language"], train_metadata[vid]) for vid in train_metadata] with open(val_registry, "r") as f: val_metadata = json.load(f) val = [(vid, val_metadata[vid]["language"], val_metadata[vid]) for vid in val_metadata] # Assemble *all* language language = [x[1] for x in train + val] # Build AutoTokenizer (from `language_model` identifier) tokenizer = AutoTokenizer.from_pretrained(language_model, cache_dir=hf_cache) # If `max_lang_len` not specified, dump some statistics to compute... if max_lang_len == -1: # Naively tokenizes and pads to the "maximum length" of _all_ language... long tail is a problem! encoded_language = tokenizer(language, return_tensors="pt", padding=True) lengths = encoded_language["attention_mask"].sum(dim=1) # Compute a histogram of lengths hist = lengths.float().histc(bins=lengths.max()).int() overwatch.info(f"Histogram: {hist.numpy().tolist()}") raise AssertionError("Compute max length and update dataset configuration!") # Otherwise, we've already set the maximum length, so let's use it! overwatch.info(f"\tTokenizing all language in dataset to maximum length `{max_lang_len}`") encoded_language = tokenizer( language, return_tensors="pt", max_length=max_lang_len, truncation=True, padding="max_length" ) input_ids, attention_mask = encoded_language["input_ids"], encoded_language["attention_mask"] train_input_ids, train_attention_mask = input_ids[: len(train)], attention_mask[: len(train)] val_input_ids, val_attention_mask = input_ids[len(train) :], attention_mask[len(train) :] # Assertion, just to sanity check assert len(val_input_ids) == len(val_attention_mask) == len(val), "Something went wrong tokenizing language..." # Compute `index.pt` contents overwatch.info("\tAssembling `train` and `val` index structures...") train_pt = { train[i][0]: {**train[i][2], **{"input_ids": train_input_ids[i], "attention_mask": train_attention_mask[i]}} for i in range(len(train)) } val_pt = { val[i][0]: {**val[i][2], **{"input_ids": val_input_ids[i], "attention_mask": val_attention_mask[i]}} for i in range(len(val)) } # Additionally dump JSON versions of the same --> downstream interpretability, XLA overwatch.info("JSONifying both Train and Validation Language") train_json, val_json = {}, {} for vid in track(train_pt, description="Train Language :: ", transient=True): train_json[vid] = { "language": train_pt[vid]["language"], "n_frames": train_pt[vid]["n_frames"], "input_ids": train_pt[vid]["input_ids"].numpy().tolist(), "attention_mask": train_pt[vid]["attention_mask"].numpy().tolist(), } for vid in track(val_pt, description="Validation Language :: ", transient=True): val_json[vid] = { "language": val_pt[vid]["language"], "n_frames": val_pt[vid]["n_frames"], "input_ids": val_pt[vid]["input_ids"].numpy().tolist(), "attention_mask": val_pt[vid]["attention_mask"].numpy().tolist(), } # Dump Structures... overwatch.info(f"Saving Torch indices to `{t_index}` and `{v_index}` respectively...") torch.save(train_pt, t_index) torch.save(val_pt, v_index) overwatch.info(f"Saving JSON indices to `{t_json}` and `{v_json}` respectively...") with open(t_json, "w") as f: json.dump(train_json, f) with open(v_json, "w") as f: json.dump(val_json, f) # Pull relevant files out into their own `index` directory... shutil.copy(t_json, index_dir / "train-language-index.json") shutil.copy(v_json, index_dir / "val-language-index.json") return index_dir def unify_batches( name: str, train_registry: Path, val_registry: Path, train_dir: Path, val_dir: Path, index_dir: Path, batch_formats: Tuple[Tuple[str, Tuple[str, ...]], ...], max_epochs: int = 400, initial_final_alpha: float = 0.2, ) -> None: """Phase III: Assemble "Data-Locked" Batches for *all models* for *all epochs* for consistency!""" overwatch.info(f"Phase 3 Preprocessing :: Assembling *Data-Locked* Batches for Dataset `{name}`") # Load Registries with open(train_registry, "r") as f: train_registrations = json.load(f) with open(val_registry, "r") as f: val_registrations = json.load(f) # Assert last element of `batch_formats` assumes all prior subsets... full_set_inputs = set(batch_formats[-1][1]) for _, subset_inputs in batch_formats[:-1]: assert full_set_inputs.issuperset(set(subset_inputs)), "We have a problem with batch formats..." # Assemble Tracking Data b_keys, unique_states = {b[0] for b in batch_formats}, set() # Parse out all "state"-specific Elements... state_elements = [s for s in full_set_inputs if "state_" in s] do_initial, do_final = "state_initial" in state_elements, "state_final" in state_elements n_int = len(state_elements) - 2 if ("state_initial" in state_elements and "state_final" in state_elements) else 0 # Serialize Epochs overwatch.info("\tSerializing Epochs to JSON --> Storing mapping of Epoch -> Image Paths") for b in b_keys: os.makedirs(index_dir / b, exist_ok=True) # We only write the Validation Epoch once --> held constant across *all* of training! overwatch.info("\tWriting Validation Epoch to Disk") val_epoch_idx, _, uniq_s = serialize_epoch( index_dir, val_registrations, val_dir, batch_formats, do_initial, do_final, initial_final_alpha, n_int, epoch=0, is_validation=True, ) # Update Trackers... if val_epoch_idx != -1: unique_states |= uniq_s # Compute length of epochs --> CPU Count should be no higher... epochs, n_frames_per_epoch = list(range(max_epochs)), -1 # Parallelize Train Epoch Serialization overwatch.info("\tPlacing the Train Registry into Shared Memory") manager = mp.Manager() mg_registry = manager.dict(train_registrations) # Multiprocess --> the memory demands here are a bit higher, so limit workers by factor of 4 with mp.Pool(mp.cpu_count() // 4) as pool: overwatch.info("\tWriting Train Batches per Epoch to Disk") precompute_fn = partial( serialize_epoch, index_dir, mg_registry, train_dir, batch_formats, do_initial, do_final, initial_final_alpha, n_int, ) for epoch_idx, n_frames, uniq_s in pool.imap_unordered(precompute_fn, epochs): if epoch_idx == -1: continue # Update Trackers unique_states |= uniq_s n_frames_per_epoch = n_frames # Dump Statistics (Note :: Only makes sense on "initial" computation --> uninterrupted!) overwatch.info(f"Train Uniqueness: {len(unique_states)} States & {len(mg_registry)} Utterances") overwatch.info(f"Final Statistics :: 1 Epoch has ~ {n_frames_per_epoch} Frames...") ================================================ FILE: voltron/preprocessing/transforms.py ================================================ """ transforms.py Default video/image transforms for Voltron preprocessing and training. Provides utilities for defining different scale and crop transformations on a dataset-specific basis. There are two key desiderata we ensure with the transforms: - Aspect Ratio --> We *never* naively reshape images in a way that distorts the aspect ratio; we crop instead! - Minimum Size --> We *never* upsample images; processing strictly reduces dimensionality! """ from functools import partial from typing import Any, Callable, List, Tuple import torch from PIL import Image, ImageOps from torchvision.transforms import Compose, ConvertImageDtype, Lambda, Normalize, Resize # Simple Identity Function --> needs to be top-level/pickleable for mp/distributed.spawn() def identity(x: torch.Tensor) -> torch.Tensor: return x.float() def scaled_center_crop(target_resolution: int, frames: List[Image.Image]) -> Image.Image: # Assert width >= height and height >= target_resolution orig_w, orig_h = frames[0].size assert orig_w >= orig_h >= target_resolution # Compute scale factor --> just a function of height and target_resolution scale_factor = target_resolution / orig_h for idx in range(len(frames)): frames[idx] = ImageOps.scale(frames[idx], factor=scale_factor) left = (frames[idx].size[0] - target_resolution) // 2 frames[idx] = frames[idx].crop((left, 0, left + target_resolution, target_resolution)) # Return "scaled and squared" images return frames def get_preprocess_transform( dataset_name: str, preprocess_resolution: int ) -> Callable[[List[Image.Image]], List[Image.Image]]: """Returns a transform that extracts square crops of `preprocess_resolution` from videos (as [T x H x W x C]).""" if dataset_name == "sth-sth-v2": return partial(scaled_center_crop, preprocess_resolution) else: raise ValueError(f"Preprocessing transform for dataset `{dataset_name}` is not defined!") def get_online_transform( dataset_name: str, model_arch: str, online_resolution: int, normalization: Tuple[Any, Any] ) -> Compose: """Returns an "online" torchvision Transform to be applied during training (batching/inference).""" if dataset_name == "sth-sth-v2": # Note: R3M does *not* expect normalized 0-1 (then ImageNet normalized) images --> drop the identity. if model_arch in {"v-r3m", "v-rn3m"}: return Compose([Resize((online_resolution, online_resolution), antialias=True), Lambda(identity)]) else: return Compose( [ Resize((online_resolution, online_resolution), antialias=True), ConvertImageDtype(torch.float), Normalize(mean=normalization[0], std=normalization[1]), ] ) else: raise ValueError(f"Online Transforms for Dataset `{dataset_name}` not implemented!") ================================================ FILE: voltron/preprocessing/v1/__init__.py ================================================ from .process import index, jsonify_language, preprocess_language, preprocess_videos, unify_batches ================================================ FILE: voltron/preprocessing/v1/process.py ================================================ """ process.py Utility functions for serializing datasets in multiple passes, using multiprocessing for efficient parallelization. Exposes a three-phase sequence for preprocessing: - Phase I: Read in raw videos (and language), serialize *all extracted* frames to a subdirectory for easy retrieval. - Phase II: Given image paths and language, assemble language statistics & pre-tokenize for easy batching. - Phase III: Given a total number of "conceivable epochs", create data-controlled "epoch" sets for each model. This script tries to be smart where it can, using multiprocessing.Pool in Phase I to speed up the serialization process. It also tries to be somewhat safe & efficient, producing idempotent resumes. Note :: This code represents the `v1` (initial release) preprocessing flow; this will eventually be deprecated! """ import json import logging import multiprocessing as mp import os import shutil from functools import partial from pathlib import Path from typing import Tuple import torch from rich.progress import track from transformers import AutoTokenizer from voltron.preprocessing.v1.transforms import get_pre_transform from voltron.preprocessing.v1.utils import do_dry_run, precompute_epoch, process_video # Grab Logger overwatch = logging.getLogger(__file__) def preprocess_videos( name: str, path: str, artifact_path: str = "data/processed", resolution: int = 224, n_val_videos: int = 1000, dry_run: bool = False, ) -> Tuple[Path, Path, Path, Path]: """Phase I of Preprocessing :: Uses Multiprocessing to Read Videos & Serialize Frames.""" overwatch.info(f"Phase 1 Preprocessing :: Frame serializing videos for dataset `{name}`") if name == "sth-sth-v2": # Overview of Return Values: # `t_registry` and `v_registry` =>> store mappings of "vid_id" -> {metadata} # `t_dir` and `v_dir` =>> store "processed data" (extracted frames) t_dir, v_dir = Path(artifact_path) / name / "train", Path(artifact_path) / name / "val" t_registry, v_registry = t_dir / "registry.json", v_dir / "registry.json" # Short-Circuit / Caching Logic if t_registry.exists() and v_registry.exists(): return t_registry, v_registry, t_dir, v_dir # Setup / Book-Keeping os.makedirs(t_dir, exist_ok=True) os.makedirs(v_dir, exist_ok=True) # Retrieve Image Transforms (pre-serialization, while running "offline" pass); we crop and scale once, so we're # not overdoing it on disk storage... pre_transform = get_pre_transform(name, resolution=resolution) # Open & Extract Video ID & Language Metadata with open(Path(path) / "something-something-v2-train.json", "r") as f: annotations = json.load(f) train_ids, train_lang = [x["id"] for x in annotations], [x["label"] for x in annotations] with open(Path(path) / "something-something-v2-validation.json", "r") as f: annotations = json.load(f)[:n_val_videos] val_ids, val_lang = [x["id"] for x in annotations], [x["label"] for x in annotations] # Do Dry-Run --> Single-Threaded! if dry_run: do_dry_run( name, path, n_train_videos=1000, n_val_videos=100, train_ids=train_ids, val_ids=val_ids, pre_transform=pre_transform, ) # Go Go Go =>> Iterate through all videos, dump all frames subject to the following structure: # |-> data/processed/sth-sth-v2/ # |-> / # |-> /frames<0...k>.jpg # We'll track a single metadata file with the map of : ("language", n_frames). # > To speed up the serialization, we'll use a multiprocessing.Pool and max out CPU workers with mp.Pool(mp.cpu_count()) as pool: for k, save, vids, langs in [("train", t_dir, train_ids, train_lang), ("val", v_dir, val_ids, val_lang)]: overwatch.info(f"\tWriting `{k}` videos to disk...") # Multiprocess! process_fn, registration = partial(process_video, name, Path(path), save, pre_transform), {} for key, value in track( pool.imap_unordered(process_fn, zip(vids, langs)), description=f"\t[*] Processing {k}...", total=len(vids), transient=True, ): if key is not None: registration[key] = value # Write Registration to Disk with open(t_registry if k == "train" else v_registry, "w") as f: json.dump(registration, f) # Return Paths... return t_registry, v_registry, t_dir, v_dir else: raise NotImplementedError(f"Preprocessing Pipeline for Dataset `{name}` not implemented!") def preprocess_language( name: str, train_registry: Path, val_registry: Path, max_lang_len: int, language_model: str, hf_cache: str ) -> None: """Phase II of Preprocessing :: Iterate through Language & Normalize/Tokenize to Max Length.""" overwatch.info(f"Phase 2 Preprocessing :: Normalizing & tokenizing language for dataset `{name}`") t_index, v_index = train_registry.parent / "index.pt", val_registry.parent / "index.pt" t_json, v_json = train_registry.parent / "index.json", val_registry.parent / "index.json" # Short-Circuit Logic if (t_index.exists() and v_index.exists()) or (t_json.exists() and v_json.exists()): return t_index, v_index # Grab Language, Retaining Metadata for Building Index Structures... with open(train_registry, "r") as f: train_metadata = json.load(f) train = [(vid, train_metadata[vid]["language"], train_metadata[vid]) for vid in train_metadata] with open(val_registry, "r") as f: val_metadata = json.load(f) val = [(vid, val_metadata[vid]["language"], val_metadata[vid]) for vid in val_metadata] # Assemble *all* language language = [x[1] for x in train + val] # Build AutoTokenizer (from `language_model` identifier) tokenizer = AutoTokenizer.from_pretrained(language_model, cache_dir=hf_cache) # If `max_lang_len` not specified, dump some statistics to compute... if max_lang_len == -1: # Naively tokenizes and pads to the "maximum length" of _all_ language... long tail is a problem! encoded_language = tokenizer(language, return_tensors="pt", padding=True) lengths = encoded_language["attention_mask"].sum(dim=1) # Compute a histogram of lengths hist = lengths.float().histc(bins=lengths.max()).int() overwatch.info(f"Histogram: {hist.numpy().tolist()}") raise NotImplementedError("Compute max length and update dataset configuration!") # Otherwise, we've already set the maximum length, so let's use it! else: overwatch.info(f"\tTokenizing all language in dataset to maximum length `{max_lang_len}`") encoded_language = tokenizer( language, return_tensors="pt", max_length=max_lang_len, truncation=True, padding="max_length" ) input_ids, attention_mask = encoded_language["input_ids"], encoded_language["attention_mask"] train_input_ids, train_attention_mask = input_ids[: len(train)], attention_mask[: len(train)] val_input_ids, val_attention_mask = input_ids[len(train) :], attention_mask[len(train) :] # Assertion, just to sanity check assert len(val_input_ids) == len(val_attention_mask) == len(val), "Something went wrong tokenizing language..." # Compute `index.pt` contents overwatch.info("\tAssembling `train` and `val` index structures...") train_pt = { train[i][0]: {**train[i][2], **{"input_ids": train_input_ids[i], "attention_mask": train_attention_mask[i]}} for i in range(len(train)) } val_pt = { val[i][0]: {**val[i][2], **{"input_ids": val_input_ids[i], "attention_mask": val_attention_mask[i]}} for i in range(len(val)) } # Dump structures... overwatch.info(f"Saving index structures to `{t_index}` and `{v_index}` respectively...") torch.save(train_pt, t_index) torch.save(val_pt, v_index) def jsonify_language(train_registry: Path, val_registry: Path) -> None: """Phase 2.5 (Aggregation) :: XLA is weird, won't load torch.Tensors in Dataset; JSONify instead.""" overwatch.info("\tPhase 2 Aggregation :: JSONifying Language Index") t_index, v_index = train_registry.parent / "index.pt", val_registry.parent / "index.pt" t_json, v_json = train_registry.parent / "index.json", val_registry.parent / "index.json" train_json, val_json = {}, {} # Short-Circuit Logic if t_json.exists() and v_json.exists(): return # Load Data, iterate through and "de-tensorize", while building up JSON symmetric structure... train_data, val_data = torch.load(t_index), torch.load(v_index) overwatch.info("JSONifying both Train and Validation") for vid in track(train_data, description="Train Language...", transient=True): train_json[vid] = { "language": train_data[vid]["language"], "n_frames": train_data[vid]["n_frames"], "input_ids": train_data[vid]["input_ids"].numpy().tolist(), "attention_mask": train_data[vid]["attention_mask"].numpy().tolist(), } for vid in track(val_data, description="Val Language...", transient=True): val_json[vid] = { "language": val_data[vid]["language"], "n_frames": val_data[vid]["n_frames"], "input_ids": val_data[vid]["input_ids"].numpy().tolist(), "attention_mask": val_data[vid]["attention_mask"].numpy().tolist(), } # Write Data to Disk overwatch.info("Writing JSON Indices") with open(t_json, "w") as f: json.dump(train_json, f) with open(v_json, "w") as f: json.dump(val_json, f) def index(train_registry: Path, val_registry: Path, name: str, artifact_path: str = "data/processed") -> Path: """Phase 2.75 (Indexing) :: Pull out language.json & other `absolutely necessary` indices to separate directory.""" overwatch.info("\tPhase 2 Indexing :: Indexing Language & Registry Files =>> Extracting to Separate Directory") # Create "index" directory... index_dir = Path(artifact_path) / name / "index" os.makedirs(index_dir, exist_ok=True) # Short-Circuit Logic if (index_dir / "train-language-index.json").exists() and (index_dir / "val-language-index.json").exists(): return index_dir # Retrieve Language JSON indices (train & validation) & copy to new directory... t_json, v_json = train_registry.parent / "index.json", val_registry.parent / "index.json" shutil.copy(t_json, index_dir / "train-language-index.json") shutil.copy(v_json, index_dir / "val-language-index.json") return index_dir def unify_batches( artifact_path: Path, name: str, train_registry: Path, val_registry: Path, train_dir: Path, val_dir: Path, index_dir: Path, batch_formats: Tuple[Tuple[str, Tuple[str, ...]], ...], max_epochs: int = 400, initial_final_alpha: float = 0.2, ) -> None: """Phase III of Preprocessing :: Assemble Batches for *all models* for *all epochs* in a consistent manner.""" overwatch.info("Phase 3 Preprocessing :: Assembling Data-Equivalent Epochs for each Model Format") # Load Registry Files with open(train_registry, "r") as f: train_registrations = json.load(f) with open(val_registry, "r") as f: val_registrations = json.load(f) # Assert last element of `batch_formats` assumes all prior subsets... full_set_inputs = set(batch_formats[-1][1]) for _, subset_inputs in batch_formats[:-1]: assert full_set_inputs.issuperset(set(subset_inputs)), "We have a problem with batch formats..." # Assemble Tracking Data b_keys, unique_states = {b[0] for b in batch_formats}, set() # Parse out all "state"-specific elements... state_elements = [s for s in full_set_inputs if "state_" in s] do_initial, do_final = "state_initial" in state_elements, "state_final" in state_elements n_int = len(state_elements) - 2 if ("state_initial" in state_elements and "state_final" in state_elements) else 0 # Serialize Epochs to Disk overwatch.info("\tSerializing epochs to json file, pointing to image paths on disk via a dictionary...") for b in b_keys: os.makedirs(index_dir / b, exist_ok=True) # We only write the validation epoch once --> held constant across _all_ of training! overwatch.info("\tWriting Validation Epoch to Disk...") val_epoch_idx, _, uniq_s = precompute_epoch( index_dir, val_registrations, val_dir, batch_formats, do_initial, do_final, initial_final_alpha, n_int, 0, is_validation=True, ) # Update Trackers... if val_epoch_idx != -1: unique_states |= uniq_s # Compute length of epochs --> CPU Count should be no higher... epochs, n_frames_per_epoch = list(range(max_epochs)), -1 # Load "existing" verification file (if possible) overwatch.info("\tLoading batch verification file (if possible)...") verified_batches = Path(artifact_path) / name / "verified-batches.json" if verified_batches.exists(): with open(verified_batches, "r") as f: missing_epochs_per_format = json.load(f) # Set epochs list by taking union of missing epochs over formats... epochs = sorted(list(set().union(*missing_epochs_per_format.values()))) # Dump the big objects into an mp.Manager() so that we can read efficiently from other workers... overwatch.info("\tPlacing the Train Registry into Shared Memory...") manager = mp.Manager() mg_registry = manager.dict(train_registrations) with mp.Pool(4) as pool: overwatch.info("\tWriting Train Batches per Epoch to Disk...") # Create partial function for multiprocessing pool... precompute_fn = partial( precompute_epoch, index_dir, mg_registry, train_dir, batch_formats, do_initial, do_final, initial_final_alpha, n_int, ) for epoch_idx, n_frames, uniq_s in pool.imap_unordered(precompute_fn, epochs): if epoch_idx == -1: continue # Update Trackers unique_states |= uniq_s n_frames_per_epoch = n_frames # Statistics only make sense on initial computation... should unify with code above! overwatch.info(f"Train Uniqueness: {len(unique_states)} States & {len(mg_registry)} Utterances") overwatch.info(f"Final Statistics :: 1 Epoch has ~ {n_frames_per_epoch} Frames...") overwatch.info("Preprocessing Complete!") ================================================ FILE: voltron/preprocessing/v1/transforms.py ================================================ """ transforms.py Default image/video transformations for various datasets. """ from typing import Any, Tuple import cv2 import numpy as np import torch from torchvision.transforms import Compose, ConvertImageDtype, Lambda, Normalize # Definitions of Video Transformations (Reference: `something-something-v2-baseline`) class ComposeMix: def __init__(self, transforms): self.transforms = transforms def __call__(self, imgs): for transformation, scope in self.transforms: if scope == "img": for idx, img in enumerate(imgs): imgs[idx] = transformation(img) elif scope == "vid": imgs = transformation(imgs) else: raise ValueError("Please specify a valid transformation...") return imgs class RandomCropVideo: def __init__(self, size): self.size = size def __call__(self, imgs): th, tw = self.size h, w = imgs[0].shape[:2] x1, y1 = np.random.randint(0, w - tw), np.random.randint(0, h - th) for idx, img in enumerate(imgs): imgs[idx] = img[y1 : y1 + th, x1 : x1 + tw] return imgs class Scale: def __init__(self, size): self.size = size def __call__(self, img): return cv2.resize(img, tuple(self.size)) def identity(x): """Transform needs to be pickleable for multiprocessing.spawn().""" return x.float() def get_pre_transform(dataset: str, resolution: int, scale_factor: float = 1.1) -> ComposeMix: """Defines a `pre` transform to be applied *when serializing the images* (first pass).""" if dataset == "sth-sth-v2": if scale_factor > 1: transform = ComposeMix( [ [Scale((int(resolution * scale_factor), int(resolution * scale_factor))), "img"], [RandomCropVideo((resolution, resolution)), "vid"], ] ) else: transform = ComposeMix( [ [Scale((int(resolution * scale_factor), int(resolution * scale_factor))), "img"], ] ) return transform else: raise NotImplementedError(f"(Pre) transforms for dataset `{dataset}` not yet implemented!") def get_online_transform(dataset: str, model_arch: str, normalization: Tuple[Any, Any]) -> Compose: """Defines an `online` transform to be applied *when batching the images* (during training/validation).""" if dataset == "sth-sth-v2": # Note: R3M does *not* expect normalized 0-1 (then ImageNet normalized) images --> drop the identity. if model_arch in {"v-r3m", "v-rn3m"}: return Compose([Lambda(identity)]) else: return Compose([ConvertImageDtype(torch.float), Normalize(mean=normalization[0], std=normalization[1])]) else: raise NotImplementedError(f"(Online) transforms for dataset `{dataset} not yet implemented!") ================================================ FILE: voltron/preprocessing/v1/utils.py ================================================ """ utils.py Preprocessing utilities, including functions for dry-runs and processing a single video (helpers for multiprocessing calls down the lines). """ import glob import json import logging import os import sys import time from pathlib import Path from typing import Any, Dict, List, Optional, Set, Tuple import av import cv2 import numpy as np from hurry.filesize import alternative, size from rich.progress import track from tqdm import tqdm from voltron.preprocessing.v1.transforms import ComposeMix # Grab Logger overwatch = logging.getLogger(__file__) logging.getLogger("libav").setLevel(logging.ERROR) # Videos are saved as `train_dir/{vid}/{vid}_idx={i}.jpg def get_path(save_dir: Path, v: str, i: int) -> str: return str(save_dir / v / f"{v}_idx={i}.jpg") def do_dry_run( name: str, path: str, n_train_videos: int, n_val_videos: int, train_ids: List[str], val_ids: List[str], pre_transform: ComposeMix, n_samples: int = 1000, ) -> None: """Iterates through a small subset of the total dataset, logs n_frames & average image size for estimation.""" dry_run_metrics = { "n_frames": [], "jpg_sizes": [], "n_samples": n_samples, "time_per_example": [], "blank": str(Path(path) / "blank.jpg"), } if name == "sth-sth-v2": for k, n_iter, vids in [("train", n_train_videos, train_ids), ("val", n_val_videos, val_ids)]: for idx in track(range(n_iter), description=f"Reading {k.capitalize()} Videos =>> ", transient=True): vid = vids[idx] container = av.open(str(Path(path) / "videos" / f"{vid}.webm")) try: imgs = [f.to_rgb().to_ndarray() for f in container.decode(video=0)] except (RuntimeError, ZeroDivisionError) as e: overwatch.error(f"{type(e).__name__}: WebM reader cannot open `{vid}.webm` - continuing...") continue # Close container container.close() # Apply `pre_transform` imgs = pre_transform(imgs) # Dry-Run Handling --> write a dummy JPEG to collect size statistics, dump, and move on... dry_run_metrics["n_frames"].append(len(imgs)) while dry_run_metrics["n_samples"] > 0 and len(imgs) > 0: img = imgs.pop(0) cv2.imwrite(str(dry_run_metrics["blank"]), cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) dry_run_metrics["jpg_sizes"].append(os.path.getsize(dry_run_metrics["blank"])) dry_run_metrics["n_samples"] -= 1 # Compute nice totals for "dry-run" estimation total_clips = len(train_ids) + len(val_ids) else: raise NotImplementedError(f"Dry Run for Dataset `{name}` not yet implemented!") # Compute Aggregate Statistics and gently exit... avg_size, avg_frames = np.mean(dry_run_metrics["jpg_sizes"]), int(np.mean(dry_run_metrics["n_frames"])) overwatch.info("Dry-Run Statistics =>>") overwatch.info(f"\t> A video has on average `{avg_frames}` frames at {size(avg_size, system=alternative)}") overwatch.info(f"\t> So - 1 video ~ {size(avg_frames * avg_size, system=alternative)}") overwatch.info( f"\t> With the full dataset of {total_clips} Train + Val videos ~" f" {size(total_clips * avg_frames * avg_size, system=alternative)}" ) overwatch.info("Dry-Run complete, do what you will... exiting ✌️") # Remove dummy file... os.remove(dry_run_metrics["blank"]) sys.exit(0) def process_video( name: str, path: Path, save: Path, pre_transform: ComposeMix, item: Tuple[str, str] ) -> Tuple[Optional[str], Optional[Dict[str, Any]]]: """Processes a single video file, dumps to series of image files, and returns the registry contents.""" if name == "sth-sth-v2": # For sth-sth-v2, `item` corresponds to a single video clip, so just a tuple! vid, lang = item container, registration = av.open(str(Path(path) / "videos" / f"{vid}.webm")), {"language": lang, "n_frames": 0} try: imgs = [f.to_rgb().to_ndarray() for f in container.decode(video=0)] except (RuntimeError, ZeroDivisionError) as e: overwatch.error(f"{type(e).__name__}: WebM reader cannot open `{vid}.webm` - skipping...") return None, None # Close container container.close() # Book-keeping os.makedirs(save / vid, exist_ok=True) registration["n_frames"] = len(imgs) # Early exit (writes are expensive) if len(glob.glob1(save / vid, "*.jpg")) == len(imgs): return vid, registration # Apply `pre_transform` --> write individual frames, register, and return imgs = pre_transform(imgs) for i in range(len(imgs)): cv2.imwrite(get_path(save, vid, i), cv2.cvtColor(imgs[i], cv2.COLOR_RGB2BGR)) # Return title & registration return vid, registration else: raise NotImplementedError(f"Process Video for Dataset `{name}` not yet implemented!") # ruff: noqa: C901 def precompute_epoch( index_dir: Path, registry: Dict[str, Any], vid_dir: Path, batch_formats: Tuple[Tuple[str, Tuple[str, ...]], ...], do_initial: bool, do_final: bool, initial_final_alpha: float, n_int: int, epoch: int, is_validation: bool = False, ) -> Tuple[int, int, Optional[Set[str]]]: index_file = "validation-batches.json" if is_validation else f"train-epoch={epoch}-batches.json" # Short-Circuit if all([(index_dir / key / index_file).exists() for key, _ in batch_formats]): return -1, -1, None # Random seed is inherited from parent process... we want new randomness w/ each process np.random.seed((os.getpid() * int(time.time())) % 123456789) # Create Tracking Variables unique_states, batches = set(), {b: [] for b, _ in batch_formats} # Iterate through Registry... for vid in tqdm(registry.keys(), desc=f"Epoch {epoch}", total=len(registry), position=epoch): # The initial/final states are sampled from the first [0, \alpha) and final 1-\alpha, 1] percent of the video n_frames = registry[vid]["n_frames"] initial_idx, final_idx = 0, n_frames - 1 if do_initial: initial_idx = np.random.randint(0, np.around(n_frames * initial_final_alpha)) if do_final: final_idx = np.random.randint(np.around(n_frames * (1 - initial_final_alpha)), n_frames) # Assertion --> initial_idx < final_idx - len(state_elements) assert initial_idx < final_idx - n_int, "Initial & Final are too close... no way to sample!" # Assume remaining elements are just random "interior" states --> sort to get ordering! sampled_idxs = np.random.choice(np.arange(initial_idx + 1, final_idx), size=n_int, replace=False) sampled_idxs = sorted(list(sampled_idxs)) # Compile full-set "batch" retrieved_states = [get_path(vid_dir, vid, x) for x in [initial_idx, *sampled_idxs] + [final_idx]] # Add batch to index for specific batch_format key... batches[batch_formats[-1][0]].append({"vid": vid, "states": retrieved_states, "n_frames": n_frames}) unique_states.update(retrieved_states) # Add all other batch formats to indices... for key, elements in batch_formats[:-1]: n_states = len([x for x in elements if "state_" in x]) assert (n_states <= 2) or ( n_states == len(retrieved_states) ), f"Strange value of n_states={n_states} > 2 and not equal to total possible of {len(retrieved_states)}" # States are all independent -- each of the retrieved states is its own example... if n_states == 1: for idx in range(len(retrieved_states)): batches[key].append({"vid": vid, "state": retrieved_states[idx], "n_frames": n_frames}) # OK-Context is the only "valid" context for n_states == 2 elif n_states == 2: assert elements == ["state_initial", "state_i", "language"], "n_states = 2 but not 0K context?" # Append 0th state to each of the remaining sampled contexts (usually 2 or 4)... each pair is an example for idx in range(1, len(retrieved_states)): batches[key].append( {"vid": vid, "states": [retrieved_states[0], retrieved_states[idx]], "n_frames": n_frames} ) # We're treating the entire sequence of retrieved states as a single example (for TCN/R3M/Temporal Models) else: batches[key].append({"vid": vid, "states": retrieved_states, "n_frames": n_frames}) # Write JSON Index directly to disk... for key in batches: with open(index_dir / key / index_file, "w") as f: json.dump(batches[key], f) return epoch, len(batches["state"]), unique_states ================================================ FILE: voltron/util/__init__.py ================================================ from .checkpointing import CheckpointSaver, do_resume from .metrics import Metrics from .utilities import ResumeableDistributedSampler, set_global_seed ================================================ FILE: voltron/util/checkpointing.py ================================================ """ checkpointing.py Core utility class for handling model/optimizer serialization & checkpointing -- including resume from checkpoint logic. Support the following strategies: - (k, -1, -1) --> Keep only the most recent "k" epoch checkpoints - (k, m, -1) --> Keep the most recent "k" epoch checkpoints and *every* m epoch checkpoint - (k, m, s = 2500) --> Keep "k" and "m" subject to above, but also keep *s* step checkpoints for current epoch """ import logging import os import re from collections import deque from pathlib import Path from typing import Any, Optional, Tuple import torch import torch.nn as nn from torch.optim.optimizer import Optimizer # Grab Logger overwatch = logging.getLogger(__file__) class FixedDeck(deque): def __init__(self, maxlen: int) -> None: super().__init__(maxlen=maxlen) def append(self, x: Any) -> Any: pop_value = None if self.__len__() == self.maxlen: pop_value = self.__getitem__(0) # Perform parent append and return popped value, if any! super().append(x) return pop_value class CheckpointSaver: def __init__(self, strategy: Tuple[int, int, int], run_dir: str, is_rank_zero: bool = False) -> None: """ Create a checkpoint saver with the provided strategy that saves to the given path. :param strategy: Strategy, following the (k, -1, -1) -- (k, m, -1) -- (k, m, s) description above. :param run_dir: Path to root of `run_dir`. :param is_rank_zero: Boolean whether this process is global zero (no-op if not)! """ (self.k, self.m, self.s), self.run_dir, self.is_rank_zero = strategy, run_dir, is_rank_zero self.recents, self.intervals, self.step_checkpoints = FixedDeck(maxlen=self.k), set(), set() # If `self.s == -1` --> *Disable* step checkpoints (only at save end of epoch!) self.enable_step = self.s != -1 # Create "checkpoints" subdirectory self.path = Path(run_dir) / "checkpoints" if self.is_rank_zero: os.makedirs(self.path, exist_ok=True) # Populate `step_checkpoints` on __init__ (if resuming *within* an epoch!) self.step_checkpoints.update([c for c in self.path.iterdir() if "local-epoch=" in str(c)]) # Created Saver... overwatch.info(f"Created CheckpointSaver with `k = {self.k}` -- `m = {self.m}` -- s = {self.s}!") def save( self, epoch: int, is_local_step: bool, model: nn.Module, optimizer: Optimizer, duration: int, local_step: Optional[int] = None, train_loss: Optional[float] = None, val_loss: Optional[float] = None, ) -> None: """Performs a global zero save operation, unlinking stale checkpoints if necessary.""" if not self.is_rank_zero: return # Check if saving a `local_step` (within an epoch) or if end of epoch... if self.enable_step and is_local_step and (local_step % self.s) == 0: step_checkpoint = self.path / f"local-epoch={epoch}-step={local_step}-t={duration}.pt" torch.save( {"model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict()}, step_checkpoint ) # Update Relevant Trackers... self.step_checkpoints.add(step_checkpoint) elif not is_local_step: if train_loss is None and val_loss is None: checkpoint = self.path / f"epoch={epoch}-train=inf-val=inf-t={duration}.pt" else: checkpoint = self.path / f"epoch={epoch}-train={train_loss:.4f}-val={val_loss:.4f}-t={duration}.pt" torch.save( {"model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict()}, checkpoint ) # Update Relevant Trackers if epoch % self.m == 0: self.intervals.add(checkpoint) # Remove all "step_checkpoints" now that we've made it to the end of an epoch! while len(self.step_checkpoints) > 0: os.remove(self.step_checkpoints.pop()) # Add to recents & flush stale checkpoints... to_remove = self.recents.append(checkpoint) if to_remove is not None and to_remove not in self.intervals: os.remove(to_remove) def do_resume(resume: bool, run_dir: str) -> Tuple[Optional[Path], int, int]: """Handle `resume` logic --> consists of retrieving checkpoint_path and epoch/step computation (if resuming).""" if not resume: # We're starting a fresh run --> return None for checkpoint_path, resume_epoch = 0, resume_step = 0 return None, 0, 0 # === Auto-Resume Logic === # **IMPORTANT**: We're making a few assumptions on resuming that should eventually become explicit checks: # - `accumulate_grad_batches` is exactly the same when resuming; this means: # + `model_cfg.effective_bsz`, `model_cfg.fabric_bsz`, & `accelerator_cfg.num_accelerators` are the same! # - The Weights & Biases directory `run_dir/wandb` only contains a *single run* # - The `param_groups` in `optimizer.state_dict()` are exactly the same across resumes! # + This means that (and generally should be true for resuming altogether) the architecture is the same! # - The `cfg.seed` should be the same (again, should generally be true...) all_checkpoints_path, resume_checkpoint, resume_epoch, resume_step = Path(run_dir) / "checkpoints", None, 0, 0 if all_checkpoints_path.exists() and any(all_checkpoints_path.iterdir()): # Parse out the latest "complete" epoch checkpoint, as well as any "local step" checkpoints... checkpoints = list(all_checkpoints_path.iterdir()) complete_checkpoint, complete_epoch = max( [ (c, int(re.search("epoch=(.+?)-train", c.name).group(1))) for c in checkpoints if "local-epoch=" not in str(c) ], key=lambda x: x[1], ) # Case 1 :: We have "local step" checkpoints --> will always override any "full epoch" checkpoints... local = [ ( c, int(re.search("local-epoch=(.+?)-step", c.name).group(1)), int(re.search("step=(.+?)[.-]", c.name).group(1)), ) for c in checkpoints if "local-epoch=" in str(c) ] if len(local) > 0: # Parse out (epoch, "highest" step) + assert no great "full epoch" checkpoint exists! resume_checkpoint, resume_epoch, resume_step = max(local, key=lambda x: x[1:]) assert resume_epoch == complete_epoch, "Epoch mismatch in `resume` from local_step!" # Case 2 :: Otherwise, we're just going to start with the last "complete" epoch else: resume_checkpoint, resume_epoch = complete_checkpoint, complete_epoch return resume_checkpoint, resume_epoch, resume_step ================================================ FILE: voltron/util/metrics.py ================================================ """ metrics.py Utility classes defining Metrics containers with model-specific logging to various endpoints (JSONL local logs, W&B). """ import os import re import time from abc import ABC, abstractmethod from collections import deque from datetime import datetime from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Union import jsonlines import numpy as np import torch import wandb from voltron.conf import TrackingConfig # === Define Loggers (`Logger` is an abstract base class) === class Logger(ABC): def __init__(self, run_id: str, hparams: Dict[str, Any], is_rank_zero: bool = False) -> None: self.run_id, self.hparams, self.is_rank_zero = run_id, hparams, is_rank_zero @abstractmethod def write_hyperparameters(self) -> None: raise NotImplementedError("Logger is an abstract class!") @abstractmethod def write(self, global_step: int, metrics: Dict[str, Union[int, float]]) -> None: raise NotImplementedError("Logger is an abstract class!") def finalize(self) -> None: time.sleep(1) class JSONLinesLogger(Logger): def write_hyperparameters(self) -> None: if not self.is_rank_zero: return # Only log if `is_rank_zero` with jsonlines.open(f"{self.run_id}.jsonl", mode="w", sort_keys=True) as js_logger: js_logger.write( { "run_id": self.run_id, "start_time": datetime.now().strftime("%m-%d-%H:%M"), "hparams": self.hparams, } ) def write(self, global_step: int, metrics: Dict[str, Union[int, float]]) -> None: if not self.is_rank_zero: return # Only log if `is_rank_zero` with jsonlines.open(f"{self.run_id}.jsonl", mode="a", sort_keys=True) as js_logger: js_logger.write(metrics) class WeightsBiasesLogger(Logger): def __init__( self, run_id: str, hparams: Dict[str, Any], tracking_cfg: TrackingConfig, tags: List[str], resume: bool = False, resume_id: Optional[str] = None, is_rank_zero: bool = False, ) -> None: super().__init__(run_id, hparams, is_rank_zero) self.tracking_cfg, self.tags, self.resume, self.resume_id = tracking_cfg, tags, resume, resume_id self.path = Path(os.getcwd() if self.tracking_cfg.directory is None else self.tracking_cfg.directory) # Handle (Automatic) Resume if `resume = True` if self.resume and self.resume_id is None: wandb_path = self.path / "wandb" if wandb_path.exists() and any((wandb_path / "latest-run").iterdir()): # Parse unique `run_id` from the `.wandb.` file... wandb_fns = [f.name for f in (wandb_path / "latest-run").iterdir() if f.name.endswith(".wandb")] assert len(wandb_fns) == 1, f"There should only be 1 `.wandb.` file... found {len(wandb_fns)}!" # Regex Match on `run-{id}.wandb` self.resume_id = re.search("run-(.+?).wandb", wandb_fns[0]).group(1) elif wandb_path.exists(): raise ValueError("Starting Training from Scratch with Preexisting W&B Directory; Remove to Continue!") # Call W&B.init() self.initialize() def initialize(self) -> None: """Run W&B.init on the guarded / rank-zero process.""" if not self.is_rank_zero: return # Only initialize / log if `is_rank_zero` wandb.init( project=self.tracking_cfg.project, entity=self.tracking_cfg.entity, config=self.hparams, name=self.run_id, dir=self.path, tags=self.tags, notes=self.tracking_cfg.notes, resume="allow" if self.resume else False, id=self.resume_id, ) def write_hyperparameters(self) -> None: if not self.is_rank_zero: return # Only log if `is_rank_zero` wandb.config = self.hparams def write(self, global_step: int, metrics: Dict[str, Union[int, float]]) -> None: if not self.is_rank_zero: return # Only log if `is_rank_zero` wandb.log(metrics, step=global_step) def finalize(self) -> None: wandb.finish() time.sleep(150) # === Core Metrics Container :: Responsible for Initializing Loggers and Compiling/Pushing Metrics === class Metrics: def __init__( self, active_loggers: List[str], run_id: str, hparams: Dict[str, Any], model_arch: str, is_rank_zero: bool, tracking_cfg: Optional[TrackingConfig] = None, tags: Optional[List[str]] = None, resume: bool = False, resume_id: Optional[str] = None, window: int = 128, ) -> None: """High-Level Container Logic for Metrics Logging; logic defined for each model architecture!""" self.model_arch, self.is_rank_zero, self.window = model_arch, is_rank_zero, window # Initialize Loggers self.loggers = [] for log_type in active_loggers: if log_type == "jsonl": logger = JSONLinesLogger(run_id, hparams, is_rank_zero=is_rank_zero) elif log_type == "wandb": logger = WeightsBiasesLogger( run_id, hparams, tracking_cfg, tags, resume, resume_id, is_rank_zero=is_rank_zero ) else: raise ValueError(f"Logger `{log_type}` is not defined!") # Add Hyperparameters --> Add to `self.loggers` logger.write_hyperparameters() self.loggers.append(logger) # Create Universal Trackers self.global_step, self.start_time, self.resume_time, self.step_start_time = 0, time.time(), 0, time.time() self.tracker = { "loss": deque(maxlen=self.window), "lr": [], "step_time": deque(maxlen=self.window), } # Create Model-Specific Trackers if self.model_arch == "v-mvp": self.tracker.update({"reconstruction_loss": deque(maxlen=self.window)}) elif self.model_arch in {"v-r3m", "v-rn3m"}: self.tracker.update( { "tcn_loss": deque(maxlen=self.window), "reward_loss": deque(maxlen=self.window), "l1_loss": deque(maxlen=self.window), "l2_loss": deque(maxlen=self.window), "tcn_accuracy": deque(maxlen=self.window), "reward_accuracy": deque(maxlen=self.window), } ) elif self.model_arch == "v-cond": self.tracker.update({"reconstruction_loss": deque(maxlen=self.window)}) elif self.model_arch == "v-dual": self.tracker.update( { "reconstruction_loss": deque(maxlen=self.window), "zero_reconstruction_loss": deque(maxlen=self.window), "k_reconstruction_loss": deque(maxlen=self.window), } ) elif self.model_arch == "v-gen": self.tracker.update( { "reconstruction_loss": deque(maxlen=self.window), "zero_reconstruction_loss": deque(maxlen=self.window), "k_reconstruction_loss": deque(maxlen=self.window), "lm_loss": deque(maxlen=self.window), "lm_ppl": deque(maxlen=self.window), } ) else: raise ValueError(f"Metrics for Model `{self.model_arch}` are not implemented!") def itemize(self) -> Dict[str, torch.Tensor]: """Utility method for converting `deque[torch.Tensor] --> mean over Tensors.""" return { k: torch.stack(list(v)).mean().item() for k, v in self.tracker.items() if k not in {"loss", "lr", "step_time"} } def log(self, global_step: int, metrics: Dict[str, Union[int, float]]) -> None: for logger in self.loggers: logger.write(global_step, metrics) def finalize(self) -> None: for logger in self.loggers: logger.finalize() def get_status(self, epoch: int, loss: Optional[torch.Tensor] = None) -> str: lr = self.tracker["lr"][-1] if len(self.tracker["lr"]) > 0 else 0 if loss is None: return f"=>> [Epoch {epoch:03d}] Global Step {self.global_step:06d} =>> LR :: {lr:.6f}" # Otherwise, embed `loss` in status! return f"=>> [Epoch {epoch:03d}] Global Step {self.global_step:06d} =>> LR :: {lr:.6f} -- Loss :: {loss:.4f}" def commit( self, *, global_step: Optional[int] = None, resume_time: Optional[int] = None, lr: Optional[float] = None, update_step_time: bool = False, **kwargs, ) -> None: """Update all metrics in `self.tracker` by iterating through special positional arguments & kwargs.""" if not self.is_rank_zero: return # Special Positional Arguments if global_step is not None: self.global_step = global_step if resume_time is not None: self.resume_time = resume_time if lr is not None: self.tracker["lr"].append(lr) if update_step_time: self.tracker["step_time"].append(time.time() - self.step_start_time) self.step_start_time = time.time() # Generic Keyword Arguments for key, value in kwargs.items(): self.tracker[key].append(value.detach()) def push(self, epoch: int) -> str: """Push current metrics to loggers with model-specific handling.""" if not self.is_rank_zero: return loss = torch.stack(list(self.tracker["loss"])).mean().item() step_time, lr = np.mean(list(self.tracker["step_time"])), self.tracker["lr"][-1] status = self.get_status(epoch, loss) # Model-Specific Handling itemized = self.itemize() if self.model_arch == "v-mvp": self.log( self.global_step, metrics={ "Pretrain/Step": self.global_step, "Pretrain/Epoch": epoch, "Pretrain/V-MVP Train Loss": loss, "Pretrain/Reconstruction Loss": itemized["reconstruction_loss"], "Pretrain/Learning Rate": lr, "Pretrain/Step Time": step_time, }, ) elif self.model_arch in {"v-r3m", "v-rn3m"}: self.log( self.global_step, metrics={ "Pretrain/Step": self.global_step, "Pretrain/Epoch": epoch, f"Pretrain/V-{'R3M' if self.model_arch == 'v-r3m' else 'RN3M'} Train Loss": loss, "Pretrain/TCN Loss": itemized["tcn_loss"], "Pretrain/Reward Loss": itemized["reward_loss"], "Pretrain/L1 Loss": itemized["l1_loss"], "Pretrain/L2 Loss": itemized["l2_loss"], "Pretrain/TCN Accuracy": itemized["tcn_accuracy"], "Pretrain/Reward Accuracy": itemized["reward_accuracy"], "Pretrain/Learning Rate": lr, "Pretrain/Step Time": step_time, }, ) elif self.model_arch == "v-cond": self.log( self.global_step, metrics={ "Pretrain/Step": self.global_step, "Pretrain/Epoch": epoch, "Pretrain/V-Cond Train Loss": loss, "Pretrain/Reconstruction Loss": itemized["reconstruction_loss"], "Pretrain/Learning Rate": lr, "Pretrain/Step Time": step_time, }, ) elif self.model_arch == "v-dual": self.log( self.global_step, metrics={ "Pretrain/Step": self.global_step, "Pretrain/Epoch": epoch, "Pretrain/V-Dual Train Loss": loss, "Pretrain/Reconstruction Loss": itemized["reconstruction_loss"], "Pretrain/Zero Reconstruction Loss": itemized["zero_reconstruction_loss"], "Pretrain/K Reconstruction Loss": itemized["k_reconstruction_loss"], "Pretrain/Learning Rate": lr, "Pretrain/Step Time": step_time, }, ) elif self.model_arch == "v-gen": self.log( self.global_step, metrics={ "Pretrain/Step": self.global_step, "Pretrain/Epoch": epoch, "Pretrain/V-Gen Train Loss": loss, "Pretrain/Reconstruction Loss": itemized["reconstruction_loss"], "Pretrain/Zero Reconstruction Loss": itemized["zero_reconstruction_loss"], "Pretrain/K Reconstruction Loss": itemized["k_reconstruction_loss"], "Pretrain/CLM Loss": itemized["lm_loss"], "Pretrain/CLM Perplexity": itemized["lm_ppl"], "Pretrain/LM Loss": itemized["lm_loss"], "Pretrain/LM Perplexity": itemized["lm_ppl"], "Pretrain/Learning Rate": lr, "Pretrain/Step Time": step_time, }, ) else: raise ValueError(f"Metrics.push() for Model `{self.model_arch}` is not implemented!") return status def push_epoch(self, epoch: int, val_loss: torch.Tensor) -> Tuple[str, torch.Tensor, int]: """End-of-Epoch => Push accumulated metrics to loggers with model-specific handling.""" if not self.is_rank_zero: return # Compute End-of-Epoch Specialized Metrics loss, step_time = torch.stack(list(self.tracker["loss"])).mean(), np.mean(list(self.tracker["step_time"])) lr, duration = self.tracker["lr"][-1], int(time.time() - self.start_time) + self.resume_time epoch_status = ( f"[Epoch {epoch:03d}] Global Step {self.global_step:06d} =>> LR :: {lr:.6f} -- Loss :: {loss:.4f} " f"-- Val Loss :: {val_loss:.4f} -- Total Time (sec) :: {duration}" ) # Log for Model p_arch = { "v-mvp": "MVP", "v-r3m": "R3M (ViT)", "v-rn3m": "R3M (RN)", "v-cond": "V-Cond", "v-dual": "V-Dual", "v-gen": "V-Gen", }[self.model_arch] self.log( self.global_step, metrics={ "Pretrain/Step": self.global_step, "Pretrain/Epoch": epoch, "Pretrain/Training Duration": duration, f"Pretrain/{p_arch} Train Epoch Loss": loss.item(), f"Pretrain/{p_arch} Train Loss": loss.item(), f"Pretrain/{p_arch} Validation Loss": val_loss.item(), "Pretrain/Learning Rate": lr, "Pretrain/Step Time": step_time, }, ) return epoch_status, loss, duration ================================================ FILE: voltron/util/utilities.py ================================================ """ utilities.py General utilities for randomness, distributed training, and miscellaneous checks in PyTorch. === Randomness === Random `seed_everything` functionality is taken directly from PyTorch-Lighting: > Ref: https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/utilities/seed.py This is pretty important to get right if we're every randomly generating our masks (or prefix dropout) inside our Dataset __getitem__() with multiple workers... if not handled properly, we will get repeated augmentations anytime we inject randomness from non-PyTorch sources (e.g., numpy, random)! > Ref: https://tanelp.github.io/posts/a-bug-that-plagues-thousands-of-open-source-ml-projects/ === Distributed / DDP Training ==== Utilities provide a standard API across single-GPU/multi-GPU/multi-node training. Assumes that code is running with one of the following strategies: - Single Process (on CPU?, GPU) - DDP (GPU, Multi-Node GPU) --> uses the `torchrun`/`torch.distributed` API & semantics Key Terminology -> World Size :: Total number of processes distributed over (# nodes x # devices) -- assumed homogenous! -> Rank :: Integer index of current process in the total world size -> Local Rank :: Local index on given node in [0, Devices per Node] """ import os import random from typing import Callable, Iterator, Optional, TypeVar import numpy as np import torch from torch.utils.data import Dataset from torch.utils.data.distributed import DistributedSampler T_co = TypeVar("T_co", covariant=True) # === Randomness === def worker_init_function(worker_id: int) -> None: """ Borrowed directly from PyTorch-Lightning; inspired by this issue comment in the PyTorch repo: > Ref: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562 Intuition: You can think of the seed sequence spawn function as a "janky" torch.Generator() or jax.PRNGKey that you can run iterative splitting on to get new (predictable) randomness. :param worker_id: Identifier for the given worker [0, num_workers) for the Dataloader in question. """ # Get current `rank` (if running distributed) and `process_seed` global_rank, process_seed = int(os.environ["LOCAL_RANK"]), torch.initial_seed() # Back out the "base" (original) seed - the per-worker seed is set in PyTorch: # > https://pytorch.org/docs/stable/data.html#data-loading-randomness base_seed = process_seed - worker_id # "Magic" code --> basically creates a seed sequence that mixes different "sources" and seeds every library... seed_seq = np.random.SeedSequence([base_seed, worker_id, global_rank]) # Use 128 bits (4 x 32-bit words) to represent seed --> generate_state(k) produces a `k` element array! np.random.seed(seed_seq.generate_state(4)) # Spawn distinct child sequences for PyTorch (reseed) and stdlib random torch_seed_seq, random_seed_seq = seed_seq.spawn(2) # Torch Manual seed takes 64 bits (so just specify a dtype of uint64 torch.manual_seed(torch_seed_seq.generate_state(1, dtype=np.uint64)[0]) # Use 128 Bits for `random`, but express as integer instead of as an array random_seed = (random_seed_seq.generate_state(2, dtype=np.uint64).astype(list) * [1 << 64, 1]).sum() random.seed(random_seed) def set_global_seed(seed: int, get_worker_init_fn: bool = False) -> Optional[Callable[[int], None]]: """Sets seed for all randomness libraries (mostly random, numpy, torch) and produces a `worker_init_fn`""" assert np.iinfo(np.uint32).min < seed < np.iinfo(np.uint32).max, "Seed outside the np.uint32 bounds!" # Set Seed as an Environment Variable os.environ["EXPERIMENT_GLOBAL_SEED"] = str(seed) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) return worker_init_function if get_worker_init_fn else None # === Distributed Training === class ResumeableDistributedSampler(DistributedSampler): def __init__( self, seen_examples: int, resume_epoch: int, dataset: Dataset, num_replicas: int, rank: int, shuffle: bool = True, seed: int = 0, ) -> None: super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, seed=seed) self.seen_examples, self.resume_epoch, self.do_resume = seen_examples, resume_epoch, True # Set `seen_examples_per_replica` --> this is necessary for when we re-wrap the iterator in self.__iter__() # > Note: `seen_examples` is across _all_ replicas --> so divide! self.seen_examples_per_replica = self.seen_examples // self.num_replicas def __iter__(self) -> Iterator[T_co]: epoch_iterator = super().__iter__() if self.do_resume: # Unpack iterator --> list, slice off the first `seen_examples_per_replica` examples, and re-wrap! leftover_idxs = list(epoch_iterator)[self.seen_examples_per_replica :] return iter(leftover_idxs) else: return epoch_iterator def __len__(self) -> int: if self.do_resume: # Remove the "seen" sample from self.num_samples; num_samples is *per replica*! return self.num_samples - self.seen_examples_per_replica else: return self.num_samples def set_epoch(self, epoch: int) -> None: # If epoch != self.resume_epoch --> we're in "regular DistributedSampler" mode (just a wrapper class) # > Intuition: We should *only* truncate examples on the first epoch upon resuming! self.epoch = epoch if self.epoch != self.resume_epoch: self.do_resume = False ================================================ FILE: voltron/util/v1/__init__.py ================================================ ================================================ FILE: voltron/util/v1/checkpointing.py ================================================ """ checkpointing.py XLA-specific utility class for handling model/optimizer serialization & checkpointing. Support the following strategies: - (k, -1, -1) --> Keep only the most recent "k" epoch checkpoints - (k, m, -1) --> Keep the most recent "k" epoch checkpoints and *every* m epoch checkpoint - (k, m, s = 2500) --> Keep "k" and "m" subject to above, but also keep *s* step checkpoints for current epoch """ import os from collections import deque from pathlib import Path from typing import Any, Optional, Tuple import torch.nn as nn from torch.optim.optimizer import Optimizer class FixedDeck(deque): def __init__(self, maxlen: int) -> None: super().__init__(maxlen=maxlen) def append(self, x: Any) -> Any: pop_value = None if self.__len__() == self.maxlen: pop_value = self.__getitem__(0) # Perform parent append and return popped value, if any! super().append(x) return pop_value class XLACheckpointSaver: def __init__(self, strategy: Tuple[int, int, int], run_dir: str) -> None: """ Create a checkpoint saver with the provided strategy that saves to the given path, with XLA-specific handling. :param strategy: Strategy, following the (k, -1, -1) -- (k, m, -1) -- (k, m, s) description above. :param run_dir: Path to root of `run_dir` """ import torch_xla.core.xla_model as xm (self.k, self.m, self.s), self.run_dir = strategy, run_dir self.recents, self.intervals, self.step_checkpoints = FixedDeck(maxlen=self.k), set(), set() # If `self.s` is -1 --> disable step_checkpoints self.enable_step = self.s != -1 # Create "checkpoints" subdirectory self.path = Path(run_dir) / "checkpoints" if xm.is_master_ordinal(local=False): os.makedirs(self.path, exist_ok=True) # Populate `step_checkpoints` on __init__ (if resuming *within* an epoch...) self.step_checkpoints.update([c for c in self.path.iterdir() if "local-epoch=" in str(c)]) # Create Saver xm.master_print(f"Created Saver w/ `k` = {self.k}, `m` = {self.m}`, `s` = {self.s}!") def save( self, epoch: int, is_local_step: bool, model: nn.Module, optimizer: Optimizer, duration: int, local_step: Optional[int] = None, train_loss: Optional[float] = None, val_loss: Optional[float] = None, ) -> None: """Performs the save operation, unlinking existing stale checkpoints, if necessary.""" import torch_xla.core.xla_model as xm # Check if saving a `local_step` (within an epoch) or if saving an `epoch` if self.enable_step and is_local_step and (local_step % self.s) == 0: # Create filename step_checkpoint = self.path / f"local-epoch={epoch}-step={local_step}-t={duration}.pt" # Perform actual save action... # > IMPORTANT --> XLA/XM will throw an error if optimizer has "param_groups" so only save "state"... xm.save([model.state_dict(), optimizer.state_dict()["state"]], step_checkpoint) if xm.is_master_ordinal(local=False): self.step_checkpoints.add(step_checkpoint) elif not is_local_step: # Create filename if train_loss is None and val_loss is None: checkpoint = self.path / f"epoch={epoch}-train=inf-val=inf-t={duration}.pt" else: checkpoint = self.path / f"epoch={epoch}-train={train_loss:.4f}-val={val_loss:.4f}-t={duration}.pt" # Perform actual save action... # > IMPORTANT --> XLA/XM will throw an error if optimizer has "param_groups" so only save "state"... xm.save([model.state_dict(), optimizer.state_dict()["state"]], checkpoint) if xm.is_master_ordinal(local=False): # Conditional Check for M -- Keep if modulated by interval if epoch % self.m == 0: self.intervals.add(checkpoint) # Remove all "step_checkpoints" now that we successfully made it to the end of the epoch! while len(self.step_checkpoints) > 0: os.remove(self.step_checkpoints.pop()) # Finally, recency add & unlink/delete if necessary to_remove = self.recents.append(checkpoint) if to_remove is not None and to_remove not in self.intervals: os.remove(to_remove) ================================================ FILE: voltron/util/v1/distributed.py ================================================ """ distributed.py Key distributed utilities; notably provides a standard API for getting relevant data from either CPU/GPU or XLA (TPU) devices, since the underlying implementation does differ substantially. Assumes that code is running with one of the following strategies: - Single Process (on CPU, GPU) - DDP (CPU, GPU)... uses the torch.distributed.launch API & semantics - XMP Spawn (TPU)... TPU based XLA + Multiprocessing Spawn semantics Key Terminology -> World Size :: Total number of processes distributed over (# nodes x # devices) -- assumed homogenous! -> Rank :: Integer index of current process in the total world size -> Local Rank :: Local index on given node in [0, Devices per Node] """ from importlib.util import find_spec from typing import Iterator, TypeVar import torch from torch.utils.data import Dataset from torch.utils.data.distributed import DistributedSampler T_co = TypeVar("T_co", covariant=True) class ResumeableDistributedSampler(DistributedSampler): def __init__( self, seen_examples: int, resume_epoch: int, dataset: Dataset, num_replicas: int, rank: int, shuffle: bool = True, seed: int = 0, ) -> None: super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, seed=seed) self.seen_examples, self.resume_epoch, self.do_resume = seen_examples, resume_epoch, True # Set `seen_examples_per_replica` --> this is necessary for when we re-wrap the iterator in self.__iter__() # > Note: `seen_examples` is across _all_ replicas --> so divide! self.seen_examples_per_replica = self.seen_examples // self.num_replicas def __iter__(self) -> Iterator[T_co]: epoch_iterator = super().__iter__() if self.do_resume: # Unpack iterator --> list, slice off the first `seen_examples_per_replica` examples, and re-wrap! leftover_idxs = list(epoch_iterator)[self.seen_examples_per_replica :] return iter(leftover_idxs) else: return epoch_iterator def __len__(self) -> int: if self.do_resume: # Remove the "seen" sample from self.num_samples; num_samples is *per replica*! return self.num_samples - self.seen_examples_per_replica else: return self.num_samples def set_epoch(self, epoch: int) -> None: # If epoch != self.resume_epoch --> we're in "regular DistributedSampler" mode (just a wrapper class) # > Intuition: We should *only* truncate examples on the first epoch upon resuming! self.epoch = epoch if self.epoch != self.resume_epoch: self.do_resume = False def xla_available() -> bool: try: return find_spec("torch_xla") is not None except ModuleNotFoundError: return False def get_rank() -> int: """Returns the global rank [0, World Size) of the current process.""" if xla_available(): import torch_xla.core.xla_model as xm # By default, if XLA is available, assume we're running under XMP Spawn return xm.get_ordinal() # Try to get rank via torch.distributed, but catch error if only single process try: return torch.distributed.get_rank() # RuntimeError => not running distributed (single process) except RuntimeError: return 0 ================================================ FILE: voltron/util/v1/random.py ================================================ """ random.py Utilities for dealing with randomness for PyTorch, across devices (CPU, GPU, TPU). Loosely inspired by functionality in PyTorch-Lightning: > Ref: https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/utilities/seed.py This is pretty important to get right if we're every randomly generating our masks (or prefix dropout) inside our Dataset __getitem__() with multiple workers... if not handled properly, we will get repeated augmentations anytime we inject randomness from non-PyTorch sources (e.g., numpy, random)! > Ref: https://tanelp.github.io/posts/a-bug-that-plagues-thousands-of-open-source-ml-projects/ """ import os import random from typing import Callable import numpy as np import torch from voltron.util.v1.distributed import get_rank def set_global_seed(seed: int) -> Callable[[int], None]: """Sets seed for all randomness libraries (mostly random, numpy, torch) and produces a `worker_init_fn`""" assert np.iinfo(np.uint32).min < seed < np.iinfo(np.uint32).max, "Seed outside the np.uint32 bounds!" # Set Seed as an Environment Variable os.environ["EXPERIMENT_GLOBAL_SEED"] = str(seed) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) return worker_init_function def worker_init_function(worker_id: int) -> None: """ Borrowed directly from PyTorch-Lightning; inspired by this issue comment in the PyTorch repo: > Ref: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562 Intuition: You can think of the seed sequence spawn function as a "janky" torch.Generator() or jax.PRNGKey that you can run iterative splitting on to get new (predictable) randomness. :param worker_id: Identifier for the given worker [0, num_workers) for the Dataloader in question. """ # Get current `rank` (if running distributed) and `process_seed` global_rank, process_seed = get_rank(), torch.initial_seed() # Back out the "base" (original) seed - the per-worker seed is set in PyTorch: # > https://pytorch.org/docs/stable/data.html#data-loading-randomness base_seed = process_seed - worker_id # "Magic" code --> basically creates a seed sequence that mixes different "sources" and seeds every library... seed_seq = np.random.SeedSequence([base_seed, worker_id, global_rank]) # Use 128 bits (4 x 32-bit words) to represent seed --> generate_state(k) produces a `k` element array! np.random.seed(seed_seq.generate_state(4)) # Spawn distinct child sequences for PyTorch (reseed) and stdlib random torch_seed_seq, random_seed_seq = seed_seq.spawn(2) # Torch Manual seed takes 64 bits (so just specify a dtype of uint64 torch.manual_seed(torch_seed_seq.generate_state(1, dtype=np.uint64)[0]) # Use 128 Bits for `random`, but express as integer instead of as an array random_seed = (random_seed_seq.generate_state(2, dtype=np.uint64).astype(list) * [1 << 64, 1]).sum() random.seed(random_seed) ================================================ FILE: voltron/util/v1/xla_logger.py ================================================ """ xla_logger.py Utility class defining various XLA logging methods (called within marked closures), for logging metrics periodically through training & validation. """ from typing import List import jsonlines import numpy as np import torch import torch_xla.core.xla_model as xm import wandb # === Generic (Cross-Model) Epoch End Update === def log_epoch_end_update( arch: str, epoch: int, global_step: int, run_id: str, duration: int, train_losses: List[torch.Tensor], val_loss: float, lr: float, step_times: List[float], ) -> None: train_loss = torch.stack(list(train_losses)).mean() average_step_time = np.mean(list(step_times)) # Console Logging --> Unclear if it'll work? xm.master_print( f"Epoch {epoch:03d}, Global Step {global_step:06d} || LR :: {lr:.6f} -- Train Loss :: {train_loss:.4f} " f"-- Val Loss :: {val_loss:.4f} -- Total Time (sec) :: {duration}" ) # Get Log-Friendly Arch p_arch = { "v-mvp": "MVP", "v-r3m": "R3M (ViT)", "v-rn3m": "R3M (RN)", "v-cond": "V-Cond", "v-dual": "V-Dual", "v-gen": "V-Gen", }[arch] # Log to Weights & Biases & JSONL blob = { "Pretrain/Step": global_step, "Pretrain/Epoch": epoch, "Pretrain/Training Duration": duration, "Pretrain/Step Time": average_step_time, f"Pretrain/{p_arch} Train Epoch Loss": train_loss.item(), f"Pretrain/{p_arch} Train Loss": train_loss.item(), f"Pretrain/{p_arch} Validation Loss": val_loss, "Pretrain/Learning Rate": lr, } wandb.log(blob, step=global_step) with jsonlines.open(f"{run_id}.jsonl", mode="a", sort_keys=True) as js_logger: js_logger.write(blob) # === Data-Locked Reproductions === def log_vmvp_train_update( epoch: int, global_step: int, run_id: str, train_losses: List[torch.Tensor], lr: float, reconstruction_losses: List[torch.Tensor], step_times: List[float], ) -> None: train_loss = torch.stack(list(train_losses)).mean() reconstruction_loss = torch.stack(list(reconstruction_losses)).mean() average_step_time = np.mean(list(step_times)) # Console Logging --> Just log the aggregated train loss... xm.master_print( f"Epoch {epoch:03d}, Global Step {global_step:06d} || LR :: {lr:.6f} -- Train Loss :: {train_loss:.4f}" ) # Log to Weights & Biases + JSONL blob = { "Pretrain/Step": global_step, "Pretrain/Epoch": epoch, "Pretrain/V-MVP Train Loss": train_loss.item(), "Pretrain/Reconstruction Loss": reconstruction_loss.item(), "Pretrain/Learning Rate": lr, "Pretrain/Step Time": average_step_time, } wandb.log(blob, step=global_step) with jsonlines.open(f"{run_id}.jsonl", mode="a", sort_keys=True) as js_logger: js_logger.write(blob) def log_vr3m_train_update( epoch: int, global_step: int, run_id: str, train_losses: List[torch.Tensor], lr: float, tcn_losses: List[torch.Tensor], reward_losses: List[torch.Tensor], l1_losses: List[torch.Tensor], l2_losses: List[torch.Tensor], tcn_accuracies: List[torch.Tensor], reward_accuracies: List[torch.Tensor], step_times: List[float], ) -> None: train_loss = torch.stack(list(train_losses)).mean() tcn_loss = torch.stack(list(tcn_losses)).mean() reward_loss = torch.stack(list(reward_losses)).mean() l1_loss, l2_loss = torch.stack(list(l1_losses)).mean(), torch.stack(list(l2_losses)).mean() tcn_accuracy = torch.stack(list(tcn_accuracies)).mean() reward_accuracy = torch.stack(list(reward_accuracies)).mean() average_step_time = np.mean(list(step_times)) # Console Logging --> Just log the aggregated train loss... xm.master_print( f"Epoch {epoch:03d}, Global Step {global_step:06d} || LR :: {lr:.6f} -- Train Loss :: {train_loss:.4f}" ) # Log to Weights & Biases + JSONL blob = { "Pretrain/Step": global_step, "Pretrain/Epoch": epoch, "Pretrain/V-R3M Train Loss": train_loss.item(), "Pretrain/TCN Loss": tcn_loss.item(), "Pretrain/Reward Loss": reward_loss.item(), "Pretrain/L1 Loss": l1_loss.item(), "Pretrain/L2 Loss": l2_loss.item(), "Pretrain/TCN Accuracy": tcn_accuracy.item(), "Pretrain/Reward Accuracy": reward_accuracy.item(), "Pretrain/Learning Rate": lr, "Pretrain/Step Time": average_step_time, } wandb.log(blob, step=global_step) with jsonlines.open(f"{run_id}.jsonl", mode="a", sort_keys=True) as js_logger: js_logger.write(blob) def log_vrn3m_train_update( epoch: int, global_step: int, run_id: str, train_losses: List[torch.Tensor], lr: float, tcn_losses: List[torch.Tensor], reward_losses: List[torch.Tensor], l1_losses: List[torch.Tensor], l2_losses: List[torch.Tensor], tcn_accuracies: List[torch.Tensor], reward_accuracies: List[torch.Tensor], step_times: List[float], ) -> None: train_loss = torch.stack(list(train_losses)).mean() tcn_loss = torch.stack(list(tcn_losses)).mean() reward_loss = torch.stack(list(reward_losses)).mean() l1_loss, l2_loss = torch.stack(list(l1_losses)).mean(), torch.stack(list(l2_losses)).mean() tcn_accuracy = torch.stack(list(tcn_accuracies)).mean() reward_accuracy = torch.stack(list(reward_accuracies)).mean() average_step_time = np.mean(list(step_times)) # Console Logging --> Just log the aggregated train loss... xm.master_print( f"Epoch {epoch:03d}, Global Step {global_step:06d} || LR :: {lr:.6f} -- Train Loss :: {train_loss:.4f}" ) # Log to Weights & Biases + JSONL blob = { "Pretrain/Step": global_step, "Pretrain/Epoch": epoch, "Pretrain/V-RN3M Train Loss": train_loss.item(), "Pretrain/TCN Loss": tcn_loss.item(), "Pretrain/Reward Loss": reward_loss.item(), "Pretrain/L1 Loss": l1_loss.item(), "Pretrain/L2 Loss": l2_loss.item(), "Pretrain/TCN Accuracy": tcn_accuracy.item(), "Pretrain/Reward Accuracy": reward_accuracy.item(), "Pretrain/Learning Rate": lr, "Pretrain/Step Time": average_step_time, } wandb.log(blob, step=global_step) with jsonlines.open(f"{run_id}.jsonl", mode="a", sort_keys=True) as js_logger: js_logger.write(blob) # === Voltron Models === def log_vcond_train_update( epoch: int, global_step: int, run_id: str, train_losses: List[torch.Tensor], lr: float, reconstruction_losses: List[torch.Tensor], step_times: List[float], ) -> None: train_loss = torch.stack(list(train_losses)).mean() reconstruction_loss = torch.stack(list(reconstruction_losses)).mean() average_step_time = np.mean(list(step_times)) # Console Logging --> Just log the aggregated train loss... xm.master_print( f"Epoch {epoch:03d}, Global Step {global_step:06d} || LR :: {lr:.6f} -- Train Loss :: {train_loss:.4f}" ) # Log to Weights & Biases + JSONL blob = { "Pretrain/Step": global_step, "Pretrain/Epoch": epoch, "Pretrain/V-Cond Train Loss": train_loss.item(), "Pretrain/Reconstruction Loss": reconstruction_loss.item(), "Pretrain/Learning Rate": lr, "Pretrain/Step Time": average_step_time, } wandb.log(blob, step=global_step) with jsonlines.open(f"{run_id}.jsonl", mode="a", sort_keys=True) as js_logger: js_logger.write(blob) def log_vdual_train_update( epoch: int, global_step: int, run_id: str, train_losses: List[torch.Tensor], lr: float, reconstruction_losses: List[torch.Tensor], zero_reconstruction_losses: List[torch.Tensor], k_reconstruction_losses: List[torch.Tensor], step_times: List[float], ) -> None: train_loss = torch.stack(list(train_losses)).mean() reconstruction_loss = torch.stack(list(reconstruction_losses)).mean() zero_reconstruction_loss = torch.stack(list(zero_reconstruction_losses)).mean() k_reconstruction_loss = torch.stack(list(k_reconstruction_losses)).mean() average_step_time = np.mean(list(step_times)) # Console Logging --> Just log the aggregated train loss... xm.master_print( f"Epoch {epoch:03d}, Global Step {global_step:06d} || LR :: {lr:.6f} -- Train Loss :: {train_loss:.4f}" ) # Log to Weights & Biases + JSONL blob = { "Pretrain/Step": global_step, "Pretrain/Epoch": epoch, "Pretrain/V-Dual Train Loss": train_loss.item(), "Pretrain/Reconstruction Loss": reconstruction_loss.item(), "Pretrain/Zero Reconstruction Loss": zero_reconstruction_loss.item(), "Pretrain/K Reconstruction Loss": k_reconstruction_loss.item(), "Pretrain/Learning Rate": lr, "Pretrain/Step Time": average_step_time, } wandb.log(blob, step=global_step) with jsonlines.open(f"{run_id}.jsonl", mode="a", sort_keys=True) as js_logger: js_logger.write(blob) def log_vgen_train_update( epoch: int, global_step: int, run_id: str, train_losses: List[torch.Tensor], lr: float, reconstruction_losses: List[torch.Tensor], lm_losses: List[torch.Tensor], lm_ppl: List[torch.Tensor], zero_reconstruction_losses: List[torch.Tensor], k_reconstruction_losses: List[torch.Tensor], step_times: List[float], ) -> None: train_loss = torch.stack(list(train_losses)).mean() reconstruction_loss = torch.stack(list(reconstruction_losses)).mean() lm_loss = torch.stack(list(lm_losses)).mean() lm_perplexity = torch.stack(list(lm_ppl)).mean() zero_reconstruction_loss = torch.stack(list(zero_reconstruction_losses)).mean() k_reconstruction_loss = torch.stack(list(k_reconstruction_losses)).mean() average_step_time = np.mean(list(step_times)) # Console Logging --> Just log the aggregated train loss... xm.master_print( f"Epoch {epoch:03d}, Global Step {global_step:06d} || LR :: {lr:.6f} -- Train Loss :: {train_loss:.4f} --" f" Reconstruction Loss {reconstruction_loss:.4f} -- LM Loss {lm_loss:.4f}" ) # Log to Weights & Biases + JSONL blob = { "Pretrain/Step": global_step, "Pretrain/Epoch": epoch, "Pretrain/V-Gen Train Loss": train_loss.item(), "Pretrain/Reconstruction Loss": reconstruction_loss.item(), "Pretrain/CLM Loss": lm_loss.item(), "Pretrain/CLM Perplexity": lm_perplexity.item(), "Pretrain/LM Loss": lm_loss.item(), "Pretrain/LM Perplexity": lm_perplexity.item(), "Pretrain/Zero Reconstruction Loss": zero_reconstruction_loss.item(), "Pretrain/K Reconstruction Loss": k_reconstruction_loss.item(), "Pretrain/Learning Rate": lr, "Pretrain/Step Time": average_step_time, } wandb.log(blob, step=global_step) with jsonlines.open(f"{run_id}.jsonl", mode="a", sort_keys=True) as js_logger: js_logger.write(blob)