Repository: FoundationVision/VAR
Branch: main
Commit: 78b95394fc58
Files: 19
Total size: 148.6 KB
Directory structure:
gitextract_24dcq1k1/
├── .gitignore
├── LICENSE
├── README.md
├── dist.py
├── models/
│ ├── __init__.py
│ ├── basic_vae.py
│ ├── basic_var.py
│ ├── helpers.py
│ ├── quant.py
│ ├── var.py
│ └── vqvae.py
├── train.py
├── trainer.py
└── utils/
├── amp_sc.py
├── arg_util.py
├── data.py
├── data_sampler.py
├── lr_control.py
└── misc.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
*.swp
**/__pycache__/**
**/.ipynb_checkpoints/**
.DS_Store
.idea/*
.vscode/*
llava/
_vis_cached/
_auto_*
ckpt/
log/
tb*/
img*/
local_output*
*.pth
*.pth.tar
*.ckpt
*.log
*.txt
*.ipynb
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2024 FoundationVision
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# VAR: a new visual generation method elevates GPT-style models beyond diffusion🚀 & Scaling laws observed📈
<div align="center">
[](https://opensource.bytedance.com/gmpt/t2i/invite)
[](https://arxiv.org/abs/2404.02905)
[](https://huggingface.co/FoundationVision/var)
[](https://paperswithcode.com/sota/image-generation-on-imagenet-256x256?tag_filter=485&p=visual-autoregressive-modeling-scalable-image)
</div>
<p align="center" style="font-size: larger;">
<a href="https://arxiv.org/abs/2404.02905">Visual Autoregressive Modeling: Scalable Image Generation via Next-Scale Prediction</a>
</p>
<div>
<p align="center" style="font-size: larger;">
<strong>NeurIPS 2024 Best Paper</strong>
</p>
</div>
<p align="center">
<img src="https://github.com/FoundationVision/VAR/assets/39692511/9850df90-20b1-4f29-8592-e3526d16d755" width=95%>
<p>
<br>
## News
* **2025-11:** We Release our Text-to-Video generation model **InfinityStar** based on VAR & Infinity, please check [Infinity⭐️](https://github.com/FoundationVision/InfinityStar).
* **2025-11:** 🎉 InfinityStar is accepted as **NeurIPS 2025 Oral.**
* **2025-04:** 🎉 Infinity is accepted as **CVPR 2025 Oral.**
* **2024-12:** 🏆 VAR received **NeurIPS 2024 Best Paper Award**.
* **2024-12:** 🔥 We Release our Text-to-Image research based on VAR, please check [Infinity](https://github.com/FoundationVision/Infinity).
* **2024-09:** VAR is accepted as **NeurIPS 2024 Oral** Presentation.
* **2024-04:** [Visual AutoRegressive modeling](https://github.com/FoundationVision/VAR) is released.
## 🕹️ Try and Play with VAR!
~~We provide a [demo website](https://var.vision/demo) for you to play with VAR models and generate images interactively. Enjoy the fun of visual autoregressive modeling!~~
We provide a [demo website](https://opensource.bytedance.com/gmpt/t2i/invite) for you to play with VAR Text-to-Image and generate images interactively. Enjoy the fun of visual autoregressive modeling!
We also provide [demo_sample.ipynb](demo_sample.ipynb) for you to see more technical details about VAR.
[//]: # (<p align="center">)
[//]: # (<img src="https://user-images.githubusercontent.com/39692511/226376648-3f28a1a6-275d-4f88-8f3e-cd1219882488.png" width=50%)
[//]: # (<p>)
## What's New?
### 🔥 Introducing VAR: a new paradigm in autoregressive visual generation✨:
Visual Autoregressive Modeling (VAR) redefines the autoregressive learning on images as coarse-to-fine "next-scale prediction" or "next-resolution prediction", diverging from the standard raster-scan "next-token prediction".
<p align="center">
<img src="https://github.com/FoundationVision/VAR/assets/39692511/3e12655c-37dc-4528-b923-ec6c4cfef178" width=93%>
<p>
### 🔥 For the first time, GPT-style autoregressive models surpass diffusion models🚀:
<p align="center">
<img src="https://github.com/FoundationVision/VAR/assets/39692511/cc30b043-fa4e-4d01-a9b1-e50650d5675d" width=55%>
<p>
### 🔥 Discovering power-law Scaling Laws in VAR transformers📈:
<p align="center">
<img src="https://github.com/FoundationVision/VAR/assets/39692511/c35fb56e-896e-4e4b-9fb9-7a1c38513804" width=85%>
<p>
<p align="center">
<img src="https://github.com/FoundationVision/VAR/assets/39692511/91d7b92c-8fc3-44d9-8fb4-73d6cdb8ec1e" width=85%>
<p>
### 🔥 Zero-shot generalizability🛠️:
<p align="center">
<img src="https://github.com/FoundationVision/VAR/assets/39692511/a54a4e52-6793-4130-bae2-9e459a08e96a" width=70%>
<p>
#### For a deep dive into our analyses, discussions, and evaluations, check out our [paper](https://arxiv.org/abs/2404.02905).
## VAR zoo
We provide VAR models for you to play with, which are on <a href='https://huggingface.co/FoundationVision/var'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Huggingface-FoundationVision/var-yellow'></a> or can be downloaded from the following links:
| model | reso. | FID | rel. cost | #params | HF weights🤗 |
|:----------:|:-----:|:--------:|:---------:|:-------:|:------------------------------------------------------------------------------------|
| VAR-d16 | 256 | 3.55 | 0.4 | 310M | [var_d16.pth](https://huggingface.co/FoundationVision/var/resolve/main/var_d16.pth) |
| VAR-d20 | 256 | 2.95 | 0.5 | 600M | [var_d20.pth](https://huggingface.co/FoundationVision/var/resolve/main/var_d20.pth) |
| VAR-d24 | 256 | 2.33 | 0.6 | 1.0B | [var_d24.pth](https://huggingface.co/FoundationVision/var/resolve/main/var_d24.pth) |
| VAR-d30 | 256 | 1.97 | 1 | 2.0B | [var_d30.pth](https://huggingface.co/FoundationVision/var/resolve/main/var_d30.pth) |
| VAR-d30-re | 256 | **1.80** | 1 | 2.0B | [var_d30.pth](https://huggingface.co/FoundationVision/var/resolve/main/var_d30.pth) |
| VAR-d36 | 512 | **2.63** | - | 2.3B | [var_d36.pth](https://huggingface.co/FoundationVision/var/resolve/main/var_d36.pth) |
You can load these models to generate images via the codes in [demo_sample.ipynb](demo_sample.ipynb). Note: you need to download [vae_ch160v4096z32.pth](https://huggingface.co/FoundationVision/var/resolve/main/vae_ch160v4096z32.pth) first.
## Installation
1. Install `torch>=2.0.0`.
2. Install other pip packages via `pip3 install -r requirements.txt`.
3. Prepare the [ImageNet](http://image-net.org/) dataset
<details>
<summary> assume the ImageNet is in `/path/to/imagenet`. It should be like this:</summary>
```
/path/to/imagenet/:
train/:
n01440764:
many_images.JPEG ...
n01443537:
many_images.JPEG ...
val/:
n01440764:
ILSVRC2012_val_00000293.JPEG ...
n01443537:
ILSVRC2012_val_00000236.JPEG ...
```
**NOTE: The arg `--data_path=/path/to/imagenet` should be passed to the training script.**
</details>
5. (Optional) install and compile `flash-attn` and `xformers` for faster attention computation. Our code will automatically use them if installed. See [models/basic_var.py#L15-L30](models/basic_var.py#L15-L30).
## Training Scripts
To train VAR-{d16, d20, d24, d30, d36-s} on ImageNet 256x256 or 512x512, you can run the following command:
```shell
# d16, 256x256
torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
--depth=16 --bs=768 --ep=200 --fp16=1 --alng=1e-3 --wpe=0.1
# d20, 256x256
torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
--depth=20 --bs=768 --ep=250 --fp16=1 --alng=1e-3 --wpe=0.1
# d24, 256x256
torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
--depth=24 --bs=768 --ep=350 --tblr=8e-5 --fp16=1 --alng=1e-4 --wpe=0.01
# d30, 256x256
torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
--depth=30 --bs=1024 --ep=350 --tblr=8e-5 --fp16=1 --alng=1e-5 --wpe=0.01 --twde=0.08
# d36-s, 512x512 (-s means saln=1, shared AdaLN)
torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
--depth=36 --saln=1 --pn=512 --bs=768 --ep=350 --tblr=8e-5 --fp16=1 --alng=5e-6 --wpe=0.01 --twde=0.08
```
A folder named `local_output` will be created to save the checkpoints and logs.
You can monitor the training process by checking the logs in `local_output/log.txt` and `local_output/stdout.txt`, or using `tensorboard --logdir=local_output/`.
If your experiment is interrupted, just rerun the command, and the training will **automatically resume** from the last checkpoint in `local_output/ckpt*.pth` (see [utils/misc.py#L344-L357](utils/misc.py#L344-L357)).
## Sampling & Zero-shot Inference
For FID evaluation, use `var.autoregressive_infer_cfg(..., cfg=1.5, top_p=0.96, top_k=900, more_smooth=False)` to sample 50,000 images (50 per class) and save them as PNG (not JPEG) files in a folder. Pack them into a `.npz` file via `create_npz_from_sample_folder(sample_folder)` in [utils/misc.py#L344](utils/misc.py#L360).
Then use the [OpenAI's FID evaluation toolkit](https://github.com/openai/guided-diffusion/tree/main/evaluations) and reference ground truth npz file of [256x256](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz) or [512x512](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/VIRTUAL_imagenet512.npz) to evaluate FID, IS, precision, and recall.
Note a relatively small `cfg=1.5` is used for trade-off between image quality and diversity. You can adjust it to `cfg=5.0`, or sample with `autoregressive_infer_cfg(..., more_smooth=True)` for **better visual quality**.
We'll provide the sampling script later.
## Third-party Usage and Research
***In this pargraph, we cross link third-party repositories or research which use VAR and report results. You can let us know by raising an issue***
(`Note please report accuracy numbers and provide trained models in your new repository to facilitate others to get sense of correctness and model behavior`)
| **Time** | **Research** | **Link** |
|--------------|-------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------|
| [5/12/2025] | [ICML 2025]Continuous Visual Autoregressive Generation via Score Maximization | https://github.com/shaochenze/EAR |
| [5/8/2025] | Generative Autoregressive Transformers for Model-Agnostic Federated MRI Reconstruction | https://github.com/icon-lab/FedGAT |
| [4/7/2025] | FastVAR: Linear Visual Autoregressive Modeling via Cached Token Pruning | https://github.com/csguoh/FastVAR |
| [4/3/2025] | VARGPT-v1.1: Improve Visual Autoregressive Large Unified Model via Iterative Instruction Tuning and Reinforcement Learning | https://github.com/VARGPT-family/VARGPT-v1.1 |
| [3/31/2025] | Training-Free Text-Guided Image Editing with Visual Autoregressive Model | https://github.com/wyf0912/AREdit |
| [3/17/2025] | Next-Scale Autoregressive Models are Zero-Shot Single-Image Object View Synthesizers | https://github.com/Shiran-Yuan/ArchonView |
| [3/14/2025] | Safe-VAR: Safe Visual Autoregressive Model for Text-to-Image Generative Watermarking | https://arxiv.org/abs/2503.11324 |
| [3/3/2025] | [ICML 2025]Direct Discriminative Optimization: Your Likelihood-Based Visual Generative Model is Secretly a GAN Discriminator | https://research.nvidia.com/labs/dir/ddo/ |
| [2/28/2025] | Autoregressive Medical Image Segmentation via Next-Scale Mask Prediction | https://arxiv.org/abs/2502.20784 |
| [2/27/2025] | FlexVAR: Flexible Visual Autoregressive Modeling without Residual Prediction | https://github.com/jiaosiyu1999/FlexVAR |
| [2/17/2025] | MARS: Mesh AutoRegressive Model for 3D Shape Detailization | https://arxiv.org/abs/2502.11390 |
| [1/31/2025] | [ICML 2025]Visual Autoregressive Modeling for Image Super-Resolution | https://github.com/quyp2000/VARSR |
| [1/21/2025] | VARGPT: Unified Understanding and Generation in a Visual Autoregressive Multimodal Large Language Model | https://github.com/VARGPT-family/VARGPT |
| [1/26/2025] | [ICML 2025]Visual Generation Without Guidance | https://github.com/thu-ml/GFT |
| [12/30/2024] | Next Token Prediction Towards Multimodal Intelligence | https://github.com/LMM101/Awesome-Multimodal-Next-Token-Prediction |
| [12/30/2024] | Varformer: Adapting VAR’s Generative Prior for Image Restoration | https://arxiv.org/abs/2412.21063 |
| [12/22/2024] | [ICLR 2025]Distilled Decoding 1: One-step Sampling of Image Auto-regressive Models with Flow Matching | https://github.com/imagination-research/distilled-decoding |
| [12/19/2024] | FlowAR: Scale-wise Autoregressive Image Generation Meets Flow Matching | https://github.com/OliverRensu/FlowAR |
| [12/13/2024] | 3D representation in 512-Byte: Variational tokenizer is the key for autoregressive 3D generation | https://github.com/sparse-mvs-2/VAT |
| [12/9/2024] | CARP: Visuomotor Policy Learning via Coarse-to-Fine Autoregressive Prediction | https://carp-robot.github.io/ |
| [12/5/2024] | [CVPR 2025]Infinity ∞: Scaling Bitwise AutoRegressive Modeling for High-Resolution Image Synthesis | https://github.com/FoundationVision/Infinity |
| [12/5/2024] | [CVPR 2025]Switti: Designing Scale-Wise Transformers for Text-to-Image Synthesis | https://github.com/yandex-research/switti |
| [12/4/2024] | [CVPR 2025]TokenFlow🚀: Unified Image Tokenizer for Multimodal Understanding and Generation | https://github.com/ByteFlow-AI/TokenFlow |
| [12/3/2024] | XQ-GAN🚀: An Open-source Image Tokenization Framework for Autoregressive Generation | https://github.com/lxa9867/ImageFolder |
| [11/28/2024] | [CVPR 2025]CoDe: Collaborative Decoding Makes Visual Auto-Regressive Modeling Efficient | https://github.com/czg1225/CoDe |
| [11/28/2024] | [CVPR 2025]Scalable Autoregressive Monocular Depth Estimation | https://arxiv.org/abs/2411.11361 |
| [11/27/2024] | [CVPR 2025]SAR3D: Autoregressive 3D Object Generation and Understanding via Multi-scale 3D VQVAE | https://github.com/cyw-3d/SAR3D |
| [11/26/2024] | LiteVAR: Compressing Visual Autoregressive Modelling with Efficient Attention and Quantization | https://arxiv.org/abs/2411.17178 |
| [11/15/2024] | M-VAR: Decoupled Scale-wise Autoregressive Modeling for High-Quality Image Generation | https://github.com/OliverRensu/MVAR |
| [10/14/2024] | [ICLR 2025]HART: Efficient Visual Generation with Hybrid Autoregressive Transformer | https://github.com/mit-han-lab/hart |
| [10/12/2024] | [ICLR 2025 Oral]Toward Guidance-Free AR Visual Generation via Condition Contrastive Alignment | https://github.com/thu-ml/CCA |
| [10/3/2024] | [ICLR 2025]ImageFolder🚀: Autoregressive Image Generation with Folded Tokens | https://github.com/lxa9867/ImageFolder |
| [07/25/2024] | ControlVAR: Exploring Controllable Visual Autoregressive Modeling | https://github.com/lxa9867/ControlVAR |
| [07/3/2024] | VAR-CLIP: Text-to-Image Generator with Visual Auto-Regressive Modeling | https://github.com/daixiangzi/VAR-CLIP |
| [06/16/2024] | STAR: Scale-wise Text-to-image generation via Auto-Regressive representations | https://arxiv.org/abs/2406.10797 |
## License
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
## Citation
If our work assists your research, feel free to give us a star ⭐ or cite us using:
```
@Article{VAR,
title={Visual Autoregressive Modeling: Scalable Image Generation via Next-Scale Prediction},
author={Keyu Tian and Yi Jiang and Zehuan Yuan and Bingyue Peng and Liwei Wang},
year={2024},
eprint={2404.02905},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```
```
@misc{Infinity,
title={Infinity: Scaling Bitwise AutoRegressive Modeling for High-Resolution Image Synthesis},
author={Jian Han and Jinlai Liu and Yi Jiang and Bin Yan and Yuqi Zhang and Zehuan Yuan and Bingyue Peng and Xiaobing Liu},
year={2024},
eprint={2412.04431},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2412.04431},
}
```
================================================
FILE: dist.py
================================================
import datetime
import functools
import os
import sys
from typing import List
from typing import Union
import torch
import torch.distributed as tdist
import torch.multiprocessing as mp
__rank, __local_rank, __world_size, __device = 0, 0, 1, 'cuda' if torch.cuda.is_available() else 'cpu'
__initialized = False
def initialized():
return __initialized
def initialize(fork=False, backend='nccl', gpu_id_if_not_distibuted=0, timeout=30):
global __device
if not torch.cuda.is_available():
print(f'[dist initialize] cuda is not available, use cpu instead', file=sys.stderr)
return
elif 'RANK' not in os.environ:
torch.cuda.set_device(gpu_id_if_not_distibuted)
__device = torch.empty(1).cuda().device
print(f'[dist initialize] env variable "RANK" is not set, use {__device} as the device', file=sys.stderr)
return
# then 'RANK' must exist
global_rank, num_gpus = int(os.environ['RANK']), torch.cuda.device_count()
local_rank = global_rank % num_gpus
torch.cuda.set_device(local_rank)
# ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29
if mp.get_start_method(allow_none=True) is None:
method = 'fork' if fork else 'spawn'
print(f'[dist initialize] mp method={method}')
mp.set_start_method(method)
tdist.init_process_group(backend=backend, timeout=datetime.timedelta(seconds=timeout*60))
global __rank, __local_rank, __world_size, __initialized
__local_rank = local_rank
__rank, __world_size = tdist.get_rank(), tdist.get_world_size()
__device = torch.empty(1).cuda().device
__initialized = True
assert tdist.is_initialized(), 'torch.distributed is not initialized!'
print(f'[lrk={get_local_rank()}, rk={get_rank()}]')
def get_rank():
return __rank
def get_local_rank():
return __local_rank
def get_world_size():
return __world_size
def get_device():
return __device
def set_gpu_id(gpu_id: int):
if gpu_id is None: return
global __device
if isinstance(gpu_id, (str, int)):
torch.cuda.set_device(int(gpu_id))
__device = torch.empty(1).cuda().device
else:
raise NotImplementedError
def is_master():
return __rank == 0
def is_local_master():
return __local_rank == 0
def new_group(ranks: List[int]):
if __initialized:
return tdist.new_group(ranks=ranks)
return None
def barrier():
if __initialized:
tdist.barrier()
def allreduce(t: torch.Tensor, async_op=False):
if __initialized:
if not t.is_cuda:
cu = t.detach().cuda()
ret = tdist.all_reduce(cu, async_op=async_op)
t.copy_(cu.cpu())
else:
ret = tdist.all_reduce(t, async_op=async_op)
return ret
return None
def allgather(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]:
if __initialized:
if not t.is_cuda:
t = t.cuda()
ls = [torch.empty_like(t) for _ in range(__world_size)]
tdist.all_gather(ls, t)
else:
ls = [t]
if cat:
ls = torch.cat(ls, dim=0)
return ls
def allgather_diff_shape(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]:
if __initialized:
if not t.is_cuda:
t = t.cuda()
t_size = torch.tensor(t.size(), device=t.device)
ls_size = [torch.empty_like(t_size) for _ in range(__world_size)]
tdist.all_gather(ls_size, t_size)
max_B = max(size[0].item() for size in ls_size)
pad = max_B - t_size[0].item()
if pad:
pad_size = (pad, *t.size()[1:])
t = torch.cat((t, t.new_empty(pad_size)), dim=0)
ls_padded = [torch.empty_like(t) for _ in range(__world_size)]
tdist.all_gather(ls_padded, t)
ls = []
for t, size in zip(ls_padded, ls_size):
ls.append(t[:size[0].item()])
else:
ls = [t]
if cat:
ls = torch.cat(ls, dim=0)
return ls
def broadcast(t: torch.Tensor, src_rank) -> None:
if __initialized:
if not t.is_cuda:
cu = t.detach().cuda()
tdist.broadcast(cu, src=src_rank)
t.copy_(cu.cpu())
else:
tdist.broadcast(t, src=src_rank)
def dist_fmt_vals(val: float, fmt: Union[str, None] = '%.2f') -> Union[torch.Tensor, List]:
if not initialized():
return torch.tensor([val]) if fmt is None else [fmt % val]
ts = torch.zeros(__world_size)
ts[__rank] = val
allreduce(ts)
if fmt is None:
return ts
return [fmt % v for v in ts.cpu().numpy().tolist()]
def master_only(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
force = kwargs.pop('force', False)
if force or is_master():
ret = func(*args, **kwargs)
else:
ret = None
barrier()
return ret
return wrapper
def local_master_only(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
force = kwargs.pop('force', False)
if force or is_local_master():
ret = func(*args, **kwargs)
else:
ret = None
barrier()
return ret
return wrapper
def for_visualize(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if is_master():
# with torch.no_grad():
ret = func(*args, **kwargs)
else:
ret = None
return ret
return wrapper
def finalize():
if __initialized:
tdist.destroy_process_group()
================================================
FILE: models/__init__.py
================================================
from typing import Tuple
import torch.nn as nn
from .quant import VectorQuantizer2
from .var import VAR
from .vqvae import VQVAE
def build_vae_var(
# Shared args
device, patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # 10 steps by default
# VQVAE args
V=4096, Cvae=32, ch=160, share_quant_resi=4,
# VAR args
num_classes=1000, depth=16, shared_aln=False, attn_l2_norm=True,
flash_if_available=True, fused_if_available=True,
init_adaln=0.5, init_adaln_gamma=1e-5, init_head=0.02, init_std=-1, # init_std < 0: automated
) -> Tuple[VQVAE, VAR]:
heads = depth
width = depth * 64
dpr = 0.1 * depth/24
# disable built-in initialization for speed
for clz in (nn.Linear, nn.LayerNorm, nn.BatchNorm2d, nn.SyncBatchNorm, nn.Conv1d, nn.Conv2d, nn.ConvTranspose1d, nn.ConvTranspose2d):
setattr(clz, 'reset_parameters', lambda self: None)
# build models
vae_local = VQVAE(vocab_size=V, z_channels=Cvae, ch=ch, test_mode=True, share_quant_resi=share_quant_resi, v_patch_nums=patch_nums).to(device)
var_wo_ddp = VAR(
vae_local=vae_local,
num_classes=num_classes, depth=depth, embed_dim=width, num_heads=heads, drop_rate=0., attn_drop_rate=0., drop_path_rate=dpr,
norm_eps=1e-6, shared_aln=shared_aln, cond_drop_rate=0.1,
attn_l2_norm=attn_l2_norm,
patch_nums=patch_nums,
flash_if_available=flash_if_available, fused_if_available=fused_if_available,
).to(device)
var_wo_ddp.init_weights(init_adaln=init_adaln, init_adaln_gamma=init_adaln_gamma, init_head=init_head, init_std=init_std)
return vae_local, var_wo_ddp
================================================
FILE: models/basic_vae.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
# this file only provides the 2 modules used in VQVAE
__all__ = ['Encoder', 'Decoder',]
"""
References: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/diffusionmodules/model.py
"""
# swish
def nonlinearity(x):
return x * torch.sigmoid(x)
def Normalize(in_channels, num_groups=32):
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
class Upsample2x(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x):
return self.conv(F.interpolate(x, scale_factor=2, mode='nearest'))
class Downsample2x(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, x):
return self.conv(F.pad(x, pad=(0, 1, 0, 1), mode='constant', value=0))
class ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, dropout): # conv_shortcut=False, # conv_shortcut: always False in VAE
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.norm1 = Normalize(in_channels)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout) if dropout > 1e-6 else nn.Identity()
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
else:
self.nin_shortcut = nn.Identity()
def forward(self, x):
h = self.conv1(F.silu(self.norm1(x), inplace=True))
h = self.conv2(self.dropout(F.silu(self.norm2(h), inplace=True)))
return self.nin_shortcut(x) + h
class AttnBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.C = in_channels
self.norm = Normalize(in_channels)
self.qkv = torch.nn.Conv2d(in_channels, 3*in_channels, kernel_size=1, stride=1, padding=0)
self.w_ratio = int(in_channels) ** (-0.5)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x):
qkv = self.qkv(self.norm(x))
B, _, H, W = qkv.shape # should be B,3C,H,W
C = self.C
q, k, v = qkv.reshape(B, 3, C, H, W).unbind(1)
# compute attention
q = q.view(B, C, H * W).contiguous()
q = q.permute(0, 2, 1).contiguous() # B,HW,C
k = k.view(B, C, H * W).contiguous() # B,C,HW
w = torch.bmm(q, k).mul_(self.w_ratio) # B,HW,HW w[B,i,j]=sum_c q[B,i,C]k[B,C,j]
w = F.softmax(w, dim=2)
# attend to values
v = v.view(B, C, H * W).contiguous()
w = w.permute(0, 2, 1).contiguous() # B,HW,HW (first HW of k, second of q)
h = torch.bmm(v, w) # B, C,HW (HW of q) h[B,C,j] = sum_i v[B,C,i] w[B,i,j]
h = h.view(B, C, H, W).contiguous()
return x + self.proj_out(h)
def make_attn(in_channels, using_sa=True):
return AttnBlock(in_channels) if using_sa else nn.Identity()
class Encoder(nn.Module):
def __init__(
self, *, ch=128, ch_mult=(1, 2, 4, 8), num_res_blocks=2,
dropout=0.0, in_channels=3,
z_channels, double_z=False, using_sa=True, using_mid_sa=True,
):
super().__init__()
self.ch = ch
self.num_resolutions = len(ch_mult)
self.downsample_ratio = 2 ** (self.num_resolutions - 1)
self.num_res_blocks = num_res_blocks
self.in_channels = in_channels
# downsampling
self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
in_ch_mult = (1,) + tuple(ch_mult)
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dropout=dropout))
block_in = block_out
if i_level == self.num_resolutions - 1 and using_sa:
attn.append(make_attn(block_in, using_sa=True))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample2x(block_in)
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout)
self.mid.attn_1 = make_attn(block_in, using_sa=using_mid_sa)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout)
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in, (2 * z_channels if double_z else z_channels), kernel_size=3, stride=1, padding=1)
def forward(self, x):
# downsampling
h = self.conv_in(x)
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](h)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
if i_level != self.num_resolutions - 1:
h = self.down[i_level].downsample(h)
# middle
h = self.mid.block_2(self.mid.attn_1(self.mid.block_1(h)))
# end
h = self.conv_out(F.silu(self.norm_out(h), inplace=True))
return h
class Decoder(nn.Module):
def __init__(
self, *, ch=128, ch_mult=(1, 2, 4, 8), num_res_blocks=2,
dropout=0.0, in_channels=3, # in_channels: raw img channels
z_channels, using_sa=True, using_mid_sa=True,
):
super().__init__()
self.ch = ch
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.in_channels = in_channels
# compute in_ch_mult, block_in and curr_res at lowest res
in_ch_mult = (1,) + tuple(ch_mult)
block_in = ch * ch_mult[self.num_resolutions - 1]
# z to block_in
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout)
self.mid.attn_1 = make_attn(block_in, using_sa=using_mid_sa)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dropout=dropout))
block_in = block_out
if i_level == self.num_resolutions-1 and using_sa:
attn.append(make_attn(block_in, using_sa=True))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample2x(block_in)
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in, in_channels, kernel_size=3, stride=1, padding=1)
def forward(self, z):
# z to block_in
# middle
h = self.mid.block_2(self.mid.attn_1(self.mid.block_1(self.conv_in(z))))
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
h = self.conv_out(F.silu(self.norm_out(h), inplace=True))
return h
================================================
FILE: models/basic_var.py
================================================
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.helpers import DropPath, drop_path
# this file only provides the 3 blocks used in VAR transformer
__all__ = ['FFN', 'AdaLNSelfAttn', 'AdaLNBeforeHead']
# automatically import fused operators
dropout_add_layer_norm = fused_mlp_func = memory_efficient_attention = flash_attn_func = None
try:
from flash_attn.ops.layer_norm import dropout_add_layer_norm
from flash_attn.ops.fused_dense import fused_mlp_func
except ImportError: pass
# automatically import faster attention implementations
try: from xformers.ops import memory_efficient_attention
except ImportError: pass
try: from flash_attn import flash_attn_func # qkv: BLHc, ret: BLHcq
except ImportError: pass
try: from torch.nn.functional import scaled_dot_product_attention as slow_attn # q, k, v: BHLc
except ImportError:
def slow_attn(query, key, value, scale: float, attn_mask=None, dropout_p=0.0):
attn = query.mul(scale) @ key.transpose(-2, -1) # BHLc @ BHcL => BHLL
if attn_mask is not None: attn.add_(attn_mask)
return (F.dropout(attn.softmax(dim=-1), p=dropout_p, inplace=True) if dropout_p > 0 else attn.softmax(dim=-1)) @ value
class FFN(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, drop=0., fused_if_available=True):
super().__init__()
self.fused_mlp_func = fused_mlp_func if fused_if_available else None
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = nn.GELU(approximate='tanh')
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop, inplace=True) if drop > 0 else nn.Identity()
def forward(self, x):
if self.fused_mlp_func is not None:
return self.drop(self.fused_mlp_func(
x=x, weight1=self.fc1.weight, weight2=self.fc2.weight, bias1=self.fc1.bias, bias2=self.fc2.bias,
activation='gelu_approx', save_pre_act=self.training, return_residual=False, checkpoint_lvl=0,
heuristic=0, process_group=None,
))
else:
return self.drop(self.fc2( self.act(self.fc1(x)) ))
def extra_repr(self) -> str:
return f'fused_mlp_func={self.fused_mlp_func is not None}'
class SelfAttention(nn.Module):
def __init__(
self, block_idx, embed_dim=768, num_heads=12,
attn_drop=0., proj_drop=0., attn_l2_norm=False, flash_if_available=True,
):
super().__init__()
assert embed_dim % num_heads == 0
self.block_idx, self.num_heads, self.head_dim = block_idx, num_heads, embed_dim // num_heads # =64
self.attn_l2_norm = attn_l2_norm
if self.attn_l2_norm:
self.scale = 1
self.scale_mul_1H11 = nn.Parameter(torch.full(size=(1, self.num_heads, 1, 1), fill_value=4.0).log(), requires_grad=True)
self.max_scale_mul = torch.log(torch.tensor(100)).item()
else:
self.scale = 0.25 / math.sqrt(self.head_dim)
self.mat_qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False)
self.q_bias, self.v_bias = nn.Parameter(torch.zeros(embed_dim)), nn.Parameter(torch.zeros(embed_dim))
self.register_buffer('zero_k_bias', torch.zeros(embed_dim))
self.proj = nn.Linear(embed_dim, embed_dim)
self.proj_drop = nn.Dropout(proj_drop, inplace=True) if proj_drop > 0 else nn.Identity()
self.attn_drop: float = attn_drop
self.using_flash = flash_if_available and flash_attn_func is not None
self.using_xform = flash_if_available and memory_efficient_attention is not None
# only used during inference
self.caching, self.cached_k, self.cached_v = False, None, None
def kv_caching(self, enable: bool): self.caching, self.cached_k, self.cached_v = enable, None, None
# NOTE: attn_bias is None during inference because kv cache is enabled
def forward(self, x, attn_bias):
B, L, C = x.shape
qkv = F.linear(input=x, weight=self.mat_qkv.weight, bias=torch.cat((self.q_bias, self.zero_k_bias, self.v_bias))).view(B, L, 3, self.num_heads, self.head_dim)
main_type = qkv.dtype
# qkv: BL3Hc
using_flash = self.using_flash and attn_bias is None and qkv.dtype != torch.float32
if using_flash or self.using_xform: q, k, v = qkv.unbind(dim=2); dim_cat = 1 # q or k or v: BLHc
else: q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(dim=0); dim_cat = 2 # q or k or v: BHLc
if self.attn_l2_norm:
scale_mul = self.scale_mul_1H11.clamp_max(self.max_scale_mul).exp()
if using_flash or self.using_xform: scale_mul = scale_mul.transpose(1, 2) # 1H11 to 11H1
q = F.normalize(q, dim=-1).mul(scale_mul)
k = F.normalize(k, dim=-1)
if self.caching:
if self.cached_k is None: self.cached_k = k; self.cached_v = v
else: k = self.cached_k = torch.cat((self.cached_k, k), dim=dim_cat); v = self.cached_v = torch.cat((self.cached_v, v), dim=dim_cat)
dropout_p = self.attn_drop if self.training else 0.0
if using_flash:
oup = flash_attn_func(q.to(dtype=main_type), k.to(dtype=main_type), v.to(dtype=main_type), dropout_p=dropout_p, softmax_scale=self.scale).view(B, L, C)
elif self.using_xform:
oup = memory_efficient_attention(q.to(dtype=main_type), k.to(dtype=main_type), v.to(dtype=main_type), attn_bias=None if attn_bias is None else attn_bias.to(dtype=main_type).expand(B, self.num_heads, -1, -1), p=dropout_p, scale=self.scale).view(B, L, C)
else:
oup = slow_attn(query=q, key=k, value=v, scale=self.scale, attn_mask=attn_bias, dropout_p=dropout_p).transpose(1, 2).reshape(B, L, C)
return self.proj_drop(self.proj(oup))
# attn = (q @ k.transpose(-2, -1)).add_(attn_bias + self.local_rpb()) # BHLc @ BHcL => BHLL
# attn = self.attn_drop(attn.softmax(dim=-1))
# oup = (attn @ v).transpose_(1, 2).reshape(B, L, -1) # BHLL @ BHLc = BHLc => BLHc => BLC
def extra_repr(self) -> str:
return f'using_flash={self.using_flash}, using_xform={self.using_xform}, attn_l2_norm={self.attn_l2_norm}'
class AdaLNSelfAttn(nn.Module):
def __init__(
self, block_idx, last_drop_p, embed_dim, cond_dim, shared_aln: bool, norm_layer,
num_heads, mlp_ratio=4., drop=0., attn_drop=0., drop_path=0., attn_l2_norm=False,
flash_if_available=False, fused_if_available=True,
):
super(AdaLNSelfAttn, self).__init__()
self.block_idx, self.last_drop_p, self.C = block_idx, last_drop_p, embed_dim
self.C, self.D = embed_dim, cond_dim
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.attn = SelfAttention(block_idx=block_idx, embed_dim=embed_dim, num_heads=num_heads, attn_drop=attn_drop, proj_drop=drop, attn_l2_norm=attn_l2_norm, flash_if_available=flash_if_available)
self.ffn = FFN(in_features=embed_dim, hidden_features=round(embed_dim * mlp_ratio), drop=drop, fused_if_available=fused_if_available)
self.ln_wo_grad = norm_layer(embed_dim, elementwise_affine=False)
self.shared_aln = shared_aln
if self.shared_aln:
self.ada_gss = nn.Parameter(torch.randn(1, 1, 6, embed_dim) / embed_dim**0.5)
else:
lin = nn.Linear(cond_dim, 6*embed_dim)
self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), lin)
self.fused_add_norm_fn = None
# NOTE: attn_bias is None during inference because kv cache is enabled
def forward(self, x, cond_BD, attn_bias): # C: embed_dim, D: cond_dim
if self.shared_aln:
gamma1, gamma2, scale1, scale2, shift1, shift2 = (self.ada_gss + cond_BD).unbind(2) # 116C + B16C =unbind(2)=> 6 B1C
else:
gamma1, gamma2, scale1, scale2, shift1, shift2 = self.ada_lin(cond_BD).view(-1, 1, 6, self.C).unbind(2)
x = x + self.drop_path(self.attn( self.ln_wo_grad(x).mul(scale1.add(1)).add_(shift1), attn_bias=attn_bias ).mul_(gamma1))
x = x + self.drop_path(self.ffn( self.ln_wo_grad(x).mul(scale2.add(1)).add_(shift2) ).mul(gamma2)) # this mul(gamma2) cannot be in-placed when FusedMLP is used
return x
def extra_repr(self) -> str:
return f'shared_aln={self.shared_aln}'
class AdaLNBeforeHead(nn.Module):
def __init__(self, C, D, norm_layer): # C: embed_dim, D: cond_dim
super().__init__()
self.C, self.D = C, D
self.ln_wo_grad = norm_layer(C, elementwise_affine=False)
self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), nn.Linear(D, 2*C))
def forward(self, x_BLC: torch.Tensor, cond_BD: torch.Tensor):
scale, shift = self.ada_lin(cond_BD).view(-1, 1, 2, self.C).unbind(2)
return self.ln_wo_grad(x_BLC).mul(scale.add(1)).add_(shift)
================================================
FILE: models/helpers.py
================================================
import torch
from torch import nn as nn
from torch.nn import functional as F
def sample_with_top_k_top_p_(logits_BlV: torch.Tensor, top_k: int = 0, top_p: float = 0.0, rng=None, num_samples=1) -> torch.Tensor: # return idx, shaped (B, l)
B, l, V = logits_BlV.shape
if top_k > 0:
idx_to_remove = logits_BlV < logits_BlV.topk(top_k, largest=True, sorted=False, dim=-1)[0].amin(dim=-1, keepdim=True)
logits_BlV.masked_fill_(idx_to_remove, -torch.inf)
if top_p > 0:
sorted_logits, sorted_idx = logits_BlV.sort(dim=-1, descending=False)
sorted_idx_to_remove = sorted_logits.softmax(dim=-1).cumsum_(dim=-1) <= (1 - top_p)
sorted_idx_to_remove[..., -1:] = False
logits_BlV.masked_fill_(sorted_idx_to_remove.scatter(sorted_idx.ndim - 1, sorted_idx, sorted_idx_to_remove), -torch.inf)
# sample (have to squeeze cuz torch.multinomial can only be used for 2D tensor)
replacement = num_samples >= 0
num_samples = abs(num_samples)
return torch.multinomial(logits_BlV.softmax(dim=-1).view(-1, V), num_samples=num_samples, replacement=replacement, generator=rng).view(B, l, num_samples)
def gumbel_softmax_with_rng(logits: torch.Tensor, tau: float = 1, hard: bool = False, eps: float = 1e-10, dim: int = -1, rng: torch.Generator = None) -> torch.Tensor:
if rng is None:
return F.gumbel_softmax(logits=logits, tau=tau, hard=hard, eps=eps, dim=dim)
gumbels = (-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_(generator=rng).log())
gumbels = (logits + gumbels) / tau
y_soft = gumbels.softmax(dim)
if hard:
index = y_soft.max(dim, keepdim=True)[1]
y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
ret = y_hard - y_soft.detach() + y_soft
else:
ret = y_soft
return ret
def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): # taken from timm
if drop_prob == 0. or not training: return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and scale_by_keep:
random_tensor.div_(keep_prob)
return x * random_tensor
class DropPath(nn.Module): # taken from timm
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
def forward(self, x):
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
def extra_repr(self):
return f'(drop_prob=...)'
================================================
FILE: models/quant.py
================================================
from typing import List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
from torch import distributed as tdist, nn as nn
from torch.nn import functional as F
import dist
# this file only provides the VectorQuantizer2 used in VQVAE
__all__ = ['VectorQuantizer2',]
class VectorQuantizer2(nn.Module):
# VQGAN originally use beta=1.0, never tried 0.25; SD seems using 0.25
def __init__(
self, vocab_size, Cvae, using_znorm, beta: float = 0.25,
default_qresi_counts=0, v_patch_nums=None, quant_resi=0.5, share_quant_resi=4, # share_quant_resi: args.qsr
):
super().__init__()
self.vocab_size: int = vocab_size
self.Cvae: int = Cvae
self.using_znorm: bool = using_znorm
self.v_patch_nums: Tuple[int] = v_patch_nums
self.quant_resi_ratio = quant_resi
if share_quant_resi == 0: # non-shared: \phi_{1 to K} for K scales
self.quant_resi = PhiNonShared([(Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity()) for _ in range(default_qresi_counts or len(self.v_patch_nums))])
elif share_quant_resi == 1: # fully shared: only a single \phi for K scales
self.quant_resi = PhiShared(Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity())
else: # partially shared: \phi_{1 to share_quant_resi} for K scales
self.quant_resi = PhiPartiallyShared(nn.ModuleList([(Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity()) for _ in range(share_quant_resi)]))
self.register_buffer('ema_vocab_hit_SV', torch.full((len(self.v_patch_nums), self.vocab_size), fill_value=0.0))
self.record_hit = 0
self.beta: float = beta
self.embedding = nn.Embedding(self.vocab_size, self.Cvae)
# only used for progressive training of VAR (not supported yet, will be tested and supported in the future)
self.prog_si = -1 # progressive training: not supported yet, prog_si always -1
def eini(self, eini):
if eini > 0: nn.init.trunc_normal_(self.embedding.weight.data, std=eini)
elif eini < 0: self.embedding.weight.data.uniform_(-abs(eini) / self.vocab_size, abs(eini) / self.vocab_size)
def extra_repr(self) -> str:
return f'{self.v_patch_nums}, znorm={self.using_znorm}, beta={self.beta} | S={len(self.v_patch_nums)}, quant_resi={self.quant_resi_ratio}'
# ===================== `forward` is only used in VAE training =====================
def forward(self, f_BChw: torch.Tensor, ret_usages=False) -> Tuple[torch.Tensor, List[float], torch.Tensor]:
dtype = f_BChw.dtype
if dtype != torch.float32: f_BChw = f_BChw.float()
B, C, H, W = f_BChw.shape
f_no_grad = f_BChw.detach()
f_rest = f_no_grad.clone()
f_hat = torch.zeros_like(f_rest)
with torch.cuda.amp.autocast(enabled=False):
mean_vq_loss: torch.Tensor = 0.0
vocab_hit_V = torch.zeros(self.vocab_size, dtype=torch.float, device=f_BChw.device)
SN = len(self.v_patch_nums)
for si, pn in enumerate(self.v_patch_nums): # from small to large
# find the nearest embedding
if self.using_znorm:
rest_NC = F.interpolate(f_rest, size=(pn, pn), mode='area').permute(0, 2, 3, 1).reshape(-1, C) if (si != SN-1) else f_rest.permute(0, 2, 3, 1).reshape(-1, C)
rest_NC = F.normalize(rest_NC, dim=-1)
idx_N = torch.argmax(rest_NC @ F.normalize(self.embedding.weight.data.T, dim=0), dim=1)
else:
rest_NC = F.interpolate(f_rest, size=(pn, pn), mode='area').permute(0, 2, 3, 1).reshape(-1, C) if (si != SN-1) else f_rest.permute(0, 2, 3, 1).reshape(-1, C)
d_no_grad = torch.sum(rest_NC.square(), dim=1, keepdim=True) + torch.sum(self.embedding.weight.data.square(), dim=1, keepdim=False)
d_no_grad.addmm_(rest_NC, self.embedding.weight.data.T, alpha=-2, beta=1) # (B*h*w, vocab_size)
idx_N = torch.argmin(d_no_grad, dim=1)
hit_V = idx_N.bincount(minlength=self.vocab_size).float()
if self.training:
if dist.initialized(): handler = tdist.all_reduce(hit_V, async_op=True)
# calc loss
idx_Bhw = idx_N.view(B, pn, pn)
h_BChw = F.interpolate(self.embedding(idx_Bhw).permute(0, 3, 1, 2), size=(H, W), mode='bicubic').contiguous() if (si != SN-1) else self.embedding(idx_Bhw).permute(0, 3, 1, 2).contiguous()
h_BChw = self.quant_resi[si/(SN-1)](h_BChw)
f_hat = f_hat + h_BChw
f_rest -= h_BChw
if self.training and dist.initialized():
handler.wait()
if self.record_hit == 0: self.ema_vocab_hit_SV[si].copy_(hit_V)
elif self.record_hit < 100: self.ema_vocab_hit_SV[si].mul_(0.9).add_(hit_V.mul(0.1))
else: self.ema_vocab_hit_SV[si].mul_(0.99).add_(hit_V.mul(0.01))
self.record_hit += 1
vocab_hit_V.add_(hit_V)
mean_vq_loss += F.mse_loss(f_hat.data, f_BChw).mul_(self.beta) + F.mse_loss(f_hat, f_no_grad)
mean_vq_loss *= 1. / SN
f_hat = (f_hat.data - f_no_grad).add_(f_BChw)
margin = tdist.get_world_size() * (f_BChw.numel() / f_BChw.shape[1]) / self.vocab_size * 0.08
# margin = pn*pn / 100
if ret_usages: usages = [(self.ema_vocab_hit_SV[si] >= margin).float().mean().item() * 100 for si, pn in enumerate(self.v_patch_nums)]
else: usages = None
return f_hat, usages, mean_vq_loss
# ===================== `forward` is only used in VAE training =====================
def embed_to_fhat(self, ms_h_BChw: List[torch.Tensor], all_to_max_scale=True, last_one=False) -> Union[List[torch.Tensor], torch.Tensor]:
ls_f_hat_BChw = []
B = ms_h_BChw[0].shape[0]
H = W = self.v_patch_nums[-1]
SN = len(self.v_patch_nums)
if all_to_max_scale:
f_hat = ms_h_BChw[0].new_zeros(B, self.Cvae, H, W, dtype=torch.float32)
for si, pn in enumerate(self.v_patch_nums): # from small to large
h_BChw = ms_h_BChw[si]
if si < len(self.v_patch_nums) - 1:
h_BChw = F.interpolate(h_BChw, size=(H, W), mode='bicubic')
h_BChw = self.quant_resi[si/(SN-1)](h_BChw)
f_hat.add_(h_BChw)
if last_one: ls_f_hat_BChw = f_hat
else: ls_f_hat_BChw.append(f_hat.clone())
else:
# WARNING: this is not the case in VQ-VAE training or inference (we'll interpolate every token map to the max H W, like above)
# WARNING: this should only be used for experimental purpose
f_hat = ms_h_BChw[0].new_zeros(B, self.Cvae, self.v_patch_nums[0], self.v_patch_nums[0], dtype=torch.float32)
for si, pn in enumerate(self.v_patch_nums): # from small to large
f_hat = F.interpolate(f_hat, size=(pn, pn), mode='bicubic')
h_BChw = self.quant_resi[si/(SN-1)](ms_h_BChw[si])
f_hat.add_(h_BChw)
if last_one: ls_f_hat_BChw = f_hat
else: ls_f_hat_BChw.append(f_hat)
return ls_f_hat_BChw
def f_to_idxBl_or_fhat(self, f_BChw: torch.Tensor, to_fhat: bool, v_patch_nums: Optional[Sequence[Union[int, Tuple[int, int]]]] = None) -> List[Union[torch.Tensor, torch.LongTensor]]: # z_BChw is the feature from inp_img_no_grad
B, C, H, W = f_BChw.shape
f_no_grad = f_BChw.detach()
f_rest = f_no_grad.clone()
f_hat = torch.zeros_like(f_rest)
f_hat_or_idx_Bl: List[torch.Tensor] = []
patch_hws = [(pn, pn) if isinstance(pn, int) else (pn[0], pn[1]) for pn in (v_patch_nums or self.v_patch_nums)] # from small to large
assert patch_hws[-1][0] == H and patch_hws[-1][1] == W, f'{patch_hws[-1]=} != ({H=}, {W=})'
SN = len(patch_hws)
for si, (ph, pw) in enumerate(patch_hws): # from small to large
if 0 <= self.prog_si < si: break # progressive training: not supported yet, prog_si always -1
# find the nearest embedding
z_NC = F.interpolate(f_rest, size=(ph, pw), mode='area').permute(0, 2, 3, 1).reshape(-1, C) if (si != SN-1) else f_rest.permute(0, 2, 3, 1).reshape(-1, C)
if self.using_znorm:
z_NC = F.normalize(z_NC, dim=-1)
idx_N = torch.argmax(z_NC @ F.normalize(self.embedding.weight.data.T, dim=0), dim=1)
else:
d_no_grad = torch.sum(z_NC.square(), dim=1, keepdim=True) + torch.sum(self.embedding.weight.data.square(), dim=1, keepdim=False)
d_no_grad.addmm_(z_NC, self.embedding.weight.data.T, alpha=-2, beta=1) # (B*h*w, vocab_size)
idx_N = torch.argmin(d_no_grad, dim=1)
idx_Bhw = idx_N.view(B, ph, pw)
h_BChw = F.interpolate(self.embedding(idx_Bhw).permute(0, 3, 1, 2), size=(H, W), mode='bicubic').contiguous() if (si != SN-1) else self.embedding(idx_Bhw).permute(0, 3, 1, 2).contiguous()
h_BChw = self.quant_resi[si/(SN-1)](h_BChw)
f_hat.add_(h_BChw)
f_rest.sub_(h_BChw)
f_hat_or_idx_Bl.append(f_hat.clone() if to_fhat else idx_N.reshape(B, ph*pw))
return f_hat_or_idx_Bl
# ===================== idxBl_to_var_input: only used in VAR training, for getting teacher-forcing input =====================
def idxBl_to_var_input(self, gt_ms_idx_Bl: List[torch.Tensor]) -> torch.Tensor:
next_scales = []
B = gt_ms_idx_Bl[0].shape[0]
C = self.Cvae
H = W = self.v_patch_nums[-1]
SN = len(self.v_patch_nums)
f_hat = gt_ms_idx_Bl[0].new_zeros(B, C, H, W, dtype=torch.float32)
pn_next: int = self.v_patch_nums[0]
for si in range(SN-1):
if self.prog_si == 0 or (0 <= self.prog_si-1 < si): break # progressive training: not supported yet, prog_si always -1
h_BChw = F.interpolate(self.embedding(gt_ms_idx_Bl[si]).transpose_(1, 2).view(B, C, pn_next, pn_next), size=(H, W), mode='bicubic')
f_hat.add_(self.quant_resi[si/(SN-1)](h_BChw))
pn_next = self.v_patch_nums[si+1]
next_scales.append(F.interpolate(f_hat, size=(pn_next, pn_next), mode='area').view(B, C, -1).transpose(1, 2))
return torch.cat(next_scales, dim=1) if len(next_scales) else None # cat BlCs to BLC, this should be float32
# ===================== get_next_autoregressive_input: only used in VAR inference, for getting next step's input =====================
def get_next_autoregressive_input(self, si: int, SN: int, f_hat: torch.Tensor, h_BChw: torch.Tensor) -> Tuple[Optional[torch.Tensor], torch.Tensor]: # only used in VAR inference
HW = self.v_patch_nums[-1]
if si != SN-1:
h = self.quant_resi[si/(SN-1)](F.interpolate(h_BChw, size=(HW, HW), mode='bicubic')) # conv after upsample
f_hat.add_(h)
return f_hat, F.interpolate(f_hat, size=(self.v_patch_nums[si+1], self.v_patch_nums[si+1]), mode='area')
else:
h = self.quant_resi[si/(SN-1)](h_BChw)
f_hat.add_(h)
return f_hat, f_hat
class Phi(nn.Conv2d):
def __init__(self, embed_dim, quant_resi):
ks = 3
super().__init__(in_channels=embed_dim, out_channels=embed_dim, kernel_size=ks, stride=1, padding=ks//2)
self.resi_ratio = abs(quant_resi)
def forward(self, h_BChw):
return h_BChw.mul(1-self.resi_ratio) + super().forward(h_BChw).mul_(self.resi_ratio)
class PhiShared(nn.Module):
def __init__(self, qresi: Phi):
super().__init__()
self.qresi: Phi = qresi
def __getitem__(self, _) -> Phi:
return self.qresi
class PhiPartiallyShared(nn.Module):
def __init__(self, qresi_ls: nn.ModuleList):
super().__init__()
self.qresi_ls = qresi_ls
K = len(qresi_ls)
self.ticks = np.linspace(1/3/K, 1-1/3/K, K) if K == 4 else np.linspace(1/2/K, 1-1/2/K, K)
def __getitem__(self, at_from_0_to_1: float) -> Phi:
return self.qresi_ls[np.argmin(np.abs(self.ticks - at_from_0_to_1)).item()]
def extra_repr(self) -> str:
return f'ticks={self.ticks}'
class PhiNonShared(nn.ModuleList):
def __init__(self, qresi: List):
super().__init__(qresi)
# self.qresi = qresi
K = len(qresi)
self.ticks = np.linspace(1/3/K, 1-1/3/K, K) if K == 4 else np.linspace(1/2/K, 1-1/2/K, K)
def __getitem__(self, at_from_0_to_1: float) -> Phi:
return super().__getitem__(np.argmin(np.abs(self.ticks - at_from_0_to_1)).item())
def extra_repr(self) -> str:
return f'ticks={self.ticks}'
================================================
FILE: models/var.py
================================================
import math
from functools import partial
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin
import dist
from models.basic_var import AdaLNBeforeHead, AdaLNSelfAttn
from models.helpers import gumbel_softmax_with_rng, sample_with_top_k_top_p_
from models.vqvae import VQVAE, VectorQuantizer2
class SharedAdaLin(nn.Linear):
def forward(self, cond_BD):
C = self.weight.shape[0] // 6
return super().forward(cond_BD).view(-1, 1, 6, C) # B16C
class VAR(nn.Module):
def __init__(
self, vae_local: VQVAE,
num_classes=1000, depth=16, embed_dim=1024, num_heads=16, mlp_ratio=4., drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
norm_eps=1e-6, shared_aln=False, cond_drop_rate=0.1,
attn_l2_norm=False,
patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # 10 steps by default
flash_if_available=True, fused_if_available=True,
):
super().__init__()
# 0. hyperparameters
assert embed_dim % num_heads == 0
self.Cvae, self.V = vae_local.Cvae, vae_local.vocab_size
self.depth, self.C, self.D, self.num_heads = depth, embed_dim, embed_dim, num_heads
self.cond_drop_rate = cond_drop_rate
self.prog_si = -1 # progressive training
self.patch_nums: Tuple[int] = patch_nums
self.L = sum(pn ** 2 for pn in self.patch_nums)
self.first_l = self.patch_nums[0] ** 2
self.begin_ends = []
cur = 0
for i, pn in enumerate(self.patch_nums):
self.begin_ends.append((cur, cur+pn ** 2))
cur += pn ** 2
self.num_stages_minus_1 = len(self.patch_nums) - 1
self.rng = torch.Generator(device=dist.get_device())
# 1. input (word) embedding
quant: VectorQuantizer2 = vae_local.quantize
self.vae_proxy: Tuple[VQVAE] = (vae_local,)
self.vae_quant_proxy: Tuple[VectorQuantizer2] = (quant,)
self.word_embed = nn.Linear(self.Cvae, self.C)
# 2. class embedding
init_std = math.sqrt(1 / self.C / 3)
self.num_classes = num_classes
self.uniform_prob = torch.full((1, num_classes), fill_value=1.0 / num_classes, dtype=torch.float32, device=dist.get_device())
self.class_emb = nn.Embedding(self.num_classes + 1, self.C)
nn.init.trunc_normal_(self.class_emb.weight.data, mean=0, std=init_std)
self.pos_start = nn.Parameter(torch.empty(1, self.first_l, self.C))
nn.init.trunc_normal_(self.pos_start.data, mean=0, std=init_std)
# 3. absolute position embedding
pos_1LC = []
for i, pn in enumerate(self.patch_nums):
pe = torch.empty(1, pn*pn, self.C)
nn.init.trunc_normal_(pe, mean=0, std=init_std)
pos_1LC.append(pe)
pos_1LC = torch.cat(pos_1LC, dim=1) # 1, L, C
assert tuple(pos_1LC.shape) == (1, self.L, self.C)
self.pos_1LC = nn.Parameter(pos_1LC)
# level embedding (similar to GPT's segment embedding, used to distinguish different levels of token pyramid)
self.lvl_embed = nn.Embedding(len(self.patch_nums), self.C)
nn.init.trunc_normal_(self.lvl_embed.weight.data, mean=0, std=init_std)
# 4. backbone blocks
self.shared_ada_lin = nn.Sequential(nn.SiLU(inplace=False), SharedAdaLin(self.D, 6*self.C)) if shared_aln else nn.Identity()
norm_layer = partial(nn.LayerNorm, eps=norm_eps)
self.drop_path_rate = drop_path_rate
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule (linearly increasing)
self.blocks = nn.ModuleList([
AdaLNSelfAttn(
cond_dim=self.D, shared_aln=shared_aln,
block_idx=block_idx, embed_dim=self.C, norm_layer=norm_layer, num_heads=num_heads, mlp_ratio=mlp_ratio,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[block_idx], last_drop_p=0 if block_idx == 0 else dpr[block_idx-1],
attn_l2_norm=attn_l2_norm,
flash_if_available=flash_if_available, fused_if_available=fused_if_available,
)
for block_idx in range(depth)
])
fused_add_norm_fns = [b.fused_add_norm_fn is not None for b in self.blocks]
self.using_fused_add_norm_fn = any(fused_add_norm_fns)
print(
f'\n[constructor] ==== flash_if_available={flash_if_available} ({sum(b.attn.using_flash for b in self.blocks)}/{self.depth}), fused_if_available={fused_if_available} (fusing_add_ln={sum(fused_add_norm_fns)}/{self.depth}, fusing_mlp={sum(b.ffn.fused_mlp_func is not None for b in self.blocks)}/{self.depth}) ==== \n'
f' [VAR config ] embed_dim={embed_dim}, num_heads={num_heads}, depth={depth}, mlp_ratio={mlp_ratio}\n'
f' [drop ratios ] drop_rate={drop_rate}, attn_drop_rate={attn_drop_rate}, drop_path_rate={drop_path_rate:g} ({torch.linspace(0, drop_path_rate, depth)})',
end='\n\n', flush=True
)
# 5. attention mask used in training (for masking out the future)
# it won't be used in inference, since kv cache is enabled
d: torch.Tensor = torch.cat([torch.full((pn*pn,), i) for i, pn in enumerate(self.patch_nums)]).view(1, self.L, 1)
dT = d.transpose(1, 2) # dT: 11L
lvl_1L = dT[:, 0].contiguous()
self.register_buffer('lvl_1L', lvl_1L)
attn_bias_for_masking = torch.where(d >= dT, 0., -torch.inf).reshape(1, 1, self.L, self.L)
self.register_buffer('attn_bias_for_masking', attn_bias_for_masking.contiguous())
# 6. classifier head
self.head_nm = AdaLNBeforeHead(self.C, self.D, norm_layer=norm_layer)
self.head = nn.Linear(self.C, self.V)
def get_logits(self, h_or_h_and_residual: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], cond_BD: Optional[torch.Tensor]):
if not isinstance(h_or_h_and_residual, torch.Tensor):
h, resi = h_or_h_and_residual # fused_add_norm must be used
h = resi + self.blocks[-1].drop_path(h)
else: # fused_add_norm is not used
h = h_or_h_and_residual
return self.head(self.head_nm(h.float(), cond_BD).float()).float()
@torch.no_grad()
def autoregressive_infer_cfg(
self, B: int, label_B: Optional[Union[int, torch.LongTensor]],
g_seed: Optional[int] = None, cfg=1.5, top_k=0, top_p=0.0,
more_smooth=False,
) -> torch.Tensor: # returns reconstructed image (B, 3, H, W) in [0, 1]
"""
only used for inference, on autoregressive mode
:param B: batch size
:param label_B: imagenet label; if None, randomly sampled
:param g_seed: random seed
:param cfg: classifier-free guidance ratio
:param top_k: top-k sampling
:param top_p: top-p sampling
:param more_smooth: smoothing the pred using gumbel softmax; only used in visualization, not used in FID/IS benchmarking
:return: if returns_vemb: list of embedding h_BChw := vae_embed(idx_Bl), else: list of idx_Bl
"""
if g_seed is None: rng = None
else: self.rng.manual_seed(g_seed); rng = self.rng
if label_B is None:
label_B = torch.multinomial(self.uniform_prob, num_samples=B, replacement=True, generator=rng).reshape(B)
elif isinstance(label_B, int):
label_B = torch.full((B,), fill_value=self.num_classes if label_B < 0 else label_B, device=self.lvl_1L.device)
sos = cond_BD = self.class_emb(torch.cat((label_B, torch.full_like(label_B, fill_value=self.num_classes)), dim=0))
lvl_pos = self.lvl_embed(self.lvl_1L) + self.pos_1LC
next_token_map = sos.unsqueeze(1).expand(2 * B, self.first_l, -1) + self.pos_start.expand(2 * B, self.first_l, -1) + lvl_pos[:, :self.first_l]
cur_L = 0
f_hat = sos.new_zeros(B, self.Cvae, self.patch_nums[-1], self.patch_nums[-1])
for b in self.blocks: b.attn.kv_caching(True)
for si, pn in enumerate(self.patch_nums): # si: i-th segment
ratio = si / self.num_stages_minus_1
# last_L = cur_L
cur_L += pn*pn
# assert self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L].sum() == 0, f'AR with {(self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L] != 0).sum()} / {self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L].numel()} mask item'
cond_BD_or_gss = self.shared_ada_lin(cond_BD)
x = next_token_map
AdaLNSelfAttn.forward
for b in self.blocks:
x = b(x=x, cond_BD=cond_BD_or_gss, attn_bias=None)
logits_BlV = self.get_logits(x, cond_BD)
t = cfg * ratio
logits_BlV = (1+t) * logits_BlV[:B] - t * logits_BlV[B:]
idx_Bl = sample_with_top_k_top_p_(logits_BlV, rng=rng, top_k=top_k, top_p=top_p, num_samples=1)[:, :, 0]
if not more_smooth: # this is the default case
h_BChw = self.vae_quant_proxy[0].embedding(idx_Bl) # B, l, Cvae
else: # not used when evaluating FID/IS/Precision/Recall
gum_t = max(0.27 * (1 - ratio * 0.95), 0.005) # refer to mask-git
h_BChw = gumbel_softmax_with_rng(logits_BlV.mul(1 + ratio), tau=gum_t, hard=False, dim=-1, rng=rng) @ self.vae_quant_proxy[0].embedding.weight.unsqueeze(0)
h_BChw = h_BChw.transpose_(1, 2).reshape(B, self.Cvae, pn, pn)
f_hat, next_token_map = self.vae_quant_proxy[0].get_next_autoregressive_input(si, len(self.patch_nums), f_hat, h_BChw)
if si != self.num_stages_minus_1: # prepare for next stage
next_token_map = next_token_map.view(B, self.Cvae, -1).transpose(1, 2)
next_token_map = self.word_embed(next_token_map) + lvl_pos[:, cur_L:cur_L + self.patch_nums[si+1] ** 2]
next_token_map = next_token_map.repeat(2, 1, 1) # double the batch sizes due to CFG
for b in self.blocks: b.attn.kv_caching(False)
return self.vae_proxy[0].fhat_to_img(f_hat).add_(1).mul_(0.5) # de-normalize, from [-1, 1] to [0, 1]
def forward(self, label_B: torch.LongTensor, x_BLCv_wo_first_l: torch.Tensor) -> torch.Tensor: # returns logits_BLV
"""
:param label_B: label_B
:param x_BLCv_wo_first_l: teacher forcing input (B, self.L-self.first_l, self.Cvae)
:return: logits BLV, V is vocab_size
"""
bg, ed = self.begin_ends[self.prog_si] if self.prog_si >= 0 else (0, self.L)
B = x_BLCv_wo_first_l.shape[0]
with torch.cuda.amp.autocast(enabled=False):
label_B = torch.where(torch.rand(B, device=label_B.device) < self.cond_drop_rate, self.num_classes, label_B)
sos = cond_BD = self.class_emb(label_B)
sos = sos.unsqueeze(1).expand(B, self.first_l, -1) + self.pos_start.expand(B, self.first_l, -1)
if self.prog_si == 0: x_BLC = sos
else: x_BLC = torch.cat((sos, self.word_embed(x_BLCv_wo_first_l.float())), dim=1)
x_BLC += self.lvl_embed(self.lvl_1L[:, :ed].expand(B, -1)) + self.pos_1LC[:, :ed] # lvl: BLC; pos: 1LC
attn_bias = self.attn_bias_for_masking[:, :, :ed, :ed]
cond_BD_or_gss = self.shared_ada_lin(cond_BD)
# hack: get the dtype if mixed precision is used
temp = x_BLC.new_ones(8, 8)
main_type = torch.matmul(temp, temp).dtype
x_BLC = x_BLC.to(dtype=main_type)
cond_BD_or_gss = cond_BD_or_gss.to(dtype=main_type)
attn_bias = attn_bias.to(dtype=main_type)
AdaLNSelfAttn.forward
for i, b in enumerate(self.blocks):
x_BLC = b(x=x_BLC, cond_BD=cond_BD_or_gss, attn_bias=attn_bias)
x_BLC = self.get_logits(x_BLC.float(), cond_BD)
if self.prog_si == 0:
if isinstance(self.word_embed, nn.Linear):
x_BLC[0, 0, 0] += self.word_embed.weight[0, 0] * 0 + self.word_embed.bias[0] * 0
else:
s = 0
for p in self.word_embed.parameters():
if p.requires_grad:
s += p.view(-1)[0] * 0
x_BLC[0, 0, 0] += s
return x_BLC # logits BLV, V is vocab_size
def init_weights(self, init_adaln=0.5, init_adaln_gamma=1e-5, init_head=0.02, init_std=0.02, conv_std_or_gain=0.02):
if init_std < 0: init_std = (1 / self.C / 3) ** 0.5 # init_std < 0: automated
print(f'[init_weights] {type(self).__name__} with {init_std=:g}')
for m in self.modules():
with_weight = hasattr(m, 'weight') and m.weight is not None
with_bias = hasattr(m, 'bias') and m.bias is not None
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight.data, std=init_std)
if with_bias: m.bias.data.zero_()
elif isinstance(m, nn.Embedding):
nn.init.trunc_normal_(m.weight.data, std=init_std)
if m.padding_idx is not None: m.weight.data[m.padding_idx].zero_()
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm, nn.GroupNorm, nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)):
if with_weight: m.weight.data.fill_(1.)
if with_bias: m.bias.data.zero_()
# conv: VAR has no conv, only VQVAE has conv
elif isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
if conv_std_or_gain > 0: nn.init.trunc_normal_(m.weight.data, std=conv_std_or_gain)
else: nn.init.xavier_normal_(m.weight.data, gain=-conv_std_or_gain)
if with_bias: m.bias.data.zero_()
if init_head >= 0:
if isinstance(self.head, nn.Linear):
self.head.weight.data.mul_(init_head)
self.head.bias.data.zero_()
elif isinstance(self.head, nn.Sequential):
self.head[-1].weight.data.mul_(init_head)
self.head[-1].bias.data.zero_()
if isinstance(self.head_nm, AdaLNBeforeHead):
self.head_nm.ada_lin[-1].weight.data.mul_(init_adaln)
if hasattr(self.head_nm.ada_lin[-1], 'bias') and self.head_nm.ada_lin[-1].bias is not None:
self.head_nm.ada_lin[-1].bias.data.zero_()
depth = len(self.blocks)
for block_idx, sab in enumerate(self.blocks):
sab: AdaLNSelfAttn
sab.attn.proj.weight.data.div_(math.sqrt(2 * depth))
sab.ffn.fc2.weight.data.div_(math.sqrt(2 * depth))
if hasattr(sab.ffn, 'fcg') and sab.ffn.fcg is not None:
nn.init.ones_(sab.ffn.fcg.bias)
nn.init.trunc_normal_(sab.ffn.fcg.weight, std=1e-5)
if hasattr(sab, 'ada_lin'):
sab.ada_lin[-1].weight.data[2*self.C:].mul_(init_adaln)
sab.ada_lin[-1].weight.data[:2*self.C].mul_(init_adaln_gamma)
if hasattr(sab.ada_lin[-1], 'bias') and sab.ada_lin[-1].bias is not None:
sab.ada_lin[-1].bias.data.zero_()
elif hasattr(sab, 'ada_gss'):
sab.ada_gss.data[:, :, 2:].mul_(init_adaln)
sab.ada_gss.data[:, :, :2].mul_(init_adaln_gamma)
def extra_repr(self):
return f'drop_path_rate={self.drop_path_rate:g}'
class VARHF(VAR, PyTorchModelHubMixin):
# repo_url="https://github.com/FoundationVision/VAR",
# tags=["image-generation"]):
def __init__(
self,
vae_kwargs,
num_classes=1000, depth=16, embed_dim=1024, num_heads=16, mlp_ratio=4., drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
norm_eps=1e-6, shared_aln=False, cond_drop_rate=0.1,
attn_l2_norm=False,
patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # 10 steps by default
flash_if_available=True, fused_if_available=True,
):
vae_local = VQVAE(**vae_kwargs)
super().__init__(
vae_local=vae_local,
num_classes=num_classes, depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate,
norm_eps=norm_eps, shared_aln=shared_aln, cond_drop_rate=cond_drop_rate,
attn_l2_norm=attn_l2_norm,
patch_nums=patch_nums,
flash_if_available=flash_if_available, fused_if_available=fused_if_available,
)
================================================
FILE: models/vqvae.py
================================================
"""
References:
- VectorQuantizer2: https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/taming/modules/vqvae/quantize.py#L110
- GumbelQuantize: https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/taming/modules/vqvae/quantize.py#L213
- VQVAE (VQModel): https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/models/autoencoder.py#L14
"""
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import torch
import torch.nn as nn
from .basic_vae import Decoder, Encoder
from .quant import VectorQuantizer2
class VQVAE(nn.Module):
def __init__(
self, vocab_size=4096, z_channels=32, ch=128, dropout=0.0,
beta=0.25, # commitment loss weight
using_znorm=False, # whether to normalize when computing the nearest neighbors
quant_conv_ks=3, # quant conv kernel size
quant_resi=0.5, # 0.5 means \phi(x) = 0.5conv(x) + (1-0.5)x
share_quant_resi=4, # use 4 \phi layers for K scales: partially-shared \phi
default_qresi_counts=0, # if is 0: automatically set to len(v_patch_nums)
v_patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # number of patches for each scale, h_{1 to K} = w_{1 to K} = v_patch_nums[k]
test_mode=True,
):
super().__init__()
self.test_mode = test_mode
self.V, self.Cvae = vocab_size, z_channels
# ddconfig is copied from https://github.com/CompVis/latent-diffusion/blob/e66308c7f2e64cb581c6d27ab6fbeb846828253b/models/first_stage_models/vq-f16/config.yaml
ddconfig = dict(
dropout=dropout, ch=ch, z_channels=z_channels,
in_channels=3, ch_mult=(1, 1, 2, 2, 4), num_res_blocks=2, # from vq-f16/config.yaml above
using_sa=True, using_mid_sa=True, # from vq-f16/config.yaml above
# resamp_with_conv=True, # always True, removed.
)
ddconfig.pop('double_z', None) # only KL-VAE should use double_z=True
self.encoder = Encoder(double_z=False, **ddconfig)
self.decoder = Decoder(**ddconfig)
self.vocab_size = vocab_size
self.downsample = 2 ** (len(ddconfig['ch_mult'])-1)
self.quantize: VectorQuantizer2 = VectorQuantizer2(
vocab_size=vocab_size, Cvae=self.Cvae, using_znorm=using_znorm, beta=beta,
default_qresi_counts=default_qresi_counts, v_patch_nums=v_patch_nums, quant_resi=quant_resi, share_quant_resi=share_quant_resi,
)
self.quant_conv = torch.nn.Conv2d(self.Cvae, self.Cvae, quant_conv_ks, stride=1, padding=quant_conv_ks//2)
self.post_quant_conv = torch.nn.Conv2d(self.Cvae, self.Cvae, quant_conv_ks, stride=1, padding=quant_conv_ks//2)
if self.test_mode:
self.eval()
[p.requires_grad_(False) for p in self.parameters()]
# ===================== `forward` is only used in VAE training =====================
def forward(self, inp, ret_usages=False): # -> rec_B3HW, idx_N, loss
VectorQuantizer2.forward
f_hat, usages, vq_loss = self.quantize(self.quant_conv(self.encoder(inp)), ret_usages=ret_usages)
return self.decoder(self.post_quant_conv(f_hat)), usages, vq_loss
# ===================== `forward` is only used in VAE training =====================
def fhat_to_img(self, f_hat: torch.Tensor):
return self.decoder(self.post_quant_conv(f_hat)).clamp_(-1, 1)
def img_to_idxBl(self, inp_img_no_grad: torch.Tensor, v_patch_nums: Optional[Sequence[Union[int, Tuple[int, int]]]] = None) -> List[torch.LongTensor]: # return List[Bl]
f = self.quant_conv(self.encoder(inp_img_no_grad))
return self.quantize.f_to_idxBl_or_fhat(f, to_fhat=False, v_patch_nums=v_patch_nums)
def idxBl_to_img(self, ms_idx_Bl: List[torch.Tensor], same_shape: bool, last_one=False) -> Union[List[torch.Tensor], torch.Tensor]:
B = ms_idx_Bl[0].shape[0]
ms_h_BChw = []
for idx_Bl in ms_idx_Bl:
l = idx_Bl.shape[1]
pn = round(l ** 0.5)
ms_h_BChw.append(self.quantize.embedding(idx_Bl).transpose(1, 2).view(B, self.Cvae, pn, pn))
return self.embed_to_img(ms_h_BChw=ms_h_BChw, all_to_max_scale=same_shape, last_one=last_one)
def embed_to_img(self, ms_h_BChw: List[torch.Tensor], all_to_max_scale: bool, last_one=False) -> Union[List[torch.Tensor], torch.Tensor]:
if last_one:
return self.decoder(self.post_quant_conv(self.quantize.embed_to_fhat(ms_h_BChw, all_to_max_scale=all_to_max_scale, last_one=True))).clamp_(-1, 1)
else:
return [self.decoder(self.post_quant_conv(f_hat)).clamp_(-1, 1) for f_hat in self.quantize.embed_to_fhat(ms_h_BChw, all_to_max_scale=all_to_max_scale, last_one=False)]
def img_to_reconstructed_img(self, x, v_patch_nums: Optional[Sequence[Union[int, Tuple[int, int]]]] = None, last_one=False) -> List[torch.Tensor]:
f = self.quant_conv(self.encoder(x))
ls_f_hat_BChw = self.quantize.f_to_idxBl_or_fhat(f, to_fhat=True, v_patch_nums=v_patch_nums)
if last_one:
return self.decoder(self.post_quant_conv(ls_f_hat_BChw[-1])).clamp_(-1, 1)
else:
return [self.decoder(self.post_quant_conv(f_hat)).clamp_(-1, 1) for f_hat in ls_f_hat_BChw]
def load_state_dict(self, state_dict: Dict[str, Any], strict=True, assign=False):
if 'quantize.ema_vocab_hit_SV' in state_dict and state_dict['quantize.ema_vocab_hit_SV'].shape[0] != self.quantize.ema_vocab_hit_SV.shape[0]:
state_dict['quantize.ema_vocab_hit_SV'] = self.quantize.ema_vocab_hit_SV
return super().load_state_dict(state_dict=state_dict, strict=strict, assign=assign)
================================================
FILE: train.py
================================================
import gc
import os
import shutil
import sys
import time
import warnings
from functools import partial
import torch
from torch.utils.data import DataLoader
import dist
from utils import arg_util, misc
from utils.data import build_dataset
from utils.data_sampler import DistInfiniteBatchSampler, EvalDistributedSampler
from utils.misc import auto_resume
def build_everything(args: arg_util.Args):
# resume
auto_resume_info, start_ep, start_it, trainer_state, args_state = auto_resume(args, 'ar-ckpt*.pth')
# create tensorboard logger
tb_lg: misc.TensorboardLogger
with_tb_lg = dist.is_master()
if with_tb_lg:
os.makedirs(args.tb_log_dir_path, exist_ok=True)
# noinspection PyTypeChecker
tb_lg = misc.DistLogger(misc.TensorboardLogger(log_dir=args.tb_log_dir_path, filename_suffix=f'__{misc.time_str("%m%d_%H%M")}'), verbose=True)
tb_lg.flush()
else:
# noinspection PyTypeChecker
tb_lg = misc.DistLogger(None, verbose=False)
dist.barrier()
# log args
print(f'global bs={args.glb_batch_size}, local bs={args.batch_size}')
print(f'initial args:\n{str(args)}')
# build data
if not args.local_debug:
print(f'[build PT data] ...\n')
num_classes, dataset_train, dataset_val = build_dataset(
args.data_path, final_reso=args.data_load_reso, hflip=args.hflip, mid_reso=args.mid_reso,
)
types = str((type(dataset_train).__name__, type(dataset_val).__name__))
ld_val = DataLoader(
dataset_val, num_workers=0, pin_memory=True,
batch_size=round(args.batch_size*1.5), sampler=EvalDistributedSampler(dataset_val, num_replicas=dist.get_world_size(), rank=dist.get_rank()),
shuffle=False, drop_last=False,
)
del dataset_val
ld_train = DataLoader(
dataset=dataset_train, num_workers=args.workers, pin_memory=True,
generator=args.get_different_generator_for_each_rank(), # worker_init_fn=worker_init_fn,
batch_sampler=DistInfiniteBatchSampler(
dataset_len=len(dataset_train), glb_batch_size=args.glb_batch_size, same_seed_for_all_ranks=args.same_seed_for_all_ranks,
shuffle=True, fill_last=True, rank=dist.get_rank(), world_size=dist.get_world_size(), start_ep=start_ep, start_it=start_it,
),
)
del dataset_train
[print(line) for line in auto_resume_info]
print(f'[dataloader multi processing] ...', end='', flush=True)
stt = time.time()
iters_train = len(ld_train)
ld_train = iter(ld_train)
# noinspection PyArgumentList
print(f' [dataloader multi processing](*) finished! ({time.time()-stt:.2f}s)', flush=True, clean=True)
print(f'[dataloader] gbs={args.glb_batch_size}, lbs={args.batch_size}, iters_train={iters_train}, types(tr, va)={types}')
else:
num_classes = 1000
ld_val = ld_train = None
iters_train = 10
# build models
from torch.nn.parallel import DistributedDataParallel as DDP
from models import VAR, VQVAE, build_vae_var
from trainer import VARTrainer
from utils.amp_sc import AmpOptimizer
from utils.lr_control import filter_params
vae_local, var_wo_ddp = build_vae_var(
V=4096, Cvae=32, ch=160, share_quant_resi=4, # hard-coded VQVAE hyperparameters
device=dist.get_device(), patch_nums=args.patch_nums,
num_classes=num_classes, depth=args.depth, shared_aln=args.saln, attn_l2_norm=args.anorm,
flash_if_available=args.fuse, fused_if_available=args.fuse,
init_adaln=args.aln, init_adaln_gamma=args.alng, init_head=args.hd, init_std=args.ini,
)
vae_ckpt = 'vae_ch160v4096z32.pth'
if dist.is_local_master():
if not os.path.exists(vae_ckpt):
os.system(f'wget https://huggingface.co/FoundationVision/var/resolve/main/{vae_ckpt}')
dist.barrier()
vae_local.load_state_dict(torch.load(vae_ckpt, map_location='cpu'), strict=True)
vae_local: VQVAE = args.compile_model(vae_local, args.vfast)
var_wo_ddp: VAR = args.compile_model(var_wo_ddp, args.tfast)
var: DDP = (DDP if dist.initialized() else NullDDP)(var_wo_ddp, device_ids=[dist.get_local_rank()], find_unused_parameters=False, broadcast_buffers=False)
print(f'[INIT] VAR model = {var_wo_ddp}\n\n')
count_p = lambda m: f'{sum(p.numel() for p in m.parameters())/1e6:.2f}'
print(f'[INIT][#para] ' + ', '.join([f'{k}={count_p(m)}' for k, m in (('VAE', vae_local), ('VAE.enc', vae_local.encoder), ('VAE.dec', vae_local.decoder), ('VAE.quant', vae_local.quantize))]))
print(f'[INIT][#para] ' + ', '.join([f'{k}={count_p(m)}' for k, m in (('VAR', var_wo_ddp),)]) + '\n\n')
# build optimizer
names, paras, para_groups = filter_params(var_wo_ddp, nowd_keys={
'cls_token', 'start_token', 'task_token', 'cfg_uncond',
'pos_embed', 'pos_1LC', 'pos_start', 'start_pos', 'lvl_embed',
'gamma', 'beta',
'ada_gss', 'moe_bias',
'scale_mul',
})
opt_clz = {
'adam': partial(torch.optim.AdamW, betas=(0.9, 0.95), fused=args.afuse),
'adamw': partial(torch.optim.AdamW, betas=(0.9, 0.95), fused=args.afuse),
}[args.opt.lower().strip()]
opt_kw = dict(lr=args.tlr, weight_decay=0)
print(f'[INIT] optim={opt_clz}, opt_kw={opt_kw}\n')
var_optim = AmpOptimizer(
mixed_precision=args.fp16, optimizer=opt_clz(params=para_groups, **opt_kw), names=names, paras=paras,
grad_clip=args.tclip, n_gradient_accumulation=args.ac
)
del names, paras, para_groups
# build trainer
trainer = VARTrainer(
device=args.device, patch_nums=args.patch_nums, resos=args.resos,
vae_local=vae_local, var_wo_ddp=var_wo_ddp, var=var,
var_opt=var_optim, label_smooth=args.ls,
)
if trainer_state is not None and len(trainer_state):
trainer.load_state_dict(trainer_state, strict=False, skip_vae=True) # don't load vae again
del vae_local, var_wo_ddp, var, var_optim
if args.local_debug:
rng = torch.Generator('cpu')
rng.manual_seed(0)
B = 4
inp = torch.rand(B, 3, args.data_load_reso, args.data_load_reso)
label = torch.ones(B, dtype=torch.long)
me = misc.MetricLogger(delimiter=' ')
trainer.train_step(
it=0, g_it=0, stepping=True, metric_lg=me, tb_lg=tb_lg,
inp_B3HW=inp, label_B=label, prog_si=args.pg0, prog_wp_it=20,
)
trainer.load_state_dict(trainer.state_dict())
trainer.train_step(
it=99, g_it=599, stepping=True, metric_lg=me, tb_lg=tb_lg,
inp_B3HW=inp, label_B=label, prog_si=-1, prog_wp_it=20,
)
print({k: meter.global_avg for k, meter in me.meters.items()})
args.dump_log(); tb_lg.flush(); tb_lg.close()
if isinstance(sys.stdout, misc.SyncPrint) and isinstance(sys.stderr, misc.SyncPrint):
sys.stdout.close(), sys.stderr.close()
exit(0)
dist.barrier()
return (
tb_lg, trainer, start_ep, start_it,
iters_train, ld_train, ld_val
)
def main_training():
args: arg_util.Args = arg_util.init_dist_and_get_args()
if args.local_debug:
torch.autograd.set_detect_anomaly(True)
(
tb_lg, trainer,
start_ep, start_it,
iters_train, ld_train, ld_val
) = build_everything(args)
# train
start_time = time.time()
best_L_mean, best_L_tail, best_acc_mean, best_acc_tail = 999., 999., -1., -1.
best_val_loss_mean, best_val_loss_tail, best_val_acc_mean, best_val_acc_tail = 999, 999, -1, -1
L_mean, L_tail = -1, -1
for ep in range(start_ep, args.ep):
if hasattr(ld_train, 'sampler') and hasattr(ld_train.sampler, 'set_epoch'):
ld_train.sampler.set_epoch(ep)
if ep < 3:
# noinspection PyArgumentList
print(f'[{type(ld_train).__name__}] [ld_train.sampler.set_epoch({ep})]', flush=True, force=True)
tb_lg.set_step(ep * iters_train)
stats, (sec, remain_time, finish_time) = train_one_ep(
ep, ep == start_ep, start_it if ep == start_ep else 0, args, tb_lg, ld_train, iters_train, trainer
)
L_mean, L_tail, acc_mean, acc_tail, grad_norm = stats['Lm'], stats['Lt'], stats['Accm'], stats['Acct'], stats['tnm']
best_L_mean, best_acc_mean = min(best_L_mean, L_mean), max(best_acc_mean, acc_mean)
if L_tail != -1: best_L_tail, best_acc_tail = min(best_L_tail, L_tail), max(best_acc_tail, acc_tail)
args.L_mean, args.L_tail, args.acc_mean, args.acc_tail, args.grad_norm = L_mean, L_tail, acc_mean, acc_tail, grad_norm
args.cur_ep = f'{ep+1}/{args.ep}'
args.remain_time, args.finish_time = remain_time, finish_time
AR_ep_loss = dict(L_mean=L_mean, L_tail=L_tail, acc_mean=acc_mean, acc_tail=acc_tail)
is_val_and_also_saving = (ep + 1) % 10 == 0 or (ep + 1) == args.ep
if is_val_and_also_saving:
val_loss_mean, val_loss_tail, val_acc_mean, val_acc_tail, tot, cost = trainer.eval_ep(ld_val)
best_updated = best_val_loss_tail > val_loss_tail
best_val_loss_mean, best_val_loss_tail = min(best_val_loss_mean, val_loss_mean), min(best_val_loss_tail, val_loss_tail)
best_val_acc_mean, best_val_acc_tail = max(best_val_acc_mean, val_acc_mean), max(best_val_acc_tail, val_acc_tail)
AR_ep_loss.update(vL_mean=val_loss_mean, vL_tail=val_loss_tail, vacc_mean=val_acc_mean, vacc_tail=val_acc_tail)
args.vL_mean, args.vL_tail, args.vacc_mean, args.vacc_tail = val_loss_mean, val_loss_tail, val_acc_mean, val_acc_tail
print(f' [*] [ep{ep}] (val {tot}) Lm: {L_mean:.4f}, Lt: {L_tail:.4f}, Acc m&t: {acc_mean:.2f} {acc_tail:.2f}, Val cost: {cost:.2f}s')
if dist.is_local_master():
local_out_ckpt = os.path.join(args.local_out_dir_path, 'ar-ckpt-last.pth')
local_out_ckpt_best = os.path.join(args.local_out_dir_path, 'ar-ckpt-best.pth')
print(f'[saving ckpt] ...', end='', flush=True)
torch.save({
'epoch': ep+1,
'iter': 0,
'trainer': trainer.state_dict(),
'args': args.state_dict(),
}, local_out_ckpt)
if best_updated:
shutil.copy(local_out_ckpt, local_out_ckpt_best)
print(f' [saving ckpt](*) finished! @ {local_out_ckpt}', flush=True, clean=True)
dist.barrier()
print( f' [ep{ep}] (training ) Lm: {best_L_mean:.3f} ({L_mean:.3f}), Lt: {best_L_tail:.3f} ({L_tail:.3f}), Acc m&t: {best_acc_mean:.2f} {best_acc_tail:.2f}, Remain: {remain_time}, Finish: {finish_time}', flush=True)
tb_lg.update(head='AR_ep_loss', step=ep+1, **AR_ep_loss)
tb_lg.update(head='AR_z_burnout', step=ep+1, rest_hours=round(sec / 60 / 60, 2))
args.dump_log(); tb_lg.flush()
total_time = f'{(time.time() - start_time) / 60 / 60:.1f}h'
print('\n\n')
print(f' [*] [PT finished] Total cost: {total_time}, Lm: {best_L_mean:.3f} ({L_mean}), Lt: {best_L_tail:.3f} ({L_tail})')
print('\n\n')
del stats
del iters_train, ld_train
time.sleep(3), gc.collect(), torch.cuda.empty_cache(), time.sleep(3)
args.remain_time, args.finish_time = '-', time.strftime("%Y-%m-%d %H:%M", time.localtime(time.time() - 60))
print(f'final args:\n\n{str(args)}')
args.dump_log(); tb_lg.flush(); tb_lg.close()
dist.barrier()
def train_one_ep(ep: int, is_first_ep: bool, start_it: int, args: arg_util.Args, tb_lg: misc.TensorboardLogger, ld_or_itrt, iters_train: int, trainer):
# import heavy packages after Dataloader object creation
from trainer import VARTrainer
from utils.lr_control import lr_wd_annealing
trainer: VARTrainer
step_cnt = 0
me = misc.MetricLogger(delimiter=' ')
me.add_meter('tlr', misc.SmoothedValue(window_size=1, fmt='{value:.2g}'))
me.add_meter('tnm', misc.SmoothedValue(window_size=1, fmt='{value:.2f}'))
[me.add_meter(x, misc.SmoothedValue(fmt='{median:.3f} ({global_avg:.3f})')) for x in ['Lm', 'Lt']]
[me.add_meter(x, misc.SmoothedValue(fmt='{median:.2f} ({global_avg:.2f})')) for x in ['Accm', 'Acct']]
header = f'[Ep]: [{ep:4d}/{args.ep}]'
if is_first_ep:
warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings('ignore', category=UserWarning)
g_it, max_it = ep * iters_train, args.ep * iters_train
for it, (inp, label) in me.log_every(start_it, iters_train, ld_or_itrt, 30 if iters_train > 8000 else 5, header):
g_it = ep * iters_train + it
if it < start_it: continue
if is_first_ep and it == start_it: warnings.resetwarnings()
inp = inp.to(args.device, non_blocking=True)
label = label.to(args.device, non_blocking=True)
args.cur_it = f'{it+1}/{iters_train}'
wp_it = args.wp * iters_train
min_tlr, max_tlr, min_twd, max_twd = lr_wd_annealing(args.sche, trainer.var_opt.optimizer, args.tlr, args.twd, args.twde, g_it, wp_it, max_it, wp0=args.wp0, wpe=args.wpe)
args.cur_lr, args.cur_wd = max_tlr, max_twd
if args.pg: # default: args.pg == 0.0, means no progressive training, won't get into this
if g_it <= wp_it: prog_si = args.pg0
elif g_it >= max_it*args.pg: prog_si = len(args.patch_nums) - 1
else:
delta = len(args.patch_nums) - 1 - args.pg0
progress = min(max((g_it - wp_it) / (max_it*args.pg - wp_it), 0), 1) # from 0 to 1
prog_si = args.pg0 + round(progress * delta) # from args.pg0 to len(args.patch_nums)-1
else:
prog_si = -1
stepping = (g_it + 1) % args.ac == 0
step_cnt += int(stepping)
grad_norm, scale_log2 = trainer.train_step(
it=it, g_it=g_it, stepping=stepping, metric_lg=me, tb_lg=tb_lg,
inp_B3HW=inp, label_B=label, prog_si=prog_si, prog_wp_it=args.pgwp * iters_train,
)
me.update(tlr=max_tlr)
tb_lg.set_step(step=g_it)
tb_lg.update(head='AR_opt_lr/lr_min', sche_tlr=min_tlr)
tb_lg.update(head='AR_opt_lr/lr_max', sche_tlr=max_tlr)
tb_lg.update(head='AR_opt_wd/wd_max', sche_twd=max_twd)
tb_lg.update(head='AR_opt_wd/wd_min', sche_twd=min_twd)
tb_lg.update(head='AR_opt_grad/fp16', scale_log2=scale_log2)
if args.tclip > 0:
tb_lg.update(head='AR_opt_grad/grad', grad_norm=grad_norm)
tb_lg.update(head='AR_opt_grad/grad', grad_clip=args.tclip)
me.synchronize_between_processes()
return {k: meter.global_avg for k, meter in me.meters.items()}, me.iter_time.time_preds(max_it - (g_it + 1) + (args.ep - ep) * 15) # +15: other cost
class NullDDP(torch.nn.Module):
def __init__(self, module, *args, **kwargs):
super(NullDDP, self).__init__()
self.module = module
self.require_backward_grad_sync = False
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
if __name__ == '__main__':
try: main_training()
finally:
dist.finalize()
if isinstance(sys.stdout, misc.SyncPrint) and isinstance(sys.stderr, misc.SyncPrint):
sys.stdout.close(), sys.stderr.close()
================================================
FILE: trainer.py
================================================
import time
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
import dist
from models import VAR, VQVAE, VectorQuantizer2
from utils.amp_sc import AmpOptimizer
from utils.misc import MetricLogger, TensorboardLogger
Ten = torch.Tensor
FTen = torch.Tensor
ITen = torch.LongTensor
BTen = torch.BoolTensor
class VARTrainer(object):
def __init__(
self, device, patch_nums: Tuple[int, ...], resos: Tuple[int, ...],
vae_local: VQVAE, var_wo_ddp: VAR, var: DDP,
var_opt: AmpOptimizer, label_smooth: float,
):
super(VARTrainer, self).__init__()
self.var, self.vae_local, self.quantize_local = var, vae_local, vae_local.quantize
self.quantize_local: VectorQuantizer2
self.var_wo_ddp: VAR = var_wo_ddp # after torch.compile
self.var_opt = var_opt
del self.var_wo_ddp.rng
self.var_wo_ddp.rng = torch.Generator(device=device)
self.label_smooth = label_smooth
self.train_loss = nn.CrossEntropyLoss(label_smoothing=label_smooth, reduction='none')
self.val_loss = nn.CrossEntropyLoss(label_smoothing=0.0, reduction='mean')
self.L = sum(pn * pn for pn in patch_nums)
self.last_l = patch_nums[-1] * patch_nums[-1]
self.loss_weight = torch.ones(1, self.L, device=device) / self.L
self.patch_nums, self.resos = patch_nums, resos
self.begin_ends = []
cur = 0
for i, pn in enumerate(patch_nums):
self.begin_ends.append((cur, cur + pn * pn))
cur += pn*pn
self.prog_it = 0
self.last_prog_si = -1
self.first_prog = True
@torch.no_grad()
def eval_ep(self, ld_val: DataLoader):
tot = 0
L_mean, L_tail, acc_mean, acc_tail = 0, 0, 0, 0
stt = time.time()
training = self.var_wo_ddp.training
self.var_wo_ddp.eval()
for inp_B3HW, label_B in ld_val:
B, V = label_B.shape[0], self.vae_local.vocab_size
inp_B3HW = inp_B3HW.to(dist.get_device(), non_blocking=True)
label_B = label_B.to(dist.get_device(), non_blocking=True)
gt_idx_Bl: List[ITen] = self.vae_local.img_to_idxBl(inp_B3HW)
gt_BL = torch.cat(gt_idx_Bl, dim=1)
x_BLCv_wo_first_l: Ten = self.quantize_local.idxBl_to_var_input(gt_idx_Bl)
self.var_wo_ddp.forward
logits_BLV = self.var_wo_ddp(label_B, x_BLCv_wo_first_l)
L_mean += self.val_loss(logits_BLV.data.view(-1, V), gt_BL.view(-1)) * B
L_tail += self.val_loss(logits_BLV.data[:, -self.last_l:].reshape(-1, V), gt_BL[:, -self.last_l:].reshape(-1)) * B
acc_mean += (logits_BLV.data.argmax(dim=-1) == gt_BL).sum() * (100/gt_BL.shape[1])
acc_tail += (logits_BLV.data[:, -self.last_l:].argmax(dim=-1) == gt_BL[:, -self.last_l:]).sum() * (100 / self.last_l)
tot += B
self.var_wo_ddp.train(training)
stats = L_mean.new_tensor([L_mean.item(), L_tail.item(), acc_mean.item(), acc_tail.item(), tot])
dist.allreduce(stats)
tot = round(stats[-1].item())
stats /= tot
L_mean, L_tail, acc_mean, acc_tail, _ = stats.tolist()
return L_mean, L_tail, acc_mean, acc_tail, tot, time.time()-stt
def train_step(
self, it: int, g_it: int, stepping: bool, metric_lg: MetricLogger, tb_lg: TensorboardLogger,
inp_B3HW: FTen, label_B: Union[ITen, FTen], prog_si: int, prog_wp_it: float,
) -> Tuple[Optional[Union[Ten, float]], Optional[float]]:
# if progressive training
self.var_wo_ddp.prog_si = self.vae_local.quantize.prog_si = prog_si
if self.last_prog_si != prog_si:
if self.last_prog_si != -1: self.first_prog = False
self.last_prog_si = prog_si
self.prog_it = 0
self.prog_it += 1
prog_wp = max(min(self.prog_it / prog_wp_it, 1), 0.01)
if self.first_prog: prog_wp = 1 # no prog warmup at first prog stage, as it's already solved in wp
if prog_si == len(self.patch_nums) - 1: prog_si = -1 # max prog, as if no prog
# forward
B, V = label_B.shape[0], self.vae_local.vocab_size
self.var.require_backward_grad_sync = stepping
gt_idx_Bl: List[ITen] = self.vae_local.img_to_idxBl(inp_B3HW)
gt_BL = torch.cat(gt_idx_Bl, dim=1)
x_BLCv_wo_first_l: Ten = self.quantize_local.idxBl_to_var_input(gt_idx_Bl)
with self.var_opt.amp_ctx:
self.var_wo_ddp.forward
logits_BLV = self.var(label_B, x_BLCv_wo_first_l)
loss = self.train_loss(logits_BLV.view(-1, V), gt_BL.view(-1)).view(B, -1)
if prog_si >= 0: # in progressive training
bg, ed = self.begin_ends[prog_si]
assert logits_BLV.shape[1] == gt_BL.shape[1] == ed
lw = self.loss_weight[:, :ed].clone()
lw[:, bg:ed] *= min(max(prog_wp, 0), 1)
else: # not in progressive training
lw = self.loss_weight
loss = loss.mul(lw).sum(dim=-1).mean()
# backward
grad_norm, scale_log2 = self.var_opt.backward_clip_step(loss=loss, stepping=stepping)
# log
pred_BL = logits_BLV.data.argmax(dim=-1)
if it == 0 or it in metric_lg.log_iters:
Lmean = self.val_loss(logits_BLV.data.view(-1, V), gt_BL.view(-1)).item()
acc_mean = (pred_BL == gt_BL).float().mean().item() * 100
if prog_si >= 0: # in progressive training
Ltail = acc_tail = -1
else: # not in progressive training
Ltail = self.val_loss(logits_BLV.data[:, -self.last_l:].reshape(-1, V), gt_BL[:, -self.last_l:].reshape(-1)).item()
acc_tail = (pred_BL[:, -self.last_l:] == gt_BL[:, -self.last_l:]).float().mean().item() * 100
grad_norm = grad_norm.item()
metric_lg.update(Lm=Lmean, Lt=Ltail, Accm=acc_mean, Acct=acc_tail, tnm=grad_norm)
# log to tensorboard
if g_it == 0 or (g_it + 1) % 500 == 0:
prob_per_class_is_chosen = pred_BL.view(-1).bincount(minlength=V).float()
dist.allreduce(prob_per_class_is_chosen)
prob_per_class_is_chosen /= prob_per_class_is_chosen.sum()
cluster_usage = (prob_per_class_is_chosen > 0.001 / V).float().mean().item() * 100
if dist.is_master():
if g_it == 0:
tb_lg.update(head='AR_iter_loss', z_voc_usage=cluster_usage, step=-10000)
tb_lg.update(head='AR_iter_loss', z_voc_usage=cluster_usage, step=-1000)
kw = dict(z_voc_usage=cluster_usage)
for si, (bg, ed) in enumerate(self.begin_ends):
if 0 <= prog_si < si: break
pred, tar = logits_BLV.data[:, bg:ed].reshape(-1, V), gt_BL[:, bg:ed].reshape(-1)
acc = (pred.argmax(dim=-1) == tar).float().mean().item() * 100
ce = self.val_loss(pred, tar).item()
kw[f'acc_{self.resos[si]}'] = acc
kw[f'L_{self.resos[si]}'] = ce
tb_lg.update(head='AR_iter_loss', **kw, step=g_it)
tb_lg.update(head='AR_iter_schedule', prog_a_reso=self.resos[prog_si], prog_si=prog_si, prog_wp=prog_wp, step=g_it)
self.var_wo_ddp.prog_si = self.vae_local.quantize.prog_si = -1
return grad_norm, scale_log2
def get_config(self):
return {
'patch_nums': self.patch_nums, 'resos': self.resos,
'label_smooth': self.label_smooth,
'prog_it': self.prog_it, 'last_prog_si': self.last_prog_si, 'first_prog': self.first_prog,
}
def state_dict(self):
state = {'config': self.get_config()}
for k in ('var_wo_ddp', 'vae_local', 'var_opt'):
m = getattr(self, k)
if m is not None:
if hasattr(m, '_orig_mod'):
m = m._orig_mod
state[k] = m.state_dict()
return state
def load_state_dict(self, state, strict=True, skip_vae=False):
for k in ('var_wo_ddp', 'vae_local', 'var_opt'):
if skip_vae and 'vae' in k: continue
m = getattr(self, k)
if m is not None:
if hasattr(m, '_orig_mod'):
m = m._orig_mod
ret = m.load_state_dict(state[k], strict=strict)
if ret is not None:
missing, unexpected = ret
print(f'[VARTrainer.load_state_dict] {k} missing: {missing}')
print(f'[VARTrainer.load_state_dict] {k} unexpected: {unexpected}')
config: dict = state.pop('config', None)
self.prog_it = config.get('prog_it', 0)
self.last_prog_si = config.get('last_prog_si', -1)
self.first_prog = config.get('first_prog', True)
if config is not None:
for k, v in self.get_config().items():
if config.get(k, None) != v:
err = f'[VAR.load_state_dict] config mismatch: this.{k}={v} (ckpt.{k}={config.get(k, None)})'
if strict: raise AttributeError(err)
else: print(err)
================================================
FILE: utils/amp_sc.py
================================================
import math
from typing import List, Optional, Tuple, Union
import torch
class NullCtx:
def __enter__(self):
pass
def __exit__(self, exc_type, exc_val, exc_tb):
pass
class AmpOptimizer:
def __init__(
self,
mixed_precision: int,
optimizer: torch.optim.Optimizer, names: List[str], paras: List[torch.nn.Parameter],
grad_clip: float, n_gradient_accumulation: int = 1,
):
self.enable_amp = mixed_precision > 0
self.using_fp16_rather_bf16 = mixed_precision == 1
if self.enable_amp:
self.amp_ctx = torch.autocast('cuda', enabled=True, dtype=torch.float16 if self.using_fp16_rather_bf16 else torch.bfloat16, cache_enabled=True)
self.scaler = torch.cuda.amp.GradScaler(init_scale=2. ** 11, growth_interval=1000) if self.using_fp16_rather_bf16 else None # only fp16 needs a scaler
else:
self.amp_ctx = NullCtx()
self.scaler = None
self.optimizer, self.names, self.paras = optimizer, names, paras # paras have been filtered so everyone requires grad
self.grad_clip = grad_clip
self.early_clipping = self.grad_clip > 0 and not hasattr(optimizer, 'global_grad_norm')
self.late_clipping = self.grad_clip > 0 and hasattr(optimizer, 'global_grad_norm')
self.r_accu = 1 / n_gradient_accumulation # r_accu == 1.0 / n_gradient_accumulation
def backward_clip_step(
self, stepping: bool, loss: torch.Tensor,
) -> Tuple[Optional[Union[torch.Tensor, float]], Optional[float]]:
# backward
loss = loss.mul(self.r_accu) # r_accu == 1.0 / n_gradient_accumulation
orig_norm = scaler_sc = None
if self.scaler is not None:
self.scaler.scale(loss).backward(retain_graph=False, create_graph=False)
else:
loss.backward(retain_graph=False, create_graph=False)
if stepping:
if self.scaler is not None: self.scaler.unscale_(self.optimizer)
if self.early_clipping:
orig_norm = torch.nn.utils.clip_grad_norm_(self.paras, self.grad_clip)
if self.scaler is not None:
self.scaler.step(self.optimizer)
scaler_sc: float = self.scaler.get_scale()
if scaler_sc > 32768.: # fp16 will overflow when >65536, so multiply 32768 could be dangerous
self.scaler.update(new_scale=32768.)
else:
self.scaler.update()
try:
scaler_sc = float(math.log2(scaler_sc))
except Exception as e:
print(f'[scaler_sc = {scaler_sc}]\n' * 15, flush=True)
raise e
else:
self.optimizer.step()
if self.late_clipping:
orig_norm = self.optimizer.global_grad_norm
self.optimizer.zero_grad(set_to_none=True)
return orig_norm, scaler_sc
def state_dict(self):
return {
'optimizer': self.optimizer.state_dict()
} if self.scaler is None else {
'scaler': self.scaler.state_dict(),
'optimizer': self.optimizer.state_dict()
}
def load_state_dict(self, state, strict=True):
if self.scaler is not None:
try: self.scaler.load_state_dict(state['scaler'])
except Exception as e: print(f'[fp16 load_state_dict err] {e}')
self.optimizer.load_state_dict(state['optimizer'])
================================================
FILE: utils/arg_util.py
================================================
import json
import os
import random
import re
import subprocess
import sys
import time
from collections import OrderedDict
from typing import Optional, Union
import numpy as np
import torch
try:
from tap import Tap
except ImportError as e:
print(f'`>>>>>>>> from tap import Tap` failed, please run: pip3 install typed-argument-parser <<<<<<<<', file=sys.stderr, flush=True)
print(f'`>>>>>>>> from tap import Tap` failed, please run: pip3 install typed-argument-parser <<<<<<<<', file=sys.stderr, flush=True)
time.sleep(5)
raise e
import dist
class Args(Tap):
data_path: str = '/path/to/imagenet'
exp_name: str = 'text'
# VAE
vfast: int = 0 # torch.compile VAE; =0: not compile; 1: compile with 'reduce-overhead'; 2: compile with 'max-autotune'
# VAR
tfast: int = 0 # torch.compile VAR; =0: not compile; 1: compile with 'reduce-overhead'; 2: compile with 'max-autotune'
depth: int = 16 # VAR depth
# VAR initialization
ini: float = -1 # -1: automated model parameter initialization
hd: float = 0.02 # head.w *= hd
aln: float = 0.5 # the multiplier of ada_lin.w's initialization
alng: float = 1e-5 # the multiplier of ada_lin.w[gamma channels]'s initialization
# VAR optimization
fp16: int = 0 # 1: using fp16, 2: bf16
tblr: float = 1e-4 # base lr
tlr: float = None # lr = base lr * (bs / 256)
twd: float = 0.05 # initial wd
twde: float = 0 # final wd, =twde or twd
tclip: float = 2. # <=0 for not using grad clip
ls: float = 0.0 # label smooth
bs: int = 768 # global batch size
batch_size: int = 0 # [automatically set; don't specify this] batch size per GPU = round(args.bs / args.ac / dist.get_world_size() / 8) * 8
glb_batch_size: int = 0 # [automatically set; don't specify this] global batch size = args.batch_size * dist.get_world_size()
ac: int = 1 # gradient accumulation
ep: int = 250
wp: float = 0
wp0: float = 0.005 # initial lr ratio at the begging of lr warm up
wpe: float = 0.01 # final lr ratio at the end of training
sche: str = 'lin0' # lr schedule
opt: str = 'adamw' # lion: https://cloud.tencent.com/developer/article/2336657?areaId=106001 lr=5e-5 (0.25x) wd=0.8 (8x); Lion needs a large bs to work
afuse: bool = True # fused adamw
# other hps
saln: bool = False # whether to use shared adaln
anorm: bool = True # whether to use L2 normalized attention
fuse: bool = True # whether to use fused op like flash attn, xformers, fused MLP, fused LayerNorm, etc.
# data
pn: str = '1_2_3_4_5_6_8_10_13_16'
patch_size: int = 16
patch_nums: tuple = None # [automatically set; don't specify this] = tuple(map(int, args.pn.replace('-', '_').split('_')))
resos: tuple = None # [automatically set; don't specify this] = tuple(pn * args.patch_size for pn in args.patch_nums)
data_load_reso: int = None # [automatically set; don't specify this] would be max(patch_nums) * patch_size
mid_reso: float = 1.125 # aug: first resize to mid_reso = 1.125 * data_load_reso, then crop to data_load_reso
hflip: bool = False # augmentation: horizontal flip
workers: int = 0 # num workers; 0: auto, -1: don't use multiprocessing in DataLoader
# progressive training
pg: float = 0.0 # >0 for use progressive training during [0%, this] of training
pg0: int = 4 # progressive initial stage, 0: from the 1st token map, 1: from the 2nd token map, etc
pgwp: float = 0 # num of warmup epochs at each progressive stage
# would be automatically set in runtime
cmd: str = ' '.join(sys.argv[1:]) # [automatically set; don't specify this]
branch: str = subprocess.check_output(f'git symbolic-ref --short HEAD 2>/dev/null || git rev-parse HEAD', shell=True).decode('utf-8').strip() or '[unknown]' # [automatically set; don't specify this]
commit_id: str = subprocess.check_output(f'git rev-parse HEAD', shell=True).decode('utf-8').strip() or '[unknown]' # [automatically set; don't specify this]
commit_msg: str = (subprocess.check_output(f'git log -1', shell=True).decode('utf-8').strip().splitlines() or ['[unknown]'])[-1].strip() # [automatically set; don't specify this]
acc_mean: float = None # [automatically set; don't specify this]
acc_tail: float = None # [automatically set; don't specify this]
L_mean: float = None # [automatically set; don't specify this]
L_tail: float = None # [automatically set; don't specify this]
vacc_mean: float = None # [automatically set; don't specify this]
vacc_tail: float = None # [automatically set; don't specify this]
vL_mean: float = None # [automatically set; don't specify this]
vL_tail: float = None # [automatically set; don't specify this]
grad_norm: float = None # [automatically set; don't specify this]
cur_lr: float = None # [automatically set; don't specify this]
cur_wd: float = None # [automatically set; don't specify this]
cur_it: str = '' # [automatically set; don't specify this]
cur_ep: str = '' # [automatically set; don't specify this]
remain_time: str = '' # [automatically set; don't specify this]
finish_time: str = '' # [automatically set; don't specify this]
# environment
local_out_dir_path: str = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'local_output') # [automatically set; don't specify this]
tb_log_dir_path: str = '...tb-...' # [automatically set; don't specify this]
log_txt_path: str = '...' # [automatically set; don't specify this]
last_ckpt_path: str = '...' # [automatically set; don't specify this]
tf32: bool = True # whether to use TensorFloat32
device: str = 'cpu' # [automatically set; don't specify this]
seed: int = None # seed
def seed_everything(self, benchmark: bool):
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = benchmark
if self.seed is None:
torch.backends.cudnn.deterministic = False
else:
torch.backends.cudnn.deterministic = True
seed = self.seed * dist.get_world_size() + dist.get_rank()
os.environ['PYTHONHASHSEED'] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
same_seed_for_all_ranks: int = 0 # this is only for distributed sampler
def get_different_generator_for_each_rank(self) -> Optional[torch.Generator]: # for random augmentation
if self.seed is None:
return None
g = torch.Generator()
g.manual_seed(self.seed * dist.get_world_size() + dist.get_rank())
return g
local_debug: bool = 'KEVIN_LOCAL' in os.environ
dbg_nan: bool = False # 'KEVIN_LOCAL' in os.environ
def compile_model(self, m, fast):
if fast == 0 or self.local_debug:
return m
return torch.compile(m, mode={
1: 'reduce-overhead',
2: 'max-autotune',
3: 'default',
}[fast]) if hasattr(torch, 'compile') else m
def state_dict(self, key_ordered=True) -> Union[OrderedDict, dict]:
d = (OrderedDict if key_ordered else dict)()
# self.as_dict() would contain methods, but we only need variables
for k in self.class_variables.keys():
if k not in {'device'}: # these are not serializable
d[k] = getattr(self, k)
return d
def load_state_dict(self, d: Union[OrderedDict, dict, str]):
if isinstance(d, str): # for compatibility with old version
d: dict = eval('\n'.join([l for l in d.splitlines() if '<bound' not in l and 'device(' not in l]))
for k in d.keys():
try:
setattr(self, k, d[k])
except Exception as e:
print(f'k={k}, v={d[k]}')
raise e
@staticmethod
def set_tf32(tf32: bool):
if torch.cuda.is_available():
torch.backends.cudnn.allow_tf32 = bool(tf32)
torch.backends.cuda.matmul.allow_tf32 = bool(tf32)
if hasattr(torch, 'set_float32_matmul_precision'):
torch.set_float32_matmul_precision('high' if tf32 else 'highest')
print(f'[tf32] [precis] torch.get_float32_matmul_precision(): {torch.get_float32_matmul_precision()}')
print(f'[tf32] [ conv ] torch.backends.cudnn.allow_tf32: {torch.backends.cudnn.allow_tf32}')
print(f'[tf32] [matmul] torch.backends.cuda.matmul.allow_tf32: {torch.backends.cuda.matmul.allow_tf32}')
def dump_log(self):
if not dist.is_local_master():
return
if '1/' in self.cur_ep: # first time to dump log
with open(self.log_txt_path, 'w') as fp:
json.dump({'is_master': dist.is_master(), 'name': self.exp_name, 'cmd': self.cmd, 'commit': self.commit_id, 'branch': self.branch, 'tb_log_dir_path': self.tb_log_dir_path}, fp, indent=0)
fp.write('\n')
log_dict = {}
for k, v in {
'it': self.cur_it, 'ep': self.cur_ep,
'lr': self.cur_lr, 'wd': self.cur_wd, 'grad_norm': self.grad_norm,
'L_mean': self.L_mean, 'L_tail': self.L_tail, 'acc_mean': self.acc_mean, 'acc_tail': self.acc_tail,
'vL_mean': self.vL_mean, 'vL_tail': self.vL_tail, 'vacc_mean': self.vacc_mean, 'vacc_tail': self.vacc_tail,
'remain_time': self.remain_time, 'finish_time': self.finish_time,
}.items():
if hasattr(v, 'item'): v = v.item()
log_dict[k] = v
with open(self.log_txt_path, 'a') as fp:
fp.write(f'{log_dict}\n')
def __str__(self):
s = []
for k in self.class_variables.keys():
if k not in {'device', 'dbg_ks_fp'}: # these are not serializable
s.append(f' {k:20s}: {getattr(self, k)}')
s = '\n'.join(s)
return f'{{\n{s}\n}}\n'
def init_dist_and_get_args():
for i in range(len(sys.argv)):
if sys.argv[i].startswith('--local-rank=') or sys.argv[i].startswith('--local_rank='):
del sys.argv[i]
break
args = Args(explicit_bool=True).parse_args(known_only=True)
if args.local_debug:
args.pn = '1_2_3'
args.seed = 1
args.aln = 1e-2
args.alng = 1e-5
args.saln = False
args.afuse = False
args.pg = 0.8
args.pg0 = 1
else:
if args.data_path == '/path/to/imagenet':
raise ValueError(f'{"*"*40} please specify --data_path=/path/to/imagenet {"*"*40}')
# warn args.extra_args
if len(args.extra_args) > 0:
print(f'======================================================================================')
print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================\n{args.extra_args}')
print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================')
print(f'======================================================================================\n\n')
# init torch distributed
from utils import misc
os.makedirs(args.local_out_dir_path, exist_ok=True)
misc.init_distributed_mode(local_out_path=args.local_out_dir_path, timeout=30)
# set env
args.set_tf32(args.tf32)
args.seed_everything(benchmark=args.pg == 0)
# update args: data loading
args.device = dist.get_device()
if args.pn == '256':
args.pn = '1_2_3_4_5_6_8_10_13_16'
elif args.pn == '512':
args.pn = '1_2_3_4_6_9_13_18_24_32'
elif args.pn == '1024':
args.pn = '1_2_3_4_5_7_9_12_16_21_27_36_48_64'
args.patch_nums = tuple(map(int, args.pn.replace('-', '_').split('_')))
args.resos = tuple(pn * args.patch_size for pn in args.patch_nums)
args.data_load_reso = max(args.resos)
# update args: bs and lr
bs_per_gpu = round(args.bs / args.ac / dist.get_world_size())
args.batch_size = bs_per_gpu
args.bs = args.glb_batch_size = args.batch_size * dist.get_world_size()
args.workers = min(max(0, args.workers), args.batch_size)
args.tlr = args.ac * args.tblr * args.glb_batch_size / 256
args.twde = args.twde or args.twd
if args.wp == 0:
args.wp = args.ep * 1/50
# update args: progressive training
if args.pgwp == 0:
args.pgwp = args.ep * 1/300
if args.pg > 0:
args.sche = f'lin{args.pg:g}'
# update args: paths
args.log_txt_path = os.path.join(args.local_out_dir_path, 'log.txt')
args.last_ckpt_path = os.path.join(args.local_out_dir_path, f'ar-ckpt-last.pth')
_reg_valid_name = re.compile(r'[^\w\-+,.]')
tb_name = _reg_valid_name.sub(
'_',
f'tb-VARd{args.depth}'
f'__pn{args.pn}'
f'__b{args.bs}ep{args.ep}{args.opt[:4]}lr{args.tblr:g}wd{args.twd:g}'
)
args.tb_log_dir_path = os.path.join(args.local_out_dir_path, tb_name)
return args
================================================
FILE: utils/data.py
================================================
import os.path as osp
import PIL.Image as PImage
from torchvision.datasets.folder import DatasetFolder, IMG_EXTENSIONS
from torchvision.transforms import InterpolationMode, transforms
def normalize_01_into_pm1(x): # normalize x from [0, 1] to [-1, 1] by (x*2) - 1
return x.add(x).add_(-1)
def build_dataset(
data_path: str, final_reso: int,
hflip=False, mid_reso=1.125,
):
# build augmentations
mid_reso = round(mid_reso * final_reso) # first resize to mid_reso, then crop to final_reso
train_aug, val_aug = [
transforms.Resize(mid_reso, interpolation=InterpolationMode.LANCZOS), # transforms.Resize: resize the shorter edge to mid_reso
transforms.RandomCrop((final_reso, final_reso)),
transforms.ToTensor(), normalize_01_into_pm1,
], [
transforms.Resize(mid_reso, interpolation=InterpolationMode.LANCZOS), # transforms.Resize: resize the shorter edge to mid_reso
transforms.CenterCrop((final_reso, final_reso)),
transforms.ToTensor(), normalize_01_into_pm1,
]
if hflip: train_aug.insert(0, transforms.RandomHorizontalFlip())
train_aug, val_aug = transforms.Compose(train_aug), transforms.Compose(val_aug)
# build dataset
train_set = DatasetFolder(root=osp.join(data_path, 'train'), loader=pil_loader, extensions=IMG_EXTENSIONS, transform=train_aug)
val_set = DatasetFolder(root=osp.join(data_path, 'val'), loader=pil_loader, extensions=IMG_EXTENSIONS, transform=val_aug)
num_classes = 1000
print(f'[Dataset] {len(train_set)=}, {len(val_set)=}, {num_classes=}')
print_aug(train_aug, '[train]')
print_aug(val_aug, '[val]')
return num_classes, train_set, val_set
def pil_loader(path):
with open(path, 'rb') as f:
img: PImage.Image = PImage.open(f).convert('RGB')
return img
def print_aug(transform, label):
print(f'Transform {label} = ')
if hasattr(transform, 'transforms'):
for t in transform.transforms:
print(t)
else:
print(transform)
print('---------------------------\n')
================================================
FILE: utils/data_sampler.py
================================================
import numpy as np
import torch
from torch.utils.data.sampler import Sampler
class EvalDistributedSampler(Sampler):
def __init__(self, dataset, num_replicas, rank):
seps = np.linspace(0, len(dataset), num_replicas+1, dtype=int)
beg, end = seps[:-1], seps[1:]
beg, end = beg[rank], end[rank]
self.indices = tuple(range(beg, end))
def __iter__(self):
return iter(self.indices)
def __len__(self) -> int:
return len(self.indices)
class InfiniteBatchSampler(Sampler):
def __init__(self, dataset_len, batch_size, seed_for_all_rank=0, fill_last=False, shuffle=True, drop_last=False, start_ep=0, start_it=0):
self.dataset_len = dataset_len
self.batch_size = batch_size
self.iters_per_ep = dataset_len // batch_size if drop_last else (dataset_len + batch_size - 1) // batch_size
self.max_p = self.iters_per_ep * batch_size
self.fill_last = fill_last
self.shuffle = shuffle
self.epoch = start_ep
self.same_seed_for_all_ranks = seed_for_all_rank
self.indices = self.gener_indices()
self.start_ep, self.start_it = start_ep, start_it
def gener_indices(self):
if self.shuffle:
g = torch.Generator()
g.manual_seed(self.epoch + self.same_seed_for_all_ranks)
indices = torch.randperm(self.dataset_len, generator=g).numpy()
else:
indices = torch.arange(self.dataset_len).numpy()
tails = self.batch_size - (self.dataset_len % self.batch_size)
if tails != self.batch_size and self.fill_last:
tails = indices[:tails]
np.random.shuffle(indices)
indices = np.concatenate((indices, tails))
# built-in list/tuple is faster than np.ndarray (when collating the data via a for-loop)
# noinspection PyTypeChecker
return tuple(indices.tolist())
def __iter__(self):
self.epoch = self.start_ep
while True:
self.epoch += 1
p = (self.start_it * self.batch_size) if self.epoch == self.start_ep else 0
while p < self.max_p:
q = p + self.batch_size
yield self.indices[p:q]
p = q
if self.shuffle:
self.indices = self.gener_indices()
def __len__(self):
return self.iters_per_ep
class DistInfiniteBatchSampler(InfiniteBatchSampler):
def __init__(self, world_size, rank, dataset_len, glb_batch_size, same_seed_for_all_ranks=0, repeated_aug=0, fill_last=False, shuffle=True, start_ep=0, start_it=0):
assert glb_batch_size % world_size == 0
self.world_size, self.rank = world_size, rank
self.dataset_len = dataset_len
self.glb_batch_size = glb_batch_size
self.batch_size = glb_batch_size // world_size
self.iters_per_ep = (dataset_len + glb_batch_size - 1) // glb_batch_size
self.fill_last = fill_last
self.shuffle = shuffle
self.repeated_aug = repeated_aug
self.epoch = start_ep
self.same_seed_for_all_ranks = same_seed_for_all_ranks
self.indices = self.gener_indices()
self.start_ep, self.start_it = start_ep, start_it
def gener_indices(self):
global_max_p = self.iters_per_ep * self.glb_batch_size # global_max_p % world_size must be 0 cuz glb_batch_size % world_size == 0
# print(f'global_max_p = iters_per_ep({self.iters_per_ep}) * glb_batch_size({self.glb_batch_size}) = {global_max_p}')
if self.shuffle:
g = torch.Generator()
g.manual_seed(self.epoch + self.same_seed_for_all_ranks)
global_indices = torch.randperm(self.dataset_len, generator=g)
if self.repeated_aug > 1:
global_indices = global_indices[:(self.dataset_len + self.repeated_aug - 1) // self.repeated_aug].repeat_interleave(self.repeated_aug, dim=0)[:global_max_p]
else:
global_indices = torch.arange(self.dataset_len)
filling = global_max_p - global_indices.shape[0]
if filling > 0 and self.fill_last:
global_indices = torch.cat((global_indices, global_indices[:filling]))
# global_indices = tuple(global_indices.numpy().tolist())
seps = torch.linspace(0, global_indices.shape[0], self.world_size + 1, dtype=torch.int)
local_indices = global_indices[seps[self.rank].item():seps[self.rank + 1].item()].tolist()
self.max_p = len(local_indices)
return local_indices
================================================
FILE: utils/lr_control.py
================================================
import math
from pprint import pformat
from typing import Tuple, List, Dict, Union
import torch.nn
import dist
def lr_wd_annealing(sche_type: str, optimizer, peak_lr, wd, wd_end, cur_it, wp_it, max_it, wp0=0.005, wpe=0.001):
"""Decay the learning rate with half-cycle cosine after warmup"""
wp_it = round(wp_it)
if cur_it < wp_it:
cur_lr = wp0 + (1-wp0) * cur_it / wp_it
else:
pasd = (cur_it - wp_it) / (max_it-1 - wp_it) # [0, 1]
rest = 1 - pasd # [1, 0]
if sche_type == 'cos':
cur_lr = wpe + (1-wpe) * (0.5 + 0.5 * math.cos(math.pi * pasd))
elif sche_type == 'lin':
T = 0.15; max_rest = 1-T
if pasd < T: cur_lr = 1
else: cur_lr = wpe + (1-wpe) * rest / max_rest # 1 to wpe
elif sche_type == 'lin0':
T = 0.05; max_rest = 1-T
if pasd < T: cur_lr = 1
else: cur_lr = wpe + (1-wpe) * rest / max_rest
elif sche_type == 'lin00':
cur_lr = wpe + (1-wpe) * rest
elif sche_type.startswith('lin'):
T = float(sche_type[3:]); max_rest = 1-T
wpe_mid = wpe + (1-wpe) * max_rest
wpe_mid = (1 + wpe_mid) / 2
if pasd < T: cur_lr = 1 + (wpe_mid-1) * pasd / T
else: cur_lr = wpe + (wpe_mid-wpe) * rest / max_rest
elif sche_type == 'exp':
T = 0.15; max_rest = 1-T
if pasd < T: cur_lr = 1
else:
expo = (pasd-T) / max_rest * math.log(wpe)
cur_lr = math.exp(expo)
else:
raise NotImplementedError(f'unknown sche_type {sche_type}')
cur_lr *= peak_lr
pasd = cur_it / (max_it-1)
cur_wd = wd_end + (wd - wd_end) * (0.5 + 0.5 * math.cos(math.pi * pasd))
inf = 1e6
min_lr, max_lr = inf, -1
min_wd, max_wd = inf, -1
for param_group in optimizer.param_groups:
param_group['lr'] = cur_lr * param_group.get('lr_sc', 1) # 'lr_sc' could be assigned
max_lr = max(max_lr, param_group['lr'])
min_lr = min(min_lr, param_group['lr'])
param_group['weight_decay'] = cur_wd * param_group.get('wd_sc', 1)
max_wd = max(max_wd, param_group['weight_decay'])
if param_group['weight_decay'] > 0:
min_wd = min(min_wd, param_group['weight_decay'])
if min_lr == inf: min_lr = -1
if min_wd == inf: min_wd = -1
return min_lr, max_lr, min_wd, max_wd
def filter_params(model, nowd_keys=()) -> Tuple[
List[str], List[torch.nn.Parameter], List[Dict[str, Union[torch.nn.Parameter, float]]]
]:
para_groups, para_groups_dbg = {}, {}
names, paras = [], []
names_no_grad = []
count, numel = 0, 0
for name, para in model.named_parameters():
name = name.replace('_fsdp_wrapped_module.', '')
if not para.requires_grad:
names_no_grad.append(name)
continue # frozen weights
count += 1
numel += para.numel()
names.append(name)
paras.append(para)
if para.ndim == 1 or name.endswith('bias') or any(k in name for k in nowd_keys):
cur_wd_sc, group_name = 0., 'ND'
else:
cur_wd_sc, group_name = 1., 'D'
cur_lr_sc = 1.
if group_name not in para_groups:
para_groups[group_name] = {'params': [], 'wd_sc': cur_wd_sc, 'lr_sc': cur_lr_sc}
para_groups_dbg[group_name] = {'params': [], 'wd_sc': cur_wd_sc, 'lr_sc': cur_lr_sc}
para_groups[group_name]['params'].append(para)
para_groups_dbg[group_name]['params'].append(name)
for g in para_groups_dbg.values():
g['params'] = pformat(', '.join(g['params']), width=200)
print(f'[get_param_groups] param_groups = \n{pformat(para_groups_dbg, indent=2, width=240)}\n')
for rk in range(dist.get_world_size()):
dist.barrier()
if dist.get_rank() == rk:
print(f'[get_param_groups][rank{dist.get_rank()}] {type(model).__name__=} {count=}, {numel=}', flush=True, force=True)
print('')
assert len(names_no_grad) == 0, f'[get_param_groups] names_no_grad = \n{pformat(names_no_grad, indent=2, width=240)}\n'
return names, paras, list(para_groups.values())
================================================
FILE: utils/misc.py
================================================
import datetime
import functools
import glob
import os
import subprocess
import sys
import time
from collections import defaultdict, deque
from typing import Iterator, List, Tuple
import numpy as np
import pytz
import torch
import torch.distributed as tdist
import dist
from utils import arg_util
os_system = functools.partial(subprocess.call, shell=True)
def echo(info):
os_system(f'echo "[$(date "+%m-%d-%H:%M:%S")] ({os.path.basename(sys._getframe().f_back.f_code.co_filename)}, line{sys._getframe().f_back.f_lineno})=> {info}"')
def os_system_get_stdout(cmd):
return subprocess.run(cmd, shell=True, stdout=subprocess.PIPE).stdout.decode('utf-8')
def os_system_get_stdout_stderr(cmd):
cnt = 0
while True:
try:
sp = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=30)
except subprocess.TimeoutExpired:
cnt += 1
print(f'[fetch free_port file] timeout cnt={cnt}')
else:
return sp.stdout.decode('utf-8'), sp.stderr.decode('utf-8')
def time_str(fmt='[%m-%d %H:%M:%S]'):
return datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime(fmt)
def init_distributed_mode(local_out_path, only_sync_master=False, timeout=30):
try:
dist.initialize(fork=False, timeout=timeout)
dist.barrier()
except RuntimeError:
print(f'{">"*75} NCCL Error {"<"*75}', flush=True)
time.sleep(10)
if local_out_path is not None: os.makedirs(local_out_path, exist_ok=True)
_change_builtin_print(dist.is_local_master())
if (dist.is_master() if only_sync_master else dist.is_local_master()) and local_out_path is not None and len(local_out_path):
sys.stdout, sys.stderr = SyncPrint(local_out_path, sync_stdout=True), SyncPrint(local_out_path, sync_stdout=False)
def _change_builtin_print(is_master):
import builtins as __builtin__
builtin_print = __builtin__.print
if type(builtin_print) != type(open):
return
def prt(*args, **kwargs):
force = kwargs.pop('force', False)
clean = kwargs.pop('clean', False)
deeper = kwargs.pop('deeper', False)
if is_master or force:
if not clean:
f_back = sys._getframe().f_back
if deeper and f_back.f_back is not None:
f_back = f_back.f_back
file_desc = f'{f_back.f_code.co_filename:24s}'[-24:]
builtin_print(f'{time_str()} ({file_desc}, line{f_back.f_lineno:-4d})=>', *args, **kwargs)
else:
builtin_print(*args, **kwargs)
__builtin__.print = prt
class SyncPrint(object):
def __init__(self, local_output_dir, sync_stdout=True):
self.sync_stdout = sync_stdout
self.terminal_stream = sys.stdout if sync_stdout else sys.stderr
fname = os.path.join(local_output_dir, 'stdout.txt' if sync_stdout else 'stderr.txt')
existing = os.path.exists(fname)
self.file_stream = open(fname, 'a')
if existing:
self.file_stream.write('\n'*7 + '='*55 + f' RESTART {time_str()} ' + '='*55 + '\n')
self.file_stream.flush()
self.enabled = True
def write(self, message):
self.terminal_stream.write(message)
self.file_stream.write(message)
def flush(self):
self.terminal_stream.flush()
self.file_stream.flush()
def close(self):
if not self.enabled:
return
self.enabled = False
self.file_stream.flush()
self.file_stream.close()
if self.sync_stdout:
sys.stdout = self.terminal_stream
sys.stdout.flush()
else:
sys.stderr = self.terminal_stream
sys.stderr.flush()
def __del__(self):
self.close()
class DistLogger(object):
def __init__(self, lg, verbose):
self._lg, self._verbose = lg, verbose
@staticmethod
def do_nothing(*args, **kwargs):
pass
def __getattr__(self, attr: str):
return getattr(self._lg, attr) if self._verbose else DistLogger.do_nothing
class TensorboardLogger(object):
def __init__(self, log_dir, filename_suffix):
try: import tensorflow_io as tfio
except: pass
from torch.utils.tensorboard import SummaryWriter
self.writer = SummaryWriter(log_dir=log_dir, filename_suffix=filename_suffix)
self.step = 0
def set_step(self, step=None):
if step is not None:
self.step = step
else:
self.step += 1
def update(self, head='scalar', step=None, **kwargs):
for k, v in kwargs.items():
if v is None:
continue
# assert isinstance(v, (float, int)), type(v)
if step is None: # iter wise
it = self.step
if it == 0 or (it + 1) % 500 == 0:
if hasattr(v, 'item'): v = v.item()
self.writer.add_scalar(f'{head}/{k}', v, it)
else: # epoch wise
if hasattr(v, 'item'): v = v.item()
self.writer.add_scalar(f'{head}/{k}', v, step)
def log_tensor_as_distri(self, tag, tensor1d, step=None):
if step is None: # iter wise
step = self.step
loggable = step == 0 or (step + 1) % 500 == 0
else: # epoch wise
loggable = True
if loggable:
try:
self.writer.add_histogram(tag=tag, values=tensor1d, global_step=step)
except Exception as e:
print(f'[log_tensor_as_distri writer.add_histogram failed]: {e}')
def log_image(self, tag, img_chw, step=None):
if step is None: # iter wise
step = self.step
loggable = step == 0 or (step + 1) % 500 == 0
else: # epoch wise
loggable = True
if loggable:
self.writer.add_image(tag, img_chw, step, dataformats='CHW')
def flush(self):
self.writer.flush()
def close(self):
self.writer.close()
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=30, fmt=None):
if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
tdist.barrier()
tdist.all_reduce(t)
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
return np.median(self.deque) if len(self.deque) else 0
@property
def avg(self):
return sum(self.deque) / (len(self.deque) or 1)
@property
def global_avg(self):
return self.total / (self.count or 1)
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1] if len(self.deque) else 0
def time_preds(self, counts) -> Tuple[float, str, str]:
remain_secs = counts * self.median
return remain_secs, str(datetime.timedelta(seconds=round(remain_secs))), time.strftime("%Y-%m-%d %H:%M", time.localtime(time.time() + remain_secs))
def __str__(self):
return self.fmt.format(
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value)
class MetricLogger(object):
def __init__(self, delimiter=' '):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
self.iter_end_t = time.time()
self.log_iters = []
def update(self, **kwargs):
for k, v in kwargs.items():
if v is None:
continue
if hasattr(v, 'item'): v = v.item()
# assert isinstance(v, (float, int)), type(v)
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, attr))
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
if len(meter.deque):
loss_str.append(
"{}: {}".format(name, str(meter))
)
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, meter):
self.meters[name] = meter
def log_every(self, start_it, max_iters, itrt, print_freq, header=None):
self.log_iters = set(np.linspace(0, max_iters-1, print_freq, dtype=int).tolist())
self.log_iters.add(start_it)
if not header:
header = ''
start_time = time.time()
self.iter_end_t = time.time()
self.iter_time = SmoothedValue(fmt='{avg:.4f}')
self.data_time = SmoothedValue(fmt='{avg:.4f}')
space_fmt = ':' + str(len(str(max_iters))) + 'd'
log_msg = [
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}'
]
log_msg = self.delimiter.join(log_msg)
if isinstance(itrt, Iterator) and not hasattr(itrt, 'preload') and not hasattr(itrt, 'set_epoch'):
for i in range(start_it, max_iters):
obj = next(itrt)
self.data_time.update(time.time() - self.iter_end_t)
yield i, obj
self.iter_time.update(time.time() - self.iter_end_t)
if i in self.log_iters:
eta_seconds = self.iter_time.global_avg * (max_iters - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
print(log_msg.format(
i, max_iters, eta=eta_string,
meters=str(self),
time=str(self.iter_time), data=str(self.data_time)), flush=True)
self.iter_end_t = time.time()
else:
if isinstance(itrt, int): itrt = range(itrt)
for i, obj in enumerate(itrt):
self.data_time.update(time.time() - self.iter_end_t)
yield i, obj
self.iter_time.update(time.time() - self.iter_end_t)
if i in self.log_iters:
eta_seconds = self.iter_time.global_avg * (max_iters - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
print(log_msg.format(
i, max_iters, eta=eta_string,
meters=str(self),
time=str(self.iter_time), data=str(self.data_time)), flush=True)
self.iter_end_t = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('{} Total time: {} ({:.3f} s / it)'.format(
header, total_time_str, total_time / max_iters), flush=True)
def glob_with_latest_modified_first(pattern, recursive=False):
return sorted(glob.glob(pattern, recursive=recursive), key=os.path.getmtime, reverse=True)
def auto_resume(args: arg_util.Args, pattern='ckpt*.pth') -> Tuple[List[str], int, int, dict, dict]:
info = []
file = os.path.join(args.local_out_dir_path, pattern)
all_ckpt = glob_with_latest_modified_first(file)
if len(all_ckpt) == 0:
info.append(f'[auto_resume] no ckpt found @ {file}')
info.append(f'[auto_resume quit]')
return info, 0, 0, {}, {}
else:
info.append(f'[auto_resume] load ckpt from @ {all_ckpt[0]} ...')
ckpt = torch.load(all_ckpt[0], map_location='cpu')
ep, it = ckpt['epoch'], ckpt['iter']
info.append(f'[auto_resume success] resume from ep{ep}, it{it}')
return info, ep, it, ckpt['trainer'], ckpt['args']
def create_npz_from_sample_folder(sample_folder: str):
"""
Builds a single .npz file from a folder of .png samples. Refer to DiT.
"""
import os, glob
import numpy as np
from tqdm import tqdm
from PIL import Image
samples = []
pngs = glob.glob(os.path.join(sample_folder, '*.png')) + glob.glob(os.path.join(sample_folder, '*.PNG'))
assert len(pngs) == 50_000, f'{len(pngs)} png files found in {sample_folder}, but expected 50,000'
for png in tqdm(pngs, desc='Building .npz file from samples (png only)'):
with Image.open(png) as sample_pil:
sample_np = np.asarray(sample_pil).astype(np.uint8)
samples.append(sample_np)
samples = np.stack(samples)
assert samples.shape == (50_000, samples.shape[1], samples.shape[2], 3)
npz_path = f'{sample_folder}.npz'
np.savez(npz_path, arr_0=samples)
print(f'Saved .npz file to {npz_path} [shape={samples.shape}].')
return npz_path
gitextract_24dcq1k1/
├── .gitignore
├── LICENSE
├── README.md
├── dist.py
├── models/
│ ├── __init__.py
│ ├── basic_vae.py
│ ├── basic_var.py
│ ├── helpers.py
│ ├── quant.py
│ ├── var.py
│ └── vqvae.py
├── train.py
├── trainer.py
└── utils/
├── amp_sc.py
├── arg_util.py
├── data.py
├── data_sampler.py
├── lr_control.py
└── misc.py
SYMBOL INDEX (204 symbols across 16 files)
FILE: dist.py
function initialized (line 16) | def initialized():
function initialize (line 20) | def initialize(fork=False, backend='nccl', gpu_id_if_not_distibuted=0, t...
function get_rank (line 52) | def get_rank():
function get_local_rank (line 56) | def get_local_rank():
function get_world_size (line 60) | def get_world_size():
function get_device (line 64) | def get_device():
function set_gpu_id (line 68) | def set_gpu_id(gpu_id: int):
function is_master (line 78) | def is_master():
function is_local_master (line 82) | def is_local_master():
function new_group (line 86) | def new_group(ranks: List[int]):
function barrier (line 92) | def barrier():
function allreduce (line 97) | def allreduce(t: torch.Tensor, async_op=False):
function allgather (line 109) | def allgather(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], to...
function allgather_diff_shape (line 122) | def allgather_diff_shape(t: torch.Tensor, cat=True) -> Union[List[torch....
function broadcast (line 149) | def broadcast(t: torch.Tensor, src_rank) -> None:
function dist_fmt_vals (line 159) | def dist_fmt_vals(val: float, fmt: Union[str, None] = '%.2f') -> Union[t...
function master_only (line 171) | def master_only(func):
function local_master_only (line 184) | def local_master_only(func):
function for_visualize (line 197) | def for_visualize(func):
function finalize (line 209) | def finalize():
FILE: models/__init__.py
function build_vae_var (line 9) | def build_vae_var(
FILE: models/basic_vae.py
function nonlinearity (line 14) | def nonlinearity(x):
function Normalize (line 18) | def Normalize(in_channels, num_groups=32):
class Upsample2x (line 22) | class Upsample2x(nn.Module):
method __init__ (line 23) | def __init__(self, in_channels):
method forward (line 27) | def forward(self, x):
class Downsample2x (line 31) | class Downsample2x(nn.Module):
method __init__ (line 32) | def __init__(self, in_channels):
method forward (line 36) | def forward(self, x):
class ResnetBlock (line 40) | class ResnetBlock(nn.Module):
method __init__ (line 41) | def __init__(self, *, in_channels, out_channels=None, dropout): # conv...
method forward (line 57) | def forward(self, x):
class AttnBlock (line 63) | class AttnBlock(nn.Module):
method __init__ (line 64) | def __init__(self, in_channels):
method forward (line 73) | def forward(self, x):
function make_attn (line 95) | def make_attn(in_channels, using_sa=True):
class Encoder (line 99) | class Encoder(nn.Module):
method __init__ (line 100) | def __init__(
method forward (line 144) | def forward(self, x):
class Decoder (line 163) | class Decoder(nn.Module):
method __init__ (line 164) | def __init__(
method forward (line 210) | def forward(self, z):
FILE: models/basic_var.py
function slow_attn (line 27) | def slow_attn(query, key, value, scale: float, attn_mask=None, dropout_p...
class FFN (line 33) | class FFN(nn.Module):
method __init__ (line 34) | def __init__(self, in_features, hidden_features=None, out_features=Non...
method forward (line 44) | def forward(self, x):
method extra_repr (line 54) | def extra_repr(self) -> str:
class SelfAttention (line 58) | class SelfAttention(nn.Module):
method __init__ (line 59) | def __init__(
method kv_caching (line 87) | def kv_caching(self, enable: bool): self.caching, self.cached_k, self....
method forward (line 90) | def forward(self, x, attn_bias):
method extra_repr (line 124) | def extra_repr(self) -> str:
class AdaLNSelfAttn (line 128) | class AdaLNSelfAttn(nn.Module):
method __init__ (line 129) | def __init__(
method forward (line 152) | def forward(self, x, cond_BD, attn_bias): # C: embed_dim, D: cond_dim
method extra_repr (line 161) | def extra_repr(self) -> str:
class AdaLNBeforeHead (line 165) | class AdaLNBeforeHead(nn.Module):
method __init__ (line 166) | def __init__(self, C, D, norm_layer): # C: embed_dim, D: cond_dim
method forward (line 172) | def forward(self, x_BLC: torch.Tensor, cond_BD: torch.Tensor):
FILE: models/helpers.py
function sample_with_top_k_top_p_ (line 6) | def sample_with_top_k_top_p_(logits_BlV: torch.Tensor, top_k: int = 0, t...
function gumbel_softmax_with_rng (line 22) | def gumbel_softmax_with_rng(logits: torch.Tensor, tau: float = 1, hard: ...
function drop_path (line 39) | def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by...
class DropPath (line 49) | class DropPath(nn.Module): # taken from timm
method __init__ (line 50) | def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
method forward (line 55) | def forward(self, x):
method extra_repr (line 58) | def extra_repr(self):
FILE: models/quant.py
class VectorQuantizer2 (line 15) | class VectorQuantizer2(nn.Module):
method __init__ (line 17) | def __init__(
method eini (line 44) | def eini(self, eini):
method extra_repr (line 48) | def extra_repr(self) -> str:
method forward (line 52) | def forward(self, f_BChw: torch.Tensor, ret_usages=False) -> Tuple[tor...
method embed_to_fhat (line 107) | def embed_to_fhat(self, ms_h_BChw: List[torch.Tensor], all_to_max_scal...
method f_to_idxBl_or_fhat (line 135) | def f_to_idxBl_or_fhat(self, f_BChw: torch.Tensor, to_fhat: bool, v_pa...
method idxBl_to_var_input (line 169) | def idxBl_to_var_input(self, gt_ms_idx_Bl: List[torch.Tensor]) -> torc...
method get_next_autoregressive_input (line 187) | def get_next_autoregressive_input(self, si: int, SN: int, f_hat: torch...
class Phi (line 199) | class Phi(nn.Conv2d):
method __init__ (line 200) | def __init__(self, embed_dim, quant_resi):
method forward (line 205) | def forward(self, h_BChw):
class PhiShared (line 209) | class PhiShared(nn.Module):
method __init__ (line 210) | def __init__(self, qresi: Phi):
method __getitem__ (line 214) | def __getitem__(self, _) -> Phi:
class PhiPartiallyShared (line 218) | class PhiPartiallyShared(nn.Module):
method __init__ (line 219) | def __init__(self, qresi_ls: nn.ModuleList):
method __getitem__ (line 225) | def __getitem__(self, at_from_0_to_1: float) -> Phi:
method extra_repr (line 228) | def extra_repr(self) -> str:
class PhiNonShared (line 232) | class PhiNonShared(nn.ModuleList):
method __init__ (line 233) | def __init__(self, qresi: List):
method __getitem__ (line 239) | def __getitem__(self, at_from_0_to_1: float) -> Phi:
method extra_repr (line 242) | def extra_repr(self) -> str:
FILE: models/var.py
class SharedAdaLin (line 15) | class SharedAdaLin(nn.Linear):
method forward (line 16) | def forward(self, cond_BD):
class VAR (line 21) | class VAR(nn.Module):
method __init__ (line 22) | def __init__(
method get_logits (line 118) | def get_logits(self, h_or_h_and_residual: Union[torch.Tensor, Tuple[to...
method autoregressive_infer_cfg (line 127) | def autoregressive_infer_cfg(
method forward (line 192) | def forward(self, label_B: torch.LongTensor, x_BLCv_wo_first_l: torch....
method init_weights (line 236) | def init_weights(self, init_adaln=0.5, init_adaln_gamma=1e-5, init_hea...
method extra_repr (line 288) | def extra_repr(self):
class VARHF (line 292) | class VARHF(VAR, PyTorchModelHubMixin):
method __init__ (line 295) | def __init__(
FILE: models/vqvae.py
class VQVAE (line 16) | class VQVAE(nn.Module):
method __init__ (line 17) | def __init__(
method forward (line 56) | def forward(self, inp, ret_usages=False): # -> rec_B3HW, idx_N, loss
method fhat_to_img (line 62) | def fhat_to_img(self, f_hat: torch.Tensor):
method img_to_idxBl (line 65) | def img_to_idxBl(self, inp_img_no_grad: torch.Tensor, v_patch_nums: Op...
method idxBl_to_img (line 69) | def idxBl_to_img(self, ms_idx_Bl: List[torch.Tensor], same_shape: bool...
method embed_to_img (line 78) | def embed_to_img(self, ms_h_BChw: List[torch.Tensor], all_to_max_scale...
method img_to_reconstructed_img (line 84) | def img_to_reconstructed_img(self, x, v_patch_nums: Optional[Sequence[...
method load_state_dict (line 92) | def load_state_dict(self, state_dict: Dict[str, Any], strict=True, ass...
FILE: train.py
function build_everything (line 19) | def build_everything(args: arg_util.Args):
function main_training (line 171) | def main_training():
function train_one_ep (line 253) | def train_one_ep(ep: int, is_first_ep: bool, start_it: int, args: arg_ut...
class NullDDP (line 320) | class NullDDP(torch.nn.Module):
method __init__ (line 321) | def __init__(self, module, *args, **kwargs):
method forward (line 326) | def forward(self, *args, **kwargs):
FILE: trainer.py
class VARTrainer (line 20) | class VARTrainer(object):
method __init__ (line 21) | def __init__(
method eval_ep (line 55) | def eval_ep(self, ld_val: DataLoader):
method train_step (line 86) | def train_step(
method get_config (line 162) | def get_config(self):
method state_dict (line 169) | def state_dict(self):
method load_state_dict (line 179) | def load_state_dict(self, state, strict=True, skip_vae=False):
FILE: utils/amp_sc.py
class NullCtx (line 7) | class NullCtx:
method __enter__ (line 8) | def __enter__(self):
method __exit__ (line 11) | def __exit__(self, exc_type, exc_val, exc_tb):
class AmpOptimizer (line 15) | class AmpOptimizer:
method __init__ (line 16) | def __init__(
method backward_clip_step (line 39) | def backward_clip_step(
method state_dict (line 77) | def state_dict(self):
method load_state_dict (line 85) | def load_state_dict(self, state, strict=True):
FILE: utils/arg_util.py
class Args (line 25) | class Args(Tap):
method seed_everything (line 113) | def seed_everything(self, benchmark: bool):
method get_different_generator_for_each_rank (line 129) | def get_different_generator_for_each_rank(self) -> Optional[torch.Gene...
method compile_model (line 139) | def compile_model(self, m, fast):
method state_dict (line 148) | def state_dict(self, key_ordered=True) -> Union[OrderedDict, dict]:
method load_state_dict (line 156) | def load_state_dict(self, d: Union[OrderedDict, dict, str]):
method set_tf32 (line 167) | def set_tf32(tf32: bool):
method dump_log (line 177) | def dump_log(self):
method __str__ (line 198) | def __str__(self):
function init_dist_and_get_args (line 207) | def init_dist_and_get_args():
FILE: utils/data.py
function normalize_01_into_pm1 (line 8) | def normalize_01_into_pm1(x): # normalize x from [0, 1] to [-1, 1] by (...
function build_dataset (line 12) | def build_dataset(
function pil_loader (line 41) | def pil_loader(path):
function print_aug (line 47) | def print_aug(transform, label):
FILE: utils/data_sampler.py
class EvalDistributedSampler (line 6) | class EvalDistributedSampler(Sampler):
method __init__ (line 7) | def __init__(self, dataset, num_replicas, rank):
method __iter__ (line 13) | def __iter__(self):
method __len__ (line 16) | def __len__(self) -> int:
class InfiniteBatchSampler (line 20) | class InfiniteBatchSampler(Sampler):
method __init__ (line 21) | def __init__(self, dataset_len, batch_size, seed_for_all_rank=0, fill_...
method gener_indices (line 33) | def gener_indices(self):
method __iter__ (line 51) | def __iter__(self):
method __len__ (line 63) | def __len__(self):
class DistInfiniteBatchSampler (line 67) | class DistInfiniteBatchSampler(InfiniteBatchSampler):
method __init__ (line 68) | def __init__(self, world_size, rank, dataset_len, glb_batch_size, same...
method gener_indices (line 84) | def gener_indices(self):
FILE: utils/lr_control.py
function lr_wd_annealing (line 10) | def lr_wd_annealing(sche_type: str, optimizer, peak_lr, wd, wd_end, cur_...
function filter_params (line 68) | def filter_params(model, nowd_keys=()) -> Tuple[
FILE: utils/misc.py
function echo (line 20) | def echo(info):
function os_system_get_stdout (line 22) | def os_system_get_stdout(cmd):
function os_system_get_stdout_stderr (line 24) | def os_system_get_stdout_stderr(cmd):
function time_str (line 36) | def time_str(fmt='[%m-%d %H:%M:%S]'):
function init_distributed_mode (line 40) | def init_distributed_mode(local_out_path, only_sync_master=False, timeou...
function _change_builtin_print (line 54) | def _change_builtin_print(is_master):
class SyncPrint (line 78) | class SyncPrint(object):
method __init__ (line 79) | def __init__(self, local_output_dir, sync_stdout=True):
method write (line 90) | def write(self, message):
method flush (line 94) | def flush(self):
method close (line 98) | def close(self):
method __del__ (line 111) | def __del__(self):
class DistLogger (line 115) | class DistLogger(object):
method __init__ (line 116) | def __init__(self, lg, verbose):
method do_nothing (line 120) | def do_nothing(*args, **kwargs):
method __getattr__ (line 123) | def __getattr__(self, attr: str):
class TensorboardLogger (line 127) | class TensorboardLogger(object):
method __init__ (line 128) | def __init__(self, log_dir, filename_suffix):
method set_step (line 135) | def set_step(self, step=None):
method update (line 141) | def update(self, head='scalar', step=None, **kwargs):
method log_tensor_as_distri (line 155) | def log_tensor_as_distri(self, tag, tensor1d, step=None):
method log_image (line 167) | def log_image(self, tag, img_chw, step=None):
method flush (line 176) | def flush(self):
method close (line 179) | def close(self):
class SmoothedValue (line 183) | class SmoothedValue(object):
method __init__ (line 188) | def __init__(self, window_size=30, fmt=None):
method update (line 196) | def update(self, value, n=1):
method synchronize_between_processes (line 201) | def synchronize_between_processes(self):
method median (line 213) | def median(self):
method avg (line 217) | def avg(self):
method global_avg (line 221) | def global_avg(self):
method max (line 225) | def max(self):
method value (line 229) | def value(self):
method time_preds (line 232) | def time_preds(self, counts) -> Tuple[float, str, str]:
method __str__ (line 236) | def __str__(self):
class MetricLogger (line 245) | class MetricLogger(object):
method __init__ (line 246) | def __init__(self, delimiter=' '):
method update (line 252) | def update(self, **kwargs):
method __getattr__ (line 261) | def __getattr__(self, attr):
method __str__ (line 269) | def __str__(self):
method synchronize_between_processes (line 278) | def synchronize_between_processes(self):
method add_meter (line 282) | def add_meter(self, name, meter):
method log_every (line 285) | def log_every(self, start_it, max_iters, itrt, print_freq, header=None):
function glob_with_latest_modified_first (line 340) | def glob_with_latest_modified_first(pattern, recursive=False):
function auto_resume (line 344) | def auto_resume(args: arg_util.Args, pattern='ckpt*.pth') -> Tuple[List[...
function create_npz_from_sample_folder (line 360) | def create_npz_from_sample_folder(sample_folder: str):
Condensed preview — 19 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (156K chars).
[
{
"path": ".gitignore",
"chars": 184,
"preview": "*.swp\n**/__pycache__/**\n**/.ipynb_checkpoints/**\n.DS_Store\n.idea/*\n.vscode/*\nllava/\n_vis_cached/\n_auto_*\nckpt/\nlog/\ntb*/"
},
{
"path": "LICENSE",
"chars": 1072,
"preview": "MIT License\n\nCopyright (c) 2024 FoundationVision\n\nPermission is hereby granted, free of charge, to any person obtaining "
},
{
"path": "README.md",
"chars": 19770,
"preview": "# VAR: a new visual generation method elevates GPT-style models beyond diffusion🚀 & Scaling laws observed📈\n\n<div align=\""
},
{
"path": "dist.py",
"chars": 5644,
"preview": "import datetime\nimport functools\nimport os\nimport sys\nfrom typing import List\nfrom typing import Union\n\nimport torch\nimp"
},
{
"path": "models/__init__.py",
"chars": 1651,
"preview": "from typing import Tuple\nimport torch.nn as nn\n\nfrom .quant import VectorQuantizer2\nfrom .var import VAR\nfrom .vqvae imp"
},
{
"path": "models/basic_vae.py",
"chars": 9004,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\n# this file only provides the 2 modules used in VQV"
},
{
"path": "models/basic_var.py",
"chars": 9151,
"preview": "import math\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom models.helpers import DropPath, dr"
},
{
"path": "models/helpers.py",
"chars": 2782,
"preview": "import torch\nfrom torch import nn as nn\nfrom torch.nn import functional as F\n\n\ndef sample_with_top_k_top_p_(logits_BlV: "
},
{
"path": "models/quant.py",
"chars": 13164,
"preview": "from typing import List, Optional, Sequence, Tuple, Union\n\nimport numpy as np\nimport torch\nfrom torch import distributed"
},
{
"path": "models/var.py",
"chars": 16884,
"preview": "import math\nfrom functools import partial\nfrom typing import Optional, Tuple, Union\n\nimport torch\nimport torch.nn as nn\n"
},
{
"path": "models/vqvae.py",
"chars": 5857,
"preview": "\"\"\"\nReferences:\n- VectorQuantizer2: https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3"
},
{
"path": "train.py",
"chars": 15724,
"preview": "import gc\nimport os\nimport shutil\nimport sys\nimport time\nimport warnings\nfrom functools import partial\n\nimport torch\nfro"
},
{
"path": "trainer.py",
"chars": 9539,
"preview": "import time\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\nimport torch.nn as nn\nfrom torch.nn.parallel i"
},
{
"path": "utils/amp_sc.py",
"chars": 3601,
"preview": "import math\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\n\n\nclass NullCtx:\n def __enter__(self):\n "
},
{
"path": "utils/arg_util.py",
"chars": 13530,
"preview": "import json\nimport os\nimport random\nimport re\nimport subprocess\nimport sys\nimport time\nfrom collections import OrderedDi"
},
{
"path": "utils/data.py",
"chars": 2076,
"preview": "import os.path as osp\n\nimport PIL.Image as PImage\nfrom torchvision.datasets.folder import DatasetFolder, IMG_EXTENSIONS\n"
},
{
"path": "utils/data_sampler.py",
"chars": 4585,
"preview": "import numpy as np\nimport torch\nfrom torch.utils.data.sampler import Sampler\n\n\nclass EvalDistributedSampler(Sampler):\n "
},
{
"path": "utils/lr_control.py",
"chars": 4262,
"preview": "import math\nfrom pprint import pformat\nfrom typing import Tuple, List, Dict, Union\n\nimport torch.nn\n\nimport dist\n\n\ndef l"
},
{
"path": "utils/misc.py",
"chars": 13725,
"preview": "import datetime\nimport functools\nimport glob\nimport os\nimport subprocess\nimport sys\nimport time\nfrom collections import "
}
]
About this extraction
This page contains the full source code of the FoundationVision/VAR GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 19 files (148.6 KB), approximately 42.0k tokens, and a symbol index with 204 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.