Repository: FoundationVision/VAR Branch: main Commit: 78b95394fc58 Files: 19 Total size: 148.6 KB Directory structure: gitextract_24dcq1k1/ ├── .gitignore ├── LICENSE ├── README.md ├── dist.py ├── models/ │ ├── __init__.py │ ├── basic_vae.py │ ├── basic_var.py │ ├── helpers.py │ ├── quant.py │ ├── var.py │ └── vqvae.py ├── train.py ├── trainer.py └── utils/ ├── amp_sc.py ├── arg_util.py ├── data.py ├── data_sampler.py ├── lr_control.py └── misc.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ *.swp **/__pycache__/** **/.ipynb_checkpoints/** .DS_Store .idea/* .vscode/* llava/ _vis_cached/ _auto_* ckpt/ log/ tb*/ img*/ local_output* *.pth *.pth.tar *.ckpt *.log *.txt *.ipynb ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2024 FoundationVision Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # VAR: a new visual generation method elevates GPT-style models beyond diffusion🚀 & Scaling laws observed📈
[![demo platform](https://img.shields.io/badge/Play%20with%20VAR%21-VAR%20demo%20platform-lightblue)](https://opensource.bytedance.com/gmpt/t2i/invite)  [![arXiv](https://img.shields.io/badge/arXiv%20paper-2404.02905-b31b1b.svg)](https://arxiv.org/abs/2404.02905)  [![huggingface weights](https://img.shields.io/badge/%F0%9F%A4%97%20Weights-FoundationVision/var-yellow)](https://huggingface.co/FoundationVision/var)  [![SOTA](https://img.shields.io/badge/State%20of%20the%20Art-Image%20Generation%20on%20ImageNet%20%28AR%29-32B1B4?logo=data%3Aimage%2Fsvg%2Bxml%3Bbase64%2CPHN2ZyB3aWR0aD0iNjA2IiBoZWlnaHQ9IjYwNiIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIiB4bWxuczp4bGluaz0iaHR0cDovL3d3dy53My5vcmcvMTk5OS94bGluayIgb3ZlcmZsb3c9ImhpZGRlbiI%2BPGRlZnM%2BPGNsaXBQYXRoIGlkPSJjbGlwMCI%2BPHJlY3QgeD0iLTEiIHk9Ii0xIiB3aWR0aD0iNjA2IiBoZWlnaHQ9IjYwNiIvPjwvY2xpcFBhdGg%2BPC9kZWZzPjxnIGNsaXAtcGF0aD0idXJsKCNjbGlwMCkiIHRyYW5zZm9ybT0idHJhbnNsYXRlKDEgMSkiPjxyZWN0IHg9IjUyOSIgeT0iNjYiIHdpZHRoPSI1NiIgaGVpZ2h0PSI0NzMiIGZpbGw9IiM0NEYyRjYiLz48cmVjdCB4PSIxOSIgeT0iNjYiIHdpZHRoPSI1NyIgaGVpZ2h0PSI0NzMiIGZpbGw9IiM0NEYyRjYiLz48cmVjdCB4PSIyNzQiIHk9IjE1MSIgd2lkdGg9IjU3IiBoZWlnaHQ9IjMwMiIgZmlsbD0iIzQ0RjJGNiIvPjxyZWN0IHg9IjEwNCIgeT0iMTUxIiB3aWR0aD0iNTciIGhlaWdodD0iMzAyIiBmaWxsPSIjNDRGMkY2Ii8%2BPHJlY3QgeD0iNDQ0IiB5PSIxNTEiIHdpZHRoPSI1NyIgaGVpZ2h0PSIzMDIiIGZpbGw9IiM0NEYyRjYiLz48cmVjdCB4PSIzNTkiIHk9IjE3MCIgd2lkdGg9IjU2IiBoZWlnaHQ9IjI2NCIgZmlsbD0iIzQ0RjJGNiIvPjxyZWN0IHg9IjE4OCIgeT0iMTcwIiB3aWR0aD0iNTciIGhlaWdodD0iMjY0IiBmaWxsPSIjNDRGMkY2Ii8%2BPHJlY3QgeD0iNzYiIHk9IjY2IiB3aWR0aD0iNDciIGhlaWdodD0iNTciIGZpbGw9IiM0NEYyRjYiLz48cmVjdCB4PSI0ODIiIHk9IjY2IiB3aWR0aD0iNDciIGhlaWdodD0iNTciIGZpbGw9IiM0NEYyRjYiLz48cmVjdCB4PSI3NiIgeT0iNDgyIiB3aWR0aD0iNDciIGhlaWdodD0iNTciIGZpbGw9IiM0NEYyRjYiLz48cmVjdCB4PSI0ODIiIHk9IjQ4MiIgd2lkdGg9IjQ3IiBoZWlnaHQ9IjU3IiBmaWxsPSIjNDRGMkY2Ii8%2BPC9nPjwvc3ZnPg%3D%3D)](https://paperswithcode.com/sota/image-generation-on-imagenet-256x256?tag_filter=485&p=visual-autoregressive-modeling-scalable-image)

Visual Autoregressive Modeling: Scalable Image Generation via Next-Scale Prediction

NeurIPS 2024 Best Paper


## News * **2025-11:** We Release our Text-to-Video generation model **InfinityStar** based on VAR & Infinity, please check [Infinity⭐️](https://github.com/FoundationVision/InfinityStar). * **2025-11:** 🎉 InfinityStar is accepted as **NeurIPS 2025 Oral.** * **2025-04:** 🎉 Infinity is accepted as **CVPR 2025 Oral.** * **2024-12:** 🏆 VAR received **NeurIPS 2024 Best Paper Award**. * **2024-12:** 🔥 We Release our Text-to-Image research based on VAR, please check [Infinity](https://github.com/FoundationVision/Infinity). * **2024-09:** VAR is accepted as **NeurIPS 2024 Oral** Presentation. * **2024-04:** [Visual AutoRegressive modeling](https://github.com/FoundationVision/VAR) is released. ## 🕹️ Try and Play with VAR! ~~We provide a [demo website](https://var.vision/demo) for you to play with VAR models and generate images interactively. Enjoy the fun of visual autoregressive modeling!~~ We provide a [demo website](https://opensource.bytedance.com/gmpt/t2i/invite) for you to play with VAR Text-to-Image and generate images interactively. Enjoy the fun of visual autoregressive modeling! We also provide [demo_sample.ipynb](demo_sample.ipynb) for you to see more technical details about VAR. [//]: # (

) [//]: # () ## What's New? ### 🔥 Introducing VAR: a new paradigm in autoregressive visual generation✨: Visual Autoregressive Modeling (VAR) redefines the autoregressive learning on images as coarse-to-fine "next-scale prediction" or "next-resolution prediction", diverging from the standard raster-scan "next-token prediction".

### 🔥 For the first time, GPT-style autoregressive models surpass diffusion models🚀:

### 🔥 Discovering power-law Scaling Laws in VAR transformers📈:

### 🔥 Zero-shot generalizability🛠️:

#### For a deep dive into our analyses, discussions, and evaluations, check out our [paper](https://arxiv.org/abs/2404.02905). ## VAR zoo We provide VAR models for you to play with, which are on or can be downloaded from the following links: | model | reso. | FID | rel. cost | #params | HF weights🤗 | |:----------:|:-----:|:--------:|:---------:|:-------:|:------------------------------------------------------------------------------------| | VAR-d16 | 256 | 3.55 | 0.4 | 310M | [var_d16.pth](https://huggingface.co/FoundationVision/var/resolve/main/var_d16.pth) | | VAR-d20 | 256 | 2.95 | 0.5 | 600M | [var_d20.pth](https://huggingface.co/FoundationVision/var/resolve/main/var_d20.pth) | | VAR-d24 | 256 | 2.33 | 0.6 | 1.0B | [var_d24.pth](https://huggingface.co/FoundationVision/var/resolve/main/var_d24.pth) | | VAR-d30 | 256 | 1.97 | 1 | 2.0B | [var_d30.pth](https://huggingface.co/FoundationVision/var/resolve/main/var_d30.pth) | | VAR-d30-re | 256 | **1.80** | 1 | 2.0B | [var_d30.pth](https://huggingface.co/FoundationVision/var/resolve/main/var_d30.pth) | | VAR-d36 | 512 | **2.63** | - | 2.3B | [var_d36.pth](https://huggingface.co/FoundationVision/var/resolve/main/var_d36.pth) | You can load these models to generate images via the codes in [demo_sample.ipynb](demo_sample.ipynb). Note: you need to download [vae_ch160v4096z32.pth](https://huggingface.co/FoundationVision/var/resolve/main/vae_ch160v4096z32.pth) first. ## Installation 1. Install `torch>=2.0.0`. 2. Install other pip packages via `pip3 install -r requirements.txt`. 3. Prepare the [ImageNet](http://image-net.org/) dataset

assume the ImageNet is in `/path/to/imagenet`. It should be like this: ``` /path/to/imagenet/: train/: n01440764: many_images.JPEG ... n01443537: many_images.JPEG ... val/: n01440764: ILSVRC2012_val_00000293.JPEG ... n01443537: ILSVRC2012_val_00000236.JPEG ... ``` **NOTE: The arg `--data_path=/path/to/imagenet` should be passed to the training script.**
5. (Optional) install and compile `flash-attn` and `xformers` for faster attention computation. Our code will automatically use them if installed. See [models/basic_var.py#L15-L30](models/basic_var.py#L15-L30). ## Training Scripts To train VAR-{d16, d20, d24, d30, d36-s} on ImageNet 256x256 or 512x512, you can run the following command: ```shell # d16, 256x256 torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \ --depth=16 --bs=768 --ep=200 --fp16=1 --alng=1e-3 --wpe=0.1 # d20, 256x256 torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \ --depth=20 --bs=768 --ep=250 --fp16=1 --alng=1e-3 --wpe=0.1 # d24, 256x256 torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \ --depth=24 --bs=768 --ep=350 --tblr=8e-5 --fp16=1 --alng=1e-4 --wpe=0.01 # d30, 256x256 torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \ --depth=30 --bs=1024 --ep=350 --tblr=8e-5 --fp16=1 --alng=1e-5 --wpe=0.01 --twde=0.08 # d36-s, 512x512 (-s means saln=1, shared AdaLN) torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \ --depth=36 --saln=1 --pn=512 --bs=768 --ep=350 --tblr=8e-5 --fp16=1 --alng=5e-6 --wpe=0.01 --twde=0.08 ``` A folder named `local_output` will be created to save the checkpoints and logs. You can monitor the training process by checking the logs in `local_output/log.txt` and `local_output/stdout.txt`, or using `tensorboard --logdir=local_output/`. If your experiment is interrupted, just rerun the command, and the training will **automatically resume** from the last checkpoint in `local_output/ckpt*.pth` (see [utils/misc.py#L344-L357](utils/misc.py#L344-L357)). ## Sampling & Zero-shot Inference For FID evaluation, use `var.autoregressive_infer_cfg(..., cfg=1.5, top_p=0.96, top_k=900, more_smooth=False)` to sample 50,000 images (50 per class) and save them as PNG (not JPEG) files in a folder. Pack them into a `.npz` file via `create_npz_from_sample_folder(sample_folder)` in [utils/misc.py#L344](utils/misc.py#L360). Then use the [OpenAI's FID evaluation toolkit](https://github.com/openai/guided-diffusion/tree/main/evaluations) and reference ground truth npz file of [256x256](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz) or [512x512](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/VIRTUAL_imagenet512.npz) to evaluate FID, IS, precision, and recall. Note a relatively small `cfg=1.5` is used for trade-off between image quality and diversity. You can adjust it to `cfg=5.0`, or sample with `autoregressive_infer_cfg(..., more_smooth=True)` for **better visual quality**. We'll provide the sampling script later. ## Third-party Usage and Research ***In this pargraph, we cross link third-party repositories or research which use VAR and report results. You can let us know by raising an issue*** (`Note please report accuracy numbers and provide trained models in your new repository to facilitate others to get sense of correctness and model behavior`) | **Time** | **Research** | **Link** | |--------------|-------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------| | [5/12/2025] | [ICML 2025]Continuous Visual Autoregressive Generation via Score Maximization | https://github.com/shaochenze/EAR | | [5/8/2025] | Generative Autoregressive Transformers for Model-Agnostic Federated MRI Reconstruction | https://github.com/icon-lab/FedGAT | | [4/7/2025] | FastVAR: Linear Visual Autoregressive Modeling via Cached Token Pruning | https://github.com/csguoh/FastVAR | | [4/3/2025] | VARGPT-v1.1: Improve Visual Autoregressive Large Unified Model via Iterative Instruction Tuning and Reinforcement Learning | https://github.com/VARGPT-family/VARGPT-v1.1 | | [3/31/2025] | Training-Free Text-Guided Image Editing with Visual Autoregressive Model | https://github.com/wyf0912/AREdit | | [3/17/2025] | Next-Scale Autoregressive Models are Zero-Shot Single-Image Object View Synthesizers | https://github.com/Shiran-Yuan/ArchonView | | [3/14/2025] | Safe-VAR: Safe Visual Autoregressive Model for Text-to-Image Generative Watermarking | https://arxiv.org/abs/2503.11324 | | [3/3/2025] | [ICML 2025]Direct Discriminative Optimization: Your Likelihood-Based Visual Generative Model is Secretly a GAN Discriminator | https://research.nvidia.com/labs/dir/ddo/ | | [2/28/2025] | Autoregressive Medical Image Segmentation via Next-Scale Mask Prediction | https://arxiv.org/abs/2502.20784 | | [2/27/2025] | FlexVAR: Flexible Visual Autoregressive Modeling without Residual Prediction | https://github.com/jiaosiyu1999/FlexVAR | | [2/17/2025] | MARS: Mesh AutoRegressive Model for 3D Shape Detailization | https://arxiv.org/abs/2502.11390 | | [1/31/2025] | [ICML 2025]Visual Autoregressive Modeling for Image Super-Resolution | https://github.com/quyp2000/VARSR | | [1/21/2025] | VARGPT: Unified Understanding and Generation in a Visual Autoregressive Multimodal Large Language Model | https://github.com/VARGPT-family/VARGPT | | [1/26/2025] | [ICML 2025]Visual Generation Without Guidance | https://github.com/thu-ml/GFT | | [12/30/2024] | Next Token Prediction Towards Multimodal Intelligence | https://github.com/LMM101/Awesome-Multimodal-Next-Token-Prediction | | [12/30/2024] | Varformer: Adapting VAR’s Generative Prior for Image Restoration | https://arxiv.org/abs/2412.21063 | | [12/22/2024] | [ICLR 2025]Distilled Decoding 1: One-step Sampling of Image Auto-regressive Models with Flow Matching | https://github.com/imagination-research/distilled-decoding | | [12/19/2024] | FlowAR: Scale-wise Autoregressive Image Generation Meets Flow Matching | https://github.com/OliverRensu/FlowAR | | [12/13/2024] | 3D representation in 512-Byte: Variational tokenizer is the key for autoregressive 3D generation | https://github.com/sparse-mvs-2/VAT | | [12/9/2024] | CARP: Visuomotor Policy Learning via Coarse-to-Fine Autoregressive Prediction | https://carp-robot.github.io/ | | [12/5/2024] | [CVPR 2025]Infinity ∞: Scaling Bitwise AutoRegressive Modeling for High-Resolution Image Synthesis | https://github.com/FoundationVision/Infinity | | [12/5/2024] | [CVPR 2025]Switti: Designing Scale-Wise Transformers for Text-to-Image Synthesis | https://github.com/yandex-research/switti | | [12/4/2024] | [CVPR 2025]TokenFlow🚀: Unified Image Tokenizer for Multimodal Understanding and Generation | https://github.com/ByteFlow-AI/TokenFlow | | [12/3/2024] | XQ-GAN🚀: An Open-source Image Tokenization Framework for Autoregressive Generation | https://github.com/lxa9867/ImageFolder | | [11/28/2024] | [CVPR 2025]CoDe: Collaborative Decoding Makes Visual Auto-Regressive Modeling Efficient | https://github.com/czg1225/CoDe | | [11/28/2024] | [CVPR 2025]Scalable Autoregressive Monocular Depth Estimation | https://arxiv.org/abs/2411.11361 | | [11/27/2024] | [CVPR 2025]SAR3D: Autoregressive 3D Object Generation and Understanding via Multi-scale 3D VQVAE | https://github.com/cyw-3d/SAR3D | | [11/26/2024] | LiteVAR: Compressing Visual Autoregressive Modelling with Efficient Attention and Quantization | https://arxiv.org/abs/2411.17178 | | [11/15/2024] | M-VAR: Decoupled Scale-wise Autoregressive Modeling for High-Quality Image Generation | https://github.com/OliverRensu/MVAR | | [10/14/2024] | [ICLR 2025]HART: Efficient Visual Generation with Hybrid Autoregressive Transformer | https://github.com/mit-han-lab/hart | | [10/12/2024] | [ICLR 2025 Oral]Toward Guidance-Free AR Visual Generation via Condition Contrastive Alignment | https://github.com/thu-ml/CCA | | [10/3/2024] | [ICLR 2025]ImageFolder🚀: Autoregressive Image Generation with Folded Tokens | https://github.com/lxa9867/ImageFolder | | [07/25/2024] | ControlVAR: Exploring Controllable Visual Autoregressive Modeling | https://github.com/lxa9867/ControlVAR | | [07/3/2024] | VAR-CLIP: Text-to-Image Generator with Visual Auto-Regressive Modeling | https://github.com/daixiangzi/VAR-CLIP | | [06/16/2024] | STAR: Scale-wise Text-to-image generation via Auto-Regressive representations | https://arxiv.org/abs/2406.10797 | ## License This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. ## Citation If our work assists your research, feel free to give us a star ⭐ or cite us using: ``` @Article{VAR, title={Visual Autoregressive Modeling: Scalable Image Generation via Next-Scale Prediction}, author={Keyu Tian and Yi Jiang and Zehuan Yuan and Bingyue Peng and Liwei Wang}, year={2024}, eprint={2404.02905}, archivePrefix={arXiv}, primaryClass={cs.CV} } ``` ``` @misc{Infinity, title={Infinity: Scaling Bitwise AutoRegressive Modeling for High-Resolution Image Synthesis}, author={Jian Han and Jinlai Liu and Yi Jiang and Bin Yan and Yuqi Zhang and Zehuan Yuan and Bingyue Peng and Xiaobing Liu}, year={2024}, eprint={2412.04431}, archivePrefix={arXiv}, primaryClass={cs.CV}, url={https://arxiv.org/abs/2412.04431}, } ``` ================================================ FILE: dist.py ================================================ import datetime import functools import os import sys from typing import List from typing import Union import torch import torch.distributed as tdist import torch.multiprocessing as mp __rank, __local_rank, __world_size, __device = 0, 0, 1, 'cuda' if torch.cuda.is_available() else 'cpu' __initialized = False def initialized(): return __initialized def initialize(fork=False, backend='nccl', gpu_id_if_not_distibuted=0, timeout=30): global __device if not torch.cuda.is_available(): print(f'[dist initialize] cuda is not available, use cpu instead', file=sys.stderr) return elif 'RANK' not in os.environ: torch.cuda.set_device(gpu_id_if_not_distibuted) __device = torch.empty(1).cuda().device print(f'[dist initialize] env variable "RANK" is not set, use {__device} as the device', file=sys.stderr) return # then 'RANK' must exist global_rank, num_gpus = int(os.environ['RANK']), torch.cuda.device_count() local_rank = global_rank % num_gpus torch.cuda.set_device(local_rank) # ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29 if mp.get_start_method(allow_none=True) is None: method = 'fork' if fork else 'spawn' print(f'[dist initialize] mp method={method}') mp.set_start_method(method) tdist.init_process_group(backend=backend, timeout=datetime.timedelta(seconds=timeout*60)) global __rank, __local_rank, __world_size, __initialized __local_rank = local_rank __rank, __world_size = tdist.get_rank(), tdist.get_world_size() __device = torch.empty(1).cuda().device __initialized = True assert tdist.is_initialized(), 'torch.distributed is not initialized!' print(f'[lrk={get_local_rank()}, rk={get_rank()}]') def get_rank(): return __rank def get_local_rank(): return __local_rank def get_world_size(): return __world_size def get_device(): return __device def set_gpu_id(gpu_id: int): if gpu_id is None: return global __device if isinstance(gpu_id, (str, int)): torch.cuda.set_device(int(gpu_id)) __device = torch.empty(1).cuda().device else: raise NotImplementedError def is_master(): return __rank == 0 def is_local_master(): return __local_rank == 0 def new_group(ranks: List[int]): if __initialized: return tdist.new_group(ranks=ranks) return None def barrier(): if __initialized: tdist.barrier() def allreduce(t: torch.Tensor, async_op=False): if __initialized: if not t.is_cuda: cu = t.detach().cuda() ret = tdist.all_reduce(cu, async_op=async_op) t.copy_(cu.cpu()) else: ret = tdist.all_reduce(t, async_op=async_op) return ret return None def allgather(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]: if __initialized: if not t.is_cuda: t = t.cuda() ls = [torch.empty_like(t) for _ in range(__world_size)] tdist.all_gather(ls, t) else: ls = [t] if cat: ls = torch.cat(ls, dim=0) return ls def allgather_diff_shape(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]: if __initialized: if not t.is_cuda: t = t.cuda() t_size = torch.tensor(t.size(), device=t.device) ls_size = [torch.empty_like(t_size) for _ in range(__world_size)] tdist.all_gather(ls_size, t_size) max_B = max(size[0].item() for size in ls_size) pad = max_B - t_size[0].item() if pad: pad_size = (pad, *t.size()[1:]) t = torch.cat((t, t.new_empty(pad_size)), dim=0) ls_padded = [torch.empty_like(t) for _ in range(__world_size)] tdist.all_gather(ls_padded, t) ls = [] for t, size in zip(ls_padded, ls_size): ls.append(t[:size[0].item()]) else: ls = [t] if cat: ls = torch.cat(ls, dim=0) return ls def broadcast(t: torch.Tensor, src_rank) -> None: if __initialized: if not t.is_cuda: cu = t.detach().cuda() tdist.broadcast(cu, src=src_rank) t.copy_(cu.cpu()) else: tdist.broadcast(t, src=src_rank) def dist_fmt_vals(val: float, fmt: Union[str, None] = '%.2f') -> Union[torch.Tensor, List]: if not initialized(): return torch.tensor([val]) if fmt is None else [fmt % val] ts = torch.zeros(__world_size) ts[__rank] = val allreduce(ts) if fmt is None: return ts return [fmt % v for v in ts.cpu().numpy().tolist()] def master_only(func): @functools.wraps(func) def wrapper(*args, **kwargs): force = kwargs.pop('force', False) if force or is_master(): ret = func(*args, **kwargs) else: ret = None barrier() return ret return wrapper def local_master_only(func): @functools.wraps(func) def wrapper(*args, **kwargs): force = kwargs.pop('force', False) if force or is_local_master(): ret = func(*args, **kwargs) else: ret = None barrier() return ret return wrapper def for_visualize(func): @functools.wraps(func) def wrapper(*args, **kwargs): if is_master(): # with torch.no_grad(): ret = func(*args, **kwargs) else: ret = None return ret return wrapper def finalize(): if __initialized: tdist.destroy_process_group() ================================================ FILE: models/__init__.py ================================================ from typing import Tuple import torch.nn as nn from .quant import VectorQuantizer2 from .var import VAR from .vqvae import VQVAE def build_vae_var( # Shared args device, patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # 10 steps by default # VQVAE args V=4096, Cvae=32, ch=160, share_quant_resi=4, # VAR args num_classes=1000, depth=16, shared_aln=False, attn_l2_norm=True, flash_if_available=True, fused_if_available=True, init_adaln=0.5, init_adaln_gamma=1e-5, init_head=0.02, init_std=-1, # init_std < 0: automated ) -> Tuple[VQVAE, VAR]: heads = depth width = depth * 64 dpr = 0.1 * depth/24 # disable built-in initialization for speed for clz in (nn.Linear, nn.LayerNorm, nn.BatchNorm2d, nn.SyncBatchNorm, nn.Conv1d, nn.Conv2d, nn.ConvTranspose1d, nn.ConvTranspose2d): setattr(clz, 'reset_parameters', lambda self: None) # build models vae_local = VQVAE(vocab_size=V, z_channels=Cvae, ch=ch, test_mode=True, share_quant_resi=share_quant_resi, v_patch_nums=patch_nums).to(device) var_wo_ddp = VAR( vae_local=vae_local, num_classes=num_classes, depth=depth, embed_dim=width, num_heads=heads, drop_rate=0., attn_drop_rate=0., drop_path_rate=dpr, norm_eps=1e-6, shared_aln=shared_aln, cond_drop_rate=0.1, attn_l2_norm=attn_l2_norm, patch_nums=patch_nums, flash_if_available=flash_if_available, fused_if_available=fused_if_available, ).to(device) var_wo_ddp.init_weights(init_adaln=init_adaln, init_adaln_gamma=init_adaln_gamma, init_head=init_head, init_std=init_std) return vae_local, var_wo_ddp ================================================ FILE: models/basic_vae.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F # this file only provides the 2 modules used in VQVAE __all__ = ['Encoder', 'Decoder',] """ References: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/diffusionmodules/model.py """ # swish def nonlinearity(x): return x * torch.sigmoid(x) def Normalize(in_channels, num_groups=32): return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) class Upsample2x(nn.Module): def __init__(self, in_channels): super().__init__() self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) def forward(self, x): return self.conv(F.interpolate(x, scale_factor=2, mode='nearest')) class Downsample2x(nn.Module): def __init__(self, in_channels): super().__init__() self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) def forward(self, x): return self.conv(F.pad(x, pad=(0, 1, 0, 1), mode='constant', value=0)) class ResnetBlock(nn.Module): def __init__(self, *, in_channels, out_channels=None, dropout): # conv_shortcut=False, # conv_shortcut: always False in VAE super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels self.norm1 = Normalize(in_channels) self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) self.norm2 = Normalize(out_channels) self.dropout = torch.nn.Dropout(dropout) if dropout > 1e-6 else nn.Identity() self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) if self.in_channels != self.out_channels: self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) else: self.nin_shortcut = nn.Identity() def forward(self, x): h = self.conv1(F.silu(self.norm1(x), inplace=True)) h = self.conv2(self.dropout(F.silu(self.norm2(h), inplace=True))) return self.nin_shortcut(x) + h class AttnBlock(nn.Module): def __init__(self, in_channels): super().__init__() self.C = in_channels self.norm = Normalize(in_channels) self.qkv = torch.nn.Conv2d(in_channels, 3*in_channels, kernel_size=1, stride=1, padding=0) self.w_ratio = int(in_channels) ** (-0.5) self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) def forward(self, x): qkv = self.qkv(self.norm(x)) B, _, H, W = qkv.shape # should be B,3C,H,W C = self.C q, k, v = qkv.reshape(B, 3, C, H, W).unbind(1) # compute attention q = q.view(B, C, H * W).contiguous() q = q.permute(0, 2, 1).contiguous() # B,HW,C k = k.view(B, C, H * W).contiguous() # B,C,HW w = torch.bmm(q, k).mul_(self.w_ratio) # B,HW,HW w[B,i,j]=sum_c q[B,i,C]k[B,C,j] w = F.softmax(w, dim=2) # attend to values v = v.view(B, C, H * W).contiguous() w = w.permute(0, 2, 1).contiguous() # B,HW,HW (first HW of k, second of q) h = torch.bmm(v, w) # B, C,HW (HW of q) h[B,C,j] = sum_i v[B,C,i] w[B,i,j] h = h.view(B, C, H, W).contiguous() return x + self.proj_out(h) def make_attn(in_channels, using_sa=True): return AttnBlock(in_channels) if using_sa else nn.Identity() class Encoder(nn.Module): def __init__( self, *, ch=128, ch_mult=(1, 2, 4, 8), num_res_blocks=2, dropout=0.0, in_channels=3, z_channels, double_z=False, using_sa=True, using_mid_sa=True, ): super().__init__() self.ch = ch self.num_resolutions = len(ch_mult) self.downsample_ratio = 2 ** (self.num_resolutions - 1) self.num_res_blocks = num_res_blocks self.in_channels = in_channels # downsampling self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) in_ch_mult = (1,) + tuple(ch_mult) self.down = nn.ModuleList() for i_level in range(self.num_resolutions): block = nn.ModuleList() attn = nn.ModuleList() block_in = ch * in_ch_mult[i_level] block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks): block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dropout=dropout)) block_in = block_out if i_level == self.num_resolutions - 1 and using_sa: attn.append(make_attn(block_in, using_sa=True)) down = nn.Module() down.block = block down.attn = attn if i_level != self.num_resolutions - 1: down.downsample = Downsample2x(block_in) self.down.append(down) # middle self.mid = nn.Module() self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout) self.mid.attn_1 = make_attn(block_in, using_sa=using_mid_sa) self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout) # end self.norm_out = Normalize(block_in) self.conv_out = torch.nn.Conv2d(block_in, (2 * z_channels if double_z else z_channels), kernel_size=3, stride=1, padding=1) def forward(self, x): # downsampling h = self.conv_in(x) for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): h = self.down[i_level].block[i_block](h) if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) if i_level != self.num_resolutions - 1: h = self.down[i_level].downsample(h) # middle h = self.mid.block_2(self.mid.attn_1(self.mid.block_1(h))) # end h = self.conv_out(F.silu(self.norm_out(h), inplace=True)) return h class Decoder(nn.Module): def __init__( self, *, ch=128, ch_mult=(1, 2, 4, 8), num_res_blocks=2, dropout=0.0, in_channels=3, # in_channels: raw img channels z_channels, using_sa=True, using_mid_sa=True, ): super().__init__() self.ch = ch self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.in_channels = in_channels # compute in_ch_mult, block_in and curr_res at lowest res in_ch_mult = (1,) + tuple(ch_mult) block_in = ch * ch_mult[self.num_resolutions - 1] # z to block_in self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) # middle self.mid = nn.Module() self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout) self.mid.attn_1 = make_attn(block_in, using_sa=using_mid_sa) self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout) # upsampling self.up = nn.ModuleList() for i_level in reversed(range(self.num_resolutions)): block = nn.ModuleList() attn = nn.ModuleList() block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks + 1): block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dropout=dropout)) block_in = block_out if i_level == self.num_resolutions-1 and using_sa: attn.append(make_attn(block_in, using_sa=True)) up = nn.Module() up.block = block up.attn = attn if i_level != 0: up.upsample = Upsample2x(block_in) self.up.insert(0, up) # prepend to get consistent order # end self.norm_out = Normalize(block_in) self.conv_out = torch.nn.Conv2d(block_in, in_channels, kernel_size=3, stride=1, padding=1) def forward(self, z): # z to block_in # middle h = self.mid.block_2(self.mid.attn_1(self.mid.block_1(self.conv_in(z)))) # upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks + 1): h = self.up[i_level].block[i_block](h) if len(self.up[i_level].attn) > 0: h = self.up[i_level].attn[i_block](h) if i_level != 0: h = self.up[i_level].upsample(h) # end h = self.conv_out(F.silu(self.norm_out(h), inplace=True)) return h ================================================ FILE: models/basic_var.py ================================================ import math import torch import torch.nn as nn import torch.nn.functional as F from models.helpers import DropPath, drop_path # this file only provides the 3 blocks used in VAR transformer __all__ = ['FFN', 'AdaLNSelfAttn', 'AdaLNBeforeHead'] # automatically import fused operators dropout_add_layer_norm = fused_mlp_func = memory_efficient_attention = flash_attn_func = None try: from flash_attn.ops.layer_norm import dropout_add_layer_norm from flash_attn.ops.fused_dense import fused_mlp_func except ImportError: pass # automatically import faster attention implementations try: from xformers.ops import memory_efficient_attention except ImportError: pass try: from flash_attn import flash_attn_func # qkv: BLHc, ret: BLHcq except ImportError: pass try: from torch.nn.functional import scaled_dot_product_attention as slow_attn # q, k, v: BHLc except ImportError: def slow_attn(query, key, value, scale: float, attn_mask=None, dropout_p=0.0): attn = query.mul(scale) @ key.transpose(-2, -1) # BHLc @ BHcL => BHLL if attn_mask is not None: attn.add_(attn_mask) return (F.dropout(attn.softmax(dim=-1), p=dropout_p, inplace=True) if dropout_p > 0 else attn.softmax(dim=-1)) @ value class FFN(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, drop=0., fused_if_available=True): super().__init__() self.fused_mlp_func = fused_mlp_func if fused_if_available else None out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = nn.GELU(approximate='tanh') self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop, inplace=True) if drop > 0 else nn.Identity() def forward(self, x): if self.fused_mlp_func is not None: return self.drop(self.fused_mlp_func( x=x, weight1=self.fc1.weight, weight2=self.fc2.weight, bias1=self.fc1.bias, bias2=self.fc2.bias, activation='gelu_approx', save_pre_act=self.training, return_residual=False, checkpoint_lvl=0, heuristic=0, process_group=None, )) else: return self.drop(self.fc2( self.act(self.fc1(x)) )) def extra_repr(self) -> str: return f'fused_mlp_func={self.fused_mlp_func is not None}' class SelfAttention(nn.Module): def __init__( self, block_idx, embed_dim=768, num_heads=12, attn_drop=0., proj_drop=0., attn_l2_norm=False, flash_if_available=True, ): super().__init__() assert embed_dim % num_heads == 0 self.block_idx, self.num_heads, self.head_dim = block_idx, num_heads, embed_dim // num_heads # =64 self.attn_l2_norm = attn_l2_norm if self.attn_l2_norm: self.scale = 1 self.scale_mul_1H11 = nn.Parameter(torch.full(size=(1, self.num_heads, 1, 1), fill_value=4.0).log(), requires_grad=True) self.max_scale_mul = torch.log(torch.tensor(100)).item() else: self.scale = 0.25 / math.sqrt(self.head_dim) self.mat_qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False) self.q_bias, self.v_bias = nn.Parameter(torch.zeros(embed_dim)), nn.Parameter(torch.zeros(embed_dim)) self.register_buffer('zero_k_bias', torch.zeros(embed_dim)) self.proj = nn.Linear(embed_dim, embed_dim) self.proj_drop = nn.Dropout(proj_drop, inplace=True) if proj_drop > 0 else nn.Identity() self.attn_drop: float = attn_drop self.using_flash = flash_if_available and flash_attn_func is not None self.using_xform = flash_if_available and memory_efficient_attention is not None # only used during inference self.caching, self.cached_k, self.cached_v = False, None, None def kv_caching(self, enable: bool): self.caching, self.cached_k, self.cached_v = enable, None, None # NOTE: attn_bias is None during inference because kv cache is enabled def forward(self, x, attn_bias): B, L, C = x.shape qkv = F.linear(input=x, weight=self.mat_qkv.weight, bias=torch.cat((self.q_bias, self.zero_k_bias, self.v_bias))).view(B, L, 3, self.num_heads, self.head_dim) main_type = qkv.dtype # qkv: BL3Hc using_flash = self.using_flash and attn_bias is None and qkv.dtype != torch.float32 if using_flash or self.using_xform: q, k, v = qkv.unbind(dim=2); dim_cat = 1 # q or k or v: BLHc else: q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(dim=0); dim_cat = 2 # q or k or v: BHLc if self.attn_l2_norm: scale_mul = self.scale_mul_1H11.clamp_max(self.max_scale_mul).exp() if using_flash or self.using_xform: scale_mul = scale_mul.transpose(1, 2) # 1H11 to 11H1 q = F.normalize(q, dim=-1).mul(scale_mul) k = F.normalize(k, dim=-1) if self.caching: if self.cached_k is None: self.cached_k = k; self.cached_v = v else: k = self.cached_k = torch.cat((self.cached_k, k), dim=dim_cat); v = self.cached_v = torch.cat((self.cached_v, v), dim=dim_cat) dropout_p = self.attn_drop if self.training else 0.0 if using_flash: oup = flash_attn_func(q.to(dtype=main_type), k.to(dtype=main_type), v.to(dtype=main_type), dropout_p=dropout_p, softmax_scale=self.scale).view(B, L, C) elif self.using_xform: oup = memory_efficient_attention(q.to(dtype=main_type), k.to(dtype=main_type), v.to(dtype=main_type), attn_bias=None if attn_bias is None else attn_bias.to(dtype=main_type).expand(B, self.num_heads, -1, -1), p=dropout_p, scale=self.scale).view(B, L, C) else: oup = slow_attn(query=q, key=k, value=v, scale=self.scale, attn_mask=attn_bias, dropout_p=dropout_p).transpose(1, 2).reshape(B, L, C) return self.proj_drop(self.proj(oup)) # attn = (q @ k.transpose(-2, -1)).add_(attn_bias + self.local_rpb()) # BHLc @ BHcL => BHLL # attn = self.attn_drop(attn.softmax(dim=-1)) # oup = (attn @ v).transpose_(1, 2).reshape(B, L, -1) # BHLL @ BHLc = BHLc => BLHc => BLC def extra_repr(self) -> str: return f'using_flash={self.using_flash}, using_xform={self.using_xform}, attn_l2_norm={self.attn_l2_norm}' class AdaLNSelfAttn(nn.Module): def __init__( self, block_idx, last_drop_p, embed_dim, cond_dim, shared_aln: bool, norm_layer, num_heads, mlp_ratio=4., drop=0., attn_drop=0., drop_path=0., attn_l2_norm=False, flash_if_available=False, fused_if_available=True, ): super(AdaLNSelfAttn, self).__init__() self.block_idx, self.last_drop_p, self.C = block_idx, last_drop_p, embed_dim self.C, self.D = embed_dim, cond_dim self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.attn = SelfAttention(block_idx=block_idx, embed_dim=embed_dim, num_heads=num_heads, attn_drop=attn_drop, proj_drop=drop, attn_l2_norm=attn_l2_norm, flash_if_available=flash_if_available) self.ffn = FFN(in_features=embed_dim, hidden_features=round(embed_dim * mlp_ratio), drop=drop, fused_if_available=fused_if_available) self.ln_wo_grad = norm_layer(embed_dim, elementwise_affine=False) self.shared_aln = shared_aln if self.shared_aln: self.ada_gss = nn.Parameter(torch.randn(1, 1, 6, embed_dim) / embed_dim**0.5) else: lin = nn.Linear(cond_dim, 6*embed_dim) self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), lin) self.fused_add_norm_fn = None # NOTE: attn_bias is None during inference because kv cache is enabled def forward(self, x, cond_BD, attn_bias): # C: embed_dim, D: cond_dim if self.shared_aln: gamma1, gamma2, scale1, scale2, shift1, shift2 = (self.ada_gss + cond_BD).unbind(2) # 116C + B16C =unbind(2)=> 6 B1C else: gamma1, gamma2, scale1, scale2, shift1, shift2 = self.ada_lin(cond_BD).view(-1, 1, 6, self.C).unbind(2) x = x + self.drop_path(self.attn( self.ln_wo_grad(x).mul(scale1.add(1)).add_(shift1), attn_bias=attn_bias ).mul_(gamma1)) x = x + self.drop_path(self.ffn( self.ln_wo_grad(x).mul(scale2.add(1)).add_(shift2) ).mul(gamma2)) # this mul(gamma2) cannot be in-placed when FusedMLP is used return x def extra_repr(self) -> str: return f'shared_aln={self.shared_aln}' class AdaLNBeforeHead(nn.Module): def __init__(self, C, D, norm_layer): # C: embed_dim, D: cond_dim super().__init__() self.C, self.D = C, D self.ln_wo_grad = norm_layer(C, elementwise_affine=False) self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), nn.Linear(D, 2*C)) def forward(self, x_BLC: torch.Tensor, cond_BD: torch.Tensor): scale, shift = self.ada_lin(cond_BD).view(-1, 1, 2, self.C).unbind(2) return self.ln_wo_grad(x_BLC).mul(scale.add(1)).add_(shift) ================================================ FILE: models/helpers.py ================================================ import torch from torch import nn as nn from torch.nn import functional as F def sample_with_top_k_top_p_(logits_BlV: torch.Tensor, top_k: int = 0, top_p: float = 0.0, rng=None, num_samples=1) -> torch.Tensor: # return idx, shaped (B, l) B, l, V = logits_BlV.shape if top_k > 0: idx_to_remove = logits_BlV < logits_BlV.topk(top_k, largest=True, sorted=False, dim=-1)[0].amin(dim=-1, keepdim=True) logits_BlV.masked_fill_(idx_to_remove, -torch.inf) if top_p > 0: sorted_logits, sorted_idx = logits_BlV.sort(dim=-1, descending=False) sorted_idx_to_remove = sorted_logits.softmax(dim=-1).cumsum_(dim=-1) <= (1 - top_p) sorted_idx_to_remove[..., -1:] = False logits_BlV.masked_fill_(sorted_idx_to_remove.scatter(sorted_idx.ndim - 1, sorted_idx, sorted_idx_to_remove), -torch.inf) # sample (have to squeeze cuz torch.multinomial can only be used for 2D tensor) replacement = num_samples >= 0 num_samples = abs(num_samples) return torch.multinomial(logits_BlV.softmax(dim=-1).view(-1, V), num_samples=num_samples, replacement=replacement, generator=rng).view(B, l, num_samples) def gumbel_softmax_with_rng(logits: torch.Tensor, tau: float = 1, hard: bool = False, eps: float = 1e-10, dim: int = -1, rng: torch.Generator = None) -> torch.Tensor: if rng is None: return F.gumbel_softmax(logits=logits, tau=tau, hard=hard, eps=eps, dim=dim) gumbels = (-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_(generator=rng).log()) gumbels = (logits + gumbels) / tau y_soft = gumbels.softmax(dim) if hard: index = y_soft.max(dim, keepdim=True)[1] y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0) ret = y_hard - y_soft.detach() + y_soft else: ret = y_soft return ret def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): # taken from timm if drop_prob == 0. or not training: return x keep_prob = 1 - drop_prob shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets random_tensor = x.new_empty(shape).bernoulli_(keep_prob) if keep_prob > 0.0 and scale_by_keep: random_tensor.div_(keep_prob) return x * random_tensor class DropPath(nn.Module): # taken from timm def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): super(DropPath, self).__init__() self.drop_prob = drop_prob self.scale_by_keep = scale_by_keep def forward(self, x): return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) def extra_repr(self): return f'(drop_prob=...)' ================================================ FILE: models/quant.py ================================================ from typing import List, Optional, Sequence, Tuple, Union import numpy as np import torch from torch import distributed as tdist, nn as nn from torch.nn import functional as F import dist # this file only provides the VectorQuantizer2 used in VQVAE __all__ = ['VectorQuantizer2',] class VectorQuantizer2(nn.Module): # VQGAN originally use beta=1.0, never tried 0.25; SD seems using 0.25 def __init__( self, vocab_size, Cvae, using_znorm, beta: float = 0.25, default_qresi_counts=0, v_patch_nums=None, quant_resi=0.5, share_quant_resi=4, # share_quant_resi: args.qsr ): super().__init__() self.vocab_size: int = vocab_size self.Cvae: int = Cvae self.using_znorm: bool = using_znorm self.v_patch_nums: Tuple[int] = v_patch_nums self.quant_resi_ratio = quant_resi if share_quant_resi == 0: # non-shared: \phi_{1 to K} for K scales self.quant_resi = PhiNonShared([(Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity()) for _ in range(default_qresi_counts or len(self.v_patch_nums))]) elif share_quant_resi == 1: # fully shared: only a single \phi for K scales self.quant_resi = PhiShared(Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity()) else: # partially shared: \phi_{1 to share_quant_resi} for K scales self.quant_resi = PhiPartiallyShared(nn.ModuleList([(Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity()) for _ in range(share_quant_resi)])) self.register_buffer('ema_vocab_hit_SV', torch.full((len(self.v_patch_nums), self.vocab_size), fill_value=0.0)) self.record_hit = 0 self.beta: float = beta self.embedding = nn.Embedding(self.vocab_size, self.Cvae) # only used for progressive training of VAR (not supported yet, will be tested and supported in the future) self.prog_si = -1 # progressive training: not supported yet, prog_si always -1 def eini(self, eini): if eini > 0: nn.init.trunc_normal_(self.embedding.weight.data, std=eini) elif eini < 0: self.embedding.weight.data.uniform_(-abs(eini) / self.vocab_size, abs(eini) / self.vocab_size) def extra_repr(self) -> str: return f'{self.v_patch_nums}, znorm={self.using_znorm}, beta={self.beta} | S={len(self.v_patch_nums)}, quant_resi={self.quant_resi_ratio}' # ===================== `forward` is only used in VAE training ===================== def forward(self, f_BChw: torch.Tensor, ret_usages=False) -> Tuple[torch.Tensor, List[float], torch.Tensor]: dtype = f_BChw.dtype if dtype != torch.float32: f_BChw = f_BChw.float() B, C, H, W = f_BChw.shape f_no_grad = f_BChw.detach() f_rest = f_no_grad.clone() f_hat = torch.zeros_like(f_rest) with torch.cuda.amp.autocast(enabled=False): mean_vq_loss: torch.Tensor = 0.0 vocab_hit_V = torch.zeros(self.vocab_size, dtype=torch.float, device=f_BChw.device) SN = len(self.v_patch_nums) for si, pn in enumerate(self.v_patch_nums): # from small to large # find the nearest embedding if self.using_znorm: rest_NC = F.interpolate(f_rest, size=(pn, pn), mode='area').permute(0, 2, 3, 1).reshape(-1, C) if (si != SN-1) else f_rest.permute(0, 2, 3, 1).reshape(-1, C) rest_NC = F.normalize(rest_NC, dim=-1) idx_N = torch.argmax(rest_NC @ F.normalize(self.embedding.weight.data.T, dim=0), dim=1) else: rest_NC = F.interpolate(f_rest, size=(pn, pn), mode='area').permute(0, 2, 3, 1).reshape(-1, C) if (si != SN-1) else f_rest.permute(0, 2, 3, 1).reshape(-1, C) d_no_grad = torch.sum(rest_NC.square(), dim=1, keepdim=True) + torch.sum(self.embedding.weight.data.square(), dim=1, keepdim=False) d_no_grad.addmm_(rest_NC, self.embedding.weight.data.T, alpha=-2, beta=1) # (B*h*w, vocab_size) idx_N = torch.argmin(d_no_grad, dim=1) hit_V = idx_N.bincount(minlength=self.vocab_size).float() if self.training: if dist.initialized(): handler = tdist.all_reduce(hit_V, async_op=True) # calc loss idx_Bhw = idx_N.view(B, pn, pn) h_BChw = F.interpolate(self.embedding(idx_Bhw).permute(0, 3, 1, 2), size=(H, W), mode='bicubic').contiguous() if (si != SN-1) else self.embedding(idx_Bhw).permute(0, 3, 1, 2).contiguous() h_BChw = self.quant_resi[si/(SN-1)](h_BChw) f_hat = f_hat + h_BChw f_rest -= h_BChw if self.training and dist.initialized(): handler.wait() if self.record_hit == 0: self.ema_vocab_hit_SV[si].copy_(hit_V) elif self.record_hit < 100: self.ema_vocab_hit_SV[si].mul_(0.9).add_(hit_V.mul(0.1)) else: self.ema_vocab_hit_SV[si].mul_(0.99).add_(hit_V.mul(0.01)) self.record_hit += 1 vocab_hit_V.add_(hit_V) mean_vq_loss += F.mse_loss(f_hat.data, f_BChw).mul_(self.beta) + F.mse_loss(f_hat, f_no_grad) mean_vq_loss *= 1. / SN f_hat = (f_hat.data - f_no_grad).add_(f_BChw) margin = tdist.get_world_size() * (f_BChw.numel() / f_BChw.shape[1]) / self.vocab_size * 0.08 # margin = pn*pn / 100 if ret_usages: usages = [(self.ema_vocab_hit_SV[si] >= margin).float().mean().item() * 100 for si, pn in enumerate(self.v_patch_nums)] else: usages = None return f_hat, usages, mean_vq_loss # ===================== `forward` is only used in VAE training ===================== def embed_to_fhat(self, ms_h_BChw: List[torch.Tensor], all_to_max_scale=True, last_one=False) -> Union[List[torch.Tensor], torch.Tensor]: ls_f_hat_BChw = [] B = ms_h_BChw[0].shape[0] H = W = self.v_patch_nums[-1] SN = len(self.v_patch_nums) if all_to_max_scale: f_hat = ms_h_BChw[0].new_zeros(B, self.Cvae, H, W, dtype=torch.float32) for si, pn in enumerate(self.v_patch_nums): # from small to large h_BChw = ms_h_BChw[si] if si < len(self.v_patch_nums) - 1: h_BChw = F.interpolate(h_BChw, size=(H, W), mode='bicubic') h_BChw = self.quant_resi[si/(SN-1)](h_BChw) f_hat.add_(h_BChw) if last_one: ls_f_hat_BChw = f_hat else: ls_f_hat_BChw.append(f_hat.clone()) else: # WARNING: this is not the case in VQ-VAE training or inference (we'll interpolate every token map to the max H W, like above) # WARNING: this should only be used for experimental purpose f_hat = ms_h_BChw[0].new_zeros(B, self.Cvae, self.v_patch_nums[0], self.v_patch_nums[0], dtype=torch.float32) for si, pn in enumerate(self.v_patch_nums): # from small to large f_hat = F.interpolate(f_hat, size=(pn, pn), mode='bicubic') h_BChw = self.quant_resi[si/(SN-1)](ms_h_BChw[si]) f_hat.add_(h_BChw) if last_one: ls_f_hat_BChw = f_hat else: ls_f_hat_BChw.append(f_hat) return ls_f_hat_BChw def f_to_idxBl_or_fhat(self, f_BChw: torch.Tensor, to_fhat: bool, v_patch_nums: Optional[Sequence[Union[int, Tuple[int, int]]]] = None) -> List[Union[torch.Tensor, torch.LongTensor]]: # z_BChw is the feature from inp_img_no_grad B, C, H, W = f_BChw.shape f_no_grad = f_BChw.detach() f_rest = f_no_grad.clone() f_hat = torch.zeros_like(f_rest) f_hat_or_idx_Bl: List[torch.Tensor] = [] patch_hws = [(pn, pn) if isinstance(pn, int) else (pn[0], pn[1]) for pn in (v_patch_nums or self.v_patch_nums)] # from small to large assert patch_hws[-1][0] == H and patch_hws[-1][1] == W, f'{patch_hws[-1]=} != ({H=}, {W=})' SN = len(patch_hws) for si, (ph, pw) in enumerate(patch_hws): # from small to large if 0 <= self.prog_si < si: break # progressive training: not supported yet, prog_si always -1 # find the nearest embedding z_NC = F.interpolate(f_rest, size=(ph, pw), mode='area').permute(0, 2, 3, 1).reshape(-1, C) if (si != SN-1) else f_rest.permute(0, 2, 3, 1).reshape(-1, C) if self.using_znorm: z_NC = F.normalize(z_NC, dim=-1) idx_N = torch.argmax(z_NC @ F.normalize(self.embedding.weight.data.T, dim=0), dim=1) else: d_no_grad = torch.sum(z_NC.square(), dim=1, keepdim=True) + torch.sum(self.embedding.weight.data.square(), dim=1, keepdim=False) d_no_grad.addmm_(z_NC, self.embedding.weight.data.T, alpha=-2, beta=1) # (B*h*w, vocab_size) idx_N = torch.argmin(d_no_grad, dim=1) idx_Bhw = idx_N.view(B, ph, pw) h_BChw = F.interpolate(self.embedding(idx_Bhw).permute(0, 3, 1, 2), size=(H, W), mode='bicubic').contiguous() if (si != SN-1) else self.embedding(idx_Bhw).permute(0, 3, 1, 2).contiguous() h_BChw = self.quant_resi[si/(SN-1)](h_BChw) f_hat.add_(h_BChw) f_rest.sub_(h_BChw) f_hat_or_idx_Bl.append(f_hat.clone() if to_fhat else idx_N.reshape(B, ph*pw)) return f_hat_or_idx_Bl # ===================== idxBl_to_var_input: only used in VAR training, for getting teacher-forcing input ===================== def idxBl_to_var_input(self, gt_ms_idx_Bl: List[torch.Tensor]) -> torch.Tensor: next_scales = [] B = gt_ms_idx_Bl[0].shape[0] C = self.Cvae H = W = self.v_patch_nums[-1] SN = len(self.v_patch_nums) f_hat = gt_ms_idx_Bl[0].new_zeros(B, C, H, W, dtype=torch.float32) pn_next: int = self.v_patch_nums[0] for si in range(SN-1): if self.prog_si == 0 or (0 <= self.prog_si-1 < si): break # progressive training: not supported yet, prog_si always -1 h_BChw = F.interpolate(self.embedding(gt_ms_idx_Bl[si]).transpose_(1, 2).view(B, C, pn_next, pn_next), size=(H, W), mode='bicubic') f_hat.add_(self.quant_resi[si/(SN-1)](h_BChw)) pn_next = self.v_patch_nums[si+1] next_scales.append(F.interpolate(f_hat, size=(pn_next, pn_next), mode='area').view(B, C, -1).transpose(1, 2)) return torch.cat(next_scales, dim=1) if len(next_scales) else None # cat BlCs to BLC, this should be float32 # ===================== get_next_autoregressive_input: only used in VAR inference, for getting next step's input ===================== def get_next_autoregressive_input(self, si: int, SN: int, f_hat: torch.Tensor, h_BChw: torch.Tensor) -> Tuple[Optional[torch.Tensor], torch.Tensor]: # only used in VAR inference HW = self.v_patch_nums[-1] if si != SN-1: h = self.quant_resi[si/(SN-1)](F.interpolate(h_BChw, size=(HW, HW), mode='bicubic')) # conv after upsample f_hat.add_(h) return f_hat, F.interpolate(f_hat, size=(self.v_patch_nums[si+1], self.v_patch_nums[si+1]), mode='area') else: h = self.quant_resi[si/(SN-1)](h_BChw) f_hat.add_(h) return f_hat, f_hat class Phi(nn.Conv2d): def __init__(self, embed_dim, quant_resi): ks = 3 super().__init__(in_channels=embed_dim, out_channels=embed_dim, kernel_size=ks, stride=1, padding=ks//2) self.resi_ratio = abs(quant_resi) def forward(self, h_BChw): return h_BChw.mul(1-self.resi_ratio) + super().forward(h_BChw).mul_(self.resi_ratio) class PhiShared(nn.Module): def __init__(self, qresi: Phi): super().__init__() self.qresi: Phi = qresi def __getitem__(self, _) -> Phi: return self.qresi class PhiPartiallyShared(nn.Module): def __init__(self, qresi_ls: nn.ModuleList): super().__init__() self.qresi_ls = qresi_ls K = len(qresi_ls) self.ticks = np.linspace(1/3/K, 1-1/3/K, K) if K == 4 else np.linspace(1/2/K, 1-1/2/K, K) def __getitem__(self, at_from_0_to_1: float) -> Phi: return self.qresi_ls[np.argmin(np.abs(self.ticks - at_from_0_to_1)).item()] def extra_repr(self) -> str: return f'ticks={self.ticks}' class PhiNonShared(nn.ModuleList): def __init__(self, qresi: List): super().__init__(qresi) # self.qresi = qresi K = len(qresi) self.ticks = np.linspace(1/3/K, 1-1/3/K, K) if K == 4 else np.linspace(1/2/K, 1-1/2/K, K) def __getitem__(self, at_from_0_to_1: float) -> Phi: return super().__getitem__(np.argmin(np.abs(self.ticks - at_from_0_to_1)).item()) def extra_repr(self) -> str: return f'ticks={self.ticks}' ================================================ FILE: models/var.py ================================================ import math from functools import partial from typing import Optional, Tuple, Union import torch import torch.nn as nn from huggingface_hub import PyTorchModelHubMixin import dist from models.basic_var import AdaLNBeforeHead, AdaLNSelfAttn from models.helpers import gumbel_softmax_with_rng, sample_with_top_k_top_p_ from models.vqvae import VQVAE, VectorQuantizer2 class SharedAdaLin(nn.Linear): def forward(self, cond_BD): C = self.weight.shape[0] // 6 return super().forward(cond_BD).view(-1, 1, 6, C) # B16C class VAR(nn.Module): def __init__( self, vae_local: VQVAE, num_classes=1000, depth=16, embed_dim=1024, num_heads=16, mlp_ratio=4., drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_eps=1e-6, shared_aln=False, cond_drop_rate=0.1, attn_l2_norm=False, patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # 10 steps by default flash_if_available=True, fused_if_available=True, ): super().__init__() # 0. hyperparameters assert embed_dim % num_heads == 0 self.Cvae, self.V = vae_local.Cvae, vae_local.vocab_size self.depth, self.C, self.D, self.num_heads = depth, embed_dim, embed_dim, num_heads self.cond_drop_rate = cond_drop_rate self.prog_si = -1 # progressive training self.patch_nums: Tuple[int] = patch_nums self.L = sum(pn ** 2 for pn in self.patch_nums) self.first_l = self.patch_nums[0] ** 2 self.begin_ends = [] cur = 0 for i, pn in enumerate(self.patch_nums): self.begin_ends.append((cur, cur+pn ** 2)) cur += pn ** 2 self.num_stages_minus_1 = len(self.patch_nums) - 1 self.rng = torch.Generator(device=dist.get_device()) # 1. input (word) embedding quant: VectorQuantizer2 = vae_local.quantize self.vae_proxy: Tuple[VQVAE] = (vae_local,) self.vae_quant_proxy: Tuple[VectorQuantizer2] = (quant,) self.word_embed = nn.Linear(self.Cvae, self.C) # 2. class embedding init_std = math.sqrt(1 / self.C / 3) self.num_classes = num_classes self.uniform_prob = torch.full((1, num_classes), fill_value=1.0 / num_classes, dtype=torch.float32, device=dist.get_device()) self.class_emb = nn.Embedding(self.num_classes + 1, self.C) nn.init.trunc_normal_(self.class_emb.weight.data, mean=0, std=init_std) self.pos_start = nn.Parameter(torch.empty(1, self.first_l, self.C)) nn.init.trunc_normal_(self.pos_start.data, mean=0, std=init_std) # 3. absolute position embedding pos_1LC = [] for i, pn in enumerate(self.patch_nums): pe = torch.empty(1, pn*pn, self.C) nn.init.trunc_normal_(pe, mean=0, std=init_std) pos_1LC.append(pe) pos_1LC = torch.cat(pos_1LC, dim=1) # 1, L, C assert tuple(pos_1LC.shape) == (1, self.L, self.C) self.pos_1LC = nn.Parameter(pos_1LC) # level embedding (similar to GPT's segment embedding, used to distinguish different levels of token pyramid) self.lvl_embed = nn.Embedding(len(self.patch_nums), self.C) nn.init.trunc_normal_(self.lvl_embed.weight.data, mean=0, std=init_std) # 4. backbone blocks self.shared_ada_lin = nn.Sequential(nn.SiLU(inplace=False), SharedAdaLin(self.D, 6*self.C)) if shared_aln else nn.Identity() norm_layer = partial(nn.LayerNorm, eps=norm_eps) self.drop_path_rate = drop_path_rate dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule (linearly increasing) self.blocks = nn.ModuleList([ AdaLNSelfAttn( cond_dim=self.D, shared_aln=shared_aln, block_idx=block_idx, embed_dim=self.C, norm_layer=norm_layer, num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[block_idx], last_drop_p=0 if block_idx == 0 else dpr[block_idx-1], attn_l2_norm=attn_l2_norm, flash_if_available=flash_if_available, fused_if_available=fused_if_available, ) for block_idx in range(depth) ]) fused_add_norm_fns = [b.fused_add_norm_fn is not None for b in self.blocks] self.using_fused_add_norm_fn = any(fused_add_norm_fns) print( f'\n[constructor] ==== flash_if_available={flash_if_available} ({sum(b.attn.using_flash for b in self.blocks)}/{self.depth}), fused_if_available={fused_if_available} (fusing_add_ln={sum(fused_add_norm_fns)}/{self.depth}, fusing_mlp={sum(b.ffn.fused_mlp_func is not None for b in self.blocks)}/{self.depth}) ==== \n' f' [VAR config ] embed_dim={embed_dim}, num_heads={num_heads}, depth={depth}, mlp_ratio={mlp_ratio}\n' f' [drop ratios ] drop_rate={drop_rate}, attn_drop_rate={attn_drop_rate}, drop_path_rate={drop_path_rate:g} ({torch.linspace(0, drop_path_rate, depth)})', end='\n\n', flush=True ) # 5. attention mask used in training (for masking out the future) # it won't be used in inference, since kv cache is enabled d: torch.Tensor = torch.cat([torch.full((pn*pn,), i) for i, pn in enumerate(self.patch_nums)]).view(1, self.L, 1) dT = d.transpose(1, 2) # dT: 11L lvl_1L = dT[:, 0].contiguous() self.register_buffer('lvl_1L', lvl_1L) attn_bias_for_masking = torch.where(d >= dT, 0., -torch.inf).reshape(1, 1, self.L, self.L) self.register_buffer('attn_bias_for_masking', attn_bias_for_masking.contiguous()) # 6. classifier head self.head_nm = AdaLNBeforeHead(self.C, self.D, norm_layer=norm_layer) self.head = nn.Linear(self.C, self.V) def get_logits(self, h_or_h_and_residual: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], cond_BD: Optional[torch.Tensor]): if not isinstance(h_or_h_and_residual, torch.Tensor): h, resi = h_or_h_and_residual # fused_add_norm must be used h = resi + self.blocks[-1].drop_path(h) else: # fused_add_norm is not used h = h_or_h_and_residual return self.head(self.head_nm(h.float(), cond_BD).float()).float() @torch.no_grad() def autoregressive_infer_cfg( self, B: int, label_B: Optional[Union[int, torch.LongTensor]], g_seed: Optional[int] = None, cfg=1.5, top_k=0, top_p=0.0, more_smooth=False, ) -> torch.Tensor: # returns reconstructed image (B, 3, H, W) in [0, 1] """ only used for inference, on autoregressive mode :param B: batch size :param label_B: imagenet label; if None, randomly sampled :param g_seed: random seed :param cfg: classifier-free guidance ratio :param top_k: top-k sampling :param top_p: top-p sampling :param more_smooth: smoothing the pred using gumbel softmax; only used in visualization, not used in FID/IS benchmarking :return: if returns_vemb: list of embedding h_BChw := vae_embed(idx_Bl), else: list of idx_Bl """ if g_seed is None: rng = None else: self.rng.manual_seed(g_seed); rng = self.rng if label_B is None: label_B = torch.multinomial(self.uniform_prob, num_samples=B, replacement=True, generator=rng).reshape(B) elif isinstance(label_B, int): label_B = torch.full((B,), fill_value=self.num_classes if label_B < 0 else label_B, device=self.lvl_1L.device) sos = cond_BD = self.class_emb(torch.cat((label_B, torch.full_like(label_B, fill_value=self.num_classes)), dim=0)) lvl_pos = self.lvl_embed(self.lvl_1L) + self.pos_1LC next_token_map = sos.unsqueeze(1).expand(2 * B, self.first_l, -1) + self.pos_start.expand(2 * B, self.first_l, -1) + lvl_pos[:, :self.first_l] cur_L = 0 f_hat = sos.new_zeros(B, self.Cvae, self.patch_nums[-1], self.patch_nums[-1]) for b in self.blocks: b.attn.kv_caching(True) for si, pn in enumerate(self.patch_nums): # si: i-th segment ratio = si / self.num_stages_minus_1 # last_L = cur_L cur_L += pn*pn # assert self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L].sum() == 0, f'AR with {(self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L] != 0).sum()} / {self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L].numel()} mask item' cond_BD_or_gss = self.shared_ada_lin(cond_BD) x = next_token_map AdaLNSelfAttn.forward for b in self.blocks: x = b(x=x, cond_BD=cond_BD_or_gss, attn_bias=None) logits_BlV = self.get_logits(x, cond_BD) t = cfg * ratio logits_BlV = (1+t) * logits_BlV[:B] - t * logits_BlV[B:] idx_Bl = sample_with_top_k_top_p_(logits_BlV, rng=rng, top_k=top_k, top_p=top_p, num_samples=1)[:, :, 0] if not more_smooth: # this is the default case h_BChw = self.vae_quant_proxy[0].embedding(idx_Bl) # B, l, Cvae else: # not used when evaluating FID/IS/Precision/Recall gum_t = max(0.27 * (1 - ratio * 0.95), 0.005) # refer to mask-git h_BChw = gumbel_softmax_with_rng(logits_BlV.mul(1 + ratio), tau=gum_t, hard=False, dim=-1, rng=rng) @ self.vae_quant_proxy[0].embedding.weight.unsqueeze(0) h_BChw = h_BChw.transpose_(1, 2).reshape(B, self.Cvae, pn, pn) f_hat, next_token_map = self.vae_quant_proxy[0].get_next_autoregressive_input(si, len(self.patch_nums), f_hat, h_BChw) if si != self.num_stages_minus_1: # prepare for next stage next_token_map = next_token_map.view(B, self.Cvae, -1).transpose(1, 2) next_token_map = self.word_embed(next_token_map) + lvl_pos[:, cur_L:cur_L + self.patch_nums[si+1] ** 2] next_token_map = next_token_map.repeat(2, 1, 1) # double the batch sizes due to CFG for b in self.blocks: b.attn.kv_caching(False) return self.vae_proxy[0].fhat_to_img(f_hat).add_(1).mul_(0.5) # de-normalize, from [-1, 1] to [0, 1] def forward(self, label_B: torch.LongTensor, x_BLCv_wo_first_l: torch.Tensor) -> torch.Tensor: # returns logits_BLV """ :param label_B: label_B :param x_BLCv_wo_first_l: teacher forcing input (B, self.L-self.first_l, self.Cvae) :return: logits BLV, V is vocab_size """ bg, ed = self.begin_ends[self.prog_si] if self.prog_si >= 0 else (0, self.L) B = x_BLCv_wo_first_l.shape[0] with torch.cuda.amp.autocast(enabled=False): label_B = torch.where(torch.rand(B, device=label_B.device) < self.cond_drop_rate, self.num_classes, label_B) sos = cond_BD = self.class_emb(label_B) sos = sos.unsqueeze(1).expand(B, self.first_l, -1) + self.pos_start.expand(B, self.first_l, -1) if self.prog_si == 0: x_BLC = sos else: x_BLC = torch.cat((sos, self.word_embed(x_BLCv_wo_first_l.float())), dim=1) x_BLC += self.lvl_embed(self.lvl_1L[:, :ed].expand(B, -1)) + self.pos_1LC[:, :ed] # lvl: BLC; pos: 1LC attn_bias = self.attn_bias_for_masking[:, :, :ed, :ed] cond_BD_or_gss = self.shared_ada_lin(cond_BD) # hack: get the dtype if mixed precision is used temp = x_BLC.new_ones(8, 8) main_type = torch.matmul(temp, temp).dtype x_BLC = x_BLC.to(dtype=main_type) cond_BD_or_gss = cond_BD_or_gss.to(dtype=main_type) attn_bias = attn_bias.to(dtype=main_type) AdaLNSelfAttn.forward for i, b in enumerate(self.blocks): x_BLC = b(x=x_BLC, cond_BD=cond_BD_or_gss, attn_bias=attn_bias) x_BLC = self.get_logits(x_BLC.float(), cond_BD) if self.prog_si == 0: if isinstance(self.word_embed, nn.Linear): x_BLC[0, 0, 0] += self.word_embed.weight[0, 0] * 0 + self.word_embed.bias[0] * 0 else: s = 0 for p in self.word_embed.parameters(): if p.requires_grad: s += p.view(-1)[0] * 0 x_BLC[0, 0, 0] += s return x_BLC # logits BLV, V is vocab_size def init_weights(self, init_adaln=0.5, init_adaln_gamma=1e-5, init_head=0.02, init_std=0.02, conv_std_or_gain=0.02): if init_std < 0: init_std = (1 / self.C / 3) ** 0.5 # init_std < 0: automated print(f'[init_weights] {type(self).__name__} with {init_std=:g}') for m in self.modules(): with_weight = hasattr(m, 'weight') and m.weight is not None with_bias = hasattr(m, 'bias') and m.bias is not None if isinstance(m, nn.Linear): nn.init.trunc_normal_(m.weight.data, std=init_std) if with_bias: m.bias.data.zero_() elif isinstance(m, nn.Embedding): nn.init.trunc_normal_(m.weight.data, std=init_std) if m.padding_idx is not None: m.weight.data[m.padding_idx].zero_() elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm, nn.GroupNorm, nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)): if with_weight: m.weight.data.fill_(1.) if with_bias: m.bias.data.zero_() # conv: VAR has no conv, only VQVAE has conv elif isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)): if conv_std_or_gain > 0: nn.init.trunc_normal_(m.weight.data, std=conv_std_or_gain) else: nn.init.xavier_normal_(m.weight.data, gain=-conv_std_or_gain) if with_bias: m.bias.data.zero_() if init_head >= 0: if isinstance(self.head, nn.Linear): self.head.weight.data.mul_(init_head) self.head.bias.data.zero_() elif isinstance(self.head, nn.Sequential): self.head[-1].weight.data.mul_(init_head) self.head[-1].bias.data.zero_() if isinstance(self.head_nm, AdaLNBeforeHead): self.head_nm.ada_lin[-1].weight.data.mul_(init_adaln) if hasattr(self.head_nm.ada_lin[-1], 'bias') and self.head_nm.ada_lin[-1].bias is not None: self.head_nm.ada_lin[-1].bias.data.zero_() depth = len(self.blocks) for block_idx, sab in enumerate(self.blocks): sab: AdaLNSelfAttn sab.attn.proj.weight.data.div_(math.sqrt(2 * depth)) sab.ffn.fc2.weight.data.div_(math.sqrt(2 * depth)) if hasattr(sab.ffn, 'fcg') and sab.ffn.fcg is not None: nn.init.ones_(sab.ffn.fcg.bias) nn.init.trunc_normal_(sab.ffn.fcg.weight, std=1e-5) if hasattr(sab, 'ada_lin'): sab.ada_lin[-1].weight.data[2*self.C:].mul_(init_adaln) sab.ada_lin[-1].weight.data[:2*self.C].mul_(init_adaln_gamma) if hasattr(sab.ada_lin[-1], 'bias') and sab.ada_lin[-1].bias is not None: sab.ada_lin[-1].bias.data.zero_() elif hasattr(sab, 'ada_gss'): sab.ada_gss.data[:, :, 2:].mul_(init_adaln) sab.ada_gss.data[:, :, :2].mul_(init_adaln_gamma) def extra_repr(self): return f'drop_path_rate={self.drop_path_rate:g}' class VARHF(VAR, PyTorchModelHubMixin): # repo_url="https://github.com/FoundationVision/VAR", # tags=["image-generation"]): def __init__( self, vae_kwargs, num_classes=1000, depth=16, embed_dim=1024, num_heads=16, mlp_ratio=4., drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_eps=1e-6, shared_aln=False, cond_drop_rate=0.1, attn_l2_norm=False, patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # 10 steps by default flash_if_available=True, fused_if_available=True, ): vae_local = VQVAE(**vae_kwargs) super().__init__( vae_local=vae_local, num_classes=num_classes, depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate, norm_eps=norm_eps, shared_aln=shared_aln, cond_drop_rate=cond_drop_rate, attn_l2_norm=attn_l2_norm, patch_nums=patch_nums, flash_if_available=flash_if_available, fused_if_available=fused_if_available, ) ================================================ FILE: models/vqvae.py ================================================ """ References: - VectorQuantizer2: https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/taming/modules/vqvae/quantize.py#L110 - GumbelQuantize: https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/taming/modules/vqvae/quantize.py#L213 - VQVAE (VQModel): https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/models/autoencoder.py#L14 """ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import torch import torch.nn as nn from .basic_vae import Decoder, Encoder from .quant import VectorQuantizer2 class VQVAE(nn.Module): def __init__( self, vocab_size=4096, z_channels=32, ch=128, dropout=0.0, beta=0.25, # commitment loss weight using_znorm=False, # whether to normalize when computing the nearest neighbors quant_conv_ks=3, # quant conv kernel size quant_resi=0.5, # 0.5 means \phi(x) = 0.5conv(x) + (1-0.5)x share_quant_resi=4, # use 4 \phi layers for K scales: partially-shared \phi default_qresi_counts=0, # if is 0: automatically set to len(v_patch_nums) v_patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # number of patches for each scale, h_{1 to K} = w_{1 to K} = v_patch_nums[k] test_mode=True, ): super().__init__() self.test_mode = test_mode self.V, self.Cvae = vocab_size, z_channels # ddconfig is copied from https://github.com/CompVis/latent-diffusion/blob/e66308c7f2e64cb581c6d27ab6fbeb846828253b/models/first_stage_models/vq-f16/config.yaml ddconfig = dict( dropout=dropout, ch=ch, z_channels=z_channels, in_channels=3, ch_mult=(1, 1, 2, 2, 4), num_res_blocks=2, # from vq-f16/config.yaml above using_sa=True, using_mid_sa=True, # from vq-f16/config.yaml above # resamp_with_conv=True, # always True, removed. ) ddconfig.pop('double_z', None) # only KL-VAE should use double_z=True self.encoder = Encoder(double_z=False, **ddconfig) self.decoder = Decoder(**ddconfig) self.vocab_size = vocab_size self.downsample = 2 ** (len(ddconfig['ch_mult'])-1) self.quantize: VectorQuantizer2 = VectorQuantizer2( vocab_size=vocab_size, Cvae=self.Cvae, using_znorm=using_znorm, beta=beta, default_qresi_counts=default_qresi_counts, v_patch_nums=v_patch_nums, quant_resi=quant_resi, share_quant_resi=share_quant_resi, ) self.quant_conv = torch.nn.Conv2d(self.Cvae, self.Cvae, quant_conv_ks, stride=1, padding=quant_conv_ks//2) self.post_quant_conv = torch.nn.Conv2d(self.Cvae, self.Cvae, quant_conv_ks, stride=1, padding=quant_conv_ks//2) if self.test_mode: self.eval() [p.requires_grad_(False) for p in self.parameters()] # ===================== `forward` is only used in VAE training ===================== def forward(self, inp, ret_usages=False): # -> rec_B3HW, idx_N, loss VectorQuantizer2.forward f_hat, usages, vq_loss = self.quantize(self.quant_conv(self.encoder(inp)), ret_usages=ret_usages) return self.decoder(self.post_quant_conv(f_hat)), usages, vq_loss # ===================== `forward` is only used in VAE training ===================== def fhat_to_img(self, f_hat: torch.Tensor): return self.decoder(self.post_quant_conv(f_hat)).clamp_(-1, 1) def img_to_idxBl(self, inp_img_no_grad: torch.Tensor, v_patch_nums: Optional[Sequence[Union[int, Tuple[int, int]]]] = None) -> List[torch.LongTensor]: # return List[Bl] f = self.quant_conv(self.encoder(inp_img_no_grad)) return self.quantize.f_to_idxBl_or_fhat(f, to_fhat=False, v_patch_nums=v_patch_nums) def idxBl_to_img(self, ms_idx_Bl: List[torch.Tensor], same_shape: bool, last_one=False) -> Union[List[torch.Tensor], torch.Tensor]: B = ms_idx_Bl[0].shape[0] ms_h_BChw = [] for idx_Bl in ms_idx_Bl: l = idx_Bl.shape[1] pn = round(l ** 0.5) ms_h_BChw.append(self.quantize.embedding(idx_Bl).transpose(1, 2).view(B, self.Cvae, pn, pn)) return self.embed_to_img(ms_h_BChw=ms_h_BChw, all_to_max_scale=same_shape, last_one=last_one) def embed_to_img(self, ms_h_BChw: List[torch.Tensor], all_to_max_scale: bool, last_one=False) -> Union[List[torch.Tensor], torch.Tensor]: if last_one: return self.decoder(self.post_quant_conv(self.quantize.embed_to_fhat(ms_h_BChw, all_to_max_scale=all_to_max_scale, last_one=True))).clamp_(-1, 1) else: return [self.decoder(self.post_quant_conv(f_hat)).clamp_(-1, 1) for f_hat in self.quantize.embed_to_fhat(ms_h_BChw, all_to_max_scale=all_to_max_scale, last_one=False)] def img_to_reconstructed_img(self, x, v_patch_nums: Optional[Sequence[Union[int, Tuple[int, int]]]] = None, last_one=False) -> List[torch.Tensor]: f = self.quant_conv(self.encoder(x)) ls_f_hat_BChw = self.quantize.f_to_idxBl_or_fhat(f, to_fhat=True, v_patch_nums=v_patch_nums) if last_one: return self.decoder(self.post_quant_conv(ls_f_hat_BChw[-1])).clamp_(-1, 1) else: return [self.decoder(self.post_quant_conv(f_hat)).clamp_(-1, 1) for f_hat in ls_f_hat_BChw] def load_state_dict(self, state_dict: Dict[str, Any], strict=True, assign=False): if 'quantize.ema_vocab_hit_SV' in state_dict and state_dict['quantize.ema_vocab_hit_SV'].shape[0] != self.quantize.ema_vocab_hit_SV.shape[0]: state_dict['quantize.ema_vocab_hit_SV'] = self.quantize.ema_vocab_hit_SV return super().load_state_dict(state_dict=state_dict, strict=strict, assign=assign) ================================================ FILE: train.py ================================================ import gc import os import shutil import sys import time import warnings from functools import partial import torch from torch.utils.data import DataLoader import dist from utils import arg_util, misc from utils.data import build_dataset from utils.data_sampler import DistInfiniteBatchSampler, EvalDistributedSampler from utils.misc import auto_resume def build_everything(args: arg_util.Args): # resume auto_resume_info, start_ep, start_it, trainer_state, args_state = auto_resume(args, 'ar-ckpt*.pth') # create tensorboard logger tb_lg: misc.TensorboardLogger with_tb_lg = dist.is_master() if with_tb_lg: os.makedirs(args.tb_log_dir_path, exist_ok=True) # noinspection PyTypeChecker tb_lg = misc.DistLogger(misc.TensorboardLogger(log_dir=args.tb_log_dir_path, filename_suffix=f'__{misc.time_str("%m%d_%H%M")}'), verbose=True) tb_lg.flush() else: # noinspection PyTypeChecker tb_lg = misc.DistLogger(None, verbose=False) dist.barrier() # log args print(f'global bs={args.glb_batch_size}, local bs={args.batch_size}') print(f'initial args:\n{str(args)}') # build data if not args.local_debug: print(f'[build PT data] ...\n') num_classes, dataset_train, dataset_val = build_dataset( args.data_path, final_reso=args.data_load_reso, hflip=args.hflip, mid_reso=args.mid_reso, ) types = str((type(dataset_train).__name__, type(dataset_val).__name__)) ld_val = DataLoader( dataset_val, num_workers=0, pin_memory=True, batch_size=round(args.batch_size*1.5), sampler=EvalDistributedSampler(dataset_val, num_replicas=dist.get_world_size(), rank=dist.get_rank()), shuffle=False, drop_last=False, ) del dataset_val ld_train = DataLoader( dataset=dataset_train, num_workers=args.workers, pin_memory=True, generator=args.get_different_generator_for_each_rank(), # worker_init_fn=worker_init_fn, batch_sampler=DistInfiniteBatchSampler( dataset_len=len(dataset_train), glb_batch_size=args.glb_batch_size, same_seed_for_all_ranks=args.same_seed_for_all_ranks, shuffle=True, fill_last=True, rank=dist.get_rank(), world_size=dist.get_world_size(), start_ep=start_ep, start_it=start_it, ), ) del dataset_train [print(line) for line in auto_resume_info] print(f'[dataloader multi processing] ...', end='', flush=True) stt = time.time() iters_train = len(ld_train) ld_train = iter(ld_train) # noinspection PyArgumentList print(f' [dataloader multi processing](*) finished! ({time.time()-stt:.2f}s)', flush=True, clean=True) print(f'[dataloader] gbs={args.glb_batch_size}, lbs={args.batch_size}, iters_train={iters_train}, types(tr, va)={types}') else: num_classes = 1000 ld_val = ld_train = None iters_train = 10 # build models from torch.nn.parallel import DistributedDataParallel as DDP from models import VAR, VQVAE, build_vae_var from trainer import VARTrainer from utils.amp_sc import AmpOptimizer from utils.lr_control import filter_params vae_local, var_wo_ddp = build_vae_var( V=4096, Cvae=32, ch=160, share_quant_resi=4, # hard-coded VQVAE hyperparameters device=dist.get_device(), patch_nums=args.patch_nums, num_classes=num_classes, depth=args.depth, shared_aln=args.saln, attn_l2_norm=args.anorm, flash_if_available=args.fuse, fused_if_available=args.fuse, init_adaln=args.aln, init_adaln_gamma=args.alng, init_head=args.hd, init_std=args.ini, ) vae_ckpt = 'vae_ch160v4096z32.pth' if dist.is_local_master(): if not os.path.exists(vae_ckpt): os.system(f'wget https://huggingface.co/FoundationVision/var/resolve/main/{vae_ckpt}') dist.barrier() vae_local.load_state_dict(torch.load(vae_ckpt, map_location='cpu'), strict=True) vae_local: VQVAE = args.compile_model(vae_local, args.vfast) var_wo_ddp: VAR = args.compile_model(var_wo_ddp, args.tfast) var: DDP = (DDP if dist.initialized() else NullDDP)(var_wo_ddp, device_ids=[dist.get_local_rank()], find_unused_parameters=False, broadcast_buffers=False) print(f'[INIT] VAR model = {var_wo_ddp}\n\n') count_p = lambda m: f'{sum(p.numel() for p in m.parameters())/1e6:.2f}' print(f'[INIT][#para] ' + ', '.join([f'{k}={count_p(m)}' for k, m in (('VAE', vae_local), ('VAE.enc', vae_local.encoder), ('VAE.dec', vae_local.decoder), ('VAE.quant', vae_local.quantize))])) print(f'[INIT][#para] ' + ', '.join([f'{k}={count_p(m)}' for k, m in (('VAR', var_wo_ddp),)]) + '\n\n') # build optimizer names, paras, para_groups = filter_params(var_wo_ddp, nowd_keys={ 'cls_token', 'start_token', 'task_token', 'cfg_uncond', 'pos_embed', 'pos_1LC', 'pos_start', 'start_pos', 'lvl_embed', 'gamma', 'beta', 'ada_gss', 'moe_bias', 'scale_mul', }) opt_clz = { 'adam': partial(torch.optim.AdamW, betas=(0.9, 0.95), fused=args.afuse), 'adamw': partial(torch.optim.AdamW, betas=(0.9, 0.95), fused=args.afuse), }[args.opt.lower().strip()] opt_kw = dict(lr=args.tlr, weight_decay=0) print(f'[INIT] optim={opt_clz}, opt_kw={opt_kw}\n') var_optim = AmpOptimizer( mixed_precision=args.fp16, optimizer=opt_clz(params=para_groups, **opt_kw), names=names, paras=paras, grad_clip=args.tclip, n_gradient_accumulation=args.ac ) del names, paras, para_groups # build trainer trainer = VARTrainer( device=args.device, patch_nums=args.patch_nums, resos=args.resos, vae_local=vae_local, var_wo_ddp=var_wo_ddp, var=var, var_opt=var_optim, label_smooth=args.ls, ) if trainer_state is not None and len(trainer_state): trainer.load_state_dict(trainer_state, strict=False, skip_vae=True) # don't load vae again del vae_local, var_wo_ddp, var, var_optim if args.local_debug: rng = torch.Generator('cpu') rng.manual_seed(0) B = 4 inp = torch.rand(B, 3, args.data_load_reso, args.data_load_reso) label = torch.ones(B, dtype=torch.long) me = misc.MetricLogger(delimiter=' ') trainer.train_step( it=0, g_it=0, stepping=True, metric_lg=me, tb_lg=tb_lg, inp_B3HW=inp, label_B=label, prog_si=args.pg0, prog_wp_it=20, ) trainer.load_state_dict(trainer.state_dict()) trainer.train_step( it=99, g_it=599, stepping=True, metric_lg=me, tb_lg=tb_lg, inp_B3HW=inp, label_B=label, prog_si=-1, prog_wp_it=20, ) print({k: meter.global_avg for k, meter in me.meters.items()}) args.dump_log(); tb_lg.flush(); tb_lg.close() if isinstance(sys.stdout, misc.SyncPrint) and isinstance(sys.stderr, misc.SyncPrint): sys.stdout.close(), sys.stderr.close() exit(0) dist.barrier() return ( tb_lg, trainer, start_ep, start_it, iters_train, ld_train, ld_val ) def main_training(): args: arg_util.Args = arg_util.init_dist_and_get_args() if args.local_debug: torch.autograd.set_detect_anomaly(True) ( tb_lg, trainer, start_ep, start_it, iters_train, ld_train, ld_val ) = build_everything(args) # train start_time = time.time() best_L_mean, best_L_tail, best_acc_mean, best_acc_tail = 999., 999., -1., -1. best_val_loss_mean, best_val_loss_tail, best_val_acc_mean, best_val_acc_tail = 999, 999, -1, -1 L_mean, L_tail = -1, -1 for ep in range(start_ep, args.ep): if hasattr(ld_train, 'sampler') and hasattr(ld_train.sampler, 'set_epoch'): ld_train.sampler.set_epoch(ep) if ep < 3: # noinspection PyArgumentList print(f'[{type(ld_train).__name__}] [ld_train.sampler.set_epoch({ep})]', flush=True, force=True) tb_lg.set_step(ep * iters_train) stats, (sec, remain_time, finish_time) = train_one_ep( ep, ep == start_ep, start_it if ep == start_ep else 0, args, tb_lg, ld_train, iters_train, trainer ) L_mean, L_tail, acc_mean, acc_tail, grad_norm = stats['Lm'], stats['Lt'], stats['Accm'], stats['Acct'], stats['tnm'] best_L_mean, best_acc_mean = min(best_L_mean, L_mean), max(best_acc_mean, acc_mean) if L_tail != -1: best_L_tail, best_acc_tail = min(best_L_tail, L_tail), max(best_acc_tail, acc_tail) args.L_mean, args.L_tail, args.acc_mean, args.acc_tail, args.grad_norm = L_mean, L_tail, acc_mean, acc_tail, grad_norm args.cur_ep = f'{ep+1}/{args.ep}' args.remain_time, args.finish_time = remain_time, finish_time AR_ep_loss = dict(L_mean=L_mean, L_tail=L_tail, acc_mean=acc_mean, acc_tail=acc_tail) is_val_and_also_saving = (ep + 1) % 10 == 0 or (ep + 1) == args.ep if is_val_and_also_saving: val_loss_mean, val_loss_tail, val_acc_mean, val_acc_tail, tot, cost = trainer.eval_ep(ld_val) best_updated = best_val_loss_tail > val_loss_tail best_val_loss_mean, best_val_loss_tail = min(best_val_loss_mean, val_loss_mean), min(best_val_loss_tail, val_loss_tail) best_val_acc_mean, best_val_acc_tail = max(best_val_acc_mean, val_acc_mean), max(best_val_acc_tail, val_acc_tail) AR_ep_loss.update(vL_mean=val_loss_mean, vL_tail=val_loss_tail, vacc_mean=val_acc_mean, vacc_tail=val_acc_tail) args.vL_mean, args.vL_tail, args.vacc_mean, args.vacc_tail = val_loss_mean, val_loss_tail, val_acc_mean, val_acc_tail print(f' [*] [ep{ep}] (val {tot}) Lm: {L_mean:.4f}, Lt: {L_tail:.4f}, Acc m&t: {acc_mean:.2f} {acc_tail:.2f}, Val cost: {cost:.2f}s') if dist.is_local_master(): local_out_ckpt = os.path.join(args.local_out_dir_path, 'ar-ckpt-last.pth') local_out_ckpt_best = os.path.join(args.local_out_dir_path, 'ar-ckpt-best.pth') print(f'[saving ckpt] ...', end='', flush=True) torch.save({ 'epoch': ep+1, 'iter': 0, 'trainer': trainer.state_dict(), 'args': args.state_dict(), }, local_out_ckpt) if best_updated: shutil.copy(local_out_ckpt, local_out_ckpt_best) print(f' [saving ckpt](*) finished! @ {local_out_ckpt}', flush=True, clean=True) dist.barrier() print( f' [ep{ep}] (training ) Lm: {best_L_mean:.3f} ({L_mean:.3f}), Lt: {best_L_tail:.3f} ({L_tail:.3f}), Acc m&t: {best_acc_mean:.2f} {best_acc_tail:.2f}, Remain: {remain_time}, Finish: {finish_time}', flush=True) tb_lg.update(head='AR_ep_loss', step=ep+1, **AR_ep_loss) tb_lg.update(head='AR_z_burnout', step=ep+1, rest_hours=round(sec / 60 / 60, 2)) args.dump_log(); tb_lg.flush() total_time = f'{(time.time() - start_time) / 60 / 60:.1f}h' print('\n\n') print(f' [*] [PT finished] Total cost: {total_time}, Lm: {best_L_mean:.3f} ({L_mean}), Lt: {best_L_tail:.3f} ({L_tail})') print('\n\n') del stats del iters_train, ld_train time.sleep(3), gc.collect(), torch.cuda.empty_cache(), time.sleep(3) args.remain_time, args.finish_time = '-', time.strftime("%Y-%m-%d %H:%M", time.localtime(time.time() - 60)) print(f'final args:\n\n{str(args)}') args.dump_log(); tb_lg.flush(); tb_lg.close() dist.barrier() def train_one_ep(ep: int, is_first_ep: bool, start_it: int, args: arg_util.Args, tb_lg: misc.TensorboardLogger, ld_or_itrt, iters_train: int, trainer): # import heavy packages after Dataloader object creation from trainer import VARTrainer from utils.lr_control import lr_wd_annealing trainer: VARTrainer step_cnt = 0 me = misc.MetricLogger(delimiter=' ') me.add_meter('tlr', misc.SmoothedValue(window_size=1, fmt='{value:.2g}')) me.add_meter('tnm', misc.SmoothedValue(window_size=1, fmt='{value:.2f}')) [me.add_meter(x, misc.SmoothedValue(fmt='{median:.3f} ({global_avg:.3f})')) for x in ['Lm', 'Lt']] [me.add_meter(x, misc.SmoothedValue(fmt='{median:.2f} ({global_avg:.2f})')) for x in ['Accm', 'Acct']] header = f'[Ep]: [{ep:4d}/{args.ep}]' if is_first_ep: warnings.filterwarnings('ignore', category=DeprecationWarning) warnings.filterwarnings('ignore', category=UserWarning) g_it, max_it = ep * iters_train, args.ep * iters_train for it, (inp, label) in me.log_every(start_it, iters_train, ld_or_itrt, 30 if iters_train > 8000 else 5, header): g_it = ep * iters_train + it if it < start_it: continue if is_first_ep and it == start_it: warnings.resetwarnings() inp = inp.to(args.device, non_blocking=True) label = label.to(args.device, non_blocking=True) args.cur_it = f'{it+1}/{iters_train}' wp_it = args.wp * iters_train min_tlr, max_tlr, min_twd, max_twd = lr_wd_annealing(args.sche, trainer.var_opt.optimizer, args.tlr, args.twd, args.twde, g_it, wp_it, max_it, wp0=args.wp0, wpe=args.wpe) args.cur_lr, args.cur_wd = max_tlr, max_twd if args.pg: # default: args.pg == 0.0, means no progressive training, won't get into this if g_it <= wp_it: prog_si = args.pg0 elif g_it >= max_it*args.pg: prog_si = len(args.patch_nums) - 1 else: delta = len(args.patch_nums) - 1 - args.pg0 progress = min(max((g_it - wp_it) / (max_it*args.pg - wp_it), 0), 1) # from 0 to 1 prog_si = args.pg0 + round(progress * delta) # from args.pg0 to len(args.patch_nums)-1 else: prog_si = -1 stepping = (g_it + 1) % args.ac == 0 step_cnt += int(stepping) grad_norm, scale_log2 = trainer.train_step( it=it, g_it=g_it, stepping=stepping, metric_lg=me, tb_lg=tb_lg, inp_B3HW=inp, label_B=label, prog_si=prog_si, prog_wp_it=args.pgwp * iters_train, ) me.update(tlr=max_tlr) tb_lg.set_step(step=g_it) tb_lg.update(head='AR_opt_lr/lr_min', sche_tlr=min_tlr) tb_lg.update(head='AR_opt_lr/lr_max', sche_tlr=max_tlr) tb_lg.update(head='AR_opt_wd/wd_max', sche_twd=max_twd) tb_lg.update(head='AR_opt_wd/wd_min', sche_twd=min_twd) tb_lg.update(head='AR_opt_grad/fp16', scale_log2=scale_log2) if args.tclip > 0: tb_lg.update(head='AR_opt_grad/grad', grad_norm=grad_norm) tb_lg.update(head='AR_opt_grad/grad', grad_clip=args.tclip) me.synchronize_between_processes() return {k: meter.global_avg for k, meter in me.meters.items()}, me.iter_time.time_preds(max_it - (g_it + 1) + (args.ep - ep) * 15) # +15: other cost class NullDDP(torch.nn.Module): def __init__(self, module, *args, **kwargs): super(NullDDP, self).__init__() self.module = module self.require_backward_grad_sync = False def forward(self, *args, **kwargs): return self.module(*args, **kwargs) if __name__ == '__main__': try: main_training() finally: dist.finalize() if isinstance(sys.stdout, misc.SyncPrint) and isinstance(sys.stderr, misc.SyncPrint): sys.stdout.close(), sys.stderr.close() ================================================ FILE: trainer.py ================================================ import time from typing import List, Optional, Tuple, Union import torch import torch.nn as nn from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader import dist from models import VAR, VQVAE, VectorQuantizer2 from utils.amp_sc import AmpOptimizer from utils.misc import MetricLogger, TensorboardLogger Ten = torch.Tensor FTen = torch.Tensor ITen = torch.LongTensor BTen = torch.BoolTensor class VARTrainer(object): def __init__( self, device, patch_nums: Tuple[int, ...], resos: Tuple[int, ...], vae_local: VQVAE, var_wo_ddp: VAR, var: DDP, var_opt: AmpOptimizer, label_smooth: float, ): super(VARTrainer, self).__init__() self.var, self.vae_local, self.quantize_local = var, vae_local, vae_local.quantize self.quantize_local: VectorQuantizer2 self.var_wo_ddp: VAR = var_wo_ddp # after torch.compile self.var_opt = var_opt del self.var_wo_ddp.rng self.var_wo_ddp.rng = torch.Generator(device=device) self.label_smooth = label_smooth self.train_loss = nn.CrossEntropyLoss(label_smoothing=label_smooth, reduction='none') self.val_loss = nn.CrossEntropyLoss(label_smoothing=0.0, reduction='mean') self.L = sum(pn * pn for pn in patch_nums) self.last_l = patch_nums[-1] * patch_nums[-1] self.loss_weight = torch.ones(1, self.L, device=device) / self.L self.patch_nums, self.resos = patch_nums, resos self.begin_ends = [] cur = 0 for i, pn in enumerate(patch_nums): self.begin_ends.append((cur, cur + pn * pn)) cur += pn*pn self.prog_it = 0 self.last_prog_si = -1 self.first_prog = True @torch.no_grad() def eval_ep(self, ld_val: DataLoader): tot = 0 L_mean, L_tail, acc_mean, acc_tail = 0, 0, 0, 0 stt = time.time() training = self.var_wo_ddp.training self.var_wo_ddp.eval() for inp_B3HW, label_B in ld_val: B, V = label_B.shape[0], self.vae_local.vocab_size inp_B3HW = inp_B3HW.to(dist.get_device(), non_blocking=True) label_B = label_B.to(dist.get_device(), non_blocking=True) gt_idx_Bl: List[ITen] = self.vae_local.img_to_idxBl(inp_B3HW) gt_BL = torch.cat(gt_idx_Bl, dim=1) x_BLCv_wo_first_l: Ten = self.quantize_local.idxBl_to_var_input(gt_idx_Bl) self.var_wo_ddp.forward logits_BLV = self.var_wo_ddp(label_B, x_BLCv_wo_first_l) L_mean += self.val_loss(logits_BLV.data.view(-1, V), gt_BL.view(-1)) * B L_tail += self.val_loss(logits_BLV.data[:, -self.last_l:].reshape(-1, V), gt_BL[:, -self.last_l:].reshape(-1)) * B acc_mean += (logits_BLV.data.argmax(dim=-1) == gt_BL).sum() * (100/gt_BL.shape[1]) acc_tail += (logits_BLV.data[:, -self.last_l:].argmax(dim=-1) == gt_BL[:, -self.last_l:]).sum() * (100 / self.last_l) tot += B self.var_wo_ddp.train(training) stats = L_mean.new_tensor([L_mean.item(), L_tail.item(), acc_mean.item(), acc_tail.item(), tot]) dist.allreduce(stats) tot = round(stats[-1].item()) stats /= tot L_mean, L_tail, acc_mean, acc_tail, _ = stats.tolist() return L_mean, L_tail, acc_mean, acc_tail, tot, time.time()-stt def train_step( self, it: int, g_it: int, stepping: bool, metric_lg: MetricLogger, tb_lg: TensorboardLogger, inp_B3HW: FTen, label_B: Union[ITen, FTen], prog_si: int, prog_wp_it: float, ) -> Tuple[Optional[Union[Ten, float]], Optional[float]]: # if progressive training self.var_wo_ddp.prog_si = self.vae_local.quantize.prog_si = prog_si if self.last_prog_si != prog_si: if self.last_prog_si != -1: self.first_prog = False self.last_prog_si = prog_si self.prog_it = 0 self.prog_it += 1 prog_wp = max(min(self.prog_it / prog_wp_it, 1), 0.01) if self.first_prog: prog_wp = 1 # no prog warmup at first prog stage, as it's already solved in wp if prog_si == len(self.patch_nums) - 1: prog_si = -1 # max prog, as if no prog # forward B, V = label_B.shape[0], self.vae_local.vocab_size self.var.require_backward_grad_sync = stepping gt_idx_Bl: List[ITen] = self.vae_local.img_to_idxBl(inp_B3HW) gt_BL = torch.cat(gt_idx_Bl, dim=1) x_BLCv_wo_first_l: Ten = self.quantize_local.idxBl_to_var_input(gt_idx_Bl) with self.var_opt.amp_ctx: self.var_wo_ddp.forward logits_BLV = self.var(label_B, x_BLCv_wo_first_l) loss = self.train_loss(logits_BLV.view(-1, V), gt_BL.view(-1)).view(B, -1) if prog_si >= 0: # in progressive training bg, ed = self.begin_ends[prog_si] assert logits_BLV.shape[1] == gt_BL.shape[1] == ed lw = self.loss_weight[:, :ed].clone() lw[:, bg:ed] *= min(max(prog_wp, 0), 1) else: # not in progressive training lw = self.loss_weight loss = loss.mul(lw).sum(dim=-1).mean() # backward grad_norm, scale_log2 = self.var_opt.backward_clip_step(loss=loss, stepping=stepping) # log pred_BL = logits_BLV.data.argmax(dim=-1) if it == 0 or it in metric_lg.log_iters: Lmean = self.val_loss(logits_BLV.data.view(-1, V), gt_BL.view(-1)).item() acc_mean = (pred_BL == gt_BL).float().mean().item() * 100 if prog_si >= 0: # in progressive training Ltail = acc_tail = -1 else: # not in progressive training Ltail = self.val_loss(logits_BLV.data[:, -self.last_l:].reshape(-1, V), gt_BL[:, -self.last_l:].reshape(-1)).item() acc_tail = (pred_BL[:, -self.last_l:] == gt_BL[:, -self.last_l:]).float().mean().item() * 100 grad_norm = grad_norm.item() metric_lg.update(Lm=Lmean, Lt=Ltail, Accm=acc_mean, Acct=acc_tail, tnm=grad_norm) # log to tensorboard if g_it == 0 or (g_it + 1) % 500 == 0: prob_per_class_is_chosen = pred_BL.view(-1).bincount(minlength=V).float() dist.allreduce(prob_per_class_is_chosen) prob_per_class_is_chosen /= prob_per_class_is_chosen.sum() cluster_usage = (prob_per_class_is_chosen > 0.001 / V).float().mean().item() * 100 if dist.is_master(): if g_it == 0: tb_lg.update(head='AR_iter_loss', z_voc_usage=cluster_usage, step=-10000) tb_lg.update(head='AR_iter_loss', z_voc_usage=cluster_usage, step=-1000) kw = dict(z_voc_usage=cluster_usage) for si, (bg, ed) in enumerate(self.begin_ends): if 0 <= prog_si < si: break pred, tar = logits_BLV.data[:, bg:ed].reshape(-1, V), gt_BL[:, bg:ed].reshape(-1) acc = (pred.argmax(dim=-1) == tar).float().mean().item() * 100 ce = self.val_loss(pred, tar).item() kw[f'acc_{self.resos[si]}'] = acc kw[f'L_{self.resos[si]}'] = ce tb_lg.update(head='AR_iter_loss', **kw, step=g_it) tb_lg.update(head='AR_iter_schedule', prog_a_reso=self.resos[prog_si], prog_si=prog_si, prog_wp=prog_wp, step=g_it) self.var_wo_ddp.prog_si = self.vae_local.quantize.prog_si = -1 return grad_norm, scale_log2 def get_config(self): return { 'patch_nums': self.patch_nums, 'resos': self.resos, 'label_smooth': self.label_smooth, 'prog_it': self.prog_it, 'last_prog_si': self.last_prog_si, 'first_prog': self.first_prog, } def state_dict(self): state = {'config': self.get_config()} for k in ('var_wo_ddp', 'vae_local', 'var_opt'): m = getattr(self, k) if m is not None: if hasattr(m, '_orig_mod'): m = m._orig_mod state[k] = m.state_dict() return state def load_state_dict(self, state, strict=True, skip_vae=False): for k in ('var_wo_ddp', 'vae_local', 'var_opt'): if skip_vae and 'vae' in k: continue m = getattr(self, k) if m is not None: if hasattr(m, '_orig_mod'): m = m._orig_mod ret = m.load_state_dict(state[k], strict=strict) if ret is not None: missing, unexpected = ret print(f'[VARTrainer.load_state_dict] {k} missing: {missing}') print(f'[VARTrainer.load_state_dict] {k} unexpected: {unexpected}') config: dict = state.pop('config', None) self.prog_it = config.get('prog_it', 0) self.last_prog_si = config.get('last_prog_si', -1) self.first_prog = config.get('first_prog', True) if config is not None: for k, v in self.get_config().items(): if config.get(k, None) != v: err = f'[VAR.load_state_dict] config mismatch: this.{k}={v} (ckpt.{k}={config.get(k, None)})' if strict: raise AttributeError(err) else: print(err) ================================================ FILE: utils/amp_sc.py ================================================ import math from typing import List, Optional, Tuple, Union import torch class NullCtx: def __enter__(self): pass def __exit__(self, exc_type, exc_val, exc_tb): pass class AmpOptimizer: def __init__( self, mixed_precision: int, optimizer: torch.optim.Optimizer, names: List[str], paras: List[torch.nn.Parameter], grad_clip: float, n_gradient_accumulation: int = 1, ): self.enable_amp = mixed_precision > 0 self.using_fp16_rather_bf16 = mixed_precision == 1 if self.enable_amp: self.amp_ctx = torch.autocast('cuda', enabled=True, dtype=torch.float16 if self.using_fp16_rather_bf16 else torch.bfloat16, cache_enabled=True) self.scaler = torch.cuda.amp.GradScaler(init_scale=2. ** 11, growth_interval=1000) if self.using_fp16_rather_bf16 else None # only fp16 needs a scaler else: self.amp_ctx = NullCtx() self.scaler = None self.optimizer, self.names, self.paras = optimizer, names, paras # paras have been filtered so everyone requires grad self.grad_clip = grad_clip self.early_clipping = self.grad_clip > 0 and not hasattr(optimizer, 'global_grad_norm') self.late_clipping = self.grad_clip > 0 and hasattr(optimizer, 'global_grad_norm') self.r_accu = 1 / n_gradient_accumulation # r_accu == 1.0 / n_gradient_accumulation def backward_clip_step( self, stepping: bool, loss: torch.Tensor, ) -> Tuple[Optional[Union[torch.Tensor, float]], Optional[float]]: # backward loss = loss.mul(self.r_accu) # r_accu == 1.0 / n_gradient_accumulation orig_norm = scaler_sc = None if self.scaler is not None: self.scaler.scale(loss).backward(retain_graph=False, create_graph=False) else: loss.backward(retain_graph=False, create_graph=False) if stepping: if self.scaler is not None: self.scaler.unscale_(self.optimizer) if self.early_clipping: orig_norm = torch.nn.utils.clip_grad_norm_(self.paras, self.grad_clip) if self.scaler is not None: self.scaler.step(self.optimizer) scaler_sc: float = self.scaler.get_scale() if scaler_sc > 32768.: # fp16 will overflow when >65536, so multiply 32768 could be dangerous self.scaler.update(new_scale=32768.) else: self.scaler.update() try: scaler_sc = float(math.log2(scaler_sc)) except Exception as e: print(f'[scaler_sc = {scaler_sc}]\n' * 15, flush=True) raise e else: self.optimizer.step() if self.late_clipping: orig_norm = self.optimizer.global_grad_norm self.optimizer.zero_grad(set_to_none=True) return orig_norm, scaler_sc def state_dict(self): return { 'optimizer': self.optimizer.state_dict() } if self.scaler is None else { 'scaler': self.scaler.state_dict(), 'optimizer': self.optimizer.state_dict() } def load_state_dict(self, state, strict=True): if self.scaler is not None: try: self.scaler.load_state_dict(state['scaler']) except Exception as e: print(f'[fp16 load_state_dict err] {e}') self.optimizer.load_state_dict(state['optimizer']) ================================================ FILE: utils/arg_util.py ================================================ import json import os import random import re import subprocess import sys import time from collections import OrderedDict from typing import Optional, Union import numpy as np import torch try: from tap import Tap except ImportError as e: print(f'`>>>>>>>> from tap import Tap` failed, please run: pip3 install typed-argument-parser <<<<<<<<', file=sys.stderr, flush=True) print(f'`>>>>>>>> from tap import Tap` failed, please run: pip3 install typed-argument-parser <<<<<<<<', file=sys.stderr, flush=True) time.sleep(5) raise e import dist class Args(Tap): data_path: str = '/path/to/imagenet' exp_name: str = 'text' # VAE vfast: int = 0 # torch.compile VAE; =0: not compile; 1: compile with 'reduce-overhead'; 2: compile with 'max-autotune' # VAR tfast: int = 0 # torch.compile VAR; =0: not compile; 1: compile with 'reduce-overhead'; 2: compile with 'max-autotune' depth: int = 16 # VAR depth # VAR initialization ini: float = -1 # -1: automated model parameter initialization hd: float = 0.02 # head.w *= hd aln: float = 0.5 # the multiplier of ada_lin.w's initialization alng: float = 1e-5 # the multiplier of ada_lin.w[gamma channels]'s initialization # VAR optimization fp16: int = 0 # 1: using fp16, 2: bf16 tblr: float = 1e-4 # base lr tlr: float = None # lr = base lr * (bs / 256) twd: float = 0.05 # initial wd twde: float = 0 # final wd, =twde or twd tclip: float = 2. # <=0 for not using grad clip ls: float = 0.0 # label smooth bs: int = 768 # global batch size batch_size: int = 0 # [automatically set; don't specify this] batch size per GPU = round(args.bs / args.ac / dist.get_world_size() / 8) * 8 glb_batch_size: int = 0 # [automatically set; don't specify this] global batch size = args.batch_size * dist.get_world_size() ac: int = 1 # gradient accumulation ep: int = 250 wp: float = 0 wp0: float = 0.005 # initial lr ratio at the begging of lr warm up wpe: float = 0.01 # final lr ratio at the end of training sche: str = 'lin0' # lr schedule opt: str = 'adamw' # lion: https://cloud.tencent.com/developer/article/2336657?areaId=106001 lr=5e-5 (0.25x) wd=0.8 (8x); Lion needs a large bs to work afuse: bool = True # fused adamw # other hps saln: bool = False # whether to use shared adaln anorm: bool = True # whether to use L2 normalized attention fuse: bool = True # whether to use fused op like flash attn, xformers, fused MLP, fused LayerNorm, etc. # data pn: str = '1_2_3_4_5_6_8_10_13_16' patch_size: int = 16 patch_nums: tuple = None # [automatically set; don't specify this] = tuple(map(int, args.pn.replace('-', '_').split('_'))) resos: tuple = None # [automatically set; don't specify this] = tuple(pn * args.patch_size for pn in args.patch_nums) data_load_reso: int = None # [automatically set; don't specify this] would be max(patch_nums) * patch_size mid_reso: float = 1.125 # aug: first resize to mid_reso = 1.125 * data_load_reso, then crop to data_load_reso hflip: bool = False # augmentation: horizontal flip workers: int = 0 # num workers; 0: auto, -1: don't use multiprocessing in DataLoader # progressive training pg: float = 0.0 # >0 for use progressive training during [0%, this] of training pg0: int = 4 # progressive initial stage, 0: from the 1st token map, 1: from the 2nd token map, etc pgwp: float = 0 # num of warmup epochs at each progressive stage # would be automatically set in runtime cmd: str = ' '.join(sys.argv[1:]) # [automatically set; don't specify this] branch: str = subprocess.check_output(f'git symbolic-ref --short HEAD 2>/dev/null || git rev-parse HEAD', shell=True).decode('utf-8').strip() or '[unknown]' # [automatically set; don't specify this] commit_id: str = subprocess.check_output(f'git rev-parse HEAD', shell=True).decode('utf-8').strip() or '[unknown]' # [automatically set; don't specify this] commit_msg: str = (subprocess.check_output(f'git log -1', shell=True).decode('utf-8').strip().splitlines() or ['[unknown]'])[-1].strip() # [automatically set; don't specify this] acc_mean: float = None # [automatically set; don't specify this] acc_tail: float = None # [automatically set; don't specify this] L_mean: float = None # [automatically set; don't specify this] L_tail: float = None # [automatically set; don't specify this] vacc_mean: float = None # [automatically set; don't specify this] vacc_tail: float = None # [automatically set; don't specify this] vL_mean: float = None # [automatically set; don't specify this] vL_tail: float = None # [automatically set; don't specify this] grad_norm: float = None # [automatically set; don't specify this] cur_lr: float = None # [automatically set; don't specify this] cur_wd: float = None # [automatically set; don't specify this] cur_it: str = '' # [automatically set; don't specify this] cur_ep: str = '' # [automatically set; don't specify this] remain_time: str = '' # [automatically set; don't specify this] finish_time: str = '' # [automatically set; don't specify this] # environment local_out_dir_path: str = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'local_output') # [automatically set; don't specify this] tb_log_dir_path: str = '...tb-...' # [automatically set; don't specify this] log_txt_path: str = '...' # [automatically set; don't specify this] last_ckpt_path: str = '...' # [automatically set; don't specify this] tf32: bool = True # whether to use TensorFloat32 device: str = 'cpu' # [automatically set; don't specify this] seed: int = None # seed def seed_everything(self, benchmark: bool): torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = benchmark if self.seed is None: torch.backends.cudnn.deterministic = False else: torch.backends.cudnn.deterministic = True seed = self.seed * dist.get_world_size() + dist.get_rank() os.environ['PYTHONHASHSEED'] = str(seed) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) same_seed_for_all_ranks: int = 0 # this is only for distributed sampler def get_different_generator_for_each_rank(self) -> Optional[torch.Generator]: # for random augmentation if self.seed is None: return None g = torch.Generator() g.manual_seed(self.seed * dist.get_world_size() + dist.get_rank()) return g local_debug: bool = 'KEVIN_LOCAL' in os.environ dbg_nan: bool = False # 'KEVIN_LOCAL' in os.environ def compile_model(self, m, fast): if fast == 0 or self.local_debug: return m return torch.compile(m, mode={ 1: 'reduce-overhead', 2: 'max-autotune', 3: 'default', }[fast]) if hasattr(torch, 'compile') else m def state_dict(self, key_ordered=True) -> Union[OrderedDict, dict]: d = (OrderedDict if key_ordered else dict)() # self.as_dict() would contain methods, but we only need variables for k in self.class_variables.keys(): if k not in {'device'}: # these are not serializable d[k] = getattr(self, k) return d def load_state_dict(self, d: Union[OrderedDict, dict, str]): if isinstance(d, str): # for compatibility with old version d: dict = eval('\n'.join([l for l in d.splitlines() if ' 0: print(f'======================================================================================') print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================\n{args.extra_args}') print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================') print(f'======================================================================================\n\n') # init torch distributed from utils import misc os.makedirs(args.local_out_dir_path, exist_ok=True) misc.init_distributed_mode(local_out_path=args.local_out_dir_path, timeout=30) # set env args.set_tf32(args.tf32) args.seed_everything(benchmark=args.pg == 0) # update args: data loading args.device = dist.get_device() if args.pn == '256': args.pn = '1_2_3_4_5_6_8_10_13_16' elif args.pn == '512': args.pn = '1_2_3_4_6_9_13_18_24_32' elif args.pn == '1024': args.pn = '1_2_3_4_5_7_9_12_16_21_27_36_48_64' args.patch_nums = tuple(map(int, args.pn.replace('-', '_').split('_'))) args.resos = tuple(pn * args.patch_size for pn in args.patch_nums) args.data_load_reso = max(args.resos) # update args: bs and lr bs_per_gpu = round(args.bs / args.ac / dist.get_world_size()) args.batch_size = bs_per_gpu args.bs = args.glb_batch_size = args.batch_size * dist.get_world_size() args.workers = min(max(0, args.workers), args.batch_size) args.tlr = args.ac * args.tblr * args.glb_batch_size / 256 args.twde = args.twde or args.twd if args.wp == 0: args.wp = args.ep * 1/50 # update args: progressive training if args.pgwp == 0: args.pgwp = args.ep * 1/300 if args.pg > 0: args.sche = f'lin{args.pg:g}' # update args: paths args.log_txt_path = os.path.join(args.local_out_dir_path, 'log.txt') args.last_ckpt_path = os.path.join(args.local_out_dir_path, f'ar-ckpt-last.pth') _reg_valid_name = re.compile(r'[^\w\-+,.]') tb_name = _reg_valid_name.sub( '_', f'tb-VARd{args.depth}' f'__pn{args.pn}' f'__b{args.bs}ep{args.ep}{args.opt[:4]}lr{args.tblr:g}wd{args.twd:g}' ) args.tb_log_dir_path = os.path.join(args.local_out_dir_path, tb_name) return args ================================================ FILE: utils/data.py ================================================ import os.path as osp import PIL.Image as PImage from torchvision.datasets.folder import DatasetFolder, IMG_EXTENSIONS from torchvision.transforms import InterpolationMode, transforms def normalize_01_into_pm1(x): # normalize x from [0, 1] to [-1, 1] by (x*2) - 1 return x.add(x).add_(-1) def build_dataset( data_path: str, final_reso: int, hflip=False, mid_reso=1.125, ): # build augmentations mid_reso = round(mid_reso * final_reso) # first resize to mid_reso, then crop to final_reso train_aug, val_aug = [ transforms.Resize(mid_reso, interpolation=InterpolationMode.LANCZOS), # transforms.Resize: resize the shorter edge to mid_reso transforms.RandomCrop((final_reso, final_reso)), transforms.ToTensor(), normalize_01_into_pm1, ], [ transforms.Resize(mid_reso, interpolation=InterpolationMode.LANCZOS), # transforms.Resize: resize the shorter edge to mid_reso transforms.CenterCrop((final_reso, final_reso)), transforms.ToTensor(), normalize_01_into_pm1, ] if hflip: train_aug.insert(0, transforms.RandomHorizontalFlip()) train_aug, val_aug = transforms.Compose(train_aug), transforms.Compose(val_aug) # build dataset train_set = DatasetFolder(root=osp.join(data_path, 'train'), loader=pil_loader, extensions=IMG_EXTENSIONS, transform=train_aug) val_set = DatasetFolder(root=osp.join(data_path, 'val'), loader=pil_loader, extensions=IMG_EXTENSIONS, transform=val_aug) num_classes = 1000 print(f'[Dataset] {len(train_set)=}, {len(val_set)=}, {num_classes=}') print_aug(train_aug, '[train]') print_aug(val_aug, '[val]') return num_classes, train_set, val_set def pil_loader(path): with open(path, 'rb') as f: img: PImage.Image = PImage.open(f).convert('RGB') return img def print_aug(transform, label): print(f'Transform {label} = ') if hasattr(transform, 'transforms'): for t in transform.transforms: print(t) else: print(transform) print('---------------------------\n') ================================================ FILE: utils/data_sampler.py ================================================ import numpy as np import torch from torch.utils.data.sampler import Sampler class EvalDistributedSampler(Sampler): def __init__(self, dataset, num_replicas, rank): seps = np.linspace(0, len(dataset), num_replicas+1, dtype=int) beg, end = seps[:-1], seps[1:] beg, end = beg[rank], end[rank] self.indices = tuple(range(beg, end)) def __iter__(self): return iter(self.indices) def __len__(self) -> int: return len(self.indices) class InfiniteBatchSampler(Sampler): def __init__(self, dataset_len, batch_size, seed_for_all_rank=0, fill_last=False, shuffle=True, drop_last=False, start_ep=0, start_it=0): self.dataset_len = dataset_len self.batch_size = batch_size self.iters_per_ep = dataset_len // batch_size if drop_last else (dataset_len + batch_size - 1) // batch_size self.max_p = self.iters_per_ep * batch_size self.fill_last = fill_last self.shuffle = shuffle self.epoch = start_ep self.same_seed_for_all_ranks = seed_for_all_rank self.indices = self.gener_indices() self.start_ep, self.start_it = start_ep, start_it def gener_indices(self): if self.shuffle: g = torch.Generator() g.manual_seed(self.epoch + self.same_seed_for_all_ranks) indices = torch.randperm(self.dataset_len, generator=g).numpy() else: indices = torch.arange(self.dataset_len).numpy() tails = self.batch_size - (self.dataset_len % self.batch_size) if tails != self.batch_size and self.fill_last: tails = indices[:tails] np.random.shuffle(indices) indices = np.concatenate((indices, tails)) # built-in list/tuple is faster than np.ndarray (when collating the data via a for-loop) # noinspection PyTypeChecker return tuple(indices.tolist()) def __iter__(self): self.epoch = self.start_ep while True: self.epoch += 1 p = (self.start_it * self.batch_size) if self.epoch == self.start_ep else 0 while p < self.max_p: q = p + self.batch_size yield self.indices[p:q] p = q if self.shuffle: self.indices = self.gener_indices() def __len__(self): return self.iters_per_ep class DistInfiniteBatchSampler(InfiniteBatchSampler): def __init__(self, world_size, rank, dataset_len, glb_batch_size, same_seed_for_all_ranks=0, repeated_aug=0, fill_last=False, shuffle=True, start_ep=0, start_it=0): assert glb_batch_size % world_size == 0 self.world_size, self.rank = world_size, rank self.dataset_len = dataset_len self.glb_batch_size = glb_batch_size self.batch_size = glb_batch_size // world_size self.iters_per_ep = (dataset_len + glb_batch_size - 1) // glb_batch_size self.fill_last = fill_last self.shuffle = shuffle self.repeated_aug = repeated_aug self.epoch = start_ep self.same_seed_for_all_ranks = same_seed_for_all_ranks self.indices = self.gener_indices() self.start_ep, self.start_it = start_ep, start_it def gener_indices(self): global_max_p = self.iters_per_ep * self.glb_batch_size # global_max_p % world_size must be 0 cuz glb_batch_size % world_size == 0 # print(f'global_max_p = iters_per_ep({self.iters_per_ep}) * glb_batch_size({self.glb_batch_size}) = {global_max_p}') if self.shuffle: g = torch.Generator() g.manual_seed(self.epoch + self.same_seed_for_all_ranks) global_indices = torch.randperm(self.dataset_len, generator=g) if self.repeated_aug > 1: global_indices = global_indices[:(self.dataset_len + self.repeated_aug - 1) // self.repeated_aug].repeat_interleave(self.repeated_aug, dim=0)[:global_max_p] else: global_indices = torch.arange(self.dataset_len) filling = global_max_p - global_indices.shape[0] if filling > 0 and self.fill_last: global_indices = torch.cat((global_indices, global_indices[:filling])) # global_indices = tuple(global_indices.numpy().tolist()) seps = torch.linspace(0, global_indices.shape[0], self.world_size + 1, dtype=torch.int) local_indices = global_indices[seps[self.rank].item():seps[self.rank + 1].item()].tolist() self.max_p = len(local_indices) return local_indices ================================================ FILE: utils/lr_control.py ================================================ import math from pprint import pformat from typing import Tuple, List, Dict, Union import torch.nn import dist def lr_wd_annealing(sche_type: str, optimizer, peak_lr, wd, wd_end, cur_it, wp_it, max_it, wp0=0.005, wpe=0.001): """Decay the learning rate with half-cycle cosine after warmup""" wp_it = round(wp_it) if cur_it < wp_it: cur_lr = wp0 + (1-wp0) * cur_it / wp_it else: pasd = (cur_it - wp_it) / (max_it-1 - wp_it) # [0, 1] rest = 1 - pasd # [1, 0] if sche_type == 'cos': cur_lr = wpe + (1-wpe) * (0.5 + 0.5 * math.cos(math.pi * pasd)) elif sche_type == 'lin': T = 0.15; max_rest = 1-T if pasd < T: cur_lr = 1 else: cur_lr = wpe + (1-wpe) * rest / max_rest # 1 to wpe elif sche_type == 'lin0': T = 0.05; max_rest = 1-T if pasd < T: cur_lr = 1 else: cur_lr = wpe + (1-wpe) * rest / max_rest elif sche_type == 'lin00': cur_lr = wpe + (1-wpe) * rest elif sche_type.startswith('lin'): T = float(sche_type[3:]); max_rest = 1-T wpe_mid = wpe + (1-wpe) * max_rest wpe_mid = (1 + wpe_mid) / 2 if pasd < T: cur_lr = 1 + (wpe_mid-1) * pasd / T else: cur_lr = wpe + (wpe_mid-wpe) * rest / max_rest elif sche_type == 'exp': T = 0.15; max_rest = 1-T if pasd < T: cur_lr = 1 else: expo = (pasd-T) / max_rest * math.log(wpe) cur_lr = math.exp(expo) else: raise NotImplementedError(f'unknown sche_type {sche_type}') cur_lr *= peak_lr pasd = cur_it / (max_it-1) cur_wd = wd_end + (wd - wd_end) * (0.5 + 0.5 * math.cos(math.pi * pasd)) inf = 1e6 min_lr, max_lr = inf, -1 min_wd, max_wd = inf, -1 for param_group in optimizer.param_groups: param_group['lr'] = cur_lr * param_group.get('lr_sc', 1) # 'lr_sc' could be assigned max_lr = max(max_lr, param_group['lr']) min_lr = min(min_lr, param_group['lr']) param_group['weight_decay'] = cur_wd * param_group.get('wd_sc', 1) max_wd = max(max_wd, param_group['weight_decay']) if param_group['weight_decay'] > 0: min_wd = min(min_wd, param_group['weight_decay']) if min_lr == inf: min_lr = -1 if min_wd == inf: min_wd = -1 return min_lr, max_lr, min_wd, max_wd def filter_params(model, nowd_keys=()) -> Tuple[ List[str], List[torch.nn.Parameter], List[Dict[str, Union[torch.nn.Parameter, float]]] ]: para_groups, para_groups_dbg = {}, {} names, paras = [], [] names_no_grad = [] count, numel = 0, 0 for name, para in model.named_parameters(): name = name.replace('_fsdp_wrapped_module.', '') if not para.requires_grad: names_no_grad.append(name) continue # frozen weights count += 1 numel += para.numel() names.append(name) paras.append(para) if para.ndim == 1 or name.endswith('bias') or any(k in name for k in nowd_keys): cur_wd_sc, group_name = 0., 'ND' else: cur_wd_sc, group_name = 1., 'D' cur_lr_sc = 1. if group_name not in para_groups: para_groups[group_name] = {'params': [], 'wd_sc': cur_wd_sc, 'lr_sc': cur_lr_sc} para_groups_dbg[group_name] = {'params': [], 'wd_sc': cur_wd_sc, 'lr_sc': cur_lr_sc} para_groups[group_name]['params'].append(para) para_groups_dbg[group_name]['params'].append(name) for g in para_groups_dbg.values(): g['params'] = pformat(', '.join(g['params']), width=200) print(f'[get_param_groups] param_groups = \n{pformat(para_groups_dbg, indent=2, width=240)}\n') for rk in range(dist.get_world_size()): dist.barrier() if dist.get_rank() == rk: print(f'[get_param_groups][rank{dist.get_rank()}] {type(model).__name__=} {count=}, {numel=}', flush=True, force=True) print('') assert len(names_no_grad) == 0, f'[get_param_groups] names_no_grad = \n{pformat(names_no_grad, indent=2, width=240)}\n' return names, paras, list(para_groups.values()) ================================================ FILE: utils/misc.py ================================================ import datetime import functools import glob import os import subprocess import sys import time from collections import defaultdict, deque from typing import Iterator, List, Tuple import numpy as np import pytz import torch import torch.distributed as tdist import dist from utils import arg_util os_system = functools.partial(subprocess.call, shell=True) def echo(info): os_system(f'echo "[$(date "+%m-%d-%H:%M:%S")] ({os.path.basename(sys._getframe().f_back.f_code.co_filename)}, line{sys._getframe().f_back.f_lineno})=> {info}"') def os_system_get_stdout(cmd): return subprocess.run(cmd, shell=True, stdout=subprocess.PIPE).stdout.decode('utf-8') def os_system_get_stdout_stderr(cmd): cnt = 0 while True: try: sp = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=30) except subprocess.TimeoutExpired: cnt += 1 print(f'[fetch free_port file] timeout cnt={cnt}') else: return sp.stdout.decode('utf-8'), sp.stderr.decode('utf-8') def time_str(fmt='[%m-%d %H:%M:%S]'): return datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime(fmt) def init_distributed_mode(local_out_path, only_sync_master=False, timeout=30): try: dist.initialize(fork=False, timeout=timeout) dist.barrier() except RuntimeError: print(f'{">"*75} NCCL Error {"<"*75}', flush=True) time.sleep(10) if local_out_path is not None: os.makedirs(local_out_path, exist_ok=True) _change_builtin_print(dist.is_local_master()) if (dist.is_master() if only_sync_master else dist.is_local_master()) and local_out_path is not None and len(local_out_path): sys.stdout, sys.stderr = SyncPrint(local_out_path, sync_stdout=True), SyncPrint(local_out_path, sync_stdout=False) def _change_builtin_print(is_master): import builtins as __builtin__ builtin_print = __builtin__.print if type(builtin_print) != type(open): return def prt(*args, **kwargs): force = kwargs.pop('force', False) clean = kwargs.pop('clean', False) deeper = kwargs.pop('deeper', False) if is_master or force: if not clean: f_back = sys._getframe().f_back if deeper and f_back.f_back is not None: f_back = f_back.f_back file_desc = f'{f_back.f_code.co_filename:24s}'[-24:] builtin_print(f'{time_str()} ({file_desc}, line{f_back.f_lineno:-4d})=>', *args, **kwargs) else: builtin_print(*args, **kwargs) __builtin__.print = prt class SyncPrint(object): def __init__(self, local_output_dir, sync_stdout=True): self.sync_stdout = sync_stdout self.terminal_stream = sys.stdout if sync_stdout else sys.stderr fname = os.path.join(local_output_dir, 'stdout.txt' if sync_stdout else 'stderr.txt') existing = os.path.exists(fname) self.file_stream = open(fname, 'a') if existing: self.file_stream.write('\n'*7 + '='*55 + f' RESTART {time_str()} ' + '='*55 + '\n') self.file_stream.flush() self.enabled = True def write(self, message): self.terminal_stream.write(message) self.file_stream.write(message) def flush(self): self.terminal_stream.flush() self.file_stream.flush() def close(self): if not self.enabled: return self.enabled = False self.file_stream.flush() self.file_stream.close() if self.sync_stdout: sys.stdout = self.terminal_stream sys.stdout.flush() else: sys.stderr = self.terminal_stream sys.stderr.flush() def __del__(self): self.close() class DistLogger(object): def __init__(self, lg, verbose): self._lg, self._verbose = lg, verbose @staticmethod def do_nothing(*args, **kwargs): pass def __getattr__(self, attr: str): return getattr(self._lg, attr) if self._verbose else DistLogger.do_nothing class TensorboardLogger(object): def __init__(self, log_dir, filename_suffix): try: import tensorflow_io as tfio except: pass from torch.utils.tensorboard import SummaryWriter self.writer = SummaryWriter(log_dir=log_dir, filename_suffix=filename_suffix) self.step = 0 def set_step(self, step=None): if step is not None: self.step = step else: self.step += 1 def update(self, head='scalar', step=None, **kwargs): for k, v in kwargs.items(): if v is None: continue # assert isinstance(v, (float, int)), type(v) if step is None: # iter wise it = self.step if it == 0 or (it + 1) % 500 == 0: if hasattr(v, 'item'): v = v.item() self.writer.add_scalar(f'{head}/{k}', v, it) else: # epoch wise if hasattr(v, 'item'): v = v.item() self.writer.add_scalar(f'{head}/{k}', v, step) def log_tensor_as_distri(self, tag, tensor1d, step=None): if step is None: # iter wise step = self.step loggable = step == 0 or (step + 1) % 500 == 0 else: # epoch wise loggable = True if loggable: try: self.writer.add_histogram(tag=tag, values=tensor1d, global_step=step) except Exception as e: print(f'[log_tensor_as_distri writer.add_histogram failed]: {e}') def log_image(self, tag, img_chw, step=None): if step is None: # iter wise step = self.step loggable = step == 0 or (step + 1) % 500 == 0 else: # epoch wise loggable = True if loggable: self.writer.add_image(tag, img_chw, step, dataformats='CHW') def flush(self): self.writer.flush() def close(self): self.writer.close() class SmoothedValue(object): """Track a series of values and provide access to smoothed values over a window or the global series average. """ def __init__(self, window_size=30, fmt=None): if fmt is None: fmt = "{median:.4f} ({global_avg:.4f})" self.deque = deque(maxlen=window_size) self.total = 0.0 self.count = 0 self.fmt = fmt def update(self, value, n=1): self.deque.append(value) self.count += n self.total += value * n def synchronize_between_processes(self): """ Warning: does not synchronize the deque! """ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') tdist.barrier() tdist.all_reduce(t) t = t.tolist() self.count = int(t[0]) self.total = t[1] @property def median(self): return np.median(self.deque) if len(self.deque) else 0 @property def avg(self): return sum(self.deque) / (len(self.deque) or 1) @property def global_avg(self): return self.total / (self.count or 1) @property def max(self): return max(self.deque) @property def value(self): return self.deque[-1] if len(self.deque) else 0 def time_preds(self, counts) -> Tuple[float, str, str]: remain_secs = counts * self.median return remain_secs, str(datetime.timedelta(seconds=round(remain_secs))), time.strftime("%Y-%m-%d %H:%M", time.localtime(time.time() + remain_secs)) def __str__(self): return self.fmt.format( median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value) class MetricLogger(object): def __init__(self, delimiter=' '): self.meters = defaultdict(SmoothedValue) self.delimiter = delimiter self.iter_end_t = time.time() self.log_iters = [] def update(self, **kwargs): for k, v in kwargs.items(): if v is None: continue if hasattr(v, 'item'): v = v.item() # assert isinstance(v, (float, int)), type(v) assert isinstance(v, (float, int)) self.meters[k].update(v) def __getattr__(self, attr): if attr in self.meters: return self.meters[attr] if attr in self.__dict__: return self.__dict__[attr] raise AttributeError("'{}' object has no attribute '{}'".format( type(self).__name__, attr)) def __str__(self): loss_str = [] for name, meter in self.meters.items(): if len(meter.deque): loss_str.append( "{}: {}".format(name, str(meter)) ) return self.delimiter.join(loss_str) def synchronize_between_processes(self): for meter in self.meters.values(): meter.synchronize_between_processes() def add_meter(self, name, meter): self.meters[name] = meter def log_every(self, start_it, max_iters, itrt, print_freq, header=None): self.log_iters = set(np.linspace(0, max_iters-1, print_freq, dtype=int).tolist()) self.log_iters.add(start_it) if not header: header = '' start_time = time.time() self.iter_end_t = time.time() self.iter_time = SmoothedValue(fmt='{avg:.4f}') self.data_time = SmoothedValue(fmt='{avg:.4f}') space_fmt = ':' + str(len(str(max_iters))) + 'd' log_msg = [ header, '[{0' + space_fmt + '}/{1}]', 'eta: {eta}', '{meters}', 'time: {time}', 'data: {data}' ] log_msg = self.delimiter.join(log_msg) if isinstance(itrt, Iterator) and not hasattr(itrt, 'preload') and not hasattr(itrt, 'set_epoch'): for i in range(start_it, max_iters): obj = next(itrt) self.data_time.update(time.time() - self.iter_end_t) yield i, obj self.iter_time.update(time.time() - self.iter_end_t) if i in self.log_iters: eta_seconds = self.iter_time.global_avg * (max_iters - i) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) print(log_msg.format( i, max_iters, eta=eta_string, meters=str(self), time=str(self.iter_time), data=str(self.data_time)), flush=True) self.iter_end_t = time.time() else: if isinstance(itrt, int): itrt = range(itrt) for i, obj in enumerate(itrt): self.data_time.update(time.time() - self.iter_end_t) yield i, obj self.iter_time.update(time.time() - self.iter_end_t) if i in self.log_iters: eta_seconds = self.iter_time.global_avg * (max_iters - i) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) print(log_msg.format( i, max_iters, eta=eta_string, meters=str(self), time=str(self.iter_time), data=str(self.data_time)), flush=True) self.iter_end_t = time.time() total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('{} Total time: {} ({:.3f} s / it)'.format( header, total_time_str, total_time / max_iters), flush=True) def glob_with_latest_modified_first(pattern, recursive=False): return sorted(glob.glob(pattern, recursive=recursive), key=os.path.getmtime, reverse=True) def auto_resume(args: arg_util.Args, pattern='ckpt*.pth') -> Tuple[List[str], int, int, dict, dict]: info = [] file = os.path.join(args.local_out_dir_path, pattern) all_ckpt = glob_with_latest_modified_first(file) if len(all_ckpt) == 0: info.append(f'[auto_resume] no ckpt found @ {file}') info.append(f'[auto_resume quit]') return info, 0, 0, {}, {} else: info.append(f'[auto_resume] load ckpt from @ {all_ckpt[0]} ...') ckpt = torch.load(all_ckpt[0], map_location='cpu') ep, it = ckpt['epoch'], ckpt['iter'] info.append(f'[auto_resume success] resume from ep{ep}, it{it}') return info, ep, it, ckpt['trainer'], ckpt['args'] def create_npz_from_sample_folder(sample_folder: str): """ Builds a single .npz file from a folder of .png samples. Refer to DiT. """ import os, glob import numpy as np from tqdm import tqdm from PIL import Image samples = [] pngs = glob.glob(os.path.join(sample_folder, '*.png')) + glob.glob(os.path.join(sample_folder, '*.PNG')) assert len(pngs) == 50_000, f'{len(pngs)} png files found in {sample_folder}, but expected 50,000' for png in tqdm(pngs, desc='Building .npz file from samples (png only)'): with Image.open(png) as sample_pil: sample_np = np.asarray(sample_pil).astype(np.uint8) samples.append(sample_np) samples = np.stack(samples) assert samples.shape == (50_000, samples.shape[1], samples.shape[2], 3) npz_path = f'{sample_folder}.npz' np.savez(npz_path, arr_0=samples) print(f'Saved .npz file to {npz_path} [shape={samples.shape}].') return npz_path