Repository: siddk/voltron-robotics
Branch: main
Commit: 1b299bf5cfa0
Files: 60
Total size: 431.8 KB
Directory structure:
gitextract_4mb9jrxp/
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── Makefile
├── README.md
├── docs/
│ └── ROADMAP.md
├── examples/
│ ├── pretrain/
│ │ ├── README.md
│ │ ├── preprocess.py
│ │ └── pretrain.py
│ ├── usage.py
│ ├── verification/
│ │ └── verify.py
│ └── xla-reference/
│ ├── README.md
│ ├── xpreprocess.py
│ └── xpretrain.py
├── pyproject.toml
├── setup.py
└── voltron/
├── __init__.py
├── conf/
│ ├── __init__.py
│ ├── accelerators.py
│ ├── datasets.py
│ ├── models.py
│ └── tracking.py
├── datasets/
│ ├── __init__.py
│ ├── datasets.py
│ └── v1/
│ ├── __init__.py
│ └── stream_datasets.py
├── models/
│ ├── __init__.py
│ ├── core/
│ │ ├── __init__.py
│ │ ├── vcond.py
│ │ ├── vdual.py
│ │ └── vgen.py
│ ├── instantiate.py
│ ├── materialize.py
│ ├── reproductions/
│ │ ├── __init__.py
│ │ ├── vmvp.py
│ │ ├── vr3m.py
│ │ └── vrn3m.py
│ └── util/
│ ├── __init__.py
│ ├── extraction.py
│ ├── optimization.py
│ └── transformer.py
├── overwatch/
│ ├── __init__.py
│ └── overwatch.py
├── preprocessing/
│ ├── __init__.py
│ ├── core.py
│ ├── process.py
│ ├── transforms.py
│ └── v1/
│ ├── __init__.py
│ ├── process.py
│ ├── transforms.py
│ └── utils.py
└── util/
├── __init__.py
├── checkpointing.py
├── metrics.py
├── utilities.py
└── v1/
├── __init__.py
├── checkpointing.py
├── distributed.py
├── random.py
└── xla_logger.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# Ruff
.ruff_cache/
# IDE caches
.idea/
.vscode/
# Mac OS
.DS_Store
# Cache
data/
cache/
# Scratch
scratch/
================================================
FILE: .pre-commit-config.yaml
================================================
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
exclude: ".git"
repos:
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.252
hooks:
- id: ruff
args: [ --fix, --exit-non-zero-on-fix ]
- repo: https://github.com/psf/black
rev: 23.1.0
hooks:
- id: black
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: check-added-large-files
args: ["--maxkb=40000"]
- id: check-ast
- id: check-case-conflict
- id: check-merge-conflict
- id: check-toml
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2021-present, Siddharth Karamcheti and other contributors.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: Makefile
================================================
.PHONY: help check autoformat
.DEFAULT: help
# Generates a useful overview/help message for various make features - add to this as necessary!
help:
@echo "make check"
@echo " Run code style and linting (black, ruff) *without* changing files!"
@echo "make autoformat"
@echo " Run code styling (black, ruff) and update in place - committing with pre-commit also does this."
check:
black --check .
ruff check --show-source .
autoformat:
black .
ruff check --fix --show-fixes .
================================================
FILE: README.md
================================================
[](https://arxiv.org/abs/2302.12766)
[](https://pytorch.org/get-started/locally/)
[](https://github.com/psf/black)
[](https://github.com/charliermarsh/ruff)

---
# Language-Driven Representation Learning for Robotics
Package repository for Voltron: Language-Driven Representation Learning for Robotics. Provides code for loading
pretrained Voltron, R3M, and MVP representations for adaptation to downstream tasks, as well as code for pretraining
such representations on arbitrary datasets.
---
## Quickstart
This repository is built with PyTorch; while specified as a dependency for the package, we highly recommend that
you install the desired version (e.g., with accelerator support) for your given hardware and environment
manager (e.g., `conda`).
PyTorch installation instructions [can be found here](https://pytorch.org/get-started/locally/). This repository
should work with PyTorch >= 1.12. Releases before 1.1.0 have been thoroughly tested with PyTorch 1.12.0,
Torchvision 0.13.0, and Torchaudio 0.12.0. **Note**: Releases 1.1.0 and after *assume PyTorch 2.0*!
Once PyTorch has been properly installed, you can install this package via PyPI, and you're off!
```bash
pip install voltron-robotics
```
You can also install this package locally via an editable installation in case you want to run examples/extend the
current functionality:
```bash
git clone https://github.com/siddk/voltron-robotics
cd voltron-robotics
pip install -e .
```
## Usage
Voltron Robotics (package: `voltron`) is structured to provide easy access to pretrained Voltron models (and
reproductions), to facilitate use for various downstream tasks. Using a pretrained Voltron model is easy:
```python
from torchvision.io import read_image
from voltron import instantiate_extractor, load
# Load a frozen Voltron (V-Cond) model & configure a vector extractor
vcond, preprocess = load("v-cond", device="cuda", freeze=True)
vector_extractor = instantiate_extractor(vcond)()
# Obtain & Preprocess an image =>> can be from a dataset, or camera on a robot, etc.
# => Feel free to add any language if you have it (Voltron models work either way!)
img = preprocess(read_image("examples/img/peel-carrot-initial.png"))[None, ...].to("cuda")
lang = ["peeling a carrot"]
# Extract both multimodal AND vision-only embeddings!
multimodal_embeddings = vcond(img, lang, mode="multimodal")
visual_embeddings = vcond(img, mode="visual")
# Use the `vector_extractor` to output dense vector representations for downstream applications!
# => Pass this representation to model of your choice (object detector, control policy, etc.)
representation = vector_extractor(multimodal_embeddings)
```
Voltron representations can be used for a variety of different applications; in the
[`voltron-evaluation`](https://github.com/siddk/voltron-evaluation) repository, you can find code for adapting Voltron
representations to various downstream tasks (segmentation, object detection, control, etc.); all the applications from
our paper.
---
## API

The package `voltron` provides the following functionality for using and adapting existing representations:
#### `voltron.available_models()`
Returns the name of available Voltron models; right now, the following models (all models trained in the paper) are
available:
- `v-cond` – V-Cond (ViT-Small) trained on Sth-Sth; single-frame w/ language-conditioning.
- `v-dual` – V-Dual (ViT-Small) trained on Sth-Sth; dual-frame w/ language-conditioning.
- `v-gen` – V-Gen (ViT-Small) trained on Sth-Sth; dual-frame w/ language conditioning AND generation.
- `r-mvp` – R-MVP (ViT-Small); reproduction of [MVP](https://github.com/ir413/mvp) trained on Sth-Sth.
- `r-r3m-vit` – R-R3M (ViT-Small); reproduction of [R3M](https://github.com/facebookresearch/r3m) trained on Sth-Sth.
- `r-r3m-rn50` – R-R3M (ResNet-50); reproduction of [R3M](https://github.com/facebookresearch/r3m) trained on Sth-Sth.
- `v-cond-base` – V-Cond (ViT-Base) trained on Sth-Sth; larger (86M parameter) variant of V-Cond.
#### `voltron.load(name: str, device: str, freeze: bool, cache: str = cache/)`
Returns the model and the Torchvision Transform needed by the model, where `name` is one of the strings returned
by `voltron.available_models()`; this in general follows the same API as
[OpenAI's CLIP](https://github.com/openai/CLIP).
---
Voltron models (`v-{cond, dual, gen, ...}`) returned by `voltron.load()` support the following:
#### `model(img: Tensor, lang: Optional[List[str]], mode: str = "multimodal")`
Returns a sequence of embeddings corresponding to the output of the multimodal encoder; note that `lang` can be None,
which is totally fine for Voltron models! However, if you have any language (even a coarse task description), it'll
probably be helpful!
The parameter `mode` in `["multimodal", "visual"]` controls whether the output will contain the fused image patch and
language embeddings, or only the image patch embeddings.
**Note:** For the API for the non-Voltron models (e.g., R-MVP, R-R3M), take a look at
[`examples/verify.py`](examples/verify.py); this file shows how representations from *every* model can be extracted.
### Adaptation
See [`examples/usage.py`](examples/usage.py) and the [`voltron-evaluation`](https://github.com/siddk/voltron-evaluation)
repository for more examples on the various ways to adapt/use Voltron representations.
---
## Contributing
Before committing to the repository, make sure to set up your dev environment!
Here are the basic development environment setup guidelines:
+ Fork/clone the repository, performing an editable installation. Make sure to install with the development dependencies
(e.g., `pip install -e ".[dev]"`); this will install `black`, `ruff`, and `pre-commit`.
+ Install `pre-commit` hooks (`pre-commit install`).
+ Branch for the specific feature/issue, issuing PR against the upstream repository for review.
Additional Contribution Notes:
- This project has migrated to the recommended
[`pyproject.toml` based configuration for setuptools](https://setuptools.pypa.io/en/latest/userguide/quickstart.html).
However, as some tools haven't yet adopted [PEP 660](https://peps.python.org/pep-0660/), we provide a
[`setup.py` file](https://setuptools.pypa.io/en/latest/userguide/pyproject_config.html).
- This package follows the [`flat-layout` structure](https://setuptools.pypa.io/en/latest/userguide/package_discovery.html#flat-layout)
described in `setuptools`.
- Make sure to add any new dependencies to the `project.toml` file!
---
## Repository Structure
High-level overview of repository/project file-tree:
+ `docs/` - Package documentation & assets - including project roadmap.
+ `voltron` - Package source code; has all core utilities for model specification, loading, feature extraction,
preprocessing, etc.
+ `examples/` - Standalone examples scripts for demonstrating various functionality (e.g., extracting different types
of representations, adapting representations in various contexts, pretraining, amongst others).
+ `.pre-commit-config.yaml` - Pre-commit configuration file (sane defaults + `black` + `ruff`).
+ `LICENSE` - Code is made available under the MIT License.
+ `Makefile` - Top-level Makefile (by default, supports linting - checking & auto-fix); extend as needed.
+ `pyproject.toml` - Following PEP 621, this file has all project configuration details (including dependencies), as
well as tool configurations (for `black` and `ruff`).
+ `README.md` - You are here!
---
## Citation
Please cite [our paper](https://arxiv.org/abs/2302.12766) if using any of the Voltron models, evaluation suite, or other parts of our framework in your work.
```bibtex
@inproceedings{karamcheti2023voltron,
title={Language-Driven Representation Learning for Robotics},
author={Siddharth Karamcheti and Suraj Nair and Annie S. Chen and Thomas Kollar and Chelsea Finn and Dorsa Sadigh and Percy Liang},
booktitle={Robotics: Science and Systems (RSS)},
year={2023}
}
```
================================================
FILE: docs/ROADMAP.md
================================================
# Project Roadmap
We document the future of this project (new features to be added, issues to address) here. For the most part, any
new features/bugfixes are documented as [Github Issues](https://github.com/siddk/voltron-robotics/issues).
## Timeline
[X] - **February 26th, 2023**: Initial Voltron-Robotics release with support for loading/adapting all pretrained models,
with comprehensive verification scripts & a small adaptation example.
[X] - **April 4, 2023**: [#1](https://github.com/siddk/voltron-robotics/issues/1) - Add `xpretrain.py` reference script,
mostly for completeness. Refactor/rewrite the preprocessing and pretraining pipeline to reflect
the Qualcomm Sth-Sth data format, as well as PyTorch DDP vs. the patched PyTorch XLA!
[X] - **April 11, 2023**: [#2](https://github.com/siddk/voltron-robotics/issues/2) - Add support and a more general API
for pretraining on other datasets.
[ ] - **Future**: [#5](https://github.com/siddk/voltron-robotics/issues/5) - Add better documentation and examples
around using the MAP extractor (especially for adaptation tasks).
================================================
FILE: examples/pretrain/README.md
================================================
# Pretraining Voltron Models
We provide scripts for pretraining Voltron models on various datasets. Below, we provide the full pipeline from
downloading the raw Something-Something-v2 Dataset from Qualcomm, running preprocessing, then running Distributed
Data Parallel (DDP) pretraining on 1+ GPUs via `torchrun`. Adding support for new datasets should follow this same
general flow.
---
## Dataset Preprocessing
We provide end-to-end instructions for downloading, preprocessing, and serializing various pretraining datasets (and
combinations thereof). Where possible, we provide links to batch/dataset index files.
**Note:** We make a key assumption that you have enough local disk space (e.g., on your server, attached NFS volume) to
store all *raw* and *preprocessed* data; this can range from 100s of GBs to 10s of TBs! We did not have access to such
storage in the original work, necessitating the *streaming* dataloaders defined in
`voltron/datasets/v1/stream_datasets.py`. Given your resources, you might consider adopting a similar approach; feel
free to post an issue with any questions!
We currently support pretraining on the following datasets:
- [Something-Something-v2](https://developer.qualcomm.com/software/ai-datasets/something-something)
Instructions for downloading/preprocessing each dataset can be found below!
---
### Something-Something-v2
Dataset Download: [Qualcomm AI Datasets](https://developer.qualcomm.com/software/ai-datasets/something-something)
#### Obtaining the Raw Dataset
Follow the instructions [at the above link](https://developer.qualcomm.com/software/ai-datasets/something-something) to
download the dataset. Qualcomm requires that you register for a
[Qualcomm OneID Account](https://myaccount.qualcomm.com/signup?target=https%3A%2F%2Fdeveloper.qualcomm.com)
to get access to the data. Approval might take some time.
After registering for an account, make sure to download all of the following files to a directory of your choosing
(we create a directory `data/raw/something-something-v2/downloaded/`). *You will need to manually download all 22 of
the following files from the Qualcomm site*:
1. Datasheet / Instructions (PDF – optional, but useful): `20bn-something-something_download_instructions_-_091622.pdf`
2. Labels (includes language annotations): `20bn-something-something_download-package-labels.zip`
3. Chunked Videos (should be 20 `.zip` archives):
+ `20bn-something-something-v2-00.zip`
+ ...
+ `20bn-something-something-v2-19.zip`
To extract all the given files (we extract to `data/raw/something-something-v2/`) - *execute the following from inside
the `downloaded/` subdirectory)*:
```bash
# Labels (annotations/language) --> creates `data/raw/something-something-v2/labels`
unzip 20bn-something-something-download-package-labels.zip -d ../
# Videos (following instructions in `20-bn-something-something_download_instructions_-_091622.pdf`)
unzip "20bn-something-something-v2-*.zip" -d ../videos
cd ../videos
cat 20bn-something-something-?? | tar -xvzf -
find . -maxdepth 1 -type f -delete
cd 20bn-something-something-v2/
find . -mindepth 1 -maxdepth 1 -exec mv -t .. -- {} +
cd ..
rm -r 20bn-something-something-v2
ls | wc # Should have 220847 `.webm` files!
```
#### Dataset Information & Statistics
Something-Something-v2 consists of 220,847 `.webm` clips (168,913 in the `train` split) each with a height of exactly
240px, and variable width. The frames are encoded at a fixed 12 FPS.
There are an average of 45 frames per clip (approx ~7 KB per jpeg); ~7.6M frames total (~56 GB).
#### Video/Image Transformations --> from Video Clip to "frame" --> "tensor"
```python
import av
from PIL import Image, ImageOps
# Resolutions for "preprocessing" (serialize to disk) and "training"
PREPROCESS_RESOLUTION, TRAIN_RESOLUTION = 240, 224
# Define Preprocessing Transformation
def preprocess_transform(frames: List[Image.Image]) -> List[Image.Image]:
# Assert width >= height and height >= PREPROCESS_RESOLUTION
orig_w, orig_h = frames[0].size
assert orig_w >= orig_h >= PREPROCESS_RESOLUTION
# Compute scale factor --> just a function of height and PREPROCESS_RESOLUTION
scale_factor = PREPROCESS_RESOLUTION / orig_h
# Full Transformation --> scale (preserve aspect ratio, then get square)
for idx in range(len(frames)):
frames[idx] = ImageOps.scale(frames[idx], factor=scale_factor)
left = (frames[idx].size[0] - PREPROCESS_RESOLUTION) // 2
frames[idx] = frames[idx].crop((left, 0, left + PREPROCESS_RESOLUTION, PREPROCESS_RESOLUTION))
return frames
def train_transform(img) -> torch.Tensor:
# Assumes square, just resizes to TRAIN_RESOLUTION via `torchvision.transforms`
...
def extract_frames(webm_file: str) -> None:
container = av.open(webm_file)
assert int(container.streams.video[0].average_rate) == 12, "FPS for `sth-sth-v2` should be 12!"
# Extract --> then serialize via `Image.save("frame_{idx}.jpg")`
frames = preprocess_transform([f.to_image() for f in container.decode(video=0)])
...
```
#### Citation
If you are pretraining on this dataset, make sure to cite the original research; Something-Something-v2 is the product
of two papers:
```bibtex
@inproceedings{goyal2017sthsthv1,
author = {Raghav Goyal and Samira Ebrahimi Kahou and Vincent Michalski and Joanna Materzynska and Susanne Westphal and Heuna Kim and Valentin Haenel and Ingo Fründ and Peter N. Yianilos and Moritz Mueller-Freitag and Florian Hoppe and Christian Thurau and Ingo Bax and Roland Memisevic},
booktitle = {International Conference on Computer Vision (ICCV)},
title = {The ``Something Something'' Video Database for Learning and Evaluating Visual Common Sense},
year = {2017},
}
@article{mahidisoltani2018sthsthv2,
author={Farzaneh Mahdisoltani and Guillaume Berger and Waseem Gharbieh and David J. Fleet and Roland Memisevic},
journal = {arXiv preprint arXiv:1804.09235},
title={On the Effectiveness of Task Granularity for Transfer Learning},
year={2018}
}
```
---
## PyTorch Native Pretraining Pipeline
To pretrain a Voltron model (e.g., `v-cond`) on the processed data, make sure to read `examples/pretrain/preprocess.py`.
A sample launch command to run with the Something-Something-v2 dataset on a single node with 8 GPUs is as follows:
```bash
torchrun --standalone --nnodes 1 --nproc-per-node 8 examples/pretrain/pretrain.py
```
Make sure to check the following configuration files and either update them manually (adding your own dataclass,
overriding [DEFAULTS](https://github.com/siddk/voltron-robotics/blob/main/examples/pretrain/pretrain.py#L38)), or by
using Hydra semantics to override them at the command line (e.g., `... pretrain.py dataset.path="" ...`):
- [Accelerator Config](../../voltron/conf/accelerators.py): Depending on hardware, might need to tune `num_workers`
- [Dataset Config](../../voltron/conf/datasets.py): Make sure to override `path` and `artifact_path`
- [Tracking Config](../../voltron/conf/tracking.py): Disable Weights & Biases / change default entity/name
================================================
FILE: examples/pretrain/preprocess.py
================================================
"""
preprocess.py
Centralized script for preprocessing various video/vision-language datasets for GPU pretraining, using a multi-stage,
multiprocessing approach.
Run as a standalone script, *prior* to calling `pretrain.py` =>> mostly because we want to preprocess the data once, as
a fixed cost.
"""
import logging
from dataclasses import dataclass, field
from typing import Any, Dict, List
import hydra
from hydra.core.config_store import ConfigStore
from omegaconf import MISSING
from voltron.conf import DatasetConfig
from voltron.overwatch import OverwatchRich
from voltron.preprocessing import extract_frames, preprocess_language, unify_batches
from voltron.util import set_global_seed
# Grab Logger
overwatch = logging.getLogger(__file__)
# Set Defaults (Hydra w/ Structured Configs)
DEFAULTS = ["_self_", {"dataset": "sth-sth-v2"}, {"override hydra/job_logging": "overwatch_rich"}]
@dataclass
class PreprocessingConfig:
# fmt: off
defaults: List[Any] = field(default_factory=lambda: DEFAULTS)
hydra: Dict[str, Any] = field(
default_factory=lambda: {"run": {"dir": "./runs/preprocessing/${now:%m-%d}/dataset-${dataset.name}"}}
)
# Command Line Arguments
seed: int = 21 # Random Seed (for reproducibility)
dry_run: bool = False # Dry Run --> Get a sense of preprocessing/serialization footprint
# Composable / Structured Arguments
dataset: DatasetConfig = MISSING # Dataset(s) for pretraining/preprocessing
# fmt: on
# Hydra Setup :: Retrieve ConfigStore (Singleton) & Register Components
cs = ConfigStore.instance()
cs.store(group="hydra/job_logging", name="overwatch_rich", node=OverwatchRich)
cs.store(name="config", node=PreprocessingConfig)
@hydra.main(config_path=None, config_name="config")
def preprocess(cfg: PreprocessingConfig) -> None:
overwatch.info("Preprocessing :: Running Phases for Frame Extraction, Language Compilation, and Batching...")
# Set Randomness
set_global_seed(cfg.seed)
# Phase 1 :: Serialize Frames from Video Clips --> get `registry` (index files) for train and validation
train_registry, val_registry, train_dir, val_dir = extract_frames(
cfg.dataset.name,
path=cfg.dataset.path,
artifact_path=cfg.dataset.artifact_path,
preprocess_resolution=cfg.dataset.preprocess_resolution,
n_val_videos=cfg.dataset.n_val_videos,
dry_run=cfg.dry_run,
)
# Phase 2 :: Normalize & Tokenize Language --> create `index.pt` and `index.json` files
index_dir = preprocess_language(
cfg.dataset.name,
train_registry,
val_registry,
artifact_path=cfg.dataset.artifact_path,
max_lang_len=cfg.dataset.max_lang_len,
language_model=cfg.dataset.language_model,
hf_cache=cfg.dataset.hf_cache,
)
# Phase 3 :: Assemble "Data-Locked" Batch Sets for Various Models (e.g., for single-frame/dual-frame/quintet)
unify_batches(
cfg.dataset.name,
train_registry,
val_registry,
train_dir,
val_dir,
index_dir,
batch_formats=cfg.dataset.batch_formats,
max_epochs=cfg.dataset.max_epochs,
initial_final_alpha=cfg.dataset.initial_final_alpha,
)
overwatch.info("Preprocessing Complete!")
if __name__ == "__main__":
preprocess()
================================================
FILE: examples/pretrain/pretrain.py
================================================
"""
pretrain.py
Core pretraining script for Native PyTorch (Single/Multi-) GPU pretraining on the Something-Something-v2 dataset; this
is basically just a 1-1 reproduction of the XLA pretraining script (`examples/xla-reference/xpretrain.py`) with just
a bit of cleanup, the default PyTorch DDP semantics (`torchrun`), using PyTorch 2.0.
Other notable differences from `xpretrain.py`:
- Loads data from the local filesystem instead of streaming from a GCP bucket (can be added back easily!)
- No TPU/XLA specific dependencies --> just PyTorch 2.0!
Run with:
- [Single Node Multi-GPU ($K)]: `torchrun --standalone --nnodes 1 --nproc-per-node $K examples/pretrain/pretrain.py`
"""
import logging
import os
import re
import time
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
import hydra
import torch
import torch.distributed as dist
from hydra.core.config_store import ConfigStore
from omegaconf import MISSING, OmegaConf
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm
from voltron.conf import AcceleratorConfig, DatasetConfig, ModelConfig, TrackingConfig
from voltron.datasets import get_datasets
from voltron.models import get_model_optimizer
from voltron.overwatch import OverwatchRich
from voltron.util import CheckpointSaver, Metrics, ResumeableDistributedSampler, do_resume, set_global_seed
# Set Defaults (Hydra w/ Structured Configs)
DEFAULTS = [
"_self_",
{"model": "v-cond"},
{"dataset": "sth-sth-v2"},
{"accelerator": "torchrun"},
{"tracking": "voltron-tracking"},
{"override hydra/job_logging": "overwatch_rich"},
]
@dataclass
class PretrainConfig:
# fmt: off
defaults: List[Any] = field(default_factory=lambda: DEFAULTS)
hydra: Dict[str, Any] = field(default_factory=lambda: {
"run": {"dir": "runs/train/${model.identifier}+dataset-${dataset.name}"}
})
# Command Line Arguments
run_id: Optional[str] = None # Run ID for Logging
seed: int = 21 # Random Seed (for reproducibility)
# Resume / Debug Behavior
resume: bool = True # Whether to resume an existing run...
wandb_resume_id: Optional[str] = None # W&B Run ID for `resume` behavior...
# Composable / Structured Arguments
model: ModelConfig = MISSING # Model architecture for pretraining
dataset: DatasetConfig = MISSING # List of datasets for pretraining
accelerator: AcceleratorConfig = MISSING # Accelerator (should always keep `torchrun`)
tracking: TrackingConfig = MISSING # Run/experiment tracking configuration
# fmt: on
# Hydra Setup :: Retrieve ConfigStore (Singleton) & Register Components
cs = ConfigStore.instance()
cs.store(group="hydra/job_logging", name="overwatch_rich", node=OverwatchRich)
cs.store(name="config", node=PretrainConfig)
@hydra.main(config_path=None, config_name="config")
def pretrain(cfg: PretrainConfig) -> None:
# Initialize Distributed Process Group --> assumes NCCL + Environment Variable Initialization (via `torchrun`)
dist.init_process_group(backend="nccl", init_method="env://")
device_id = dist.get_rank() % torch.cuda.device_count()
is_rank_zero, rank, world_size = dist.get_rank() == 0, dist.get_rank(), dist.get_world_size()
# Create Unique Run Name -- `resume = True` we assume the same "run_id"
if cfg.run_id is None:
cfg.run_id = run_dir = f"{cfg.model.identifier}+{cfg.dataset.name}-ddp-x{cfg.seed}"
else:
run_dir = cfg.run_id
# Setup Logging (Rank 0 Only!) and Directory Handling
overwatch = logging.getLogger(__file__)
overwatch.setLevel(logging.INFO if is_rank_zero else logging.ERROR)
overwatch.info("Voltron Training :: Assembling the Legendary Defender...")
if is_rank_zero:
os.makedirs(run_dir, exist_ok=True)
# Let's Get Started!
overwatch.info(
'\t=>> "If you get too worried about what could go wrong, you might miss a chance to do something great."'
)
# Set Randomness & Get Dataloader `worker_init_fn` to ensure proper randomness in augmentations (if any)
worker_init_fn = set_global_seed(cfg.seed, get_worker_init_fn=True)
# Initialize Model & Optimizer --> Wrap in DDP / Device Handling
# > Note :: For (Standard) DDP Training --> initializing Optimizer before DDP == initializing after!
overwatch.info("Initializing Model, Optimizer, and Learning Rate Scheduler")
model, optimizer, update_lr = get_model_optimizer(cfg.model, cfg.dataset)
model = DDP(model.to(device_id), device_ids=[device_id], output_device=device_id)
# Handle Resume / Checkpoint Loading
resume_checkpoint, resume_epoch, resume_step = do_resume(cfg.resume, run_dir=run_dir)
if resume_checkpoint is not None:
# IMPORTANT --> Load weights by mapping specifically to `cuda:`!
resume_state = torch.load(resume_checkpoint, map_location=f"cuda:{device_id}")
model.load_state_dict(resume_state["model_state_dict"])
optimizer.load_state_dict(resume_state["optimizer_state_dict"])
dist.barrier()
# Create Checkpoint Saver and Save Initial Checkpoint
saver = CheckpointSaver(cfg.tracking.checkpoint_strategy, run_dir, is_rank_zero=is_rank_zero)
if resume_checkpoint is None and resume_epoch == 0:
overwatch.info(" | Saving 0th Epoch Checkpoint (Model Initialization)")
saver.save(
epoch=0, is_local_step=False, model=model, optimizer=optimizer, duration=0, train_loss=None, val_loss=None
)
dist.barrier()
# Get Datasets --> Barrier after I/O Intensive Operation
overwatch.info(f"Retrieving Dataset `{cfg.dataset.name}` prepared for Model `{cfg.model.arch}`")
train_dataset, val_dataset = get_datasets(
0,
cfg.dataset.name,
cfg.model.arch,
cfg.dataset.artifact_path,
cfg.model.data_modality,
cfg.dataset.resolution,
cfg.dataset.normalization,
cfg.model.get("lang_dropout", None),
cfg.model.get("gen_ratio", None),
)
dist.barrier()
# Create Metrics =>> Handles on-the-fly computation, logging to JSONL and Weights & Biases
metrics = Metrics(
active_loggers=cfg.tracking.active_loggers,
run_id=cfg.run_id,
hparams=OmegaConf.to_container(cfg),
model_arch=cfg.model.arch,
is_rank_zero=is_rank_zero,
tracking_cfg=cfg.tracking,
tags=cfg.tracking.tags,
resume=cfg.resume,
resume_id=cfg.wandb_resume_id,
)
dist.barrier()
# Configure Gradient Accumulation --> function of `effective_bsz`, `native_bsz`, and `WORLD_SIZE`
assert cfg.model.effective_bsz % cfg.model.native_bsz == 0, "Device `native_bsz` must evenly divide `effective_bsz`"
accumulate_grad_batches = cfg.model.effective_bsz // cfg.model.native_bsz // world_size
overwatch.info(f"Running `{cfg.model.identifier}` Model Pretraining with Parameters =>")
overwatch.info(f" | Effective Batch Size = `{cfg.model.effective_bsz}`")
overwatch.info(f" | Per-Device Batch Size = `{cfg.model.native_bsz}`")
overwatch.info(f" | Distributed World Size = `{world_size}`")
overwatch.info(f" | Accumulation Steps = `{accumulate_grad_batches}`")
# Start Train Loop --> Iterate through Epochs (Evaluation at end of Epoch)
overwatch.info("Starting Training Loop")
for epoch in range(resume_epoch, cfg.dataset.max_epochs):
overwatch.info(f" | [Epoch {epoch:03d}] Building Distributed Sampler & DataLoaders")
train_dataset.set_epoch(epoch)
dist.barrier()
# [Custom] ResumeableDistributedSampler operates over *examples* --> start_step (full batches) * effective_bsz
train_sampler = ResumeableDistributedSampler(
seen_examples=resume_step * cfg.model.effective_bsz,
resume_epoch=resume_epoch,
dataset=train_dataset,
num_replicas=world_size,
rank=rank,
shuffle=True,
seed=cfg.seed,
)
train_sampler.set_epoch(epoch)
val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank, shuffle=False, drop_last=True)
# Create Epoch DataLoaders
train_dl = DataLoader(
train_dataset,
batch_size=cfg.model.native_bsz,
sampler=train_sampler,
shuffle=False,
num_workers=cfg.accelerator.num_workers,
drop_last=True,
pin_memory=True,
prefetch_factor=4,
worker_init_fn=worker_init_fn,
)
val_dl = DataLoader(
val_dataset, batch_size=cfg.model.native_bsz, sampler=val_sampler, shuffle=False, num_workers=4
)
# Book-Keeping =>> Set LR when `resume = True` (or starting from scratch)
if epoch == resume_epoch or epoch == 0:
metrics.resume_time = (
int(re.search("-t=(.+?).pt", str(resume_checkpoint)).group(1)) if resume_checkpoint is not None else 0
)
metrics.commit(
global_step=resume_step + ((len(train_dataset) // cfg.model.effective_bsz) * resume_epoch),
lr=update_lr(resume_epoch, resume_step / (len(train_dataset) // cfg.model.effective_bsz)),
update_step_time=True,
)
# === Train Epoch ===
model.train()
status = metrics.get_status(epoch)
overwatch.info(f" | [Epoch {epoch:03d}] Running Train Loop")
with tqdm(
total=len(train_dl) // accumulate_grad_batches, desc=status, leave=False, disable=not is_rank_zero
) as progress:
for train_idx, batch in enumerate(train_dl):
# Model-Specific Handling
if cfg.model.arch == "v-mvp":
img = batch
loss, _, _ = model(img.to(device_id, non_blocking=True))
metrics.commit(reconstruction_loss=loss)
elif cfg.model.arch in {"v-r3m", "v-rn3m"}:
imgs, lang, lang_mask = batch
loss, tcn_loss, reward_loss, l1_loss, l2_loss, tcn_acc, rew_acc = model(
imgs.to(device_id, non_blocking=True),
lang.to(device_id, non_blocking=True),
lang_mask.to(device_id, non_blocking=True),
)
metrics.commit(
tcn_loss=tcn_loss,
reward_loss=reward_loss,
l1_loss=l1_loss,
l2_loss=l2_loss,
tcn_accuracy=tcn_acc,
reward_accuracy=rew_acc,
)
elif cfg.model.arch == "v-cond":
img, lang, lang_mask = batch
loss, _, _ = model(
img.to(device_id, non_blocking=True),
lang.to(device_id, non_blocking=True),
lang_mask.to(device_id, non_blocking=True),
)
metrics.commit(reconstruction_loss=loss)
elif cfg.model.arch == "v-dual":
imgs, lang, lang_mask = batch
loss, [zero_loss, k_loss] = model(
imgs.to(device_id, non_blocking=True),
lang.to(device_id, non_blocking=True),
lang_mask.to(device_id, non_blocking=True),
)
metrics.commit(
reconstruction_loss=loss,
zero_reconstruction_loss=zero_loss,
k_reconstruction_loss=k_loss,
)
elif cfg.model.arch == "v-gen":
imgs, lang_con, lang_con_mask, lang_gen, lang_gen_mask, lang_gen_weight = batch
loss, reconstruction_loss, lm_loss, [zero_loss, k_loss] = model(
imgs.to(device_id, non_blocking=True),
lang_con.to(device_id, non_blocking=True),
lang_con_mask.to(device_id, non_blocking=True),
lang_gen.to(device_id, non_blocking=True),
lang_gen_mask.to(device_id, non_blocking=True),
lang_gen_weight,
)
metrics.commit(
reconstruction_loss=reconstruction_loss,
zero_reconstruction_loss=zero_loss,
k_reconstruction_loss=k_loss,
lm_loss=lm_loss,
lm_ppl=torch.exp(lm_loss),
)
else:
raise ValueError(f"Forward() for Model `{cfg.model.arch}` is not implemented!")
# Commit Loss (Prior to Normalization)
metrics.commit(loss=loss)
# Normalize Loss to account for Gradient Accumulation --> Backward!
normalized_loss = loss / accumulate_grad_batches
normalized_loss.backward()
# Step =>> Check if done w/ Gradient Accumulation
if (train_idx + 1) % accumulate_grad_batches == 0:
metrics.commit(update_step_time=True)
# Push Metrics every `log_frequency` steps...
if metrics.global_step % cfg.tracking.log_frequency == 0:
status = metrics.push(epoch)
# Optimizer Step --> Increment Global Step, Learning Rate, and Checkpoint (if specified)
optimizer.step()
optimizer.zero_grad()
lr = update_lr(
epoch,
(resume_step + ((train_idx + 1) // accumulate_grad_batches))
/ (len(train_dataset) // cfg.model.effective_bsz),
)
metrics.commit(global_step=metrics.global_step + 1, lr=lr)
saver.save(
epoch,
is_local_step=True,
model=model,
optimizer=optimizer,
duration=int(time.time() - metrics.start_time) + metrics.resume_time,
local_step=resume_step + ((train_idx + 1) // accumulate_grad_batches),
)
# Update Progress Bar
progress.update()
progress.set_description(status)
# === After Train Epoch --> Clear Gradients and reset `resume_step` ===
optimizer.zero_grad()
resume_step = 0
# === Validation ===
overwatch.info(f" | [Epoch {epoch:03d}] Running Validation Loop")
model.eval()
# Accumulate `validation_losses` in order to `all_reduce` later!
val_losses = []
with torch.no_grad():
for batch in tqdm(val_dl, disable=not is_rank_zero, leave=False):
# Model-Specific Handling
if cfg.model.arch == "v-mvp":
img = batch
val_loss, _, _ = model(img.to(device_id, non_blocking=True))
elif cfg.model.arch in {"v-r3m", "v-rn3m"}:
imgs, lang, lang_mask = batch
val_loss, _, _, _, _, _, _ = model(
imgs.to(device_id, non_blocking=True),
lang.to(device_id, non_blocking=True),
lang_mask.to(device_id, non_blocking=True),
)
elif cfg.model.arch == "v-cond":
img, lang, lang_mask = batch
val_loss, _, _ = model(
img.to(device_id, non_blocking=True),
lang.to(device_id, non_blocking=True),
lang_mask.to(device_id, non_blocking=True),
)
elif cfg.model.arch == "v-dual":
imgs, lang, lang_mask = batch
val_loss, _ = model(
imgs.to(device_id, non_blocking=True),
lang.to(device_id, non_blocking=True),
lang_mask.to(device_id, non_blocking=True),
)
elif cfg.model.arch == "v-gen":
imgs, lang_con, lang_con_mask, lang_gen, lang_gen_mask, lang_gen_weight = batch
val_loss, _, _, _ = model(
imgs.to(device_id, non_blocking=True),
lang_con.to(device_id, non_blocking=True),
lang_con_mask.to(device_id, non_blocking=True),
lang_gen.to(device_id, non_blocking=True),
lang_gen_mask.to(device_id, non_blocking=True),
lang_gen_weight,
)
else:
raise ValueError(f"Forward() for Model `{cfg.model.arch}` is not implemented!")
# Add to Validation Losses
val_losses.append(val_loss)
# All Reduce --> Push Epoch Metrics --> Checkpoint!
validation_loss = torch.stack(val_losses).mean()
dist.all_reduce(validation_loss)
avg_val_loss = validation_loss / world_size
if is_rank_zero:
epoch_status, train_loss, training_duration = metrics.push_epoch(epoch, avg_val_loss)
saver.save(
epoch=epoch + 1,
is_local_step=False,
model=model,
optimizer=optimizer,
duration=training_duration,
train_loss=train_loss.item(),
val_loss=avg_val_loss.item(),
)
# === End of Epoch ===
dist.barrier()
# Finalize
metrics.finalize()
# And... we're done!
overwatch.info("...and that's all, folks!")
dist.barrier()
if __name__ == "__main__":
# General Defaults --> should use Tensor Cores (kinda) if you have them!
torch.set_float32_matmul_precision("high")
torch.multiprocessing.set_start_method("spawn", force=True)
pretrain()
================================================
FILE: examples/usage.py
================================================
"""
usage.py
Example script demonstrating how to load a Voltron model (`V-Cond`) and instantiate a Multiheaded Attention Pooling
extractor head for downstream tasks.
This is the basic formula/protocol for using Voltron for arbitrary downstream applications.
Run with (from root of repository): `python examples/usage.py`
"""
import torch
from torchvision.io import read_image
from voltron import instantiate_extractor, load
def usage() -> None:
print("[*] Demonstrating Voltron Usage for Various Adaptation Applications")
# Get `torch.device` for loading model (note -- we'll load weights directly onto device!)
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load Voltron model --> specify `freeze`, `device` and get model (nn.Module) and preprocessor
vcond, preprocess = load("v-cond", device=device, freeze=True)
# Obtain and preprocess an image =>> can be from a dataset, from a camera on a robot, etc.
img = preprocess(read_image("examples/img/peel-carrot-initial.png"))[None, ...].to(device)
lang = ["peeling a carrot"]
# Get various representations...
with torch.no_grad():
multimodal_features = vcond(img, lang, mode="multimodal") # Fused vision & language features
visual_features = vcond(img, mode="visual") # Vision-only features (no language)
# Can instantiate various extractors for downstream applications
vector_extractor = instantiate_extractor(vcond, n_latents=1, device=device)()
seq_extractor = instantiate_extractor(vcond, n_latents=64, device=device)()
# Assertions...
assert list(vector_extractor(multimodal_features).shape) == [1, vcond.embed_dim], "Should return a dense vector!"
assert list(seq_extractor(visual_features).shape) == [1, 64, vcond.embed_dim], "Should return a sequence!"
if __name__ == "__main__":
usage()
================================================
FILE: examples/verification/verify.py
================================================
"""
verify.py
Example script demonstrating how to load all Voltron models (and reproduced models), take input image(s), and get the
various (e.g., multimodal, image-only) representations.
Also serves to verify that representation loading is working as advertised.
Run with (from root of repository): `python examples/verification/verify.py`
"""
import torch
from torchvision.io import read_image
from voltron import load
# Available Models
MODELS = ["v-cond", "v-dual", "v-gen", "r-mvp", "r-r3m-vit", "r-r3m-rn50"]
# Sample Inputs
IMG_A, IMG_B = "examples/verification/img/peel-carrot-initial.png", "examples/verification/img/peel-carrot-final.png"
LANGUAGE = "peeling a carrot"
def verify() -> None:
print("[*] Running `verify` =>> Verifying Model Representations!")
# Read both images (we'll use the second image for the dual-frame models)
image_a, image_b = read_image(IMG_A), read_image(IMG_B)
# Get `torch.device` for loading model (note -- we'll load weights directly onto device!)
device = "cuda" if torch.cuda.is_available() else "cpu"
for model_id in MODELS:
print(f"\t=> Loading Model ID `{model_id}` and Verifying Representation Shapes!")
model, preprocess = load(model_id, device=device, freeze=True)
# Preprocess image, run feature extraction --> assert on shapes!
if model_id in {"v-cond", "v-cond-base"}:
for modality, expected in [("multimodal", 196 + 20), ("visual", 196)]:
representation = model(preprocess(image_a)[None, ...].to(device), [LANGUAGE], mode=modality)
assert representation.squeeze(dim=0).shape[0] == expected, "Shape not expected!"
elif model_id in {"v-dual", "v-gen"}:
for modality, expected in [("multimodal", 196 + 20), ("visual", 196)]:
dual_img = torch.stack([preprocess(image_a), preprocess(image_b)])[None, ...].to(device)
representation = model(dual_img, [LANGUAGE], mode=modality)
assert representation.squeeze(dim=0).shape[0] == expected, "Shape not expected!"
elif model_id == "r-mvp":
for mode, expected in [("patch", 196), ("cls", 1)]:
representation = model(preprocess(image_a)[None, ...].to(device), mode=mode)
assert representation.squeeze(dim=0).shape[0] == expected, "Shape not expected!"
elif model_id in {"r-r3m-vit", "r-r3m-rn50"}:
representation = model(preprocess(image_a)[None, ...].to(device))
assert representation.squeeze(dim=0).shape[0] == 1, "Shape not expected!"
else:
raise ValueError(f"Model {model_id} not supported!")
# We're good!
print("[*] All representations & shapes verified! Yay!")
if __name__ == "__main__":
verify()
================================================
FILE: examples/xla-reference/README.md
================================================
# XLA Reference
*Note :: This code was written for the experimental PyTorch XLA build in PyTorch 1.12; no guarantees it works with later
versions!*
We trained the original Voltron models (and data-locked reproductions of R3M and MVP) on TPU v3-8 nodes generously
provided by the [TPU Research Cloud (TRC)](https://sites.research.google/trc/about/) program. At the time we started
the project, PyTorch XLA still had some bumps, which was further complicated by the switch from
[TPU Nodes to TPU VMs](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu-arch).
To get things to work, we had to add some non-intuitive code to facilitate PyTorch + TPUs (vs. a standard distributed
data parallel training pipeline). As a result, `xpretrain.py` is here mostly for documentation purposes, with a fully
refactored version `pretrain.py` forthcoming.
We also include the original cloud preprocessing script `xpreprocess.py` for completeness (this is more general).
================================================
FILE: examples/xla-reference/xpreprocess.py
================================================
"""
xpreprocess.py
Centralized script for preprocessing Sth-Sth-v2 for TPU/GCP pretraining, using a multi-stage, multiprocessing strategy.
Run as a standalone script, *prior* to calling `xpretrain.py` =>> mostly because we want to preprocess the data
once, as a fixed cost.
"""
import logging
from dataclasses import dataclass, field
from typing import Any, Dict, List
import hydra
from hydra.core.config_store import ConfigStore
from omegaconf import MISSING
from voltron.conf import DatasetConfig
from voltron.overwatch import OverwatchRich
from voltron.preprocessing.v1 import index, jsonify_language, preprocess_language, preprocess_videos, unify_batches
from voltron.util.v1.random import set_global_seed
# Grab Logger
overwatch = logging.getLogger(__file__)
# Set Defaults (Hydra w/ Structured Configs)
DEFAULTS = ["_self_", {"dataset": "sth-sth-v2"}, {"override hydra/job_logging": "overwatch_rich"}]
@dataclass
class PreprocessingConfig:
# fmt: off
defaults: List[Any] = field(default_factory=lambda: DEFAULTS)
hydra: Dict[str, Any] = field(
default_factory=lambda: {"run": {"dir": "./runs/preprocessing/${now:%m-%d}/dataset-${dataset.name}"}}
)
# Command Line Arguments
seed: int = 21 # Random Seed (for reproducibility)
dry_run: bool = False # Dry Run --> Get a sense of preprocessing/serialization footprint
# Composable / Structured Arguments
dataset: DatasetConfig = MISSING # Dataset(s) for pretraining/preprocessing
# fmt: on
# Hydra Setup :: Retrieve ConfigStore (Singleton) & Register Components
cs = ConfigStore.instance()
cs.store(group="hydra/job_logging", name="overwatch_rich", node=OverwatchRich)
cs.store(name="config", node=PreprocessingConfig)
@hydra.main(config_path=None, config_name="config")
def xpreprocess(cfg: PreprocessingConfig) -> None:
overwatch.info("Preprocessing :: Running Phases for Frame Extraction, Language Compilation, and Batching...")
# Set Randomness
set_global_seed(cfg.seed)
# Phase 1 :: Serialize Frames from Video Clips --> Get `registry` for train and val (index structure)
train_registry, val_registry, train_dir, val_dir = preprocess_videos(
cfg.dataset.name,
path=cfg.dataset.path,
artifact_path=cfg.dataset.artifact_path,
resolution=cfg.dataset.resolution,
n_val_videos=cfg.dataset.n_val_videos,
dry_run=cfg.dry_run,
)
# Phase 2 :: Normalize & Tokenize Language --> Create `index.pt` & `index.json` files
preprocess_language(
cfg.dataset.name,
train_registry,
val_registry,
max_lang_len=cfg.dataset.max_lang_len,
language_model=cfg.dataset.language_model,
hf_cache=cfg.dataset.hf_cache,
)
jsonify_language(train_registry, val_registry)
index_dir = index(train_registry, val_registry, cfg.dataset.name, artifact_path=cfg.dataset.artifact_path)
# Phase 3 :: Assemble & Unify Batch "Sets" across the Varied Dataset Formats (for each Model =>> "data-locked")
unify_batches(
cfg.dataset.artifact_path,
cfg.dataset.name,
train_registry,
val_registry,
train_dir,
val_dir,
index_dir,
cfg.dataset.batch_formats,
max_epochs=cfg.dataset.max_epochs,
initial_final_alpha=cfg.dataset.initial_final_alpha,
)
if __name__ == "__main__":
xpreprocess()
================================================
FILE: examples/xla-reference/xpretrain.py
================================================
"""
xpretrain.py
(The `x` prefix indicates this is a script geared for XLA/TPU backends *only*)!
Reference script for PyTorch XLA (TPU-based) pretraining on the Something-Something-v2 dataset; this is
mostly for completeness =>> the hope is that the regular `pretrain.py` script is more general and maintained.
Focuses on multi-TPU (XLA) training --> but also supports single-core TPU training, as the default distributed mp.spawn
behavior just collapses into a single thread! Loads and preprocesses dataset, instantiates a model, and runs training.
Run with `python examples/xla-reference/xpretrain.py` (will use the configuration specified by `DEFAULTS` below).
"""
import os
import re
import time
from collections import deque
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional
import hydra
import jsonlines
import numpy as np
import torch
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as parallel
import wandb
from hydra.core.config_store import ConfigStore
from omegaconf import MISSING, OmegaConf
from torch.utils.data import DataLoader
from tqdm import tqdm
from voltron.conf import AcceleratorConfig, DatasetConfig, ModelConfig, TrackingConfig
from voltron.datasets.v1.stream_datasets import get_epoch_datasets
from voltron.models import VMVP, VR3M, VRN3M, VCond, VDual, VGen
from voltron.overwatch import OverwatchRich
from voltron.util.v1.checkpointing import XLACheckpointSaver
from voltron.util.v1.distributed import ResumeableDistributedSampler
from voltron.util.v1.random import set_global_seed
from voltron.util.v1.xla_logger import (
log_epoch_end_update,
log_vcond_train_update,
log_vdual_train_update,
log_vgen_train_update,
log_vmvp_train_update,
log_vr3m_train_update,
log_vrn3m_train_update,
)
# Set Defaults (Hydra w/ Structured Configs)
DEFAULTS = [
"_self_",
{"model": "v-cond"},
{"dataset": "sth-sth-v2"},
{"accelerator": "tpu-v3-8"},
{"tracking": "voltron-tracking"},
{"override hydra/job_logging": "overwatch_rich"},
]
@dataclass
class PretrainConfig:
# fmt: off
defaults: List[Any] = field(default_factory=lambda: DEFAULTS)
hydra: Dict[str, Any] = field(default_factory=lambda: {
"run": {"dir": "./runs/train/${model.identifier}+dataset-${dataset.name}"}
})
# Command Line Arguments
run_id: Optional[str] = None # Run ID for Logging
seed: int = 21 # Random Seed (for reproducibility)
# Resume / Debug Behavior
resume: bool = True # Whether to resume an existing run...
resume_epoch: Optional[int] = None # Epoch to resume (if auto-resuming)...
checkpoint_path: Optional[str] = None # Path to the specific checkpoint to load!
wandb_resume_id: Optional[str] = None # W&B Run ID for `resume` behavior...
# Composable / Structured Arguments
model: ModelConfig = MISSING # Model architecture for pretraining
dataset: DatasetConfig = MISSING # List of datasets for pretraining
accelerator: AcceleratorConfig = MISSING # Accelerator configuration
tracking: TrackingConfig = MISSING # Run/experiment tracking configuration
# fmt: on
# Hydra Setup :: Retrieve ConfigStore (Singleton) & Register Components
cs = ConfigStore.instance()
cs.store(group="hydra/job_logging", name="overwatch_rich", node=OverwatchRich) # Annoying - configure logger for Hydra
cs.store(name="config", node=PretrainConfig)
# ruff: noqa: C901
def xpretrain(cfg: PretrainConfig) -> None:
# Identify if `is_rank_zero` --> We only log from the rank zero process!
is_rank_zero = xm.is_master_ordinal(local=False)
xm.master_print("Voltron Training :: Assembling the Legendary Defender...")
# Create Unique Run Name -- if `resume = True` we assume same "run_id"
run_id = cfg.run_id
if run_id is None:
run_id = run_dir = f"{cfg.model.identifier}+{cfg.dataset.name}-x{cfg.seed}"
cfg.run_id = run_id
else:
cfg.run_id = run_dir = run_id
if is_rank_zero:
os.makedirs(run_dir, exist_ok=True)
xm.master_print(
'\t=>> "If you get too worried about what could go wrong, you might miss a chance to do something great."'
)
# Set Randomness, get DataLoader worker initialization function (to ensure any random augmentations!)
worker_init_fn = set_global_seed(cfg.seed)
# Model Initialization Logic
xm.master_print("Initializing Model and Placing on Different Devices...")
if cfg.model.arch == "v-mvp":
xm.master_print(f"Initializing MVP variant `{cfg.model.identifier}`")
model = VMVP(
resolution=cfg.dataset.resolution,
patch_size=cfg.model.patch_size,
encoder_depth=cfg.model.encoder_depth,
encoder_embed_dim=cfg.model.encoder_embed_dim,
encoder_n_heads=cfg.model.encoder_n_heads,
decoder_depth=cfg.model.decoder_depth,
decoder_embed_dim=cfg.model.decoder_embed_dim,
decoder_n_heads=cfg.model.decoder_n_heads,
optimizer=cfg.model.optimizer,
schedule=cfg.model.schedule,
base_lr=cfg.model.base_lr,
min_lr=cfg.model.min_lr,
effective_bsz=cfg.model.effective_bsz,
betas=cfg.model.betas,
weight_decay=cfg.model.weight_decay,
warmup_epochs=cfg.dataset.warmup_epochs,
max_epochs=cfg.dataset.max_epochs,
mlp_ratio=cfg.model.mlp_ratio,
norm_pixel_loss=cfg.model.norm_pixel_loss,
)
elif cfg.model.arch == "v-r3m":
xm.master_print(f"Initializing R3M (ViT) Variant `{cfg.model.identifier}`")
model = VR3M(
resolution=cfg.dataset.resolution,
patch_size=cfg.model.patch_size,
depth=cfg.model.depth,
embed_dim=cfg.model.embed_dim,
n_heads=cfg.model.n_heads,
language_model=cfg.model.language_model,
hf_cache=cfg.model.hf_cache,
language_dim=cfg.model.language_dim,
reward_dim=cfg.model.reward_dim,
n_negatives=cfg.model.n_negatives,
lang_reward_weight=cfg.model.lang_reward_weight,
tcn_weight=cfg.model.tcn_weight,
l1_weight=cfg.model.l1_weight,
l2_weight=cfg.model.l2_weight,
optimizer=cfg.model.optimizer,
schedule=cfg.model.schedule,
lr=cfg.model.lr,
min_lr=cfg.model.min_lr,
warmup_epochs=cfg.dataset.warmup_epochs,
max_epochs=cfg.dataset.max_epochs,
mlp_ratio=cfg.model.mlp_ratio,
)
elif cfg.model.arch == "v-rn3m":
xm.master_print(f"Intializing R3M (ResNet) Variant `{cfg.model.identifier}`")
model = VRN3M(
resolution=cfg.dataset.resolution,
fc_dim=cfg.model.fc_dim,
language_model=cfg.model.language_model,
hf_cache=cfg.model.hf_cache,
language_dim=cfg.model.language_dim,
reward_dim=cfg.model.reward_dim,
n_negatives=cfg.model.n_negatives,
lang_reward_weight=cfg.model.lang_reward_weight,
tcn_weight=cfg.model.tcn_weight,
l1_weight=cfg.model.l1_weight,
l2_weight=cfg.model.l2_weight,
optimizer=cfg.model.optimizer,
lr=cfg.model.lr,
)
elif cfg.model.arch == "v-cond":
xm.master_print(f"Initializing Voltron V-Cond variant `{cfg.model.identifier}`")
model = VCond(
resolution=cfg.dataset.resolution,
patch_size=cfg.model.patch_size,
encoder_depth=cfg.model.encoder_depth,
encoder_embed_dim=cfg.model.encoder_embed_dim,
encoder_n_heads=cfg.model.encoder_n_heads,
decoder_depth=cfg.model.decoder_depth,
decoder_embed_dim=cfg.model.decoder_embed_dim,
decoder_n_heads=cfg.model.decoder_n_heads,
language_model=cfg.model.language_model,
hf_cache=cfg.model.hf_cache,
language_dim=cfg.model.language_dim,
optimizer=cfg.model.optimizer,
schedule=cfg.model.schedule,
base_lr=cfg.model.base_lr,
min_lr=cfg.model.min_lr,
effective_bsz=cfg.model.effective_bsz,
betas=cfg.model.betas,
weight_decay=cfg.model.weight_decay,
warmup_epochs=cfg.dataset.warmup_epochs,
max_epochs=cfg.dataset.max_epochs,
mlp_ratio=cfg.model.mlp_ratio,
norm_pixel_loss=cfg.model.norm_pixel_loss,
)
elif cfg.model.arch == "v-dual":
xm.master_print(f"Initializing Voltron V-Dual variant `{cfg.model.identifier}`")
model = VDual(
resolution=cfg.dataset.resolution,
patch_size=cfg.model.patch_size,
encoder_depth=cfg.model.encoder_depth,
encoder_embed_dim=cfg.model.encoder_embed_dim,
encoder_n_heads=cfg.model.encoder_n_heads,
decoder_depth=cfg.model.decoder_depth,
decoder_embed_dim=cfg.model.decoder_embed_dim,
decoder_n_heads=cfg.model.decoder_n_heads,
language_model=cfg.model.language_model,
hf_cache=cfg.model.hf_cache,
language_dim=cfg.model.language_dim,
optimizer=cfg.model.optimizer,
schedule=cfg.model.schedule,
base_lr=cfg.model.base_lr,
min_lr=cfg.model.min_lr,
effective_bsz=cfg.model.effective_bsz,
betas=cfg.model.betas,
weight_decay=cfg.model.weight_decay,
warmup_epochs=cfg.dataset.warmup_epochs,
max_epochs=cfg.dataset.max_epochs,
mlp_ratio=cfg.model.mlp_ratio,
norm_pixel_loss=cfg.model.norm_pixel_loss,
)
elif cfg.model.arch == "v-gen":
xm.master_print(f"Initializing Voltron V-Gen variant `{cfg.model.identifier}`")
model = VGen(
resolution=cfg.dataset.resolution,
patch_size=cfg.model.patch_size,
encoder_depth=cfg.model.encoder_depth,
encoder_embed_dim=cfg.model.encoder_embed_dim,
encoder_n_heads=cfg.model.encoder_n_heads,
decoder_depth=cfg.model.decoder_depth,
decoder_embed_dim=cfg.model.decoder_embed_dim,
decoder_n_heads=cfg.model.decoder_n_heads,
language_model=cfg.model.language_model,
hf_cache=cfg.model.hf_cache,
language_dim=cfg.model.language_dim,
max_lang_len=cfg.dataset.max_lang_len,
vocab_size=cfg.model.vocab_size,
mae_weight=cfg.model.mae_weight,
lm_weight=cfg.model.lm_weight,
optimizer=cfg.model.optimizer,
schedule=cfg.model.schedule,
base_lr=cfg.model.base_lr,
min_lr=cfg.model.min_lr,
effective_bsz=cfg.model.effective_bsz,
betas=cfg.model.betas,
weight_decay=cfg.model.weight_decay,
warmup_epochs=cfg.dataset.warmup_epochs,
max_epochs=cfg.dataset.max_epochs,
mlp_ratio=cfg.model.mlp_ratio,
norm_pixel_loss=cfg.model.norm_pixel_loss,
)
else:
raise NotImplementedError(f"Model Architecture `{cfg.model.arch}` is not supported!")
# We use gradient accumulation to honor the effective batch size specified...
assert cfg.model.effective_bsz % cfg.model.device_bsz == 0, "Device bsz must evenly divide effective bsz!"
accumulate_grad_batches = cfg.model.effective_bsz // cfg.model.device_bsz // xm.xrt_world_size()
xm.master_print(
f"Running `{cfg.model.identifier}` model w/ Effective Batch Size of `{cfg.model.effective_bsz}`, "
f"Per-Device Batch Size of `{cfg.model.device_bsz}`, "
f"Distributed World Size of `{xm.xrt_world_size()}` and `{accumulate_grad_batches}` Accumulation Steps"
)
# If Resuming =>> Load Model from Checkpoint
start_checkpoint, start_epoch, start_step = None, 0, 0
if cfg.resume:
# **IMPORTANT**: We're making a few assumptions on resuming that should eventually become explicit checks:
# - `accumulate_grad_batches` is exactly the same when resuming; this means:
# + `cfg.model.effective_bsz`, `cfg.model.device_bsz`, & `cfg.accelerator.num_accelerators` are the same!
# - The Weights & Biases directory `run_dir/wandb` only contains a *single run*
# - The `param_groups` in `optimizer.state_dict()` are exactly the same across resumes!
# + This means that (and generally should be true for resuming altogether) the architecture is the same!
# - The `cfg.seed` should be the same (again, should generally be true...)
if cfg.checkpoint_path is None:
xm.master_print("Resuming :: Attempting to Automatically Load Checkpoint -- Searching!")
checkpoint_path = Path(run_dir) / "checkpoints"
if checkpoint_path.exists() and any(checkpoint_path.iterdir()):
# Parse out the latest "complete" epoch checkpoint, as well as any "local step" checkpoints...
checkpoints = list(checkpoint_path.iterdir())
complete_checkpoint, complete_epoch = max(
[
(c, int(re.search("epoch=(.+?)-train", c.name).group(1)))
for c in checkpoints
if "local-epoch=" not in str(c)
],
key=lambda x: x[1],
)
# Case 1 :: We have "local step" checkpoints --> will always override any "full epoch" checkpoints...
local = [
(
c,
int(re.search("local-epoch=(.+?)-step", c.name).group(1)),
int(re.search("step=(.+?)[.-]", c.name).group(1)),
)
for c in checkpoints
if "local-epoch=" in str(c)
]
if len(local) > 0:
# Parse out (epoch, "highest" step) + assert no great "full epoch" checkpoint exists!
start_checkpoint, start_epoch, start_step = max(local, key=lambda x: x[1:])
assert start_epoch == complete_epoch, "Epoch mismatch in `resume` from local_step!"
# Case 2 :: Otherwise, we're just going to start with the last "complete" epoch...
else:
start_checkpoint, start_epoch = complete_checkpoint, complete_epoch
else:
xm.master_print("No Checkpoints Found -- Starting Run from Scratch!")
else:
xm.master_print(f"Resuming :: Loading from Checkpoint `{cfg.checkpoint_path}`...")
start_checkpoint = cfg.checkpoint_path
# Actually Load the Checkpoint State!
if start_checkpoint is not None:
xm.master_print(f"Resuming :: Loading Model & Optimizer State Dictionaries from `{start_checkpoint}`")
checkpoint = torch.load(str(start_checkpoint))
model_state_dict, optimizer_state_dict = checkpoint
model.load_state_dict(model_state_dict)
# Logging / W&B Handling
if is_rank_zero:
xm.master_print("Initializing Weights & Biases + JSONL + Checkpoint Saver on Rank Zero ONLY...")
tags = None
if cfg.tracking.tags is None:
tags = [cfg.model.identifier, cfg.dataset.name, "pretraining"]
# W&B Initialize & Log all Hyperparameters (Only on ordinal 0)
wandb_resume_id = None
if cfg.resume and cfg.wandb_resume_id is None:
xm.master_print("Resuming :: Attempting to Automatically Load W&B Resume ID -- Searching!")
wandb_path = Path("wandb")
if wandb_path.exists() and any((wandb_path / "latest-run").iterdir()):
# Parse out the unique resume_id from the `.wandb` file...
wandb_fns = [f.name for f in (wandb_path / "latest-run").iterdir() if str(f).endswith(".wandb")]
assert len(wandb_fns) == 1, f"There should only be 1 `.wandb` file... found {len(wandb_fns)}!"
# Regex match on `run-{id}.wandb`...
wandb_resume_id = re.search("run-(.+?).wandb", wandb_fns[0]).group(1)
# Otherwise, assert that we're starting from scratch!
else:
assert start_checkpoint is None, "Trying to restart a run from checkpoint without a valid W&B ID!"
elif cfg.resume:
xm.master_print(f"Resuming :: Using Specified W&B Resume ID = `{cfg.wandb_resume_id}`")
wandb_resume_id = cfg.wandb_resume_id
# Initialize Weights & Biases
xm.master_print(f"W&B Resume is {cfg.resume} w/ W&B Resume ID = {wandb_resume_id}!")
wandb.init(
project=cfg.tracking.project,
entity=cfg.tracking.entity,
config=cfg,
name=run_id,
dir=f"{os.getcwd()}" if cfg.tracking.directory is None else cfg.tracking.directory,
tags=tags,
notes=cfg.tracking.notes,
resume="allow" if start_checkpoint is not None else False,
id=wandb_resume_id,
# Weird line because PT-TPU VMs don't come with a working install of Tensorflow...
settings=wandb.Settings(_disable_stats=True),
)
# Initialize JSONL Logger (append only mode) --> last "global step" will always take precedence.
with jsonlines.open(f"{run_id}.jsonl", mode="a", sort_keys=True) as js_logger:
js_logger.write(
{
"run_id": run_id,
"start_time": datetime.now().strftime("%m-%d-%H:%M"),
"hparams": OmegaConf.to_container(cfg),
}
)
# Rank Zero Node will take time to spin up the loggers & checkpointer... might as well rendezvous?
xm.rendezvous("Logging...")
# === Here Be Dragons ===
# Time to handle device placement -- Note - example code doesn't specify device idx - why not?
# > https://github.com/pytorch/xla/blob/3c0d68da07702995a592ea70f27868cd76fa0755/test/test_train_mp_mnist.py#L114
# > Results in printing [xla:0] and [xla:1] a bunch... no [xla:2-7]? This feels bad...?
#
# |=> Debugging Try: `xm.xla_device(n=xm.get_ordinal()) ---> hangs completely?
# +=> *ANSWER*: https://github.com/pytorch/xla/issues/2345#issuecomment-657114819
# >> "Make no assumptions and don't try to build them manually..."
device = xm.xla_device()
model = model.train().to(device)
optimizer, update_lr = model.configure_optimizer()
global_step, train_losses, lrs, start_time, resume_time = 0, deque(maxlen=128), [], time.time(), 0
# If resuming (valid `start_checkpoint`) -- patch the optimizer state dictionary, and load!
if start_checkpoint is not None:
patched_optimizer_state_dict = {
"state": optimizer_state_dict,
"param_groups": optimizer.state_dict()["param_groups"],
}
optimizer.load_state_dict(patched_optimizer_state_dict)
# Create step timing...
step_times, step_start_time = deque(maxlen=128), time.time()
# Create Model/Architecture-Specific Trackers...
if cfg.model.arch == "v-mvp":
reconstruction_losses = deque(maxlen=128)
elif cfg.model.arch in {"v-r3m", "v-rn3m"}:
tcn_losses, reward_losses, l1_losses, l2_losses = [deque(maxlen=128) for _ in range(4)]
tcn_accuracies, reward_accuracies = [deque(maxlen=128) for _ in range(2)]
elif cfg.model.arch == "v-cond":
reconstruction_losses = deque(maxlen=128)
elif cfg.model.arch == "v-dual":
reconstruction_losses = deque(maxlen=128)
zero_reconstruction, k_reconstruction = deque(maxlen=128), deque(maxlen=128)
elif cfg.model.arch == "v-gen":
reconstruction_losses, lm_losses, lm_ppl = deque(maxlen=128), deque(maxlen=128), deque(maxlen=128)
zero_reconstruction, k_reconstruction = deque(maxlen=128), deque(maxlen=128)
else:
raise NotImplementedError(f"Trackers for Model `{cfg.model.arch}` not implemented!")
# 0th Checkpoint - Pull out optimizer state explicitly (`groups` are not serializable & can easily be replicated)
saver = XLACheckpointSaver(cfg.tracking.checkpoint_strategy, run_dir, cfg.accelerator.accelerator)
if start_checkpoint is None and start_epoch == 0:
xm.master_print("Saving 0th Epoch Checkpoint...")
saver.save(
epoch=0, is_local_step=False, model=model, optimizer=optimizer, duration=0, train_loss=None, val_loss=None
)
# Run on all processes --> retrieve "0th epoch" dataset!
# =>> Important, ensures data is locked across models, for the given epoch!
xm.master_print(f"Retrieving Dataset `{cfg.dataset.name}` prepared for `{cfg.model.arch}`!")
train_dataset, val_dataset = get_epoch_datasets(
0,
cfg.dataset.name,
cfg.dataset.normalization,
cfg.model.arch,
cfg.dataset.stream,
cfg.dataset.artifact_path,
cfg.dataset.stream_prefix,
cfg.model.data_modality,
cfg.model.get("lang_dropout", None),
cfg.model.get("gen_ratio", None),
)
# Loading Datasets might take time... rendezvous to be safe
xm.rendezvous("Retrieved Datasets...")
# Iterate through Epochs, Evaluating at the end of each Training Epoch!
# >> Everything in this loop should happen across all workers, except for the logging (ordinal 0)!
xm.master_print("Starting Training Loop...")
for epoch in range(start_epoch, cfg.dataset.max_epochs):
xm.master_print(f"\t[Epoch {epoch}] Building Distributed Sampler & DataLoaders...")
train_dataset.set_epoch(epoch)
# ResumeableDistributedSampler operates at over *examples* --> start_step (full_batches) * effective_bsz
seen_examples = start_step * cfg.model.effective_bsz
train_sampler = ResumeableDistributedSampler(
seen_examples,
start_epoch,
train_dataset,
num_replicas=xm.xrt_world_size(),
rank=xm.get_ordinal(),
shuffle=True,
seed=cfg.seed,
)
# Set epoch appropriately for the `train_sampler` --> necessary to trigger "normal" logic!
train_sampler.set_epoch(epoch)
train_dataloader = DataLoader(
train_dataset,
batch_size=cfg.model.device_bsz,
sampler=train_sampler,
shuffle=False if train_sampler else True,
num_workers=cfg.accelerator.num_workers,
drop_last=True,
worker_init_fn=worker_init_fn,
prefetch_factor=4,
)
# NOTE :: We're not sharding the Validation set --> *everybody* will run forward passes on the *same* data!
# > We will have to reduce_mesh() later... unclear why, but the torch_xla folks seem keen on it; might lead to
# > weird rendezvous/hang issues if Validation is big enough...
val_dataloader = DataLoader(
val_dataset,
batch_size=cfg.model.device_bsz,
shuffle=False,
num_workers=4,
drop_last=True,
worker_init_fn=worker_init_fn,
)
# Initializing the Dataloaders might take time depending on process...
xm.rendezvous("Initialized Dataloaders...")
# Leverage the *special* API that handles synchronizing TPU cores across batches!
# > NOTE: This is super important!
xm.master_print("\tSetting up Parallel MpDeviceLoaders...")
train_device_loader = parallel.MpDeviceLoader(train_dataloader, device)
val_device_loader = parallel.MpDeviceLoader(val_dataloader, device)
# Book-keeping & LR setting when `resuming` --> only do this on start_epoch!
if epoch == start_epoch:
if start_checkpoint is not None:
global_step = start_step + ((len(train_dataset) // cfg.model.effective_bsz) * start_epoch)
resume_time = int(re.search("-t=(.+?).pt", str(start_checkpoint)).group(1))
lrs.append(update_lr(start_epoch, start_step / (len(train_dataset) // cfg.model.effective_bsz)))
else:
lrs.append(update_lr(start_epoch, 0))
# Iterate...
step_start_time = time.time()
with tqdm(total=len(train_device_loader) // accumulate_grad_batches, disable=not is_rank_zero) as progress:
for train_idx, batch in enumerate(train_device_loader):
if cfg.model.arch == "v-mvp":
# Run a forward pass through the MAE... other return vals are reconstructions (pixel norm) & mask
loss, _, _ = model(batch)
reconstruction_losses.append(loss)
elif cfg.model.arch in {"v-r3m", "v-rn3m"}:
imgs, lang, lang_mask = batch
loss, tcn_loss, reward_loss, l1_loss, l2_loss, tcn_acc, rew_acc = model(imgs, lang, lang_mask)
# Add to trackers
tcn_losses.append(tcn_loss)
reward_losses.append(reward_loss)
l1_losses.append(l1_loss)
l2_losses.append(l2_loss)
tcn_accuracies.append(tcn_acc)
reward_accuracies.append(rew_acc)
elif cfg.model.arch == "v-cond":
img, lang, lang_mask = batch
loss, _, _ = model(img, lang, lang_mask)
reconstruction_losses.append(loss)
elif cfg.model.arch == "v-dual":
imgs, lang, lang_mask = batch
loss, [zero_loss, k_loss] = model(imgs, lang, lang_mask)
# Add to trackers
reconstruction_losses.append(loss)
zero_reconstruction.append(zero_loss)
k_reconstruction.append(k_loss)
elif cfg.model.arch == "v-gen":
imgs, lang_con, lang_con_mask, lang_gen, lang_gen_mask, lang_gen_weight = batch
loss, reconstruction_loss, lm_loss, [zero_loss, k_loss] = model(
imgs, lang_con, lang_con_mask, lang_gen, lang_gen_mask, lang_gen_weight
)
# Add to trackers
reconstruction_losses.append(reconstruction_loss)
lm_losses.append(lm_loss)
lm_ppl.append(torch.exp(lm_loss))
zero_reconstruction.append(zero_loss)
k_reconstruction.append(k_loss)
else:
raise NotImplementedError(f"Forward Pass Logic for Model `{cfg.model.arch}` not implemented!")
# Write Loss to Loggers (prior to accumulation normalization)
train_losses.append(loss)
# Normalize loss to account for accumulation
loss = loss / accumulate_grad_batches
loss.backward()
# Gradient Accumulation =>> Note: skip any errant batches at the end...
if (train_idx + 1) % accumulate_grad_batches == 0:
xm.optimizer_step(optimizer) # Note call to xm.optimizer_step() -- has implicit mark_step!
optimizer.zero_grad()
# Add to `step_times`
step_times.append(time.time() - step_start_time)
# Logging --> Because there is no guarantee processes will be in sync, we need a `closure`
# > Ref: https://pytorch.org/xla/release/1.11/index.html#torch_xla.core.xla_model.add_step_closure
if is_rank_zero and global_step % cfg.tracking.log_frequency == 0:
if cfg.model.arch == "v-mvp":
xm.add_step_closure(
log_vmvp_train_update,
args=(
epoch,
global_step,
run_id,
train_losses,
lrs[-1],
reconstruction_losses,
step_times,
),
)
elif cfg.model.arch == "v-r3m":
xm.add_step_closure(
log_vr3m_train_update,
args=(
epoch,
global_step,
run_id,
train_losses,
lrs[-1],
tcn_losses,
reward_losses,
l1_losses,
l2_losses,
tcn_accuracies,
reward_accuracies,
step_times,
),
)
elif cfg.model.arch == "v-rn3m":
xm.add_step_closure(
log_vrn3m_train_update,
args=(
epoch,
global_step,
run_id,
train_losses,
lrs[-1],
tcn_losses,
reward_losses,
l1_losses,
l2_losses,
tcn_accuracies,
reward_accuracies,
step_times,
),
)
elif cfg.model.arch == "v-cond":
xm.add_step_closure(
log_vcond_train_update,
args=(
epoch,
global_step,
run_id,
train_losses,
lrs[-1],
reconstruction_losses,
step_times,
),
)
elif cfg.model.arch == "v-dual":
xm.add_step_closure(
log_vdual_train_update,
args=(
epoch,
global_step,
run_id,
train_losses,
lrs[-1],
reconstruction_losses,
zero_reconstruction,
k_reconstruction,
step_times,
),
)
elif cfg.model.arch == "v-gen":
xm.add_step_closure(
log_vgen_train_update,
args=(
epoch,
global_step,
run_id,
train_losses,
lrs[-1],
reconstruction_losses,
lm_losses,
lm_ppl,
zero_reconstruction,
k_reconstruction,
step_times,
),
)
else:
raise NotImplementedError(f"Log Update for Model `{cfg.model.arch}` not implemented!")
# Increment Global Step _after_ logging!
global_step += 1
# Save checkpoint subject to *local_step = (train_idx + 1) // accumulate_grad_batches*
saver.save(
epoch=epoch,
is_local_step=True,
model=model,
optimizer=optimizer,
duration=int(time.time() - start_time) + resume_time,
local_step=start_step + ((train_idx + 1) // accumulate_grad_batches),
)
# Update LR every `accumulation_steps` iterations...
lrs.append(
update_lr(
epoch,
(start_step + ((train_idx + 1) // accumulate_grad_batches))
/ (len(train_dataset) // cfg.model.effective_bsz),
)
)
# Reset `step_start_time`
step_start_time = time.time()
# Update `progress` each time we take a gradient step!
progress.update()
# After each forward pass, mark a step, to compile XLA graph for a single forward pass!
# =>> Note :: this is important, with gradient accumulation, the graph can get massive otherwise!
xm.mark_step()
else:
# Clear gradients and reset start step (regardless) at end of the loop
optimizer.zero_grad()
start_step = 0
# Redundant, but Synchronous Validation Epoch...
xm.master_print("Validating...")
val_losses = []
with torch.no_grad():
for batch in tqdm(val_device_loader, disable=not is_rank_zero):
if cfg.model.arch == "v-mvp":
loss, _, _ = model(batch)
elif cfg.model.arch in {"v-r3m", "v-rn3m"}:
imgs, lang, lang_mask = batch
loss, _, _, _, _, _, _ = model(imgs, lang, lang_mask)
elif cfg.model.arch == "v-cond":
img, lang, lang_mask = batch
loss, _, _ = model(img, lang, lang_mask)
elif cfg.model.arch == "v-dual":
imgs, lang, lang_mask = batch
loss, _ = model(imgs, lang, lang_mask)
elif cfg.model.arch == "v-gen":
imgs, lang_con, lang_con_mask, lang_gen, lang_gen_mask, lang_gen_weight = batch
loss, _, _, _ = model(imgs, lang_con, lang_con_mask, lang_gen, lang_gen_mask, lang_gen_weight)
else:
raise NotImplementedError(f"Forward Pass Logic for Model `{cfg.model.arch} not implemented!")
# Just append to val_losses...
val_losses.append(loss)
# Compute Val Loss & *mesh reduce* --> Why? :: the XLA people said so!
val_loss = torch.stack(val_losses).mean().item()
val_loss = xm.mesh_reduce("val_loss", val_loss, np.mean) # All replicas should just return the same thing?
# Logging --> add another `closure` for end-of-epoch cleanup --> compute `duration` as well...
duration = int(time.time() - start_time) + resume_time
if is_rank_zero:
xm.add_step_closure(
log_epoch_end_update,
args=(
cfg.model.arch,
epoch,
global_step,
run_id,
duration,
train_losses,
val_loss,
lrs[-1],
step_times,
),
)
# Save Checkpoint (at end of Epoch)
saver.save(
epoch=epoch + 1,
is_local_step=False,
model=model,
optimizer=optimizer,
duration=duration,
train_loss=train_losses[-1].item(),
val_loss=val_loss,
)
# Dump TPU Debugging Metrics...
if is_rank_zero:
with open("tpu-debug-metrics.log", "w") as f:
f.write(met.metrics_report())
# Exiting w/ Multiprocessing is a Nightmare... try to join?
xm.master_print("...and that's all, folks!")
xm.rendezvous("Cheers!")
# Sleep for like 3 minutes... get W&B to finish syncing logs
wandb.finish()
time.sleep(150)
def mp_fn(_: int, cfg: PretrainConfig) -> None:
torch.set_default_tensor_type("torch.FloatTensor")
# Let's Start Pretraining!
xpretrain(cfg)
@hydra.main(config_path=None, config_name="config")
def main(cfg: PretrainConfig) -> None:
import torch_xla.distributed.xla_multiprocessing as xmp
# Call XMP Spawn w/ the Config as the sole argument...
xmp.spawn(mp_fn, args=(cfg,), nprocs=cfg.accelerator.num_accelerators, start_method="spawn")
if __name__ == "__main__":
main()
================================================
FILE: pyproject.toml
================================================
[build-system]
requires = ["setuptools"]
build-backend = "setuptools.build_meta"
[project]
name = "voltron-robotics"
authors = [
{name = "Siddharth Karamcheti", email="skaramcheti@cs.stanford.edu"}
]
description = "Voltron: Language-Driven Representation Learning for Robotics."
version = "1.1.0"
readme = "README.md"
requires-python = ">=3.8"
keywords = ["robotics", "representation learning", "natural language processing", "machine learning"]
license = {file = "LICENSE"}
classifiers = [
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
"Intended Audience :: Education",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3 :: Only",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
]
dependencies = [
"av",
"einops",
"gdown",
"google-cloud-storage",
"h5py",
"hurry.filesize",
"hydra-core==1.1.1", # Lock Hydra =>> future versions break!
"jsonlines",
"omegaconf==2.1.2", # Lock OmegaConf =>> future versions break!
"opencv-python",
"pandas",
"rich",
"torch>=2.0.0", # Native PyTorch Code (Release 2.0.0) uses PyTorch 2.0!
"torchvision>=0.15.0",
"transformers",
"wandb",
]
[project.optional-dependencies]
dev = [
"black",
"ipython",
"pre-commit",
"ruff",
]
[project.urls]
homepage = "https://github.com/siddk/voltron-robotics"
repository = "https://github.com/siddk/voltron-robotics"
documentation = "https://github.com/siddk/voltron-robotics"
[tool.black]
line-length = 121
target-version = ["py38", "py39", "py310"]
preview = true
[tool.ruff]
line-length = 121
target-version = "py38"
select = ["A", "B", "C90", "E", "F", "I", "RUF", "W"]
[tool.ruff.per-file-ignores]
"__init__.py" = ["E402", "F401"]
[tool.setuptools.packages.find]
where = ["."]
exclude = ["cache"]
================================================
FILE: setup.py
================================================
"""
setup.py
PEP 621 switches most of Packaging to `pyproject.toml` -- yet keep a "dummy" setup.py for external code that has not
yet upgraded.
"""
from setuptools import setup
setup()
================================================
FILE: voltron/__init__.py
================================================
from .models.materialize import available_models, load
from .models.util import instantiate_extractor
================================================
FILE: voltron/conf/__init__.py
================================================
from .accelerators import AcceleratorConfig
from .datasets import DatasetConfig
from .models import ModelConfig
from .tracking import TrackingConfig
================================================
FILE: voltron/conf/accelerators.py
================================================
"""
accelerator.py
Base Hydra Structured Configs for defining various accelerator schemes. Uses a simple single inheritance structure.
"""
import os
from dataclasses import dataclass
from hydra.core.config_store import ConfigStore
from omegaconf import MISSING
# === Vanilla Accelerators (Deprecated; mostly for XLA code) ===
@dataclass
class AcceleratorConfig:
accelerator: str = MISSING
num_accelerators: int = MISSING
num_workers: int = MISSING
@dataclass
class TPUv2OneConfig(AcceleratorConfig):
accelerator = "tpu"
num_accelerators = 1
num_workers = 4
@dataclass
class TPUv2EightConfig(AcceleratorConfig):
accelerator = "tpu"
num_accelerators = 8
num_workers = 4
@dataclass
class TPUv3OneConfig(AcceleratorConfig):
accelerator = "tpu"
num_accelerators = 1
num_workers = 8
@dataclass
class TPUv3EightConfig(AcceleratorConfig):
accelerator = "tpu"
num_accelerators = 8
num_workers = 8
# === GPU Default Config --> just set `num_workers`; `torchrun` takes care of the rest! ===
# > Note :: Defaults to 1 GPU if WORLD_SIZE not set (e.g., not running with `torchrun`)
@dataclass
class TorchRunDefaultConfig(AcceleratorConfig):
accelerator = "gpu"
num_accelerators = int(os.environ["WORLD_SIZE"] if "WORLD_SIZE" in os.environ else 1)
num_workers = 8
# Create a configuration group `accelerator` and populate with the above...
cs = ConfigStore.instance()
cs.store(group="accelerator", name="tpu-v2-1", node=TPUv2OneConfig)
cs.store(group="accelerator", name="tpu-v2-8", node=TPUv2EightConfig)
cs.store(group="accelerator", name="tpu-v3-1", node=TPUv3OneConfig)
cs.store(group="accelerator", name="tpu-v3-8", node=TPUv3EightConfig)
cs.store(group="accelerator", name="torchrun", node=TorchRunDefaultConfig)
================================================
FILE: voltron/conf/datasets.py
================================================
"""
datasets.py
Base Hydra Structured Config for defining various pretraining datasets and appropriate configurations. Uses a simple,
single inheritance structure.
"""
from dataclasses import dataclass
from typing import Any, Tuple
from hydra.core.config_store import ConfigStore
from hydra.utils import to_absolute_path
from omegaconf import MISSING
@dataclass
class DatasetConfig:
name: str = MISSING
path: str = MISSING
artifact_path: str = MISSING
# Streaming Parameters (assumes fully preprocessed dataset lives at `stream_prefix/...`)
# =>> Deprecated as of `v2`
stream: bool = True
stream_prefix: str = "data/processed"
# Dataset-Specific Parameters
resolution: int = 224
normalization: Tuple[Any, Any] = MISSING
# For preprocessing --> maximum size of saved frames (assumed square)
preprocess_resolution: int = MISSING
# Validation Parameters
n_val_videos: int = MISSING
# Language Modeling Parameters
language_model: str = "distilbert-base-uncased"
hf_cache: str = to_absolute_path("data/hf-cache")
# Maximum Length for truncating language inputs... should be computed after the fact (set to -1 to compute!)
max_lang_len: int = MISSING
# Dataset sets the number of pretraining epochs (general rule :: warmup should be ~5% of full)
warmup_epochs: int = MISSING
max_epochs: int = MISSING
# Plausible Formats --> These are instantiations each "batch" could take, with a small DSL
# > Note: Assumes final element of the list is the "most expressive" --> used to back-off
batch_formats: Any = (
("state", ("state_i",)),
("state+language", ("state_i", "language")),
("state+ok", ("state_initial", "state_i", "language")),
("quintet+language", ("state_initial", "state_i", "state_j", "state_k", "state_final", "language")),
)
# Preprocessing :: Frame-Sampling Parameters
initial_final_alpha: float = 0.2
@dataclass
class SthSthv2Config(DatasetConfig):
# fmt: off
name: str = "sth-sth-v2"
path: str = to_absolute_path("data/raw/sth-sth-v2")
artifact_path: str = to_absolute_path("data/processed/sth-sth-v2")
# Dataset Specific arguments
normalization: Tuple[Any, Any] = ( # Mean & Standard Deviation (default :: ImageNet)
(0.485, 0.456, 0.406),
(0.229, 0.224, 0.225),
)
# Sth-Sth-v2 Videos have a fixed height of 240; we'll crop to square at this resolution!
preprocess_resolution: int = 240
# Validation Parameters
n_val_videos: int = 1000 # Number of Validation Clips (fast evaluation!)
# Epochs for Dataset
warmup_epochs: int = 20
max_epochs: int = 400
# Language Modeling Parameters
max_lang_len: int = 20
# fmt: on
# Create a configuration group `dataset` and populate with the above...
# =>> Note :: this is meant to be extendable --> add arbitrary datasets & mixtures!
cs = ConfigStore.instance()
cs.store(group="dataset", name="sth-sth-v2", node=SthSthv2Config)
================================================
FILE: voltron/conf/models.py
================================================
"""
models.py
Base Hydra Structured Config for defining various pretraining model architectures and appropriate configurations. Uses a
simple single inheritance structure.
"""
from dataclasses import dataclass
from typing import Tuple
from hydra.core.config_store import ConfigStore
from hydra.utils import to_absolute_path
from omegaconf import MISSING
@dataclass
class ModelConfig:
arch: str = MISSING
identifier: str = MISSING
# Dataset Modality
data_modality: str = MISSING
# Default Vision Transformer Configuration
patch_size: int = 16
mlp_ratio: float = 4.0
# Effective batch size --> total number of examples before gradient update
effective_bsz: int = MISSING
# Number of examples one can safely fit on an accelerator w/ this model!
device_bsz: int = MISSING # For backwards compatibility, only use device_bsz for XLA/TPU pretraining...
native_bsz: int = MISSING # For backwards compatibility, define a separate `native_bsz`...
# @Data-Locked Reproductions --- Encompasses MVP (MAE) + R3M
# MVP (Base Masked Autoencoder)
@dataclass
class MVPConfig(ModelConfig):
arch: str = "v-mvp"
identifier: str = MISSING
# Dataset Modality
data_modality: str = "state"
# Base MAE Parameters
mask_ratio: float = 0.75
# Architecture Parameters
encoder_depth: int = MISSING
encoder_embed_dim: int = MISSING
encoder_n_heads: int = MISSING
decoder_depth: int = MISSING
decoder_embed_dim: int = MISSING
decoder_n_heads: int = MISSING
# MAE Loss/Objective Configuration
norm_pixel_loss: bool = True
effective_bsz: int = 1024
device_bsz: int = MISSING
native_bsz: int = MISSING
# Optimization Parameters
optimizer: str = "adamw"
schedule: str = "linear-warmup+cosine-decay"
base_lr: float = 1.5e-4
min_lr: float = 0.0
betas: Tuple[float, float] = (0.9, 0.95)
weight_decay: float = 0.05
@dataclass
class MVPSmallConfig(MVPConfig):
identifier = "r-mvp"
# Architecture Parameters -- should match ViT Small Architecture to the letter!
# Note: Small is defined in TIMM:
# > https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L683
encoder_depth = 12
encoder_embed_dim = 384
encoder_n_heads = 6
decoder_depth = 6
decoder_embed_dim = 192
decoder_n_heads = 6
# Number of examples one can safely fit on an accelerator w/ this model!
# > TPU-v3: max of 128 per device.
device_bsz = 128
native_bsz = 128
# R3M Models --> Just different visual encoders, roughly following the above!
@dataclass
class R3MConfig(ModelConfig):
arch: str = "v-r3m"
identifier: str = MISSING
# Dataset Modality
data_modality: str = "quintet+language"
# ViT Architecture Parameters
depth: int = MISSING
embed_dim: int = MISSING
n_heads: int = MISSING
# Effective Batch Size
effective_bsz: int = 1024
# Language Model Parameters
language_model: str = "distilbert-base-uncased"
hf_cache: str = to_absolute_path("data/hf-cache")
language_dim: int = 768
vocab_size: int = 30522
reward_dim: int = 1024
# Loss/Objective Configuration
lang_reward_weight: float = 1.0
tcn_weight: float = 1.0
l1_weight: float = 1e-5
l2_weight: float = 1e-5
n_negatives: int = 3
device_bsz: int = MISSING
native_bsz: int = MISSING
# Optimization Parameters
optimizer: str = "adam"
schedule: str = "linear-warmup+cosine-decay"
lr: float = 1e-4
min_lr: float = 0.0
@dataclass
class R3MSmallConfig(R3MConfig):
identifier = "r-r3m-vit"
# Architecture Parameters -- should match ViT Small Architecture to the letter!
# Note: Small is defined in TIMM:
# > https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L683
depth = 12
embed_dim = 384
n_heads = 6
# Device Batch Size
device_bsz = 32
native_bsz = 128
# R3M -- ResNet50 Encoder (instead of ViT)
@dataclass
class ResNet3MConfig(ModelConfig):
arch: str = "v-rn3m"
identifier: str = MISSING
# Dataset Modality
data_modality: str = "quintet+language"
# Effective Batch Size
effective_bsz: int = 1024
# Architecture Parameters
fc_dim: int = MISSING
# Language Model Parameters
language_model: str = "distilbert-base-uncased"
hf_cache: str = to_absolute_path("data/hf-cache")
language_dim: int = 768
vocab_size: int = 30522
reward_dim: int = 1024
# Loss/Objective Configuration
lang_reward_weight: float = 1.0
tcn_weight: float = 1.0
l1_weight: float = 1e-5
l2_weight: float = 1e-5
n_negatives: int = 3
device_bsz: int = MISSING
native_bsz: int = MISSING
# Optimization Parameters
optimizer: str = "adam"
lr: float = 1e-4
class RN3M50Config(ResNet3MConfig):
identifier = "r-r3m-rn50"
# Architecture Parameters
fc_dim = 2048
# Device Batch Size
device_bsz = 32
native_bsz = 128
# @Voltron Models -- VCond, VDual, VGen
# VCond -- Single Frame + Language Conditioning
@dataclass
class VCondConfig(ModelConfig):
arch: str = "v-cond"
identifier: str = MISSING
# Dataset Modality
data_modality: str = "state+language"
# Base MAE Parameters
mask_ratio: float = 0.75
# Base Language Parameters --> full sentence dropout only...
language_model: str = "distilbert-base-uncased"
hf_cache: str = to_absolute_path("data/hf-cache")
language_dim: int = 768
vocab_size: int = 30522
lang_dropout: float = MISSING
# Architecture Parameters
encoder_depth: int = MISSING
encoder_embed_dim: int = MISSING
encoder_n_heads: int = MISSING
decoder_depth: int = MISSING
decoder_embed_dim: int = MISSING
decoder_n_heads: int = MISSING
use_cls_token: bool = True
# MAE Loss/Objective Configuration
norm_pixel_loss: bool = True
effective_bsz: int = 1024
device_bsz: int = MISSING
native_bsz: int = MISSING
# Optimization Parameters
optimizer: str = "adamw"
schedule: str = "linear-warmup+cosine-decay"
base_lr: float = 1.5e-4
min_lr: float = 0.0
betas: Tuple[float, float] = (0.9, 0.95)
weight_decay: float = 0.05
@dataclass
class VCondSmallConfig(VCondConfig):
identifier = "v-cond"
# No language dropout...
lang_dropout = 0.0
# Architecture Parameters -- should match ViT Small Architecture to the letter!
# Note: Small is defined in TIMM:
# > https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L683
encoder_depth = 12
encoder_embed_dim = 384
encoder_n_heads = 6
decoder_depth = 6
decoder_embed_dim = 192
decoder_n_heads = 6
# Number of examples one can safely fit on an accelerator w/ this model!
# > TPU-v3: max of 128 per device
# > GPU w/ 32G of RAM: max of 128 per device!
device_bsz = 128
native_bsz = 128
@dataclass
class VCondBaseConfig(VCondConfig):
identifier = "v-cond-base"
# No language dropout...
lang_dropout = 0.0
# Architecture Parameters -- should match ViT Base Architecture to the letter!
# Note: Base is defined in TIMM & Original MAE Repository:
# > https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L723
# > https://github.com/facebookresearch/mae/blob/main/models_mae.py#L223
encoder_depth = 12
encoder_embed_dim = 768
encoder_n_heads = 12
decoder_depth = 8
decoder_embed_dim = 512
decoder_n_heads = 16
# Number of examples one can safely fit on an accelerator w/ this model!
# > TPU-v3: max of 128 per device!
# > GPU w/ 32G of RAM: max of 128 per device!
device_bsz = 128
native_bsz = 128
# VDual - Dual Frame (0th Frame + Kth frame) + Language Conditioning
@dataclass
class VDualConfig(ModelConfig):
arch: str = "v-dual"
identifier: str = MISSING
# Dataset Modality
data_modality: str = "state+ok"
# Base MAE Parameters
mae_weight: float = 1.0
mask_ratio: float = 0.75
# Base Language Parameters --> full sentence dropout only...
language_model: str = "distilbert-base-uncased"
hf_cache: str = to_absolute_path("data/hf-cache")
language_dim: int = 768
vocab_size: int = 30522
lang_dropout: float = MISSING
# Architecture Parameters
encoder_depth: int = MISSING
encoder_embed_dim: int = MISSING
encoder_n_heads: int = MISSING
decoder_depth: int = MISSING
decoder_embed_dim: int = MISSING
decoder_n_heads: int = MISSING
use_cls_token: bool = True
# MAE Loss/Objective Configuration -- Cut effective batch size since we see 12-25x contexts per batch example!
norm_pixel_loss: bool = True
effective_bsz: int = 1024
device_bsz: int = MISSING
native_bsz: int = MISSING
# Optimization Parameters
optimizer: str = "adamw"
schedule: str = "linear-warmup+cosine-decay"
base_lr: float = 1.5e-4
min_lr: float = 0.0
betas: Tuple[float, float] = (0.9, 0.95)
weight_decay: float = 0.05
@dataclass
class VDualSmallConfig(VDualConfig):
identifier = "v-dual"
# No language dropout...
lang_dropout = 0.0
# Architecture Parameters -- should match ViT Small Architecture to the letter!
# Note: Small is defined in TIMM:
# > https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L683
encoder_depth = 12
encoder_embed_dim = 384
encoder_n_heads = 6
decoder_depth = 6
decoder_embed_dim = 192
decoder_n_heads = 6
# Number of examples one can safely fit on an accelerator w/ this model!
# > TPU-v3: max of 128 per device!
# > GPU w/ 32G of RAM: max of 128 per device!
device_bsz = 128
native_bsz = 128
@dataclass
class VDualBaseConfig(VDualConfig):
identifier = "v-dual-base"
# No language dropout...
lang_dropout = 0.0
# Architecture Parameters -- should match ViT Base Architecture to the letter!
# Note: Base is defined in TIMM & Original MAE Repository:
# > https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L723
# > https://github.com/facebookresearch/mae/blob/main/models_mae.py#L223
encoder_depth = 12
encoder_embed_dim = 768
encoder_n_heads = 12
decoder_depth = 8
decoder_embed_dim = 512
decoder_n_heads = 16
# Number of examples one can safely fit on an accelerator w/ this model!
# > TPU-v3: max of 128 per device!
# > GPU w/ 32G of RAM: max of 64 per device!
device_bsz = 128
native_bsz = 64
# VGen - Dual Frame with Language Conditioning AND Language Generation
@dataclass
class VGenConfig(ModelConfig):
arch: str = "v-gen"
identifier: str = MISSING
# Dataset Modality
data_modality: str = "state+ok"
# Base MAE & LM Parameters --> LM Weight is set such that mae & lang loss are ~same order of magnitude
mae_weight: float = 1.0
lm_weight: float = 0.5
mask_ratio: float = 0.75
gen_ratio: float = MISSING
# Base Language Parameters
language_model: str = "distilbert-base-uncased"
hf_cache: str = to_absolute_path("data/hf-cache")
language_dim: int = 768
vocab_size: int = 30522
# Architecture Parameters
encoder_depth: int = MISSING
encoder_embed_dim: int = MISSING
encoder_n_heads: int = MISSING
decoder_depth: int = MISSING
decoder_embed_dim: int = MISSING
decoder_n_heads: int = MISSING
use_cls_token: bool = True
# MAE Loss/Objective Configuration -- Cut effective batch size since we see 12-25x contexts per batch example!
norm_pixel_loss: bool = True
effective_bsz: int = 1024
device_bsz: int = MISSING
native_bsz: int = MISSING
# Optimization Parameters
optimizer: str = "adamw"
schedule: str = "linear-warmup+cosine-decay"
base_lr: float = 1.5e-4
min_lr: float = 0.0
betas: Tuple[float, float] = (0.9, 0.95)
weight_decay: float = 0.05
@dataclass
class VGen50SmallConfig(VGenConfig):
identifier = "v-gen"
# LM Parameters --> control % of examples that are for "language generation" (no conditioning)
gen_ratio = 0.50
# Architecture Parameters -- should match ViT Small Architecture to the letter!
# Note: Small is defined in TIMM:
# > https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L683
encoder_depth = 12
encoder_embed_dim = 384
encoder_n_heads = 6
decoder_depth = 6
decoder_embed_dim = 192
decoder_n_heads = 6
# Number of examples one can safely fit on an accelerator w/ this model!
# > TPU-v3: max of 64 per device!
# > GPU w/ 32G of RAM: max of 64 per device!
device_bsz = 64
native_bsz = 64
@dataclass
class VGen50BaseConfig(VGenConfig):
identifier = "v-gen-base"
# LM Parameters --> control % of examples that are for "language generation" (no conditioning)
gen_ratio = 0.50
# Architecture Parameters -- should match ViT Base Architecture to the letter!
# Note: Base is defined in TIMM & Original MAE Repository:
# > https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L723
# > https://github.com/facebookresearch/mae/blob/main/models_mae.py#L223
encoder_depth = 12
encoder_embed_dim = 768
encoder_n_heads = 12
decoder_depth = 8
decoder_embed_dim = 512
decoder_n_heads = 16
# Number of examples one can safely fit on an accelerator w/ this model!
# > TPU-v3: max of 32 per device!
# > GPU w/ 32G of RAM: max of 32 per device!
device_bsz = 32
native_bsz = 32
# Create a configuration group `model` and populate with the above...
cs = ConfigStore.instance()
# === @Data-Locked Reproductions ===
# Image-Only MAE/MVP Architectures
cs.store(group="model", name="r-mvp", node=MVPSmallConfig)
# R3M Architectures - ViT & ResNet50
cs.store(group="model", name="r-r3m-vit", node=R3MSmallConfig)
cs.store(group="model", name="r-r3m-rn50", node=RN3M50Config)
# === @Voltron ===
# VCond Architectures
cs.store(group="model", name="v-cond", node=VCondSmallConfig)
cs.store(group="model", name="v-cond-base", node=VCondBaseConfig)
# VDual
cs.store(group="model", name="v-dual", node=VDualSmallConfig)
cs.store(group="model", name="v-dual-base", node=VDualBaseConfig)
# VGen
cs.store(group="model", name="v-gen", node=VGen50SmallConfig)
cs.store(group="model", name="v-gen-base", node=VGen50BaseConfig)
================================================
FILE: voltron/conf/tracking.py
================================================
"""
tracking.py
Base Hydra Structured Config for defining various run & experiment tracking configurations, e.g., via Weights & Biases.
Uses a simple single inheritance structure.
"""
from dataclasses import dataclass, field
from typing import List, Optional, Tuple
from hydra.core.config_store import ConfigStore
from omegaconf import MISSING
@dataclass
class TrackingConfig:
# Active Loggers --> List of Loggers
active_loggers: List[str] = field(default_factory=lambda: ["jsonl", "wandb"])
# Generic Logging Frequency --> Matters more for XLA/TPUs... set this to be as large as you can stomach!
log_frequency: int = 100
# Checkpointing Strategy --> Save each epoch, keep most recent `idx[0]` checkpoints & *every* `idx[1]` checkpoints
# Additionally, save (locally) a checkpoint every `idx[2]` steps for the current epoch (-1).
checkpoint_strategy: Tuple[int, int, int] = (1, 1, 1500)
# Weights & Biases Setup
project: str = "voltron-pretraining"
entity: str = "voltron-robotics"
# Notes & Tags are at the discretion of the user... see below
notes: str = MISSING
tags: Optional[List[str]] = None
# Directory to save W&B Metadata & Logs in General -- if None, defaults to `logs/` in the Hydra CWD
directory: Optional[str] = None
@dataclass
class VoltronTrackingConfig(TrackingConfig):
# Note: I really like using notes to keep track of things, so will crash unless specified with run.
# > For `tags` I like to populate based on other args in the script, so letting it remain None
notes: str = MISSING
# Create a configuration group `trackers` and populate with the above...
cs = ConfigStore.instance()
cs.store(group="tracking", name="voltron-tracking", node=VoltronTrackingConfig)
================================================
FILE: voltron/datasets/__init__.py
================================================
from .datasets import get_datasets
================================================
FILE: voltron/datasets/datasets.py
================================================
"""
datasets.py
Core Pytorch Dataset implementations for the various "data flavors" used by the different representation learning models
(Voltron and data-locked reproductions). Crucially, ach dataset loads from the corresponding serialized batch files that
define the exact data (and order of iterating through the data) to see during each epoch.
Notably, these serialized files control exactly what data is seen by *all* methods *across epochs*; using these files is
critical to reproducibility & comparisons.
The file contains logic for a "standard" Dataset; all files (batch index files, image/video/language files) are stored
on local disk, assuming storage conducive to fast random reads. For a "streaming" dataset (loading data directly from
GCP Buckets/Amazon S3), see `v1/stream_datasets.py`.
"""
from pathlib import Path
from typing import Any, Optional, Tuple
import h5py
import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision.io import read_image
from torchvision.transforms import Compose
from voltron.preprocessing.transforms import get_online_transform
class PretrainDataset(Dataset):
def __init__(self) -> None:
super().__init__()
self.epoch, self.h5, self.vid, self.states = 0, None, None, None
self.index_path, self.language_path, self.language = None, None, None
def hydrate(self, path: Path) -> None:
# Create Open HDF5 Handle
self.h5 = h5py.File(path, "r")
self.vid, self.states = self.h5["vid"].asstr(), self.h5["states"].asstr()
# Load Language Index
if self.language_path is not None:
self.language = torch.load(self.index_path / self.language_path)
def set_epoch(self, epoch: int) -> None:
self.epoch = epoch
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, ...]:
raise NotImplementedError("PretrainDataset is an abstract class; should never be initialized directly!")
def __len__(self) -> int:
raise NotImplementedError("PretrainDataset is an abstract class; should never be initialized directly!")
class StateDataset(PretrainDataset):
def __init__(self, epoch: int, index_path: Path, img_transform: Compose, is_val: bool = False) -> None:
super().__init__()
self.index_path, self.is_val, self.val_loaded = index_path, is_val, False
self.epoch, self.img_transform, self.hdf5_path, self.n_examples = epoch, img_transform, None, None
# === Retrieve Epoch Batches --> only call before/between epochs (as we create new DataLoaders) ===
self.set_epoch(epoch)
def set_epoch(self, epoch: int) -> None:
# Load Validation Batches
if self.is_val and not self.val_loaded:
self.hdf5_path = self.index_path / "state" / "validation-batches.hdf5"
with h5py.File(self.hdf5_path, "r") as h5:
self.n_examples = len(h5["states"])
# Set `val_loaded`
self.val_loaded = True
# Load Train Batches
elif not self.is_val:
self.hdf5_path = self.index_path / "state" / f"train-epoch={epoch}-batches.hdf5"
with h5py.File(self.hdf5_path, "r") as h5:
self.n_examples = len(h5["states"])
def __getitem__(self, idx: int) -> torch.Tensor:
"""Return processed image frame as a Tensor."""
if self.h5 is None:
self.hydrate(self.hdf5_path)
return self.img_transform(read_image(str(self.index_path.parent / self.states[idx][0])))
def __len__(self) -> int:
return self.n_examples
class StateLanguageDataset(PretrainDataset):
def __init__(
self,
epoch: int,
index_path: Path,
img_transform: Compose,
lang_dropout: Optional[float] = None,
is_val: bool = False,
) -> None:
super().__init__()
self.index_path, self.is_val, self.val_loaded = index_path, is_val, False
self.epoch, self.img_transform, self.hdf5_path, self.n_examples = epoch, img_transform, None, None
self.lang_dropout, self.dropout_idxs = 0.0 if (lang_dropout is None) else lang_dropout, set()
# Set Language Path
self.language_path = "val-language-index.pt" if self.is_val else "train-language-index.pt"
# === Retrieve Epoch Batches --> only call before/between epochs (as we create new DataLoaders) ===
self.set_epoch(epoch)
def set_epoch(self, epoch: int) -> None:
# Load Validation Batches
if self.is_val and not self.val_loaded:
self.hdf5_path = self.index_path / "state+language" / "validation-batches.hdf5"
with h5py.File(self.hdf5_path, "r") as h5:
self.n_examples = len(h5["states"])
# Set `val_loaded`
self.val_loaded = True
# Load Train Batches
elif not self.is_val:
self.hdf5_path = self.index_path / "state+language" / f"train-epoch={epoch}-batches.hdf5"
with h5py.File(self.hdf5_path, "r") as h5:
self.n_examples = len(h5["states"])
# Assemble Dropout Indices
n_drop = int(self.lang_dropout * self.n_examples)
self.dropout_idxs = set(np.random.choice(self.n_examples, n_drop, replace=False))
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Return processed image frame and language, decomposed as the input_ids and attention_mask."""
if self.h5 is None:
self.hydrate(self.hdf5_path)
# Get Vid ID --> parse out language, transform frame!
vid = self.vid[idx]
lang, lang_mask = self.language[vid]["input_ids"], self.language[vid]["attention_mask"]
# Dropout Language (Naive Zeroing leads to NaN --> just want the "CLS" token)
if idx in self.dropout_idxs:
# Initial language token is *always* = `101` --> last token always = `102`
lang[1:] *= 0
lang_mask[1:] *= 0
# Retrieve Frame & Return
img = self.img_transform(read_image(str(self.index_path.parent / self.states[idx][0])))
return img, lang, lang_mask
def __len__(self) -> int:
return self.n_examples
class StateOKDataset(PretrainDataset):
def __init__(
self,
epoch: int,
index_path: Path,
img_transform: Compose,
lang_dropout: Optional[float] = None,
is_val: bool = False,
) -> None:
super().__init__()
self.index_path, self.is_val, self.val_loaded = index_path, is_val, False
self.epoch, self.img_transform, self.hdf5_path, self.n_examples = epoch, img_transform, None, None
self.lang_dropout, self.dropout_idxs = 0.0 if (lang_dropout is None) else lang_dropout, set()
# Set Language Path
self.language_path = "val-language-index.pt" if self.is_val else "train-language-index.pt"
# === Retrieve Epoch Batches --> only call before/between epochs (as we create new DataLoaders) ===
self.set_epoch(epoch)
def set_epoch(self, epoch: int) -> None:
# Load Validation Batches
if self.is_val and not self.val_loaded:
self.hdf5_path = self.index_path / "state+ok" / "validation-batches.hdf5"
with h5py.File(self.hdf5_path, "r") as h5:
self.n_examples = len(h5["states"])
# Set `val_loaded`
self.val_loaded = True
# Load Train Batches
elif not self.is_val:
self.hdf5_path = self.index_path / "state+ok" / f"train-epoch={epoch}-batches.hdf5"
with h5py.File(self.hdf5_path, "r") as h5:
self.n_examples = len(h5["states"])
# Assemble Dropout Indices
n_drop = int(self.lang_dropout * self.n_examples)
self.dropout_idxs = set(np.random.choice(self.n_examples, n_drop, replace=False))
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Return processed dual frames and language, decomposed as the input_ids and attention_mask."""
if self.h5 is None:
self.hydrate(self.hdf5_path)
# Get Vid ID --> parse out language, transform frames!
vid = self.vid[idx]
lang, lang_mask = self.language[vid]["input_ids"], self.language[vid]["attention_mask"]
# Dropout Language (Naive Zeroing leads to NaN --> just want the "CLS" token)
if idx in self.dropout_idxs:
# Initial language token is *always* = `101` --> last token always = `102`
lang[1:] *= 0
lang_mask[1:] *= 0
# Retrieve Frames & Return
imgs = self.states[idx]
imgs = torch.stack([self.img_transform(read_image(str(self.index_path.parent / fn))) for fn in imgs])
return imgs, lang, lang_mask
def __len__(self) -> int:
return self.n_examples
class GenStateOKDataset(PretrainDataset):
def __init__(
self,
epoch: int,
index_path: Path,
img_transform: Compose,
gen_ratio: float,
is_val: bool = False,
) -> None:
super().__init__()
self.index_path, self.is_val, self.val_loaded = index_path, is_val, False
self.epoch, self.img_transform, self.hdf5_path, self.n_examples = epoch, img_transform, None, None
self.gen_ratio, self.gen_idxs = gen_ratio, set()
# Set Language Path
self.language_path = "val-language-index.pt" if self.is_val else "train-language-index.pt"
# === Retrieve Epoch Batches --> only call before/between epochs (as we create new DataLoaders) ===
self.set_epoch(epoch)
def set_epoch(self, epoch: int) -> None:
# Load Validation Batches
if self.is_val and not self.val_loaded:
self.hdf5_path = self.index_path / "state+ok" / "validation-batches.hdf5"
with h5py.File(self.hdf5_path, "r") as h5:
self.n_examples = len(h5["states"])
# Set `val_loaded`
self.val_loaded = True
# Load Train Batches
elif not self.is_val:
self.hdf5_path = self.index_path / "state+ok" / f"train-epoch={epoch}-batches.hdf5"
with h5py.File(self.hdf5_path, "r") as h5:
self.n_examples = len(h5["states"])
# Assemble Generation Indices
n_gen = int(self.gen_ratio * self.n_examples)
self.gen_idxs = set(np.random.choice(self.n_examples, n_gen, replace=False))
def __getitem__(
self, idx: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, float]:
"""Return dual frames, conditioning language, language to generate, decomposed as input_ids/attention mask."""
if self.h5 is None:
self.hydrate(self.hdf5_path)
# Get Vid ID --> parse out language to condition on / generate
vid = self.vid[idx]
lang_con, lang_con_mask = self.language[vid]["input_ids"], self.language[vid]["attention_mask"]
lang_gen, lang_gen_mask, lang_gen_weight = lang_con.clone(), lang_con_mask.clone(), None
# Generate / Condition Check --> (Naive Zeroing leads to NaN --> just want the "CLS" token)
if idx in self.gen_idxs:
# When Generating --> just condition on the token and generate the rest!
lang_con[1:] *= 0
lang_con_mask[1:] *= 0
lang_gen_weight = 1
else:
# When Conditioning -> just generate the token (so things don't crash) but set weight to 0
lang_gen[1:] *= 0
lang_gen_mask[1:] *= 0
lang_gen_weight = 0
# Retrieve Frames and Return
imgs = self.states[idx]
imgs = torch.stack([self.img_transform(read_image(str(self.index_path.parent / fn))) for fn in imgs])
return imgs, lang_con, lang_con_mask, lang_gen, lang_gen_mask, lang_gen_weight
def __len__(self) -> int:
return self.n_examples
class QuintetDataset(PretrainDataset):
def __init__(self, epoch: int, index_path: Path, img_transform: Compose, is_val: bool = False) -> None:
super().__init__()
self.index_path, self.is_val, self.val_loaded = index_path, is_val, False
self.epoch, self.img_transform, self.hdf5_path, self.n_examples = epoch, img_transform, None, None
# Set Language Path
self.language_path = "val-language-index.pt" if self.is_val else "train-language-index.pt"
# === Retrieve Epoch Batches --> only call before/between epochs (as we create new DataLoaders) ===
self.set_epoch(epoch)
def set_epoch(self, epoch: int) -> None:
# Load Validation Batches
if self.is_val and not self.val_loaded:
self.hdf5_path = self.index_path / "quintet+language" / "validation-batches.hdf5"
with h5py.File(self.hdf5_path, "r") as h5:
self.n_examples = len(h5["states"])
# Set `val_loaded`
self.val_loaded = True
# Load Train Batches
elif not self.is_val:
self.hdf5_path = self.index_path / "quintet+language" / f"train-epoch={epoch}-batches.hdf5"
with h5py.File(self.hdf5_path, "r") as h5:
self.n_examples = len(h5["states"])
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Return all five processed frames and language, decomposed as the input_ids/attention_mask."""
if self.h5 is None:
self.hydrate(self.hdf5_path)
# Get Vid ID --> parse out language, transform frames!
vid = self.vid[idx]
lang, lang_mask = self.language[vid]["input_ids"], self.language[vid]["attention_mask"]
# Retrieve Frames & Return
imgs = self.states[idx]
imgs = torch.stack([self.img_transform(read_image(str(self.index_path.parent / fn))) for fn in imgs])
return imgs, lang, lang_mask
def __len__(self) -> int:
return self.n_examples
def get_datasets(
epoch: int,
dataset_name: str,
model_arch: str,
artifact_path: str,
data_modality: str,
resolution: int,
normalization: Tuple[Any, Any],
lang_dropout: Optional[float] = None,
gen_ratio: Optional[float] = None,
) -> Tuple[PretrainDataset, PretrainDataset]:
index = Path(artifact_path) / dataset_name / "index"
img_transform = get_online_transform(dataset_name, model_arch, resolution, normalization)
# Switch on `data_modality` --> differs based on `model_arch` (e.g., MVP --> img, V-Cond --> img, language)
if data_modality == "state":
train_ds = StateDataset(epoch, index, img_transform)
val_ds = StateDataset(epoch, index, img_transform, is_val=True)
elif data_modality == "state+language":
train_ds = StateLanguageDataset(epoch, index, img_transform, lang_dropout)
val_ds = StateLanguageDataset(epoch, index, img_transform, lang_dropout, is_val=True)
elif data_modality == "state+ok":
# V-Dual --> don't return language modeling elements (causal attention mask, suffix language, etc.)
if gen_ratio is None:
train_ds = StateOKDataset(epoch, index, img_transform, lang_dropout)
val_ds = StateOKDataset(epoch, index, img_transform, lang_dropout, is_val=True)
# V-Gen --> add language modeling elements!
else:
train_ds = GenStateOKDataset(epoch, index, img_transform, gen_ratio)
val_ds = GenStateOKDataset(epoch, index, img_transform, gen_ratio, is_val=True)
elif data_modality == "quintet+language":
train_ds = QuintetDataset(epoch, index, img_transform)
val_ds = QuintetDataset(epoch, index, img_transform, is_val=True)
else:
raise ValueError(f"Data Modality `{data_modality}` is not supported!")
return train_ds, val_ds
================================================
FILE: voltron/datasets/v1/__init__.py
================================================
================================================
FILE: voltron/datasets/v1/stream_datasets.py
================================================
"""
stream_datasets.py
Core PyTorch Datasets for the various "flavors" of data used by the various models under study. Crucially, each dataset
loads from the corresponding "batch" serialized files, that define the exact data to use.
Notably, these serialized files control exactly what data is seen by *all* methods **across epochs.** Using them is
fairly critical to reproducibility & fair comparison.
This specific file contains logic for a "streaming" Dataset; data is fetched (within the dataloader, by each
worker) via an open connection over the network to a GCS bucket, materializing data as raw BytesIO objects fed to
PIL.Image constructors.
"""
import json
import os
from io import BytesIO
from pathlib import Path
from typing import Any, Optional, Tuple
import numpy as np
import torch
from google.api_core.exceptions import NotFound
from google.auth.exceptions import TransportError
from google.cloud import storage
from google.resumable_media._helpers import _LOGGER
from PIL import Image, UnidentifiedImageError
from torch.utils.data import Dataset, get_worker_info
from torchvision.io import read_image
from torchvision.transforms import Compose
from torchvision.transforms.functional import pil_to_tensor
from voltron.preprocessing.v1.transforms import get_online_transform
from voltron.util.distributed import get_rank
# NOTE --> IF STREAMING JPEGS, WE NEED TO USE PILLOW TO READ FILES (w/o extracting locally...)
# =>> Instead of `read_image(file)` assume we have "fname" and open fileobj (as BytesIO) -- remember to `seek(0)`
#
# > from PIL import Image
# > from torchvision.transforms.functional import pil_to_tensor
# > tensor = pil_to_tensor(Image.open(fileobj)
# |--> This returns a `torch.uint8` Tensor of shape [3, 224, 224] --> *verified* equivalent to `read_image`
# Create Global GCS Client...
# =>> Multiprocessing.spawn() does not inherit from base shell --> need to set service account key...
# =>> TODO :: Figure out how to fetch num_accelerators & num_workers programatically...
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "/mnt/home/auth/gcp-auth.json"
N_CORES, BUCKETS = 8, [storage.Client().bucket("voltron-ANONYMIZED") for _ in range(8 * 8)]
# Suppress Google Cloud Loggers
_LOGGER.propagate = False
storage.blob._logger.propagate = False
class PretrainDataset(Dataset):
def set_epoch(self, epoch: int) -> None:
self.epoch = epoch
class StateDataset(PretrainDataset):
def __init__(
self,
epoch: int,
index_path: Path,
img_transform: Compose,
stream: bool = False,
prefix: Optional[Path] = None,
is_val: bool = False,
do_retry: bool = True,
n_retries: int = 3,
) -> None:
super().__init__()
self.index_path, self.stream, self.is_val, self.val_loaded = index_path, stream, is_val, False
self.epoch, self.transform, self.elements, self.prefix = epoch, img_transform, None, prefix
self.r = N_CORES * get_rank()
self.do_retry, self.n_retries = do_retry, n_retries
# === Retrieve Epoch Batches (only call `set_epoch` inside __init__() ===
self.set_epoch(self.epoch)
def set_epoch(self, epoch: int) -> None:
# Not Streaming --> Read from local disk...
if not self.stream:
if self.is_val and not self.val_loaded:
with open(self.index_path / "state" / "validation-batches.json", "r") as f:
self.elements = json.load(f)
# Set `val_loaded` and move on...
self.val_loaded = True
elif not self.is_val:
with open(self.index_path / "state" / f"train-epoch={epoch}-batches.json", "r") as f:
self.elements = json.load(f)
# Streaming --> Beam directly from Bucket
else:
if self.is_val and not self.val_loaded:
blob = BUCKETS[self.r].blob(str(self.prefix / "index" / "state" / "validation-batches.json"))
self.elements = json.loads(blob.download_as_string())
# `elements[i]["state"] currently maps to disk path... remove all but `parent/child.jpg`
for element in self.elements:
element["state"] = "/".join(element["state"].split("/")[-2:])
# Set `val_loaded` and move on...
self.val_loaded = True
elif not self.is_val:
blob = BUCKETS[self.r].blob(str(self.prefix / "index" / "state" / f"train-epoch={epoch}-batches.json"))
self.elements = json.loads(blob.download_as_string())
# `elements[i]["state"]` currently maps to disk path... remove all but `parent/child.jpg`
for element in self.elements:
element["state"] = "/".join(element["state"].split("/")[-2:])
def __getitem__(self, index: int) -> torch.Tensor:
"""Return single frame as torch Tensor."""
if not self.stream:
return self.transform(read_image(self.elements[index]["state"]))
else:
# Multiplex w/ num_worker idx...
worker_info = get_worker_info()
r = (self.r + worker_info.id) if worker_info is not None else self.r
# Streaming + Retry Logic (in case of a bad connection -- retry same file!)
frame_path = self.elements[index]["state"]
for _i in range(self.n_retries):
try:
# Stream JPEG contents into BytesIO (seek back to 0), then into PIL Image.open()
if self.is_val:
blob, fobj = BUCKETS[r].blob(str(self.prefix / "val" / frame_path)), BytesIO()
else:
blob, fobj = BUCKETS[r].blob(str(self.prefix / "train" / frame_path)), BytesIO()
# Error Handling --> we've already run several dry-run/verification trials, but about ~0.004% of
# the time, we'll hit some sort of TCP/Transport error; this might even go up
# with multiple runs happening at the same time.
#
# To address this, we're adopting the simplest possible "retry" strategy that
# immediately tries to re-download the same file (and crashes if not possible).
# This ensures reproducibility, but puts some extra effort onto the user...
# File download...
blob.download_to_file(fobj)
fobj.seek(0)
# Image loading...
img_tensor = pil_to_tensor(Image.open(fobj))
# Return transformed image...
return self.transform(img_tensor)
except (NotFound, TransportError, UnidentifiedImageError, OSError) as e:
# At the minimum --> print the broken file (obnoxiously!)
print(f"=>> BROKEN FILE :: {frame_path}")
if not self.do_retry:
raise e
else:
continue
# If we've exhausted our retries --> raise an informative ValueError
raise ValueError(f"Failed to fix state `{self.elements[index]['state']}` w/ {self.n_retries} retries...")
def __len__(self) -> int:
return len(self.elements)
class StateLanguageDataset(PretrainDataset):
def __init__(
self,
epoch: int,
index_path: Path,
img_transform: Compose,
lang_dropout: Optional[float] = None,
stream: bool = False,
prefix: Optional[Path] = None,
is_val: bool = False,
do_retry: bool = True,
n_retries: int = 3,
) -> None:
super().__init__()
self.index_path, self.stream, self.is_val, self.val_loaded = index_path, stream, is_val, False
self.epoch, self.transform, self.elements, self.prefix = epoch, img_transform, None, prefix
self.lang_dropout = 0.0 if (lang_dropout is None or lang_dropout == 0) else lang_dropout
self.dropout_indices = set()
self.r = N_CORES * get_rank()
self.do_retry, self.n_retries = do_retry, n_retries
# Load Language Index & Retrieve Epoch 0 Batches
language_path = "val-language-index.json" if self.is_val else "train-language-index.json"
if not self.stream:
with open(self.index_path / language_path, "r") as f:
self.language_index = json.load(f)
else:
blob = BUCKETS[self.r].blob(str(self.prefix / "index" / language_path))
self.language_index = json.loads(blob.download_as_string())
# === Retrieve Epoch Batches (only call `set_epoch` inside __init__() ===
self.set_epoch(self.epoch)
def set_epoch(self, epoch: int) -> None:
# Not Streaming --> Read from local disk...
if not self.stream:
if self.is_val and not self.val_loaded:
with open(self.index_path / "state+language" / "validation-batches.json", "r") as f:
self.elements = json.load(f)
# Assemble the set of dropout indices for the given epoch...
n_drop = int(self.lang_dropout * len(self.elements))
self.dropout_indices = set(np.random.choice(len(self.elements), n_drop, replace=False))
# Set `val_loaded` and move on...
self.val_loaded = True
elif not self.is_val:
with open(self.index_path / "state+language" / f"train-epoch={epoch}-batches.json", "r") as f:
self.elements = json.load(f)
# Assemble the set of dropout indices for the given epoch...
n_drop = int(self.lang_dropout * len(self.elements))
self.dropout_indices = set(np.random.choice(len(self.elements), n_drop, replace=False))
# Streaming --> Beam directly from Bucket
else:
if self.is_val and not self.val_loaded:
blob = BUCKETS[self.r].blob(str(self.prefix / "index" / "state+language" / "validation-batches.json"))
self.elements = json.loads(blob.download_as_string())
# `elements[i]["state"]` currently maps to disk path... remove all but `parent/child.jpg`
for element in self.elements:
element["state"] = "/".join(element["state"].split("/")[-2:])
# Assemble the set of dropout indices for the given epoch...
n_drop = int(self.lang_dropout * len(self.elements))
self.dropout_indices = set(np.random.choice(len(self.elements), n_drop, replace=False))
# Set `val_loaded` and move on...
self.val_loaded = True
elif not self.is_val:
blob = BUCKETS[self.r].blob(
str(self.prefix / "index" / "state+language" / f"train-epoch={epoch}-batches.json")
)
self.elements = json.loads(blob.download_as_string())
# `elements[i]["state"] currently maps to disk path... remove all but `parent/child.jpg`
for element in self.elements:
element["state"] = "/".join(element["state"].split("/")[-2:])
# Assemble the set of dropout indices for the given epoch...
n_drop = int(self.lang_dropout * len(self.elements))
self.dropout_indices = set(np.random.choice(len(self.elements), n_drop, replace=False))
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Return the frame and language, decomposed as the input_ids, and attention_mask."""
vid = self.elements[index]["vid"]
lang = torch.tensor(self.language_index[vid]["input_ids"], dtype=torch.int64)
lang_mask = torch.tensor(self.language_index[vid]["attention_mask"], dtype=torch.int64)
# Dropout Language Check --> (Naive Zeroing leads to NaN --> just want the "CLS" token)
if index in self.dropout_indices:
# Initial language token is *always* = `101` --> last token always = `102`
lang[1:] *= 0
lang_mask[1:] *= 0
# Retrieve Single Frame
if not self.stream:
img = self.transform(read_image(self.elements[index]["state"]))
return img, lang, lang_mask
else:
# Multiplex w/ num_worker idx...
worker_info = get_worker_info()
r = (self.r + worker_info.id) if worker_info is not None else self.r
# Streaming + Retry Logic (in case of a bad connection -- retry same file!)
frame_path = self.elements[index]["state"]
for _i in range(self.n_retries):
try:
# Stream JPEG contents into BytesIO (seek back to 0), then into PIL Image.open()
if self.is_val:
blob, fobj = BUCKETS[r].blob(str(self.prefix / "val" / frame_path)), BytesIO()
else:
blob, fobj = BUCKETS[r].blob(str(self.prefix / "train" / frame_path)), BytesIO()
# Error Handling --> we've already run several dry-run/verification trials, but about ~0.004% of
# the time, we'll hit some sort of TCP/Transport error; this might even go up
# with multiple runs happening at the same time.
#
# To address this, we're adopting the simplest possible "retry" strategy that
# immediately tries to re-download the same file (and crashes if not possible).
# This ensures reproducibility, but puts some extra effort onto the user...
# File download...
blob.download_to_file(fobj)
fobj.seek(0)
# Image loading...
img_tensor = pil_to_tensor(Image.open(fobj))
# Assemble transformed image and return...
img = self.transform(img_tensor)
return img, lang, lang_mask
except (NotFound, TransportError, UnidentifiedImageError, OSError) as e:
# At the minimum --> print the broken file (obnoxiously!)
print(f"=>> BROKEN FILE :: {frame_path}")
if not self.do_retry:
raise e
else:
continue
# If we've exhausted our retries --> raise an informative ValueError
raise ValueError(f"Failed to fix state `{self.elements[index]['state']}` w/ {self.n_retries} retries...")
def __len__(self) -> int:
return len(self.elements)
class StateOKDataset(PretrainDataset):
def __init__(
self,
epoch: int,
index_path: Path,
img_transform: Compose,
lang_dropout: Optional[float] = None,
stream: bool = False,
prefix: Optional[Path] = None,
no_lang: bool = False,
is_val: bool = False,
do_retry: bool = True,
n_retries: int = 3,
) -> None:
super().__init__()
self.index_path, self.stream, self.is_val, self.val_loaded = index_path, stream, is_val, False
self.epoch, self.transform, self.elements, self.prefix = epoch, img_transform, None, prefix
self.no_lang, self.lang_dropout = no_lang, 0.0 if (lang_dropout is None or lang_dropout == 0) else lang_dropout
self.dropout_indices = set()
self.r = N_CORES * get_rank()
self.do_retry, self.n_retries = do_retry, n_retries
# Load Language Index & Retrieve Epoch 0 Batches
if not self.no_lang:
language_path = "val-language-index.json" if self.is_val else "train-language-index.json"
if not self.stream:
with open(self.index_path / language_path, "r") as f:
self.language_index = json.load(f)
else:
blob = BUCKETS[self.r].blob(str(self.prefix / "index" / language_path))
self.language_index = json.loads(blob.download_as_string())
# === Retrieve Epoch Batches (only call `set_epoch` inside __init__() ===
self.set_epoch(self.epoch)
def set_epoch(self, epoch: int) -> None:
# Not Streaming --> Read from local disk...
if not self.stream:
if self.is_val and not self.val_loaded:
with open(self.index_path / "state+ok" / "validation-batches.json", "r") as f:
self.elements = json.load(f)
# Assemble the set of dropout indices for the given epoch...
n_drop = int(self.lang_dropout * len(self.elements))
self.dropout_indices = set(np.random.choice(len(self.elements), n_drop, replace=False))
# Set `val_loaded` and move on...
self.val_loaded = True
elif not self.is_val:
with open(self.index_path / "state + ok" / f"train-epoch={epoch}-batches.json", "r") as f:
self.elements = json.load(f)
# Assemble the set of dropout indices for the given epoch...
n_drop = int(self.lang_dropout * len(self.elements))
self.dropout_indices = set(np.random.choice(len(self.elements), n_drop, replace=False))
# Streaming --> Beam directly from Bucket
else:
if self.is_val and not self.val_loaded:
blob = BUCKETS[self.r].blob(str(self.prefix / "index" / "state+ok" / "validation-batches.json"))
self.elements = json.loads(blob.download_as_string())
# `elements[i]["states"]` currently maps to disk path.. remove all but `parent/child.jpg`
for element in self.elements:
element["states"] = ["/".join(x.split("/")[-2:]) for x in element["states"]]
# Assemble the set of dropout indices for the given epoch...
n_drop = int(self.lang_dropout * len(self.elements))
self.dropout_indices = set(np.random.choice(len(self.elements), n_drop, replace=False))
# Set `val_loaded` and move on...
self.val_loaded = True
elif not self.is_val:
blob = BUCKETS[self.r].blob(
str(self.prefix / "index" / "state+ok" / f"train-epoch={epoch}-batches.json")
)
self.elements = json.loads(blob.download_as_string())
# `elements[i]["states"]` currently maps to disk path.. remove all but `parent/child.jpg`
for element in self.elements:
element["states"] = ["/".join(x.split("/")[-2:]) for x in element["states"]]
# Assemble the set of dropout indices for the given epoch...
n_drop = int(self.lang_dropout * len(self.elements))
self.dropout_indices = set(np.random.choice(len(self.elements), n_drop, replace=False))
# ruff: noqa: C901
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Return both states/frames and language, decomposed as the input_ids and attention_mask."""
vid = self.elements[index]["vid"]
# Fetch language if desired...
if not self.no_lang:
lang = torch.tensor(self.language_index[vid]["input_ids"], dtype=torch.int64)
lang_mask = torch.tensor(self.language_index[vid]["attention_mask"], dtype=torch.int64)
# Dropout Language Check --> (Naive Zeroing leads to NaN --> just want the "CLS" token)
if index in self.dropout_indices:
# Initial language token is *always* = `101` --> last token always = `102`
lang[1:] *= 0
lang_mask[1:] *= 0
# Retrieve Frames
if not self.stream:
imgs = self.elements[index]["states"]
imgs = torch.stack([self.transform(read_image(s)) for s in imgs])
# Return --> based on `self.no_lang`
if not self.no_lang:
return imgs, lang, lang_mask
else:
return imgs
else:
# Multiplex w/ num_worker idx...
worker_info = get_worker_info()
r = (self.r + worker_info.id) if worker_info is not None else self.r
# Streaming + Retry Logic (in case of a bad connection -- retry same files!)
frame_paths, current_frame = list(self.elements[index]["states"]), None
for _i in range(self.n_retries):
try:
# Stream JPEG contents into BytesIO (seek back to 0), then PIL Image.open()
imgs = []
for _current_idx, current_frame in enumerate(frame_paths):
if self.is_val:
blob, fobj = BUCKETS[r].blob(str(self.prefix / "val" / current_frame)), BytesIO()
else:
blob, fobj = BUCKETS[r].blob(str(self.prefix / "train" / current_frame)), BytesIO()
# Error Handling --> we've already run several dry-run/verification trials, but about ~0.004% of
# the time, we'll hit some sort of TCP/Transport error; this might even go up
# with multiple runs happening at the same time.
#
# To address this, we're adopting the simplest possible "retry" strategy that
# immediately tries to re-download the same file (crashes if not possible).
# This ensures reproducibility, but puts some extra effort onto the user...
# File download...
blob.download_to_file(fobj)
fobj.seek(0)
# Image loading...
img_tensor = pil_to_tensor(Image.open(fobj))
imgs.append(self.transform(img_tensor))
# Stack...
assert len(imgs) == 2, "Something went awry with try/except in StateOK Dataset..."
imgs = torch.stack(imgs)
# Return --> based on `self.no_lang`
if not self.no_lang:
return imgs, lang, lang_mask
else:
return imgs
except (NotFound, TransportError, UnidentifiedImageError, OSError) as e:
# At the minimum --> print the broken file (obnoxiously!)
print(f"=>> BROKEN FILE :: {current_frame}")
if not self.do_retry:
raise e
else:
continue
# If we've exhausted our retries --> raise an informative ValueError
raise ValueError(f"Failed to fix states `{self.elements[index]['states']}` w/ {self.n_retries} retries...")
def __len__(self) -> int:
return len(self.elements)
class GenStateOKDataset(PretrainDataset):
def __init__(
self,
epoch: int,
index_path: Path,
img_transform: Compose,
gen_ratio: float,
stream: bool = False,
prefix: Optional[Path] = None,
is_val: bool = False,
do_retry: bool = True,
n_retries: int = 3,
) -> None:
super().__init__()
self.index_path, self.stream, self.is_val, self.val_loaded = index_path, stream, is_val, False
self.epoch, self.transform, self.elements, self.prefix = epoch, img_transform, None, prefix
self.gen_ratio, self.gen_indices = gen_ratio, set()
self.r = N_CORES * get_rank()
self.do_retry, self.n_retries = do_retry, n_retries
# Load Language Index & Retrieve Epoch 0 Batches
language_path = "val-language-index.json" if self.is_val else "train-language-index.json"
if not self.stream:
with open(self.index_path / language_path, "r") as f:
self.language_index = json.load(f)
else:
blob = BUCKETS[self.r].blob(str(self.prefix / "index" / language_path))
self.language_index = json.loads(blob.download_as_string())
# === Retrieve Epoch Batches (only call `set_epoch` inside __init__() ===
self.set_epoch(self.epoch)
def set_epoch(self, epoch: int) -> None:
# Not Streaming --> Read from local disk...
if not self.stream:
if self.is_val and not self.val_loaded:
with open(self.index_path / "state+ok" / "validation-batches.json", "r") as f:
self.elements = json.load(f)
# Assemble the set of dropout indices for the given epoch...
n_gen = int(self.gen_ratio * len(self.elements))
self.gen_indices = set(np.random.choice(len(self.elements), n_gen, replace=False))
# Set `val_loaded` and move on...
self.val_loaded = True
elif not self.is_val:
with open(self.index_path / "state+ok" / f"train-epoch={epoch}-batches.json", "r") as f:
self.elements = json.load(f)
# Assemble the set of dropout indices for the given epoch...
n_gen = int(self.gen_ratio * len(self.elements))
self.gen_indices = set(np.random.choice(len(self.elements), n_gen, replace=False))
# Streaming --> Beam directly from Bucket
else:
if self.is_val and not self.val_loaded:
blob = BUCKETS[self.r].blob(str(self.prefix / "index" / "state+ok" / "validation-batches.json"))
self.elements = json.loads(blob.download_as_string())
# `elements[i]["states"]` currently maps to disk path.. remove all but `parent/child.jpg`
for element in self.elements:
element["states"] = ["/".join(x.split("/")[-2:]) for x in element["states"]]
# Assemble the set of dropout indices for the given epoch...
n_gen = int(self.gen_ratio * len(self.elements))
self.gen_indices = set(np.random.choice(len(self.elements), n_gen, replace=False))
# Set `val_loaded` and move on...
self.val_loaded = True
elif not self.is_val:
blob = BUCKETS[self.r].blob(
str(self.prefix / "index" / "state+ok" / f"train-epoch={epoch}-batches.json")
)
self.elements = json.loads(blob.download_as_string())
# `elements[i]["states"]` currently maps to disk path.. remove all but `parent/child.jpg`
for element in self.elements:
element["states"] = ["/".join(x.split("/")[-2:]) for x in element["states"]]
# Assemble the set of dropout indices for the given epoch...
n_gen = int(self.gen_ratio * len(self.elements))
self.gen_indices = set(np.random.choice(len(self.elements), n_gen, replace=False))
def __getitem__(
self, index: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, float]:
"""Return both states/frames, language to condition on, language to generate, decomposed as input_ids/mask."""
vid = self.elements[index]["vid"]
# Fetch language to condition on / generate...
lang_condition = torch.tensor(self.language_index[vid]["input_ids"], dtype=torch.int64)
lang_condition_mask = torch.tensor(self.language_index[vid]["attention_mask"], dtype=torch.int64)
lang_gen = torch.tensor(self.language_index[vid]["input_ids"], dtype=torch.int64)
lang_gen_mask = torch.tensor(self.language_index[vid]["attention_mask"], dtype=torch.int64)
# Generate/Condition Check --> (Naive Zeroing leads to NaN --> just want the "CLS" token)
if index in self.gen_indices:
# If generating, just condition on the token (always the initial...), but generate everything!
lang_condition[1:] *= 0
lang_condition_mask[1:] *= 0
lang_gen_weight = 1
else:
# If conditioning, generate the token (dummy so things don't crash), but set weight to 0
lang_gen[1:] *= 0
lang_gen_mask[1:] *= 0
lang_gen_weight = 0
# Retrieve Frames
if not self.stream:
imgs = self.elements[index]["states"]
imgs = torch.stack([self.transform(read_image(s)) for s in imgs])
# Return...
return imgs, lang_condition, lang_condition_mask, lang_gen, lang_gen_mask, lang_gen_weight
else:
# Multiplex w/ num_worker idx...
worker_info = get_worker_info()
r = (self.r + worker_info.id) if worker_info is not None else self.r
# Streaming + Retry Logic (in case of a bad connection -- retry same files!)
frame_paths, current_frame = list(self.elements[index]["states"]), None
for _i in range(self.n_retries):
try:
# Stream JPEG contents into BytesIO (seek back to 0), then PIL Image.open()
imgs = []
for _current_idx, current_frame in enumerate(frame_paths):
if self.is_val:
blob, fobj = BUCKETS[r].blob(str(self.prefix / "val" / current_frame)), BytesIO()
else:
blob, fobj = BUCKETS[r].blob(str(self.prefix / "train" / current_frame)), BytesIO()
# Error Handling --> we've already run several dry-run/verification trials, but about ~0.004% of
# the time, we'll hit some sort of TCP/Transport error; this might even go up
# with multiple runs happening at the same time.
#
# To address this, we're adopting the simplest possible "retry" strategy that
# immediately tries to re-download the same file (crashes if not possible).
# This ensures reproducibility, but puts some extra effort onto the user...
# File download...
blob.download_to_file(fobj)
fobj.seek(0)
# Image loading...
img_tensor = pil_to_tensor(Image.open(fobj))
imgs.append(self.transform(img_tensor))
# Stack...
assert len(imgs) == 2, "Something went awry with try/except in GenStateOK Dataset..."
imgs = torch.stack(imgs)
# Return...
return imgs, lang_condition, lang_condition_mask, lang_gen, lang_gen_mask, lang_gen_weight
except (NotFound, TransportError, UnidentifiedImageError, OSError) as e:
# At the minimum --> print the broken file (obnoxiously!)
print(f"=>> BROKEN FILE :: {current_frame}")
if not self.do_retry:
raise e
else:
continue
# If we've exhausted our retries --> raise an informative ValueError
raise ValueError(f"Failed to fix states `{self.elements[index]['states']}` w/ {self.n_retries} retries...")
def __len__(self) -> int:
return len(self.elements)
class QuintetDataset(PretrainDataset):
def __init__(
self,
epoch: int,
index_path: Path,
img_transform: Compose,
lang_dropout: Optional[float] = None,
stream: bool = False,
prefix: Optional[Path] = None,
is_val: bool = False,
) -> None:
super().__init__()
self.index_path, self.stream, self.is_val, self.val_loaded = index_path, stream, is_val, False
self.epoch, self.transform, self.elements, self.prefix = epoch, img_transform, None, prefix
self.lang_dropout = 0.0 if (lang_dropout is None or lang_dropout == 0) else lang_dropout
self.dropout_indices = set()
self.r = N_CORES * get_rank()
# Load Language Index & Retrieve Epoch 0 Batches
language_path = "val-language-index.json" if self.is_val else "train-language-index.json"
if not self.stream:
with open(self.index_path / language_path, "r") as f:
self.language_index = json.load(f)
else:
blob = BUCKETS[self.r].blob(str(self.prefix / "index" / language_path))
self.language_index = json.loads(blob.download_as_string())
# === Retrieve Epoch Batches (only call `set_epoch` inside __init__() ===
self.set_epoch(self.epoch)
def set_epoch(self, epoch: int) -> None:
# Not Streaming --> Read from local disk...
if not self.stream:
if self.is_val and not self.val_loaded:
with open(self.index_path / "quintet+language" / "validation-batches.json", "r") as f:
self.elements = json.load(f)
# Assemble the set of dropout indices for the given epoch...
n_drop = int(self.lang_dropout * len(self.elements))
self.dropout_indices = set(np.random.choice(len(self.elements), n_drop, replace=False))
# Set `val_loaded` and move on...
self.val_loaded = True
elif not self.is_val:
with open(self.index_path / "quintet+language" / f"train-epoch={epoch}-batches.json", "r") as f:
self.elements = json.load(f)
# Assemble the set of dropout indices for the given epoch...
n_drop = int(self.lang_dropout * len(self.elements))
self.dropout_indices = set(np.random.choice(len(self.elements), n_drop, replace=False))
# Streaming --> Beam directly from Bucket
else:
if self.is_val and not self.val_loaded:
blob = BUCKETS[self.r].blob(str(self.prefix / "index" / "quintet+language" / "validation-batches.json"))
self.elements = json.loads(blob.download_as_string())
# `elements[i]["states"]` currently maps to disk path.. remove all but `parent/child.jpg`
for element in self.elements:
element["states"] = ["/".join(x.split("/")[-2:]) for x in element["states"]]
# Assemble the set of dropout indices for the given epoch...
n_drop = int(self.lang_dropout * len(self.elements))
self.dropout_indices = set(np.random.choice(len(self.elements), n_drop, replace=False))
# Set `val_loaded` and move on...
self.val_loaded = True
elif not self.is_val:
blob = BUCKETS[self.r].blob(
str(self.prefix / "index" / "quintet+language" / f"train-epoch={epoch}-batches.json")
)
self.elements = json.loads(blob.download_as_string())
# `elements[i]["states"]` currently maps to disk path.. remove all but `parent/child.jpg`
for element in self.elements:
element["states"] = ["/".join(x.split("/")[-2:]) for x in element["states"]]
# Assemble the set of dropout indices for the given epoch...
n_drop = int(self.lang_dropout * len(self.elements))
self.dropout_indices = set(np.random.choice(len(self.elements), n_drop, replace=False))
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Return all 5 states/frames, and language, decomposed as the input_ids and attention_mask."""
vid = self.elements[index]["vid"]
lang = torch.tensor(self.language_index[vid]["input_ids"], dtype=torch.int64)
lang_mask = torch.tensor(self.language_index[vid]["attention_mask"], dtype=torch.int64)
# Dropout Language Check --> (Naive Zeroing leads to NaN --> just want the "PAD" token)
if index in self.dropout_indices:
# Initial language token is *always* = `101` --> last token always = `102`
lang[1:] *= 0
lang_mask[1:] *= 0
# Retrieve Frames
if not self.stream:
imgs = self.elements[index]["states"]
imgs = torch.stack([self.transform(read_image(s)) for s in imgs])
else:
# Multiplex w/ num_worker idx...
worker_info, imgs = get_worker_info(), []
r = (self.r + worker_info.id) if worker_info is not None else self.r
# Stream JPEG contents into BytesIO (seek back to 0), then PIL Image.open()
for state in self.elements[index]["states"]:
if self.is_val:
blob, fobj = BUCKETS[r].blob(str(self.prefix / "val" / state)), BytesIO()
else:
blob, fobj = BUCKETS[r].blob(str(self.prefix / "train" / state)), BytesIO()
# Download into FileObj & Rewind...
blob.download_to_file(fobj)
fobj.seek(0)
# Add to imgs...
imgs.append(self.transform(pil_to_tensor(Image.open(fobj))))
# Stack...
imgs = torch.stack(imgs)
return imgs, lang, lang_mask
def __len__(self) -> int:
return len(self.elements)
def get_epoch_datasets(
epoch: int,
name: str,
normalization: Tuple[Any, Any],
model_arch: str,
stream: bool,
artifact_path: str,
stream_prefix: str,
data_modality: str,
lang_dropout: Optional[float] = None,
gen_ratio: Optional[float] = None,
) -> Tuple[PretrainDataset, PretrainDataset]:
"""Retrieve the custom Dataset classes for the train & val set for the given dataset & data modality."""
index, img_transform = Path(artifact_path) / name / "index", get_online_transform(name, model_arch, normalization)
prefix = Path(stream_prefix) / name if stream else stream_prefix
# Switch on `data_modality`
if data_modality == "state":
train_ds = StateDataset(epoch, index, img_transform, stream, prefix)
val_ds = StateDataset(epoch, index, img_transform, stream, prefix, is_val=True)
elif data_modality == "state+language":
train_ds = StateLanguageDataset(epoch, index, img_transform, lang_dropout, stream, prefix)
val_ds = StateLanguageDataset(epoch, index, img_transform, lang_dropout, stream, prefix, is_val=True)
elif data_modality == "state+ok":
if gen_ratio is None:
nl = model_arch == "v-dual"
train_ds = StateOKDataset(epoch, index, img_transform, lang_dropout, stream, prefix, no_lang=nl)
val_ds = StateOKDataset(epoch, index, img_transform, lang_dropout, stream, prefix, no_lang=nl, is_val=True)
else:
# Special Generative Language Dataset...
train_ds = GenStateOKDataset(epoch, index, img_transform, gen_ratio, stream, prefix)
val_ds = GenStateOKDataset(epoch, index, img_transform, gen_ratio, stream, prefix, is_val=True)
elif data_modality == "quintet+language":
train_ds = QuintetDataset(epoch, index, img_transform, lang_dropout, stream, prefix)
val_ds = QuintetDataset(epoch, index, img_transform, lang_dropout, stream, prefix, is_val=True)
else:
raise NotImplementedError(f"Support for data modality `{data_modality}` not yet implemented!")
return train_ds, val_ds
================================================
FILE: voltron/models/__init__.py
================================================
from .instantiate import VMVP, VR3M, VRN3M, VCond, VDual, VGen, get_model_optimizer
================================================
FILE: voltron/models/core/__init__.py
================================================
================================================
FILE: voltron/models/core/vcond.py
================================================
"""
vcond.py
PyTorch Module defining the Voltron `V-Cond` variant (single-frame with language-conditioning). In general, follows the
MAE recipe, with the architectural modifications described in the paper:
- RMSNorm, for stability/performance ("Do Transformer Modifications Transfer...")
- SwishGLU activations in the Attention Block Feed-Forward MLP (gated linear units) as used in PaLM
- LayerScale with a default value of 0.1 (from Mistral/CaIT)
References:
- https://github.com/facebookresearch/mae
- https://github.com/lucidrains/x-transformers
"""
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import transformers
from einops import rearrange, repeat
from voltron.models.util.optimization import get_lr_update
from voltron.models.util.transformer import Block, PatchEmbed, RMSNorm, get_2D_position_embeddings
# Suppress Transformers Logging
transformers.logging.set_verbosity_error()
class VCond(nn.Module):
def __init__(
self,
resolution: int,
patch_size: int,
encoder_depth: int,
encoder_embed_dim: int,
encoder_n_heads: int,
decoder_depth: int,
decoder_embed_dim: int,
decoder_n_heads: int,
language_model: str,
hf_cache: str,
language_dim: int,
optimizer: str,
schedule: str,
base_lr: float,
min_lr: float,
effective_bsz: float,
betas: Tuple[float, float],
weight_decay: float,
warmup_epochs: int,
max_epochs: int,
mask_ratio: float = 0.75,
mlp_ratio: float = 4.0,
in_channels: int = 3,
norm_pixel_loss: bool = True,
use_cls_token: bool = False,
) -> None:
"""
Initialize a VCond model with the requisite architecture parameters.
:param resolution: Base image resolution -- usually 224 (ImageNet size).
:param patch_size: Height/Width of each patch in pixels -- usually 16.
:param encoder_depth: Number of Transformer blocks in the encoder -- should be greater than decoder.
:param encoder_embed_dim: Core embedding/hidden dimension for encoder vision transformer backbone.
:param encoder_n_heads: Number of heads for encoder multi-headed self-attention.
:param decoder_depth: Number of Transformer blocks in the decoder -- should be relatively shallow.
:param decoder_embed_dim: Core embedding/hidden dimension for encoder vision transformer backbone.
:param decoder_n_heads: Number of heads for encoder multi-headed self-attention.
:param language_model: Language model to freeze for encoding narrations/utterances.
:param hf_cache: Cache directory to store pretrained models, for safe distributed training.
:param language_dim: Dimensionality of the language embedding coming out of the pretrained LM.
:param optimizer: String denoting which optimizer to use (for MAEs, usually `adamw`)
:param schedule: Learning rate schedule to use; for Transformers a linear warmup + decay is recommended!
:param base_lr: Base learning rate, to be scaled via a linear scaling rule (from scaling laws).
:param min_lr: Minimum learning rate to decay to over the course of learning (usually 0.0)
:param effective_bsz: Global batch size for update, dictates the scaling of the base_lr.
:param betas: Adam optimizer betas (only applicable for `adam` and `adamw`. Prevents early loss spiking.
:param weight_decay: Weight decay for global weight regularization (only applied to non-bias, non-LN layers).
:param warmup_epochs: Number of epochs to warmup learning rate for linear warmup schedule.
:param max_epochs: Total number of training epochs to be run.
:param mask_ratio: Ratio for number of patches to mask out for MAE -- should be fairly high!
:param mlp_ratio: Ratio for embedding size to Position-wise Feed-Forward MLP (gets shrunk back down).
:param in_channels: Default number of channels in the base image -- almost always 3.
:param norm_pixel_loss: Normalize decoder pixel targets for reconstruction (better perf, not interpretable).
:param use_cls_token: Add token for continued pretraining (NOTE: not used in MAE pretraining/finetuning!)
"""
super().__init__()
self.resolution, self.patch_size, self.mask_ratio = resolution, patch_size, mask_ratio
self.in_channels, self.norm_pixel_loss, self.mlp_ratio = in_channels, norm_pixel_loss, mlp_ratio
self.optimizer, self.schedule, self.betas, self.weight_decay = optimizer, schedule, betas, weight_decay
self.lr, self.base_lr, self.min_lr, self.effective_bsz = None, base_lr, min_lr, effective_bsz
self.warmup_epochs, self.max_epochs = warmup_epochs, max_epochs
self.use_cls_token = use_cls_token
self.language_dim = language_dim
# Encoder/Decoder Parameters
self.encoder_depth, self.decoder_depth = encoder_depth, decoder_depth
self.encoder_embed_dim, self.encoder_n_heads = encoder_embed_dim, encoder_n_heads
self.decoder_embed_dim, self.decoder_n_heads = decoder_embed_dim, decoder_n_heads
# General Parameters (for downstream adaptation)
self.embed_dim, self.n_heads = self.encoder_embed_dim, self.encoder_n_heads
# (Optional) Token Handling
if self.use_cls_token:
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.encoder_embed_dim))
# MAE Encoder Parameters
self.patch2embed = PatchEmbed(
self.resolution, self.patch_size, self.encoder_embed_dim, in_channels=self.in_channels
)
self.encoder_pe = nn.Parameter(
torch.zeros(1, self.patch2embed.num_patches + (1 if self.use_cls_token else 0), self.encoder_embed_dim),
requires_grad=False,
)
self.encoder_blocks = nn.ModuleList(
[
Block(
self.encoder_embed_dim,
self.encoder_n_heads,
self.mlp_ratio,
do_rms_norm=True,
do_swish_glu=True,
do_layer_scale=True,
)
for _ in range(self.encoder_depth)
]
)
self.encoder_norm = RMSNorm(self.encoder_embed_dim)
# Projection from Language Embedding to Decoder
self.lang2encoder = nn.Linear(self.language_dim, self.encoder_embed_dim)
# Projection from Encoder to Decoder
self.encoder2decoder = nn.Linear(self.encoder_embed_dim, self.decoder_embed_dim)
# MAE Decoder Parameters -- Remember the CLS Token (if specified)!
self.mask_token = nn.Parameter(torch.zeros(1, 1, self.decoder_embed_dim))
self.decoder_pe = nn.Parameter(
torch.zeros(1, self.patch2embed.num_patches + (1 if self.use_cls_token else 0), self.decoder_embed_dim),
requires_grad=False,
)
self.decoder_blocks = nn.ModuleList(
[
Block(
self.decoder_embed_dim,
self.decoder_n_heads,
self.mlp_ratio,
do_rms_norm=True,
do_swish_glu=True,
do_layer_scale=True,
)
for _ in range(self.decoder_depth)
]
)
self.decoder_norm = RMSNorm(self.decoder_embed_dim)
self.decoder_prediction = nn.Linear(self.decoder_embed_dim, (patch_size**2) * in_channels, bias=True)
# VCond -- Add "Image" and "Language" Modifier Tokens...
self.img_token = nn.Parameter(torch.zeros(1, 1, self.encoder_embed_dim))
self.lang_token = nn.Parameter(torch.zeros(1, 1, self.encoder_embed_dim))
# Initialize all Weights
self.initialize_weights()
# @AFTER INITIALIZATION -- Create Language Model & Language Reward MLP --> LM has requires_grad = False
# > For BERT models, our "embedding" is just going to be the last hidden state
# > Assumes inputs to forward pass are pre-tokenized!
self.tokenizer = transformers.AutoTokenizer.from_pretrained(language_model, cache_dir=hf_cache)
self.lm = transformers.AutoModel.from_pretrained(language_model, cache_dir=hf_cache)
self.lm.eval()
# Shape Assertion -- make sure self.language_dim actually is the same as the LM dimension!
assert self.lm.config.dim == self.language_dim, "Language model embedding dimension != self.language_dim!"
# Freeze the LM
for _, param in self.lm.named_parameters():
param.requires_grad = False
def initialize_weights(self) -> None:
# Position Encoding -- Fixed 2D Sine-Cosine Embeddings
enc_pe = get_2D_position_embeddings(
self.encoder_embed_dim, int(self.patch2embed.num_patches**0.5), cls_token=self.use_cls_token
)
self.encoder_pe.data.copy_(torch.from_numpy(enc_pe).float().unsqueeze(0))
dec_pe = get_2D_position_embeddings(
self.decoder_embed_dim, int(self.patch2embed.num_patches**0.5), cls_token=self.use_cls_token
)
self.decoder_pe.data.copy_(torch.from_numpy(dec_pe).float().unsqueeze(0))
# Initialize PatchEmbedding as a Linear...
nn.init.xavier_uniform_(self.patch2embed.proj.weight.data.view([self.patch2embed.proj.weight.data.shape[0], -1]))
# Initialize Mask Token, Img Token, Lang Token w/ Truncated Normal
nn.init.normal_(self.mask_token, std=0.02)
nn.init.normal_(self.img_token, std=0.02)
nn.init.normal_(self.lang_token, std=0.02)
if self.use_cls_token:
nn.init.normal_(self.cls_token, std=0.02)
# Default Transformer initializer on everything else...
self.apply(self.transformer_initializer)
@staticmethod
def transformer_initializer(m: nn.Module) -> None:
# Use `xavier_uniform` following Jax ViT
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0.0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.weight, 1.0)
nn.init.constant_(m.bias, 0.0)
def encode_language(self, lang: torch.Tensor, lang_mask: torch.Tensor) -> torch.Tensor:
"""Encode language by feeding the *pre-tokenized text* through the frozen language model."""
self.lm.eval()
with torch.no_grad():
transformer_embeddings = self.lm(lang, attention_mask=lang_mask).last_hidden_state
return transformer_embeddings
def mask(
self, patches: torch.Tensor, mask_ratio: Optional[float] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Perform per-sample random masking by shuffling :: uses argsort random noise to identify masked patches."""
bsz, n_patches, embed_dim = patches.shape
if mask_ratio is not None:
n_keep = int(n_patches * (1 - mask_ratio))
else:
n_keep = int(n_patches * (1 - self.mask_ratio))
# Sample noise of n_patches size, argsort to get shuffled IDs, argsort again to get "unshuffle"
# > For clarity -- argsort is an invertible transformation (if argsort `restore`, recovers `shuffle`)
shuffle_idxs = torch.argsort(torch.rand(bsz, n_patches, device=patches.device), dim=1)
restore_idxs = torch.argsort(shuffle_idxs, dim=1)
# Get "keep" (visible) patches
visible_patches = torch.gather(patches, dim=1, index=shuffle_idxs[:, :n_keep, None].repeat(1, 1, embed_dim))
# Generate the binary mask --> IMPORTANT :: `0` is keep, `1` is remove (following MAE convention)
mask = torch.ones(bsz, n_patches, device=patches.device)
mask[:, :n_keep] = 0
mask = torch.gather(mask, dim=1, index=restore_idxs)
return visible_patches, mask, restore_idxs
def get_representations(
self, img: torch.Tensor, language: Optional[Union[List[str], Tuple[str]]] = None, mode: str = "multimodal"
) -> torch.Tensor:
"""
Given either a singleton (img, language) pair or a batch of images and language, extract representations
subject to the specified mode in < multimodal | visual >.
:param img: Processed batch of images :: [bsz, 3, 224, 224]
:param language: Input language as `List[str] | Tuple[str] | None`
:param mode: Type of representations to extract -- `multimodal` (both vision+text), `visual` (visual only)
:return: Extracted representations given (img, language) input as sequence.
"""
assert img.ndim == 4 and (
language is None or isinstance(language, list) or isinstance(language, tuple)
), "Invalid input to `get_representations()`"
assert mode in {"multimodal", "visual"}, f"Extraction mode `{mode}` not supported!"
# Tokenize Language --> note max length is 20!
if language is None:
lang, lang_mask = [torch.zeros(img.shape[0], 20, dtype=int, device=self.lm.device) for _ in range(2)]
lang[:, 0], lang_mask[:, 0] = self.tokenizer.cls_token_id, 1
else:
tokens = self.tokenizer(language, return_tensors="pt", max_length=20, padding="max_length", truncation=True)
lang, lang_mask = tokens["input_ids"].to(self.lm.device), tokens["attention_mask"].to(self.lm.device)
# Tile Language & Language Mask if mismatch with # images!
if not len(lang) == len(img):
lang = repeat(lang, "b seq -> (bsz b) seq", bsz=img.size(0))
lang_mask = repeat(lang_mask, "b seq -> (bsz b) seq", bsz=img.size(0))
# Extract desired representations...
representations = self.encode(img, lang, lang_mask)
return representations if mode == "multimodal" else representations[:, : -lang_mask.shape[-1]]
def encode(self, img: torch.Tensor, lang: torch.Tensor, lang_mask: torch.Tensor) -> torch.Tensor:
"""Default representation extraction function, given a batch of images and tokenized language."""
lang_embeddings = self.encode_language(lang, lang_mask)
projected_language = self.lang2encoder(lang_embeddings)
# Patchify
patches = self.patch2embed(img)
# (Optional) Token Handling
if self.use_cls_token:
cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)
patches = torch.cat([cls_tokens, patches], dim=1)
# Position Encoding
patches_pe = patches + self.encoder_pe
# Add "modality" embeddings to patches & language
img_embeddings, lang_embeddings = patches_pe + self.img_token, projected_language + self.lang_token
# Create "dummy" visible mask, concatenate image patches & language, feed to Transformer
patches_mask = torch.ones_like(img_embeddings[..., -1], dtype=lang_mask.dtype)
multimodal_embeddings = torch.cat([img_embeddings, lang_embeddings], dim=1) # Merge on sequence length...
multimodal_mask = torch.cat([patches_mask, lang_mask], dim=1) # Merge on sequence length...
# Apply Transformer Blocks...
for block in self.encoder_blocks:
multimodal_embeddings = block(multimodal_embeddings, multimodal_mask)
multimodal_embeddings = self.encoder_norm(multimodal_embeddings)
# Return the full sequence of multimodal embeddings...
return multimodal_embeddings
def forward_encoder(
self, img: torch.Tensor, lang: torch.Tensor, lang_mask: torch.Tensor, mask_ratio: Optional[float] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
lang_embeddings = self.encode_language(lang, lang_mask)
projected_lang = self.lang2encoder(lang_embeddings)
# Patchify + Position Embedding (without Token!)
patches = self.patch2embed(img)
patches_pe = patches + (self.encoder_pe if not self.use_cls_token else self.encoder_pe[:, 1:, :])
# Create mask (and go ahead and mask out patches at the same time)
visible_patches, mask, restore_idxs = self.mask(patches_pe, mask_ratio)
# (Optional) Token Handling
if self.use_cls_token:
cls_token_pe = self.cls_token + self.encoder_pe[:, :1, :]
cls_tokens = cls_token_pe.expand(img.shape[0], -1, -1)
visible_patches = torch.cat([cls_tokens, visible_patches], dim=1)
# Add "modality" embeddings to patches & language
visible_patches, projected_lang = visible_patches + self.img_token, projected_lang + self.lang_token
# Create "dummy" visible mask, concatenate image patches & language, feed to Transformer
visible_mask = torch.ones_like(visible_patches[..., -1], dtype=lang_mask.dtype)
multimodal_embedding = torch.cat([visible_patches, projected_lang], dim=1) # Merge on sequence length...
multimodal_mask = torch.cat([visible_mask, lang_mask], dim=1) # Merge on sequence length...
# Apply Transformer Blocks...
for block in self.encoder_blocks:
multimodal_embedding = block(multimodal_embedding, multimodal_mask)
multimodal_embedding = self.encoder_norm(multimodal_embedding)
# Split multimodal embedding, remove language and return only the visible patches (+ optional token)!
visible_patches = multimodal_embedding[:, : -lang_mask.shape[-1], ...]
return visible_patches, mask, restore_idxs
def forward_decoder(self, visible_patches: torch.Tensor, restore_idxs: torch.Tensor) -> torch.Tensor:
# Project patches into decoder embedding dimension
projected_patches = self.encoder2decoder(visible_patches)
# Add Mask Tokens to Sequence & Unshuffle
mask_tokens = self.mask_token.repeat(
projected_patches.shape[0],
restore_idxs.shape[1] - visible_patches.shape[1] + (1 if self.use_cls_token else 0),
1,
)
# (Optional) Token Handling
if self.use_cls_token:
# Remove & add back CLS Token as part of the "unshuffling"
concatenated_patches = torch.cat([projected_patches[:, 1:, :], mask_tokens], dim=1) # Skip CLS Token
no_cls_unshuffled_patches = torch.gather(
concatenated_patches, dim=1, index=restore_idxs[..., None].repeat(1, 1, self.decoder_embed_dim)
)
unshuffled_patches = torch.cat([projected_patches[:, :1, :], no_cls_unshuffled_patches], dim=1)
else:
concatenated_patches = torch.cat([projected_patches, mask_tokens], dim=1)
unshuffled_patches = torch.gather(
concatenated_patches, dim=1, index=restore_idxs[..., None].repeat(1, 1, self.decoder_embed_dim)
)
# Add Position Embeddings
decoder_patches = unshuffled_patches + self.decoder_pe
# Apply Transformer Blocks...
for block in self.decoder_blocks:
decoder_patches = block(decoder_patches)
decoder_patches = self.decoder_norm(decoder_patches)
# Run final projection & return --> note token handling!
decoder_prediction = self.decoder_prediction(decoder_patches)
return decoder_prediction if not self.use_cls_token else decoder_prediction[:, 1:, :]
def patchify(self, imgs: torch.Tensor) -> torch.Tensor:
"""Convert a batch of images to their patched equivalents, by naive reshaping"""
return rearrange(
imgs,
"bsz c (height patch_h) (width patch_w) -> bsz (height width) (patch_h patch_w c)",
patch_h=self.patch_size,
patch_w=self.patch_size,
)
def compute_loss(self, imgs: torch.Tensor, reconstructions: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
assert self.norm_pixel_loss, "`norm_pixel_loss` should always be true... false only for visualizations!"
targets = self.patchify(imgs)
# Normalize targets...
mu, var = targets.mean(dim=-1, keepdim=True), targets.var(dim=-1, unbiased=True, keepdim=True)
targets = (targets - mu) / ((var + 1e-6) ** 0.5)
# Compute mean loss per patch first...
mse = (reconstructions - targets) ** 2
avg_loss_per_patch = mse.mean(dim=-1)
# Compute mean loss only on *removed* patches and return
return (avg_loss_per_patch * mask).sum() / mask.sum()
def forward(
self, img: torch.Tensor, lang: torch.Tensor, lang_mask: torch.Tensor, mask_ratio: Optional[float] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
visible_patches, mask, restore_idxs = self.forward_encoder(img, lang, lang_mask, mask_ratio)
reconstructions = self.forward_decoder(visible_patches, restore_idxs)
loss = self.compute_loss(img, reconstructions, mask)
return loss, reconstructions, mask
def configure_optimizer(self) -> Tuple[torch.optim.Optimizer, Callable[[int, float], float]]:
# Short-Circuit on Valid Optimizers
if self.optimizer not in ["adamw"]:
raise NotImplementedError(f"Optimizer `{self.optimizer}` not supported - try [`adamw`] instead!")
# Create Parameter Groups --> bias terms, normalization layer parameters shouldn't be decayed...
# > This is a compact rewrite of `param_groups_weight_decay()` from TIMM because I don't want the dependency
decay, no_decay = [], []
for name, param in self.named_parameters():
if not param.requires_grad:
continue
# Check on any parameters with fewer than 2 dimensions or with "bias" in the name...
if param.ndim <= 1 or name.endswith(".bias"):
no_decay.append(param)
else:
decay.append(param)
# Build Parameter Groups
groups = [{"params": decay, "weight_decay": self.weight_decay}, {"params": no_decay, "weight_decay": 0.0}]
# Compute LR -- MAE uses the `linear scaling rule` :: lr = base_lr * (effective_bsz / 256)
# > https://github.com/facebookresearch/mae/blob/main/PRETRAIN.md
self.lr = self.base_lr * (self.effective_bsz / 256)
# Create Optimizer & LR Scheduler
optimizer = torch.optim.AdamW(groups, lr=self.lr, betas=self.betas)
update_lr = get_lr_update(optimizer, self.schedule, self.lr, self.min_lr, self.warmup_epochs, self.max_epochs)
return optimizer, update_lr
================================================
FILE: voltron/models/core/vdual.py
================================================
"""
vdual.py
PyTorch Module defining the Voltron `V-Dual` variant (dual-frame with language-conditioning). In general, follows the
MAE recipe, with the same modifications described in `vcond.py`.
When masking visual patches, the same patches are elided for both the 0th frame and the Kth frame to avoid cheating!
References:
- https://github.com/huggingface/m4/blob/main/m4/modeling/pretraining/video/videomae.py
- https://github.com/lucidrains/vit-pytorch
- https://github.com/MCG-NJU/VideoMAE
"""
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import transformers
from einops import rearrange, repeat
from voltron.models.util.optimization import get_lr_update
from voltron.models.util.transformer import Block, PatchEmbed, RMSNorm, get_2D_position_embeddings
# Suppress Transformers Logging
transformers.logging.set_verbosity_error()
class VDual(nn.Module):
def __init__(
self,
resolution: int,
patch_size: int,
encoder_depth: int,
encoder_embed_dim: int,
encoder_n_heads: int,
decoder_depth: int,
decoder_embed_dim: int,
decoder_n_heads: int,
language_model: str,
hf_cache: str,
language_dim: int,
optimizer: str,
schedule: str,
base_lr: float,
min_lr: float,
effective_bsz: int,
betas: Tuple[float, float],
weight_decay: float,
warmup_epochs: int,
max_epochs: int,
mask_ratio: float = 0.75,
mlp_ratio: float = 4.0,
in_channels: int = 3,
norm_pixel_loss: bool = True,
use_cls_token: bool = False,
) -> None:
"""
Initialize a VDual model with the requisite architecture parameters.
:param resolution: Base image resolution -- usually 224 (ImageNet size).
:param patch_size: Height/Width of each patch in pixels -- usually 16.
:param encoder_depth: Number of Transformer blocks in the encoder -- should be greater than decoder.
:param encoder_embed_dim: Core embedding/hidden dimension for encoder vision transformer backbone.
:param encoder_n_heads: Number of heads for encoder multi-headed self-attention.
:param decoder_depth: Number of Transformer blocks in the decoder -- should be relatively shallow.
:param decoder_embed_dim: Core embedding/hidden dimension for encoder vision transformer backbone.
:param decoder_n_heads: Number of heads for encoder multi-headed self-attention.
:param language_model: Language model to freeze for encoding narrations/utterances.
:param hf_cache: Cache directory to store pretrained models, for safe distributed training.
:param language_dim: Dimensionality of the language embedding coming out of the pretrained LM.
:param optimizer: String denoting which optimizer to use (for MAEs, usually `adamw`)
:param schedule: Learning rate schedule to use; for Transformers a linear warmup + decay is recommended!
:param base_lr: Base learning rate, to be scaled via a linear scaling rule (from scaling laws).
:param min_lr: Minimum learning rate to decay to over the course of learning (usually 0.0)
:param effective_bsz: Global batch size for update, dictates the scaling of the base_lr.
:param betas: Adam optimizer betas (only applicable for `adam` and `adamw`. Prevents early loss spiking.
:param weight_decay: Weight decay for global weight regularization (only applied to non-bias, non-LN layers).
:param warmup_epochs: Number of epochs to warmup learning rate for linear warmup schedule.
:param max_epochs: Total number of training epochs to be run.
:param mask_ratio: Ratio for number of patches to mask out for MAE -- should be fairly high!
:param mlp_ratio: Ratio for embedding size to Position-wise FeedForward MLP (gets shrunk back down).
:param in_channels: Default number of channels in the base image -- almost always 3.
:param norm_pixel_loss: Normalize decoder pixel targets for reconstruction (better perf, not interpretable).
:param use_cls_token: Add token for continued pretraining (NOTE: not used in MAE pretraining/finetuning!)
"""
super().__init__()
self.resolution, self.patch_size, self.mask_ratio = resolution, patch_size, mask_ratio
self.in_channels, self.norm_pixel_loss, self.mlp_ratio = in_channels, norm_pixel_loss, mlp_ratio
self.optimizer, self.schedule, self.betas, self.weight_decay = optimizer, schedule, betas, weight_decay
self.lr, self.base_lr, self.min_lr, self.effective_bsz = None, base_lr, min_lr, effective_bsz
self.warmup_epochs, self.max_epochs = warmup_epochs, max_epochs
self.use_cls_token = use_cls_token
self.language_dim = language_dim
# Encoder/Decoder Parameters
self.encoder_depth, self.decoder_depth = encoder_depth, decoder_depth
self.encoder_embed_dim, self.encoder_n_heads = encoder_embed_dim, encoder_n_heads
self.decoder_embed_dim, self.decoder_n_heads = decoder_embed_dim, decoder_n_heads
# General Parameters (for downstream adaptation)
self.embed_dim, self.n_heads = self.encoder_embed_dim, self.encoder_n_heads
# (Optional) Token Handling
if self.use_cls_token:
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.encoder_embed_dim))
# MAE Encoder Parameters
self.patch2embed = PatchEmbed(
self.resolution, self.patch_size, self.encoder_embed_dim, in_channels=self.in_channels
)
self.encoder_pe = nn.Parameter(
torch.zeros(1, self.patch2embed.num_patches + (1 if self.use_cls_token else 0), self.encoder_embed_dim),
requires_grad=False,
)
self.encoder_blocks = nn.ModuleList(
[
Block(
self.encoder_embed_dim,
self.encoder_n_heads,
self.mlp_ratio,
do_rms_norm=True,
do_swish_glu=True,
do_layer_scale=True,
)
for _ in range(self.encoder_depth)
]
)
self.encoder_norm = RMSNorm(self.encoder_embed_dim)
# Projection from Language Embedding to Decoder
self.lang2encoder = nn.Linear(self.language_dim, self.encoder_embed_dim)
# Projection from Encoder to Decoder
self.encoder2decoder = nn.Linear(self.encoder_embed_dim, self.decoder_embed_dim)
# MAE Decoder Parameters -- Remember the CLS Token (if specified)!
self.mask_token = nn.Parameter(torch.zeros(1, 1, 1, self.decoder_embed_dim))
self.decoder_pe = nn.Parameter(
torch.zeros(1, self.patch2embed.num_patches + (1 if self.use_cls_token else 0), self.decoder_embed_dim),
requires_grad=False,
)
self.decoder_blocks = nn.ModuleList(
[
Block(
self.decoder_embed_dim,
self.decoder_n_heads,
self.mlp_ratio,
do_rms_norm=True,
do_swish_glu=True,
do_layer_scale=True,
)
for _ in range(self.decoder_depth)
]
)
self.decoder_norm = RMSNorm(self.decoder_embed_dim)
self.decoder_prediction = nn.Linear(self.decoder_embed_dim, (patch_size**2) * in_channels, bias=True)
# VDual -- Add "Image" and "Language" Modifier Tokens...
self.img_token = nn.Parameter(torch.zeros(1, 1, 1, self.encoder_embed_dim))
self.lang_token = nn.Parameter(torch.zeros(1, 1, self.encoder_embed_dim))
# VDual -- Learnable "ctx" position embeddings --> initialize via `randn` following original ViT & @lucidrains
# =>> Ref: https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py#L99
# =>> Note that n_context = 2 (0th frame + Kth frame)
self.ctx_enc_pe = nn.Parameter(torch.randn(1, 2, 1, self.encoder_embed_dim))
self.ctx_dec_pe = nn.Parameter(torch.randn(1, 2, 1, self.decoder_embed_dim))
# Initialize all Weights
self.initialize_weights()
# @AFTER INITIALIZATION -- Create Language Model & Language Reward MLP --> LM has requires_grad = False
# > For BERT models, our "embedding" is just going to be the last hidden state
# > Assumes inputs to forward pass are pre-tokenized!
self.tokenizer = transformers.AutoTokenizer.from_pretrained(language_model, cache_dir=hf_cache)
self.lm = transformers.AutoModel.from_pretrained(language_model, cache_dir=hf_cache)
self.lm.eval()
# Shape Assertion -- make sure self.language_dim actually is the same as the LM dimension!
assert self.lm.config.dim == self.language_dim, "Language model embedding dimension != self.language_dim!"
# Freeze the LM
for _, param in self.lm.named_parameters():
param.requires_grad = False
def initialize_weights(self) -> None:
# Position Encoding -- Fixed 2D Sine-Cosine Embeddings
enc_pe = get_2D_position_embeddings(
self.encoder_embed_dim, int(self.patch2embed.num_patches**0.5), cls_token=self.use_cls_token
)
self.encoder_pe.data.copy_(torch.from_numpy(enc_pe).float().unsqueeze(0))
dec_pe = get_2D_position_embeddings(
self.decoder_embed_dim, int(self.patch2embed.num_patches**0.5), cls_token=self.use_cls_token
)
self.decoder_pe.data.copy_(torch.from_numpy(dec_pe).float().unsqueeze(0))
# Initialize PatchEmbedding as a Linear...
nn.init.xavier_uniform_(self.patch2embed.proj.weight.data.view([self.patch2embed.proj.weight.data.shape[0], -1]))
# Initialize Mask Token, Img Token, Lang Token w/ Truncated Normal
nn.init.normal_(self.mask_token, std=0.02)
nn.init.normal_(self.img_token, std=0.02)
nn.init.normal_(self.lang_token, std=0.02)
if self.use_cls_token:
nn.init.normal_(self.cls_token, std=0.02)
# Everything else...
self.apply(self.transformer_initializer)
@staticmethod
def transformer_initializer(m: nn.Module) -> None:
if isinstance(m, nn.Linear):
# Use xavier_uniform following Jax ViT
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0.0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.weight, 1.0)
nn.init.constant_(m.bias, 0.0)
def encode_language(self, lang: torch.Tensor, lang_mask: torch.Tensor) -> torch.Tensor:
"""Encode language by feeding the *pre-tokenized text* through the frozen language model."""
self.lm.eval()
with torch.no_grad():
transformer_embeddings = self.lm(lang, attention_mask=lang_mask).last_hidden_state
return transformer_embeddings
def mask(
self, ctx_patches: torch.Tensor, mask_ratio: Optional[float] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Perform per-context random masking by shuffling :: uses argsort random noise to identify masked patches."""
bsz, ctx_len, n_patches, embed_dim = ctx_patches.shape
if mask_ratio is not None:
n_keep = int(n_patches * (1 - mask_ratio))
else:
n_keep = int(n_patches * (1 - self.mask_ratio))
# Sample noise of n_patches size, argsort to get shuffled IDs, argsort again to get "unshuffle"
# > For clarity -- argsort is an invertible transformation (if argsort `restore`, recovers `shuffle`)
# > Note that shuffle_idxs is defined solely as a function of *n_patches* and **not** context! Same mask!
shuffle_idxs = torch.argsort(torch.rand(bsz, n_patches, device=ctx_patches.device), dim=1)
restore_idxs = torch.argsort(shuffle_idxs, dim=1)
# Get "keep" (visible) patches --> make sure to get _same_ patches *across* context length!
visible_patches = torch.gather(
ctx_patches, dim=2, index=shuffle_idxs[:, None, :n_keep, None].repeat(1, ctx_len, 1, embed_dim)
)
# Generate the binary mask --> IMPORTANT :: `0` is keep, `1` is remove (following MAE convention)
mask = torch.ones(bsz, n_patches, device=ctx_patches.device)
mask[:, :n_keep] = 0
mask = torch.gather(mask, dim=1, index=restore_idxs)
return visible_patches, mask, restore_idxs
def get_representations(
self, imgs: torch.Tensor, language: Optional[Union[List[str], Tuple[str]]] = None, mode: str = "multimodal"
) -> torch.Tensor:
"""
Given either a singleton (dual-imgs, language) pair or batch of dual-imgs and language, extract representations
subject to the specified mode in < multimodal | visual >.
:param imgs: Processed batch of images :: [bsz, 2, 3, 224, 224]
:param language: Input language as `List[str] | Tuple[str] | None`
:param mode: Type of representations to extract -- `multimodal` (both vision+text), `visual` (visual only)
:return: Extracted representations given (imgs, language) input as sequence.
"""
assert (
imgs.ndim == 5
and imgs.shape[1] == 2
and (language is None or isinstance(language, list) or isinstance(language, tuple))
), "Invalid input to `get_representations()`"
assert mode in {"multimodal", "visual"}, f"Extraction mode `{mode}` not supported!"
# Tokenize Language --> note max length is 20!
if language is None:
lang, lang_mask = [torch.zeros(imgs.shape[0], 20, dtype=int, device=self.lm.device) for _ in range(2)]
lang[:, 0], lang_mask[:, 0] = self.tokenizer.cls_token_id, 1
else:
tokens = self.tokenizer(language, return_tensors="pt", max_length=20, padding="max_length", truncation=True)
lang, lang_mask = tokens["input_ids"].to(self.lm.device), tokens["attention_mask"].to(self.lm.device)
# Tile Language & Language Mask if mismatch with # images!
if not len(lang) == len(imgs):
lang = repeat(lang, "b seq -> (bsz b) seq", bsz=imgs.size(0))
lang_mask = repeat(lang_mask, "b seq -> (bsz b) seq", bsz=imgs.size(0))
# Extract desired representations...
representations = self.encode(imgs, lang, lang_mask)
return representations if mode == "multimodal" else representations[:, : -lang_mask.shape[-1]]
def encode(self, imgs: torch.Tensor, lang: torch.Tensor, lang_mask: torch.Tensor) -> torch.Tensor:
"""Default representation extraction function, given a batch of dual-images and tokenized language."""
lang_embeddings = self.encode_language(lang, lang_mask)
projected_language = self.lang2encoder(lang_embeddings)
# Patchify, broadcast position embedding ctx_len (0 + K) dimension, unfold, add `ctx_enc_pe` embeddings!
patches = self.patch2embed(rearrange(imgs, "bsz ctx channels res1 res2 -> (bsz ctx) channels res1 res2"))
patches_pe = patches + (self.encoder_pe[:, 1:, :] if self.use_cls_token else self.encoder_pe)
ctx_patches = rearrange(patches_pe, "(bsz ctx) seq embed -> bsz ctx seq embed", ctx=2)
ctx_patches_pe = ctx_patches + self.ctx_enc_pe[:, :2, ...]
# Add "modality" embeddings to patches & language & flatten out context patches...
img_ctx_embeddings, lang_embeddings = ctx_patches_pe + self.img_token, projected_language + self.lang_token
img_embeddings = rearrange(img_ctx_embeddings, "bsz ctx seq embed -> bsz (ctx seq) embed")
# (Optional) Token Handling
if self.use_cls_token:
cls_token_pe = self.cls_token + self.encoder_pe[:, :1, :] + self.img_token[:, 0, :, :]
cls_tokens = cls_token_pe.expand(imgs.shape[0], -1, -1)
img_embeddings = torch.cat([cls_tokens, img_embeddings], dim=1)
# Create "dummy" visible mask, concatenate image patches & language, feed to Transformer
patches_mask = torch.ones_like(img_embeddings[..., -1], dtype=lang_mask.dtype)
multimodal_embeddings = torch.cat([img_embeddings, lang_embeddings], dim=1) # Merge on sequence length...
multimodal_mask = torch.cat([patches_mask, lang_mask], dim=1) # Merge on sequence length...
# Apply Transformer Blocks...
for block in self.encoder_blocks:
multimodal_embeddings = block(multimodal_embeddings, multimodal_mask)
multimodal_embeddings = self.encoder_norm(multimodal_embeddings)
# Return the full sequence of multimodal embeddings (but ignore 0th frame) => the `~` denote what to remove!
# => [CLS] + ~[n_patches x 0th frame]~ + [n_patches x Kth frame] + [max_lang_len language]
# => ~[n_patches x 0th frame]~ + [n_patches x Kth frame] + [max_lang_len language]
if self.use_cls_token:
return torch.cat(
[multimodal_embeddings[:, :1, :], multimodal_embeddings[:, 1 + self.patch2embed.num_patches :, :]], dim=1
)
else:
return multimodal_embeddings[:, self.patch2embed.num_patches :]
def forward_encoder(
self, img_ctx: torch.Tensor, lang: torch.Tensor, lang_mask: torch.Tensor, mask_ratio: Optional[float] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
lang_embeddings = self.encode_language(lang, lang_mask)
projected_lang = self.lang2encoder(lang_embeddings)
# Patchify, broadcast position embedding across ctx_len (0 + K) dimension, unfold, add `ctx_enc_pe` embeddings!
patches = self.patch2embed(rearrange(img_ctx, "bsz ctx channels res1 res2 -> (bsz ctx) channels res1 res2"))
patches_pe = patches + (self.encoder_pe if not self.use_cls_token else self.encoder_pe[:, 1:, :])
ctx_patches = rearrange(patches_pe, "(bsz ctx) seq embed -> bsz ctx seq embed", ctx=2)
ctx_patches_pe = ctx_patches + self.ctx_enc_pe[:, :2, ...]
# Create mask (and go ahead and mask out patches at the same time)
visible_ctx_patches, mask, restore_idxs = self.mask(ctx_patches_pe, mask_ratio)
# Add "modality" embeddings to patches & language & flatten out context patches...
visible_ctx_patches, lang = visible_ctx_patches + self.img_token, projected_lang + self.lang_token
visible_patches = rearrange(visible_ctx_patches, "bsz ctx seq embed -> bsz (ctx seq) embed")
# (Optional) Token Handling
if self.use_cls_token:
cls_token_pe = self.cls_token + self.encoder_pe[:, :1, :] + self.img_token[:, 0, :, :]
cls_tokens = cls_token_pe.expand(img_ctx.shape[0], -1, -1)
visible_patches = torch.cat([cls_tokens, visible_patches], dim=1)
# Create "dummy" visible mask, concatenate image patches & language, feed to Transformer...
visible_mask = torch.ones_like(visible_patches[..., -1], dtype=lang_mask.dtype)
multimodal_embedding = torch.cat([visible_patches, lang], dim=1) # Merge on sequence length...
multimodal_mask = torch.cat([visible_mask, lang_mask], dim=1) # Merge on sequence length...
# Apply Transformer Blocks...
for block in self.encoder_blocks:
multimodal_embedding = block(multimodal_embedding, multimodal_mask)
multimodal_embedding = self.encoder_norm(multimodal_embedding)
# Split multimodal embedding, remove language, return the visible ctx (0th + Kth frame) patches (+ )!
visible_patches = multimodal_embedding[:, : -lang_mask.shape[-1], ...]
return visible_patches, mask, restore_idxs
def forward_decoder(self, visible_patches: torch.Tensor, restore_idxs: torch.Tensor) -> torch.Tensor:
# Project patches into decoder embedding dimension (visible_ctx_patches :: [bsz, (CLS) + 2 * seq, enc_embed])
projected_patches = self.encoder2decoder(visible_patches)
visible_per_frame = (projected_patches.shape[1] - (1 if self.use_cls_token else 0)) // 2
# Add Mask Tokens to Sequence and Unshuffle
mask_tokens = self.mask_token.repeat(projected_patches.shape[0], 2, restore_idxs.shape[1] - visible_per_frame, 1)
# (Optional) Token Handling
if self.use_cls_token:
# Remove CLS Token as part of "unshuffling"
projected_ctx_patches = rearrange(
projected_patches[:, 1:, :], "bsz (ctx seq) embed -> bsz ctx seq embed", ctx=2
)
no_cls_concatenated_ctx_patches = torch.cat([projected_ctx_patches, mask_tokens], dim=2)
unshuffled_ctx_patches = torch.gather(
no_cls_concatenated_ctx_patches,
dim=2,
index=restore_idxs[:, None, ..., None].repeat(1, 2, 1, self.decoder_embed_dim),
)
else:
projected_ctx_patches = rearrange(projected_patches, "bsz (ctx seq) embed -> bsz ctx seq embed", ctx=2)
concatenated_ctx_patches = torch.cat([projected_ctx_patches, mask_tokens], dim=2)
unshuffled_ctx_patches = torch.gather(
concatenated_ctx_patches,
dim=2,
index=restore_idxs[:, None, ..., None].repeat(1, 2, 1, self.decoder_embed_dim),
)
# Add position embeddings, `ctx_dec_pe` embeddings, and flatten patches for Transformer...
decoder_ctx_patches_pe = unshuffled_ctx_patches + (
self.decoder_pe[None, ...] if not self.use_cls_token else self.decoder_pe[None, :, 1:, :]
)
decoder_ctx_patches = decoder_ctx_patches_pe + self.ctx_dec_pe[:, :2, ...]
decoder_patches = rearrange(decoder_ctx_patches, "bsz ctx seq embed -> bsz (ctx seq) embed")
# (Optional) Token Handling
if self.use_cls_token:
# Add back Token from `projected_patches[:, :1, :]`
cls_embedding = projected_patches[:, :1, :] + self.decoder_pe[:, :1, :]
decoder_patches = torch.cat([cls_embedding, decoder_patches], dim=1)
# Apply Transformer Blocks...
for block in self.decoder_blocks:
decoder_patches = block(decoder_patches)
decoder_patches = self.decoder_norm(decoder_patches)
# Run final projection & return "unflattened" patches --> note token handling!
decoder_prediction = self.decoder_prediction(decoder_patches)
reconstructions = decoder_prediction if not self.use_cls_token else decoder_prediction[:, 1:, :]
return rearrange(reconstructions, "bsz (ctx seq) embed -> bsz ctx seq embed", ctx=2)
def patchify(self, imgs: torch.Tensor) -> torch.Tensor:
"""Convert a batch of (0th + Kth frame) images to their patched equivalents by naive reshaping."""
return rearrange(
imgs,
"bsz ctx c (height patch_h) (width patch_w) -> bsz ctx (height width) (patch_h patch_w c)",
patch_h=self.patch_size,
patch_w=self.patch_size,
)
def compute_loss(
self, imgs: torch.Tensor, ctx_reconstructions: torch.Tensor, mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
assert self.norm_pixel_loss, "`norm_pixel_loss` should always be true... false only for visualizations!"
targets = self.patchify(imgs)
# Normalize targets...
mu, var = targets.mean(dim=-1, keepdim=True), targets.var(dim=-1, unbiased=True, keepdim=True)
targets = (targets - mu) / ((var + 1e-6) ** 0.5)
# Split targets into 0 and K --> do the same for ctx_reconstructions
zero_target, k_target = targets[:, 0, ...], targets[:, 1, ...]
zero_reconstruction, k_reconstruction = ctx_reconstructions[:, 0, ...], ctx_reconstructions[:, 1, ...]
# Compute mean losses per patch first...
zero_mse, k_mse = (zero_reconstruction - zero_target) ** 2, (k_reconstruction - k_target) ** 2
zero_avg_loss_per_patch, k_avg_loss_per_patch = zero_mse.mean(dim=-1), k_mse.mean(dim=-1)
# Compute mean loss only on *removed* patches and return...
return (zero_avg_loss_per_patch * mask).sum() / mask.sum(), (k_avg_loss_per_patch * mask).sum() / mask.sum()
def forward(
self, imgs: torch.Tensor, lang: torch.Tensor, lang_mask: torch.Tensor, mask_ratio: Optional[float] = None
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""
Run a forward pass through the model, computing the language-conditioned MAE reconstruction loss on the
0th + Kth frame temporal context, given language prefix.
:param imgs: A [bsz, 2, in_channels, resolution, resolution] tensor of (0th frame, Kth frame) sequences.
:param lang: A [bsz, seq_len] tensor of language context to condition on.
:param lang_mask: A [bsz, seq_len] binary mask tensor to indicate padding locations in the lang tensor.
:param mask_ratio: Optional masking ratio to use instead of the default.
:return Tuple of losses and intermediates, as follows:
> (combined loss, [reconstruction loss per frame in {0, K}])
"""
visible_ctx_patches, mask, restore_idxs = self.forward_encoder(imgs, lang, lang_mask, mask_ratio)
ctx_reconstructions = self.forward_decoder(visible_ctx_patches, restore_idxs)
zero_loss, k_loss = self.compute_loss(imgs, ctx_reconstructions, mask)
# Return average reconstruction loss, individual losses...
loss = (zero_loss + k_loss) / 2
return loss, [zero_loss, k_loss]
def configure_optimizer(self) -> Tuple[torch.optim.Optimizer, Callable[[int, float], float]]:
# Short-Circuit on Valid Optimizers
if self.optimizer not in ["adamw"]:
raise NotImplementedError(f"Optimizer `{self.optimizer}` not supported - try [`adamw`] instead!")
# Create Parameter Groups --> Bias terms, Normalization layer parameters shouldn't be decayed...
# > This is a compact rewrite of `param_groups_weight_decay()` from TIMM because I don't want the dependency
decay, no_decay = [], []
for name, param in self.named_parameters():
if not param.requires_grad:
continue
# Check on any parameters with fewer than 2 dimensions or with "bias" in the name...
if param.ndim <= 1 or name.endswith(".bias"):
no_decay.append(param)
else:
decay.append(param)
# Build Parameter Groups
groups = [{"params": decay, "weight_decay": self.weight_decay}, {"params": no_decay, "weight_decay": 0.0}]
# Compute LR -- MAE uses the `linear scaling rule` :: lr = base_lr * (effective_bsz / 256)
# > https://github.com/facebookresearch/mae/blob/main/PRETRAIN.md
self.lr = self.base_lr * (self.effective_bsz / 256)
# Create Optimizer & LR Scheduler
optimizer = torch.optim.AdamW(groups, lr=self.lr, betas=self.betas)
update_lr = get_lr_update(optimizer, self.schedule, self.lr, self.min_lr, self.warmup_epochs, self.max_epochs)
return optimizer, update_lr
================================================
FILE: voltron/models/core/vgen.py
================================================
"""
vgen.py
PyTorch Module defining the Voltron `V-Gen` variant (dual-frame with language-conditioning AND language-generation).
This model adds the ability to *both* condition on language context or (XOR) generate language given masked frame
context (with a hyperparameter (`alpha` in the paper) controlling the `gen_ratio` -- the ratio of examples for which to
generate language).
The objective this model seeks to optimize is the sum of the MAE reconstruction error (when conditioning on language)
and the log-likelihood of predicting the next token given prior tokens and the entire learned image representation.
Follows same dual-frame encoding structure as VDual, and architectural modifications from VCond.
References:
- https://github.com/lucidrains/x-transformers
"""
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
from einops import rearrange, repeat
from voltron.models.util.optimization import get_lr_update
from voltron.models.util.transformer import Block, PatchEmbed, RMSNorm, get_2D_position_embeddings
# Suppress Transformers Logging
transformers.logging.set_verbosity_error()
class VGen(nn.Module):
def __init__(
self,
resolution: int,
patch_size: int,
encoder_depth: int,
encoder_embed_dim: int,
encoder_n_heads: int,
decoder_depth: int,
decoder_embed_dim: int,
decoder_n_heads: int,
language_model: str,
hf_cache: str,
language_dim: int,
max_lang_len: int,
vocab_size: int,
mae_weight: float,
lm_weight: float,
optimizer: str,
schedule: str,
base_lr: float,
min_lr: float,
effective_bsz: float,
betas: Tuple[float, float],
weight_decay: float,
warmup_epochs: int,
max_epochs: int,
mask_ratio: float = 0.75,
mlp_ratio: float = 4.0,
in_channels: int = 3,
norm_pixel_loss: bool = True,
use_cls_token: bool = False,
eps: float = 1e-8,
) -> None:
"""
Initialize a VGen model with the requisite architecture parameters.
:param resolution: Base image resolution -- usually 224 (ImageNet size).
:param patch_size: Height/Width of each patch in pixels -- usually 16.
:param encoder_depth: Number of Transformer blocks in the encoder -- should be greater than decoder.
:param encoder_embed_dim: Core embedding/hidden dimension for encoder vision transformer backbone.
:param encoder_n_heads: Number of heads for encoder multi-headed self-attention.
:param decoder_depth: Number of Transformer blocks in the decoder -- should be relatively shallow.
:param decoder_embed_dim: Core embedding/hidden dimension for encoder vision transformer backbone.
:param decoder_n_heads: Number of heads for encoder multi-headed self-attention.
:param language_model: Language model to freeze for encoding narrations/utterances.
:param hf_cache: Cache directory to store pretrained models, for safe distributed training.
:param language_dim: Dimensionality of the language embedding coming out of the pretrained LM.
:param max_lang_len: Maximum length of input sequence (in tokens).
:param vocab_size: Vocabulary size for final cross-entropy loss over token prediction.
:param mae_weight: Weighting term for the MAE loss -- usually 1.0 (borrowed from M3AE paper as *rough* guide)
:param lm_weight: Weighting term for the LM loss -- usually 0.5 (borrowed from the M3AE paper as *rough* guide)
:param optimizer: String denoting which optimizer to use (for MAEs, usually `adamw`)
:param schedule: Learning rate schedule to use; for Transformers a linear warmup + decay is recommended!
:param base_lr: Base learning rate, to be scaled via a linear scaling rule (from scaling laws).
:param min_lr: Minimum learning rate to decay to over the course of learning (usually 0.0)
:param effective_bsz: Global batch size for update, dictates the scaling of the base_lr.
:param betas: Adam optimizer betas (only applicable for `adam` and `adamw`. Prevents early loss spiking.
:param weight_decay: Weight decay for global weight regularization (only applied to non-bias, non-LN layers).
:param warmup_epochs: Number of epochs to warmup learning rate for linear warmup schedule.
:param max_epochs: Total number of training epochs to be run.
:param mask_ratio: Ratio for number of patches AND tokens to mask out for M3AE -- should be fairly high!
:param mlp_ratio: Ratio for embedding size to Position-wise FeedForward MLP (gets shrunk back down).
:param in_channels: Default number of channels in the base image -- almost always 3.
:param norm_pixel_loss: Normalize decoder pixel targets for reconstruction (better perf, not interpretable).
:param use_cls_token: Add token for continued pretraining (NOTE: not used in MAE pretraining/finetuning!)
:param eps: Epsilon for preventing divide by zero.
"""
super().__init__()
self.resolution, self.patch_size, self.mask_ratio, self.eps = resolution, patch_size, mask_ratio, eps
self.in_channels, self.norm_pixel_loss, self.mlp_ratio = in_channels, norm_pixel_loss, mlp_ratio
self.optimizer, self.schedule, self.betas, self.weight_decay = optimizer, schedule, betas, weight_decay
self.lr, self.base_lr, self.min_lr, self.effective_bsz = None, base_lr, min_lr, effective_bsz
self.mae_weight, self.lm_weight, self.language_dim = mae_weight, lm_weight, language_dim
self.max_lang_len, self.vocab_size = max_lang_len, vocab_size
self.use_cls_token = use_cls_token
self.warmup_epochs, self.max_epochs = warmup_epochs, max_epochs
# Encoder/Decoder Parameters
self.encoder_depth, self.decoder_depth = encoder_depth, decoder_depth
self.encoder_embed_dim, self.encoder_n_heads = encoder_embed_dim, encoder_n_heads
self.decoder_embed_dim, self.decoder_n_heads = decoder_embed_dim, decoder_n_heads
# General Parameters (for downstream adaptation)
self.embed_dim, self.n_heads = self.encoder_embed_dim, self.encoder_n_heads
# (Optional) Token Handling
if self.use_cls_token:
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.encoder_embed_dim))
# MAE Encoder Parameters
self.patch2embed = PatchEmbed(
self.resolution, self.patch_size, self.encoder_embed_dim, in_channels=self.in_channels
)
self.encoder_pe = nn.Parameter(
torch.zeros(1, self.patch2embed.num_patches + (1 if self.use_cls_token else 0), self.encoder_embed_dim),
requires_grad=False,
)
self.encoder_blocks = nn.ModuleList(
[
Block(
self.encoder_embed_dim,
self.encoder_n_heads,
self.mlp_ratio,
do_rms_norm=True,
do_swish_glu=True,
do_layer_scale=True,
)
for _ in range(self.encoder_depth)
]
)
self.encoder_norm = RMSNorm(self.encoder_embed_dim)
# Projection from Language Embedding to Decoder
self.lang2encoder = nn.Linear(self.language_dim, self.encoder_embed_dim)
self.lang2decoder = nn.Linear(self.language_dim, self.decoder_embed_dim)
# Projection from Encoder to Decoder
self.encoder2decoder = nn.Linear(self.encoder_embed_dim, self.decoder_embed_dim)
# MAE Decoder Parameters -- Remember the CLS Token (if specified)!
self.mask_token = nn.Parameter(torch.zeros(1, 1, 1, self.decoder_embed_dim))
self.decoder_pe = nn.Parameter(
torch.zeros(1, self.patch2embed.num_patches + (1 if self.use_cls_token else 0), self.decoder_embed_dim),
requires_grad=False,
)
self.decoder_blocks = nn.ModuleList(
[
Block(
self.decoder_embed_dim,
self.decoder_n_heads,
self.mlp_ratio,
do_rms_norm=True,
do_swish_glu=True,
do_layer_scale=True,
)
for _ in range(self.decoder_depth)
]
)
self.decoder_norm = RMSNorm(self.decoder_embed_dim)
self.decoder_patch_prediction = nn.Linear(self.decoder_embed_dim, (patch_size**2) * in_channels, bias=True)
self.decoder_lang_prediction = nn.Linear(self.decoder_embed_dim, self.vocab_size, bias=True)
# VGen -- Add "Image" and "Language" Modifier Tokens for Encoder & Decoder...
self.img_enc_token = nn.Parameter(torch.zeros(1, 1, 1, self.encoder_embed_dim))
self.lang_enc_token = nn.Parameter(torch.zeros(1, 1, self.encoder_embed_dim))
# VGen -- Learnable "ctx" position embeddings --> initialize via `randn` following original ViT & @lucidrains
# =>> Ref: https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py#L99
# =>> Note that n_context = 2 (0th frame + Kth frame)
self.ctx_enc_pe = nn.Parameter(torch.randn(1, 2, 1, self.encoder_embed_dim))
self.ctx_dec_pe = nn.Parameter(torch.randn(1, 2, 1, self.decoder_embed_dim))
# Register Prefix Mask --> Lower Triangular ==> set prefix to 1
n_patches, total_seq = 2 * self.patch2embed.num_patches, (2 * self.patch2embed.num_patches) + self.max_lang_len
prefix_mask = torch.tril(torch.ones((total_seq, total_seq), dtype=torch.uint8))
prefix_mask[:n_patches, :n_patches] = 1
# Register this once... we'll multiply by padding masks prior to feeding to Transformer
self.register_buffer("prefix_mask", prefix_mask.view(1, 1, total_seq, total_seq))
# Initialize all Weights
self.initialize_weights()
# @AFTER INITIALIZATION -- Create Language Model & Language Reward MLP --> LM has requires_grad = False
# > For BERT models, our "embedding" is just going to be the last hidden state
# > Assumes inputs to forward pass are pre-tokenized!
self.tokenizer = transformers.AutoTokenizer.from_pretrained(language_model, cache_dir=hf_cache)
self.lm = transformers.AutoModel.from_pretrained(language_model, cache_dir=hf_cache)
self.lm.eval()
# Shape Assertion -- make sure self.language_dim actually is the same as the LM dimension!
assert self.lm.config.dim == self.language_dim, "Language model embedding dimension != self.language_dim!"
# Freeze the LM
for _, param in self.lm.named_parameters():
param.requires_grad = False
def initialize_weights(self) -> None:
# Position Encoding -- Fixed 2D Sine-Cosine Embeddings
enc_pe = get_2D_position_embeddings(
self.encoder_embed_dim, int(self.patch2embed.num_patches**0.5), cls_token=self.use_cls_token
)
self.encoder_pe.data.copy_(torch.from_numpy(enc_pe).float().unsqueeze(0))
dec_pe = get_2D_position_embeddings(
self.decoder_embed_dim, int(self.patch2embed.num_patches**0.5), cls_token=self.use_cls_token
)
self.decoder_pe.data.copy_(torch.from_numpy(dec_pe).float().unsqueeze(0))
# Initialize PatchEmbedding as a Linear...
nn.init.xavier_uniform_(self.patch2embed.proj.weight.data.view([self.patch2embed.proj.weight.data.shape[0], -1]))
# Initialize Mask Token, Img Token, Lang Token w/ Truncated Normal
nn.init.normal_(self.mask_token, std=0.02)
nn.init.normal_(self.img_enc_token, std=0.02)
nn.init.normal_(self.lang_enc_token, std=0.02)
if self.use_cls_token:
nn.init.normal_(self.cls_token, std=0.02)
# Everything else...
self.apply(self.transformer_initializer)
@staticmethod
def transformer_initializer(m: nn.Module) -> None:
if isinstance(m, nn.Linear):
# Use xavier_uniform following Jax ViT
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0.0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.weight, 1.0)
nn.init.constant_(m.bias, 0.0)
def embed_language(self, lang: torch.Tensor) -> torch.Tensor:
"""Only feed language through the pretrained *embedding* matrix (no bidirectional cheating)."""
self.lm.eval()
with torch.no_grad():
# Note :: These have position_encodings included... no need for separate `decoder_lang_pe`
token_embeddings = self.lm.embeddings(lang)
return token_embeddings
def encode_language(self, lang: torch.Tensor, lang_mask: torch.Tensor) -> torch.Tensor:
"""Encode language by feeding the *pre-tokenized text* through the frozen language model."""
self.lm.eval()
with torch.no_grad():
transformer_embeddings = self.lm(lang, attention_mask=lang_mask).last_hidden_state
return transformer_embeddings
def mask(
self, ctx_patches: torch.Tensor, mask_ratio: Optional[float] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Perform per-context random masking by shuffling :: uses argsort random noise to identify masked patches."""
bsz, ctx_len, n_patches, embed_dim = ctx_patches.shape
if mask_ratio is not None:
n_keep = int(n_patches * (1 - mask_ratio))
else:
n_keep = int(n_patches * (1 - self.mask_ratio))
# Sample noise of n_patches size, argsort to get shuffled IDs, argsort again to get "unshuffle"
# > For clarity -- argsort is an invertible transformation (if argsort `restore`, recovers `shuffle`)
# > Note that shuffle_idxs is defined solely as a function of *n_patches* and **not** context! Same mask!
shuffle_idxs = torch.argsort(torch.rand(bsz, n_patches, device=ctx_patches.device), dim=1)
restore_idxs = torch.argsort(shuffle_idxs, dim=1)
# Get "keep" (visible) patches --> make sure to get _same_ patches *across* context length!
visible_patches = torch.gather(
ctx_patches, dim=2, index=shuffle_idxs[:, None, :n_keep, None].repeat(1, ctx_len, 1, embed_dim)
)
# Generate the binary mask --> IMPORTANT :: `0` is keep, `1` is remove (following FAIR MAE convention)
mask = torch.ones(bsz, n_patches, device=ctx_patches.device)
mask[:, :n_keep] = 0
mask = torch.gather(mask, dim=1, index=restore_idxs)
return visible_patches, mask, restore_idxs
def get_representations(
self, imgs: torch.Tensor, language: Optional[Union[List[str], Tuple[str]]] = None, mode: str = "multimodal"
) -> torch.Tensor:
"""
Given either a singleton (dual-imgs, language) pair or batch of dual-imgs and language, extract representations
subject to the specified mode in < multimodal | visual >.
:param imgs: Processed batch of images :: [bsz, 2, 3, 224, 224]
:param language: Input language as `List[str] | Tuple[str] | None`
:param mode: Type of representations to extract -- `multimodal` (both vision+text), `visual` (visual only)
:return: Extracted representations given (imgs, language) input as sequence.
"""
assert (
imgs.ndim == 5
and imgs.shape[1] == 2
and (language is None or isinstance(language, list) or isinstance(language, tuple))
), "Invalid input to `get_representations()`"
assert mode in {"multimodal", "visual"}, f"Extraction mode `{mode}` not supported!"
# Tokenize Language --> note max length is 20!
if language is None:
lang, lang_mask = [torch.zeros(imgs.shape[0], 20, dtype=int, device=self.lm.device) for _ in range(2)]
lang[:, 0], lang_mask[:, 0] = self.tokenizer.cls_token_id, 1
else:
tokens = self.tokenizer(language, return_tensors="pt", max_length=20, padding="max_length", truncation=True)
lang, lang_mask = tokens["input_ids"].to(self.lm.device), tokens["attention_mask"].to(self.lm.device)
# Tile Language & Language Mask if mismatch with # images!
if not len(lang) == len(imgs):
lang = repeat(lang, "b seq -> (bsz b) seq", bsz=imgs.size(0))
lang_mask = repeat(lang_mask, "b seq -> (bsz b) seq", bsz=imgs.size(0))
# Extract desired representations...
representations = self.encode(imgs, lang, lang_mask)
return representations if mode == "multimodal" else representations[:, : -lang_mask.shape[-1]]
def encode(self, imgs: torch.Tensor, lang: torch.Tensor, lang_mask: torch.Tensor) -> torch.Tensor:
"""Default representation extraction function, given a batch of dual-images and tokenized language."""
lang_embeddings = self.encode_language(lang, lang_mask)
projected_language = self.lang2encoder(lang_embeddings)
# Patchify, broadcast position embedding across ctx_len (0 + K) dimension, unfold, add `ctx_enc_pe` embeddings!
patches = self.patch2embed(rearrange(imgs, "bsz ctx channels res1 res2 -> (bsz ctx) channels res1 res2"))
patches_pe = patches + (self.encoder_pe[:, 1:, :] if self.use_cls_token else self.encoder_pe)
ctx_patches = rearrange(patches_pe, "(bsz ctx) seq embed -> bsz ctx seq embed", ctx=2)
ctx_patches_pe = ctx_patches + self.ctx_enc_pe[:, :2, ...]
# Add "modality" embeddings to patches & language & flatten out context patches...
img_ctx_embeddings, lang_embeddings = (
ctx_patches_pe + self.img_enc_token,
projected_language + self.lang_enc_token,
)
img_embeddings = rearrange(img_ctx_embeddings, "bsz ctx seq embed -> bsz (ctx seq) embed")
# (Optional) Token Handling
if self.use_cls_token:
cls_token_pe = self.cls_token + self.encoder_pe[:, :1, :] + self.img_enc_token[:, 0, :, :]
cls_tokens = cls_token_pe.expand(imgs.shape[0], -1, -1)
img_embeddings = torch.cat([cls_tokens, img_embeddings], dim=1)
# Create "dummy" visible mask, concatenate image patches & language, feed to Transformer
patches_mask = torch.ones_like(img_embeddings[..., -1], dtype=lang_mask.dtype)
multimodal_embeddings = torch.cat([img_embeddings, lang_embeddings], dim=1) # Merge on sequence length...
multimodal_mask = torch.cat([patches_mask, lang_mask], dim=1) # Merge on sequence length...
# Apply Transformer Blocks...
for block in self.encoder_blocks:
multimodal_embeddings = block(multimodal_embeddings, multimodal_mask)
multimodal_embeddings = self.encoder_norm(multimodal_embeddings)
# Return the full sequence of multimodal embeddings (but ignore 0th frame) => the `~` denote what to remove!
# => [CLS] + ~[n_patches x 0th frame]~ + [n_patches x Kth frame] + [max_lang_len language]
# => ~[n_patches x 0th frame]~ + [n_patches x Kth frame] + [max_lang_len language]
if self.use_cls_token:
return torch.cat(
[multimodal_embeddings[:, :1, :], multimodal_embeddings[:, 1 + self.patch2embed.num_patches :, :]], dim=1
)
else:
return multimodal_embeddings[:, self.patch2embed.num_patches :]
def score(self, imgs: torch.Tensor, langs: torch.Tensor, lang_masks: torch.Tensor) -> torch.Tensor:
"""
Given an example 0-K pair and a set of k language instructions, output scores under the generative language
model for each instruction.
:param imgs: 0-K pairs --> [1, 2, 3, 224, 224]
:param langs: Tokenized language input --> [1, k, seq]
:param lang_masks: Language padding masks --> [1, k, seq]
:return: [1, k] Tensor of LM probabilities given imgs.
"""
# Blank out the "encoder" language --> just [ = 101, 0 ...]
blank_lang = torch.zeros(1, self.max_lang_len, dtype=torch.int64, device=imgs.device)
blank_lang_mask = torch.zeros(1, self.max_lang_len, dtype=torch.int64, device=imgs.device)
blank_lang[0][0], blank_lang_mask[0][0] = 101, 1
# === Encoder Forward ===
lang_embeddings = self.encode_language(blank_lang, blank_lang_mask)
projected_language = self.lang2encoder(lang_embeddings)
# Patchify, broadcast position embedding across ctx_len (0 + K) dimension, unfold, add `ctx_enc_pe` embeddings!
patches = self.patch2embed(rearrange(imgs, "bsz ctx channels res1 res2 -> (bsz ctx) channels res1 res2"))
patches_pe = patches + (self.encoder_pe[:, 1:, :] if self.use_cls_token else self.encoder_pe)
ctx_patches = rearrange(patches_pe, "(bsz ctx) seq embed -> bsz ctx seq embed", ctx=2)
ctx_patches_pe = ctx_patches + self.ctx_enc_pe[:, :2, ...]
# Add "modality" embeddings to patches & language & flatten out context patches...
img_ctx_embeddings, lang_embeddings = (
ctx_patches_pe + self.img_enc_token,
projected_language + self.lang_enc_token,
)
img_embeddings = rearrange(img_ctx_embeddings, "bsz ctx seq embed -> bsz (ctx seq) embed")
# (Optional) Token Handling
if self.use_cls_token:
cls_token_pe = self.cls_token + self.encoder_pe[:, :1, :] + self.img_enc_token[:, 0, :, :]
cls_tokens = cls_token_pe.expand(imgs.shape[0], -1, -1)
img_embeddings = torch.cat([cls_tokens, img_embeddings], dim=1)
# Create "dummy" visible mask, concatenate image patches & language, feed to Transformer
patches_mask = torch.ones_like(img_embeddings[..., -1], dtype=blank_lang_mask.dtype)
multimodal_embeddings = torch.cat([img_embeddings, lang_embeddings], dim=1) # Merge on sequence length...
multimodal_mask = torch.cat([patches_mask, blank_lang_mask], dim=1) # Merge on sequence length...
# Apply Transformer Blocks...
for block in self.encoder_blocks:
multimodal_embeddings = block(multimodal_embeddings, multimodal_mask)
multimodal_embeddings = self.encoder_norm(multimodal_embeddings)
# Split multimodal embedding, remove language, and return only the (CLS +) 0th + Kth frame patches
enc_patches = multimodal_embeddings[:, : -blank_lang_mask.shape[-1], ...]
# === Encoder =>> Decoder Hand-Off ===
enc_patches = repeat(enc_patches, "b cseq embed -> (bsz b) cseq embed", bsz=langs.size(0))
lang_gen_embeddings = self.embed_language(langs)
# === Decoder Forward ===
projected_patches = self.encoder2decoder(enc_patches)
projected_lang_gen = self.lang2decoder(lang_gen_embeddings)
# (Optional) Token Handling
if self.use_cls_token:
projected_ctx_patches = rearrange(
projected_patches[:, 1:, :], "bsz (ctx seq) embed -> bsz ctx seq embed", ctx=2
)
else:
projected_ctx_patches = rearrange(projected_patches, "bsz (ctx seq) embed -> bsz ctx seq embed", ctx=2)
# Add position embeddings, `ctx_dec_pe` embeddings, and flatten patches for Transformer...
decoder_ctx_patches_pe = projected_ctx_patches + (
self.decoder_pe[None, ...] if not self.use_cls_token else self.decoder_pe[None, :, 1:, :]
)
decoder_ctx_patches = decoder_ctx_patches_pe + self.ctx_dec_pe[:, :2, ...]
decoder_patches = rearrange(decoder_ctx_patches, "bsz ctx seq embed -> bsz (ctx seq) embed")
# (Optional) Token Handling
if self.use_cls_token:
# Add back Token from `projected_patches[:, :1, :]`
cls_embedding = projected_patches[:, :1, :] + self.decoder_pe[:, :1, :]
decoder_patches = torch.cat([cls_embedding, decoder_patches], dim=1)
# Add language -> create "mask" by multiply padding by self.prefix_mask
decoder_patches_mask = torch.ones_like(decoder_patches[..., -1], dtype=lang_masks.dtype)
multimodal_embedding = torch.cat([decoder_patches, projected_lang_gen], dim=1) # Merge on sequence length...
multimodal_mask = torch.cat([decoder_patches_mask, lang_masks], dim=1) # Merge on sequence length...
# Compute prefix_padded_mask
prefix_padded_mask = rearrange(multimodal_mask, "bsz seq -> bsz 1 seq 1") * self.prefix_mask
# Apply Transformer Blocks...
for block in self.decoder_blocks:
multimodal_embedding = block(multimodal_embedding, prefix_padded_mask)
multimodal_embedding = self.decoder_norm(multimodal_embedding)
# Split multimodal embedding into *just* the language + project!
lang = multimodal_embedding[:, -lang_masks.shape[-1] :, ...]
generations = self.decoder_lang_prediction(lang)
# Compute cross-entropy loss (multiply by -1 for "final scoring") --> log-likelihood!
bsz, seq = langs.shape
lang_logits = rearrange(generations[:, :-1, ...], "bsz seq vocab -> (bsz seq) vocab")
lang_targets = rearrange(langs[:, 1:], "bsz seq -> (bsz seq)")
lang_loss_mask = lang_masks[:, :-1] # Defined where valid...
ce_loss = F.cross_entropy(lang_logits, lang_targets, reduction="none")
per_token_loss = rearrange(ce_loss, "(bsz seq) -> bsz seq", bsz=bsz, seq=seq - 1) # -1 because causal mask...
# Compute loss only on *non-padded* and *non-ignored* tokens...
lang_example_loss = (per_token_loss * lang_loss_mask).sum(dim=-1) / lang_loss_mask.sum(dim=-1)
return -1 * lang_example_loss.detach()
def forward_encoder(
self,
img_ctx: torch.Tensor,
lang_con: torch.Tensor,
lang_con_mask: torch.Tensor,
mask_ratio: Optional[float] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
lang_embeddings = self.encode_language(lang_con, lang_con_mask)
projected_lang = self.lang2encoder(lang_embeddings)
# Reshape image context to apply masking *identically*
# Patchify, broadcast position embedding across ctx_len (0 + K) dimension, unfold, add `ctx_enc_pe` embeddings!
patches = self.patch2embed(rearrange(img_ctx, "bsz ctx channels res1 res2 -> (bsz ctx) channels res1 res2"))
patches_pe = patches + (self.encoder_pe if not self.use_cls_token else self.encoder_pe[:, 1:, :])
ctx_patches = rearrange(patches_pe, "(bsz ctx) seq embed -> bsz ctx seq embed", ctx=2)
ctx_patches_pe = ctx_patches + self.ctx_enc_pe[:, :2, ...]
# Create mask (and go ahead and mask out patches at the same time)
visible_ctx_patches, mask, restore_idxs = self.mask(ctx_patches_pe, mask_ratio)
# Add "modality" embeddings to patches & language & flatten out context patches...
visible_ctx_patches, lang = visible_ctx_patches + self.img_enc_token, projected_lang + self.lang_enc_token
visible_patches = rearrange(visible_ctx_patches, "bsz ctx seq embed -> bsz (ctx seq) embed")
# (Optional) Token Handling
if self.use_cls_token:
cls_token_pe = self.cls_token + self.encoder_pe[:, :1, :] + self.img_enc_token[:, 0, :, :]
cls_tokens = cls_token_pe.expand(img_ctx.shape[0], -1, -1)
visible_patches = torch.cat([cls_tokens, visible_patches], dim=1)
# Create "dummy" visible mask, concatenate image patches & language, feed to Transformer...
visible_mask = torch.ones_like(visible_patches[..., -1], dtype=lang_con_mask.dtype)
multimodal_embedding = torch.cat([visible_patches, lang], dim=1) # Merge on sequence length...
multimodal_mask = torch.cat([visible_mask, lang_con_mask], dim=1) # Merge on sequence length...
# Apply Transformer Blocks...
for block in self.encoder_blocks:
multimodal_embedding = block(multimodal_embedding, multimodal_mask)
multimodal_embedding = self.encoder_norm(multimodal_embedding)
# Split multimodal embedding, remove language, return the visible ctx (0th + Kth frame) patches (+ )!
visible_patches = multimodal_embedding[:, : -lang_con_mask.shape[-1], ...]
return visible_patches, mask, restore_idxs
def forward_decoder(
self,
visible_patches: torch.Tensor,
restore_idxs: torch.Tensor,
lang_gen: torch.Tensor,
lang_gen_mask: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Project patches & lang_gen into decoder dimension (visible_patches :: [bsz, (CLS) + 2 * seq, enc_embed])
projected_patches = self.encoder2decoder(visible_patches)
projected_lang_gen = self.lang2decoder(lang_gen)
visible_per_frame = (projected_patches.shape[1] - (1 if self.use_cls_token else 0)) // 2
# Add Mask Tokens to Sequence and Unshuffle
mask_tokens = self.mask_token.repeat(projected_patches.shape[0], 2, restore_idxs.shape[1] - visible_per_frame, 1)
# (Optional) Token Handling
if self.use_cls_token:
# Remove CLS Token as part of "unshuffling"
projected_ctx_patches = rearrange(
projected_patches[:, 1:, :], "bsz (ctx seq) embed -> bsz ctx seq embed", ctx=2
)
no_cls_concatenated_ctx_patches = torch.cat([projected_ctx_patches, mask_tokens], dim=2)
unshuffled_ctx_patches = torch.gather(
no_cls_concatenated_ctx_patches,
dim=2,
index=restore_idxs[:, None, ..., None].repeat(1, 2, 1, self.decoder_embed_dim),
)
else:
projected_ctx_patches = rearrange(projected_patches, "bsz (ctx seq) embed -> bsz ctx seq embed", ctx=2)
concatenated_ctx_patches = torch.cat([projected_ctx_patches, mask_tokens], dim=2)
unshuffled_ctx_patches = torch.gather(
concatenated_ctx_patches,
dim=2,
index=restore_idxs[:, None, ..., None].repeat(1, 2, 1, self.decoder_embed_dim),
)
# Add position embeddings, `ctx_dec_pe` embeddings, and flatten patches for Transformer...
decoder_ctx_patches_pe = unshuffled_ctx_patches + (
self.decoder_pe[None, ...] if not self.use_cls_token else self.decoder_pe[None, :, 1:, :]
)
decoder_ctx_patches = decoder_ctx_patches_pe + self.ctx_dec_pe[:, :2, ...]
decoder_patches = rearrange(decoder_ctx_patches, "bsz ctx seq embed -> bsz (ctx seq) embed")
# (Optional) Token Handling
if self.use_cls_token:
# Add back Token from `projected_patches[:, :1, :]`
cls_embedding = projected_patches[:, :1, :] + self.decoder_pe[:, :1, :]
decoder_patches = torch.cat([cls_embedding, decoder_patches], dim=1)
# Add language -> create "mask" by multiply padding by self.prefix_mask
decoder_patches_mask = torch.ones_like(decoder_patches[..., -1], dtype=lang_gen_mask.dtype)
multimodal_embedding = torch.cat([decoder_patches, projected_lang_gen], dim=1) # Merge on sequence length...
multimodal_mask = torch.cat([decoder_patches_mask, lang_gen_mask], dim=1) # Merge on sequence length...
# Compute prefix_padded_mask
prefix_padded_mask = rearrange(multimodal_mask, "bsz seq -> bsz 1 seq 1") * self.prefix_mask
# Apply Transformer Blocks...
for block in self.decoder_blocks:
multimodal_embedding = block(multimodal_embedding, prefix_padded_mask)
multimodal_embedding = self.decoder_norm(multimodal_embedding)
# Split multimodal embedding into patches and language...
patches_ctx = multimodal_embedding[:, : -lang_gen_mask.shape[-1], ...]
patches = rearrange(
patches_ctx if not self.use_cls_token else patches_ctx[:, 1:, :],
"bsz (ctx seq) embed -> bsz ctx seq embed",
ctx=2,
)
lang = multimodal_embedding[:, -lang_gen_mask.shape[-1] :, ...]
# Project each up to the output space...
reconstructions = self.decoder_patch_prediction(patches)
generations = self.decoder_lang_prediction(lang)
return reconstructions, generations
def patchify(self, imgs: torch.Tensor) -> torch.Tensor:
"""Convert a batch of (0th + Kth frame) images to their patched equivalents by naive reshaping."""
return rearrange(
imgs,
"bsz ctx c (height patch_h) (width patch_w) -> bsz ctx (height width) (patch_h patch_w c)",
patch_h=self.patch_size,
patch_w=self.patch_size,
)
def compute_loss(
self,
imgs: torch.Tensor,
ctx_reconstructions: torch.Tensor,
mask: torch.Tensor,
lang: torch.Tensor,
generated_language: torch.Tensor,
lang_gen_mask: torch.Tensor,
lang_gen_weight: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# Reconstruction Loss...
assert self.norm_pixel_loss, "`norm_pixel_loss` should always be true... false only for visualizations!"
targets = self.patchify(imgs)
# Normalize targets...
mu, var = targets.mean(dim=-1, keepdim=True), targets.var(dim=-1, unbiased=True, keepdim=True)
targets = (targets - mu) / ((var + 1e-6) ** 0.5)
# Split targets into 0 and K --> do the same for ctx_reconstructions
zero_target, k_target = targets[:, 0, ...], targets[:, 1, ...]
zero_reconstruction, k_reconstruction = ctx_reconstructions[:, 0, ...], ctx_reconstructions[:, 1, ...]
# Compute mean losses per patch first...
zero_mse, k_mse = (zero_reconstruction - zero_target) ** 2, (k_reconstruction - k_target) ** 2
zero_avg_loss_per_patch, k_avg_loss_per_patch = zero_mse.mean(dim=-1), k_mse.mean(dim=-1)
# Compute reconstruction losses...
zero_loss = (zero_avg_loss_per_patch * mask).sum() / mask.sum()
k_loss = (k_avg_loss_per_patch * mask).sum() / mask.sum()
reconstruction_loss = (zero_loss + k_loss) / 2
# Language Loss...
bsz, seq = lang.shape
lang_logits = rearrange(generated_language[:, :-1, ...], "bsz seq vocab -> (bsz seq) vocab")
lang_targets = rearrange(lang[:, 1:], "bsz seq -> (bsz seq)")
lang_loss_mask = lang_gen_mask[:, :-1] # Defined where valid...
ce_loss = F.cross_entropy(lang_logits, lang_targets, reduction="none")
per_token_loss = rearrange(ce_loss, "(bsz seq) -> bsz seq", bsz=bsz, seq=seq - 1) # -1 because causal mask...
# Compute loss only on *non-padded* and *non-ignored* tokens...
lang_example_loss = (per_token_loss * lang_loss_mask).sum(dim=-1) / lang_loss_mask.sum(dim=-1)
lang_loss = (lang_example_loss * lang_gen_weight).sum() / (self.eps + lang_gen_weight.sum()) # Divide by 0...
# TODO (Remove) -- NaN Check...
if reconstruction_loss.isnan().any() or lang_loss.isnan().any():
# fmt: off
print(
f"Found Nan -- "
f"ctx_reconstructions: {ctx_reconstructions.isnan().any()} -- "
f"generated_language: {generated_language.isnan().any()} -- "
f"zero_avg_loss_per_patch: {zero_avg_loss_per_patch.isnan().any()} -- "
f"k_avg_loss_per_patch: {k_avg_loss_per_patch.isnan().any()} -- "
f"zero_loss: {zero_loss.isnan().any()} -- "
f"k_loss: {k_loss.isnan().any()} -- "
f"reconstruction_loss: {reconstruction_loss.isnan().any()} -- "
f"ce_loss: {ce_loss.isnan().any()} -- "
f"per_token_loss: {per_token_loss.isnan().any()} -- "
f"lang_example_loss: {lang_example_loss.isnan().any()} -- "
f"lang_loss: {lang_loss.isnan().any()}"
)
exit(1)
# fmt: on
# Compute weighted loss...
loss = self.mae_weight * reconstruction_loss + self.lm_weight * lang_loss
return loss, reconstruction_loss, lang_loss, zero_loss, k_loss
def forward(
self,
imgs: torch.Tensor,
lang_con: torch.Tensor,
lang_con_mask: torch.Tensor,
lang_gen: torch.Tensor,
lang_gen_mask: torch.Tensor,
lang_gen_weight: torch.Tensor,
mask_ratio: Optional[float] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[torch.Tensor]]:
"""
Run a forward pass through the model, computing the MAE reconstruction loss (language-conditioned if applicable)
on the 0th + Kth frame temporal context, as well as the generated language given masked context (if applicable).
:param imgs: A [bsz, 2, in_channels, resolution, resolution] tensor of (0th frame, Kth frame) sequences.
:param lang_con: A [bsz, seq_len] tensor of language context to condition on.
:param lang_con_mask: A [bsz, seq_len] binary mask tensor to indicate padding/null in `lang_condition`.
:param lang_gen: A [bsz, seq_len] tensor of language to generate.
:param lang_gen_mask: A [bsz, seq_len] binary mask tensor to indicate padding/null in `lang_gen`.
:param lang_gen_weight: A [bsz] tensor of per-example weights to indicate when to 0 `lm` loss.
:param mask_ratio: Optional masking ratio to use instead of the default.
:return: Tuple of losses and intermediates, as follows:
> (combined loss, reconstruction loss, lm loss, [reconstruction loss per frame in {0, K}])
"""
visible_ctx_patches, mask, restore_idxs = self.forward_encoder(imgs, lang_con, lang_con_mask, mask_ratio)
# Get token embeddings -- *NOT CONTEXTUAL* -- for the lang_gen tokens...
lang_gen_embeddings = self.embed_language(lang_gen)
# Run patches, and lang_gen through decoder --> note that we need a causal mask on language generation...
ctx_reconstructions, generated_language = self.forward_decoder(
visible_ctx_patches, restore_idxs, lang_gen_embeddings, lang_gen_mask
)
# Compute loss for reconstructed patches & generated language
loss, reconstruction_loss, lang_loss, zero_loss, k_loss = self.compute_loss(
imgs, ctx_reconstructions, mask, lang_gen, generated_language, lang_gen_mask, lang_gen_weight
)
return loss, reconstruction_loss, lang_loss, [zero_loss, k_loss]
def configure_optimizer(self) -> Tuple[torch.optim.Optimizer, Callable[[int, float], float]]:
# Short-Circuit on Valid Optimizers
if self.optimizer not in ["adamw"]:
raise NotImplementedError(f"Optimizer `{self.optimizer}` not supported - try [`adamw`] instead!")
# Create Parameter Groups --> Bias terms, Normalization layer parameters shouldn't be decayed...
# > This is a compact rewrite of `param_groups_weight_decay()` from TIMM because I don't want the dependency
decay, no_decay = [], []
for name, param in self.named_parameters():
if not param.requires_grad:
continue
# Check on any parameters with fewer than 2 dimensions or with "bias" in the name...
if param.ndim <= 1 or name.endswith(".bias"):
no_decay.append(param)
else:
decay.append(param)
# Build Parameter Groups
groups = [{"params": decay, "weight_decay": self.weight_decay}, {"params": no_decay, "weight_decay": 0.0}]
# Compute LR -- MAE uses the `linear scaling rule` :: lr = base_lr * (effective_bsz / 256)
# > https://github.com/facebookresearch/mae/blob/main/PRETRAIN.md
self.lr = self.base_lr * (self.effective_bsz / 256)
# Create Optimizer & LR Scheduler
optimizer = torch.optim.AdamW(groups, lr=self.lr, betas=self.betas)
update_lr = get_lr_update(optimizer, self.schedule, self.lr, self.min_lr, self.warmup_epochs, self.max_epochs)
return optimizer, update_lr
================================================
FILE: voltron/models/instantiate.py
================================================
"""
instantiate.py
Simple wrapping script for instantiating a core Voltron/reproduction model and configuring the torch.Optimizer for DDP
pretraining. Meant to be modular and extensible!
"""
from typing import Callable, Tuple
import torch.nn as nn
from torch.optim import Optimizer
from voltron.conf import DatasetConfig, ModelConfig
from .core.vcond import VCond
from .core.vdual import VDual
from .core.vgen import VGen
from .reproductions.vmvp import VMVP
from .reproductions.vr3m import VR3M
from .reproductions.vrn3m import VRN3M
def get_model_optimizer(
model_cfg: ModelConfig, dataset_cfg: DatasetConfig
) -> Tuple[nn.Module, Optimizer, Callable[[int, float], float]]:
"""Switch on `model_cfg.arch` --> instantiate the correct nn.Module and Optimizer (on CPU/default device)."""
# Data-Locked Reproductions
if model_cfg.arch == "v-mvp":
model = VMVP(
resolution=dataset_cfg.resolution,
patch_size=model_cfg.patch_size,
encoder_depth=model_cfg.encoder_depth,
encoder_embed_dim=model_cfg.encoder_embed_dim,
encoder_n_heads=model_cfg.encoder_n_heads,
decoder_depth=model_cfg.decoder_depth,
decoder_embed_dim=model_cfg.decoder_embed_dim,
decoder_n_heads=model_cfg.decoder_n_heads,
optimizer=model_cfg.optimizer,
schedule=model_cfg.schedule,
base_lr=model_cfg.base_lr,
min_lr=model_cfg.min_lr,
effective_bsz=model_cfg.effective_bsz,
betas=model_cfg.betas,
weight_decay=model_cfg.weight_decay,
warmup_epochs=dataset_cfg.warmup_epochs,
max_epochs=dataset_cfg.max_epochs,
mlp_ratio=model_cfg.mlp_ratio,
norm_pixel_loss=model_cfg.norm_pixel_loss,
)
elif model_cfg.arch == "v-r3m":
model = VR3M(
resolution=dataset_cfg.resolution,
patch_size=model_cfg.patch_size,
depth=model_cfg.depth,
embed_dim=model_cfg.embed_dim,
n_heads=model_cfg.n_heads,
language_model=model_cfg.language_model,
hf_cache=model_cfg.hf_cache,
language_dim=model_cfg.language_dim,
reward_dim=model_cfg.reward_dim,
n_negatives=model_cfg.n_negatives,
lang_reward_weight=model_cfg.lang_reward_weight,
tcn_weight=model_cfg.tcn_weight,
l1_weight=model_cfg.l1_weight,
l2_weight=model_cfg.l2_weight,
optimizer=model_cfg.optimizer,
schedule=model_cfg.schedule,
lr=model_cfg.lr,
min_lr=model_cfg.min_lr,
warmup_epochs=dataset_cfg.warmup_epochs,
max_epochs=dataset_cfg.max_epochs,
mlp_ratio=model_cfg.mlp_ratio,
)
elif model_cfg.arch == "v-rn3m":
model = VRN3M(
resolution=dataset_cfg.resolution,
fc_dim=model_cfg.fc_dim,
language_model=model_cfg.language_model,
hf_cache=model_cfg.hf_cache,
language_dim=model_cfg.language_dim,
reward_dim=model_cfg.reward_dim,
n_negatives=model_cfg.n_negatives,
lang_reward_weight=model_cfg.lang_reward_weight,
tcn_weight=model_cfg.tcn_weight,
l1_weight=model_cfg.l1_weight,
l2_weight=model_cfg.l2_weight,
optimizer=model_cfg.optimizer,
lr=model_cfg.lr,
)
# Voltron Models
elif model_cfg.arch == "v-cond":
model = VCond(
resolution=dataset_cfg.resolution,
patch_size=model_cfg.patch_size,
encoder_depth=model_cfg.encoder_depth,
encoder_embed_dim=model_cfg.encoder_embed_dim,
encoder_n_heads=model_cfg.encoder_n_heads,
decoder_depth=model_cfg.decoder_depth,
decoder_embed_dim=model_cfg.decoder_embed_dim,
decoder_n_heads=model_cfg.decoder_n_heads,
language_model=model_cfg.language_model,
hf_cache=model_cfg.hf_cache,
language_dim=model_cfg.language_dim,
optimizer=model_cfg.optimizer,
schedule=model_cfg.schedule,
base_lr=model_cfg.base_lr,
min_lr=model_cfg.min_lr,
effective_bsz=model_cfg.effective_bsz,
betas=model_cfg.betas,
weight_decay=model_cfg.weight_decay,
warmup_epochs=dataset_cfg.warmup_epochs,
max_epochs=dataset_cfg.max_epochs,
mlp_ratio=model_cfg.mlp_ratio,
norm_pixel_loss=model_cfg.norm_pixel_loss,
)
elif model_cfg.arch == "v-dual":
model = VDual(
resolution=dataset_cfg.resolution,
patch_size=model_cfg.patch_size,
encoder_depth=model_cfg.encoder_depth,
encoder_embed_dim=model_cfg.encoder_embed_dim,
encoder_n_heads=model_cfg.encoder_n_heads,
decoder_depth=model_cfg.decoder_depth,
decoder_embed_dim=model_cfg.decoder_embed_dim,
decoder_n_heads=model_cfg.decoder_n_heads,
language_model=model_cfg.language_model,
hf_cache=model_cfg.hf_cache,
language_dim=model_cfg.language_dim,
optimizer=model_cfg.optimizer,
schedule=model_cfg.schedule,
base_lr=model_cfg.base_lr,
min_lr=model_cfg.min_lr,
effective_bsz=model_cfg.effective_bsz,
betas=model_cfg.betas,
weight_decay=model_cfg.weight_decay,
warmup_epochs=dataset_cfg.warmup_epochs,
max_epochs=dataset_cfg.max_epochs,
mlp_ratio=model_cfg.mlp_ratio,
norm_pixel_loss=model_cfg.norm_pixel_loss,
)
elif model_cfg.arch == "v-gen":
model = VGen(
resolution=dataset_cfg.resolution,
patch_size=model_cfg.patch_size,
encoder_depth=model_cfg.encoder_depth,
encoder_embed_dim=model_cfg.encoder_embed_dim,
encoder_n_heads=model_cfg.encoder_n_heads,
decoder_depth=model_cfg.decoder_depth,
decoder_embed_dim=model_cfg.decoder_embed_dim,
decoder_n_heads=model_cfg.decoder_n_heads,
language_model=model_cfg.language_model,
hf_cache=model_cfg.hf_cache,
language_dim=model_cfg.language_dim,
max_lang_len=dataset_cfg.max_lang_len,
vocab_size=model_cfg.vocab_size,
mae_weight=model_cfg.mae_weight,
lm_weight=model_cfg.lm_weight,
optimizer=model_cfg.optimizer,
schedule=model_cfg.schedule,
base_lr=model_cfg.base_lr,
min_lr=model_cfg.min_lr,
effective_bsz=model_cfg.effective_bsz,
betas=model_cfg.betas,
weight_decay=model_cfg.weight_decay,
warmup_epochs=dataset_cfg.warmup_epochs,
max_epochs=dataset_cfg.max_epochs,
mlp_ratio=model_cfg.mlp_ratio,
norm_pixel_loss=model_cfg.norm_pixel_loss,
)
else:
raise ValueError(f"Model Architecture `{model_cfg.arch}` is not implemented!")
# Configure Optimizer --> on same device (CPU)
optimizer, update_lr = model.configure_optimizer()
return model, optimizer, update_lr
================================================
FILE: voltron/models/materialize.py
================================================
"""
materialize.py
Core functionality for using pretrained models; defines the package-level `load` functionality for downloading and
instantiating pretrained Voltron (and baseline) models.
"""
import json
import os
from pathlib import Path
from typing import Callable, List, Tuple
import gdown
import torch
import torch.nn as nn
import torchvision.transforms as T
from voltron.models import VMVP, VR3M, VRN3M, VCond, VDual, VGen
# === Define Useful Variables for Loading Models ===
DEFAULT_CACHE = "cache/"
NORMALIZATION = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
# Pretrained Model Registry :: "model id" -> {"config" -> gdown ID, "checkpoint" -> gdown ID, "cls" -> Model Class}
MODEL_REGISTRY = {
# === Voltron ViT-Small (Sth-Sth) Models ===
"v-cond": {
"config": "1O4oqRIblfS6PdFlZzUcYIX-Rqe6LbvnD",
"checkpoint": "12g5QckQSMKqrfr4lFY3UPdy7oLw4APpG",
"cls": VCond,
},
"v-dual": {
"config": "1zgKiK81SF9-0lg0XbMZwNhUh1Q7YdZZU",
"checkpoint": "1CCRqrwcvF8xhIbJJmwnCbcWfWTJCK40T",
"cls": VDual,
},
"v-gen": {
"config": "18-mUBDsr-2_-KrGoL2E2YzjcUO8JOwUF",
"checkpoint": "1TzSQpKVKBWKCSvYJf22c45hrKczTQz24",
"cls": VGen,
},
# === Voltron ViT-Base Model ===
"v-cond-base": {
"config": "1CLe7CaIzTEcGCijIgw_S-uqMXHfBFSLI",
"checkpoint": "1PwczOijL0hfYD8DI4xLOPLf1xL_7Kg9S",
"cls": VCond,
},
# === Data-Locked Reproductions ===
"r-mvp": {
"config": "1KKNWag6aS1xkUiUjaJ1Khm9D6F3ROhCR",
"checkpoint": "1-ExshZ6EC8guElOv_s-e8gOJ0R1QEAfj",
"cls": VMVP,
},
"r-r3m-vit": {
"config": "1JGk32BLXwI79uDLAGcpbw0PiupBknf-7",
"checkpoint": "1Yby5oB4oPc33IDQqYxwYjQV3-56hjCTW",
"cls": VR3M,
},
"r-r3m-rn50": {
"config": "1OS3mB4QRm-MFzHoD9chtzSmVhOA-eL_n",
"checkpoint": "1t1gkQYr6JbRSkG3fGqy_9laFg_54IIJL",
"cls": VRN3M,
},
}
def available_models() -> List[str]:
return list(MODEL_REGISTRY.keys())
def load(
model_id: str, device: torch.device = "cpu", freeze: bool = True, cache: str = DEFAULT_CACHE
) -> Tuple[nn.Module, Callable[[torch.Tensor], torch.Tensor]]:
"""
Download & cache specified model configuration & checkpoint, then load & return module & image processor.
Note :: We *override* the default `forward()` method of each of the respective model classes with the
`extract_features` method --> by default passing "NULL" language for any language-conditioned models.
This can be overridden either by passing in language (as a `str) or by invoking the corresponding methods.
"""
assert model_id in MODEL_REGISTRY, f"Model ID `{model_id}` not valid, try one of {list(MODEL_REGISTRY.keys())}"
# Download Config & Checkpoint (if not in cache)
model_cache = Path(cache) / model_id
config_path, checkpoint_path = model_cache / f"{model_id}-config.json", model_cache / f"{model_id}.pt"
os.makedirs(model_cache, exist_ok=True)
if not checkpoint_path.exists() or not config_path.exists():
gdown.download(id=MODEL_REGISTRY[model_id]["config"], output=str(config_path), quiet=False)
gdown.download(id=MODEL_REGISTRY[model_id]["checkpoint"], output=str(checkpoint_path), quiet=False)
# Load Configuration --> patch `hf_cache` key if present (don't download to random locations on filesystem)
with open(config_path, "r") as f:
model_kwargs = json.load(f)
if "hf_cache" in model_kwargs:
model_kwargs["hf_cache"] = str(Path(cache) / "hf-cache")
# By default, the model's `__call__` method defaults to `forward` --> for downstream applications, override!
# > Switch `__call__` to `get_representations`
MODEL_REGISTRY[model_id]["cls"].__call__ = MODEL_REGISTRY[model_id]["cls"].get_representations
# Materialize Model (load weights from checkpoint; note that unused element `_` are the optimizer states...)
model = MODEL_REGISTRY[model_id]["cls"](**model_kwargs)
state_dict, _ = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(state_dict, strict=True)
model.to(device)
model.eval()
# Freeze model parameters if specified (default: True)
if freeze:
for _, param in model.named_parameters():
param.requires_grad = False
# Build Visual Preprocessing Transform (assumes image is read into a torch.Tensor, but can be adapted)
if model_id in {"v-cond", "v-dual", "v-gen", "v-cond-base", "r-mvp"}:
# All models except R3M are by default normalized subject to default IN1K normalization...
preprocess = T.Compose(
[
T.Resize(model_kwargs["resolution"]),
T.CenterCrop(model_kwargs["resolution"]),
T.ConvertImageDtype(torch.float),
T.Normalize(mean=NORMALIZATION[0], std=NORMALIZATION[1]),
]
)
else:
# R3M models (following original work) expect unnormalized images with values in range [0 - 255)
preprocess = T.Compose(
[
T.Resize(model_kwargs["resolution"]),
T.CenterCrop(model_kwargs["resolution"]),
T.ConvertImageDtype(torch.float),
T.Lambda(lambda x: x * 255.0),
]
)
return model, preprocess
================================================
FILE: voltron/models/reproductions/__init__.py
================================================
================================================
FILE: voltron/models/reproductions/vmvp.py
================================================
"""
vmvp.py
PyTorch Module defining a basic MAE a la Masked Visual Pretraining for Motor Control (MVP), with the requisite
hyperparameters - as defined in the original ImageMAE paper, and as used by both MVP papers.
References:
- https://github.com/facebookresearch/mae
- https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
"""
from typing import Callable, Optional, Tuple
import torch
import torch.nn as nn
from einops import rearrange
from voltron.models.util.optimization import get_lr_update
from voltron.models.util.transformer import Block, PatchEmbed, get_2D_position_embeddings
class VMVP(nn.Module):
def __init__(
self,
resolution: int,
patch_size: int,
encoder_depth: int,
encoder_embed_dim: int,
encoder_n_heads: int,
decoder_depth: int,
decoder_embed_dim: int,
decoder_n_heads: int,
optimizer: str,
schedule: str,
base_lr: float,
min_lr: float,
effective_bsz: float,
betas: Tuple[float, float],
weight_decay: float,
warmup_epochs: int,
max_epochs: int,
mask_ratio: float = 0.75,
mlp_ratio: float = 4.0,
in_channels: int = 3,
norm_pixel_loss: bool = True,
):
"""
Initialize an VMVP (MAE) model with the requisite architecture parameters.
:param resolution: Base image resolution -- usually 224 (ImageNet size).
:param patch_size: Height/Width of each patch in pixels -- usually 16.
:param encoder_depth: Number of Transformer blocks in the encoder -- should be greater than decoder.
:param encoder_embed_dim: Core embedding/hidden dimension for encoder vision transformer backbone.
:param encoder_n_heads: Number of heads for encoder multi-headed self-attention.
:param decoder_depth: Number of Transformer blocks in the decoder -- should be relatively shallow.
:param decoder_embed_dim: Core embedding/hidden dimension for encoder vision transformer backbone.
:param decoder_n_heads: Number of heads for encoder multi-headed self-attention.
:param optimizer: String denoting which optimizer to use (for MAEs, usually `adamw`)
:param schedule: Learning rate schedule to use; for Transformers a linear warmup + decay is recommended!
:param base_lr: Base learning rate, to be scaled via a linear scaling rule (from scaling laws).
:param min_lr: Minimum learning rate to decay to over the course of learning (usually 0.0)
:param effective_bsz: Global batch size for update, dictates the scaling of the base_lr.
:param betas: Adam optimizer betas (only applicable for `adam` and `adamw`. Prevents early loss spiking.
:param weight_decay: Weight decay for global weight regularization (only applied to non-bias, non-LN layers).
:param warmup_epochs: Number of epochs to warmup learning rate for linear warmup schedule.
:param max_epochs: Total number of training epochs to be run.
:param mask_ratio: Ratio for number of patches to mask out for MAE -- should be fairly high!
:param mlp_ratio: Ratio for embedding size to Position-wise FeedForward MLP (gets shrunk back down).
:param in_channels: Default number of channels in the base image -- almost always 3.
:param norm_pixel_loss: Normalize decoder pixel targets for reconstruction (better perf, not interpretable).
"""
super().__init__()
self.resolution, self.patch_size, self.mask_ratio = resolution, patch_size, mask_ratio
self.in_channels, self.norm_pixel_loss, self.mlp_ratio = in_channels, norm_pixel_loss, mlp_ratio
self.optimizer, self.schedule, self.betas, self.weight_decay = optimizer, schedule, betas, weight_decay
self.lr, self.base_lr, self.min_lr, self.effective_bsz = None, base_lr, min_lr, effective_bsz
self.warmup_epochs, self.max_epochs = warmup_epochs, max_epochs
# Encoder/Decoder Parameters
self.encoder_depth, self.decoder_depth = encoder_depth, decoder_depth
self.encoder_embed_dim, self.encoder_n_heads = encoder_embed_dim, encoder_n_heads
self.decoder_embed_dim, self.decoder_n_heads = decoder_embed_dim, decoder_n_heads
# MAE Encoder Parameters --> MVP uses a CLS Token for feature extraction!
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.encoder_embed_dim))
self.patch2embed = PatchEmbed(
self.resolution, self.patch_size, self.encoder_embed_dim, in_channels=self.in_channels
)
self.encoder_pe = nn.Parameter(
torch.zeros(1, self.patch2embed.num_patches + 1, self.encoder_embed_dim), requires_grad=False
)
self.encoder_blocks = nn.ModuleList(
[Block(self.encoder_embed_dim, self.encoder_n_heads, self.mlp_ratio) for _ in range(self.encoder_depth)]
)
self.encoder_norm = nn.LayerNorm(self.encoder_embed_dim, eps=1e-6)
# Projection from Encoder to Decoder
self.encoder2decoder = nn.Linear(self.encoder_embed_dim, self.decoder_embed_dim)
# MAE Decoder Parameters -- Remember the CLS Token!
self.mask_token = nn.Parameter(torch.zeros(1, 1, self.decoder_embed_dim))
self.decoder_pe = nn.Parameter(
torch.zeros(1, self.patch2embed.num_patches + 1, self.decoder_embed_dim), requires_grad=False
)
self.decoder_blocks = nn.ModuleList(
[Block(self.decoder_embed_dim, self.decoder_n_heads, self.mlp_ratio) for _ in range(self.decoder_depth)]
)
self.decoder_norm = nn.LayerNorm(self.decoder_embed_dim, eps=1e-6)
self.decoder_prediction = nn.Linear(self.decoder_embed_dim, (patch_size**2) * in_channels, bias=True)
# Initialize all Weights
self.initialize_weights()
def initialize_weights(self) -> None:
# Position Encoding -- Fixed 2D Sine-Cosine Embeddings
enc_pe = get_2D_position_embeddings(self.encoder_embed_dim, int(self.patch2embed.num_patches**0.5), True)
self.encoder_pe.data.copy_(torch.from_numpy(enc_pe).float().unsqueeze(0))
dec_pe = get_2D_position_embeddings(self.decoder_embed_dim, int(self.patch2embed.num_patches**0.5), True)
self.decoder_pe.data.copy_(torch.from_numpy(dec_pe).float().unsqueeze(0))
# Initialize PatchEmbedding as a Linear...
nn.init.xavier_uniform_(self.patch2embed.proj.weight.data.view([self.patch2embed.proj.weight.data.shape[0], -1]))
# Initialize CLS Token & Mask Token w/ Truncated Normal
nn.init.normal_(self.cls_token, std=0.02)
nn.init.normal_(self.mask_token, std=0.02)
# Everything else...
self.apply(self.transformer_initializer)
@staticmethod
def transformer_initializer(m: nn.Module) -> None:
if isinstance(m, nn.Linear):
# Use xavier_uniform following Jax ViT
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0.0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.weight, 1.0)
nn.init.constant_(m.bias, 0.0)
def mask(
self, patches: torch.Tensor, mask_ratio: Optional[float] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Perform per-sample random masking by shuffling :: uses argsort random noise to identify masked patches"""
bsz, n_patches, embed_dim = patches.shape
if mask_ratio is not None:
n_keep = int(n_patches * (1 - mask_ratio))
else:
n_keep = int(n_patches * (1 - self.mask_ratio))
# Sample some noise of n_patches size, argsort to get shuffled IDs (keep small), argsort again to "unshuffle"
# > For clarity -- argsort is an invertible transformation (if argsort `restore`, recovers `shuffle`)
shuffle_idxs = torch.argsort(torch.rand(bsz, n_patches, device=patches.device), dim=1)
restore_idxs = torch.argsort(shuffle_idxs, dim=1)
# Get "keep" (visible) patches
visible_patches = torch.gather(patches, dim=1, index=shuffle_idxs[:, :n_keep, None].repeat(1, 1, embed_dim))
# Generate the binary mask --> IMPORTANT :: `0` is keep, `1` is remove (following FAIR MAE convention)
mask = torch.ones(bsz, n_patches, device=patches.device)
mask[:, :n_keep] = 0
mask = torch.gather(mask, dim=1, index=restore_idxs)
return visible_patches, mask, restore_idxs
def get_representations(self, img: torch.Tensor, mode: str = "patch") -> torch.Tensor:
"""
Given a single image, extract representations subject to the specified mode in < patch | cls >, where "cls"
denotes extracting the token embedding; for our experiments, we find that running multiheaded attention
pooling on top of the "patch" embeddings is *always* better!
:param img: Processed batch of images :: [bsz, 3, 224, 224]
:param mode: Type of representation to extract -- `patch` (sequence of patch embeddings) or `cls` ()
:return: Extracted representations given img input.
"""
assert img.ndim == 4, "Invalid input to `get_representations()`"
assert mode in {"patch", "cls"}, f"Extraction mode `{mode}` not supported!"
# Extract desired representations
representations = self.encode(img)
return representations[:, 1:] if mode == "patch" else representations[:, :1]
def encode(self, img: torch.Tensor) -> torch.Tensor:
"""Run a single image through the MAE and extract patch embeddings."""
# Note: All of this code is taken near-verbatim from the MVP repository...
# > Ref: https://github.com/ir413/mvp/blob/master/mvp/backbones/vit.py#L30
patches = self.patch2embed(img)
cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)
cls_patches = torch.cat([cls_tokens, patches]) + self.encoder_pe
# Apply Transformer Blocks...
for block in self.encoder_blocks:
cls_patches = block(cls_patches)
cls_patches = self.encoder_norm(cls_patches)
return cls_patches
def forward_encoder(
self, imgs: torch.Tensor, mask_ratio: Optional[float] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# Patchify + Position Embedding (without the CLS Token)
patches = self.patch2embed(imgs)
patches_pe = patches + self.encoder_pe[:, 1:, :]
# Create mask (and go ahead and mask out patches at the same time)
visible_patches, mask, restore_idxs = self.mask(patches_pe, mask_ratio)
# Add the CLS Token
cls_token = self.cls_token + self.encoder_pe[:, :1, :]
cls_tokens = cls_token.expand(imgs.shape[0], -1, -1)
cls_visible_patches = torch.cat([cls_tokens, visible_patches], dim=1)
# Apply Transformer Blocks...
for block in self.encoder_blocks:
cls_visible_patches = block(cls_visible_patches)
cls_visible_patches = self.encoder_norm(cls_visible_patches)
return cls_visible_patches, mask, restore_idxs
def forward_decoder(self, visible_patches: torch.Tensor, restore_idxs: torch.Tensor) -> torch.Tensor:
# Project patches into decoder embedding dimension
projected_patches = self.encoder2decoder(visible_patches)
# Add Mask Tokens to Sequence
mask_tokens = self.mask_token.repeat(
projected_patches.shape[0], restore_idxs.shape[1] - visible_patches.shape[1] + 1, 1
)
# Remove & add back CLS Token as part of the "unshuffling"
concatenated_patches = torch.cat([projected_patches[:, 1:, :], mask_tokens], dim=1) # Skip CLS Token
unshuffled_patches = torch.gather(
concatenated_patches, dim=1, index=restore_idxs[..., None].repeat(1, 1, self.decoder_embed_dim)
)
cls_unshuffled_patches = torch.cat([projected_patches[:, :1, :], unshuffled_patches], dim=1) # Add CLS Token
# Add Position Embeddings
cls_decoder_patches = cls_unshuffled_patches + self.decoder_pe
# Apply Transformer Blocks...
for block in self.decoder_blocks:
cls_decoder_patches = block(cls_decoder_patches)
cls_decoder_patches = self.decoder_norm(cls_decoder_patches)
# Run final projection, remove the CLS token, and return
cls_decoder_prediction = self.decoder_prediction(cls_decoder_patches)
decoder_prediction = cls_decoder_prediction[:, 1:, :]
return decoder_prediction
def patchify(self, imgs: torch.Tensor) -> torch.Tensor:
"""Convert a batch of images to their patched equivalents, by naive reshaping"""
return rearrange(
imgs,
"bsz c (height patch_h) (width patch_w) -> bsz (height width) (patch_h patch_w c)",
patch_h=self.patch_size,
patch_w=self.patch_size,
)
def compute_loss(self, imgs: torch.Tensor, reconstructions: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
assert self.norm_pixel_loss, "`norm_pixel_loss` should always be true... false only for visualizations!"
targets = self.patchify(imgs)
# Normalize targets...
mu, var = targets.mean(dim=-1, keepdim=True), targets.var(dim=-1, unbiased=True, keepdim=True)
targets = (targets - mu) / ((var + 1e-6) ** 0.5)
# Compute mean loss per patch first...
mse = (reconstructions - targets) ** 2
avg_loss_per_patch = mse.mean(dim=-1)
# Compute mean loss only on *removed* patches and return
return (avg_loss_per_patch * mask).sum() / mask.sum()
def forward(
self, imgs: torch.Tensor, mask_ratio: Optional[float] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
visible_patches, mask, restore_idxs = self.forward_encoder(imgs, mask_ratio)
reconstructions = self.forward_decoder(visible_patches, restore_idxs)
loss = self.compute_loss(imgs, reconstructions, mask)
return loss, reconstructions, mask
def configure_optimizer(self) -> Tuple[torch.optim.Optimizer, Callable[[int, float], float]]:
# Short-Circuit on Valid Optimizers
if self.optimizer not in ["adamw"]:
raise NotImplementedError(f"Optimizer `{self.optimizer}` not supported - try [`adamw`] instead!")
# Create Parameter Groups --> Bias terms, Normalization layer parameters shouldn't be decayed...
# > This is a compact rewrite of `param_groups_weight_decay()` from TIMM because I don't want the dependency
decay, no_decay = [], []
for name, param in self.named_parameters():
if not param.requires_grad:
continue
# Check on any parameters with fewer than 2 dimensions or with "bias" in the name...
if param.ndim <= 1 or name.endswith(".bias"):
no_decay.append(param)
else:
decay.append(param)
# Build Parameter Groups
groups = [{"params": decay, "weight_decay": self.weight_decay}, {"params": no_decay, "weight_decay": 0.0}]
# Compute LR -- MAE uses the `linear scaling rule` :: lr = base_lr * (effective_bsz / 256)
# > https://github.com/facebookresearch/mae/blob/main/PRETRAIN.md
self.lr = self.base_lr * (self.effective_bsz / 256)
# Create Optimizer & LR Scheduler
optimizer = torch.optim.AdamW(groups, lr=self.lr, betas=self.betas)
update_lr = get_lr_update(optimizer, self.schedule, self.lr, self.min_lr, self.warmup_epochs, self.max_epochs)
return optimizer, update_lr
================================================
FILE: voltron/models/reproductions/vr3m.py
================================================
"""
vr3m.py
PyTorch Module defining an R3M model (with a ViT encoder), with the remainder as described in Nair et. al. 2021,
with all the requisite hyperparameters.
Reference:
- https://github.com/facebookresearch/r3m
"""
from typing import Callable, Tuple
import torch
import torch.nn as nn
import transformers
from einops import rearrange
from voltron.models.util.optimization import get_lr_update
from voltron.models.util.transformer import Block, PatchEmbed, get_2D_position_embeddings
# Suppress Transformers Logging
transformers.logging.set_verbosity_error()
class VR3M(nn.Module):
def __init__(
self,
resolution: int,
patch_size: int,
depth: int,
embed_dim: int,
n_heads: int,
language_model: str,
hf_cache: str,
language_dim: int,
reward_dim: int,
n_negatives: int,
lang_reward_weight: float,
tcn_weight: float,
l1_weight: float,
l2_weight: float,
optimizer: str,
schedule: str,
lr: float,
min_lr: float,
warmup_epochs: int,
max_epochs: int,
mlp_ratio: float = 4.0,
in_channels: int = 3,
eps: float = 1e-8,
):
"""
Initialize a ViT R3M model with the requisite architecture parameters.
:param resolution: Base image resolution -- usually 224 (ImageNet size).
:param patch_size: Height/Width of each patch in pixels -- usually 16.
:param depth: Number of Transformer blocks in the ViT image encoder.
:param embed_dim: Core embedding/hidden dimension for the vision transformer backbone.
:param n_heads: Number of heads for multi-headed self-attention.
:param language_model: Language model to freeze for encoding narrations/utterances.
:param hf_cache: Cache directory to store pretrained models, for safe distributed training.
:param language_dim: Dimensionality of the language embedding coming out of the pretrained LM.
:param reward_dim: Hidden layer dimensionality for the language-reward MLP.
:param n_negatives: Number of cross-batch negatives to sample for contrastive learning.
:param lang_reward_weight: Weight applied to the contrastive "language alignment" loss term.
:param tcn_weight: Weight applied to the time contrastive loss term.
:param l1_weight: Weight applied to the L1 regularization loss term.
:param l2_weight: Weight applied to the L2 regularization loss term.
:param optimizer: String denoting which optimizer to use (for R3M, usually `adam`)
:param schedule: Learning rate schedule to use; for Transformers a linear warmup + decay is recommended!
:param lr: Peak learning rate to use over the course of training -- warms up to this value, then decays.
:param min_lr: Minimum learning rate to decay to over the course of learning (usually 0.0)
:param warmup_epochs: Number of epochs to warmup learning rate for linear warmup schedule.
:param max_epochs: Total number of training epochs to be run.
:param mlp_ratio: Ratio for embedding size to Position-wise FeedForward MLP (gets shrunk back down; default 4).
:param in_channels: Default number of channels in the base image -- almost always 3.
:param eps: Epsilon for preventing divide by zero in the InfoNCE loss terms.
"""
super().__init__()
self.resolution, self.patch_size, self.n_negatives, self.eps = resolution, patch_size, n_negatives, eps
self.optimizer, self.schedule, self.lr, self.min_lr = optimizer, schedule, lr, min_lr
self.warmup_epochs, self.max_epochs = warmup_epochs, max_epochs
self.mlp_ratio, self.in_channels = mlp_ratio, in_channels
self.language_dim, self.reward_dim = language_dim, reward_dim
# Weights for each loss term
self.lang_reward_weight, self.tcn_weight = lang_reward_weight, tcn_weight
self.l1_weight, self.l2_weight = l1_weight, l2_weight
# ViT Backbone
self.depth, self.embed_dim, self.n_heads = depth, embed_dim, n_heads
# Create ViT --> some differences from the original architectures in "An Image is Worth 16x16 Words":
# > Namely, subset of authors show that (https://arxiv.org/abs/2205.01580):
# 1) 2D sinusoidal embeddings are better than learned embeddings...
# 2) Average pooling is better/simpler than learning a [CLS] token embedding...
# Note: Output "embedding" is just mean pooled final transformer layer --> after extra layer norm!
# > Ref: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L446
self.patch2embed = PatchEmbed(self.resolution, self.patch_size, self.embed_dim, in_channels=self.in_channels)
self.pe = nn.Parameter(torch.zeros(1, self.patch2embed.num_patches, self.embed_dim), requires_grad=False)
self.blocks = nn.ModuleList([Block(self.embed_dim, self.n_heads, self.mlp_ratio) for _ in range(self.depth)])
self.norm = nn.LayerNorm(self.embed_dim, eps=1e-6)
# Create Language Reward Model
self.language_reward = nn.Sequential(
nn.Linear(self.embed_dim + self.embed_dim + self.language_dim, self.reward_dim),
nn.ReLU(),
nn.Linear(self.reward_dim, self.reward_dim),
nn.ReLU(),
nn.Linear(self.reward_dim, self.reward_dim),
nn.ReLU(),
nn.Linear(self.reward_dim, self.reward_dim),
nn.ReLU(),
nn.Linear(self.reward_dim, 1),
nn.Sigmoid(),
)
# Initialize all Weights
self.initialize_weights()
# @AFTER INITIALIZATION -- Create Language Model & Language Reward MLP --> LM has requires_grad = False
# > For BERT models, our "embedding" is just going to be the last hidden state
# > Assumes inputs to forward pass are pre-tokenized!
self.tokenizer = transformers.AutoTokenizer.from_pretrained(language_model, cache_dir=hf_cache)
self.lm = transformers.AutoModel.from_pretrained(language_model, cache_dir=hf_cache)
self.lm.eval()
# Shape Assertion -- make sure self.language_dim actually is the same as the LM dimension!
assert self.lm.config.dim == self.language_dim, "Language model embedding dimension != self.language_dim!"
# Freeze the LM
for _name, param in self.lm.named_parameters():
param.requires_grad = False
def initialize_weights(self) -> None:
# Position Encoding -- Fixed 2D Sine-Cosine Embeddings
pe = get_2D_position_embeddings(self.embed_dim, int(self.patch2embed.num_patches**0.5))
self.pe.data.copy_(torch.from_numpy(pe).float().unsqueeze(0))
# Initialize PatchEmbedding as a Linear...
nn.init.xavier_uniform_(self.patch2embed.proj.weight.data.view([self.patch2embed.proj.weight.data.shape[0], -1]))
# Everything else...
self.apply(self.transformer_initializer)
@staticmethod
def transformer_initializer(m: nn.Module) -> None:
if isinstance(m, nn.Linear):
# Use xavier_uniform following Jax ViT
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0.0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.weight, 1.0)
nn.init.constant_(m.bias, 0.0)
def get_representations(self, img: torch.Tensor) -> torch.Tensor:
"""
Given a single image, extract R3M "default" (mean-pooled) dense representation.
:param img: Processed batch of images :: [bsz, 3, 224, 224]
:return: Extracted R3M dense representation given img input.
"""
assert img.ndim == 4, "Invalid input to `get_representations()`"
patches = self.patch2embed(img)
img_embedding = patches + self.pe
# Apply Transformer Blocks...
for block in self.blocks:
img_embedding = block(img_embedding)
img_embedding = self.norm(img_embedding)
return img_embedding.mean(1, keepdim=True)
def encode_images(self, imgs: torch.Tensor) -> torch.Tensor:
"""Feed images through ViT, get single embedding via global mean pooling."""
patches = self.patch2embed(imgs)
img_embedding = patches + self.pe
# Apply Transformer Blocks...
for block in self.blocks:
img_embedding = block(img_embedding)
img_embedding = self.norm(img_embedding)
# Return mean pooled embeddings
return img_embedding.mean(dim=1)
def encode_language(self, lang: torch.Tensor, lang_mask: torch.Tensor) -> torch.Tensor:
"""Encode language by feeding the *pre-tokenized text* through the frozen language model."""
self.lm.eval()
with torch.no_grad():
transformer_embeddings = self.lm(lang, attention_mask=lang_mask).last_hidden_state
return transformer_embeddings.mean(dim=1)
def get_reward(self, initial: torch.Tensor, later: torch.Tensor, lang: torch.Tensor) -> torch.Tensor:
return self.language_reward(torch.cat([initial, later, lang], dim=-1)).squeeze()
def forward(self, imgs: torch.Tensor, lang: torch.Tensor, lang_mask: torch.Tensor) -> Tuple[torch.tensor, ...]:
"""
Run a forward pass through the model, computing the *full* R3M loss -- the TCN contrastive loss, the Language
Alignment loss, and both sparsity losses, as well as the full loss (which will get optimized)!
:param imgs: A [bsz, 5, in_channels, resolution, resolution] tensor of (start, i, j, k, end) sequences.
:param lang: Tokenized language of dimensionality [bsz, seq_len] to be fed to the language model.
:param lang_mask: Attention mask computed by the tokenizer, as a result of padding to the max_seq_len.
:return: Tuple of losses, as follows:
> (combined_loss, tcn_loss, reward_loss, l1_loss, l2_loss, tcn_acc, reward_acc)
"""
# Encode each image separately... feed to transformer... then reshape
all_images = rearrange(imgs, "bsz n_states c res1 res2 -> (bsz n_states) c res1 res2", n_states=5)
all_embeddings = self.encode_images(all_images)
initial, state_i, state_j, state_k, final = rearrange(
all_embeddings, "(bsz n_states) embed -> n_states bsz embed", n_states=5
)
# Compute Regularization Losses
l1_loss = torch.linalg.norm(all_embeddings, ord=1, dim=-1).mean()
l2_loss = torch.linalg.norm(all_embeddings, ord=2, dim=-1).mean()
# Compute TCN Loss
tcn_loss, tcn_acc = self.get_time_contrastive_loss(state_i, state_j, state_k)
# Compute Language Alignment/Predictive Loss
lang_reward_loss, rew_acc = self.get_reward_loss(lang, lang_mask, initial, state_i, state_j, state_k, final)
# Compute full weighted loss & return...
loss = (
(self.l1_weight * l1_loss)
+ (self.l2_weight * l2_loss)
+ (self.tcn_weight * tcn_loss)
+ (self.lang_reward_weight * lang_reward_loss)
)
return loss, tcn_loss, lang_reward_loss, l1_loss, l2_loss, tcn_acc, rew_acc
@staticmethod
def time_similarity(state_x: torch.Tensor, state_y: torch.Tensor, use_l2: bool = True) -> torch.Tensor:
"""Computes similarity between embeddings via -L2 distance."""
assert use_l2, "Non-L2 time-similarity functions not yet implemented!"
return -torch.linalg.norm(state_x - state_y, dim=-1)
def get_time_contrastive_loss(
self, state_i: torch.Tensor, state_j: torch.Tensor, state_k: torch.Tensor
) -> Tuple[torch.Tensor, ...]:
"""Evaluates the Time-Contrastive Loss, computed via InfoNCE."""
# *Punchline* - we want `sim(i, j)` to be higher than `sim(i, k)` for some k > j (goes both ways)
# `Reward(s*_0, s*_ As our positive examples --> we sample (s_i, s_j) and (s_j, s_k).
# > Our negatives --> other pairs from the triplet, cross-batch negatives!
sim_i_j_exp = torch.exp(self.time_similarity(state_i, state_j))
sim_j_k_exp = torch.exp(self.time_similarity(state_j, state_k))
# Add a "hard" negative!
neg_i_k_exp = torch.exp(self.time_similarity(state_i, state_k))
# Obtain *cross-batch* negatives
bsz, neg_i, neg_j = state_i.shape[0], [], []
for _ in range(self.n_negatives):
neg_idx = torch.randperm(bsz)
state_i_shuf = state_i[neg_idx]
neg_idx = torch.randperm(bsz)
state_j_shuf = state_j[neg_idx]
neg_i.append(self.time_similarity(state_i, state_i_shuf))
neg_j.append(self.time_similarity(state_j, state_j_shuf))
neg_i_exp, neg_j_exp = torch.exp(torch.stack(neg_i, -1)), torch.exp(torch.stack(neg_j, -1))
# Compute InfoNCE
denominator_i = sim_i_j_exp + neg_i_k_exp + neg_i_exp.sum(-1)
denominator_j = sim_j_k_exp + neg_i_k_exp + neg_j_exp.sum(-1)
nce_i = -torch.log(self.eps + (sim_i_j_exp / (self.eps + denominator_i)))
nce_j = -torch.log(self.eps + (sim_j_k_exp / (self.eps + denominator_j)))
nce = (nce_i + nce_j) / 2
# Compute "accuracy"
i_j_acc = (1.0 * (sim_i_j_exp > neg_i_k_exp)).mean()
j_k_acc = (1.0 * (sim_j_k_exp > neg_i_k_exp)).mean()
acc = (i_j_acc + j_k_acc) / 2
return nce.mean(), acc
def get_reward_loss(
self,
lang: torch.Tensor,
lang_mask: torch.Tensor,
initial: torch.Tensor,
state_i: torch.Tensor,
state_j: torch.Tensor,
state_k: torch.Tensor,
final: torch.Tensor,
) -> Tuple[torch.Tensor, ...]:
"""Evaluates the Language-Alignment Reward Loss, computed via InfoNCE."""
lang_embed = self.encode_language(lang, lang_mask)
# *Punchline* - we want `Reward(s_0, s_t, l)` to be higher than `Reward(s_0, s_ As our positive examples --> we sample s_j, s_k, and s_final (excluding s_i)
pos_final_exp = torch.exp(self.get_reward(initial, final, lang_embed))
pos_j_exp = torch.exp(self.get_reward(initial, state_j, lang_embed))
pos_k_exp = torch.exp(self.get_reward(initial, state_k, lang_embed))
# Add the within-context negatives <--> these are the most informative examples!
# > We use initial, initial as a negative for the first one, just to get reward model to "capture progress"
negs_final = [self.get_reward(initial, initial, lang_embed)]
negs_j = [self.get_reward(initial, state_i, lang_embed)]
negs_k = [self.get_reward(initial, state_j, lang_embed)]
# Cross Batch Negatives -- same as positives (indexing), but from a different batch!
# > @SK :: Unclear how well this will unroll on TPUs...
bsz = initial.shape[0]
for _ in range(self.n_negatives):
# We get three random indices to further minimize correlation... from the R3M codebase!
neg_idx = torch.randperm(bsz)
negs_final.append(self.get_reward(initial[neg_idx], final[neg_idx], lang_embed))
neg_idx = torch.randperm(bsz)
negs_j.append(self.get_reward(initial[neg_idx], state_j[neg_idx], lang_embed))
neg_idx = torch.randperm(bsz)
negs_k.append(self.get_reward(initial[neg_idx], state_k[neg_idx], lang_embed))
# Flatten & exponentiate; get ready for the InfoNCE
negs_final, negs_j, negs_k = torch.stack(negs_final, -1), torch.stack(negs_j, -1), torch.stack(negs_k, -1)
negs_final_exp, negs_j_exp, negs_k_exp = torch.exp(negs_final), torch.exp(negs_j), torch.exp(negs_k)
# Compute InfoNCE
denominator_final = pos_final_exp + negs_final_exp.sum(-1)
denominator_j = pos_j_exp + negs_j_exp.sum(-1)
denominator_k = pos_k_exp + negs_k_exp.sum(-1)
nce_final = -torch.log(self.eps + (pos_final_exp / (self.eps + denominator_final)))
nce_j = -torch.log(self.eps + (pos_j_exp / (self.eps + denominator_j)))
nce_k = -torch.log(self.eps + (pos_k_exp / (self.eps + denominator_k)))
# Compute "accuracy"
acc_final = (1.0 * (negs_final_exp.max(dim=-1)[0] < pos_final_exp)).mean()
acc_j = (1.0 * (negs_j_exp.max(dim=-1)[0] < pos_j_exp)).mean()
acc_k = (1.0 * (negs_k_exp.max(dim=-1)[0] < pos_k_exp)).mean()
acc = (acc_final + acc_j + acc_k) / 3
nce = (nce_final + nce_j + nce_k) / 3
return nce.mean(), acc
def configure_optimizer(self) -> Tuple[torch.optim.Optimizer, Callable[[int, float], float]]:
# Short-Circuit on Valid Optimizers
if self.optimizer not in ["adam"]:
raise NotImplementedError(f"Optimizer `{self.optimizer}` not supported - try [`adam`] instead!")
# Create Optimizer & LR Scheduler
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
update_lr = get_lr_update(optimizer, self.schedule, self.lr, self.min_lr, self.warmup_epochs, self.max_epochs)
return optimizer, update_lr
================================================
FILE: voltron/models/reproductions/vrn3m.py
================================================
"""
vrn3m.py
PyTorch Module defining an R3M model (with a ResNet 50 encoder), exactly as described in Nair et. al. 2021, with all the
requisite hyperparameters.
Reference:
- https://github.com/facebookresearch/r3m
"""
from typing import Callable, Tuple
import torch
import torch.nn as nn
import transformers
from einops import rearrange
from torchvision.models import resnet50
from voltron.models.util.optimization import get_lr_update
# Suppress Transformers Logging
transformers.logging.set_verbosity_error()
class VRN3M(nn.Module):
def __init__(
self,
resolution: int,
fc_dim: int,
language_model: str,
hf_cache: str,
language_dim: int,
reward_dim: int,
n_negatives: int,
lang_reward_weight: float,
tcn_weight: float,
l1_weight: float,
l2_weight: float,
optimizer: str,
lr: float,
eps: float = 1e-8,
):
"""
Initialize an ResNet-50 R3M model with the required architecture parameters.
:param resolution: Base image resolution -- usually 224 (ImageNet size).
:param fc_dim: Dimensionality of the pooled embedding coming out of the ResNet (for RN50, fc_dim = 2048)
:param language_model: Language model to freeze for encoding narrations/utterances.
:param hf_cache: Cache directory to store pretrained models, for safe distributed training.
:param language_dim: Dimensionality of the language embedding coming out of the pretrained LM.
:param reward_dim: Hidden layer dimensionality for the language-reward MLP.
:param n_negatives: Number of cross-batch negatives to sample for contrastive learning.
:param lang_reward_weight: Weight applied to the contrastive "language alignment" loss term.
:param tcn_weight: Weight applied to the time contrastive loss term.
:param l1_weight: Weight applied to the L1 regularization loss term.
:param l2_weight: Weight applied to the L2 regularization loss term.
:param optimizer: String denoting which optimizer to use (for R3M, usually `adam`).
:param lr: Learning rate (fixed for ResNet R3M models) for training.
:param eps: Epsilon for preventing divide by zero in the InfoNCE loss terms.
"""
super().__init__()
self.resolution, self.fc_dim, self.n_negatives, self.eps = resolution, fc_dim, n_negatives, eps
self.language_dim, self.reward_dim, self.optimizer, self.lr = language_dim, reward_dim, optimizer, lr
self.embed_dim = self.fc_dim
# Weights for each loss term
self.lang_reward_weight, self.tcn_weight = lang_reward_weight, tcn_weight
self.l1_weight, self.l2_weight = l1_weight, l2_weight
# Create ResNet50 --> set `rn.fc` to the Identity() to extract final features of dim = `fc_dim`
self.resnet = resnet50(weights=None)
self.resnet.fc = nn.Identity()
self.resnet.train()
# Create Language Reward Model
self.language_reward = nn.Sequential(
nn.Linear(self.fc_dim + self.fc_dim + self.language_dim, self.reward_dim),
nn.ReLU(),
nn.Linear(self.reward_dim, self.reward_dim),
nn.ReLU(),
nn.Linear(self.reward_dim, self.reward_dim),
nn.ReLU(),
nn.Linear(self.reward_dim, self.reward_dim),
nn.ReLU(),
nn.Linear(self.reward_dim, 1),
nn.Sigmoid(),
)
# Create Language Model & Language Reward MLP --> LM has requires_grad = False
# > For BERT models, our "embedding" is just going to be the last hidden state
# > Assumes inputs to forward pass are pre-tokenized!
self.tokenizer = transformers.AutoTokenizer.from_pretrained(language_model, cache_dir=hf_cache)
self.lm = transformers.AutoModel.from_pretrained(language_model, cache_dir=hf_cache)
self.lm.eval()
# Shape Assertion -- make sure self.language_dim actually is the same as the LM dimension!
assert self.lm.config.dim == self.language_dim, "Language model embedding dimension != self.language_dim!"
# Freeze the LM
for _name, param in self.lm.named_parameters():
param.requires_grad = False
def get_representations(self, img: torch.Tensor) -> torch.Tensor:
"""
Given a single image, extract R3M "default" (ResNet pooled) dense representation.
:param img: Processed batch of images :: [bsz, 3, 224, 224]
:return: Extracted R3M dense representation given img input.
"""
assert img.ndim == 4, "Invalid input to `get_representations()`"
representation = self.resnet(img)
return representation.unsqueeze(1)
def encode_images(self, imgs: torch.Tensor) -> torch.Tensor:
"""Feed images through ResNet-50 to get single embedding after global average pooling."""
return self.resnet(imgs)
def encode_language(self, lang: torch.Tensor, lang_mask: torch.Tensor) -> torch.Tensor:
"""Encode language by feeding the *pre-tokenized text* through the frozen language model."""
self.lm.eval()
with torch.no_grad():
transformer_embeddings = self.lm(lang, attention_mask=lang_mask).last_hidden_state
return transformer_embeddings.mean(dim=1)
def get_reward(self, initial: torch.Tensor, later: torch.Tensor, lang: torch.Tensor) -> torch.Tensor:
return self.language_reward(torch.cat([initial, later, lang], dim=-1)).squeeze()
def extract_features(self, img: torch.Tensor) -> torch.Tensor:
"""Run a single image of shape [1, 3, 224, 224] through the ResNet and extract the feature."""
return self.encode_images(img).detach()
def forward(self, imgs: torch.Tensor, lang: torch.Tensor, lang_mask: torch.Tensor) -> Tuple[torch.Tensor, ...]:
"""
Run a forward pass through the model, computing the *full* R3M loss -- the TCN contrastive loss, the Language
Alignment loss, and both sparsity losses, as well as the full loss (which will get optimized)!
:param imgs: A [bsz, 5, in_channels, resolution, resolution] tensor of (start, i, j, k, end) sequences.
:param lang: Tokenized language of dimensionality [bsz, seq_len] to be fed to the language model.
:param lang_mask: Attention mask computed by the tokenizer, as a result of padding to the max_seq_len.
:return: Tuple of losses, as follows:
> (combined_loss, tcn_loss, reward_loss, l1_loss, l2_loss, tcn_acc, reward_acc)
"""
# Encode each image separately... feed to transformer... then reshape
all_images = rearrange(imgs, "bsz n_states c res1 res2 -> (bsz n_states) c res1 res2", n_states=5)
all_embeddings = self.encode_images(all_images)
initial, state_i, state_j, state_k, final = rearrange(
all_embeddings, "(bsz n_states) embed -> n_states bsz embed", n_states=5
)
# Compute Regularization Losses
l1_loss = torch.linalg.norm(all_embeddings, ord=1, dim=-1).mean()
l2_loss = torch.linalg.norm(all_embeddings, ord=2, dim=-1).mean()
# Compute TCN Loss
tcn_loss, tcn_acc = self.get_time_contrastive_loss(state_i, state_j, state_k)
# Compute Language Alignment/Predictive Loss
lang_reward_loss, rew_acc = self.get_reward_loss(lang, lang_mask, initial, state_i, state_j, state_k, final)
# Compute full weighted loss & return...
loss = (
(self.l1_weight * l1_loss)
+ (self.l2_weight * l2_loss)
+ (self.tcn_weight * tcn_loss)
+ (self.lang_reward_weight * lang_reward_loss)
)
return loss, tcn_loss, lang_reward_loss, l1_loss, l2_loss, tcn_acc, rew_acc
@staticmethod
def time_similarity(state_x: torch.Tensor, state_y: torch.Tensor, use_l2: bool = True) -> torch.Tensor:
"""Computes similarity between embeddings via -L2 distance."""
assert use_l2, "Non-L2 time-similarity functions not yet implemented!"
return -torch.linalg.norm(state_x - state_y, dim=-1)
def get_time_contrastive_loss(
self, state_i: torch.Tensor, state_j: torch.Tensor, state_k: torch.Tensor
) -> Tuple[torch.Tensor, ...]:
"""Evaluates the Time-Contrastive Loss, computed via InfoNCE."""
# *Punchline* - we want `sim(i, j)` to be higher than `sim(i, k)` for some k > j (goes both ways)
# `Reward(s*_0, s*_ As our positive examples --> we sample (s_i, s_j) and (s_j, s_k).
# > Our negatives --> other pairs from the triplet, cross-batch negatives!
sim_i_j_exp = torch.exp(self.time_similarity(state_i, state_j))
sim_j_k_exp = torch.exp(self.time_similarity(state_j, state_k))
# Add a "hard" negative!
neg_i_k_exp = torch.exp(self.time_similarity(state_i, state_k))
# Obtain *cross-batch* negatives
bsz, neg_i, neg_j = state_i.shape[0], [], []
for _ in range(self.n_negatives):
neg_idx = torch.randperm(bsz)
state_i_shuf = state_i[neg_idx]
neg_idx = torch.randperm(bsz)
state_j_shuf = state_j[neg_idx]
neg_i.append(self.time_similarity(state_i, state_i_shuf))
neg_j.append(self.time_similarity(state_j, state_j_shuf))
neg_i_exp, neg_j_exp = torch.exp(torch.stack(neg_i, -1)), torch.exp(torch.stack(neg_j, -1))
# Compute InfoNCE
denominator_i = sim_i_j_exp + neg_i_k_exp + neg_i_exp.sum(-1)
denominator_j = sim_j_k_exp + neg_i_k_exp + neg_j_exp.sum(-1)
nce_i = -torch.log(self.eps + (sim_i_j_exp / (self.eps + denominator_i)))
nce_j = -torch.log(self.eps + (sim_j_k_exp / (self.eps + denominator_j)))
nce = (nce_i + nce_j) / 2
# Compute "accuracy"
i_j_acc = (1.0 * (sim_i_j_exp > neg_i_k_exp)).mean()
j_k_acc = (1.0 * (sim_j_k_exp > neg_i_k_exp)).mean()
acc = (i_j_acc + j_k_acc) / 2
return nce.mean(), acc
def get_reward_loss(
self,
lang: torch.Tensor,
lang_mask: torch.Tensor,
initial: torch.Tensor,
state_i: torch.Tensor,
state_j: torch.Tensor,
state_k: torch.Tensor,
final: torch.Tensor,
) -> Tuple[torch.Tensor, ...]:
"""Evaluates the Language-Alignment Reward Loss, computed via InfoNCE."""
lang_embed = self.encode_language(lang, lang_mask)
# *Punchline* - we want `Reward(s_0, s_t, l)` to be higher than `Reward(s_0, s_ As our positive examples --> we sample s_j, s_k, and s_final (excluding s_i)
pos_final_exp = torch.exp(self.get_reward(initial, final, lang_embed))
pos_j_exp = torch.exp(self.get_reward(initial, state_j, lang_embed))
pos_k_exp = torch.exp(self.get_reward(initial, state_k, lang_embed))
# Add the within-context negatives <--> these are the most informative examples!
# > We use initial, initial as a negative for the first one, just to get reward model to "capture progress"
negs_final = [self.get_reward(initial, initial, lang_embed)]
negs_j = [self.get_reward(initial, state_i, lang_embed)]
negs_k = [self.get_reward(initial, state_j, lang_embed)]
# Cross Batch Negatives -- same as positives (indexing), but from a different batch!
# > @SK :: Unclear how well this will unroll on TPUs...
bsz = initial.shape[0]
for _ in range(self.n_negatives):
# We get three random indices to further minimize correlation... from the R3M codebase!
neg_idx = torch.randperm(bsz)
negs_final.append(self.get_reward(initial[neg_idx], final[neg_idx], lang_embed))
neg_idx = torch.randperm(bsz)
negs_j.append(self.get_reward(initial[neg_idx], state_j[neg_idx], lang_embed))
neg_idx = torch.randperm(bsz)
negs_k.append(self.get_reward(initial[neg_idx], state_k[neg_idx], lang_embed))
# Flatten & exponentiate; get ready for the InfoNCE
negs_final, negs_j, negs_k = torch.stack(negs_final, -1), torch.stack(negs_j, -1), torch.stack(negs_k, -1)
negs_final_exp, negs_j_exp, negs_k_exp = torch.exp(negs_final), torch.exp(negs_j), torch.exp(negs_k)
# Compute InfoNCE
denominator_final = pos_final_exp + negs_final_exp.sum(-1)
denominator_j = pos_j_exp + negs_j_exp.sum(-1)
denominator_k = pos_k_exp + negs_k_exp.sum(-1)
nce_final = -torch.log(self.eps + (pos_final_exp / (self.eps + denominator_final)))
nce_j = -torch.log(self.eps + (pos_j_exp / (self.eps + denominator_j)))
nce_k = -torch.log(self.eps + (pos_k_exp / (self.eps + denominator_k)))
# Compute "accuracy"
acc_final = (1.0 * (negs_final_exp.max(dim=-1)[0] < pos_final_exp)).mean()
acc_j = (1.0 * (negs_j_exp.max(dim=-1)[0] < pos_j_exp)).mean()
acc_k = (1.0 * (negs_k_exp.max(dim=-1)[0] < pos_k_exp)).mean()
acc = (acc_final + acc_j + acc_k) / 3
nce = (nce_final + nce_j + nce_k) / 3
return nce.mean(), acc
def configure_optimizer(self) -> Tuple[torch.optim.Optimizer, Callable[[int, float], float]]:
# Short-Circuit on Valid Optimizers
if self.optimizer not in ["adam"]:
raise NotImplementedError(f"Optimizer `{self.optimizer}` not supported - try [`adam`] instead!")
# Create Optimizer and (No-Op) LR Scheduler
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
update_lr = get_lr_update(
optimizer, schedule="none", lr=self.lr, min_lr=self.lr, warmup_epochs=-1, max_epochs=-1
)
return optimizer, update_lr
================================================
FILE: voltron/models/util/__init__.py
================================================
from .extraction import instantiate_extractor
================================================
FILE: voltron/models/util/extraction.py
================================================
"""
extraction.py
General Extraction module definitions & associated utilities.
References:
- Set Transformers (MAP): https://arxiv.org/abs/1810.00825.pdf
"""
from typing import Callable
import torch
import torch.nn as nn
from einops import repeat
from voltron.models.util.transformer import RMSNorm, SwishGLU
# === Multiheaded Attention Pooling ===
# As defined in Set Transformers () -- basically the above, additionally taking in
# a set of $k$ learned "seed vectors" that are used to "pool" information.
class MAPAttention(nn.Module):
def __init__(self, embed_dim: int, n_heads: int) -> None:
"""Multi-Input Multi-Headed Attention Operation"""
super().__init__()
assert embed_dim % n_heads == 0, "`embed_dim` must be divisible by `n_heads`!"
self.n_heads, self.scale = n_heads, (embed_dim // n_heads) ** -0.5
# Projections (no bias) --> separate for Q (seed vector), and KV ("pool" inputs)
self.q, self.kv = nn.Linear(embed_dim, embed_dim, bias=False), nn.Linear(embed_dim, 2 * embed_dim, bias=False)
self.proj = nn.Linear(embed_dim, embed_dim)
def forward(self, seed: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
(B_s, K, C_s), (B_x, N, C_x) = seed.shape, x.shape
assert C_s == C_x, "Seed vectors and pool inputs must have the same embedding dimensionality!"
# Project Seed Vectors to `queries`
q = self.q(seed).reshape(B_s, K, self.n_heads, C_s // self.n_heads).permute(0, 2, 1, 3)
kv = self.kv(x).reshape(B_x, N, 2, self.n_heads, C_x // self.n_heads).permute(2, 0, 3, 1, 4)
k, v = kv.unbind(0)
# Attention --> compute weighted sum over values!
scores = q @ (k.transpose(-2, -1) * self.scale)
attn = scores.softmax(dim=-1)
vals = (attn @ v).transpose(1, 2).reshape(B_s, K, C_s)
# Project back to `embed_dim`
return self.proj(vals)
class MAPBlock(nn.Module):
def __init__(
self,
n_latents: int,
embed_dim: int,
n_heads: int,
mlp_ratio: float = 4.0,
do_rms_norm: bool = True,
do_swish_glu: bool = True,
) -> None:
"""Multiheaded Attention Pooling Block -- note that for MAP, we adopt earlier post-norm conventions."""
super().__init__()
self.n_latents, self.embed_dim, self.n_heads = n_latents, embed_dim, 2 * n_heads
# Projection Operator
self.projection = nn.Linear(embed_dim, self.embed_dim)
# Initialize Latents
self.latents = nn.Parameter(torch.zeros(self.n_latents, self.embed_dim))
nn.init.normal_(self.latents, std=0.02)
# Custom MAP Attention (seed, encoder outputs) -> seed
self.attn_norm = RMSNorm(self.embed_dim) if do_rms_norm else nn.LayerNorm(self.embed_dim, eps=1e-6)
self.attn = MAPAttention(self.embed_dim, n_heads=self.n_heads)
# Position-wise Feed-Forward Components
self.mlp_norm = RMSNorm(self.embed_dim) if do_rms_norm else nn.LayerNorm(self.embed_dim, eps=1e-6)
self.mlp = nn.Sequential(
# Handle SwishGLU vs. GELU MLP...
(
SwishGLU(self.embed_dim, int(mlp_ratio * self.embed_dim))
if do_swish_glu
else nn.Sequential(nn.Linear(self.embed_dim, int(mlp_ratio * self.embed_dim)), nn.GELU())
),
nn.Linear(int(mlp_ratio * self.embed_dim), self.embed_dim),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
latents = repeat(self.latents, "n_latents d -> bsz n_latents d", bsz=x.shape[0])
latents = self.attn_norm(latents + self.attn(latents, self.projection(x)))
latents = self.mlp_norm(latents + self.mlp(latents))
return latents.squeeze(dim=1)
# MAP Extractor Instantiation --> factory for creating extractors with the given parameters.
def instantiate_extractor(backbone: nn.Module, n_latents: int = 1) -> Callable[[], nn.Module]:
def initialize() -> nn.Module:
return MAPBlock(n_latents, backbone.embed_dim, backbone.n_heads)
return initialize
================================================
FILE: voltron/models/util/optimization.py
================================================
"""
optimization.py
General utilities for optimization, e.g., schedulers such as Linear Warmup w/ Cosine Decay for Transformer training.
Notably *does not* use the base PyTorch LR Scheduler, since we call it continuously, across epochs, across steps;
PyTorch has no built-in way of separating the two without coupling to the DataLoader, so may as well make this explicit
in the parent loop.
References
- MAE: https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/lr_sched.py
- ⚡️-Bolts: https://github.com/PyTorchLightning/lightning-bolts/blob/master/pl_bolts/optimizers/lr_scheduler.py
"""
import math
from typing import Callable
from torch.optim.optimizer import Optimizer
def get_lr_update(
opt: Optimizer, schedule: str, lr: float, min_lr: float, warmup_epochs: int, max_epochs: int
) -> Callable[[int, float], float]:
if schedule == "linear-warmup+cosine-decay":
def lr_update(epoch: int, fractional_progress: float) -> float:
"""Run the warmup check for linear increase, else cosine decay."""
if (epoch + fractional_progress) < warmup_epochs:
new_lr = lr * (epoch + fractional_progress) / max(1.0, warmup_epochs)
else:
# Cosine Decay --> as defined in the SGDR Paper...
progress = ((epoch + fractional_progress) - warmup_epochs) / max(1.0, max_epochs - warmup_epochs)
new_lr = min_lr + (lr - min_lr) * (0.5 * (1 + math.cos(math.pi * progress)))
# Apply...
for group in opt.param_groups:
if "lr_scale" in group:
group["lr"] = new_lr * group["lr_scale"]
else:
group["lr"] = new_lr
return new_lr
elif schedule == "none":
def lr_update(_: int, __: float) -> float:
return lr
else:
raise NotImplementedError(f"Schedule `{schedule}` not implemented!")
return lr_update
================================================
FILE: voltron/models/util/transformer.py
================================================
"""
transformer.py
General Transformer modules & utilities.
References:
- https://github.com/facebookresearch/mae
- https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
"""
from typing import Optional
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
# === Position Encoding Utilities ===
# Helper/Utility Function -- computes simple 1D sinusoidal position embeddings for both 1D/2D use cases.
# > We'll be combining two 1D sin-cos (traditional) position encodings for height/width of an image (grid features).
def get_1D_sine_cosine(dim: int, pos: np.ndarray) -> np.ndarray:
omega = np.arange(dim // 2, dtype=np.float32) / (dim / 2.0)
omega = 1.0 / (10000**omega)
out = np.einsum("m,d->md", pos.reshape(-1), omega) # [flatten(pos) x omega] -- outer product!
emb_sin, emb_cos = np.sin(out), np.cos(out)
return np.concatenate([emb_sin, emb_cos], axis=1) # [flatten(pos) x D]
# 1D Sine-Cosine Position Embedding -- standard from "Attention is all you need!"
def get_1D_position_embeddings(embed_dim: int, length: int) -> np.ndarray:
return get_1D_sine_cosine(embed_dim, np.arange(length))
# 2D Sine-Cosine Position Embedding (from MAE repository)
# > https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
def get_2D_position_embeddings(embed_dim: int, grid_size: int, cls_token: bool = False) -> np.ndarray:
# Create 2D Position embeddings by taking cross product of height and width and splicing 1D embeddings...
grid_h, grid_w = np.arange(grid_size, dtype=np.float32), np.arange(grid_size, dtype=np.float32)
grid = np.stack(np.meshgrid(grid_w, grid_h), axis=0).reshape(2, 1, grid_size, grid_size) # w goes first?
# Use half of dimensions to encode grid_h, other half to encode grid_w
emb_h, emb_w = get_1D_sine_cosine(embed_dim // 2, grid[0]), get_1D_sine_cosine(embed_dim // 2, grid[1])
pos_embed = np.concatenate([emb_h, emb_w], axis=1)
# CLS token handling (only for R-MVP)
if cls_token:
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
return pos_embed
# === Vision Transformer Building Blocks ===
# Patch Embedding Module
class PatchEmbed(nn.Module):
def __init__(
self,
resolution: int,
patch_size: int,
embed_dim: int,
in_channels: int = 3,
flatten: bool = True,
):
super().__init__()
self.resolution, self.patch_size = (resolution, resolution), (patch_size, patch_size)
self.grid_size = (self.resolution[0] // self.patch_size[0], self.resolution[1] // self.patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=self.patch_size, stride=self.patch_size)
def forward(self, patches: torch.Tensor) -> torch.Tensor:
patch_embeddings = self.proj(patches)
if self.flatten:
return rearrange(patch_embeddings, "bsz embed patch_h patch_w -> bsz (patch_h patch_w) embed")
return patch_embeddings
# === Stability Utilities ===
# LayerScale -- Trainable scaling for residual blocks -- Mistral/CaIT
class LayerScale(nn.Module):
def __init__(self, dim: int, init_values: float = 0.1) -> None: # CaIT :: 0.1 -> lay 12, 1e-5 -> lay 24, 1e-6...
super().__init__()
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x * self.gamma
# RMSNorm -- Better, simpler alternative to LayerNorm
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-8) -> None:
super().__init__()
self.scale, self.eps = dim**-0.5, eps
self.g = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
return x / norm.clamp(min=self.eps) * self.g
# SwishGLU -- A Gated Linear Unit (GLU) with the Swish activation; always better than GELU MLP!
class SwishGLU(nn.Module):
def __init__(self, in_dim: int, out_dim: int) -> None:
super().__init__()
self.act, self.project = nn.SiLU(), nn.Linear(in_dim, 2 * out_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
projected, gate = self.project(x).tensor_split(2, dim=-1)
return projected * self.act(gate)
# === Fundamental Transformer Building Blocks ===
class Attention(nn.Module):
def __init__(self, embed_dim: int, n_heads: int, dropout: float = 0.0) -> None:
"""Multi-Headed Self-Attention Operation"""
super().__init__()
assert embed_dim % n_heads == 0, "`embed_dim` must be divisible by `n_heads`!"
self.n_heads, self.scale = n_heads, (embed_dim // n_heads) ** -0.5
self.attn_softmax = None
# Projections
self.qkv, self.proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True), nn.Linear(embed_dim, embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
B, N, C = x.shape
# Project to Q-K-V
qkv = self.qkv(x).reshape(B, N, 3, self.n_heads, C // self.n_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
# Self-attention -- with masking!
scores = q @ (k.transpose(-2, -1) * self.scale)
if mask is not None:
if mask.ndim == 2:
mask = rearrange(mask, "bsz seq -> bsz 1 seq 1")
elif mask.ndim != 4:
raise NotImplementedError("Attention got `mask` of shape not in {2, 4}!")
# Mask out by filling indices with negative infinity...
scores = scores.masked_fill(mask == 0, torch.finfo(scores.dtype).min)
# Compute weighted sum over values
self.attn_softmax = scores.softmax(dim=-1)
vals = (self.attn_softmax @ v).transpose(1, 2).reshape(B, N, C)
# Project back to `embed_dim` -- with optional dropout
vals = self.dropout(self.proj(vals))
return vals
class Block(nn.Module):
def __init__(
self,
embed_dim: int,
n_heads: int,
mlp_ratio: float = 4.0,
dropout: float = 0.0,
do_rms_norm: bool = False,
do_swish_glu: bool = False,
do_layer_scale: bool = False,
) -> None:
"""
Transformer Block Implementation (modality-agnostic).
:param embed_dim: Core embedding/hidden dimension for vision transformer backbone.
:param n_heads: Number of heads for multi-headed self-attention.
:param mlp_ratio: Ratio for embedding size to position-wise feed-forward MLP (gets shrunk back down).
:param dropout: [Optional] dropout for projection layer and MLPs -- for MAEs, always 0.0!
:param do_rms_norm: Boolean whether or not to use RMSNorm in lieu of LayerNorm within block.
:param do_swish_glu: Use the Swish-variant of the Gated Linear Unit for the feed-forward layers.
:param do_layer_scale: Boolean whether or not to use LayerScale from Mistral/CaIT w/ initialization of 0.1.
"""
super().__init__()
self.embed_dim, self.n_heads, self.do_layer_scale = embed_dim, n_heads, do_layer_scale
# Attention Components
self.pre_norm_attn = RMSNorm(self.embed_dim) if do_rms_norm else nn.LayerNorm(self.embed_dim, eps=1e-6)
self.attn = Attention(self.embed_dim, n_heads=n_heads, dropout=dropout)
if do_layer_scale:
self.layer_scale_attn = LayerScale(self.embed_dim)
# Position-wise Feed-Forward Components
self.pre_norm_mlp = RMSNorm(self.embed_dim) if do_rms_norm else nn.LayerNorm(self.embed_dim, eps=1e-6)
self.mlp = nn.Sequential(
# Handle SwishGLU vs. GELU MLP...
(
SwishGLU(embed_dim, int(mlp_ratio * embed_dim))
if do_swish_glu
else nn.Sequential(nn.Linear(embed_dim, int(mlp_ratio * embed_dim)), nn.GELU())
),
nn.Dropout(dropout),
nn.Linear(int(mlp_ratio * embed_dim), embed_dim),
)
if self.do_layer_scale:
self.layer_scale_mlp = LayerScale(self.embed_dim)
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.do_layer_scale:
x = x + self.layer_scale_attn(self.attn(self.pre_norm_attn(x), mask))
x = x + self.layer_scale_mlp(self.mlp(self.pre_norm_mlp(x)))
else:
x = x + self.attn(self.pre_norm_attn(x), mask)
x = x + self.mlp(self.pre_norm_mlp(x))
return x
================================================
FILE: voltron/overwatch/__init__.py
================================================
from .overwatch import OverwatchRich
================================================
FILE: voltron/overwatch/overwatch.py
================================================
"""
overwatch.py
Utility class for creating a centralized/standardized logger (to pass to Hydra), with a sane default format.
"""
from dataclasses import dataclass, field
from typing import Any, Dict
# Overwatch Default Format String
FORMATTER, DATEFMT = "[*] %(asctime)s - %(name)s >> %(levelname)s :: %(message)s", "%m/%d [%H:%M:%S]"
RICH_FORMATTER = "| >> %(message)s"
# Rich Overwatch Variant --> Good for debugging, and tracing!
@dataclass
class OverwatchRich:
version: int = 1
formatters: Dict[str, Any] = field(
default_factory=lambda: {
"simple-console": {"format": RICH_FORMATTER, "datefmt": DATEFMT},
"simple-file": {"format": FORMATTER, "datefmt": DATEFMT},
}
)
handlers: Dict[str, Any] = field(
default_factory=lambda: {
"console": {
"class": "rich.logging.RichHandler",
"formatter": "simple-console",
"rich_tracebacks": True,
"show_level": True,
"show_path": True,
"show_time": True,
},
"file": {
"class": "logging.FileHandler",
"formatter": "simple-file",
"filename": "${hydra.job.name}.log",
},
}
)
root: Dict[str, Any] = field(default_factory=lambda: {"level": "INFO", "handlers": ["console", "file"]})
disable_existing_loggers: bool = True
# Standard Overwatch Variant --> Performant, no bells & whistles
@dataclass
class OverwatchStandard:
version: int = 1
formatters: Dict[str, Any] = field(default_factory=lambda: {"simple": {"format": FORMATTER, "datefmt": DATEFMT}})
handlers: Dict[str, Any] = field(
default_factory=lambda: {
"console": {"class": "logging.StreamHandler", "formatter": "simple", "stream": "ext://sys.stdout"},
"file": {
"class": "logging.FileHandler",
"formatter": "simple",
"filename": "${hydra.job.name}.log",
},
}
)
root: Dict[str, Any] = field(default_factory=lambda: {"level": "INFO", "handlers": ["console", "file"]})
disable_existing_loggers: bool = True
================================================
FILE: voltron/preprocessing/__init__.py
================================================
from .process import extract_frames, preprocess_language, unify_batches
================================================
FILE: voltron/preprocessing/core.py
================================================
"""
utils.py
Preprocessing utilities, including dry-run and single-video (single-example) processing. This file effectively defines
the "atomic" logic (take one video --> extract all frames, etc.), while the `process.py` functions invoke each unit
in a multiprocessing pool.
"""
import glob
import json
import logging
import os
import time
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
import av
import h5py
import numpy as np
import pandas as pd
from hurry.filesize import alternative, size
from PIL import Image
from rich.progress import track
from tqdm import tqdm
# Grab Logger
overwatch = logging.getLogger(__file__)
logging.getLogger("libav").setLevel(logging.ERROR)
# === General Utilities ===
# Videos are saved as `train_dir/{vid}/{vid}_idx={i}.jpg || if `relpath` then *relative path* `{split}/{vid}/...
def get_path(save_dir: Path, v: str, i: int, relpath: bool = False) -> str:
return str((save_dir if not relpath else Path(save_dir.name)) / v / f"{v}_idx={i}.jpg")
# === Dry-Run Functionality ===
def do_dry_run(
name: str,
path: str,
train_ids: List[str],
val_ids: List[str],
preprocess_transform: Callable[[List[Image.Image]], List[Image.Image]],
n_train_videos: int = 1000,
n_val_videos: int = 100,
n_samples: int = 1000,
) -> None:
"""Iterates through a small subset of the total dataset, logs n_frames & average image size for estimation."""
overwatch.info(f"Performing Dry-Run with {n_train_videos} Train Videos and {n_val_videos} Validation Videos")
dry_run_metrics = {
"n_frames": [],
"jpg_sizes": [],
"n_samples": n_samples,
"time_per_example": [],
"blank": str(Path(path) / "blank.jpg"),
}
# Switch on dataset (`name`)
if name == "sth-sth-v2":
for k, n_iter, vids in [("train", n_train_videos, train_ids), ("val", n_val_videos, val_ids)]:
for idx in track(range(n_iter), description=f"Reading {k.capitalize()} Videos =>> ", transient=True):
container = av.open(str(Path(path) / "videos" / f"{vids[idx]}.webm"))
assert int(container.streams.video[0].average_rate) == 12, "FPS for `sth-sth-v2` should be 12!"
try:
imgs = [f.to_image() for f in container.decode(video=0)]
except (RuntimeError, ZeroDivisionError) as e:
overwatch.error(f"{type(e).__name__}: WebM reader cannot open `{vids[idx]}.webm` - continuing...")
continue
container.close()
# Apply `preprocess_transform`
imgs = preprocess_transform(imgs)
# Dry-Run Handling --> write a dummy JPEG to collect size statistics, dump, and move on...
dry_run_metrics["n_frames"].append(len(imgs))
while dry_run_metrics["n_samples"] > 0 and len(imgs) > 0:
img = imgs.pop(0)
img.save(str(dry_run_metrics["blank"]))
dry_run_metrics["jpg_sizes"].append(os.path.getsize(dry_run_metrics["blank"]))
dry_run_metrics["n_samples"] -= 1
# Compute nice totals for "dry-run" estimate...
total_clips = len(train_ids) + len(val_ids)
else:
raise ValueError(f"Dry Run for Dataset `{name}` not implemented!")
# Compute aggregate statistics and gently exit...
avg_size, avg_frames = np.mean(dry_run_metrics["jpg_sizes"]), int(np.mean(dry_run_metrics["n_frames"]))
overwatch.info("Dry-Run Statistics =>>")
overwatch.info(f"\t> A video has on average `{avg_frames}` frames at {size(avg_size, system=alternative)}")
overwatch.info(f"\t> So - 1 video ~ {size(avg_frames * avg_size, system=alternative)}")
overwatch.info(
f"\t> With the full dataset of {total_clips} Train + Val videos ~"
f" {size(total_clips * avg_frames * avg_size, system=alternative)}"
)
overwatch.info("Dry-Run complete, do what you will... exiting ✌️")
# Remove dummy file...
os.remove(dry_run_metrics["blank"])
exit(0)
# === Atomic "Processing" Steps ===
def process_clip(
name: str,
path: Path,
save: Path,
preprocess_transform: Callable[[List[Image.Image]], List[Image.Image]],
item: Tuple[str, str],
) -> Tuple[Optional[str], Optional[Dict[str, Any]]]:
"""Processes a single video clip and extracts/serializes all frames (as jpeg), returning the registry contents."""
if name == "sth-sth-v2":
vid, lang = item
container, registration = av.open(str(Path(path) / "videos" / f"{vid}.webm")), {"language": lang, "n_frames": 0}
assert int(container.streams.video[0].average_rate) == 12, "FPS for `sth-sth-v2` should be 12!"
try:
imgs = [f.to_image() for f in container.decode(video=0)]
except (RuntimeError, ZeroDivisionError) as e:
overwatch.error(f"{type(e).__name__}: WebM reader cannot open `{vid}.webm` - continuing...")
return None, None
container.close()
# Book-Keeping
os.makedirs(save / vid, exist_ok=True)
registration["n_frames"] = len(imgs)
# Short Circuit --> Writes are Expensive!
if len(glob.glob1(save / vid, "*.jpg")) == len(imgs):
return vid, registration
# Apply `preprocess_transform` --> write individual frames, register, and move on!
imgs = preprocess_transform(imgs)
for idx in range(len(imgs)):
imgs[idx].save(get_path(save, vid, idx))
# Return title & registration
return vid, registration
else:
raise ValueError(f"Clip Processing for Dataset `{name}` is not implemented!")
# ruff: noqa: C901
def serialize_epoch(
index_dir: Path,
registry: Dict[str, Any],
vid_dir: Path,
batch_formats: Tuple[Tuple[str, Tuple[str, ...]], ...],
do_initial: bool,
do_final: bool,
initial_final_alpha: float,
n_int: int,
epoch: int,
is_validation: bool = False,
) -> Tuple[int, int, Optional[Set[str]]]:
index_file = "validation-batches.json" if is_validation else f"train-epoch={epoch}-batches.json"
index_hdf5 = "validation-batches.hdf5" if is_validation else f"train-epoch={epoch}-batches.hdf5"
# Short-Circuit
if all([(index_dir / key / index_file).exists() for key, _ in batch_formats]):
return -1, -1, None
# Random seed is inherited from parent process... we want new randomness w/ each process
np.random.seed((os.getpid() * int(time.time())) % 123456789)
# Create Tracking Variables
unique_states, batches = set(), {b: [] for b, _ in batch_formats}
# Iterate through Registry --> Note we're using `tqdm` instead of `track` here because of `position` feature!
for vid in tqdm(registry.keys(), desc=f"Epoch {epoch}", total=len(registry), position=epoch):
# The initial/final states are sampled from the first [0, \alpha) and final 1-\alpha, 1] percent of the video
n_frames = registry[vid]["n_frames"]
initial_idx, final_idx = 0, n_frames - 1
if do_initial:
initial_idx = np.random.randint(0, np.around(n_frames * initial_final_alpha))
if do_final:
final_idx = np.random.randint(np.around(n_frames * (1 - initial_final_alpha)), n_frames)
# Assertion --> initial_idx < final_idx - len(state_elements)
assert initial_idx < final_idx - n_int, "Initial & Final are too close... no way to sample!"
# Assume remaining elements are just random "interior" states --> sort to get ordering!
sampled_idxs = np.random.choice(np.arange(initial_idx + 1, final_idx), size=n_int, replace=False)
sampled_idxs = sorted(list(sampled_idxs))
# Compile full-set "batch"
retrieved_states = [get_path(vid_dir, vid, x, relpath=True) for x in [initial_idx, *sampled_idxs] + [final_idx]]
# Add batch to index for specific batch_format key...
batches[batch_formats[-1][0]].append({"vid": vid, "states": retrieved_states, "n_frames": n_frames})
unique_states.update(retrieved_states)
# Add all other batch formats to indices...
for key, elements in batch_formats[:-1]:
n_states = len([x for x in elements if "state_" in x])
assert (n_states <= 2) or (
n_states == len(retrieved_states)
), f"Strange value of n_states={n_states} > 2 and not equal to total possible of {len(retrieved_states)}"
# States are all independent -- each of the retrieved states is its own example...
if n_states == 1:
for idx in range(len(retrieved_states)):
batches[key].append({"vid": vid, "state": retrieved_states[idx], "n_frames": n_frames})
# OK-Context is the only "valid" context for n_states == 2
elif n_states == 2:
assert elements == ["state_initial", "state_i", "language"], "n_states = 2 but not 0K context?"
# Append 0th state to each of the remaining sampled contexts (usually 2 or 4)... each pair is an example
for idx in range(1, len(retrieved_states)):
batches[key].append(
{"vid": vid, "states": [retrieved_states[0], retrieved_states[idx]], "n_frames": n_frames}
)
# We're treating the entire sequence of retrieved states as a single example (for TCN/R3M/Temporal Models)
else:
batches[key].append({"vid": vid, "states": retrieved_states, "n_frames": n_frames})
# Write JSON Index directly to disk...
for key in batches:
with open(index_dir / key / index_file, "w") as f:
json.dump(batches[key], f)
# Write HDF5 Index directly to disk...
for key, elements in batch_formats[:-1]:
n_states = len([x for x in elements if "state_" in x])
# Create HDF5 File
df = pd.DataFrame(batches[key])
h5 = h5py.File(index_dir / key / index_hdf5, "w")
for k in ["vid", "n_frames"]:
h5.create_dataset(k, data=df[k].values)
# Handle "state(s)" --> (image path strings) --> add leading dimension (`n_states`)
if n_states == 1:
dfs = df["state"].apply(pd.Series)
h5.create_dataset("states", data=dfs.values)
else:
dfs = df["states"].apply(pd.Series)
h5.create_dataset("states", data=dfs.values)
# Close HDF5 File
h5.close()
return epoch, len(batches["state"]), unique_states
================================================
FILE: voltron/preprocessing/process.py
================================================
"""
process.py
Utility functions for preprocessing large-scale video/vision-language datasets in multiple passes, using multiprocessing
for parallelization. Exposes a three-phase sequence for preprocessing --> batching data:
- Phase I (`extract_frames`): Read in raw (video clip, language) pairs, extract and serialize *all frames* to disk.
This script tries to be smart where it can, using multiprocessing.Pool in Phase I to speed up extraction; however, for
larger datasets YMMV. You might consider extracting the relevant logic, and using tools like SLURM Job Arrays, AWS
Lambda Functions, or GCP Cloud Run to "burst preprocess" data.
"""
import json
import logging
import multiprocessing as mp
import os
import shutil
from functools import partial
from pathlib import Path
from typing import Tuple
import torch
from rich.progress import track
from transformers import AutoTokenizer
from voltron.preprocessing.core import do_dry_run, process_clip, serialize_epoch
from voltron.preprocessing.transforms import get_preprocess_transform
# Grab Logger
overwatch = logging.getLogger(__file__)
def extract_frames(
name: str,
path: str,
artifact_path: str,
preprocess_resolution: int,
n_val_videos: int,
dry_run: bool = False,
) -> Tuple[Path, Path, Path, Path]:
"""Phase I: Extract and serialize *all frames* from video clips; uses multiprocessing to parallelize."""
overwatch.info(f"Phase 1 Preprocessing :: Extracting Frames for Dataset `{name}`")
# Overview of Return Values:
# `t_registry` and `v_registry` =>> store mappings of "video id" -> {metadata}
# `t_dir` and `v_dir` =>> store "processed data" (extracted frames)
t_dir, v_dir = Path(artifact_path) / name / "train", Path(artifact_path) / name / "val"
t_registry, v_registry = t_dir / "registry.json", v_dir / "registry.json"
# Short-Circuit
if t_registry.exists() and v_registry.exists():
return t_registry, v_registry, t_dir, v_dir
# Setup / Book-Keeping
os.makedirs(t_dir, exist_ok=True)
os.makedirs(v_dir, exist_ok=True)
# Retrieve "pre-serialization" frame transform --> we scale down video frames (*while preserving aspect ratios*)
# and center crop each frame to `(preprocess_resolution, preprocess_resolution)`; saves on disk space (by a lot!)
preprocess_transform = get_preprocess_transform(name, preprocess_resolution=preprocess_resolution)
# Switch on dataset (`name`)
if name == "sth-sth-v2":
with open(Path(path) / "labels/train.json", "r") as f:
annotations = json.load(f)
train_ids, train_lang = [x["id"] for x in annotations], [x["label"] for x in annotations]
with open(Path(path) / "labels/validation.json", "r") as f:
annotations = json.load(f)[:n_val_videos]
val_ids, val_lang = [x["id"] for x in annotations], [x["label"] for x in annotations]
else:
raise ValueError(f"Language/Metadata Extraction Pipeline for Dataset `{name}` not implemented!")
# Run Dry-Run (if specified) --> single-threaded for debugging
if dry_run:
do_dry_run(name, path, train_ids, val_ids, preprocess_transform)
# Otherwise =>> Iterate through all videos, dump all frames subject to the following structure:
# |-> .../processed/something-something-v2/
# |-> /
# |-> /frames<0..k>.jpg
#
# We'll build a single metadata file with a mapping : ("language", n_frames)
# > To speed up serialization, we'll use a multiprocessing.Pool and max out CPU workers
with mp.Pool(mp.cpu_count()) as pool:
for k, save, vids, langs in [("train", t_dir, train_ids, train_lang), ("val", v_dir, val_ids, val_lang)]:
overwatch.info(f"\tWriting `{k}` videos to disk...")
# Spawn!
process_fn, registration = partial(process_clip, name, Path(path), save, preprocess_transform), {}
for key, value in track(
pool.imap_unordered(process_fn, zip(vids, langs)),
total=len(vids),
transient=True,
):
if key is not None:
registration[key] = value
# Write Registration to Disk
with open(t_registry if k == "train" else v_registry, "w") as f:
json.dump(registration, f)
# Return Paths to Registry & Extract Directories...
return t_registry, v_registry, t_dir, v_dir
def preprocess_language(
name: str,
train_registry: Path,
val_registry: Path,
artifact_path: str,
max_lang_len: int,
language_model: str,
hf_cache: str,
) -> Path:
"""Phase II: Iterate through Language Captions/Narrations and Normalize/Tokenize (truncate/pad to max length)."""
overwatch.info(f"Phase 2 Preprocessing :: Normalizing & Tokenizing Language for Dataset `{name}`")
t_index, v_index = train_registry.parent / "index.pt", val_registry.parent / "index.pt"
t_json, v_json = train_registry.parent / "index.json", val_registry.parent / "index.json"
index_dir = Path(artifact_path) / name / "index"
os.makedirs(index_dir, exist_ok=True)
# Short-Circuit
if (index_dir / "train-language-index.json").exists() and (index_dir / "val-language-index.json").exists():
return index_dir
# Grab Language --> retain metadata for building index structures!
with open(train_registry, "r") as f:
train_metadata = json.load(f)
train = [(vid, train_metadata[vid]["language"], train_metadata[vid]) for vid in train_metadata]
with open(val_registry, "r") as f:
val_metadata = json.load(f)
val = [(vid, val_metadata[vid]["language"], val_metadata[vid]) for vid in val_metadata]
# Assemble *all* language
language = [x[1] for x in train + val]
# Build AutoTokenizer (from `language_model` identifier)
tokenizer = AutoTokenizer.from_pretrained(language_model, cache_dir=hf_cache)
# If `max_lang_len` not specified, dump some statistics to compute...
if max_lang_len == -1:
# Naively tokenizes and pads to the "maximum length" of _all_ language... long tail is a problem!
encoded_language = tokenizer(language, return_tensors="pt", padding=True)
lengths = encoded_language["attention_mask"].sum(dim=1)
# Compute a histogram of lengths
hist = lengths.float().histc(bins=lengths.max()).int()
overwatch.info(f"Histogram: {hist.numpy().tolist()}")
raise AssertionError("Compute max length and update dataset configuration!")
# Otherwise, we've already set the maximum length, so let's use it!
overwatch.info(f"\tTokenizing all language in dataset to maximum length `{max_lang_len}`")
encoded_language = tokenizer(
language, return_tensors="pt", max_length=max_lang_len, truncation=True, padding="max_length"
)
input_ids, attention_mask = encoded_language["input_ids"], encoded_language["attention_mask"]
train_input_ids, train_attention_mask = input_ids[: len(train)], attention_mask[: len(train)]
val_input_ids, val_attention_mask = input_ids[len(train) :], attention_mask[len(train) :]
# Assertion, just to sanity check
assert len(val_input_ids) == len(val_attention_mask) == len(val), "Something went wrong tokenizing language..."
# Compute `index.pt` contents
overwatch.info("\tAssembling `train` and `val` index structures...")
train_pt = {
train[i][0]: {**train[i][2], **{"input_ids": train_input_ids[i], "attention_mask": train_attention_mask[i]}}
for i in range(len(train))
}
val_pt = {
val[i][0]: {**val[i][2], **{"input_ids": val_input_ids[i], "attention_mask": val_attention_mask[i]}}
for i in range(len(val))
}
# Additionally dump JSON versions of the same --> downstream interpretability, XLA
overwatch.info("JSONifying both Train and Validation Language")
train_json, val_json = {}, {}
for vid in track(train_pt, description="Train Language :: ", transient=True):
train_json[vid] = {
"language": train_pt[vid]["language"],
"n_frames": train_pt[vid]["n_frames"],
"input_ids": train_pt[vid]["input_ids"].numpy().tolist(),
"attention_mask": train_pt[vid]["attention_mask"].numpy().tolist(),
}
for vid in track(val_pt, description="Validation Language :: ", transient=True):
val_json[vid] = {
"language": val_pt[vid]["language"],
"n_frames": val_pt[vid]["n_frames"],
"input_ids": val_pt[vid]["input_ids"].numpy().tolist(),
"attention_mask": val_pt[vid]["attention_mask"].numpy().tolist(),
}
# Dump Structures...
overwatch.info(f"Saving Torch indices to `{t_index}` and `{v_index}` respectively...")
torch.save(train_pt, t_index)
torch.save(val_pt, v_index)
overwatch.info(f"Saving JSON indices to `{t_json}` and `{v_json}` respectively...")
with open(t_json, "w") as f:
json.dump(train_json, f)
with open(v_json, "w") as f:
json.dump(val_json, f)
# Pull relevant files out into their own `index` directory...
shutil.copy(t_json, index_dir / "train-language-index.json")
shutil.copy(v_json, index_dir / "val-language-index.json")
return index_dir
def unify_batches(
name: str,
train_registry: Path,
val_registry: Path,
train_dir: Path,
val_dir: Path,
index_dir: Path,
batch_formats: Tuple[Tuple[str, Tuple[str, ...]], ...],
max_epochs: int = 400,
initial_final_alpha: float = 0.2,
) -> None:
"""Phase III: Assemble "Data-Locked" Batches for *all models* for *all epochs* for consistency!"""
overwatch.info(f"Phase 3 Preprocessing :: Assembling *Data-Locked* Batches for Dataset `{name}`")
# Load Registries
with open(train_registry, "r") as f:
train_registrations = json.load(f)
with open(val_registry, "r") as f:
val_registrations = json.load(f)
# Assert last element of `batch_formats` assumes all prior subsets...
full_set_inputs = set(batch_formats[-1][1])
for _, subset_inputs in batch_formats[:-1]:
assert full_set_inputs.issuperset(set(subset_inputs)), "We have a problem with batch formats..."
# Assemble Tracking Data
b_keys, unique_states = {b[0] for b in batch_formats}, set()
# Parse out all "state"-specific Elements...
state_elements = [s for s in full_set_inputs if "state_" in s]
do_initial, do_final = "state_initial" in state_elements, "state_final" in state_elements
n_int = len(state_elements) - 2 if ("state_initial" in state_elements and "state_final" in state_elements) else 0
# Serialize Epochs
overwatch.info("\tSerializing Epochs to JSON --> Storing mapping of Epoch -> Image Paths")
for b in b_keys:
os.makedirs(index_dir / b, exist_ok=True)
# We only write the Validation Epoch once --> held constant across *all* of training!
overwatch.info("\tWriting Validation Epoch to Disk")
val_epoch_idx, _, uniq_s = serialize_epoch(
index_dir,
val_registrations,
val_dir,
batch_formats,
do_initial,
do_final,
initial_final_alpha,
n_int,
epoch=0,
is_validation=True,
)
# Update Trackers...
if val_epoch_idx != -1:
unique_states |= uniq_s
# Compute length of epochs --> CPU Count should be no higher...
epochs, n_frames_per_epoch = list(range(max_epochs)), -1
# Parallelize Train Epoch Serialization
overwatch.info("\tPlacing the Train Registry into Shared Memory")
manager = mp.Manager()
mg_registry = manager.dict(train_registrations)
# Multiprocess --> the memory demands here are a bit higher, so limit workers by factor of 4
with mp.Pool(mp.cpu_count() // 4) as pool:
overwatch.info("\tWriting Train Batches per Epoch to Disk")
precompute_fn = partial(
serialize_epoch,
index_dir,
mg_registry,
train_dir,
batch_formats,
do_initial,
do_final,
initial_final_alpha,
n_int,
)
for epoch_idx, n_frames, uniq_s in pool.imap_unordered(precompute_fn, epochs):
if epoch_idx == -1:
continue
# Update Trackers
unique_states |= uniq_s
n_frames_per_epoch = n_frames
# Dump Statistics (Note :: Only makes sense on "initial" computation --> uninterrupted!)
overwatch.info(f"Train Uniqueness: {len(unique_states)} States & {len(mg_registry)} Utterances")
overwatch.info(f"Final Statistics :: 1 Epoch has ~ {n_frames_per_epoch} Frames...")
================================================
FILE: voltron/preprocessing/transforms.py
================================================
"""
transforms.py
Default video/image transforms for Voltron preprocessing and training. Provides utilities for defining different scale
and crop transformations on a dataset-specific basis.
There are two key desiderata we ensure with the transforms:
- Aspect Ratio --> We *never* naively reshape images in a way that distorts the aspect ratio; we crop instead!
- Minimum Size --> We *never* upsample images; processing strictly reduces dimensionality!
"""
from functools import partial
from typing import Any, Callable, List, Tuple
import torch
from PIL import Image, ImageOps
from torchvision.transforms import Compose, ConvertImageDtype, Lambda, Normalize, Resize
# Simple Identity Function --> needs to be top-level/pickleable for mp/distributed.spawn()
def identity(x: torch.Tensor) -> torch.Tensor:
return x.float()
def scaled_center_crop(target_resolution: int, frames: List[Image.Image]) -> Image.Image:
# Assert width >= height and height >= target_resolution
orig_w, orig_h = frames[0].size
assert orig_w >= orig_h >= target_resolution
# Compute scale factor --> just a function of height and target_resolution
scale_factor = target_resolution / orig_h
for idx in range(len(frames)):
frames[idx] = ImageOps.scale(frames[idx], factor=scale_factor)
left = (frames[idx].size[0] - target_resolution) // 2
frames[idx] = frames[idx].crop((left, 0, left + target_resolution, target_resolution))
# Return "scaled and squared" images
return frames
def get_preprocess_transform(
dataset_name: str, preprocess_resolution: int
) -> Callable[[List[Image.Image]], List[Image.Image]]:
"""Returns a transform that extracts square crops of `preprocess_resolution` from videos (as [T x H x W x C])."""
if dataset_name == "sth-sth-v2":
return partial(scaled_center_crop, preprocess_resolution)
else:
raise ValueError(f"Preprocessing transform for dataset `{dataset_name}` is not defined!")
def get_online_transform(
dataset_name: str, model_arch: str, online_resolution: int, normalization: Tuple[Any, Any]
) -> Compose:
"""Returns an "online" torchvision Transform to be applied during training (batching/inference)."""
if dataset_name == "sth-sth-v2":
# Note: R3M does *not* expect normalized 0-1 (then ImageNet normalized) images --> drop the identity.
if model_arch in {"v-r3m", "v-rn3m"}:
return Compose([Resize((online_resolution, online_resolution), antialias=True), Lambda(identity)])
else:
return Compose(
[
Resize((online_resolution, online_resolution), antialias=True),
ConvertImageDtype(torch.float),
Normalize(mean=normalization[0], std=normalization[1]),
]
)
else:
raise ValueError(f"Online Transforms for Dataset `{dataset_name}` not implemented!")
================================================
FILE: voltron/preprocessing/v1/__init__.py
================================================
from .process import index, jsonify_language, preprocess_language, preprocess_videos, unify_batches
================================================
FILE: voltron/preprocessing/v1/process.py
================================================
"""
process.py
Utility functions for serializing datasets in multiple passes, using multiprocessing for efficient parallelization.
Exposes a three-phase sequence for preprocessing:
- Phase I: Read in raw videos (and language), serialize *all extracted* frames to a subdirectory for easy retrieval.
- Phase II: Given image paths and language, assemble language statistics & pre-tokenize for easy batching.
- Phase III: Given a total number of "conceivable epochs", create data-controlled "epoch" sets for each model.
This script tries to be smart where it can, using multiprocessing.Pool in Phase I to speed up the serialization
process. It also tries to be somewhat safe & efficient, producing idempotent resumes.
Note :: This code represents the `v1` (initial release) preprocessing flow; this will eventually be deprecated!
"""
import json
import logging
import multiprocessing as mp
import os
import shutil
from functools import partial
from pathlib import Path
from typing import Tuple
import torch
from rich.progress import track
from transformers import AutoTokenizer
from voltron.preprocessing.v1.transforms import get_pre_transform
from voltron.preprocessing.v1.utils import do_dry_run, precompute_epoch, process_video
# Grab Logger
overwatch = logging.getLogger(__file__)
def preprocess_videos(
name: str,
path: str,
artifact_path: str = "data/processed",
resolution: int = 224,
n_val_videos: int = 1000,
dry_run: bool = False,
) -> Tuple[Path, Path, Path, Path]:
"""Phase I of Preprocessing :: Uses Multiprocessing to Read Videos & Serialize Frames."""
overwatch.info(f"Phase 1 Preprocessing :: Frame serializing videos for dataset `{name}`")
if name == "sth-sth-v2":
# Overview of Return Values:
# `t_registry` and `v_registry` =>> store mappings of "vid_id" -> {metadata}
# `t_dir` and `v_dir` =>> store "processed data" (extracted frames)
t_dir, v_dir = Path(artifact_path) / name / "train", Path(artifact_path) / name / "val"
t_registry, v_registry = t_dir / "registry.json", v_dir / "registry.json"
# Short-Circuit / Caching Logic
if t_registry.exists() and v_registry.exists():
return t_registry, v_registry, t_dir, v_dir
# Setup / Book-Keeping
os.makedirs(t_dir, exist_ok=True)
os.makedirs(v_dir, exist_ok=True)
# Retrieve Image Transforms (pre-serialization, while running "offline" pass); we crop and scale once, so we're
# not overdoing it on disk storage...
pre_transform = get_pre_transform(name, resolution=resolution)
# Open & Extract Video ID & Language Metadata
with open(Path(path) / "something-something-v2-train.json", "r") as f:
annotations = json.load(f)
train_ids, train_lang = [x["id"] for x in annotations], [x["label"] for x in annotations]
with open(Path(path) / "something-something-v2-validation.json", "r") as f:
annotations = json.load(f)[:n_val_videos]
val_ids, val_lang = [x["id"] for x in annotations], [x["label"] for x in annotations]
# Do Dry-Run --> Single-Threaded!
if dry_run:
do_dry_run(
name,
path,
n_train_videos=1000,
n_val_videos=100,
train_ids=train_ids,
val_ids=val_ids,
pre_transform=pre_transform,
)
# Go Go Go =>> Iterate through all videos, dump all frames subject to the following structure:
# |-> data/processed/sth-sth-v2/
# |-> /
# |-> /frames<0...k>.jpg
# We'll track a single metadata file with the map of : ("language", n_frames).
# > To speed up the serialization, we'll use a multiprocessing.Pool and max out CPU workers
with mp.Pool(mp.cpu_count()) as pool:
for k, save, vids, langs in [("train", t_dir, train_ids, train_lang), ("val", v_dir, val_ids, val_lang)]:
overwatch.info(f"\tWriting `{k}` videos to disk...")
# Multiprocess!
process_fn, registration = partial(process_video, name, Path(path), save, pre_transform), {}
for key, value in track(
pool.imap_unordered(process_fn, zip(vids, langs)),
description=f"\t[*] Processing {k}...",
total=len(vids),
transient=True,
):
if key is not None:
registration[key] = value
# Write Registration to Disk
with open(t_registry if k == "train" else v_registry, "w") as f:
json.dump(registration, f)
# Return Paths...
return t_registry, v_registry, t_dir, v_dir
else:
raise NotImplementedError(f"Preprocessing Pipeline for Dataset `{name}` not implemented!")
def preprocess_language(
name: str, train_registry: Path, val_registry: Path, max_lang_len: int, language_model: str, hf_cache: str
) -> None:
"""Phase II of Preprocessing :: Iterate through Language & Normalize/Tokenize to Max Length."""
overwatch.info(f"Phase 2 Preprocessing :: Normalizing & tokenizing language for dataset `{name}`")
t_index, v_index = train_registry.parent / "index.pt", val_registry.parent / "index.pt"
t_json, v_json = train_registry.parent / "index.json", val_registry.parent / "index.json"
# Short-Circuit Logic
if (t_index.exists() and v_index.exists()) or (t_json.exists() and v_json.exists()):
return t_index, v_index
# Grab Language, Retaining Metadata for Building Index Structures...
with open(train_registry, "r") as f:
train_metadata = json.load(f)
train = [(vid, train_metadata[vid]["language"], train_metadata[vid]) for vid in train_metadata]
with open(val_registry, "r") as f:
val_metadata = json.load(f)
val = [(vid, val_metadata[vid]["language"], val_metadata[vid]) for vid in val_metadata]
# Assemble *all* language
language = [x[1] for x in train + val]
# Build AutoTokenizer (from `language_model` identifier)
tokenizer = AutoTokenizer.from_pretrained(language_model, cache_dir=hf_cache)
# If `max_lang_len` not specified, dump some statistics to compute...
if max_lang_len == -1:
# Naively tokenizes and pads to the "maximum length" of _all_ language... long tail is a problem!
encoded_language = tokenizer(language, return_tensors="pt", padding=True)
lengths = encoded_language["attention_mask"].sum(dim=1)
# Compute a histogram of lengths
hist = lengths.float().histc(bins=lengths.max()).int()
overwatch.info(f"Histogram: {hist.numpy().tolist()}")
raise NotImplementedError("Compute max length and update dataset configuration!")
# Otherwise, we've already set the maximum length, so let's use it!
else:
overwatch.info(f"\tTokenizing all language in dataset to maximum length `{max_lang_len}`")
encoded_language = tokenizer(
language, return_tensors="pt", max_length=max_lang_len, truncation=True, padding="max_length"
)
input_ids, attention_mask = encoded_language["input_ids"], encoded_language["attention_mask"]
train_input_ids, train_attention_mask = input_ids[: len(train)], attention_mask[: len(train)]
val_input_ids, val_attention_mask = input_ids[len(train) :], attention_mask[len(train) :]
# Assertion, just to sanity check
assert len(val_input_ids) == len(val_attention_mask) == len(val), "Something went wrong tokenizing language..."
# Compute `index.pt` contents
overwatch.info("\tAssembling `train` and `val` index structures...")
train_pt = {
train[i][0]: {**train[i][2], **{"input_ids": train_input_ids[i], "attention_mask": train_attention_mask[i]}}
for i in range(len(train))
}
val_pt = {
val[i][0]: {**val[i][2], **{"input_ids": val_input_ids[i], "attention_mask": val_attention_mask[i]}}
for i in range(len(val))
}
# Dump structures...
overwatch.info(f"Saving index structures to `{t_index}` and `{v_index}` respectively...")
torch.save(train_pt, t_index)
torch.save(val_pt, v_index)
def jsonify_language(train_registry: Path, val_registry: Path) -> None:
"""Phase 2.5 (Aggregation) :: XLA is weird, won't load torch.Tensors in Dataset; JSONify instead."""
overwatch.info("\tPhase 2 Aggregation :: JSONifying Language Index")
t_index, v_index = train_registry.parent / "index.pt", val_registry.parent / "index.pt"
t_json, v_json = train_registry.parent / "index.json", val_registry.parent / "index.json"
train_json, val_json = {}, {}
# Short-Circuit Logic
if t_json.exists() and v_json.exists():
return
# Load Data, iterate through and "de-tensorize", while building up JSON symmetric structure...
train_data, val_data = torch.load(t_index), torch.load(v_index)
overwatch.info("JSONifying both Train and Validation")
for vid in track(train_data, description="Train Language...", transient=True):
train_json[vid] = {
"language": train_data[vid]["language"],
"n_frames": train_data[vid]["n_frames"],
"input_ids": train_data[vid]["input_ids"].numpy().tolist(),
"attention_mask": train_data[vid]["attention_mask"].numpy().tolist(),
}
for vid in track(val_data, description="Val Language...", transient=True):
val_json[vid] = {
"language": val_data[vid]["language"],
"n_frames": val_data[vid]["n_frames"],
"input_ids": val_data[vid]["input_ids"].numpy().tolist(),
"attention_mask": val_data[vid]["attention_mask"].numpy().tolist(),
}
# Write Data to Disk
overwatch.info("Writing JSON Indices")
with open(t_json, "w") as f:
json.dump(train_json, f)
with open(v_json, "w") as f:
json.dump(val_json, f)
def index(train_registry: Path, val_registry: Path, name: str, artifact_path: str = "data/processed") -> Path:
"""Phase 2.75 (Indexing) :: Pull out language.json & other `absolutely necessary` indices to separate directory."""
overwatch.info("\tPhase 2 Indexing :: Indexing Language & Registry Files =>> Extracting to Separate Directory")
# Create "index" directory...
index_dir = Path(artifact_path) / name / "index"
os.makedirs(index_dir, exist_ok=True)
# Short-Circuit Logic
if (index_dir / "train-language-index.json").exists() and (index_dir / "val-language-index.json").exists():
return index_dir
# Retrieve Language JSON indices (train & validation) & copy to new directory...
t_json, v_json = train_registry.parent / "index.json", val_registry.parent / "index.json"
shutil.copy(t_json, index_dir / "train-language-index.json")
shutil.copy(v_json, index_dir / "val-language-index.json")
return index_dir
def unify_batches(
artifact_path: Path,
name: str,
train_registry: Path,
val_registry: Path,
train_dir: Path,
val_dir: Path,
index_dir: Path,
batch_formats: Tuple[Tuple[str, Tuple[str, ...]], ...],
max_epochs: int = 400,
initial_final_alpha: float = 0.2,
) -> None:
"""Phase III of Preprocessing :: Assemble Batches for *all models* for *all epochs* in a consistent manner."""
overwatch.info("Phase 3 Preprocessing :: Assembling Data-Equivalent Epochs for each Model Format")
# Load Registry Files
with open(train_registry, "r") as f:
train_registrations = json.load(f)
with open(val_registry, "r") as f:
val_registrations = json.load(f)
# Assert last element of `batch_formats` assumes all prior subsets...
full_set_inputs = set(batch_formats[-1][1])
for _, subset_inputs in batch_formats[:-1]:
assert full_set_inputs.issuperset(set(subset_inputs)), "We have a problem with batch formats..."
# Assemble Tracking Data
b_keys, unique_states = {b[0] for b in batch_formats}, set()
# Parse out all "state"-specific elements...
state_elements = [s for s in full_set_inputs if "state_" in s]
do_initial, do_final = "state_initial" in state_elements, "state_final" in state_elements
n_int = len(state_elements) - 2 if ("state_initial" in state_elements and "state_final" in state_elements) else 0
# Serialize Epochs to Disk
overwatch.info("\tSerializing epochs to json file, pointing to image paths on disk via a dictionary...")
for b in b_keys:
os.makedirs(index_dir / b, exist_ok=True)
# We only write the validation epoch once --> held constant across _all_ of training!
overwatch.info("\tWriting Validation Epoch to Disk...")
val_epoch_idx, _, uniq_s = precompute_epoch(
index_dir,
val_registrations,
val_dir,
batch_formats,
do_initial,
do_final,
initial_final_alpha,
n_int,
0,
is_validation=True,
)
# Update Trackers...
if val_epoch_idx != -1:
unique_states |= uniq_s
# Compute length of epochs --> CPU Count should be no higher...
epochs, n_frames_per_epoch = list(range(max_epochs)), -1
# Load "existing" verification file (if possible)
overwatch.info("\tLoading batch verification file (if possible)...")
verified_batches = Path(artifact_path) / name / "verified-batches.json"
if verified_batches.exists():
with open(verified_batches, "r") as f:
missing_epochs_per_format = json.load(f)
# Set epochs list by taking union of missing epochs over formats...
epochs = sorted(list(set().union(*missing_epochs_per_format.values())))
# Dump the big objects into an mp.Manager() so that we can read efficiently from other workers...
overwatch.info("\tPlacing the Train Registry into Shared Memory...")
manager = mp.Manager()
mg_registry = manager.dict(train_registrations)
with mp.Pool(4) as pool:
overwatch.info("\tWriting Train Batches per Epoch to Disk...")
# Create partial function for multiprocessing pool...
precompute_fn = partial(
precompute_epoch,
index_dir,
mg_registry,
train_dir,
batch_formats,
do_initial,
do_final,
initial_final_alpha,
n_int,
)
for epoch_idx, n_frames, uniq_s in pool.imap_unordered(precompute_fn, epochs):
if epoch_idx == -1:
continue
# Update Trackers
unique_states |= uniq_s
n_frames_per_epoch = n_frames
# Statistics only make sense on initial computation... should unify with code above!
overwatch.info(f"Train Uniqueness: {len(unique_states)} States & {len(mg_registry)} Utterances")
overwatch.info(f"Final Statistics :: 1 Epoch has ~ {n_frames_per_epoch} Frames...")
overwatch.info("Preprocessing Complete!")
================================================
FILE: voltron/preprocessing/v1/transforms.py
================================================
"""
transforms.py
Default image/video transformations for various datasets.
"""
from typing import Any, Tuple
import cv2
import numpy as np
import torch
from torchvision.transforms import Compose, ConvertImageDtype, Lambda, Normalize
# Definitions of Video Transformations (Reference: `something-something-v2-baseline`)
class ComposeMix:
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, imgs):
for transformation, scope in self.transforms:
if scope == "img":
for idx, img in enumerate(imgs):
imgs[idx] = transformation(img)
elif scope == "vid":
imgs = transformation(imgs)
else:
raise ValueError("Please specify a valid transformation...")
return imgs
class RandomCropVideo:
def __init__(self, size):
self.size = size
def __call__(self, imgs):
th, tw = self.size
h, w = imgs[0].shape[:2]
x1, y1 = np.random.randint(0, w - tw), np.random.randint(0, h - th)
for idx, img in enumerate(imgs):
imgs[idx] = img[y1 : y1 + th, x1 : x1 + tw]
return imgs
class Scale:
def __init__(self, size):
self.size = size
def __call__(self, img):
return cv2.resize(img, tuple(self.size))
def identity(x):
"""Transform needs to be pickleable for multiprocessing.spawn()."""
return x.float()
def get_pre_transform(dataset: str, resolution: int, scale_factor: float = 1.1) -> ComposeMix:
"""Defines a `pre` transform to be applied *when serializing the images* (first pass)."""
if dataset == "sth-sth-v2":
if scale_factor > 1:
transform = ComposeMix(
[
[Scale((int(resolution * scale_factor), int(resolution * scale_factor))), "img"],
[RandomCropVideo((resolution, resolution)), "vid"],
]
)
else:
transform = ComposeMix(
[
[Scale((int(resolution * scale_factor), int(resolution * scale_factor))), "img"],
]
)
return transform
else:
raise NotImplementedError(f"(Pre) transforms for dataset `{dataset}` not yet implemented!")
def get_online_transform(dataset: str, model_arch: str, normalization: Tuple[Any, Any]) -> Compose:
"""Defines an `online` transform to be applied *when batching the images* (during training/validation)."""
if dataset == "sth-sth-v2":
# Note: R3M does *not* expect normalized 0-1 (then ImageNet normalized) images --> drop the identity.
if model_arch in {"v-r3m", "v-rn3m"}:
return Compose([Lambda(identity)])
else:
return Compose([ConvertImageDtype(torch.float), Normalize(mean=normalization[0], std=normalization[1])])
else:
raise NotImplementedError(f"(Online) transforms for dataset `{dataset} not yet implemented!")
================================================
FILE: voltron/preprocessing/v1/utils.py
================================================
"""
utils.py
Preprocessing utilities, including functions for dry-runs and processing a single video (helpers for multiprocessing
calls down the lines).
"""
import glob
import json
import logging
import os
import sys
import time
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple
import av
import cv2
import numpy as np
from hurry.filesize import alternative, size
from rich.progress import track
from tqdm import tqdm
from voltron.preprocessing.v1.transforms import ComposeMix
# Grab Logger
overwatch = logging.getLogger(__file__)
logging.getLogger("libav").setLevel(logging.ERROR)
# Videos are saved as `train_dir/{vid}/{vid}_idx={i}.jpg
def get_path(save_dir: Path, v: str, i: int) -> str:
return str(save_dir / v / f"{v}_idx={i}.jpg")
def do_dry_run(
name: str,
path: str,
n_train_videos: int,
n_val_videos: int,
train_ids: List[str],
val_ids: List[str],
pre_transform: ComposeMix,
n_samples: int = 1000,
) -> None:
"""Iterates through a small subset of the total dataset, logs n_frames & average image size for estimation."""
dry_run_metrics = {
"n_frames": [],
"jpg_sizes": [],
"n_samples": n_samples,
"time_per_example": [],
"blank": str(Path(path) / "blank.jpg"),
}
if name == "sth-sth-v2":
for k, n_iter, vids in [("train", n_train_videos, train_ids), ("val", n_val_videos, val_ids)]:
for idx in track(range(n_iter), description=f"Reading {k.capitalize()} Videos =>> ", transient=True):
vid = vids[idx]
container = av.open(str(Path(path) / "videos" / f"{vid}.webm"))
try:
imgs = [f.to_rgb().to_ndarray() for f in container.decode(video=0)]
except (RuntimeError, ZeroDivisionError) as e:
overwatch.error(f"{type(e).__name__}: WebM reader cannot open `{vid}.webm` - continuing...")
continue
# Close container
container.close()
# Apply `pre_transform`
imgs = pre_transform(imgs)
# Dry-Run Handling --> write a dummy JPEG to collect size statistics, dump, and move on...
dry_run_metrics["n_frames"].append(len(imgs))
while dry_run_metrics["n_samples"] > 0 and len(imgs) > 0:
img = imgs.pop(0)
cv2.imwrite(str(dry_run_metrics["blank"]), cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
dry_run_metrics["jpg_sizes"].append(os.path.getsize(dry_run_metrics["blank"]))
dry_run_metrics["n_samples"] -= 1
# Compute nice totals for "dry-run" estimation
total_clips = len(train_ids) + len(val_ids)
else:
raise NotImplementedError(f"Dry Run for Dataset `{name}` not yet implemented!")
# Compute Aggregate Statistics and gently exit...
avg_size, avg_frames = np.mean(dry_run_metrics["jpg_sizes"]), int(np.mean(dry_run_metrics["n_frames"]))
overwatch.info("Dry-Run Statistics =>>")
overwatch.info(f"\t> A video has on average `{avg_frames}` frames at {size(avg_size, system=alternative)}")
overwatch.info(f"\t> So - 1 video ~ {size(avg_frames * avg_size, system=alternative)}")
overwatch.info(
f"\t> With the full dataset of {total_clips} Train + Val videos ~"
f" {size(total_clips * avg_frames * avg_size, system=alternative)}"
)
overwatch.info("Dry-Run complete, do what you will... exiting ✌️")
# Remove dummy file...
os.remove(dry_run_metrics["blank"])
sys.exit(0)
def process_video(
name: str, path: Path, save: Path, pre_transform: ComposeMix, item: Tuple[str, str]
) -> Tuple[Optional[str], Optional[Dict[str, Any]]]:
"""Processes a single video file, dumps to series of image files, and returns the registry contents."""
if name == "sth-sth-v2":
# For sth-sth-v2, `item` corresponds to a single video clip, so just a tuple!
vid, lang = item
container, registration = av.open(str(Path(path) / "videos" / f"{vid}.webm")), {"language": lang, "n_frames": 0}
try:
imgs = [f.to_rgb().to_ndarray() for f in container.decode(video=0)]
except (RuntimeError, ZeroDivisionError) as e:
overwatch.error(f"{type(e).__name__}: WebM reader cannot open `{vid}.webm` - skipping...")
return None, None
# Close container
container.close()
# Book-keeping
os.makedirs(save / vid, exist_ok=True)
registration["n_frames"] = len(imgs)
# Early exit (writes are expensive)
if len(glob.glob1(save / vid, "*.jpg")) == len(imgs):
return vid, registration
# Apply `pre_transform` --> write individual frames, register, and return
imgs = pre_transform(imgs)
for i in range(len(imgs)):
cv2.imwrite(get_path(save, vid, i), cv2.cvtColor(imgs[i], cv2.COLOR_RGB2BGR))
# Return title & registration
return vid, registration
else:
raise NotImplementedError(f"Process Video for Dataset `{name}` not yet implemented!")
# ruff: noqa: C901
def precompute_epoch(
index_dir: Path,
registry: Dict[str, Any],
vid_dir: Path,
batch_formats: Tuple[Tuple[str, Tuple[str, ...]], ...],
do_initial: bool,
do_final: bool,
initial_final_alpha: float,
n_int: int,
epoch: int,
is_validation: bool = False,
) -> Tuple[int, int, Optional[Set[str]]]:
index_file = "validation-batches.json" if is_validation else f"train-epoch={epoch}-batches.json"
# Short-Circuit
if all([(index_dir / key / index_file).exists() for key, _ in batch_formats]):
return -1, -1, None
# Random seed is inherited from parent process... we want new randomness w/ each process
np.random.seed((os.getpid() * int(time.time())) % 123456789)
# Create Tracking Variables
unique_states, batches = set(), {b: [] for b, _ in batch_formats}
# Iterate through Registry...
for vid in tqdm(registry.keys(), desc=f"Epoch {epoch}", total=len(registry), position=epoch):
# The initial/final states are sampled from the first [0, \alpha) and final 1-\alpha, 1] percent of the video
n_frames = registry[vid]["n_frames"]
initial_idx, final_idx = 0, n_frames - 1
if do_initial:
initial_idx = np.random.randint(0, np.around(n_frames * initial_final_alpha))
if do_final:
final_idx = np.random.randint(np.around(n_frames * (1 - initial_final_alpha)), n_frames)
# Assertion --> initial_idx < final_idx - len(state_elements)
assert initial_idx < final_idx - n_int, "Initial & Final are too close... no way to sample!"
# Assume remaining elements are just random "interior" states --> sort to get ordering!
sampled_idxs = np.random.choice(np.arange(initial_idx + 1, final_idx), size=n_int, replace=False)
sampled_idxs = sorted(list(sampled_idxs))
# Compile full-set "batch"
retrieved_states = [get_path(vid_dir, vid, x) for x in [initial_idx, *sampled_idxs] + [final_idx]]
# Add batch to index for specific batch_format key...
batches[batch_formats[-1][0]].append({"vid": vid, "states": retrieved_states, "n_frames": n_frames})
unique_states.update(retrieved_states)
# Add all other batch formats to indices...
for key, elements in batch_formats[:-1]:
n_states = len([x for x in elements if "state_" in x])
assert (n_states <= 2) or (
n_states == len(retrieved_states)
), f"Strange value of n_states={n_states} > 2 and not equal to total possible of {len(retrieved_states)}"
# States are all independent -- each of the retrieved states is its own example...
if n_states == 1:
for idx in range(len(retrieved_states)):
batches[key].append({"vid": vid, "state": retrieved_states[idx], "n_frames": n_frames})
# OK-Context is the only "valid" context for n_states == 2
elif n_states == 2:
assert elements == ["state_initial", "state_i", "language"], "n_states = 2 but not 0K context?"
# Append 0th state to each of the remaining sampled contexts (usually 2 or 4)... each pair is an example
for idx in range(1, len(retrieved_states)):
batches[key].append(
{"vid": vid, "states": [retrieved_states[0], retrieved_states[idx]], "n_frames": n_frames}
)
# We're treating the entire sequence of retrieved states as a single example (for TCN/R3M/Temporal Models)
else:
batches[key].append({"vid": vid, "states": retrieved_states, "n_frames": n_frames})
# Write JSON Index directly to disk...
for key in batches:
with open(index_dir / key / index_file, "w") as f:
json.dump(batches[key], f)
return epoch, len(batches["state"]), unique_states
================================================
FILE: voltron/util/__init__.py
================================================
from .checkpointing import CheckpointSaver, do_resume
from .metrics import Metrics
from .utilities import ResumeableDistributedSampler, set_global_seed
================================================
FILE: voltron/util/checkpointing.py
================================================
"""
checkpointing.py
Core utility class for handling model/optimizer serialization & checkpointing -- including resume from checkpoint logic.
Support the following strategies:
- (k, -1, -1) --> Keep only the most recent "k" epoch checkpoints
- (k, m, -1) --> Keep the most recent "k" epoch checkpoints and *every* m epoch checkpoint
- (k, m, s = 2500) --> Keep "k" and "m" subject to above, but also keep *s* step checkpoints for current epoch
"""
import logging
import os
import re
from collections import deque
from pathlib import Path
from typing import Any, Optional, Tuple
import torch
import torch.nn as nn
from torch.optim.optimizer import Optimizer
# Grab Logger
overwatch = logging.getLogger(__file__)
class FixedDeck(deque):
def __init__(self, maxlen: int) -> None:
super().__init__(maxlen=maxlen)
def append(self, x: Any) -> Any:
pop_value = None
if self.__len__() == self.maxlen:
pop_value = self.__getitem__(0)
# Perform parent append and return popped value, if any!
super().append(x)
return pop_value
class CheckpointSaver:
def __init__(self, strategy: Tuple[int, int, int], run_dir: str, is_rank_zero: bool = False) -> None:
"""
Create a checkpoint saver with the provided strategy that saves to the given path.
:param strategy: Strategy, following the (k, -1, -1) -- (k, m, -1) -- (k, m, s) description above.
:param run_dir: Path to root of `run_dir`.
:param is_rank_zero: Boolean whether this process is global zero (no-op if not)!
"""
(self.k, self.m, self.s), self.run_dir, self.is_rank_zero = strategy, run_dir, is_rank_zero
self.recents, self.intervals, self.step_checkpoints = FixedDeck(maxlen=self.k), set(), set()
# If `self.s == -1` --> *Disable* step checkpoints (only at save end of epoch!)
self.enable_step = self.s != -1
# Create "checkpoints" subdirectory
self.path = Path(run_dir) / "checkpoints"
if self.is_rank_zero:
os.makedirs(self.path, exist_ok=True)
# Populate `step_checkpoints` on __init__ (if resuming *within* an epoch!)
self.step_checkpoints.update([c for c in self.path.iterdir() if "local-epoch=" in str(c)])
# Created Saver...
overwatch.info(f"Created CheckpointSaver with `k = {self.k}` -- `m = {self.m}` -- s = {self.s}!")
def save(
self,
epoch: int,
is_local_step: bool,
model: nn.Module,
optimizer: Optimizer,
duration: int,
local_step: Optional[int] = None,
train_loss: Optional[float] = None,
val_loss: Optional[float] = None,
) -> None:
"""Performs a global zero save operation, unlinking stale checkpoints if necessary."""
if not self.is_rank_zero:
return
# Check if saving a `local_step` (within an epoch) or if end of epoch...
if self.enable_step and is_local_step and (local_step % self.s) == 0:
step_checkpoint = self.path / f"local-epoch={epoch}-step={local_step}-t={duration}.pt"
torch.save(
{"model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict()}, step_checkpoint
)
# Update Relevant Trackers...
self.step_checkpoints.add(step_checkpoint)
elif not is_local_step:
if train_loss is None and val_loss is None:
checkpoint = self.path / f"epoch={epoch}-train=inf-val=inf-t={duration}.pt"
else:
checkpoint = self.path / f"epoch={epoch}-train={train_loss:.4f}-val={val_loss:.4f}-t={duration}.pt"
torch.save(
{"model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict()}, checkpoint
)
# Update Relevant Trackers
if epoch % self.m == 0:
self.intervals.add(checkpoint)
# Remove all "step_checkpoints" now that we've made it to the end of an epoch!
while len(self.step_checkpoints) > 0:
os.remove(self.step_checkpoints.pop())
# Add to recents & flush stale checkpoints...
to_remove = self.recents.append(checkpoint)
if to_remove is not None and to_remove not in self.intervals:
os.remove(to_remove)
def do_resume(resume: bool, run_dir: str) -> Tuple[Optional[Path], int, int]:
"""Handle `resume` logic --> consists of retrieving checkpoint_path and epoch/step computation (if resuming)."""
if not resume:
# We're starting a fresh run --> return None for checkpoint_path, resume_epoch = 0, resume_step = 0
return None, 0, 0
# === Auto-Resume Logic ===
# **IMPORTANT**: We're making a few assumptions on resuming that should eventually become explicit checks:
# - `accumulate_grad_batches` is exactly the same when resuming; this means:
# + `model_cfg.effective_bsz`, `model_cfg.fabric_bsz`, & `accelerator_cfg.num_accelerators` are the same!
# - The Weights & Biases directory `run_dir/wandb` only contains a *single run*
# - The `param_groups` in `optimizer.state_dict()` are exactly the same across resumes!
# + This means that (and generally should be true for resuming altogether) the architecture is the same!
# - The `cfg.seed` should be the same (again, should generally be true...)
all_checkpoints_path, resume_checkpoint, resume_epoch, resume_step = Path(run_dir) / "checkpoints", None, 0, 0
if all_checkpoints_path.exists() and any(all_checkpoints_path.iterdir()):
# Parse out the latest "complete" epoch checkpoint, as well as any "local step" checkpoints...
checkpoints = list(all_checkpoints_path.iterdir())
complete_checkpoint, complete_epoch = max(
[
(c, int(re.search("epoch=(.+?)-train", c.name).group(1)))
for c in checkpoints
if "local-epoch=" not in str(c)
],
key=lambda x: x[1],
)
# Case 1 :: We have "local step" checkpoints --> will always override any "full epoch" checkpoints...
local = [
(
c,
int(re.search("local-epoch=(.+?)-step", c.name).group(1)),
int(re.search("step=(.+?)[.-]", c.name).group(1)),
)
for c in checkpoints
if "local-epoch=" in str(c)
]
if len(local) > 0:
# Parse out (epoch, "highest" step) + assert no great "full epoch" checkpoint exists!
resume_checkpoint, resume_epoch, resume_step = max(local, key=lambda x: x[1:])
assert resume_epoch == complete_epoch, "Epoch mismatch in `resume` from local_step!"
# Case 2 :: Otherwise, we're just going to start with the last "complete" epoch
else:
resume_checkpoint, resume_epoch = complete_checkpoint, complete_epoch
return resume_checkpoint, resume_epoch, resume_step
================================================
FILE: voltron/util/metrics.py
================================================
"""
metrics.py
Utility classes defining Metrics containers with model-specific logging to various endpoints (JSONL local logs, W&B).
"""
import os
import re
import time
from abc import ABC, abstractmethod
from collections import deque
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import jsonlines
import numpy as np
import torch
import wandb
from voltron.conf import TrackingConfig
# === Define Loggers (`Logger` is an abstract base class) ===
class Logger(ABC):
def __init__(self, run_id: str, hparams: Dict[str, Any], is_rank_zero: bool = False) -> None:
self.run_id, self.hparams, self.is_rank_zero = run_id, hparams, is_rank_zero
@abstractmethod
def write_hyperparameters(self) -> None:
raise NotImplementedError("Logger is an abstract class!")
@abstractmethod
def write(self, global_step: int, metrics: Dict[str, Union[int, float]]) -> None:
raise NotImplementedError("Logger is an abstract class!")
def finalize(self) -> None:
time.sleep(1)
class JSONLinesLogger(Logger):
def write_hyperparameters(self) -> None:
if not self.is_rank_zero:
return
# Only log if `is_rank_zero`
with jsonlines.open(f"{self.run_id}.jsonl", mode="w", sort_keys=True) as js_logger:
js_logger.write(
{
"run_id": self.run_id,
"start_time": datetime.now().strftime("%m-%d-%H:%M"),
"hparams": self.hparams,
}
)
def write(self, global_step: int, metrics: Dict[str, Union[int, float]]) -> None:
if not self.is_rank_zero:
return
# Only log if `is_rank_zero`
with jsonlines.open(f"{self.run_id}.jsonl", mode="a", sort_keys=True) as js_logger:
js_logger.write(metrics)
class WeightsBiasesLogger(Logger):
def __init__(
self,
run_id: str,
hparams: Dict[str, Any],
tracking_cfg: TrackingConfig,
tags: List[str],
resume: bool = False,
resume_id: Optional[str] = None,
is_rank_zero: bool = False,
) -> None:
super().__init__(run_id, hparams, is_rank_zero)
self.tracking_cfg, self.tags, self.resume, self.resume_id = tracking_cfg, tags, resume, resume_id
self.path = Path(os.getcwd() if self.tracking_cfg.directory is None else self.tracking_cfg.directory)
# Handle (Automatic) Resume if `resume = True`
if self.resume and self.resume_id is None:
wandb_path = self.path / "wandb"
if wandb_path.exists() and any((wandb_path / "latest-run").iterdir()):
# Parse unique `run_id` from the `.wandb.` file...
wandb_fns = [f.name for f in (wandb_path / "latest-run").iterdir() if f.name.endswith(".wandb")]
assert len(wandb_fns) == 1, f"There should only be 1 `.wandb.` file... found {len(wandb_fns)}!"
# Regex Match on `run-{id}.wandb`
self.resume_id = re.search("run-(.+?).wandb", wandb_fns[0]).group(1)
elif wandb_path.exists():
raise ValueError("Starting Training from Scratch with Preexisting W&B Directory; Remove to Continue!")
# Call W&B.init()
self.initialize()
def initialize(self) -> None:
"""Run W&B.init on the guarded / rank-zero process."""
if not self.is_rank_zero:
return
# Only initialize / log if `is_rank_zero`
wandb.init(
project=self.tracking_cfg.project,
entity=self.tracking_cfg.entity,
config=self.hparams,
name=self.run_id,
dir=self.path,
tags=self.tags,
notes=self.tracking_cfg.notes,
resume="allow" if self.resume else False,
id=self.resume_id,
)
def write_hyperparameters(self) -> None:
if not self.is_rank_zero:
return
# Only log if `is_rank_zero`
wandb.config = self.hparams
def write(self, global_step: int, metrics: Dict[str, Union[int, float]]) -> None:
if not self.is_rank_zero:
return
# Only log if `is_rank_zero`
wandb.log(metrics, step=global_step)
def finalize(self) -> None:
wandb.finish()
time.sleep(150)
# === Core Metrics Container :: Responsible for Initializing Loggers and Compiling/Pushing Metrics ===
class Metrics:
def __init__(
self,
active_loggers: List[str],
run_id: str,
hparams: Dict[str, Any],
model_arch: str,
is_rank_zero: bool,
tracking_cfg: Optional[TrackingConfig] = None,
tags: Optional[List[str]] = None,
resume: bool = False,
resume_id: Optional[str] = None,
window: int = 128,
) -> None:
"""High-Level Container Logic for Metrics Logging; logic defined for each model architecture!"""
self.model_arch, self.is_rank_zero, self.window = model_arch, is_rank_zero, window
# Initialize Loggers
self.loggers = []
for log_type in active_loggers:
if log_type == "jsonl":
logger = JSONLinesLogger(run_id, hparams, is_rank_zero=is_rank_zero)
elif log_type == "wandb":
logger = WeightsBiasesLogger(
run_id, hparams, tracking_cfg, tags, resume, resume_id, is_rank_zero=is_rank_zero
)
else:
raise ValueError(f"Logger `{log_type}` is not defined!")
# Add Hyperparameters --> Add to `self.loggers`
logger.write_hyperparameters()
self.loggers.append(logger)
# Create Universal Trackers
self.global_step, self.start_time, self.resume_time, self.step_start_time = 0, time.time(), 0, time.time()
self.tracker = {
"loss": deque(maxlen=self.window),
"lr": [],
"step_time": deque(maxlen=self.window),
}
# Create Model-Specific Trackers
if self.model_arch == "v-mvp":
self.tracker.update({"reconstruction_loss": deque(maxlen=self.window)})
elif self.model_arch in {"v-r3m", "v-rn3m"}:
self.tracker.update(
{
"tcn_loss": deque(maxlen=self.window),
"reward_loss": deque(maxlen=self.window),
"l1_loss": deque(maxlen=self.window),
"l2_loss": deque(maxlen=self.window),
"tcn_accuracy": deque(maxlen=self.window),
"reward_accuracy": deque(maxlen=self.window),
}
)
elif self.model_arch == "v-cond":
self.tracker.update({"reconstruction_loss": deque(maxlen=self.window)})
elif self.model_arch == "v-dual":
self.tracker.update(
{
"reconstruction_loss": deque(maxlen=self.window),
"zero_reconstruction_loss": deque(maxlen=self.window),
"k_reconstruction_loss": deque(maxlen=self.window),
}
)
elif self.model_arch == "v-gen":
self.tracker.update(
{
"reconstruction_loss": deque(maxlen=self.window),
"zero_reconstruction_loss": deque(maxlen=self.window),
"k_reconstruction_loss": deque(maxlen=self.window),
"lm_loss": deque(maxlen=self.window),
"lm_ppl": deque(maxlen=self.window),
}
)
else:
raise ValueError(f"Metrics for Model `{self.model_arch}` are not implemented!")
def itemize(self) -> Dict[str, torch.Tensor]:
"""Utility method for converting `deque[torch.Tensor] --> mean over Tensors."""
return {
k: torch.stack(list(v)).mean().item()
for k, v in self.tracker.items()
if k not in {"loss", "lr", "step_time"}
}
def log(self, global_step: int, metrics: Dict[str, Union[int, float]]) -> None:
for logger in self.loggers:
logger.write(global_step, metrics)
def finalize(self) -> None:
for logger in self.loggers:
logger.finalize()
def get_status(self, epoch: int, loss: Optional[torch.Tensor] = None) -> str:
lr = self.tracker["lr"][-1] if len(self.tracker["lr"]) > 0 else 0
if loss is None:
return f"=>> [Epoch {epoch:03d}] Global Step {self.global_step:06d} =>> LR :: {lr:.6f}"
# Otherwise, embed `loss` in status!
return f"=>> [Epoch {epoch:03d}] Global Step {self.global_step:06d} =>> LR :: {lr:.6f} -- Loss :: {loss:.4f}"
def commit(
self,
*,
global_step: Optional[int] = None,
resume_time: Optional[int] = None,
lr: Optional[float] = None,
update_step_time: bool = False,
**kwargs,
) -> None:
"""Update all metrics in `self.tracker` by iterating through special positional arguments & kwargs."""
if not self.is_rank_zero:
return
# Special Positional Arguments
if global_step is not None:
self.global_step = global_step
if resume_time is not None:
self.resume_time = resume_time
if lr is not None:
self.tracker["lr"].append(lr)
if update_step_time:
self.tracker["step_time"].append(time.time() - self.step_start_time)
self.step_start_time = time.time()
# Generic Keyword Arguments
for key, value in kwargs.items():
self.tracker[key].append(value.detach())
def push(self, epoch: int) -> str:
"""Push current metrics to loggers with model-specific handling."""
if not self.is_rank_zero:
return
loss = torch.stack(list(self.tracker["loss"])).mean().item()
step_time, lr = np.mean(list(self.tracker["step_time"])), self.tracker["lr"][-1]
status = self.get_status(epoch, loss)
# Model-Specific Handling
itemized = self.itemize()
if self.model_arch == "v-mvp":
self.log(
self.global_step,
metrics={
"Pretrain/Step": self.global_step,
"Pretrain/Epoch": epoch,
"Pretrain/V-MVP Train Loss": loss,
"Pretrain/Reconstruction Loss": itemized["reconstruction_loss"],
"Pretrain/Learning Rate": lr,
"Pretrain/Step Time": step_time,
},
)
elif self.model_arch in {"v-r3m", "v-rn3m"}:
self.log(
self.global_step,
metrics={
"Pretrain/Step": self.global_step,
"Pretrain/Epoch": epoch,
f"Pretrain/V-{'R3M' if self.model_arch == 'v-r3m' else 'RN3M'} Train Loss": loss,
"Pretrain/TCN Loss": itemized["tcn_loss"],
"Pretrain/Reward Loss": itemized["reward_loss"],
"Pretrain/L1 Loss": itemized["l1_loss"],
"Pretrain/L2 Loss": itemized["l2_loss"],
"Pretrain/TCN Accuracy": itemized["tcn_accuracy"],
"Pretrain/Reward Accuracy": itemized["reward_accuracy"],
"Pretrain/Learning Rate": lr,
"Pretrain/Step Time": step_time,
},
)
elif self.model_arch == "v-cond":
self.log(
self.global_step,
metrics={
"Pretrain/Step": self.global_step,
"Pretrain/Epoch": epoch,
"Pretrain/V-Cond Train Loss": loss,
"Pretrain/Reconstruction Loss": itemized["reconstruction_loss"],
"Pretrain/Learning Rate": lr,
"Pretrain/Step Time": step_time,
},
)
elif self.model_arch == "v-dual":
self.log(
self.global_step,
metrics={
"Pretrain/Step": self.global_step,
"Pretrain/Epoch": epoch,
"Pretrain/V-Dual Train Loss": loss,
"Pretrain/Reconstruction Loss": itemized["reconstruction_loss"],
"Pretrain/Zero Reconstruction Loss": itemized["zero_reconstruction_loss"],
"Pretrain/K Reconstruction Loss": itemized["k_reconstruction_loss"],
"Pretrain/Learning Rate": lr,
"Pretrain/Step Time": step_time,
},
)
elif self.model_arch == "v-gen":
self.log(
self.global_step,
metrics={
"Pretrain/Step": self.global_step,
"Pretrain/Epoch": epoch,
"Pretrain/V-Gen Train Loss": loss,
"Pretrain/Reconstruction Loss": itemized["reconstruction_loss"],
"Pretrain/Zero Reconstruction Loss": itemized["zero_reconstruction_loss"],
"Pretrain/K Reconstruction Loss": itemized["k_reconstruction_loss"],
"Pretrain/CLM Loss": itemized["lm_loss"],
"Pretrain/CLM Perplexity": itemized["lm_ppl"],
"Pretrain/LM Loss": itemized["lm_loss"],
"Pretrain/LM Perplexity": itemized["lm_ppl"],
"Pretrain/Learning Rate": lr,
"Pretrain/Step Time": step_time,
},
)
else:
raise ValueError(f"Metrics.push() for Model `{self.model_arch}` is not implemented!")
return status
def push_epoch(self, epoch: int, val_loss: torch.Tensor) -> Tuple[str, torch.Tensor, int]:
"""End-of-Epoch => Push accumulated metrics to loggers with model-specific handling."""
if not self.is_rank_zero:
return
# Compute End-of-Epoch Specialized Metrics
loss, step_time = torch.stack(list(self.tracker["loss"])).mean(), np.mean(list(self.tracker["step_time"]))
lr, duration = self.tracker["lr"][-1], int(time.time() - self.start_time) + self.resume_time
epoch_status = (
f"[Epoch {epoch:03d}] Global Step {self.global_step:06d} =>> LR :: {lr:.6f} -- Loss :: {loss:.4f} "
f"-- Val Loss :: {val_loss:.4f} -- Total Time (sec) :: {duration}"
)
# Log for Model
p_arch = {
"v-mvp": "MVP",
"v-r3m": "R3M (ViT)",
"v-rn3m": "R3M (RN)",
"v-cond": "V-Cond",
"v-dual": "V-Dual",
"v-gen": "V-Gen",
}[self.model_arch]
self.log(
self.global_step,
metrics={
"Pretrain/Step": self.global_step,
"Pretrain/Epoch": epoch,
"Pretrain/Training Duration": duration,
f"Pretrain/{p_arch} Train Epoch Loss": loss.item(),
f"Pretrain/{p_arch} Train Loss": loss.item(),
f"Pretrain/{p_arch} Validation Loss": val_loss.item(),
"Pretrain/Learning Rate": lr,
"Pretrain/Step Time": step_time,
},
)
return epoch_status, loss, duration
================================================
FILE: voltron/util/utilities.py
================================================
"""
utilities.py
General utilities for randomness, distributed training, and miscellaneous checks in PyTorch.
=== Randomness ===
Random `seed_everything` functionality is taken directly from PyTorch-Lighting:
> Ref: https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/utilities/seed.py
This is pretty important to get right if we're every randomly generating our masks (or prefix dropout) inside our
Dataset __getitem__() with multiple workers... if not handled properly, we will get repeated augmentations anytime
we inject randomness from non-PyTorch sources (e.g., numpy, random)!
> Ref: https://tanelp.github.io/posts/a-bug-that-plagues-thousands-of-open-source-ml-projects/
=== Distributed / DDP Training ====
Utilities provide a standard API across single-GPU/multi-GPU/multi-node training. Assumes that code is running with
one of the following strategies:
- Single Process (on CPU?, GPU)
- DDP (GPU, Multi-Node GPU) --> uses the `torchrun`/`torch.distributed` API & semantics
Key Terminology
-> World Size :: Total number of processes distributed over (# nodes x # devices) -- assumed homogenous!
-> Rank :: Integer index of current process in the total world size
-> Local Rank :: Local index on given node in [0, Devices per Node]
"""
import os
import random
from typing import Callable, Iterator, Optional, TypeVar
import numpy as np
import torch
from torch.utils.data import Dataset
from torch.utils.data.distributed import DistributedSampler
T_co = TypeVar("T_co", covariant=True)
# === Randomness ===
def worker_init_function(worker_id: int) -> None:
"""
Borrowed directly from PyTorch-Lightning; inspired by this issue comment in the PyTorch repo:
> Ref: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562
Intuition: You can think of the seed sequence spawn function as a "janky" torch.Generator() or jax.PRNGKey that
you can run iterative splitting on to get new (predictable) randomness.
:param worker_id: Identifier for the given worker [0, num_workers) for the Dataloader in question.
"""
# Get current `rank` (if running distributed) and `process_seed`
global_rank, process_seed = int(os.environ["LOCAL_RANK"]), torch.initial_seed()
# Back out the "base" (original) seed - the per-worker seed is set in PyTorch:
# > https://pytorch.org/docs/stable/data.html#data-loading-randomness
base_seed = process_seed - worker_id
# "Magic" code --> basically creates a seed sequence that mixes different "sources" and seeds every library...
seed_seq = np.random.SeedSequence([base_seed, worker_id, global_rank])
# Use 128 bits (4 x 32-bit words) to represent seed --> generate_state(k) produces a `k` element array!
np.random.seed(seed_seq.generate_state(4))
# Spawn distinct child sequences for PyTorch (reseed) and stdlib random
torch_seed_seq, random_seed_seq = seed_seq.spawn(2)
# Torch Manual seed takes 64 bits (so just specify a dtype of uint64
torch.manual_seed(torch_seed_seq.generate_state(1, dtype=np.uint64)[0])
# Use 128 Bits for `random`, but express as integer instead of as an array
random_seed = (random_seed_seq.generate_state(2, dtype=np.uint64).astype(list) * [1 << 64, 1]).sum()
random.seed(random_seed)
def set_global_seed(seed: int, get_worker_init_fn: bool = False) -> Optional[Callable[[int], None]]:
"""Sets seed for all randomness libraries (mostly random, numpy, torch) and produces a `worker_init_fn`"""
assert np.iinfo(np.uint32).min < seed < np.iinfo(np.uint32).max, "Seed outside the np.uint32 bounds!"
# Set Seed as an Environment Variable
os.environ["EXPERIMENT_GLOBAL_SEED"] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
return worker_init_function if get_worker_init_fn else None
# === Distributed Training ===
class ResumeableDistributedSampler(DistributedSampler):
def __init__(
self,
seen_examples: int,
resume_epoch: int,
dataset: Dataset,
num_replicas: int,
rank: int,
shuffle: bool = True,
seed: int = 0,
) -> None:
super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, seed=seed)
self.seen_examples, self.resume_epoch, self.do_resume = seen_examples, resume_epoch, True
# Set `seen_examples_per_replica` --> this is necessary for when we re-wrap the iterator in self.__iter__()
# > Note: `seen_examples` is across _all_ replicas --> so divide!
self.seen_examples_per_replica = self.seen_examples // self.num_replicas
def __iter__(self) -> Iterator[T_co]:
epoch_iterator = super().__iter__()
if self.do_resume:
# Unpack iterator --> list, slice off the first `seen_examples_per_replica` examples, and re-wrap!
leftover_idxs = list(epoch_iterator)[self.seen_examples_per_replica :]
return iter(leftover_idxs)
else:
return epoch_iterator
def __len__(self) -> int:
if self.do_resume:
# Remove the "seen" sample from self.num_samples; num_samples is *per replica*!
return self.num_samples - self.seen_examples_per_replica
else:
return self.num_samples
def set_epoch(self, epoch: int) -> None:
# If epoch != self.resume_epoch --> we're in "regular DistributedSampler" mode (just a wrapper class)
# > Intuition: We should *only* truncate examples on the first epoch upon resuming!
self.epoch = epoch
if self.epoch != self.resume_epoch:
self.do_resume = False
================================================
FILE: voltron/util/v1/__init__.py
================================================
================================================
FILE: voltron/util/v1/checkpointing.py
================================================
"""
checkpointing.py
XLA-specific utility class for handling model/optimizer serialization & checkpointing.
Support the following strategies:
- (k, -1, -1) --> Keep only the most recent "k" epoch checkpoints
- (k, m, -1) --> Keep the most recent "k" epoch checkpoints and *every* m epoch checkpoint
- (k, m, s = 2500) --> Keep "k" and "m" subject to above, but also keep *s* step checkpoints for current epoch
"""
import os
from collections import deque
from pathlib import Path
from typing import Any, Optional, Tuple
import torch.nn as nn
from torch.optim.optimizer import Optimizer
class FixedDeck(deque):
def __init__(self, maxlen: int) -> None:
super().__init__(maxlen=maxlen)
def append(self, x: Any) -> Any:
pop_value = None
if self.__len__() == self.maxlen:
pop_value = self.__getitem__(0)
# Perform parent append and return popped value, if any!
super().append(x)
return pop_value
class XLACheckpointSaver:
def __init__(self, strategy: Tuple[int, int, int], run_dir: str) -> None:
"""
Create a checkpoint saver with the provided strategy that saves to the given path, with XLA-specific handling.
:param strategy: Strategy, following the (k, -1, -1) -- (k, m, -1) -- (k, m, s) description above.
:param run_dir: Path to root of `run_dir`
"""
import torch_xla.core.xla_model as xm
(self.k, self.m, self.s), self.run_dir = strategy, run_dir
self.recents, self.intervals, self.step_checkpoints = FixedDeck(maxlen=self.k), set(), set()
# If `self.s` is -1 --> disable step_checkpoints
self.enable_step = self.s != -1
# Create "checkpoints" subdirectory
self.path = Path(run_dir) / "checkpoints"
if xm.is_master_ordinal(local=False):
os.makedirs(self.path, exist_ok=True)
# Populate `step_checkpoints` on __init__ (if resuming *within* an epoch...)
self.step_checkpoints.update([c for c in self.path.iterdir() if "local-epoch=" in str(c)])
# Create Saver
xm.master_print(f"Created Saver w/ `k` = {self.k}, `m` = {self.m}`, `s` = {self.s}!")
def save(
self,
epoch: int,
is_local_step: bool,
model: nn.Module,
optimizer: Optimizer,
duration: int,
local_step: Optional[int] = None,
train_loss: Optional[float] = None,
val_loss: Optional[float] = None,
) -> None:
"""Performs the save operation, unlinking existing stale checkpoints, if necessary."""
import torch_xla.core.xla_model as xm
# Check if saving a `local_step` (within an epoch) or if saving an `epoch`
if self.enable_step and is_local_step and (local_step % self.s) == 0:
# Create filename
step_checkpoint = self.path / f"local-epoch={epoch}-step={local_step}-t={duration}.pt"
# Perform actual save action...
# > IMPORTANT --> XLA/XM will throw an error if optimizer has "param_groups" so only save "state"...
xm.save([model.state_dict(), optimizer.state_dict()["state"]], step_checkpoint)
if xm.is_master_ordinal(local=False):
self.step_checkpoints.add(step_checkpoint)
elif not is_local_step:
# Create filename
if train_loss is None and val_loss is None:
checkpoint = self.path / f"epoch={epoch}-train=inf-val=inf-t={duration}.pt"
else:
checkpoint = self.path / f"epoch={epoch}-train={train_loss:.4f}-val={val_loss:.4f}-t={duration}.pt"
# Perform actual save action...
# > IMPORTANT --> XLA/XM will throw an error if optimizer has "param_groups" so only save "state"...
xm.save([model.state_dict(), optimizer.state_dict()["state"]], checkpoint)
if xm.is_master_ordinal(local=False):
# Conditional Check for M -- Keep if modulated by interval
if epoch % self.m == 0:
self.intervals.add(checkpoint)
# Remove all "step_checkpoints" now that we successfully made it to the end of the epoch!
while len(self.step_checkpoints) > 0:
os.remove(self.step_checkpoints.pop())
# Finally, recency add & unlink/delete if necessary
to_remove = self.recents.append(checkpoint)
if to_remove is not None and to_remove not in self.intervals:
os.remove(to_remove)
================================================
FILE: voltron/util/v1/distributed.py
================================================
"""
distributed.py
Key distributed utilities; notably provides a standard API for getting relevant data from either CPU/GPU or XLA (TPU)
devices, since the underlying implementation does differ substantially.
Assumes that code is running with one of the following strategies:
- Single Process (on CPU, GPU)
- DDP (CPU, GPU)... uses the torch.distributed.launch API & semantics
- XMP Spawn (TPU)... TPU based XLA + Multiprocessing Spawn semantics
Key Terminology
-> World Size :: Total number of processes distributed over (# nodes x # devices) -- assumed homogenous!
-> Rank :: Integer index of current process in the total world size
-> Local Rank :: Local index on given node in [0, Devices per Node]
"""
from importlib.util import find_spec
from typing import Iterator, TypeVar
import torch
from torch.utils.data import Dataset
from torch.utils.data.distributed import DistributedSampler
T_co = TypeVar("T_co", covariant=True)
class ResumeableDistributedSampler(DistributedSampler):
def __init__(
self,
seen_examples: int,
resume_epoch: int,
dataset: Dataset,
num_replicas: int,
rank: int,
shuffle: bool = True,
seed: int = 0,
) -> None:
super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, seed=seed)
self.seen_examples, self.resume_epoch, self.do_resume = seen_examples, resume_epoch, True
# Set `seen_examples_per_replica` --> this is necessary for when we re-wrap the iterator in self.__iter__()
# > Note: `seen_examples` is across _all_ replicas --> so divide!
self.seen_examples_per_replica = self.seen_examples // self.num_replicas
def __iter__(self) -> Iterator[T_co]:
epoch_iterator = super().__iter__()
if self.do_resume:
# Unpack iterator --> list, slice off the first `seen_examples_per_replica` examples, and re-wrap!
leftover_idxs = list(epoch_iterator)[self.seen_examples_per_replica :]
return iter(leftover_idxs)
else:
return epoch_iterator
def __len__(self) -> int:
if self.do_resume:
# Remove the "seen" sample from self.num_samples; num_samples is *per replica*!
return self.num_samples - self.seen_examples_per_replica
else:
return self.num_samples
def set_epoch(self, epoch: int) -> None:
# If epoch != self.resume_epoch --> we're in "regular DistributedSampler" mode (just a wrapper class)
# > Intuition: We should *only* truncate examples on the first epoch upon resuming!
self.epoch = epoch
if self.epoch != self.resume_epoch:
self.do_resume = False
def xla_available() -> bool:
try:
return find_spec("torch_xla") is not None
except ModuleNotFoundError:
return False
def get_rank() -> int:
"""Returns the global rank [0, World Size) of the current process."""
if xla_available():
import torch_xla.core.xla_model as xm
# By default, if XLA is available, assume we're running under XMP Spawn
return xm.get_ordinal()
# Try to get rank via torch.distributed, but catch error if only single process
try:
return torch.distributed.get_rank()
# RuntimeError => not running distributed (single process)
except RuntimeError:
return 0
================================================
FILE: voltron/util/v1/random.py
================================================
"""
random.py
Utilities for dealing with randomness for PyTorch, across devices (CPU, GPU, TPU).
Loosely inspired by functionality in PyTorch-Lightning:
> Ref: https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/utilities/seed.py
This is pretty important to get right if we're every randomly generating our masks (or prefix dropout) inside our
Dataset __getitem__() with multiple workers... if not handled properly, we will get repeated augmentations anytime
we inject randomness from non-PyTorch sources (e.g., numpy, random)!
> Ref: https://tanelp.github.io/posts/a-bug-that-plagues-thousands-of-open-source-ml-projects/
"""
import os
import random
from typing import Callable
import numpy as np
import torch
from voltron.util.v1.distributed import get_rank
def set_global_seed(seed: int) -> Callable[[int], None]:
"""Sets seed for all randomness libraries (mostly random, numpy, torch) and produces a `worker_init_fn`"""
assert np.iinfo(np.uint32).min < seed < np.iinfo(np.uint32).max, "Seed outside the np.uint32 bounds!"
# Set Seed as an Environment Variable
os.environ["EXPERIMENT_GLOBAL_SEED"] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
return worker_init_function
def worker_init_function(worker_id: int) -> None:
"""
Borrowed directly from PyTorch-Lightning; inspired by this issue comment in the PyTorch repo:
> Ref: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562
Intuition: You can think of the seed sequence spawn function as a "janky" torch.Generator() or jax.PRNGKey that
you can run iterative splitting on to get new (predictable) randomness.
:param worker_id: Identifier for the given worker [0, num_workers) for the Dataloader in question.
"""
# Get current `rank` (if running distributed) and `process_seed`
global_rank, process_seed = get_rank(), torch.initial_seed()
# Back out the "base" (original) seed - the per-worker seed is set in PyTorch:
# > https://pytorch.org/docs/stable/data.html#data-loading-randomness
base_seed = process_seed - worker_id
# "Magic" code --> basically creates a seed sequence that mixes different "sources" and seeds every library...
seed_seq = np.random.SeedSequence([base_seed, worker_id, global_rank])
# Use 128 bits (4 x 32-bit words) to represent seed --> generate_state(k) produces a `k` element array!
np.random.seed(seed_seq.generate_state(4))
# Spawn distinct child sequences for PyTorch (reseed) and stdlib random
torch_seed_seq, random_seed_seq = seed_seq.spawn(2)
# Torch Manual seed takes 64 bits (so just specify a dtype of uint64
torch.manual_seed(torch_seed_seq.generate_state(1, dtype=np.uint64)[0])
# Use 128 Bits for `random`, but express as integer instead of as an array
random_seed = (random_seed_seq.generate_state(2, dtype=np.uint64).astype(list) * [1 << 64, 1]).sum()
random.seed(random_seed)
================================================
FILE: voltron/util/v1/xla_logger.py
================================================
"""
xla_logger.py
Utility class defining various XLA logging methods (called within marked closures), for logging metrics periodically
through training & validation.
"""
from typing import List
import jsonlines
import numpy as np
import torch
import torch_xla.core.xla_model as xm
import wandb
# === Generic (Cross-Model) Epoch End Update ===
def log_epoch_end_update(
arch: str,
epoch: int,
global_step: int,
run_id: str,
duration: int,
train_losses: List[torch.Tensor],
val_loss: float,
lr: float,
step_times: List[float],
) -> None:
train_loss = torch.stack(list(train_losses)).mean()
average_step_time = np.mean(list(step_times))
# Console Logging --> Unclear if it'll work?
xm.master_print(
f"Epoch {epoch:03d}, Global Step {global_step:06d} || LR :: {lr:.6f} -- Train Loss :: {train_loss:.4f} "
f"-- Val Loss :: {val_loss:.4f} -- Total Time (sec) :: {duration}"
)
# Get Log-Friendly Arch
p_arch = {
"v-mvp": "MVP",
"v-r3m": "R3M (ViT)",
"v-rn3m": "R3M (RN)",
"v-cond": "V-Cond",
"v-dual": "V-Dual",
"v-gen": "V-Gen",
}[arch]
# Log to Weights & Biases & JSONL
blob = {
"Pretrain/Step": global_step,
"Pretrain/Epoch": epoch,
"Pretrain/Training Duration": duration,
"Pretrain/Step Time": average_step_time,
f"Pretrain/{p_arch} Train Epoch Loss": train_loss.item(),
f"Pretrain/{p_arch} Train Loss": train_loss.item(),
f"Pretrain/{p_arch} Validation Loss": val_loss,
"Pretrain/Learning Rate": lr,
}
wandb.log(blob, step=global_step)
with jsonlines.open(f"{run_id}.jsonl", mode="a", sort_keys=True) as js_logger:
js_logger.write(blob)
# === Data-Locked Reproductions ===
def log_vmvp_train_update(
epoch: int,
global_step: int,
run_id: str,
train_losses: List[torch.Tensor],
lr: float,
reconstruction_losses: List[torch.Tensor],
step_times: List[float],
) -> None:
train_loss = torch.stack(list(train_losses)).mean()
reconstruction_loss = torch.stack(list(reconstruction_losses)).mean()
average_step_time = np.mean(list(step_times))
# Console Logging --> Just log the aggregated train loss...
xm.master_print(
f"Epoch {epoch:03d}, Global Step {global_step:06d} || LR :: {lr:.6f} -- Train Loss :: {train_loss:.4f}"
)
# Log to Weights & Biases + JSONL
blob = {
"Pretrain/Step": global_step,
"Pretrain/Epoch": epoch,
"Pretrain/V-MVP Train Loss": train_loss.item(),
"Pretrain/Reconstruction Loss": reconstruction_loss.item(),
"Pretrain/Learning Rate": lr,
"Pretrain/Step Time": average_step_time,
}
wandb.log(blob, step=global_step)
with jsonlines.open(f"{run_id}.jsonl", mode="a", sort_keys=True) as js_logger:
js_logger.write(blob)
def log_vr3m_train_update(
epoch: int,
global_step: int,
run_id: str,
train_losses: List[torch.Tensor],
lr: float,
tcn_losses: List[torch.Tensor],
reward_losses: List[torch.Tensor],
l1_losses: List[torch.Tensor],
l2_losses: List[torch.Tensor],
tcn_accuracies: List[torch.Tensor],
reward_accuracies: List[torch.Tensor],
step_times: List[float],
) -> None:
train_loss = torch.stack(list(train_losses)).mean()
tcn_loss = torch.stack(list(tcn_losses)).mean()
reward_loss = torch.stack(list(reward_losses)).mean()
l1_loss, l2_loss = torch.stack(list(l1_losses)).mean(), torch.stack(list(l2_losses)).mean()
tcn_accuracy = torch.stack(list(tcn_accuracies)).mean()
reward_accuracy = torch.stack(list(reward_accuracies)).mean()
average_step_time = np.mean(list(step_times))
# Console Logging --> Just log the aggregated train loss...
xm.master_print(
f"Epoch {epoch:03d}, Global Step {global_step:06d} || LR :: {lr:.6f} -- Train Loss :: {train_loss:.4f}"
)
# Log to Weights & Biases + JSONL
blob = {
"Pretrain/Step": global_step,
"Pretrain/Epoch": epoch,
"Pretrain/V-R3M Train Loss": train_loss.item(),
"Pretrain/TCN Loss": tcn_loss.item(),
"Pretrain/Reward Loss": reward_loss.item(),
"Pretrain/L1 Loss": l1_loss.item(),
"Pretrain/L2 Loss": l2_loss.item(),
"Pretrain/TCN Accuracy": tcn_accuracy.item(),
"Pretrain/Reward Accuracy": reward_accuracy.item(),
"Pretrain/Learning Rate": lr,
"Pretrain/Step Time": average_step_time,
}
wandb.log(blob, step=global_step)
with jsonlines.open(f"{run_id}.jsonl", mode="a", sort_keys=True) as js_logger:
js_logger.write(blob)
def log_vrn3m_train_update(
epoch: int,
global_step: int,
run_id: str,
train_losses: List[torch.Tensor],
lr: float,
tcn_losses: List[torch.Tensor],
reward_losses: List[torch.Tensor],
l1_losses: List[torch.Tensor],
l2_losses: List[torch.Tensor],
tcn_accuracies: List[torch.Tensor],
reward_accuracies: List[torch.Tensor],
step_times: List[float],
) -> None:
train_loss = torch.stack(list(train_losses)).mean()
tcn_loss = torch.stack(list(tcn_losses)).mean()
reward_loss = torch.stack(list(reward_losses)).mean()
l1_loss, l2_loss = torch.stack(list(l1_losses)).mean(), torch.stack(list(l2_losses)).mean()
tcn_accuracy = torch.stack(list(tcn_accuracies)).mean()
reward_accuracy = torch.stack(list(reward_accuracies)).mean()
average_step_time = np.mean(list(step_times))
# Console Logging --> Just log the aggregated train loss...
xm.master_print(
f"Epoch {epoch:03d}, Global Step {global_step:06d} || LR :: {lr:.6f} -- Train Loss :: {train_loss:.4f}"
)
# Log to Weights & Biases + JSONL
blob = {
"Pretrain/Step": global_step,
"Pretrain/Epoch": epoch,
"Pretrain/V-RN3M Train Loss": train_loss.item(),
"Pretrain/TCN Loss": tcn_loss.item(),
"Pretrain/Reward Loss": reward_loss.item(),
"Pretrain/L1 Loss": l1_loss.item(),
"Pretrain/L2 Loss": l2_loss.item(),
"Pretrain/TCN Accuracy": tcn_accuracy.item(),
"Pretrain/Reward Accuracy": reward_accuracy.item(),
"Pretrain/Learning Rate": lr,
"Pretrain/Step Time": average_step_time,
}
wandb.log(blob, step=global_step)
with jsonlines.open(f"{run_id}.jsonl", mode="a", sort_keys=True) as js_logger:
js_logger.write(blob)
# === Voltron Models ===
def log_vcond_train_update(
epoch: int,
global_step: int,
run_id: str,
train_losses: List[torch.Tensor],
lr: float,
reconstruction_losses: List[torch.Tensor],
step_times: List[float],
) -> None:
train_loss = torch.stack(list(train_losses)).mean()
reconstruction_loss = torch.stack(list(reconstruction_losses)).mean()
average_step_time = np.mean(list(step_times))
# Console Logging --> Just log the aggregated train loss...
xm.master_print(
f"Epoch {epoch:03d}, Global Step {global_step:06d} || LR :: {lr:.6f} -- Train Loss :: {train_loss:.4f}"
)
# Log to Weights & Biases + JSONL
blob = {
"Pretrain/Step": global_step,
"Pretrain/Epoch": epoch,
"Pretrain/V-Cond Train Loss": train_loss.item(),
"Pretrain/Reconstruction Loss": reconstruction_loss.item(),
"Pretrain/Learning Rate": lr,
"Pretrain/Step Time": average_step_time,
}
wandb.log(blob, step=global_step)
with jsonlines.open(f"{run_id}.jsonl", mode="a", sort_keys=True) as js_logger:
js_logger.write(blob)
def log_vdual_train_update(
epoch: int,
global_step: int,
run_id: str,
train_losses: List[torch.Tensor],
lr: float,
reconstruction_losses: List[torch.Tensor],
zero_reconstruction_losses: List[torch.Tensor],
k_reconstruction_losses: List[torch.Tensor],
step_times: List[float],
) -> None:
train_loss = torch.stack(list(train_losses)).mean()
reconstruction_loss = torch.stack(list(reconstruction_losses)).mean()
zero_reconstruction_loss = torch.stack(list(zero_reconstruction_losses)).mean()
k_reconstruction_loss = torch.stack(list(k_reconstruction_losses)).mean()
average_step_time = np.mean(list(step_times))
# Console Logging --> Just log the aggregated train loss...
xm.master_print(
f"Epoch {epoch:03d}, Global Step {global_step:06d} || LR :: {lr:.6f} -- Train Loss :: {train_loss:.4f}"
)
# Log to Weights & Biases + JSONL
blob = {
"Pretrain/Step": global_step,
"Pretrain/Epoch": epoch,
"Pretrain/V-Dual Train Loss": train_loss.item(),
"Pretrain/Reconstruction Loss": reconstruction_loss.item(),
"Pretrain/Zero Reconstruction Loss": zero_reconstruction_loss.item(),
"Pretrain/K Reconstruction Loss": k_reconstruction_loss.item(),
"Pretrain/Learning Rate": lr,
"Pretrain/Step Time": average_step_time,
}
wandb.log(blob, step=global_step)
with jsonlines.open(f"{run_id}.jsonl", mode="a", sort_keys=True) as js_logger:
js_logger.write(blob)
def log_vgen_train_update(
epoch: int,
global_step: int,
run_id: str,
train_losses: List[torch.Tensor],
lr: float,
reconstruction_losses: List[torch.Tensor],
lm_losses: List[torch.Tensor],
lm_ppl: List[torch.Tensor],
zero_reconstruction_losses: List[torch.Tensor],
k_reconstruction_losses: List[torch.Tensor],
step_times: List[float],
) -> None:
train_loss = torch.stack(list(train_losses)).mean()
reconstruction_loss = torch.stack(list(reconstruction_losses)).mean()
lm_loss = torch.stack(list(lm_losses)).mean()
lm_perplexity = torch.stack(list(lm_ppl)).mean()
zero_reconstruction_loss = torch.stack(list(zero_reconstruction_losses)).mean()
k_reconstruction_loss = torch.stack(list(k_reconstruction_losses)).mean()
average_step_time = np.mean(list(step_times))
# Console Logging --> Just log the aggregated train loss...
xm.master_print(
f"Epoch {epoch:03d}, Global Step {global_step:06d} || LR :: {lr:.6f} -- Train Loss :: {train_loss:.4f} --"
f" Reconstruction Loss {reconstruction_loss:.4f} -- LM Loss {lm_loss:.4f}"
)
# Log to Weights & Biases + JSONL
blob = {
"Pretrain/Step": global_step,
"Pretrain/Epoch": epoch,
"Pretrain/V-Gen Train Loss": train_loss.item(),
"Pretrain/Reconstruction Loss": reconstruction_loss.item(),
"Pretrain/CLM Loss": lm_loss.item(),
"Pretrain/CLM Perplexity": lm_perplexity.item(),
"Pretrain/LM Loss": lm_loss.item(),
"Pretrain/LM Perplexity": lm_perplexity.item(),
"Pretrain/Zero Reconstruction Loss": zero_reconstruction_loss.item(),
"Pretrain/K Reconstruction Loss": k_reconstruction_loss.item(),
"Pretrain/Learning Rate": lr,
"Pretrain/Step Time": average_step_time,
}
wandb.log(blob, step=global_step)
with jsonlines.open(f"{run_id}.jsonl", mode="a", sort_keys=True) as js_logger:
js_logger.write(blob)