[
  {
    "path": ".gitignore",
    "content": "*.swp\n**/__pycache__/**\n**/.ipynb_checkpoints/**\n.DS_Store\n.idea/*\n.vscode/*\nllava/\n_vis_cached/\n_auto_*\nckpt/\nlog/\ntb*/\nimg*/\nlocal_output*\n*.pth\n*.pth.tar\n*.ckpt\n*.log\n*.txt\n*.ipynb\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2024 FoundationVision\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE."
  },
  {
    "path": "README.md",
    "content": "# VAR: a new visual generation method elevates GPT-style models beyond diffusion🚀 & Scaling laws observed📈\n\n<div align=\"center\">\n\n[![demo platform](https://img.shields.io/badge/Play%20with%20VAR%21-VAR%20demo%20platform-lightblue)](https://opensource.bytedance.com/gmpt/t2i/invite)&nbsp;\n[![arXiv](https://img.shields.io/badge/arXiv%20paper-2404.02905-b31b1b.svg)](https://arxiv.org/abs/2404.02905)&nbsp;\n[![huggingface weights](https://img.shields.io/badge/%F0%9F%A4%97%20Weights-FoundationVision/var-yellow)](https://huggingface.co/FoundationVision/var)&nbsp;\n[![SOTA](https://img.shields.io/badge/State%20of%20the%20Art-Image%20Generation%20on%20ImageNet%20%28AR%29-32B1B4?logo=data%3Aimage%2Fsvg%2Bxml%3Bbase64%2CPHN2ZyB3aWR0aD0iNjA2IiBoZWlnaHQ9IjYwNiIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIiB4bWxuczp4bGluaz0iaHR0cDovL3d3dy53My5vcmcvMTk5OS94bGluayIgb3ZlcmZsb3c9ImhpZGRlbiI%2BPGRlZnM%2BPGNsaXBQYXRoIGlkPSJjbGlwMCI%2BPHJlY3QgeD0iLTEiIHk9Ii0xIiB3aWR0aD0iNjA2IiBoZWlnaHQ9IjYwNiIvPjwvY2xpcFBhdGg%2BPC9kZWZzPjxnIGNsaXAtcGF0aD0idXJsKCNjbGlwMCkiIHRyYW5zZm9ybT0idHJhbnNsYXRlKDEgMSkiPjxyZWN0IHg9IjUyOSIgeT0iNjYiIHdpZHRoPSI1NiIgaGVpZ2h0PSI0NzMiIGZpbGw9IiM0NEYyRjYiLz48cmVjdCB4PSIxOSIgeT0iNjYiIHdpZHRoPSI1NyIgaGVpZ2h0PSI0NzMiIGZpbGw9IiM0NEYyRjYiLz48cmVjdCB4PSIyNzQiIHk9IjE1MSIgd2lkdGg9IjU3IiBoZWlnaHQ9IjMwMiIgZmlsbD0iIzQ0RjJGNiIvPjxyZWN0IHg9IjEwNCIgeT0iMTUxIiB3aWR0aD0iNTciIGhlaWdodD0iMzAyIiBmaWxsPSIjNDRGMkY2Ii8%2BPHJlY3QgeD0iNDQ0IiB5PSIxNTEiIHdpZHRoPSI1NyIgaGVpZ2h0PSIzMDIiIGZpbGw9IiM0NEYyRjYiLz48cmVjdCB4PSIzNTkiIHk9IjE3MCIgd2lkdGg9IjU2IiBoZWlnaHQ9IjI2NCIgZmlsbD0iIzQ0RjJGNiIvPjxyZWN0IHg9IjE4OCIgeT0iMTcwIiB3aWR0aD0iNTciIGhlaWdodD0iMjY0IiBmaWxsPSIjNDRGMkY2Ii8%2BPHJlY3QgeD0iNzYiIHk9IjY2IiB3aWR0aD0iNDciIGhlaWdodD0iNTciIGZpbGw9IiM0NEYyRjYiLz48cmVjdCB4PSI0ODIiIHk9IjY2IiB3aWR0aD0iNDciIGhlaWdodD0iNTciIGZpbGw9IiM0NEYyRjYiLz48cmVjdCB4PSI3NiIgeT0iNDgyIiB3aWR0aD0iNDciIGhlaWdodD0iNTciIGZpbGw9IiM0NEYyRjYiLz48cmVjdCB4PSI0ODIiIHk9IjQ4MiIgd2lkdGg9IjQ3IiBoZWlnaHQ9IjU3IiBmaWxsPSIjNDRGMkY2Ii8%2BPC9nPjwvc3ZnPg%3D%3D)](https://paperswithcode.com/sota/image-generation-on-imagenet-256x256?tag_filter=485&p=visual-autoregressive-modeling-scalable-image)\n\n\n</div>\n<p align=\"center\" style=\"font-size: larger;\">\n  <a href=\"https://arxiv.org/abs/2404.02905\">Visual Autoregressive Modeling: Scalable Image Generation via Next-Scale Prediction</a>\n</p>\n\n<div>\n  <p align=\"center\" style=\"font-size: larger;\">\n    <strong>NeurIPS 2024 Best Paper</strong>\n  </p>\n</div>\n\n<p align=\"center\">\n<img src=\"https://github.com/FoundationVision/VAR/assets/39692511/9850df90-20b1-4f29-8592-e3526d16d755\" width=95%>\n<p>\n\n<br>\n\n## News\n* **2025-11:** We Release our Text-to-Video generation model **InfinityStar** based on VAR & Infinity, please check [Infinity⭐️](https://github.com/FoundationVision/InfinityStar).\n* **2025-11:** 🎉 InfinityStar is accepted as **NeurIPS 2025 Oral.**\n* **2025-04:** 🎉 Infinity is accepted as **CVPR 2025 Oral.**\n* **2024-12:** 🏆 VAR received **NeurIPS 2024 Best Paper Award**.\n* **2024-12:** 🔥 We Release our Text-to-Image research based on VAR, please check [Infinity](https://github.com/FoundationVision/Infinity).\n* **2024-09:** VAR is accepted as **NeurIPS 2024 Oral** Presentation.\n* **2024-04:** [Visual AutoRegressive modeling](https://github.com/FoundationVision/VAR) is released.\n\n## 🕹️ Try and Play with VAR!\n\n~~We provide a [demo website](https://var.vision/demo) for you to play with VAR models and generate images interactively. Enjoy the fun of visual autoregressive modeling!~~\n\nWe provide a [demo website](https://opensource.bytedance.com/gmpt/t2i/invite) for you to play with VAR Text-to-Image and generate images interactively. Enjoy the fun of visual autoregressive modeling!\n\nWe also provide [demo_sample.ipynb](demo_sample.ipynb) for you to see more technical details about VAR.\n\n[//]: # (<p align=\"center\">)\n[//]: # (<img src=\"https://user-images.githubusercontent.com/39692511/226376648-3f28a1a6-275d-4f88-8f3e-cd1219882488.png\" width=50%)\n[//]: # (<p>)\n\n\n## What's New?\n\n### 🔥 Introducing VAR: a new paradigm in autoregressive visual generation✨:\n\nVisual Autoregressive Modeling (VAR) redefines the autoregressive learning on images as coarse-to-fine \"next-scale prediction\" or \"next-resolution prediction\", diverging from the standard raster-scan \"next-token prediction\".\n\n<p align=\"center\">\n<img src=\"https://github.com/FoundationVision/VAR/assets/39692511/3e12655c-37dc-4528-b923-ec6c4cfef178\" width=93%>\n<p>\n\n### 🔥 For the first time, GPT-style autoregressive models surpass diffusion models🚀:\n<p align=\"center\">\n<img src=\"https://github.com/FoundationVision/VAR/assets/39692511/cc30b043-fa4e-4d01-a9b1-e50650d5675d\" width=55%>\n<p>\n\n\n### 🔥 Discovering power-law Scaling Laws in VAR transformers📈:\n\n\n<p align=\"center\">\n<img src=\"https://github.com/FoundationVision/VAR/assets/39692511/c35fb56e-896e-4e4b-9fb9-7a1c38513804\" width=85%>\n<p>\n<p align=\"center\">\n<img src=\"https://github.com/FoundationVision/VAR/assets/39692511/91d7b92c-8fc3-44d9-8fb4-73d6cdb8ec1e\" width=85%>\n<p>\n\n\n### 🔥 Zero-shot generalizability🛠️:\n\n<p align=\"center\">\n<img src=\"https://github.com/FoundationVision/VAR/assets/39692511/a54a4e52-6793-4130-bae2-9e459a08e96a\" width=70%>\n<p>\n\n#### For a deep dive into our analyses, discussions, and evaluations, check out our [paper](https://arxiv.org/abs/2404.02905).\n\n\n## VAR zoo\nWe provide VAR models for you to play with, which are on <a href='https://huggingface.co/FoundationVision/var'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Huggingface-FoundationVision/var-yellow'></a> or can be downloaded from the following links:\n\n|   model    | reso. |   FID    | rel. cost | #params | HF weights🤗                                                                        |\n|:----------:|:-----:|:--------:|:---------:|:-------:|:------------------------------------------------------------------------------------|\n|  VAR-d16   |  256  |   3.55   |    0.4    |  310M   | [var_d16.pth](https://huggingface.co/FoundationVision/var/resolve/main/var_d16.pth) |\n|  VAR-d20   |  256  |   2.95   |    0.5    |  600M   | [var_d20.pth](https://huggingface.co/FoundationVision/var/resolve/main/var_d20.pth) |\n|  VAR-d24   |  256  |   2.33   |    0.6    |  1.0B   | [var_d24.pth](https://huggingface.co/FoundationVision/var/resolve/main/var_d24.pth) |\n|  VAR-d30   |  256  |   1.97   |     1     |  2.0B   | [var_d30.pth](https://huggingface.co/FoundationVision/var/resolve/main/var_d30.pth) |\n| VAR-d30-re |  256  | **1.80** |     1     |  2.0B   | [var_d30.pth](https://huggingface.co/FoundationVision/var/resolve/main/var_d30.pth) |\n| VAR-d36    |  512  | **2.63** |     -     |  2.3B   | [var_d36.pth](https://huggingface.co/FoundationVision/var/resolve/main/var_d36.pth) |\n\nYou can load these models to generate images via the codes in [demo_sample.ipynb](demo_sample.ipynb). Note: you need to download [vae_ch160v4096z32.pth](https://huggingface.co/FoundationVision/var/resolve/main/vae_ch160v4096z32.pth) first.\n\n\n## Installation\n\n1. Install `torch>=2.0.0`.\n2. Install other pip packages via `pip3 install -r requirements.txt`.\n3. Prepare the [ImageNet](http://image-net.org/) dataset\n    <details>\n    <summary> assume the ImageNet is in `/path/to/imagenet`. It should be like this:</summary>\n\n    ```\n    /path/to/imagenet/:\n        train/:\n            n01440764: \n                many_images.JPEG ...\n            n01443537:\n                many_images.JPEG ...\n        val/:\n            n01440764:\n                ILSVRC2012_val_00000293.JPEG ...\n            n01443537:\n                ILSVRC2012_val_00000236.JPEG ...\n    ```\n   **NOTE: The arg `--data_path=/path/to/imagenet` should be passed to the training script.**\n    </details>\n\n5. (Optional) install and compile `flash-attn` and `xformers` for faster attention computation. Our code will automatically use them if installed. See [models/basic_var.py#L15-L30](models/basic_var.py#L15-L30).\n\n\n## Training Scripts\n\nTo train VAR-{d16, d20, d24, d30, d36-s} on ImageNet 256x256 or 512x512, you can run the following command:\n```shell\n# d16, 256x256\ntorchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \\\n  --depth=16 --bs=768 --ep=200 --fp16=1 --alng=1e-3 --wpe=0.1\n# d20, 256x256\ntorchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \\\n  --depth=20 --bs=768 --ep=250 --fp16=1 --alng=1e-3 --wpe=0.1\n# d24, 256x256\ntorchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \\\n  --depth=24 --bs=768 --ep=350 --tblr=8e-5 --fp16=1 --alng=1e-4 --wpe=0.01\n# d30, 256x256\ntorchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \\\n  --depth=30 --bs=1024 --ep=350 --tblr=8e-5 --fp16=1 --alng=1e-5 --wpe=0.01 --twde=0.08\n# d36-s, 512x512 (-s means saln=1, shared AdaLN)\ntorchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \\\n  --depth=36 --saln=1 --pn=512 --bs=768 --ep=350 --tblr=8e-5 --fp16=1 --alng=5e-6 --wpe=0.01 --twde=0.08\n```\nA folder named `local_output` will be created to save the checkpoints and logs.\nYou can monitor the training process by checking the logs in `local_output/log.txt` and `local_output/stdout.txt`, or using `tensorboard --logdir=local_output/`.\n\nIf your experiment is interrupted, just rerun the command, and the training will **automatically resume** from the last checkpoint in `local_output/ckpt*.pth` (see [utils/misc.py#L344-L357](utils/misc.py#L344-L357)).\n\n## Sampling & Zero-shot Inference\n\nFor FID evaluation, use `var.autoregressive_infer_cfg(..., cfg=1.5, top_p=0.96, top_k=900, more_smooth=False)` to sample 50,000 images (50 per class) and save them as PNG (not JPEG) files in a folder. Pack them into a `.npz` file via `create_npz_from_sample_folder(sample_folder)` in [utils/misc.py#L344](utils/misc.py#L360).\nThen use the [OpenAI's FID evaluation toolkit](https://github.com/openai/guided-diffusion/tree/main/evaluations) and reference ground truth npz file of [256x256](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz) or [512x512](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/VIRTUAL_imagenet512.npz) to evaluate FID, IS, precision, and recall.\n\nNote a relatively small `cfg=1.5` is used for trade-off between image quality and diversity. You can adjust it to `cfg=5.0`, or sample with `autoregressive_infer_cfg(..., more_smooth=True)` for **better visual quality**.\nWe'll provide the sampling script later.\n\n\n## Third-party Usage and Research\n\n***In this pargraph, we cross link third-party repositories or research which use VAR and report results. You can let us know by raising an issue***\n\n(`Note please report accuracy numbers and provide trained models in your new repository to facilitate others to get sense of correctness and model behavior`)\n\n| **Time**     | **Research**                                                                                                                  | **Link**                                                           |\n|--------------|-------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------|\n| [5/12/2025]  | [ICML 2025]Continuous Visual Autoregressive Generation via Score Maximization                                                 | https://github.com/shaochenze/EAR                                  |\n| [5/8/2025]   | Generative Autoregressive Transformers for Model-Agnostic Federated MRI Reconstruction                                        | https://github.com/icon-lab/FedGAT                                 |\n| [4/7/2025]   | FastVAR: Linear Visual Autoregressive Modeling via Cached Token Pruning                                                       | https://github.com/csguoh/FastVAR                                  |\n| [4/3/2025]   | VARGPT-v1.1: Improve Visual Autoregressive Large Unified Model via Iterative Instruction Tuning and Reinforcement Learning    | https://github.com/VARGPT-family/VARGPT-v1.1                       |\n| [3/31/2025]  | Training-Free Text-Guided Image Editing with Visual Autoregressive Model                                                      | https://github.com/wyf0912/AREdit                                  |\n| [3/17/2025]  | Next-Scale Autoregressive Models are Zero-Shot Single-Image Object View Synthesizers                                          | https://github.com/Shiran-Yuan/ArchonView                          |\n| [3/14/2025]  | Safe-VAR: Safe Visual Autoregressive Model for Text-to-Image Generative Watermarking                                          | https://arxiv.org/abs/2503.11324                                   |\n| [3/3/2025]   | [ICML 2025]Direct Discriminative Optimization: Your Likelihood-Based Visual Generative Model is Secretly a GAN Discriminator  | https://research.nvidia.com/labs/dir/ddo/                          |\n| [2/28/2025]  | Autoregressive Medical Image Segmentation via Next-Scale Mask Prediction                                                      | https://arxiv.org/abs/2502.20784                                   |\n| [2/27/2025]  | FlexVAR: Flexible Visual Autoregressive Modeling without Residual Prediction                                                  | https://github.com/jiaosiyu1999/FlexVAR                            |\n| [2/17/2025]  | MARS: Mesh AutoRegressive Model for 3D Shape Detailization                                                                    | https://arxiv.org/abs/2502.11390                                   |\n| [1/31/2025]  | [ICML 2025]Visual Autoregressive Modeling for Image Super-Resolution                                                          | https://github.com/quyp2000/VARSR                                  |\n| [1/21/2025]  | VARGPT: Unified Understanding and Generation in a Visual Autoregressive Multimodal Large Language Model                       | https://github.com/VARGPT-family/VARGPT                            |\n| [1/26/2025]  | [ICML 2025]Visual Generation Without Guidance                                                                                 | https://github.com/thu-ml/GFT                                      |\n| [12/30/2024] | Next Token Prediction Towards Multimodal Intelligence                                                                         | https://github.com/LMM101/Awesome-Multimodal-Next-Token-Prediction |\n| [12/30/2024] | Varformer: Adapting VAR’s Generative Prior for Image Restoration                                                              | https://arxiv.org/abs/2412.21063                                   |\n| [12/22/2024] | [ICLR 2025]Distilled Decoding 1: One-step Sampling of Image Auto-regressive Models with Flow Matching                         | https://github.com/imagination-research/distilled-decoding         |\n| [12/19/2024] | FlowAR: Scale-wise Autoregressive Image Generation Meets Flow Matching                                                        | https://github.com/OliverRensu/FlowAR                              |\n| [12/13/2024] | 3D representation in 512-Byte: Variational tokenizer is the key for autoregressive 3D generation                              | https://github.com/sparse-mvs-2/VAT                                |\n| [12/9/2024]  | CARP: Visuomotor Policy Learning via Coarse-to-Fine Autoregressive Prediction                                                 | https://carp-robot.github.io/                                      |\n| [12/5/2024]  | [CVPR 2025]Infinity ∞: Scaling Bitwise AutoRegressive Modeling for High-Resolution Image Synthesis                            | https://github.com/FoundationVision/Infinity                       |\n| [12/5/2024]  | [CVPR 2025]Switti: Designing Scale-Wise Transformers for Text-to-Image Synthesis                                              | https://github.com/yandex-research/switti                          |\n| [12/4/2024]  | [CVPR 2025]TokenFlow🚀: Unified Image Tokenizer for Multimodal Understanding and Generation                                   | https://github.com/ByteFlow-AI/TokenFlow                           |\n| [12/3/2024]  | XQ-GAN🚀: An Open-source Image Tokenization Framework for Autoregressive Generation                                           | https://github.com/lxa9867/ImageFolder                             |\n| [11/28/2024] | [CVPR 2025]CoDe: Collaborative Decoding Makes Visual Auto-Regressive Modeling Efficient                                       | https://github.com/czg1225/CoDe                                    |\n| [11/28/2024] | [CVPR 2025]Scalable Autoregressive Monocular Depth Estimation                                                                 | https://arxiv.org/abs/2411.11361                                   |\n| [11/27/2024] | [CVPR 2025]SAR3D: Autoregressive 3D Object Generation and Understanding via Multi-scale 3D VQVAE                              | https://github.com/cyw-3d/SAR3D                                    |\n| [11/26/2024] | LiteVAR: Compressing Visual Autoregressive Modelling with Efficient Attention and Quantization                                | https://arxiv.org/abs/2411.17178                                   |\n| [11/15/2024] | M-VAR: Decoupled Scale-wise Autoregressive Modeling for High-Quality Image Generation                                         | https://github.com/OliverRensu/MVAR                                |\n| [10/14/2024] | [ICLR 2025]HART: Efficient Visual Generation with Hybrid Autoregressive Transformer                                           | https://github.com/mit-han-lab/hart                                |\n| [10/12/2024] | [ICLR 2025 Oral]Toward Guidance-Free AR Visual Generation via Condition Contrastive Alignment                                 | https://github.com/thu-ml/CCA                                      |\n| [10/3/2024]  | [ICLR 2025]ImageFolder🚀: Autoregressive Image Generation with Folded Tokens                                                  | https://github.com/lxa9867/ImageFolder                             |\n| [07/25/2024] | ControlVAR: Exploring Controllable Visual Autoregressive Modeling                                                             | https://github.com/lxa9867/ControlVAR                              |\n| [07/3/2024]  | VAR-CLIP: Text-to-Image Generator with Visual Auto-Regressive Modeling                                                        | https://github.com/daixiangzi/VAR-CLIP                             |\n| [06/16/2024] | STAR: Scale-wise Text-to-image generation via Auto-Regressive representations                                                 | https://arxiv.org/abs/2406.10797                                   |\n\n\n## License\nThis project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.\n\n\n## Citation\nIf our work assists your research, feel free to give us a star ⭐ or cite us using:\n```\n@Article{VAR,\n      title={Visual Autoregressive Modeling: Scalable Image Generation via Next-Scale Prediction}, \n      author={Keyu Tian and Yi Jiang and Zehuan Yuan and Bingyue Peng and Liwei Wang},\n      year={2024},\n      eprint={2404.02905},\n      archivePrefix={arXiv},\n      primaryClass={cs.CV}\n}\n```\n\n```\n@misc{Infinity,\n    title={Infinity: Scaling Bitwise AutoRegressive Modeling for High-Resolution Image Synthesis}, \n    author={Jian Han and Jinlai Liu and Yi Jiang and Bin Yan and Yuqi Zhang and Zehuan Yuan and Bingyue Peng and Xiaobing Liu},\n    year={2024},\n    eprint={2412.04431},\n    archivePrefix={arXiv},\n    primaryClass={cs.CV},\n    url={https://arxiv.org/abs/2412.04431}, \n}\n```\n"
  },
  {
    "path": "dist.py",
    "content": "import datetime\nimport functools\nimport os\nimport sys\nfrom typing import List\nfrom typing import Union\n\nimport torch\nimport torch.distributed as tdist\nimport torch.multiprocessing as mp\n\n__rank, __local_rank, __world_size, __device = 0, 0, 1, 'cuda' if torch.cuda.is_available() else 'cpu'\n__initialized = False\n\n\ndef initialized():\n    return __initialized\n\n\ndef initialize(fork=False, backend='nccl', gpu_id_if_not_distibuted=0, timeout=30):\n    global __device\n    if not torch.cuda.is_available():\n        print(f'[dist initialize] cuda is not available, use cpu instead', file=sys.stderr)\n        return\n    elif 'RANK' not in os.environ:\n        torch.cuda.set_device(gpu_id_if_not_distibuted)\n        __device = torch.empty(1).cuda().device\n        print(f'[dist initialize] env variable \"RANK\" is not set, use {__device} as the device', file=sys.stderr)\n        return\n    # then 'RANK' must exist\n    global_rank, num_gpus = int(os.environ['RANK']), torch.cuda.device_count()\n    local_rank = global_rank % num_gpus\n    torch.cuda.set_device(local_rank)\n    \n    # ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29\n    if mp.get_start_method(allow_none=True) is None:\n        method = 'fork' if fork else 'spawn'\n        print(f'[dist initialize] mp method={method}')\n        mp.set_start_method(method)\n    tdist.init_process_group(backend=backend, timeout=datetime.timedelta(seconds=timeout*60))\n    \n    global __rank, __local_rank, __world_size, __initialized\n    __local_rank = local_rank\n    __rank, __world_size = tdist.get_rank(), tdist.get_world_size()\n    __device = torch.empty(1).cuda().device\n    __initialized = True\n    \n    assert tdist.is_initialized(), 'torch.distributed is not initialized!'\n    print(f'[lrk={get_local_rank()}, rk={get_rank()}]')\n\n\ndef get_rank():\n    return __rank\n\n\ndef get_local_rank():\n    return __local_rank\n\n\ndef get_world_size():\n    return __world_size\n\n\ndef get_device():\n    return __device\n\n\ndef set_gpu_id(gpu_id: int):\n    if gpu_id is None: return\n    global __device\n    if isinstance(gpu_id, (str, int)):\n        torch.cuda.set_device(int(gpu_id))\n        __device = torch.empty(1).cuda().device\n    else:\n        raise NotImplementedError\n\n\ndef is_master():\n    return __rank == 0\n\n\ndef is_local_master():\n    return __local_rank == 0\n\n\ndef new_group(ranks: List[int]):\n    if __initialized:\n        return tdist.new_group(ranks=ranks)\n    return None\n\n\ndef barrier():\n    if __initialized:\n        tdist.barrier()\n\n\ndef allreduce(t: torch.Tensor, async_op=False):\n    if __initialized:\n        if not t.is_cuda:\n            cu = t.detach().cuda()\n            ret = tdist.all_reduce(cu, async_op=async_op)\n            t.copy_(cu.cpu())\n        else:\n            ret = tdist.all_reduce(t, async_op=async_op)\n        return ret\n    return None\n\n\ndef allgather(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]:\n    if __initialized:\n        if not t.is_cuda:\n            t = t.cuda()\n        ls = [torch.empty_like(t) for _ in range(__world_size)]\n        tdist.all_gather(ls, t)\n    else:\n        ls = [t]\n    if cat:\n        ls = torch.cat(ls, dim=0)\n    return ls\n\n\ndef allgather_diff_shape(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]:\n    if __initialized:\n        if not t.is_cuda:\n            t = t.cuda()\n        \n        t_size = torch.tensor(t.size(), device=t.device)\n        ls_size = [torch.empty_like(t_size) for _ in range(__world_size)]\n        tdist.all_gather(ls_size, t_size)\n        \n        max_B = max(size[0].item() for size in ls_size)\n        pad = max_B - t_size[0].item()\n        if pad:\n            pad_size = (pad, *t.size()[1:])\n            t = torch.cat((t, t.new_empty(pad_size)), dim=0)\n        \n        ls_padded = [torch.empty_like(t) for _ in range(__world_size)]\n        tdist.all_gather(ls_padded, t)\n        ls = []\n        for t, size in zip(ls_padded, ls_size):\n            ls.append(t[:size[0].item()])\n    else:\n        ls = [t]\n    if cat:\n        ls = torch.cat(ls, dim=0)\n    return ls\n\n\ndef broadcast(t: torch.Tensor, src_rank) -> None:\n    if __initialized:\n        if not t.is_cuda:\n            cu = t.detach().cuda()\n            tdist.broadcast(cu, src=src_rank)\n            t.copy_(cu.cpu())\n        else:\n            tdist.broadcast(t, src=src_rank)\n\n\ndef dist_fmt_vals(val: float, fmt: Union[str, None] = '%.2f') -> Union[torch.Tensor, List]:\n    if not initialized():\n        return torch.tensor([val]) if fmt is None else [fmt % val]\n    \n    ts = torch.zeros(__world_size)\n    ts[__rank] = val\n    allreduce(ts)\n    if fmt is None:\n        return ts\n    return [fmt % v for v in ts.cpu().numpy().tolist()]\n\n\ndef master_only(func):\n    @functools.wraps(func)\n    def wrapper(*args, **kwargs):\n        force = kwargs.pop('force', False)\n        if force or is_master():\n            ret = func(*args, **kwargs)\n        else:\n            ret = None\n        barrier()\n        return ret\n    return wrapper\n\n\ndef local_master_only(func):\n    @functools.wraps(func)\n    def wrapper(*args, **kwargs):\n        force = kwargs.pop('force', False)\n        if force or is_local_master():\n            ret = func(*args, **kwargs)\n        else:\n            ret = None\n        barrier()\n        return ret\n    return wrapper\n\n\ndef for_visualize(func):\n    @functools.wraps(func)\n    def wrapper(*args, **kwargs):\n        if is_master():\n            # with torch.no_grad():\n            ret = func(*args, **kwargs)\n        else:\n            ret = None\n        return ret\n    return wrapper\n\n\ndef finalize():\n    if __initialized:\n        tdist.destroy_process_group()\n"
  },
  {
    "path": "models/__init__.py",
    "content": "from typing import Tuple\nimport torch.nn as nn\n\nfrom .quant import VectorQuantizer2\nfrom .var import VAR\nfrom .vqvae import VQVAE\n\n\ndef build_vae_var(\n    # Shared args\n    device, patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16),   # 10 steps by default\n    # VQVAE args\n    V=4096, Cvae=32, ch=160, share_quant_resi=4,\n    # VAR args\n    num_classes=1000, depth=16, shared_aln=False, attn_l2_norm=True,\n    flash_if_available=True, fused_if_available=True,\n    init_adaln=0.5, init_adaln_gamma=1e-5, init_head=0.02, init_std=-1,    # init_std < 0: automated\n) -> Tuple[VQVAE, VAR]:\n    heads = depth\n    width = depth * 64\n    dpr = 0.1 * depth/24\n    \n    # disable built-in initialization for speed\n    for clz in (nn.Linear, nn.LayerNorm, nn.BatchNorm2d, nn.SyncBatchNorm, nn.Conv1d, nn.Conv2d, nn.ConvTranspose1d, nn.ConvTranspose2d):\n        setattr(clz, 'reset_parameters', lambda self: None)\n    \n    # build models\n    vae_local = VQVAE(vocab_size=V, z_channels=Cvae, ch=ch, test_mode=True, share_quant_resi=share_quant_resi, v_patch_nums=patch_nums).to(device)\n    var_wo_ddp = VAR(\n        vae_local=vae_local,\n        num_classes=num_classes, depth=depth, embed_dim=width, num_heads=heads, drop_rate=0., attn_drop_rate=0., drop_path_rate=dpr,\n        norm_eps=1e-6, shared_aln=shared_aln, cond_drop_rate=0.1,\n        attn_l2_norm=attn_l2_norm,\n        patch_nums=patch_nums,\n        flash_if_available=flash_if_available, fused_if_available=fused_if_available,\n    ).to(device)\n    var_wo_ddp.init_weights(init_adaln=init_adaln, init_adaln_gamma=init_adaln_gamma, init_head=init_head, init_std=init_std)\n    \n    return vae_local, var_wo_ddp\n"
  },
  {
    "path": "models/basic_vae.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\n# this file only provides the 2 modules used in VQVAE\n__all__ = ['Encoder', 'Decoder',]\n\n\n\"\"\"\nReferences: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/diffusionmodules/model.py\n\"\"\"\n# swish\ndef nonlinearity(x):\n    return x * torch.sigmoid(x)\n\n\ndef Normalize(in_channels, num_groups=32):\n    return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)\n\n\nclass Upsample2x(nn.Module):\n    def __init__(self, in_channels):\n        super().__init__()\n        self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)\n    \n    def forward(self, x):\n        return self.conv(F.interpolate(x, scale_factor=2, mode='nearest'))\n\n\nclass Downsample2x(nn.Module):\n    def __init__(self, in_channels):\n        super().__init__()\n        self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)\n    \n    def forward(self, x):\n        return self.conv(F.pad(x, pad=(0, 1, 0, 1), mode='constant', value=0))\n\n\nclass ResnetBlock(nn.Module):\n    def __init__(self, *, in_channels, out_channels=None, dropout): # conv_shortcut=False,  # conv_shortcut: always False in VAE\n        super().__init__()\n        self.in_channels = in_channels\n        out_channels = in_channels if out_channels is None else out_channels\n        self.out_channels = out_channels\n        \n        self.norm1 = Normalize(in_channels)\n        self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)\n        self.norm2 = Normalize(out_channels)\n        self.dropout = torch.nn.Dropout(dropout) if dropout > 1e-6 else nn.Identity()\n        self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)\n        if self.in_channels != self.out_channels:\n            self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)\n        else:\n            self.nin_shortcut = nn.Identity()\n    \n    def forward(self, x):\n        h = self.conv1(F.silu(self.norm1(x), inplace=True))\n        h = self.conv2(self.dropout(F.silu(self.norm2(h), inplace=True)))\n        return self.nin_shortcut(x) + h\n\n\nclass AttnBlock(nn.Module):\n    def __init__(self, in_channels):\n        super().__init__()\n        self.C = in_channels\n        \n        self.norm = Normalize(in_channels)\n        self.qkv = torch.nn.Conv2d(in_channels, 3*in_channels, kernel_size=1, stride=1, padding=0)\n        self.w_ratio = int(in_channels) ** (-0.5)\n        self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)\n    \n    def forward(self, x):\n        qkv = self.qkv(self.norm(x))\n        B, _, H, W = qkv.shape  # should be B,3C,H,W\n        C = self.C\n        q, k, v = qkv.reshape(B, 3, C, H, W).unbind(1)\n        \n        # compute attention\n        q = q.view(B, C, H * W).contiguous()\n        q = q.permute(0, 2, 1).contiguous()     # B,HW,C\n        k = k.view(B, C, H * W).contiguous()    # B,C,HW\n        w = torch.bmm(q, k).mul_(self.w_ratio)  # B,HW,HW    w[B,i,j]=sum_c q[B,i,C]k[B,C,j]\n        w = F.softmax(w, dim=2)\n        \n        # attend to values\n        v = v.view(B, C, H * W).contiguous()\n        w = w.permute(0, 2, 1).contiguous()  # B,HW,HW (first HW of k, second of q)\n        h = torch.bmm(v, w)  # B, C,HW (HW of q) h[B,C,j] = sum_i v[B,C,i] w[B,i,j]\n        h = h.view(B, C, H, W).contiguous()\n        \n        return x + self.proj_out(h)\n\n\ndef make_attn(in_channels, using_sa=True):\n    return AttnBlock(in_channels) if using_sa else nn.Identity()\n\n\nclass Encoder(nn.Module):\n    def __init__(\n        self, *, ch=128, ch_mult=(1, 2, 4, 8), num_res_blocks=2,\n        dropout=0.0, in_channels=3,\n        z_channels, double_z=False, using_sa=True, using_mid_sa=True,\n    ):\n        super().__init__()\n        self.ch = ch\n        self.num_resolutions = len(ch_mult)\n        self.downsample_ratio = 2 ** (self.num_resolutions - 1)\n        self.num_res_blocks = num_res_blocks\n        self.in_channels = in_channels\n        \n        # downsampling\n        self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)\n        \n        in_ch_mult = (1,) + tuple(ch_mult)\n        self.down = nn.ModuleList()\n        for i_level in range(self.num_resolutions):\n            block = nn.ModuleList()\n            attn = nn.ModuleList()\n            block_in = ch * in_ch_mult[i_level]\n            block_out = ch * ch_mult[i_level]\n            for i_block in range(self.num_res_blocks):\n                block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dropout=dropout))\n                block_in = block_out\n                if i_level == self.num_resolutions - 1 and using_sa:\n                    attn.append(make_attn(block_in, using_sa=True))\n            down = nn.Module()\n            down.block = block\n            down.attn = attn\n            if i_level != self.num_resolutions - 1:\n                down.downsample = Downsample2x(block_in)\n            self.down.append(down)\n        \n        # middle\n        self.mid = nn.Module()\n        self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout)\n        self.mid.attn_1 = make_attn(block_in, using_sa=using_mid_sa)\n        self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout)\n        \n        # end\n        self.norm_out = Normalize(block_in)\n        self.conv_out = torch.nn.Conv2d(block_in, (2 * z_channels if double_z else z_channels), kernel_size=3, stride=1, padding=1)\n    \n    def forward(self, x):\n        # downsampling\n        h = self.conv_in(x)\n        for i_level in range(self.num_resolutions):\n            for i_block in range(self.num_res_blocks):\n                h = self.down[i_level].block[i_block](h)\n                if len(self.down[i_level].attn) > 0:\n                    h = self.down[i_level].attn[i_block](h)\n            if i_level != self.num_resolutions - 1:\n                h = self.down[i_level].downsample(h)\n        \n        # middle\n        h = self.mid.block_2(self.mid.attn_1(self.mid.block_1(h)))\n        \n        # end\n        h = self.conv_out(F.silu(self.norm_out(h), inplace=True))\n        return h\n\n\nclass Decoder(nn.Module):\n    def __init__(\n        self, *, ch=128, ch_mult=(1, 2, 4, 8), num_res_blocks=2,\n        dropout=0.0, in_channels=3,  # in_channels: raw img channels\n        z_channels, using_sa=True, using_mid_sa=True,\n    ):\n        super().__init__()\n        self.ch = ch\n        self.num_resolutions = len(ch_mult)\n        self.num_res_blocks = num_res_blocks\n        self.in_channels = in_channels\n        \n        # compute in_ch_mult, block_in and curr_res at lowest res\n        in_ch_mult = (1,) + tuple(ch_mult)\n        block_in = ch * ch_mult[self.num_resolutions - 1]\n        \n        # z to block_in\n        self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)\n        \n        # middle\n        self.mid = nn.Module()\n        self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout)\n        self.mid.attn_1 = make_attn(block_in, using_sa=using_mid_sa)\n        self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout)\n        \n        # upsampling\n        self.up = nn.ModuleList()\n        for i_level in reversed(range(self.num_resolutions)):\n            block = nn.ModuleList()\n            attn = nn.ModuleList()\n            block_out = ch * ch_mult[i_level]\n            for i_block in range(self.num_res_blocks + 1):\n                block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dropout=dropout))\n                block_in = block_out\n                if i_level == self.num_resolutions-1 and using_sa:\n                    attn.append(make_attn(block_in, using_sa=True))\n            up = nn.Module()\n            up.block = block\n            up.attn = attn\n            if i_level != 0:\n                up.upsample = Upsample2x(block_in)\n            self.up.insert(0, up)  # prepend to get consistent order\n        \n        # end\n        self.norm_out = Normalize(block_in)\n        self.conv_out = torch.nn.Conv2d(block_in, in_channels, kernel_size=3, stride=1, padding=1)\n    \n    def forward(self, z):\n        # z to block_in\n        # middle\n        h = self.mid.block_2(self.mid.attn_1(self.mid.block_1(self.conv_in(z))))\n        \n        # upsampling\n        for i_level in reversed(range(self.num_resolutions)):\n            for i_block in range(self.num_res_blocks + 1):\n                h = self.up[i_level].block[i_block](h)\n                if len(self.up[i_level].attn) > 0:\n                    h = self.up[i_level].attn[i_block](h)\n            if i_level != 0:\n                h = self.up[i_level].upsample(h)\n        \n        # end\n        h = self.conv_out(F.silu(self.norm_out(h), inplace=True))\n        return h\n"
  },
  {
    "path": "models/basic_var.py",
    "content": "import math\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom models.helpers import DropPath, drop_path\n\n\n# this file only provides the 3 blocks used in VAR transformer\n__all__ = ['FFN', 'AdaLNSelfAttn', 'AdaLNBeforeHead']\n\n\n# automatically import fused operators\ndropout_add_layer_norm = fused_mlp_func = memory_efficient_attention = flash_attn_func = None\ntry:\n    from flash_attn.ops.layer_norm import dropout_add_layer_norm\n    from flash_attn.ops.fused_dense import fused_mlp_func\nexcept ImportError: pass\n# automatically import faster attention implementations\ntry: from xformers.ops import memory_efficient_attention\nexcept ImportError: pass\ntry: from flash_attn import flash_attn_func              # qkv: BLHc, ret: BLHcq\nexcept ImportError: pass\ntry: from torch.nn.functional import scaled_dot_product_attention as slow_attn    # q, k, v: BHLc\nexcept ImportError:\n    def slow_attn(query, key, value, scale: float, attn_mask=None, dropout_p=0.0):\n        attn = query.mul(scale) @ key.transpose(-2, -1) # BHLc @ BHcL => BHLL\n        if attn_mask is not None: attn.add_(attn_mask)\n        return (F.dropout(attn.softmax(dim=-1), p=dropout_p, inplace=True) if dropout_p > 0 else attn.softmax(dim=-1)) @ value\n\n\nclass FFN(nn.Module):\n    def __init__(self, in_features, hidden_features=None, out_features=None, drop=0., fused_if_available=True):\n        super().__init__()\n        self.fused_mlp_func = fused_mlp_func if fused_if_available else None\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = nn.GELU(approximate='tanh')\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop, inplace=True) if drop > 0 else nn.Identity()\n    \n    def forward(self, x):\n        if self.fused_mlp_func is not None:\n            return self.drop(self.fused_mlp_func(\n                x=x, weight1=self.fc1.weight, weight2=self.fc2.weight, bias1=self.fc1.bias, bias2=self.fc2.bias,\n                activation='gelu_approx', save_pre_act=self.training, return_residual=False, checkpoint_lvl=0,\n                heuristic=0, process_group=None,\n            ))\n        else:\n            return self.drop(self.fc2( self.act(self.fc1(x)) ))\n    \n    def extra_repr(self) -> str:\n        return f'fused_mlp_func={self.fused_mlp_func is not None}'\n\n\nclass SelfAttention(nn.Module):\n    def __init__(\n        self, block_idx, embed_dim=768, num_heads=12,\n        attn_drop=0., proj_drop=0., attn_l2_norm=False, flash_if_available=True,\n    ):\n        super().__init__()\n        assert embed_dim % num_heads == 0\n        self.block_idx, self.num_heads, self.head_dim = block_idx, num_heads, embed_dim // num_heads  # =64\n        self.attn_l2_norm = attn_l2_norm\n        if self.attn_l2_norm:\n            self.scale = 1\n            self.scale_mul_1H11 = nn.Parameter(torch.full(size=(1, self.num_heads, 1, 1), fill_value=4.0).log(), requires_grad=True)\n            self.max_scale_mul = torch.log(torch.tensor(100)).item()\n        else:\n            self.scale = 0.25 / math.sqrt(self.head_dim)\n        \n        self.mat_qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False)\n        self.q_bias, self.v_bias = nn.Parameter(torch.zeros(embed_dim)), nn.Parameter(torch.zeros(embed_dim))\n        self.register_buffer('zero_k_bias', torch.zeros(embed_dim))\n        \n        self.proj = nn.Linear(embed_dim, embed_dim)\n        self.proj_drop = nn.Dropout(proj_drop, inplace=True) if proj_drop > 0 else nn.Identity()\n        self.attn_drop: float = attn_drop\n        self.using_flash = flash_if_available and flash_attn_func is not None\n        self.using_xform = flash_if_available and memory_efficient_attention is not None\n        \n        # only used during inference\n        self.caching, self.cached_k, self.cached_v = False, None, None\n    \n    def kv_caching(self, enable: bool): self.caching, self.cached_k, self.cached_v = enable, None, None\n    \n    # NOTE: attn_bias is None during inference because kv cache is enabled\n    def forward(self, x, attn_bias):\n        B, L, C = x.shape\n        \n        qkv = F.linear(input=x, weight=self.mat_qkv.weight, bias=torch.cat((self.q_bias, self.zero_k_bias, self.v_bias))).view(B, L, 3, self.num_heads, self.head_dim)\n        main_type = qkv.dtype\n        # qkv: BL3Hc\n        \n        using_flash = self.using_flash and attn_bias is None and qkv.dtype != torch.float32\n        if using_flash or self.using_xform: q, k, v = qkv.unbind(dim=2); dim_cat = 1   # q or k or v: BLHc\n        else: q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(dim=0); dim_cat = 2               # q or k or v: BHLc\n        \n        if self.attn_l2_norm:\n            scale_mul = self.scale_mul_1H11.clamp_max(self.max_scale_mul).exp()\n            if using_flash or self.using_xform: scale_mul = scale_mul.transpose(1, 2)  # 1H11 to 11H1\n            q = F.normalize(q, dim=-1).mul(scale_mul)\n            k = F.normalize(k, dim=-1)\n        \n        if self.caching:\n            if self.cached_k is None: self.cached_k = k; self.cached_v = v\n            else: k = self.cached_k = torch.cat((self.cached_k, k), dim=dim_cat); v = self.cached_v = torch.cat((self.cached_v, v), dim=dim_cat)\n        \n        dropout_p = self.attn_drop if self.training else 0.0\n        if using_flash:\n            oup = flash_attn_func(q.to(dtype=main_type), k.to(dtype=main_type), v.to(dtype=main_type), dropout_p=dropout_p, softmax_scale=self.scale).view(B, L, C)\n        elif self.using_xform:\n            oup = memory_efficient_attention(q.to(dtype=main_type), k.to(dtype=main_type), v.to(dtype=main_type), attn_bias=None if attn_bias is None else attn_bias.to(dtype=main_type).expand(B, self.num_heads, -1, -1), p=dropout_p, scale=self.scale).view(B, L, C)\n        else:\n            oup = slow_attn(query=q, key=k, value=v, scale=self.scale, attn_mask=attn_bias, dropout_p=dropout_p).transpose(1, 2).reshape(B, L, C)\n        \n        return self.proj_drop(self.proj(oup))\n        # attn = (q @ k.transpose(-2, -1)).add_(attn_bias + self.local_rpb())  # BHLc @ BHcL => BHLL\n        # attn = self.attn_drop(attn.softmax(dim=-1))\n        # oup = (attn @ v).transpose_(1, 2).reshape(B, L, -1)     # BHLL @ BHLc = BHLc => BLHc => BLC\n    \n    def extra_repr(self) -> str:\n        return f'using_flash={self.using_flash}, using_xform={self.using_xform}, attn_l2_norm={self.attn_l2_norm}'\n\n\nclass AdaLNSelfAttn(nn.Module):\n    def __init__(\n        self, block_idx, last_drop_p, embed_dim, cond_dim, shared_aln: bool, norm_layer,\n        num_heads, mlp_ratio=4., drop=0., attn_drop=0., drop_path=0., attn_l2_norm=False,\n        flash_if_available=False, fused_if_available=True,\n    ):\n        super(AdaLNSelfAttn, self).__init__()\n        self.block_idx, self.last_drop_p, self.C = block_idx, last_drop_p, embed_dim\n        self.C, self.D = embed_dim, cond_dim\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.attn = SelfAttention(block_idx=block_idx, embed_dim=embed_dim, num_heads=num_heads, attn_drop=attn_drop, proj_drop=drop, attn_l2_norm=attn_l2_norm, flash_if_available=flash_if_available)\n        self.ffn = FFN(in_features=embed_dim, hidden_features=round(embed_dim * mlp_ratio), drop=drop, fused_if_available=fused_if_available)\n        \n        self.ln_wo_grad = norm_layer(embed_dim, elementwise_affine=False)\n        self.shared_aln = shared_aln\n        if self.shared_aln:\n            self.ada_gss = nn.Parameter(torch.randn(1, 1, 6, embed_dim) / embed_dim**0.5)\n        else:\n            lin = nn.Linear(cond_dim, 6*embed_dim)\n            self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), lin)\n        \n        self.fused_add_norm_fn = None\n    \n    # NOTE: attn_bias is None during inference because kv cache is enabled\n    def forward(self, x, cond_BD, attn_bias):   # C: embed_dim, D: cond_dim\n        if self.shared_aln:\n            gamma1, gamma2, scale1, scale2, shift1, shift2 = (self.ada_gss + cond_BD).unbind(2) # 116C + B16C =unbind(2)=> 6 B1C\n        else:\n            gamma1, gamma2, scale1, scale2, shift1, shift2 = self.ada_lin(cond_BD).view(-1, 1, 6, self.C).unbind(2)\n        x = x + self.drop_path(self.attn( self.ln_wo_grad(x).mul(scale1.add(1)).add_(shift1), attn_bias=attn_bias ).mul_(gamma1))\n        x = x + self.drop_path(self.ffn( self.ln_wo_grad(x).mul(scale2.add(1)).add_(shift2) ).mul(gamma2)) # this mul(gamma2) cannot be in-placed when FusedMLP is used\n        return x\n    \n    def extra_repr(self) -> str:\n        return f'shared_aln={self.shared_aln}'\n\n\nclass AdaLNBeforeHead(nn.Module):\n    def __init__(self, C, D, norm_layer):   # C: embed_dim, D: cond_dim\n        super().__init__()\n        self.C, self.D = C, D\n        self.ln_wo_grad = norm_layer(C, elementwise_affine=False)\n        self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), nn.Linear(D, 2*C))\n    \n    def forward(self, x_BLC: torch.Tensor, cond_BD: torch.Tensor):\n        scale, shift = self.ada_lin(cond_BD).view(-1, 1, 2, self.C).unbind(2)\n        return self.ln_wo_grad(x_BLC).mul(scale.add(1)).add_(shift)\n"
  },
  {
    "path": "models/helpers.py",
    "content": "import torch\nfrom torch import nn as nn\nfrom torch.nn import functional as F\n\n\ndef sample_with_top_k_top_p_(logits_BlV: torch.Tensor, top_k: int = 0, top_p: float = 0.0, rng=None, num_samples=1) -> torch.Tensor:  # return idx, shaped (B, l)\n    B, l, V = logits_BlV.shape\n    if top_k > 0:\n        idx_to_remove = logits_BlV < logits_BlV.topk(top_k, largest=True, sorted=False, dim=-1)[0].amin(dim=-1, keepdim=True)\n        logits_BlV.masked_fill_(idx_to_remove, -torch.inf)\n    if top_p > 0:\n        sorted_logits, sorted_idx = logits_BlV.sort(dim=-1, descending=False)\n        sorted_idx_to_remove = sorted_logits.softmax(dim=-1).cumsum_(dim=-1) <= (1 - top_p)\n        sorted_idx_to_remove[..., -1:] = False\n        logits_BlV.masked_fill_(sorted_idx_to_remove.scatter(sorted_idx.ndim - 1, sorted_idx, sorted_idx_to_remove), -torch.inf)\n    # sample (have to squeeze cuz torch.multinomial can only be used for 2D tensor)\n    replacement = num_samples >= 0\n    num_samples = abs(num_samples)\n    return torch.multinomial(logits_BlV.softmax(dim=-1).view(-1, V), num_samples=num_samples, replacement=replacement, generator=rng).view(B, l, num_samples)\n\n\ndef gumbel_softmax_with_rng(logits: torch.Tensor, tau: float = 1, hard: bool = False, eps: float = 1e-10, dim: int = -1, rng: torch.Generator = None) -> torch.Tensor:\n    if rng is None:\n        return F.gumbel_softmax(logits=logits, tau=tau, hard=hard, eps=eps, dim=dim)\n    \n    gumbels = (-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_(generator=rng).log())\n    gumbels = (logits + gumbels) / tau\n    y_soft = gumbels.softmax(dim)\n    \n    if hard:\n        index = y_soft.max(dim, keepdim=True)[1]\n        y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)\n        ret = y_hard - y_soft.detach() + y_soft\n    else:\n        ret = y_soft\n    return ret\n\n\ndef drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):    # taken from timm\n    if drop_prob == 0. or not training: return x\n    keep_prob = 1 - drop_prob\n    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets\n    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)\n    if keep_prob > 0.0 and scale_by_keep:\n        random_tensor.div_(keep_prob)\n    return x * random_tensor\n\n\nclass DropPath(nn.Module):  # taken from timm\n    def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):\n        super(DropPath, self).__init__()\n        self.drop_prob = drop_prob\n        self.scale_by_keep = scale_by_keep\n    \n    def forward(self, x):\n        return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)\n    \n    def extra_repr(self):\n        return f'(drop_prob=...)'\n"
  },
  {
    "path": "models/quant.py",
    "content": "from typing import List, Optional, Sequence, Tuple, Union\n\nimport numpy as np\nimport torch\nfrom torch import distributed as tdist, nn as nn\nfrom torch.nn import functional as F\n\nimport dist\n\n\n# this file only provides the VectorQuantizer2 used in VQVAE\n__all__ = ['VectorQuantizer2',]\n\n\nclass VectorQuantizer2(nn.Module):\n    # VQGAN originally use beta=1.0, never tried 0.25; SD seems using 0.25\n    def __init__(\n        self, vocab_size, Cvae, using_znorm, beta: float = 0.25,\n        default_qresi_counts=0, v_patch_nums=None, quant_resi=0.5, share_quant_resi=4,  # share_quant_resi: args.qsr\n    ):\n        super().__init__()\n        self.vocab_size: int = vocab_size\n        self.Cvae: int = Cvae\n        self.using_znorm: bool = using_znorm\n        self.v_patch_nums: Tuple[int] = v_patch_nums\n        \n        self.quant_resi_ratio = quant_resi\n        if share_quant_resi == 0:   # non-shared: \\phi_{1 to K} for K scales\n            self.quant_resi = PhiNonShared([(Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity()) for _ in range(default_qresi_counts or len(self.v_patch_nums))])\n        elif share_quant_resi == 1: # fully shared: only a single \\phi for K scales\n            self.quant_resi = PhiShared(Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity())\n        else:                       # partially shared: \\phi_{1 to share_quant_resi} for K scales\n            self.quant_resi = PhiPartiallyShared(nn.ModuleList([(Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity()) for _ in range(share_quant_resi)]))\n        \n        self.register_buffer('ema_vocab_hit_SV', torch.full((len(self.v_patch_nums), self.vocab_size), fill_value=0.0))\n        self.record_hit = 0\n        \n        self.beta: float = beta\n        self.embedding = nn.Embedding(self.vocab_size, self.Cvae)\n        \n        # only used for progressive training of VAR (not supported yet, will be tested and supported in the future)\n        self.prog_si = -1   # progressive training: not supported yet, prog_si always -1\n    \n    def eini(self, eini):\n        if eini > 0: nn.init.trunc_normal_(self.embedding.weight.data, std=eini)\n        elif eini < 0: self.embedding.weight.data.uniform_(-abs(eini) / self.vocab_size, abs(eini) / self.vocab_size)\n    \n    def extra_repr(self) -> str:\n        return f'{self.v_patch_nums}, znorm={self.using_znorm}, beta={self.beta}  |  S={len(self.v_patch_nums)}, quant_resi={self.quant_resi_ratio}'\n    \n    # ===================== `forward` is only used in VAE training =====================\n    def forward(self, f_BChw: torch.Tensor, ret_usages=False) -> Tuple[torch.Tensor, List[float], torch.Tensor]:\n        dtype = f_BChw.dtype\n        if dtype != torch.float32: f_BChw = f_BChw.float()\n        B, C, H, W = f_BChw.shape\n        f_no_grad = f_BChw.detach()\n        \n        f_rest = f_no_grad.clone()\n        f_hat = torch.zeros_like(f_rest)\n        \n        with torch.cuda.amp.autocast(enabled=False):\n            mean_vq_loss: torch.Tensor = 0.0\n            vocab_hit_V = torch.zeros(self.vocab_size, dtype=torch.float, device=f_BChw.device)\n            SN = len(self.v_patch_nums)\n            for si, pn in enumerate(self.v_patch_nums): # from small to large\n                # find the nearest embedding\n                if self.using_znorm:\n                    rest_NC = F.interpolate(f_rest, size=(pn, pn), mode='area').permute(0, 2, 3, 1).reshape(-1, C) if (si != SN-1) else f_rest.permute(0, 2, 3, 1).reshape(-1, C)\n                    rest_NC = F.normalize(rest_NC, dim=-1)\n                    idx_N = torch.argmax(rest_NC @ F.normalize(self.embedding.weight.data.T, dim=0), dim=1)\n                else:\n                    rest_NC = F.interpolate(f_rest, size=(pn, pn), mode='area').permute(0, 2, 3, 1).reshape(-1, C) if (si != SN-1) else f_rest.permute(0, 2, 3, 1).reshape(-1, C)\n                    d_no_grad = torch.sum(rest_NC.square(), dim=1, keepdim=True) + torch.sum(self.embedding.weight.data.square(), dim=1, keepdim=False)\n                    d_no_grad.addmm_(rest_NC, self.embedding.weight.data.T, alpha=-2, beta=1)  # (B*h*w, vocab_size)\n                    idx_N = torch.argmin(d_no_grad, dim=1)\n                \n                hit_V = idx_N.bincount(minlength=self.vocab_size).float()\n                if self.training:\n                    if dist.initialized(): handler = tdist.all_reduce(hit_V, async_op=True)\n                \n                # calc loss\n                idx_Bhw = idx_N.view(B, pn, pn)\n                h_BChw = F.interpolate(self.embedding(idx_Bhw).permute(0, 3, 1, 2), size=(H, W), mode='bicubic').contiguous() if (si != SN-1) else self.embedding(idx_Bhw).permute(0, 3, 1, 2).contiguous()\n                h_BChw = self.quant_resi[si/(SN-1)](h_BChw)\n                f_hat = f_hat + h_BChw\n                f_rest -= h_BChw\n                \n                if self.training and dist.initialized():\n                    handler.wait()\n                    if self.record_hit == 0: self.ema_vocab_hit_SV[si].copy_(hit_V)\n                    elif self.record_hit < 100: self.ema_vocab_hit_SV[si].mul_(0.9).add_(hit_V.mul(0.1))\n                    else: self.ema_vocab_hit_SV[si].mul_(0.99).add_(hit_V.mul(0.01))\n                    self.record_hit += 1\n                vocab_hit_V.add_(hit_V)\n                mean_vq_loss += F.mse_loss(f_hat.data, f_BChw).mul_(self.beta) + F.mse_loss(f_hat, f_no_grad)\n            \n            mean_vq_loss *= 1. / SN\n            f_hat = (f_hat.data - f_no_grad).add_(f_BChw)\n        \n        margin = tdist.get_world_size() * (f_BChw.numel() / f_BChw.shape[1]) / self.vocab_size * 0.08\n        # margin = pn*pn / 100\n        if ret_usages: usages = [(self.ema_vocab_hit_SV[si] >= margin).float().mean().item() * 100 for si, pn in enumerate(self.v_patch_nums)]\n        else: usages = None\n        return f_hat, usages, mean_vq_loss\n    # ===================== `forward` is only used in VAE training =====================\n    \n    def embed_to_fhat(self, ms_h_BChw: List[torch.Tensor], all_to_max_scale=True, last_one=False) -> Union[List[torch.Tensor], torch.Tensor]:\n        ls_f_hat_BChw = []\n        B = ms_h_BChw[0].shape[0]\n        H = W = self.v_patch_nums[-1]\n        SN = len(self.v_patch_nums)\n        if all_to_max_scale:\n            f_hat = ms_h_BChw[0].new_zeros(B, self.Cvae, H, W, dtype=torch.float32)\n            for si, pn in enumerate(self.v_patch_nums): # from small to large\n                h_BChw = ms_h_BChw[si]\n                if si < len(self.v_patch_nums) - 1:\n                    h_BChw = F.interpolate(h_BChw, size=(H, W), mode='bicubic')\n                h_BChw = self.quant_resi[si/(SN-1)](h_BChw)\n                f_hat.add_(h_BChw)\n                if last_one: ls_f_hat_BChw = f_hat\n                else: ls_f_hat_BChw.append(f_hat.clone())\n        else:\n            # WARNING: this is not the case in VQ-VAE training or inference (we'll interpolate every token map to the max H W, like above)\n            # WARNING: this should only be used for experimental purpose\n            f_hat = ms_h_BChw[0].new_zeros(B, self.Cvae, self.v_patch_nums[0], self.v_patch_nums[0], dtype=torch.float32)\n            for si, pn in enumerate(self.v_patch_nums): # from small to large\n                f_hat = F.interpolate(f_hat, size=(pn, pn), mode='bicubic')\n                h_BChw = self.quant_resi[si/(SN-1)](ms_h_BChw[si])\n                f_hat.add_(h_BChw)\n                if last_one: ls_f_hat_BChw = f_hat\n                else: ls_f_hat_BChw.append(f_hat)\n        \n        return ls_f_hat_BChw\n    \n    def f_to_idxBl_or_fhat(self, f_BChw: torch.Tensor, to_fhat: bool, v_patch_nums: Optional[Sequence[Union[int, Tuple[int, int]]]] = None) -> List[Union[torch.Tensor, torch.LongTensor]]:  # z_BChw is the feature from inp_img_no_grad\n        B, C, H, W = f_BChw.shape\n        f_no_grad = f_BChw.detach()\n        f_rest = f_no_grad.clone()\n        f_hat = torch.zeros_like(f_rest)\n        \n        f_hat_or_idx_Bl: List[torch.Tensor] = []\n        \n        patch_hws = [(pn, pn) if isinstance(pn, int) else (pn[0], pn[1]) for pn in (v_patch_nums or self.v_patch_nums)]    # from small to large\n        assert patch_hws[-1][0] == H and patch_hws[-1][1] == W, f'{patch_hws[-1]=} != ({H=}, {W=})'\n        \n        SN = len(patch_hws)\n        for si, (ph, pw) in enumerate(patch_hws): # from small to large\n            if 0 <= self.prog_si < si: break    # progressive training: not supported yet, prog_si always -1\n            # find the nearest embedding\n            z_NC = F.interpolate(f_rest, size=(ph, pw), mode='area').permute(0, 2, 3, 1).reshape(-1, C) if (si != SN-1) else f_rest.permute(0, 2, 3, 1).reshape(-1, C)\n            if self.using_znorm:\n                z_NC = F.normalize(z_NC, dim=-1)\n                idx_N = torch.argmax(z_NC @ F.normalize(self.embedding.weight.data.T, dim=0), dim=1)\n            else:\n                d_no_grad = torch.sum(z_NC.square(), dim=1, keepdim=True) + torch.sum(self.embedding.weight.data.square(), dim=1, keepdim=False)\n                d_no_grad.addmm_(z_NC, self.embedding.weight.data.T, alpha=-2, beta=1)  # (B*h*w, vocab_size)\n                idx_N = torch.argmin(d_no_grad, dim=1)\n            \n            idx_Bhw = idx_N.view(B, ph, pw)\n            h_BChw = F.interpolate(self.embedding(idx_Bhw).permute(0, 3, 1, 2), size=(H, W), mode='bicubic').contiguous() if (si != SN-1) else self.embedding(idx_Bhw).permute(0, 3, 1, 2).contiguous()\n            h_BChw = self.quant_resi[si/(SN-1)](h_BChw)\n            f_hat.add_(h_BChw)\n            f_rest.sub_(h_BChw)\n            f_hat_or_idx_Bl.append(f_hat.clone() if to_fhat else idx_N.reshape(B, ph*pw))\n        \n        return f_hat_or_idx_Bl\n    \n    # ===================== idxBl_to_var_input: only used in VAR training, for getting teacher-forcing input =====================\n    def idxBl_to_var_input(self, gt_ms_idx_Bl: List[torch.Tensor]) -> torch.Tensor:\n        next_scales = []\n        B = gt_ms_idx_Bl[0].shape[0]\n        C = self.Cvae\n        H = W = self.v_patch_nums[-1]\n        SN = len(self.v_patch_nums)\n        \n        f_hat = gt_ms_idx_Bl[0].new_zeros(B, C, H, W, dtype=torch.float32)\n        pn_next: int = self.v_patch_nums[0]\n        for si in range(SN-1):\n            if self.prog_si == 0 or (0 <= self.prog_si-1 < si): break   # progressive training: not supported yet, prog_si always -1\n            h_BChw = F.interpolate(self.embedding(gt_ms_idx_Bl[si]).transpose_(1, 2).view(B, C, pn_next, pn_next), size=(H, W), mode='bicubic')\n            f_hat.add_(self.quant_resi[si/(SN-1)](h_BChw))\n            pn_next = self.v_patch_nums[si+1]\n            next_scales.append(F.interpolate(f_hat, size=(pn_next, pn_next), mode='area').view(B, C, -1).transpose(1, 2))\n        return torch.cat(next_scales, dim=1) if len(next_scales) else None    # cat BlCs to BLC, this should be float32\n    \n    # ===================== get_next_autoregressive_input: only used in VAR inference, for getting next step's input =====================\n    def get_next_autoregressive_input(self, si: int, SN: int, f_hat: torch.Tensor, h_BChw: torch.Tensor) -> Tuple[Optional[torch.Tensor], torch.Tensor]: # only used in VAR inference\n        HW = self.v_patch_nums[-1]\n        if si != SN-1:\n            h = self.quant_resi[si/(SN-1)](F.interpolate(h_BChw, size=(HW, HW), mode='bicubic'))     # conv after upsample\n            f_hat.add_(h)\n            return f_hat, F.interpolate(f_hat, size=(self.v_patch_nums[si+1], self.v_patch_nums[si+1]), mode='area')\n        else:\n            h = self.quant_resi[si/(SN-1)](h_BChw)\n            f_hat.add_(h)\n            return f_hat, f_hat\n\n\nclass Phi(nn.Conv2d):\n    def __init__(self, embed_dim, quant_resi):\n        ks = 3\n        super().__init__(in_channels=embed_dim, out_channels=embed_dim, kernel_size=ks, stride=1, padding=ks//2)\n        self.resi_ratio = abs(quant_resi)\n    \n    def forward(self, h_BChw):\n        return h_BChw.mul(1-self.resi_ratio) + super().forward(h_BChw).mul_(self.resi_ratio)\n\n\nclass PhiShared(nn.Module):\n    def __init__(self, qresi: Phi):\n        super().__init__()\n        self.qresi: Phi = qresi\n    \n    def __getitem__(self, _) -> Phi:\n        return self.qresi\n\n\nclass PhiPartiallyShared(nn.Module):\n    def __init__(self, qresi_ls: nn.ModuleList):\n        super().__init__()\n        self.qresi_ls = qresi_ls\n        K = len(qresi_ls)\n        self.ticks = np.linspace(1/3/K, 1-1/3/K, K) if K == 4 else np.linspace(1/2/K, 1-1/2/K, K)\n    \n    def __getitem__(self, at_from_0_to_1: float) -> Phi:\n        return self.qresi_ls[np.argmin(np.abs(self.ticks - at_from_0_to_1)).item()]\n    \n    def extra_repr(self) -> str:\n        return f'ticks={self.ticks}'\n\n\nclass PhiNonShared(nn.ModuleList):\n    def __init__(self, qresi: List):\n        super().__init__(qresi)\n        # self.qresi = qresi\n        K = len(qresi)\n        self.ticks = np.linspace(1/3/K, 1-1/3/K, K) if K == 4 else np.linspace(1/2/K, 1-1/2/K, K)\n    \n    def __getitem__(self, at_from_0_to_1: float) -> Phi:\n        return super().__getitem__(np.argmin(np.abs(self.ticks - at_from_0_to_1)).item())\n    \n    def extra_repr(self) -> str:\n        return f'ticks={self.ticks}'\n"
  },
  {
    "path": "models/var.py",
    "content": "import math\nfrom functools import partial\nfrom typing import Optional, Tuple, Union\n\nimport torch\nimport torch.nn as nn\nfrom huggingface_hub import PyTorchModelHubMixin\n\nimport dist\nfrom models.basic_var import AdaLNBeforeHead, AdaLNSelfAttn\nfrom models.helpers import gumbel_softmax_with_rng, sample_with_top_k_top_p_\nfrom models.vqvae import VQVAE, VectorQuantizer2\n\n\nclass SharedAdaLin(nn.Linear):\n    def forward(self, cond_BD):\n        C = self.weight.shape[0] // 6\n        return super().forward(cond_BD).view(-1, 1, 6, C)   # B16C\n\n\nclass VAR(nn.Module):\n    def __init__(\n        self, vae_local: VQVAE,\n        num_classes=1000, depth=16, embed_dim=1024, num_heads=16, mlp_ratio=4., drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,\n        norm_eps=1e-6, shared_aln=False, cond_drop_rate=0.1,\n        attn_l2_norm=False,\n        patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16),   # 10 steps by default\n        flash_if_available=True, fused_if_available=True,\n    ):\n        super().__init__()\n        # 0. hyperparameters\n        assert embed_dim % num_heads == 0\n        self.Cvae, self.V = vae_local.Cvae, vae_local.vocab_size\n        self.depth, self.C, self.D, self.num_heads = depth, embed_dim, embed_dim, num_heads\n        \n        self.cond_drop_rate = cond_drop_rate\n        self.prog_si = -1   # progressive training\n        \n        self.patch_nums: Tuple[int] = patch_nums\n        self.L = sum(pn ** 2 for pn in self.patch_nums)\n        self.first_l = self.patch_nums[0] ** 2\n        self.begin_ends = []\n        cur = 0\n        for i, pn in enumerate(self.patch_nums):\n            self.begin_ends.append((cur, cur+pn ** 2))\n            cur += pn ** 2\n        \n        self.num_stages_minus_1 = len(self.patch_nums) - 1\n        self.rng = torch.Generator(device=dist.get_device())\n        \n        # 1. input (word) embedding\n        quant: VectorQuantizer2 = vae_local.quantize\n        self.vae_proxy: Tuple[VQVAE] = (vae_local,)\n        self.vae_quant_proxy: Tuple[VectorQuantizer2] = (quant,)\n        self.word_embed = nn.Linear(self.Cvae, self.C)\n        \n        # 2. class embedding\n        init_std = math.sqrt(1 / self.C / 3)\n        self.num_classes = num_classes\n        self.uniform_prob = torch.full((1, num_classes), fill_value=1.0 / num_classes, dtype=torch.float32, device=dist.get_device())\n        self.class_emb = nn.Embedding(self.num_classes + 1, self.C)\n        nn.init.trunc_normal_(self.class_emb.weight.data, mean=0, std=init_std)\n        self.pos_start = nn.Parameter(torch.empty(1, self.first_l, self.C))\n        nn.init.trunc_normal_(self.pos_start.data, mean=0, std=init_std)\n        \n        # 3. absolute position embedding\n        pos_1LC = []\n        for i, pn in enumerate(self.patch_nums):\n            pe = torch.empty(1, pn*pn, self.C)\n            nn.init.trunc_normal_(pe, mean=0, std=init_std)\n            pos_1LC.append(pe)\n        pos_1LC = torch.cat(pos_1LC, dim=1)     # 1, L, C\n        assert tuple(pos_1LC.shape) == (1, self.L, self.C)\n        self.pos_1LC = nn.Parameter(pos_1LC)\n        # level embedding (similar to GPT's segment embedding, used to distinguish different levels of token pyramid)\n        self.lvl_embed = nn.Embedding(len(self.patch_nums), self.C)\n        nn.init.trunc_normal_(self.lvl_embed.weight.data, mean=0, std=init_std)\n        \n        # 4. backbone blocks\n        self.shared_ada_lin = nn.Sequential(nn.SiLU(inplace=False), SharedAdaLin(self.D, 6*self.C)) if shared_aln else nn.Identity()\n        \n        norm_layer = partial(nn.LayerNorm, eps=norm_eps)\n        self.drop_path_rate = drop_path_rate\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule (linearly increasing)\n        self.blocks = nn.ModuleList([\n            AdaLNSelfAttn(\n                cond_dim=self.D, shared_aln=shared_aln,\n                block_idx=block_idx, embed_dim=self.C, norm_layer=norm_layer, num_heads=num_heads, mlp_ratio=mlp_ratio,\n                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[block_idx], last_drop_p=0 if block_idx == 0 else dpr[block_idx-1],\n                attn_l2_norm=attn_l2_norm,\n                flash_if_available=flash_if_available, fused_if_available=fused_if_available,\n            )\n            for block_idx in range(depth)\n        ])\n        \n        fused_add_norm_fns = [b.fused_add_norm_fn is not None for b in self.blocks]\n        self.using_fused_add_norm_fn = any(fused_add_norm_fns)\n        print(\n            f'\\n[constructor]  ==== flash_if_available={flash_if_available} ({sum(b.attn.using_flash for b in self.blocks)}/{self.depth}), fused_if_available={fused_if_available} (fusing_add_ln={sum(fused_add_norm_fns)}/{self.depth}, fusing_mlp={sum(b.ffn.fused_mlp_func is not None for b in self.blocks)}/{self.depth}) ==== \\n'\n            f'    [VAR config ] embed_dim={embed_dim}, num_heads={num_heads}, depth={depth}, mlp_ratio={mlp_ratio}\\n'\n            f'    [drop ratios ] drop_rate={drop_rate}, attn_drop_rate={attn_drop_rate}, drop_path_rate={drop_path_rate:g} ({torch.linspace(0, drop_path_rate, depth)})',\n            end='\\n\\n', flush=True\n        )\n        \n        # 5. attention mask used in training (for masking out the future)\n        #    it won't be used in inference, since kv cache is enabled\n        d: torch.Tensor = torch.cat([torch.full((pn*pn,), i) for i, pn in enumerate(self.patch_nums)]).view(1, self.L, 1)\n        dT = d.transpose(1, 2)    # dT: 11L\n        lvl_1L = dT[:, 0].contiguous()\n        self.register_buffer('lvl_1L', lvl_1L)\n        attn_bias_for_masking = torch.where(d >= dT, 0., -torch.inf).reshape(1, 1, self.L, self.L)\n        self.register_buffer('attn_bias_for_masking', attn_bias_for_masking.contiguous())\n        \n        # 6. classifier head\n        self.head_nm = AdaLNBeforeHead(self.C, self.D, norm_layer=norm_layer)\n        self.head = nn.Linear(self.C, self.V)\n    \n    def get_logits(self, h_or_h_and_residual: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], cond_BD: Optional[torch.Tensor]):\n        if not isinstance(h_or_h_and_residual, torch.Tensor):\n            h, resi = h_or_h_and_residual   # fused_add_norm must be used\n            h = resi + self.blocks[-1].drop_path(h)\n        else:                               # fused_add_norm is not used\n            h = h_or_h_and_residual\n        return self.head(self.head_nm(h.float(), cond_BD).float()).float()\n    \n    @torch.no_grad()\n    def autoregressive_infer_cfg(\n        self, B: int, label_B: Optional[Union[int, torch.LongTensor]],\n        g_seed: Optional[int] = None, cfg=1.5, top_k=0, top_p=0.0,\n        more_smooth=False,\n    ) -> torch.Tensor:   # returns reconstructed image (B, 3, H, W) in [0, 1]\n        \"\"\"\n        only used for inference, on autoregressive mode\n        :param B: batch size\n        :param label_B: imagenet label; if None, randomly sampled\n        :param g_seed: random seed\n        :param cfg: classifier-free guidance ratio\n        :param top_k: top-k sampling\n        :param top_p: top-p sampling\n        :param more_smooth: smoothing the pred using gumbel softmax; only used in visualization, not used in FID/IS benchmarking\n        :return: if returns_vemb: list of embedding h_BChw := vae_embed(idx_Bl), else: list of idx_Bl\n        \"\"\"\n        if g_seed is None: rng = None\n        else: self.rng.manual_seed(g_seed); rng = self.rng\n        \n        if label_B is None:\n            label_B = torch.multinomial(self.uniform_prob, num_samples=B, replacement=True, generator=rng).reshape(B)\n        elif isinstance(label_B, int):\n            label_B = torch.full((B,), fill_value=self.num_classes if label_B < 0 else label_B, device=self.lvl_1L.device)\n        \n        sos = cond_BD = self.class_emb(torch.cat((label_B, torch.full_like(label_B, fill_value=self.num_classes)), dim=0))\n        \n        lvl_pos = self.lvl_embed(self.lvl_1L) + self.pos_1LC\n        next_token_map = sos.unsqueeze(1).expand(2 * B, self.first_l, -1) + self.pos_start.expand(2 * B, self.first_l, -1) + lvl_pos[:, :self.first_l]\n        \n        cur_L = 0\n        f_hat = sos.new_zeros(B, self.Cvae, self.patch_nums[-1], self.patch_nums[-1])\n        \n        for b in self.blocks: b.attn.kv_caching(True)\n        for si, pn in enumerate(self.patch_nums):   # si: i-th segment\n            ratio = si / self.num_stages_minus_1\n            # last_L = cur_L\n            cur_L += pn*pn\n            # assert self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L].sum() == 0, f'AR with {(self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L] != 0).sum()} / {self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L].numel()} mask item'\n            cond_BD_or_gss = self.shared_ada_lin(cond_BD)\n            x = next_token_map\n            AdaLNSelfAttn.forward\n            for b in self.blocks:\n                x = b(x=x, cond_BD=cond_BD_or_gss, attn_bias=None)\n            logits_BlV = self.get_logits(x, cond_BD)\n            \n            t = cfg * ratio\n            logits_BlV = (1+t) * logits_BlV[:B] - t * logits_BlV[B:]\n            \n            idx_Bl = sample_with_top_k_top_p_(logits_BlV, rng=rng, top_k=top_k, top_p=top_p, num_samples=1)[:, :, 0]\n            if not more_smooth: # this is the default case\n                h_BChw = self.vae_quant_proxy[0].embedding(idx_Bl)   # B, l, Cvae\n            else:   # not used when evaluating FID/IS/Precision/Recall\n                gum_t = max(0.27 * (1 - ratio * 0.95), 0.005)   # refer to mask-git\n                h_BChw = gumbel_softmax_with_rng(logits_BlV.mul(1 + ratio), tau=gum_t, hard=False, dim=-1, rng=rng) @ self.vae_quant_proxy[0].embedding.weight.unsqueeze(0)\n            \n            h_BChw = h_BChw.transpose_(1, 2).reshape(B, self.Cvae, pn, pn)\n            f_hat, next_token_map = self.vae_quant_proxy[0].get_next_autoregressive_input(si, len(self.patch_nums), f_hat, h_BChw)\n            if si != self.num_stages_minus_1:   # prepare for next stage\n                next_token_map = next_token_map.view(B, self.Cvae, -1).transpose(1, 2)\n                next_token_map = self.word_embed(next_token_map) + lvl_pos[:, cur_L:cur_L + self.patch_nums[si+1] ** 2]\n                next_token_map = next_token_map.repeat(2, 1, 1)   # double the batch sizes due to CFG\n        \n        for b in self.blocks: b.attn.kv_caching(False)\n        return self.vae_proxy[0].fhat_to_img(f_hat).add_(1).mul_(0.5)   # de-normalize, from [-1, 1] to [0, 1]\n    \n    def forward(self, label_B: torch.LongTensor, x_BLCv_wo_first_l: torch.Tensor) -> torch.Tensor:  # returns logits_BLV\n        \"\"\"\n        :param label_B: label_B\n        :param x_BLCv_wo_first_l: teacher forcing input (B, self.L-self.first_l, self.Cvae)\n        :return: logits BLV, V is vocab_size\n        \"\"\"\n        bg, ed = self.begin_ends[self.prog_si] if self.prog_si >= 0 else (0, self.L)\n        B = x_BLCv_wo_first_l.shape[0]\n        with torch.cuda.amp.autocast(enabled=False):\n            label_B = torch.where(torch.rand(B, device=label_B.device) < self.cond_drop_rate, self.num_classes, label_B)\n            sos = cond_BD = self.class_emb(label_B)\n            sos = sos.unsqueeze(1).expand(B, self.first_l, -1) + self.pos_start.expand(B, self.first_l, -1)\n            \n            if self.prog_si == 0: x_BLC = sos\n            else: x_BLC = torch.cat((sos, self.word_embed(x_BLCv_wo_first_l.float())), dim=1)\n            x_BLC += self.lvl_embed(self.lvl_1L[:, :ed].expand(B, -1)) + self.pos_1LC[:, :ed] # lvl: BLC;  pos: 1LC\n        \n        attn_bias = self.attn_bias_for_masking[:, :, :ed, :ed]\n        cond_BD_or_gss = self.shared_ada_lin(cond_BD)\n        \n        # hack: get the dtype if mixed precision is used\n        temp = x_BLC.new_ones(8, 8)\n        main_type = torch.matmul(temp, temp).dtype\n        \n        x_BLC = x_BLC.to(dtype=main_type)\n        cond_BD_or_gss = cond_BD_or_gss.to(dtype=main_type)\n        attn_bias = attn_bias.to(dtype=main_type)\n        \n        AdaLNSelfAttn.forward\n        for i, b in enumerate(self.blocks):\n            x_BLC = b(x=x_BLC, cond_BD=cond_BD_or_gss, attn_bias=attn_bias)\n        x_BLC = self.get_logits(x_BLC.float(), cond_BD)\n        \n        if self.prog_si == 0:\n            if isinstance(self.word_embed, nn.Linear):\n                x_BLC[0, 0, 0] += self.word_embed.weight[0, 0] * 0 + self.word_embed.bias[0] * 0\n            else:\n                s = 0\n                for p in self.word_embed.parameters():\n                    if p.requires_grad:\n                        s += p.view(-1)[0] * 0\n                x_BLC[0, 0, 0] += s\n        return x_BLC    # logits BLV, V is vocab_size\n    \n    def init_weights(self, init_adaln=0.5, init_adaln_gamma=1e-5, init_head=0.02, init_std=0.02, conv_std_or_gain=0.02):\n        if init_std < 0: init_std = (1 / self.C / 3) ** 0.5     # init_std < 0: automated\n        \n        print(f'[init_weights] {type(self).__name__} with {init_std=:g}')\n        for m in self.modules():\n            with_weight = hasattr(m, 'weight') and m.weight is not None\n            with_bias = hasattr(m, 'bias') and m.bias is not None\n            if isinstance(m, nn.Linear):\n                nn.init.trunc_normal_(m.weight.data, std=init_std)\n                if with_bias: m.bias.data.zero_()\n            elif isinstance(m, nn.Embedding):\n                nn.init.trunc_normal_(m.weight.data, std=init_std)\n                if m.padding_idx is not None: m.weight.data[m.padding_idx].zero_()\n            elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm, nn.GroupNorm, nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)):\n                if with_weight: m.weight.data.fill_(1.)\n                if with_bias: m.bias.data.zero_()\n            # conv: VAR has no conv, only VQVAE has conv\n            elif isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):\n                if conv_std_or_gain > 0: nn.init.trunc_normal_(m.weight.data, std=conv_std_or_gain)\n                else: nn.init.xavier_normal_(m.weight.data, gain=-conv_std_or_gain)\n                if with_bias: m.bias.data.zero_()\n        \n        if init_head >= 0:\n            if isinstance(self.head, nn.Linear):\n                self.head.weight.data.mul_(init_head)\n                self.head.bias.data.zero_()\n            elif isinstance(self.head, nn.Sequential):\n                self.head[-1].weight.data.mul_(init_head)\n                self.head[-1].bias.data.zero_()\n        \n        if isinstance(self.head_nm, AdaLNBeforeHead):\n            self.head_nm.ada_lin[-1].weight.data.mul_(init_adaln)\n            if hasattr(self.head_nm.ada_lin[-1], 'bias') and self.head_nm.ada_lin[-1].bias is not None:\n                self.head_nm.ada_lin[-1].bias.data.zero_()\n        \n        depth = len(self.blocks)\n        for block_idx, sab in enumerate(self.blocks):\n            sab: AdaLNSelfAttn\n            sab.attn.proj.weight.data.div_(math.sqrt(2 * depth))\n            sab.ffn.fc2.weight.data.div_(math.sqrt(2 * depth))\n            if hasattr(sab.ffn, 'fcg') and sab.ffn.fcg is not None:\n                nn.init.ones_(sab.ffn.fcg.bias)\n                nn.init.trunc_normal_(sab.ffn.fcg.weight, std=1e-5)\n            if hasattr(sab, 'ada_lin'):\n                sab.ada_lin[-1].weight.data[2*self.C:].mul_(init_adaln)\n                sab.ada_lin[-1].weight.data[:2*self.C].mul_(init_adaln_gamma)\n                if hasattr(sab.ada_lin[-1], 'bias') and sab.ada_lin[-1].bias is not None:\n                    sab.ada_lin[-1].bias.data.zero_()\n            elif hasattr(sab, 'ada_gss'):\n                sab.ada_gss.data[:, :, 2:].mul_(init_adaln)\n                sab.ada_gss.data[:, :, :2].mul_(init_adaln_gamma)\n    \n    def extra_repr(self):\n        return f'drop_path_rate={self.drop_path_rate:g}'\n\n\nclass VARHF(VAR, PyTorchModelHubMixin):\n            # repo_url=\"https://github.com/FoundationVision/VAR\",\n            # tags=[\"image-generation\"]):\n    def __init__(\n        self,\n        vae_kwargs,\n        num_classes=1000, depth=16, embed_dim=1024, num_heads=16, mlp_ratio=4., drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,\n        norm_eps=1e-6, shared_aln=False, cond_drop_rate=0.1,\n        attn_l2_norm=False,\n        patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16),   # 10 steps by default\n        flash_if_available=True, fused_if_available=True,\n    ):\n        vae_local = VQVAE(**vae_kwargs)\n        super().__init__(\n            vae_local=vae_local,\n            num_classes=num_classes, depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate,\n            norm_eps=norm_eps, shared_aln=shared_aln, cond_drop_rate=cond_drop_rate,\n            attn_l2_norm=attn_l2_norm,\n            patch_nums=patch_nums,\n            flash_if_available=flash_if_available, fused_if_available=fused_if_available,\n        )\n"
  },
  {
    "path": "models/vqvae.py",
    "content": "\"\"\"\nReferences:\n- VectorQuantizer2: https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/taming/modules/vqvae/quantize.py#L110\n- GumbelQuantize: https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/taming/modules/vqvae/quantize.py#L213\n- VQVAE (VQModel): https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/models/autoencoder.py#L14\n\"\"\"\nfrom typing import Any, Dict, List, Optional, Sequence, Tuple, Union\n\nimport torch\nimport torch.nn as nn\n\nfrom .basic_vae import Decoder, Encoder\nfrom .quant import VectorQuantizer2\n\n\nclass VQVAE(nn.Module):\n    def __init__(\n        self, vocab_size=4096, z_channels=32, ch=128, dropout=0.0,\n        beta=0.25,              # commitment loss weight\n        using_znorm=False,      # whether to normalize when computing the nearest neighbors\n        quant_conv_ks=3,        # quant conv kernel size\n        quant_resi=0.5,         # 0.5 means \\phi(x) = 0.5conv(x) + (1-0.5)x\n        share_quant_resi=4,     # use 4 \\phi layers for K scales: partially-shared \\phi\n        default_qresi_counts=0, # if is 0: automatically set to len(v_patch_nums)\n        v_patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # number of patches for each scale, h_{1 to K} = w_{1 to K} = v_patch_nums[k]\n        test_mode=True,\n    ):\n        super().__init__()\n        self.test_mode = test_mode\n        self.V, self.Cvae = vocab_size, z_channels\n        # ddconfig is copied from https://github.com/CompVis/latent-diffusion/blob/e66308c7f2e64cb581c6d27ab6fbeb846828253b/models/first_stage_models/vq-f16/config.yaml\n        ddconfig = dict(\n            dropout=dropout, ch=ch, z_channels=z_channels,\n            in_channels=3, ch_mult=(1, 1, 2, 2, 4), num_res_blocks=2,   # from vq-f16/config.yaml above\n            using_sa=True, using_mid_sa=True,                           # from vq-f16/config.yaml above\n            # resamp_with_conv=True,   # always True, removed.\n        )\n        ddconfig.pop('double_z', None)  # only KL-VAE should use double_z=True\n        self.encoder = Encoder(double_z=False, **ddconfig)\n        self.decoder = Decoder(**ddconfig)\n        \n        self.vocab_size = vocab_size\n        self.downsample = 2 ** (len(ddconfig['ch_mult'])-1)\n        self.quantize: VectorQuantizer2 = VectorQuantizer2(\n            vocab_size=vocab_size, Cvae=self.Cvae, using_znorm=using_znorm, beta=beta,\n            default_qresi_counts=default_qresi_counts, v_patch_nums=v_patch_nums, quant_resi=quant_resi, share_quant_resi=share_quant_resi,\n        )\n        self.quant_conv = torch.nn.Conv2d(self.Cvae, self.Cvae, quant_conv_ks, stride=1, padding=quant_conv_ks//2)\n        self.post_quant_conv = torch.nn.Conv2d(self.Cvae, self.Cvae, quant_conv_ks, stride=1, padding=quant_conv_ks//2)\n        \n        if self.test_mode:\n            self.eval()\n            [p.requires_grad_(False) for p in self.parameters()]\n    \n    # ===================== `forward` is only used in VAE training =====================\n    def forward(self, inp, ret_usages=False):   # -> rec_B3HW, idx_N, loss\n        VectorQuantizer2.forward\n        f_hat, usages, vq_loss = self.quantize(self.quant_conv(self.encoder(inp)), ret_usages=ret_usages)\n        return self.decoder(self.post_quant_conv(f_hat)), usages, vq_loss\n    # ===================== `forward` is only used in VAE training =====================\n    \n    def fhat_to_img(self, f_hat: torch.Tensor):\n        return self.decoder(self.post_quant_conv(f_hat)).clamp_(-1, 1)\n    \n    def img_to_idxBl(self, inp_img_no_grad: torch.Tensor, v_patch_nums: Optional[Sequence[Union[int, Tuple[int, int]]]] = None) -> List[torch.LongTensor]:    # return List[Bl]\n        f = self.quant_conv(self.encoder(inp_img_no_grad))\n        return self.quantize.f_to_idxBl_or_fhat(f, to_fhat=False, v_patch_nums=v_patch_nums)\n    \n    def idxBl_to_img(self, ms_idx_Bl: List[torch.Tensor], same_shape: bool, last_one=False) -> Union[List[torch.Tensor], torch.Tensor]:\n        B = ms_idx_Bl[0].shape[0]\n        ms_h_BChw = []\n        for idx_Bl in ms_idx_Bl:\n            l = idx_Bl.shape[1]\n            pn = round(l ** 0.5)\n            ms_h_BChw.append(self.quantize.embedding(idx_Bl).transpose(1, 2).view(B, self.Cvae, pn, pn))\n        return self.embed_to_img(ms_h_BChw=ms_h_BChw, all_to_max_scale=same_shape, last_one=last_one)\n    \n    def embed_to_img(self, ms_h_BChw: List[torch.Tensor], all_to_max_scale: bool, last_one=False) -> Union[List[torch.Tensor], torch.Tensor]:\n        if last_one:\n            return self.decoder(self.post_quant_conv(self.quantize.embed_to_fhat(ms_h_BChw, all_to_max_scale=all_to_max_scale, last_one=True))).clamp_(-1, 1)\n        else:\n            return [self.decoder(self.post_quant_conv(f_hat)).clamp_(-1, 1) for f_hat in self.quantize.embed_to_fhat(ms_h_BChw, all_to_max_scale=all_to_max_scale, last_one=False)]\n    \n    def img_to_reconstructed_img(self, x, v_patch_nums: Optional[Sequence[Union[int, Tuple[int, int]]]] = None, last_one=False) -> List[torch.Tensor]:\n        f = self.quant_conv(self.encoder(x))\n        ls_f_hat_BChw = self.quantize.f_to_idxBl_or_fhat(f, to_fhat=True, v_patch_nums=v_patch_nums)\n        if last_one:\n            return self.decoder(self.post_quant_conv(ls_f_hat_BChw[-1])).clamp_(-1, 1)\n        else:\n            return [self.decoder(self.post_quant_conv(f_hat)).clamp_(-1, 1) for f_hat in ls_f_hat_BChw]\n    \n    def load_state_dict(self, state_dict: Dict[str, Any], strict=True, assign=False):\n        if 'quantize.ema_vocab_hit_SV' in state_dict and state_dict['quantize.ema_vocab_hit_SV'].shape[0] != self.quantize.ema_vocab_hit_SV.shape[0]:\n            state_dict['quantize.ema_vocab_hit_SV'] = self.quantize.ema_vocab_hit_SV\n        return super().load_state_dict(state_dict=state_dict, strict=strict, assign=assign)\n"
  },
  {
    "path": "train.py",
    "content": "import gc\nimport os\nimport shutil\nimport sys\nimport time\nimport warnings\nfrom functools import partial\n\nimport torch\nfrom torch.utils.data import DataLoader\n\nimport dist\nfrom utils import arg_util, misc\nfrom utils.data import build_dataset\nfrom utils.data_sampler import DistInfiniteBatchSampler, EvalDistributedSampler\nfrom utils.misc import auto_resume\n\n\ndef build_everything(args: arg_util.Args):\n    # resume\n    auto_resume_info, start_ep, start_it, trainer_state, args_state = auto_resume(args, 'ar-ckpt*.pth')\n    # create tensorboard logger\n    tb_lg: misc.TensorboardLogger\n    with_tb_lg = dist.is_master()\n    if with_tb_lg:\n        os.makedirs(args.tb_log_dir_path, exist_ok=True)\n        # noinspection PyTypeChecker\n        tb_lg = misc.DistLogger(misc.TensorboardLogger(log_dir=args.tb_log_dir_path, filename_suffix=f'__{misc.time_str(\"%m%d_%H%M\")}'), verbose=True)\n        tb_lg.flush()\n    else:\n        # noinspection PyTypeChecker\n        tb_lg = misc.DistLogger(None, verbose=False)\n    dist.barrier()\n    \n    # log args\n    print(f'global bs={args.glb_batch_size}, local bs={args.batch_size}')\n    print(f'initial args:\\n{str(args)}')\n    \n    # build data\n    if not args.local_debug:\n        print(f'[build PT data] ...\\n')\n        num_classes, dataset_train, dataset_val = build_dataset(\n            args.data_path, final_reso=args.data_load_reso, hflip=args.hflip, mid_reso=args.mid_reso,\n        )\n        types = str((type(dataset_train).__name__, type(dataset_val).__name__))\n        \n        ld_val = DataLoader(\n            dataset_val, num_workers=0, pin_memory=True,\n            batch_size=round(args.batch_size*1.5), sampler=EvalDistributedSampler(dataset_val, num_replicas=dist.get_world_size(), rank=dist.get_rank()),\n            shuffle=False, drop_last=False,\n        )\n        del dataset_val\n        \n        ld_train = DataLoader(\n            dataset=dataset_train, num_workers=args.workers, pin_memory=True,\n            generator=args.get_different_generator_for_each_rank(), # worker_init_fn=worker_init_fn,\n            batch_sampler=DistInfiniteBatchSampler(\n                dataset_len=len(dataset_train), glb_batch_size=args.glb_batch_size, same_seed_for_all_ranks=args.same_seed_for_all_ranks,\n                shuffle=True, fill_last=True, rank=dist.get_rank(), world_size=dist.get_world_size(), start_ep=start_ep, start_it=start_it,\n            ),\n        )\n        del dataset_train\n        \n        [print(line) for line in auto_resume_info]\n        print(f'[dataloader multi processing] ...', end='', flush=True)\n        stt = time.time()\n        iters_train = len(ld_train)\n        ld_train = iter(ld_train)\n        # noinspection PyArgumentList\n        print(f'     [dataloader multi processing](*) finished! ({time.time()-stt:.2f}s)', flush=True, clean=True)\n        print(f'[dataloader] gbs={args.glb_batch_size}, lbs={args.batch_size}, iters_train={iters_train}, types(tr, va)={types}')\n    \n    else:\n        num_classes = 1000\n        ld_val = ld_train = None\n        iters_train = 10\n    \n    # build models\n    from torch.nn.parallel import DistributedDataParallel as DDP\n    from models import VAR, VQVAE, build_vae_var\n    from trainer import VARTrainer\n    from utils.amp_sc import AmpOptimizer\n    from utils.lr_control import filter_params\n    \n    vae_local, var_wo_ddp = build_vae_var(\n        V=4096, Cvae=32, ch=160, share_quant_resi=4,        # hard-coded VQVAE hyperparameters\n        device=dist.get_device(), patch_nums=args.patch_nums,\n        num_classes=num_classes, depth=args.depth, shared_aln=args.saln, attn_l2_norm=args.anorm,\n        flash_if_available=args.fuse, fused_if_available=args.fuse,\n        init_adaln=args.aln, init_adaln_gamma=args.alng, init_head=args.hd, init_std=args.ini,\n    )\n    \n    vae_ckpt = 'vae_ch160v4096z32.pth'\n    if dist.is_local_master():\n        if not os.path.exists(vae_ckpt):\n            os.system(f'wget https://huggingface.co/FoundationVision/var/resolve/main/{vae_ckpt}')\n    dist.barrier()\n    vae_local.load_state_dict(torch.load(vae_ckpt, map_location='cpu'), strict=True)\n    \n    vae_local: VQVAE = args.compile_model(vae_local, args.vfast)\n    var_wo_ddp: VAR = args.compile_model(var_wo_ddp, args.tfast)\n    var: DDP = (DDP if dist.initialized() else NullDDP)(var_wo_ddp, device_ids=[dist.get_local_rank()], find_unused_parameters=False, broadcast_buffers=False)\n    \n    print(f'[INIT] VAR model = {var_wo_ddp}\\n\\n')\n    count_p = lambda m: f'{sum(p.numel() for p in m.parameters())/1e6:.2f}'\n    print(f'[INIT][#para] ' + ', '.join([f'{k}={count_p(m)}' for k, m in (('VAE', vae_local), ('VAE.enc', vae_local.encoder), ('VAE.dec', vae_local.decoder), ('VAE.quant', vae_local.quantize))]))\n    print(f'[INIT][#para] ' + ', '.join([f'{k}={count_p(m)}' for k, m in (('VAR', var_wo_ddp),)]) + '\\n\\n')\n    \n    # build optimizer\n    names, paras, para_groups = filter_params(var_wo_ddp, nowd_keys={\n        'cls_token', 'start_token', 'task_token', 'cfg_uncond',\n        'pos_embed', 'pos_1LC', 'pos_start', 'start_pos', 'lvl_embed',\n        'gamma', 'beta',\n        'ada_gss', 'moe_bias',\n        'scale_mul',\n    })\n    opt_clz = {\n        'adam':  partial(torch.optim.AdamW, betas=(0.9, 0.95), fused=args.afuse),\n        'adamw': partial(torch.optim.AdamW, betas=(0.9, 0.95), fused=args.afuse),\n    }[args.opt.lower().strip()]\n    opt_kw = dict(lr=args.tlr, weight_decay=0)\n    print(f'[INIT] optim={opt_clz}, opt_kw={opt_kw}\\n')\n    \n    var_optim = AmpOptimizer(\n        mixed_precision=args.fp16, optimizer=opt_clz(params=para_groups, **opt_kw), names=names, paras=paras,\n        grad_clip=args.tclip, n_gradient_accumulation=args.ac\n    )\n    del names, paras, para_groups\n    \n    # build trainer\n    trainer = VARTrainer(\n        device=args.device, patch_nums=args.patch_nums, resos=args.resos,\n        vae_local=vae_local, var_wo_ddp=var_wo_ddp, var=var,\n        var_opt=var_optim, label_smooth=args.ls,\n    )\n    if trainer_state is not None and len(trainer_state):\n        trainer.load_state_dict(trainer_state, strict=False, skip_vae=True) # don't load vae again\n    del vae_local, var_wo_ddp, var, var_optim\n    \n    if args.local_debug:\n        rng = torch.Generator('cpu')\n        rng.manual_seed(0)\n        B = 4\n        inp = torch.rand(B, 3, args.data_load_reso, args.data_load_reso)\n        label = torch.ones(B, dtype=torch.long)\n        \n        me = misc.MetricLogger(delimiter='  ')\n        trainer.train_step(\n            it=0, g_it=0, stepping=True, metric_lg=me, tb_lg=tb_lg,\n            inp_B3HW=inp, label_B=label, prog_si=args.pg0, prog_wp_it=20,\n        )\n        trainer.load_state_dict(trainer.state_dict())\n        trainer.train_step(\n            it=99, g_it=599, stepping=True, metric_lg=me, tb_lg=tb_lg,\n            inp_B3HW=inp, label_B=label, prog_si=-1, prog_wp_it=20,\n        )\n        print({k: meter.global_avg for k, meter in me.meters.items()})\n        \n        args.dump_log(); tb_lg.flush(); tb_lg.close()\n        if isinstance(sys.stdout, misc.SyncPrint) and isinstance(sys.stderr, misc.SyncPrint):\n            sys.stdout.close(), sys.stderr.close()\n        exit(0)\n    \n    dist.barrier()\n    return (\n        tb_lg, trainer, start_ep, start_it,\n        iters_train, ld_train, ld_val\n    )\n\n\ndef main_training():\n    args: arg_util.Args = arg_util.init_dist_and_get_args()\n    if args.local_debug:\n        torch.autograd.set_detect_anomaly(True)\n    \n    (\n        tb_lg, trainer,\n        start_ep, start_it,\n        iters_train, ld_train, ld_val\n    ) = build_everything(args)\n    \n    # train\n    start_time = time.time()\n    best_L_mean, best_L_tail, best_acc_mean, best_acc_tail = 999., 999., -1., -1.\n    best_val_loss_mean, best_val_loss_tail, best_val_acc_mean, best_val_acc_tail = 999, 999, -1, -1\n    \n    L_mean, L_tail = -1, -1\n    for ep in range(start_ep, args.ep):\n        if hasattr(ld_train, 'sampler') and hasattr(ld_train.sampler, 'set_epoch'):\n            ld_train.sampler.set_epoch(ep)\n            if ep < 3:\n                # noinspection PyArgumentList\n                print(f'[{type(ld_train).__name__}] [ld_train.sampler.set_epoch({ep})]', flush=True, force=True)\n        tb_lg.set_step(ep * iters_train)\n        \n        stats, (sec, remain_time, finish_time) = train_one_ep(\n            ep, ep == start_ep, start_it if ep == start_ep else 0, args, tb_lg, ld_train, iters_train, trainer\n        )\n        \n        L_mean, L_tail, acc_mean, acc_tail, grad_norm = stats['Lm'], stats['Lt'], stats['Accm'], stats['Acct'], stats['tnm']\n        best_L_mean, best_acc_mean = min(best_L_mean, L_mean), max(best_acc_mean, acc_mean)\n        if L_tail != -1: best_L_tail, best_acc_tail = min(best_L_tail, L_tail), max(best_acc_tail, acc_tail)\n        args.L_mean, args.L_tail, args.acc_mean, args.acc_tail, args.grad_norm = L_mean, L_tail, acc_mean, acc_tail, grad_norm\n        args.cur_ep = f'{ep+1}/{args.ep}'\n        args.remain_time, args.finish_time = remain_time, finish_time\n        \n        AR_ep_loss = dict(L_mean=L_mean, L_tail=L_tail, acc_mean=acc_mean, acc_tail=acc_tail)\n        is_val_and_also_saving = (ep + 1) % 10 == 0 or (ep + 1) == args.ep\n        if is_val_and_also_saving:\n            val_loss_mean, val_loss_tail, val_acc_mean, val_acc_tail, tot, cost = trainer.eval_ep(ld_val)\n            best_updated = best_val_loss_tail > val_loss_tail\n            best_val_loss_mean, best_val_loss_tail = min(best_val_loss_mean, val_loss_mean), min(best_val_loss_tail, val_loss_tail)\n            best_val_acc_mean, best_val_acc_tail = max(best_val_acc_mean, val_acc_mean), max(best_val_acc_tail, val_acc_tail)\n            AR_ep_loss.update(vL_mean=val_loss_mean, vL_tail=val_loss_tail, vacc_mean=val_acc_mean, vacc_tail=val_acc_tail)\n            args.vL_mean, args.vL_tail, args.vacc_mean, args.vacc_tail = val_loss_mean, val_loss_tail, val_acc_mean, val_acc_tail\n            print(f' [*] [ep{ep}]  (val {tot})  Lm: {L_mean:.4f}, Lt: {L_tail:.4f}, Acc m&t: {acc_mean:.2f} {acc_tail:.2f},  Val cost: {cost:.2f}s')\n            \n            if dist.is_local_master():\n                local_out_ckpt = os.path.join(args.local_out_dir_path, 'ar-ckpt-last.pth')\n                local_out_ckpt_best = os.path.join(args.local_out_dir_path, 'ar-ckpt-best.pth')\n                print(f'[saving ckpt] ...', end='', flush=True)\n                torch.save({\n                    'epoch':    ep+1,\n                    'iter':     0,\n                    'trainer':  trainer.state_dict(),\n                    'args':     args.state_dict(),\n                }, local_out_ckpt)\n                if best_updated:\n                    shutil.copy(local_out_ckpt, local_out_ckpt_best)\n                print(f'     [saving ckpt](*) finished!  @ {local_out_ckpt}', flush=True, clean=True)\n            dist.barrier()\n        \n        print(    f'     [ep{ep}]  (training )  Lm: {best_L_mean:.3f} ({L_mean:.3f}), Lt: {best_L_tail:.3f} ({L_tail:.3f}),  Acc m&t: {best_acc_mean:.2f} {best_acc_tail:.2f},  Remain: {remain_time},  Finish: {finish_time}', flush=True)\n        tb_lg.update(head='AR_ep_loss', step=ep+1, **AR_ep_loss)\n        tb_lg.update(head='AR_z_burnout', step=ep+1, rest_hours=round(sec / 60 / 60, 2))\n        args.dump_log(); tb_lg.flush()\n    \n    total_time = f'{(time.time() - start_time) / 60 / 60:.1f}h'\n    print('\\n\\n')\n    print(f'  [*] [PT finished]  Total cost: {total_time},   Lm: {best_L_mean:.3f} ({L_mean}),   Lt: {best_L_tail:.3f} ({L_tail})')\n    print('\\n\\n')\n    \n    del stats\n    del iters_train, ld_train\n    time.sleep(3), gc.collect(), torch.cuda.empty_cache(), time.sleep(3)\n    \n    args.remain_time, args.finish_time = '-', time.strftime(\"%Y-%m-%d %H:%M\", time.localtime(time.time() - 60))\n    print(f'final args:\\n\\n{str(args)}')\n    args.dump_log(); tb_lg.flush(); tb_lg.close()\n    dist.barrier()\n\n\ndef train_one_ep(ep: int, is_first_ep: bool, start_it: int, args: arg_util.Args, tb_lg: misc.TensorboardLogger, ld_or_itrt, iters_train: int, trainer):\n    # import heavy packages after Dataloader object creation\n    from trainer import VARTrainer\n    from utils.lr_control import lr_wd_annealing\n    trainer: VARTrainer\n    \n    step_cnt = 0\n    me = misc.MetricLogger(delimiter='  ')\n    me.add_meter('tlr', misc.SmoothedValue(window_size=1, fmt='{value:.2g}'))\n    me.add_meter('tnm', misc.SmoothedValue(window_size=1, fmt='{value:.2f}'))\n    [me.add_meter(x, misc.SmoothedValue(fmt='{median:.3f} ({global_avg:.3f})')) for x in ['Lm', 'Lt']]\n    [me.add_meter(x, misc.SmoothedValue(fmt='{median:.2f} ({global_avg:.2f})')) for x in ['Accm', 'Acct']]\n    header = f'[Ep]: [{ep:4d}/{args.ep}]'\n    \n    if is_first_ep:\n        warnings.filterwarnings('ignore', category=DeprecationWarning)\n        warnings.filterwarnings('ignore', category=UserWarning)\n    g_it, max_it = ep * iters_train, args.ep * iters_train\n    \n    for it, (inp, label) in me.log_every(start_it, iters_train, ld_or_itrt, 30 if iters_train > 8000 else 5, header):\n        g_it = ep * iters_train + it\n        if it < start_it: continue\n        if is_first_ep and it == start_it: warnings.resetwarnings()\n        \n        inp = inp.to(args.device, non_blocking=True)\n        label = label.to(args.device, non_blocking=True)\n        \n        args.cur_it = f'{it+1}/{iters_train}'\n        \n        wp_it = args.wp * iters_train\n        min_tlr, max_tlr, min_twd, max_twd = lr_wd_annealing(args.sche, trainer.var_opt.optimizer, args.tlr, args.twd, args.twde, g_it, wp_it, max_it, wp0=args.wp0, wpe=args.wpe)\n        args.cur_lr, args.cur_wd = max_tlr, max_twd\n        \n        if args.pg: # default: args.pg == 0.0, means no progressive training, won't get into this\n            if g_it <= wp_it: prog_si = args.pg0\n            elif g_it >= max_it*args.pg: prog_si = len(args.patch_nums) - 1\n            else:\n                delta = len(args.patch_nums) - 1 - args.pg0\n                progress = min(max((g_it - wp_it) / (max_it*args.pg - wp_it), 0), 1) # from 0 to 1\n                prog_si = args.pg0 + round(progress * delta)    # from args.pg0 to len(args.patch_nums)-1\n        else:\n            prog_si = -1\n        \n        stepping = (g_it + 1) % args.ac == 0\n        step_cnt += int(stepping)\n        \n        grad_norm, scale_log2 = trainer.train_step(\n            it=it, g_it=g_it, stepping=stepping, metric_lg=me, tb_lg=tb_lg,\n            inp_B3HW=inp, label_B=label, prog_si=prog_si, prog_wp_it=args.pgwp * iters_train,\n        )\n        \n        me.update(tlr=max_tlr)\n        tb_lg.set_step(step=g_it)\n        tb_lg.update(head='AR_opt_lr/lr_min', sche_tlr=min_tlr)\n        tb_lg.update(head='AR_opt_lr/lr_max', sche_tlr=max_tlr)\n        tb_lg.update(head='AR_opt_wd/wd_max', sche_twd=max_twd)\n        tb_lg.update(head='AR_opt_wd/wd_min', sche_twd=min_twd)\n        tb_lg.update(head='AR_opt_grad/fp16', scale_log2=scale_log2)\n        \n        if args.tclip > 0:\n            tb_lg.update(head='AR_opt_grad/grad', grad_norm=grad_norm)\n            tb_lg.update(head='AR_opt_grad/grad', grad_clip=args.tclip)\n    \n    me.synchronize_between_processes()\n    return {k: meter.global_avg for k, meter in me.meters.items()}, me.iter_time.time_preds(max_it - (g_it + 1) + (args.ep - ep) * 15)  # +15: other cost\n\n\nclass NullDDP(torch.nn.Module):\n    def __init__(self, module, *args, **kwargs):\n        super(NullDDP, self).__init__()\n        self.module = module\n        self.require_backward_grad_sync = False\n    \n    def forward(self, *args, **kwargs):\n        return self.module(*args, **kwargs)\n\n\nif __name__ == '__main__':\n    try: main_training()\n    finally:\n        dist.finalize()\n        if isinstance(sys.stdout, misc.SyncPrint) and isinstance(sys.stderr, misc.SyncPrint):\n            sys.stdout.close(), sys.stderr.close()\n"
  },
  {
    "path": "trainer.py",
    "content": "import time\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\nimport torch.nn as nn\nfrom torch.nn.parallel import DistributedDataParallel as DDP\nfrom torch.utils.data import DataLoader\n\nimport dist\nfrom models import VAR, VQVAE, VectorQuantizer2\nfrom utils.amp_sc import AmpOptimizer\nfrom utils.misc import MetricLogger, TensorboardLogger\n\nTen = torch.Tensor\nFTen = torch.Tensor\nITen = torch.LongTensor\nBTen = torch.BoolTensor\n\n\nclass VARTrainer(object):\n    def __init__(\n        self, device, patch_nums: Tuple[int, ...], resos: Tuple[int, ...],\n        vae_local: VQVAE, var_wo_ddp: VAR, var: DDP,\n        var_opt: AmpOptimizer, label_smooth: float,\n    ):\n        super(VARTrainer, self).__init__()\n        \n        self.var, self.vae_local, self.quantize_local = var, vae_local, vae_local.quantize\n        self.quantize_local: VectorQuantizer2\n        self.var_wo_ddp: VAR = var_wo_ddp  # after torch.compile\n        self.var_opt = var_opt\n        \n        del self.var_wo_ddp.rng\n        self.var_wo_ddp.rng = torch.Generator(device=device)\n        \n        self.label_smooth = label_smooth\n        self.train_loss = nn.CrossEntropyLoss(label_smoothing=label_smooth, reduction='none')\n        self.val_loss = nn.CrossEntropyLoss(label_smoothing=0.0, reduction='mean')\n        self.L = sum(pn * pn for pn in patch_nums)\n        self.last_l = patch_nums[-1] * patch_nums[-1]\n        self.loss_weight = torch.ones(1, self.L, device=device) / self.L\n        \n        self.patch_nums, self.resos = patch_nums, resos\n        self.begin_ends = []\n        cur = 0\n        for i, pn in enumerate(patch_nums):\n            self.begin_ends.append((cur, cur + pn * pn))\n            cur += pn*pn\n        \n        self.prog_it = 0\n        self.last_prog_si = -1\n        self.first_prog = True\n    \n    @torch.no_grad()\n    def eval_ep(self, ld_val: DataLoader):\n        tot = 0\n        L_mean, L_tail, acc_mean, acc_tail = 0, 0, 0, 0\n        stt = time.time()\n        training = self.var_wo_ddp.training\n        self.var_wo_ddp.eval()\n        for inp_B3HW, label_B in ld_val:\n            B, V = label_B.shape[0], self.vae_local.vocab_size\n            inp_B3HW = inp_B3HW.to(dist.get_device(), non_blocking=True)\n            label_B = label_B.to(dist.get_device(), non_blocking=True)\n            \n            gt_idx_Bl: List[ITen] = self.vae_local.img_to_idxBl(inp_B3HW)\n            gt_BL = torch.cat(gt_idx_Bl, dim=1)\n            x_BLCv_wo_first_l: Ten = self.quantize_local.idxBl_to_var_input(gt_idx_Bl)\n            \n            self.var_wo_ddp.forward\n            logits_BLV = self.var_wo_ddp(label_B, x_BLCv_wo_first_l)\n            L_mean += self.val_loss(logits_BLV.data.view(-1, V), gt_BL.view(-1)) * B\n            L_tail += self.val_loss(logits_BLV.data[:, -self.last_l:].reshape(-1, V), gt_BL[:, -self.last_l:].reshape(-1)) * B\n            acc_mean += (logits_BLV.data.argmax(dim=-1) == gt_BL).sum() * (100/gt_BL.shape[1])\n            acc_tail += (logits_BLV.data[:, -self.last_l:].argmax(dim=-1) == gt_BL[:, -self.last_l:]).sum() * (100 / self.last_l)\n            tot += B\n        self.var_wo_ddp.train(training)\n        \n        stats = L_mean.new_tensor([L_mean.item(), L_tail.item(), acc_mean.item(), acc_tail.item(), tot])\n        dist.allreduce(stats)\n        tot = round(stats[-1].item())\n        stats /= tot\n        L_mean, L_tail, acc_mean, acc_tail, _ = stats.tolist()\n        return L_mean, L_tail, acc_mean, acc_tail, tot, time.time()-stt\n    \n    def train_step(\n        self, it: int, g_it: int, stepping: bool, metric_lg: MetricLogger, tb_lg: TensorboardLogger,\n        inp_B3HW: FTen, label_B: Union[ITen, FTen], prog_si: int, prog_wp_it: float,\n    ) -> Tuple[Optional[Union[Ten, float]], Optional[float]]:\n        # if progressive training\n        self.var_wo_ddp.prog_si = self.vae_local.quantize.prog_si = prog_si\n        if self.last_prog_si != prog_si:\n            if self.last_prog_si != -1: self.first_prog = False\n            self.last_prog_si = prog_si\n            self.prog_it = 0\n        self.prog_it += 1\n        prog_wp = max(min(self.prog_it / prog_wp_it, 1), 0.01)\n        if self.first_prog: prog_wp = 1    # no prog warmup at first prog stage, as it's already solved in wp\n        if prog_si == len(self.patch_nums) - 1: prog_si = -1    # max prog, as if no prog\n        \n        # forward\n        B, V = label_B.shape[0], self.vae_local.vocab_size\n        self.var.require_backward_grad_sync = stepping\n        \n        gt_idx_Bl: List[ITen] = self.vae_local.img_to_idxBl(inp_B3HW)\n        gt_BL = torch.cat(gt_idx_Bl, dim=1)\n        x_BLCv_wo_first_l: Ten = self.quantize_local.idxBl_to_var_input(gt_idx_Bl)\n        \n        with self.var_opt.amp_ctx:\n            self.var_wo_ddp.forward\n            logits_BLV = self.var(label_B, x_BLCv_wo_first_l)\n            loss = self.train_loss(logits_BLV.view(-1, V), gt_BL.view(-1)).view(B, -1)\n            if prog_si >= 0:    # in progressive training\n                bg, ed = self.begin_ends[prog_si]\n                assert logits_BLV.shape[1] == gt_BL.shape[1] == ed\n                lw = self.loss_weight[:, :ed].clone()\n                lw[:, bg:ed] *= min(max(prog_wp, 0), 1)\n            else:               # not in progressive training\n                lw = self.loss_weight\n            loss = loss.mul(lw).sum(dim=-1).mean()\n        \n        # backward\n        grad_norm, scale_log2 = self.var_opt.backward_clip_step(loss=loss, stepping=stepping)\n        \n        # log\n        pred_BL = logits_BLV.data.argmax(dim=-1)\n        if it == 0 or it in metric_lg.log_iters:\n            Lmean = self.val_loss(logits_BLV.data.view(-1, V), gt_BL.view(-1)).item()\n            acc_mean = (pred_BL == gt_BL).float().mean().item() * 100\n            if prog_si >= 0:    # in progressive training\n                Ltail = acc_tail = -1\n            else:               # not in progressive training\n                Ltail = self.val_loss(logits_BLV.data[:, -self.last_l:].reshape(-1, V), gt_BL[:, -self.last_l:].reshape(-1)).item()\n                acc_tail = (pred_BL[:, -self.last_l:] == gt_BL[:, -self.last_l:]).float().mean().item() * 100\n            grad_norm = grad_norm.item()\n            metric_lg.update(Lm=Lmean, Lt=Ltail, Accm=acc_mean, Acct=acc_tail, tnm=grad_norm)\n        \n        # log to tensorboard\n        if g_it == 0 or (g_it + 1) % 500 == 0:\n            prob_per_class_is_chosen = pred_BL.view(-1).bincount(minlength=V).float()\n            dist.allreduce(prob_per_class_is_chosen)\n            prob_per_class_is_chosen /= prob_per_class_is_chosen.sum()\n            cluster_usage = (prob_per_class_is_chosen > 0.001 / V).float().mean().item() * 100\n            if dist.is_master():\n                if g_it == 0:\n                    tb_lg.update(head='AR_iter_loss', z_voc_usage=cluster_usage, step=-10000)\n                    tb_lg.update(head='AR_iter_loss', z_voc_usage=cluster_usage, step=-1000)\n                kw = dict(z_voc_usage=cluster_usage)\n                for si, (bg, ed) in enumerate(self.begin_ends):\n                    if 0 <= prog_si < si: break\n                    pred, tar = logits_BLV.data[:, bg:ed].reshape(-1, V), gt_BL[:, bg:ed].reshape(-1)\n                    acc = (pred.argmax(dim=-1) == tar).float().mean().item() * 100\n                    ce = self.val_loss(pred, tar).item()\n                    kw[f'acc_{self.resos[si]}'] = acc\n                    kw[f'L_{self.resos[si]}'] = ce\n                tb_lg.update(head='AR_iter_loss', **kw, step=g_it)\n                tb_lg.update(head='AR_iter_schedule', prog_a_reso=self.resos[prog_si], prog_si=prog_si, prog_wp=prog_wp, step=g_it)\n        \n        self.var_wo_ddp.prog_si = self.vae_local.quantize.prog_si = -1\n        return grad_norm, scale_log2\n    \n    def get_config(self):\n        return {\n            'patch_nums':   self.patch_nums, 'resos': self.resos,\n            'label_smooth': self.label_smooth,\n            'prog_it':      self.prog_it, 'last_prog_si': self.last_prog_si, 'first_prog': self.first_prog,\n        }\n    \n    def state_dict(self):\n        state = {'config': self.get_config()}\n        for k in ('var_wo_ddp', 'vae_local', 'var_opt'):\n            m = getattr(self, k)\n            if m is not None:\n                if hasattr(m, '_orig_mod'):\n                    m = m._orig_mod\n                state[k] = m.state_dict()\n        return state\n    \n    def load_state_dict(self, state, strict=True, skip_vae=False):\n        for k in ('var_wo_ddp', 'vae_local', 'var_opt'):\n            if skip_vae and 'vae' in k: continue\n            m = getattr(self, k)\n            if m is not None:\n                if hasattr(m, '_orig_mod'):\n                    m = m._orig_mod\n                ret = m.load_state_dict(state[k], strict=strict)\n                if ret is not None:\n                    missing, unexpected = ret\n                    print(f'[VARTrainer.load_state_dict] {k} missing:  {missing}')\n                    print(f'[VARTrainer.load_state_dict] {k} unexpected:  {unexpected}')\n        \n        config: dict = state.pop('config', None)\n        self.prog_it = config.get('prog_it', 0)\n        self.last_prog_si = config.get('last_prog_si', -1)\n        self.first_prog = config.get('first_prog', True)\n        if config is not None:\n            for k, v in self.get_config().items():\n                if config.get(k, None) != v:\n                    err = f'[VAR.load_state_dict] config mismatch:  this.{k}={v} (ckpt.{k}={config.get(k, None)})'\n                    if strict: raise AttributeError(err)\n                    else: print(err)\n"
  },
  {
    "path": "utils/amp_sc.py",
    "content": "import math\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\n\n\nclass NullCtx:\n    def __enter__(self):\n        pass\n    \n    def __exit__(self, exc_type, exc_val, exc_tb):\n        pass\n\n\nclass AmpOptimizer:\n    def __init__(\n        self,\n        mixed_precision: int,\n        optimizer: torch.optim.Optimizer, names: List[str], paras: List[torch.nn.Parameter],\n        grad_clip: float, n_gradient_accumulation: int = 1,\n    ):\n        self.enable_amp = mixed_precision > 0\n        self.using_fp16_rather_bf16 = mixed_precision == 1\n        \n        if self.enable_amp:\n            self.amp_ctx = torch.autocast('cuda', enabled=True, dtype=torch.float16 if self.using_fp16_rather_bf16 else torch.bfloat16, cache_enabled=True)\n            self.scaler = torch.cuda.amp.GradScaler(init_scale=2. ** 11, growth_interval=1000) if self.using_fp16_rather_bf16 else None # only fp16 needs a scaler\n        else:\n            self.amp_ctx = NullCtx()\n            self.scaler = None\n        \n        self.optimizer, self.names, self.paras = optimizer, names, paras   # paras have been filtered so everyone requires grad\n        self.grad_clip = grad_clip\n        self.early_clipping = self.grad_clip > 0 and not hasattr(optimizer, 'global_grad_norm')\n        self.late_clipping = self.grad_clip > 0 and hasattr(optimizer, 'global_grad_norm')\n        \n        self.r_accu = 1 / n_gradient_accumulation   # r_accu == 1.0 / n_gradient_accumulation\n    \n    def backward_clip_step(\n        self, stepping: bool, loss: torch.Tensor,\n    ) -> Tuple[Optional[Union[torch.Tensor, float]], Optional[float]]:\n        # backward\n        loss = loss.mul(self.r_accu)   # r_accu == 1.0 / n_gradient_accumulation\n        orig_norm = scaler_sc = None\n        if self.scaler is not None:\n            self.scaler.scale(loss).backward(retain_graph=False, create_graph=False)\n        else:\n            loss.backward(retain_graph=False, create_graph=False)\n        \n        if stepping:\n            if self.scaler is not None: self.scaler.unscale_(self.optimizer)\n            if self.early_clipping:\n                orig_norm = torch.nn.utils.clip_grad_norm_(self.paras, self.grad_clip)\n            \n            if self.scaler is not None:\n                self.scaler.step(self.optimizer)\n                scaler_sc: float = self.scaler.get_scale()\n                if scaler_sc > 32768.: # fp16 will overflow when >65536, so multiply 32768 could be dangerous\n                    self.scaler.update(new_scale=32768.)\n                else:\n                    self.scaler.update()\n                try:\n                    scaler_sc = float(math.log2(scaler_sc))\n                except Exception as e:\n                    print(f'[scaler_sc = {scaler_sc}]\\n' * 15, flush=True)\n                    raise e\n            else:\n                self.optimizer.step()\n            \n            if self.late_clipping:\n                orig_norm = self.optimizer.global_grad_norm\n            \n            self.optimizer.zero_grad(set_to_none=True)\n        \n        return orig_norm, scaler_sc\n    \n    def state_dict(self):\n        return {\n            'optimizer': self.optimizer.state_dict()\n        } if self.scaler is None else {\n            'scaler': self.scaler.state_dict(),\n            'optimizer': self.optimizer.state_dict()\n        }\n    \n    def load_state_dict(self, state, strict=True):\n        if self.scaler is not None:\n            try: self.scaler.load_state_dict(state['scaler'])\n            except Exception as e: print(f'[fp16 load_state_dict err] {e}')\n        self.optimizer.load_state_dict(state['optimizer'])\n"
  },
  {
    "path": "utils/arg_util.py",
    "content": "import json\nimport os\nimport random\nimport re\nimport subprocess\nimport sys\nimport time\nfrom collections import OrderedDict\nfrom typing import Optional, Union\n\nimport numpy as np\nimport torch\n\ntry:\n    from tap import Tap\nexcept ImportError as e:\n    print(f'`>>>>>>>> from tap import Tap` failed, please run:      pip3 install typed-argument-parser     <<<<<<<<', file=sys.stderr, flush=True)\n    print(f'`>>>>>>>> from tap import Tap` failed, please run:      pip3 install typed-argument-parser     <<<<<<<<', file=sys.stderr, flush=True)\n    time.sleep(5)\n    raise e\n\nimport dist\n\n\nclass Args(Tap):\n    data_path: str = '/path/to/imagenet'\n    exp_name: str = 'text'\n    \n    # VAE\n    vfast: int = 0      # torch.compile VAE; =0: not compile; 1: compile with 'reduce-overhead'; 2: compile with 'max-autotune'\n    # VAR\n    tfast: int = 0      # torch.compile VAR; =0: not compile; 1: compile with 'reduce-overhead'; 2: compile with 'max-autotune'\n    depth: int = 16     # VAR depth\n    # VAR initialization\n    ini: float = -1     # -1: automated model parameter initialization\n    hd: float = 0.02    # head.w *= hd\n    aln: float = 0.5    # the multiplier of ada_lin.w's initialization\n    alng: float = 1e-5  # the multiplier of ada_lin.w[gamma channels]'s initialization\n    # VAR optimization\n    fp16: int = 0           # 1: using fp16, 2: bf16\n    tblr: float = 1e-4      # base lr\n    tlr: float = None       # lr = base lr * (bs / 256)\n    twd: float = 0.05       # initial wd\n    twde: float = 0         # final wd, =twde or twd\n    tclip: float = 2.       # <=0 for not using grad clip\n    ls: float = 0.0         # label smooth\n    \n    bs: int = 768           # global batch size\n    batch_size: int = 0     # [automatically set; don't specify this] batch size per GPU = round(args.bs / args.ac / dist.get_world_size() / 8) * 8\n    glb_batch_size: int = 0 # [automatically set; don't specify this] global batch size = args.batch_size * dist.get_world_size()\n    ac: int = 1             # gradient accumulation\n    \n    ep: int = 250\n    wp: float = 0\n    wp0: float = 0.005      # initial lr ratio at the begging of lr warm up\n    wpe: float = 0.01       # final lr ratio at the end of training\n    sche: str = 'lin0'      # lr schedule\n    \n    opt: str = 'adamw'      # lion: https://cloud.tencent.com/developer/article/2336657?areaId=106001 lr=5e-5 (0.25x) wd=0.8 (8x); Lion needs a large bs to work\n    afuse: bool = True      # fused adamw\n    \n    # other hps\n    saln: bool = False      # whether to use shared adaln\n    anorm: bool = True      # whether to use L2 normalized attention\n    fuse: bool = True       # whether to use fused op like flash attn, xformers, fused MLP, fused LayerNorm, etc.\n    \n    # data\n    pn: str = '1_2_3_4_5_6_8_10_13_16'\n    patch_size: int = 16\n    patch_nums: tuple = None    # [automatically set; don't specify this] = tuple(map(int, args.pn.replace('-', '_').split('_')))\n    resos: tuple = None         # [automatically set; don't specify this] = tuple(pn * args.patch_size for pn in args.patch_nums)\n    \n    data_load_reso: int = None  # [automatically set; don't specify this] would be max(patch_nums) * patch_size\n    mid_reso: float = 1.125     # aug: first resize to mid_reso = 1.125 * data_load_reso, then crop to data_load_reso\n    hflip: bool = False         # augmentation: horizontal flip\n    workers: int = 0        # num workers; 0: auto, -1: don't use multiprocessing in DataLoader\n    \n    # progressive training\n    pg: float = 0.0         # >0 for use progressive training during [0%, this] of training\n    pg0: int = 4            # progressive initial stage, 0: from the 1st token map, 1: from the 2nd token map, etc\n    pgwp: float = 0         # num of warmup epochs at each progressive stage\n    \n    # would be automatically set in runtime\n    cmd: str = ' '.join(sys.argv[1:])  # [automatically set; don't specify this]\n    branch: str = subprocess.check_output(f'git symbolic-ref --short HEAD 2>/dev/null || git rev-parse HEAD', shell=True).decode('utf-8').strip() or '[unknown]' # [automatically set; don't specify this]\n    commit_id: str = subprocess.check_output(f'git rev-parse HEAD', shell=True).decode('utf-8').strip() or '[unknown]'  # [automatically set; don't specify this]\n    commit_msg: str = (subprocess.check_output(f'git log -1', shell=True).decode('utf-8').strip().splitlines() or ['[unknown]'])[-1].strip()    # [automatically set; don't specify this]\n    acc_mean: float = None      # [automatically set; don't specify this]\n    acc_tail: float = None      # [automatically set; don't specify this]\n    L_mean: float = None        # [automatically set; don't specify this]\n    L_tail: float = None        # [automatically set; don't specify this]\n    vacc_mean: float = None     # [automatically set; don't specify this]\n    vacc_tail: float = None     # [automatically set; don't specify this]\n    vL_mean: float = None       # [automatically set; don't specify this]\n    vL_tail: float = None       # [automatically set; don't specify this]\n    grad_norm: float = None     # [automatically set; don't specify this]\n    cur_lr: float = None        # [automatically set; don't specify this]\n    cur_wd: float = None        # [automatically set; don't specify this]\n    cur_it: str = ''            # [automatically set; don't specify this]\n    cur_ep: str = ''            # [automatically set; don't specify this]\n    remain_time: str = ''       # [automatically set; don't specify this]\n    finish_time: str = ''       # [automatically set; don't specify this]\n    \n    # environment\n    local_out_dir_path: str = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'local_output')  # [automatically set; don't specify this]\n    tb_log_dir_path: str = '...tb-...'  # [automatically set; don't specify this]\n    log_txt_path: str = '...'           # [automatically set; don't specify this]\n    last_ckpt_path: str = '...'         # [automatically set; don't specify this]\n    \n    tf32: bool = True       # whether to use TensorFloat32\n    device: str = 'cpu'     # [automatically set; don't specify this]\n    seed: int = None        # seed\n    def seed_everything(self, benchmark: bool):\n        torch.backends.cudnn.enabled = True\n        torch.backends.cudnn.benchmark = benchmark\n        if self.seed is None:\n            torch.backends.cudnn.deterministic = False\n        else:\n            torch.backends.cudnn.deterministic = True\n            seed = self.seed * dist.get_world_size() + dist.get_rank()\n            os.environ['PYTHONHASHSEED'] = str(seed)\n            random.seed(seed)\n            np.random.seed(seed)\n            torch.manual_seed(seed)\n            if torch.cuda.is_available():\n                torch.cuda.manual_seed(seed)\n                torch.cuda.manual_seed_all(seed)\n    same_seed_for_all_ranks: int = 0     # this is only for distributed sampler\n    def get_different_generator_for_each_rank(self) -> Optional[torch.Generator]:   # for random augmentation\n        if self.seed is None:\n            return None\n        g = torch.Generator()\n        g.manual_seed(self.seed * dist.get_world_size() + dist.get_rank())\n        return g\n    \n    local_debug: bool = 'KEVIN_LOCAL' in os.environ\n    dbg_nan: bool = False   # 'KEVIN_LOCAL' in os.environ\n    \n    def compile_model(self, m, fast):\n        if fast == 0 or self.local_debug:\n            return m\n        return torch.compile(m, mode={\n            1: 'reduce-overhead',\n            2: 'max-autotune',\n            3: 'default',\n        }[fast]) if hasattr(torch, 'compile') else m\n    \n    def state_dict(self, key_ordered=True) -> Union[OrderedDict, dict]:\n        d = (OrderedDict if key_ordered else dict)()\n        # self.as_dict() would contain methods, but we only need variables\n        for k in self.class_variables.keys():\n            if k not in {'device'}:     # these are not serializable\n                d[k] = getattr(self, k)\n        return d\n    \n    def load_state_dict(self, d: Union[OrderedDict, dict, str]):\n        if isinstance(d, str):  # for compatibility with old version\n            d: dict = eval('\\n'.join([l for l in d.splitlines() if '<bound' not in l and 'device(' not in l]))\n        for k in d.keys():\n            try:\n                setattr(self, k, d[k])\n            except Exception as e:\n                print(f'k={k}, v={d[k]}')\n                raise e\n    \n    @staticmethod\n    def set_tf32(tf32: bool):\n        if torch.cuda.is_available():\n            torch.backends.cudnn.allow_tf32 = bool(tf32)\n            torch.backends.cuda.matmul.allow_tf32 = bool(tf32)\n            if hasattr(torch, 'set_float32_matmul_precision'):\n                torch.set_float32_matmul_precision('high' if tf32 else 'highest')\n                print(f'[tf32] [precis] torch.get_float32_matmul_precision(): {torch.get_float32_matmul_precision()}')\n            print(f'[tf32] [ conv ] torch.backends.cudnn.allow_tf32: {torch.backends.cudnn.allow_tf32}')\n            print(f'[tf32] [matmul] torch.backends.cuda.matmul.allow_tf32: {torch.backends.cuda.matmul.allow_tf32}')\n    \n    def dump_log(self):\n        if not dist.is_local_master():\n            return\n        if '1/' in self.cur_ep: # first time to dump log\n            with open(self.log_txt_path, 'w') as fp:\n                json.dump({'is_master': dist.is_master(), 'name': self.exp_name, 'cmd': self.cmd, 'commit': self.commit_id, 'branch': self.branch, 'tb_log_dir_path': self.tb_log_dir_path}, fp, indent=0)\n                fp.write('\\n')\n        \n        log_dict = {}\n        for k, v in {\n            'it': self.cur_it, 'ep': self.cur_ep,\n            'lr': self.cur_lr, 'wd': self.cur_wd, 'grad_norm': self.grad_norm,\n            'L_mean': self.L_mean, 'L_tail': self.L_tail, 'acc_mean': self.acc_mean, 'acc_tail': self.acc_tail,\n            'vL_mean': self.vL_mean, 'vL_tail': self.vL_tail, 'vacc_mean': self.vacc_mean, 'vacc_tail': self.vacc_tail,\n            'remain_time': self.remain_time, 'finish_time': self.finish_time,\n        }.items():\n            if hasattr(v, 'item'): v = v.item()\n            log_dict[k] = v\n        with open(self.log_txt_path, 'a') as fp:\n            fp.write(f'{log_dict}\\n')\n    \n    def __str__(self):\n        s = []\n        for k in self.class_variables.keys():\n            if k not in {'device', 'dbg_ks_fp'}:     # these are not serializable\n                s.append(f'  {k:20s}: {getattr(self, k)}')\n        s = '\\n'.join(s)\n        return f'{{\\n{s}\\n}}\\n'\n\n\ndef init_dist_and_get_args():\n    for i in range(len(sys.argv)):\n        if sys.argv[i].startswith('--local-rank=') or sys.argv[i].startswith('--local_rank='):\n            del sys.argv[i]\n            break\n    args = Args(explicit_bool=True).parse_args(known_only=True)\n    if args.local_debug:\n        args.pn = '1_2_3'\n        args.seed = 1\n        args.aln = 1e-2\n        args.alng = 1e-5\n        args.saln = False\n        args.afuse = False\n        args.pg = 0.8\n        args.pg0 = 1\n    else:\n        if args.data_path == '/path/to/imagenet':\n            raise ValueError(f'{\"*\"*40}  please specify --data_path=/path/to/imagenet  {\"*\"*40}')\n    \n    # warn args.extra_args\n    if len(args.extra_args) > 0:\n        print(f'======================================================================================')\n        print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================\\n{args.extra_args}')\n        print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================')\n        print(f'======================================================================================\\n\\n')\n    \n    # init torch distributed\n    from utils import misc\n    os.makedirs(args.local_out_dir_path, exist_ok=True)\n    misc.init_distributed_mode(local_out_path=args.local_out_dir_path, timeout=30)\n    \n    # set env\n    args.set_tf32(args.tf32)\n    args.seed_everything(benchmark=args.pg == 0)\n    \n    # update args: data loading\n    args.device = dist.get_device()\n    if args.pn == '256':\n        args.pn = '1_2_3_4_5_6_8_10_13_16'\n    elif args.pn == '512':\n        args.pn = '1_2_3_4_6_9_13_18_24_32'\n    elif args.pn == '1024':\n        args.pn = '1_2_3_4_5_7_9_12_16_21_27_36_48_64'\n    args.patch_nums = tuple(map(int, args.pn.replace('-', '_').split('_')))\n    args.resos = tuple(pn * args.patch_size for pn in args.patch_nums)\n    args.data_load_reso = max(args.resos)\n    \n    # update args: bs and lr\n    bs_per_gpu = round(args.bs / args.ac / dist.get_world_size())\n    args.batch_size = bs_per_gpu\n    args.bs = args.glb_batch_size = args.batch_size * dist.get_world_size()\n    args.workers = min(max(0, args.workers), args.batch_size)\n    \n    args.tlr = args.ac * args.tblr * args.glb_batch_size / 256\n    args.twde = args.twde or args.twd\n    \n    if args.wp == 0:\n        args.wp = args.ep * 1/50\n    \n    # update args: progressive training\n    if args.pgwp == 0:\n        args.pgwp = args.ep * 1/300\n    if args.pg > 0:\n        args.sche = f'lin{args.pg:g}'\n    \n    # update args: paths\n    args.log_txt_path = os.path.join(args.local_out_dir_path, 'log.txt')\n    args.last_ckpt_path = os.path.join(args.local_out_dir_path, f'ar-ckpt-last.pth')\n    _reg_valid_name = re.compile(r'[^\\w\\-+,.]')\n    tb_name = _reg_valid_name.sub(\n        '_',\n        f'tb-VARd{args.depth}'\n        f'__pn{args.pn}'\n        f'__b{args.bs}ep{args.ep}{args.opt[:4]}lr{args.tblr:g}wd{args.twd:g}'\n    )\n    args.tb_log_dir_path = os.path.join(args.local_out_dir_path, tb_name)\n    \n    return args\n"
  },
  {
    "path": "utils/data.py",
    "content": "import os.path as osp\n\nimport PIL.Image as PImage\nfrom torchvision.datasets.folder import DatasetFolder, IMG_EXTENSIONS\nfrom torchvision.transforms import InterpolationMode, transforms\n\n\ndef normalize_01_into_pm1(x):  # normalize x from [0, 1] to [-1, 1] by (x*2) - 1\n    return x.add(x).add_(-1)\n\n\ndef build_dataset(\n    data_path: str, final_reso: int,\n    hflip=False, mid_reso=1.125,\n):\n    # build augmentations\n    mid_reso = round(mid_reso * final_reso)  # first resize to mid_reso, then crop to final_reso\n    train_aug, val_aug = [\n        transforms.Resize(mid_reso, interpolation=InterpolationMode.LANCZOS), # transforms.Resize: resize the shorter edge to mid_reso\n        transforms.RandomCrop((final_reso, final_reso)),\n        transforms.ToTensor(), normalize_01_into_pm1,\n    ], [\n        transforms.Resize(mid_reso, interpolation=InterpolationMode.LANCZOS), # transforms.Resize: resize the shorter edge to mid_reso\n        transforms.CenterCrop((final_reso, final_reso)),\n        transforms.ToTensor(), normalize_01_into_pm1,\n    ]\n    if hflip: train_aug.insert(0, transforms.RandomHorizontalFlip())\n    train_aug, val_aug = transforms.Compose(train_aug), transforms.Compose(val_aug)\n    \n    # build dataset\n    train_set = DatasetFolder(root=osp.join(data_path, 'train'), loader=pil_loader, extensions=IMG_EXTENSIONS, transform=train_aug)\n    val_set = DatasetFolder(root=osp.join(data_path, 'val'), loader=pil_loader, extensions=IMG_EXTENSIONS, transform=val_aug)\n    num_classes = 1000\n    print(f'[Dataset] {len(train_set)=}, {len(val_set)=}, {num_classes=}')\n    print_aug(train_aug, '[train]')\n    print_aug(val_aug, '[val]')\n    \n    return num_classes, train_set, val_set\n\n\ndef pil_loader(path):\n    with open(path, 'rb') as f:\n        img: PImage.Image = PImage.open(f).convert('RGB')\n    return img\n\n\ndef print_aug(transform, label):\n    print(f'Transform {label} = ')\n    if hasattr(transform, 'transforms'):\n        for t in transform.transforms:\n            print(t)\n    else:\n        print(transform)\n    print('---------------------------\\n')\n"
  },
  {
    "path": "utils/data_sampler.py",
    "content": "import numpy as np\nimport torch\nfrom torch.utils.data.sampler import Sampler\n\n\nclass EvalDistributedSampler(Sampler):\n    def __init__(self, dataset, num_replicas, rank):\n        seps = np.linspace(0, len(dataset), num_replicas+1, dtype=int)\n        beg, end = seps[:-1], seps[1:]\n        beg, end = beg[rank], end[rank]\n        self.indices = tuple(range(beg, end))\n    \n    def __iter__(self):\n        return iter(self.indices)\n    \n    def __len__(self) -> int:\n        return len(self.indices)\n\n\nclass InfiniteBatchSampler(Sampler):\n    def __init__(self, dataset_len, batch_size, seed_for_all_rank=0, fill_last=False, shuffle=True, drop_last=False, start_ep=0, start_it=0):\n        self.dataset_len = dataset_len\n        self.batch_size = batch_size\n        self.iters_per_ep = dataset_len // batch_size if drop_last else (dataset_len + batch_size - 1) // batch_size\n        self.max_p = self.iters_per_ep * batch_size\n        self.fill_last = fill_last\n        self.shuffle = shuffle\n        self.epoch = start_ep\n        self.same_seed_for_all_ranks = seed_for_all_rank\n        self.indices = self.gener_indices()\n        self.start_ep, self.start_it = start_ep, start_it\n    \n    def gener_indices(self):\n        if self.shuffle:\n            g = torch.Generator()\n            g.manual_seed(self.epoch + self.same_seed_for_all_ranks)\n            indices = torch.randperm(self.dataset_len, generator=g).numpy()\n        else:\n            indices = torch.arange(self.dataset_len).numpy()\n        \n        tails = self.batch_size - (self.dataset_len % self.batch_size)\n        if tails != self.batch_size and self.fill_last:\n            tails = indices[:tails]\n            np.random.shuffle(indices)\n            indices = np.concatenate((indices, tails))\n        \n        # built-in list/tuple is faster than np.ndarray (when collating the data via a for-loop)\n        # noinspection PyTypeChecker\n        return tuple(indices.tolist())\n    \n    def __iter__(self):\n        self.epoch = self.start_ep\n        while True:\n            self.epoch += 1\n            p = (self.start_it * self.batch_size) if self.epoch == self.start_ep else 0\n            while p < self.max_p:\n                q = p + self.batch_size\n                yield self.indices[p:q]\n                p = q\n            if self.shuffle:\n                self.indices = self.gener_indices()\n    \n    def __len__(self):\n        return self.iters_per_ep\n\n\nclass DistInfiniteBatchSampler(InfiniteBatchSampler):\n    def __init__(self, world_size, rank, dataset_len, glb_batch_size, same_seed_for_all_ranks=0, repeated_aug=0, fill_last=False, shuffle=True, start_ep=0, start_it=0):\n        assert glb_batch_size % world_size == 0\n        self.world_size, self.rank = world_size, rank\n        self.dataset_len = dataset_len\n        self.glb_batch_size = glb_batch_size\n        self.batch_size = glb_batch_size // world_size\n        \n        self.iters_per_ep = (dataset_len + glb_batch_size - 1) // glb_batch_size\n        self.fill_last = fill_last\n        self.shuffle = shuffle\n        self.repeated_aug = repeated_aug\n        self.epoch = start_ep\n        self.same_seed_for_all_ranks = same_seed_for_all_ranks\n        self.indices = self.gener_indices()\n        self.start_ep, self.start_it = start_ep, start_it\n    \n    def gener_indices(self):\n        global_max_p = self.iters_per_ep * self.glb_batch_size  # global_max_p % world_size must be 0 cuz glb_batch_size % world_size == 0\n        # print(f'global_max_p = iters_per_ep({self.iters_per_ep}) * glb_batch_size({self.glb_batch_size}) = {global_max_p}')\n        if self.shuffle:\n            g = torch.Generator()\n            g.manual_seed(self.epoch + self.same_seed_for_all_ranks)\n            global_indices = torch.randperm(self.dataset_len, generator=g)\n            if self.repeated_aug > 1:\n                global_indices = global_indices[:(self.dataset_len + self.repeated_aug - 1) // self.repeated_aug].repeat_interleave(self.repeated_aug, dim=0)[:global_max_p]\n        else:\n            global_indices = torch.arange(self.dataset_len)\n        filling = global_max_p - global_indices.shape[0]\n        if filling > 0 and self.fill_last:\n            global_indices = torch.cat((global_indices, global_indices[:filling]))\n        # global_indices = tuple(global_indices.numpy().tolist())\n        \n        seps = torch.linspace(0, global_indices.shape[0], self.world_size + 1, dtype=torch.int)\n        local_indices = global_indices[seps[self.rank].item():seps[self.rank + 1].item()].tolist()\n        self.max_p = len(local_indices)\n        return local_indices\n"
  },
  {
    "path": "utils/lr_control.py",
    "content": "import math\nfrom pprint import pformat\nfrom typing import Tuple, List, Dict, Union\n\nimport torch.nn\n\nimport dist\n\n\ndef lr_wd_annealing(sche_type: str, optimizer, peak_lr, wd, wd_end, cur_it, wp_it, max_it, wp0=0.005, wpe=0.001):\n    \"\"\"Decay the learning rate with half-cycle cosine after warmup\"\"\"\n    wp_it = round(wp_it)\n    \n    if cur_it < wp_it:\n        cur_lr = wp0 + (1-wp0) * cur_it / wp_it\n    else:\n        pasd = (cur_it - wp_it) / (max_it-1 - wp_it)   # [0, 1]\n        rest = 1 - pasd     # [1, 0]\n        if sche_type == 'cos':\n            cur_lr = wpe + (1-wpe) * (0.5 + 0.5 * math.cos(math.pi * pasd))\n        elif sche_type == 'lin':\n            T = 0.15; max_rest = 1-T\n            if pasd < T: cur_lr = 1\n            else: cur_lr = wpe + (1-wpe) * rest / max_rest  # 1 to wpe\n        elif sche_type == 'lin0':\n            T = 0.05; max_rest = 1-T\n            if pasd < T: cur_lr = 1\n            else: cur_lr = wpe + (1-wpe) * rest / max_rest\n        elif sche_type == 'lin00':\n            cur_lr = wpe + (1-wpe) * rest\n        elif sche_type.startswith('lin'):\n            T = float(sche_type[3:]); max_rest = 1-T\n            wpe_mid = wpe + (1-wpe) * max_rest\n            wpe_mid = (1 + wpe_mid) / 2\n            if pasd < T: cur_lr = 1 + (wpe_mid-1) * pasd / T\n            else: cur_lr = wpe + (wpe_mid-wpe) * rest / max_rest\n        elif sche_type == 'exp':\n            T = 0.15; max_rest = 1-T\n            if pasd < T: cur_lr = 1\n            else:\n                expo = (pasd-T) / max_rest * math.log(wpe)\n                cur_lr = math.exp(expo)\n        else:\n            raise NotImplementedError(f'unknown sche_type {sche_type}')\n    \n    cur_lr *= peak_lr\n    pasd = cur_it / (max_it-1)\n    cur_wd = wd_end + (wd - wd_end) * (0.5 + 0.5 * math.cos(math.pi * pasd))\n    \n    inf = 1e6\n    min_lr, max_lr = inf, -1\n    min_wd, max_wd = inf, -1\n    for param_group in optimizer.param_groups:\n        param_group['lr'] = cur_lr * param_group.get('lr_sc', 1)    # 'lr_sc' could be assigned\n        max_lr = max(max_lr, param_group['lr'])\n        min_lr = min(min_lr, param_group['lr'])\n        \n        param_group['weight_decay'] = cur_wd * param_group.get('wd_sc', 1)\n        max_wd = max(max_wd, param_group['weight_decay'])\n        if param_group['weight_decay'] > 0:\n            min_wd = min(min_wd, param_group['weight_decay'])\n\n    if min_lr == inf: min_lr = -1\n    if min_wd == inf: min_wd = -1\n    return min_lr, max_lr, min_wd, max_wd\n\n\ndef filter_params(model, nowd_keys=()) -> Tuple[\n    List[str], List[torch.nn.Parameter], List[Dict[str, Union[torch.nn.Parameter, float]]]\n]:\n    para_groups, para_groups_dbg = {}, {}\n    names, paras = [], []\n    names_no_grad = []\n    count, numel = 0, 0\n    for name, para in model.named_parameters():\n        name = name.replace('_fsdp_wrapped_module.', '')\n        if not para.requires_grad:\n            names_no_grad.append(name)\n            continue  # frozen weights\n        count += 1\n        numel += para.numel()\n        names.append(name)\n        paras.append(para)\n        \n        if para.ndim == 1 or name.endswith('bias') or any(k in name for k in nowd_keys):\n            cur_wd_sc, group_name = 0., 'ND'\n        else:\n            cur_wd_sc, group_name = 1., 'D'\n        cur_lr_sc = 1.\n        if group_name not in para_groups:\n            para_groups[group_name] = {'params': [], 'wd_sc': cur_wd_sc, 'lr_sc': cur_lr_sc}\n            para_groups_dbg[group_name] = {'params': [], 'wd_sc': cur_wd_sc, 'lr_sc': cur_lr_sc}\n        para_groups[group_name]['params'].append(para)\n        para_groups_dbg[group_name]['params'].append(name)\n    \n    for g in para_groups_dbg.values():\n        g['params'] = pformat(', '.join(g['params']), width=200)\n    \n    print(f'[get_param_groups] param_groups = \\n{pformat(para_groups_dbg, indent=2, width=240)}\\n')\n    \n    for rk in range(dist.get_world_size()):\n        dist.barrier()\n        if dist.get_rank() == rk:\n            print(f'[get_param_groups][rank{dist.get_rank()}] {type(model).__name__=} {count=}, {numel=}', flush=True, force=True)\n    print('')\n    \n    assert len(names_no_grad) == 0, f'[get_param_groups] names_no_grad = \\n{pformat(names_no_grad, indent=2, width=240)}\\n'\n    return names, paras, list(para_groups.values())\n"
  },
  {
    "path": "utils/misc.py",
    "content": "import datetime\nimport functools\nimport glob\nimport os\nimport subprocess\nimport sys\nimport time\nfrom collections import defaultdict, deque\nfrom typing import Iterator, List, Tuple\n\nimport numpy as np\nimport pytz\nimport torch\nimport torch.distributed as tdist\n\nimport dist\nfrom utils import arg_util\n\nos_system = functools.partial(subprocess.call, shell=True)\ndef echo(info):\n    os_system(f'echo \"[$(date \"+%m-%d-%H:%M:%S\")] ({os.path.basename(sys._getframe().f_back.f_code.co_filename)}, line{sys._getframe().f_back.f_lineno})=> {info}\"')\ndef os_system_get_stdout(cmd):\n    return subprocess.run(cmd, shell=True, stdout=subprocess.PIPE).stdout.decode('utf-8')\ndef os_system_get_stdout_stderr(cmd):\n    cnt = 0\n    while True:\n        try:\n            sp = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=30)\n        except subprocess.TimeoutExpired:\n            cnt += 1\n            print(f'[fetch free_port file] timeout cnt={cnt}')\n        else:\n            return sp.stdout.decode('utf-8'), sp.stderr.decode('utf-8')\n\n\ndef time_str(fmt='[%m-%d %H:%M:%S]'):\n    return datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime(fmt)\n\n\ndef init_distributed_mode(local_out_path, only_sync_master=False, timeout=30):\n    try:\n        dist.initialize(fork=False, timeout=timeout)\n        dist.barrier()\n    except RuntimeError:\n        print(f'{\">\"*75}  NCCL Error  {\"<\"*75}', flush=True)\n        time.sleep(10)\n    \n    if local_out_path is not None: os.makedirs(local_out_path, exist_ok=True)\n    _change_builtin_print(dist.is_local_master())\n    if (dist.is_master() if only_sync_master else dist.is_local_master()) and local_out_path is not None and len(local_out_path):\n        sys.stdout, sys.stderr = SyncPrint(local_out_path, sync_stdout=True), SyncPrint(local_out_path, sync_stdout=False)\n\n\ndef _change_builtin_print(is_master):\n    import builtins as __builtin__\n    \n    builtin_print = __builtin__.print\n    if type(builtin_print) != type(open):\n        return\n    \n    def prt(*args, **kwargs):\n        force = kwargs.pop('force', False)\n        clean = kwargs.pop('clean', False)\n        deeper = kwargs.pop('deeper', False)\n        if is_master or force:\n            if not clean:\n                f_back = sys._getframe().f_back\n                if deeper and f_back.f_back is not None:\n                    f_back = f_back.f_back\n                file_desc = f'{f_back.f_code.co_filename:24s}'[-24:]\n                builtin_print(f'{time_str()} ({file_desc}, line{f_back.f_lineno:-4d})=>', *args, **kwargs)\n            else:\n                builtin_print(*args, **kwargs)\n    \n    __builtin__.print = prt\n\n\nclass SyncPrint(object):\n    def __init__(self, local_output_dir, sync_stdout=True):\n        self.sync_stdout = sync_stdout\n        self.terminal_stream = sys.stdout if sync_stdout else sys.stderr\n        fname = os.path.join(local_output_dir, 'stdout.txt' if sync_stdout else 'stderr.txt')\n        existing = os.path.exists(fname)\n        self.file_stream = open(fname, 'a')\n        if existing:\n            self.file_stream.write('\\n'*7 + '='*55 + f'   RESTART {time_str()}   ' + '='*55 + '\\n')\n        self.file_stream.flush()\n        self.enabled = True\n    \n    def write(self, message):\n        self.terminal_stream.write(message)\n        self.file_stream.write(message)\n    \n    def flush(self):\n        self.terminal_stream.flush()\n        self.file_stream.flush()\n    \n    def close(self):\n        if not self.enabled:\n            return\n        self.enabled = False\n        self.file_stream.flush()\n        self.file_stream.close()\n        if self.sync_stdout:\n            sys.stdout = self.terminal_stream\n            sys.stdout.flush()\n        else:\n            sys.stderr = self.terminal_stream\n            sys.stderr.flush()\n    \n    def __del__(self):\n        self.close()\n\n\nclass DistLogger(object):\n    def __init__(self, lg, verbose):\n        self._lg, self._verbose = lg, verbose\n    \n    @staticmethod\n    def do_nothing(*args, **kwargs):\n        pass\n    \n    def __getattr__(self, attr: str):\n        return getattr(self._lg, attr) if self._verbose else DistLogger.do_nothing\n\n\nclass TensorboardLogger(object):\n    def __init__(self, log_dir, filename_suffix):\n        try: import tensorflow_io as tfio\n        except: pass\n        from torch.utils.tensorboard import SummaryWriter\n        self.writer = SummaryWriter(log_dir=log_dir, filename_suffix=filename_suffix)\n        self.step = 0\n    \n    def set_step(self, step=None):\n        if step is not None:\n            self.step = step\n        else:\n            self.step += 1\n    \n    def update(self, head='scalar', step=None, **kwargs):\n        for k, v in kwargs.items():\n            if v is None:\n                continue\n            # assert isinstance(v, (float, int)), type(v)\n            if step is None:  # iter wise\n                it = self.step\n                if it == 0 or (it + 1) % 500 == 0:\n                    if hasattr(v, 'item'): v = v.item()\n                    self.writer.add_scalar(f'{head}/{k}', v, it)\n            else:  # epoch wise\n                if hasattr(v, 'item'): v = v.item()\n                self.writer.add_scalar(f'{head}/{k}', v, step)\n    \n    def log_tensor_as_distri(self, tag, tensor1d, step=None):\n        if step is None:  # iter wise\n            step = self.step\n            loggable = step == 0 or (step + 1) % 500 == 0\n        else:  # epoch wise\n            loggable = True\n        if loggable:\n            try:\n                self.writer.add_histogram(tag=tag, values=tensor1d, global_step=step)\n            except Exception as e:\n                print(f'[log_tensor_as_distri writer.add_histogram failed]: {e}')\n    \n    def log_image(self, tag, img_chw, step=None):\n        if step is None:  # iter wise\n            step = self.step\n            loggable = step == 0 or (step + 1) % 500 == 0\n        else:  # epoch wise\n            loggable = True\n        if loggable:\n            self.writer.add_image(tag, img_chw, step, dataformats='CHW')\n    \n    def flush(self):\n        self.writer.flush()\n    \n    def close(self):\n        self.writer.close()\n\n\nclass SmoothedValue(object):\n    \"\"\"Track a series of values and provide access to smoothed values over a\n    window or the global series average.\n    \"\"\"\n    \n    def __init__(self, window_size=30, fmt=None):\n        if fmt is None:\n            fmt = \"{median:.4f} ({global_avg:.4f})\"\n        self.deque = deque(maxlen=window_size)\n        self.total = 0.0\n        self.count = 0\n        self.fmt = fmt\n    \n    def update(self, value, n=1):\n        self.deque.append(value)\n        self.count += n\n        self.total += value * n\n    \n    def synchronize_between_processes(self):\n        \"\"\"\n        Warning: does not synchronize the deque!\n        \"\"\"\n        t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')\n        tdist.barrier()\n        tdist.all_reduce(t)\n        t = t.tolist()\n        self.count = int(t[0])\n        self.total = t[1]\n    \n    @property\n    def median(self):\n        return np.median(self.deque) if len(self.deque) else 0\n    \n    @property\n    def avg(self):\n        return sum(self.deque) / (len(self.deque) or 1)\n    \n    @property\n    def global_avg(self):\n        return self.total / (self.count or 1)\n    \n    @property\n    def max(self):\n        return max(self.deque)\n    \n    @property\n    def value(self):\n        return self.deque[-1] if len(self.deque) else 0\n    \n    def time_preds(self, counts) -> Tuple[float, str, str]:\n        remain_secs = counts * self.median\n        return remain_secs, str(datetime.timedelta(seconds=round(remain_secs))), time.strftime(\"%Y-%m-%d %H:%M\", time.localtime(time.time() + remain_secs))\n    \n    def __str__(self):\n        return self.fmt.format(\n            median=self.median,\n            avg=self.avg,\n            global_avg=self.global_avg,\n            max=self.max,\n            value=self.value)\n\n\nclass MetricLogger(object):\n    def __init__(self, delimiter='  '):\n        self.meters = defaultdict(SmoothedValue)\n        self.delimiter = delimiter\n        self.iter_end_t = time.time()\n        self.log_iters = []\n    \n    def update(self, **kwargs):\n        for k, v in kwargs.items():\n            if v is None:\n                continue\n            if hasattr(v, 'item'): v = v.item()\n            # assert isinstance(v, (float, int)), type(v)\n            assert isinstance(v, (float, int))\n            self.meters[k].update(v)\n    \n    def __getattr__(self, attr):\n        if attr in self.meters:\n            return self.meters[attr]\n        if attr in self.__dict__:\n            return self.__dict__[attr]\n        raise AttributeError(\"'{}' object has no attribute '{}'\".format(\n            type(self).__name__, attr))\n    \n    def __str__(self):\n        loss_str = []\n        for name, meter in self.meters.items():\n            if len(meter.deque):\n                loss_str.append(\n                    \"{}: {}\".format(name, str(meter))\n                )\n        return self.delimiter.join(loss_str)\n    \n    def synchronize_between_processes(self):\n        for meter in self.meters.values():\n            meter.synchronize_between_processes()\n    \n    def add_meter(self, name, meter):\n        self.meters[name] = meter\n    \n    def log_every(self, start_it, max_iters, itrt, print_freq, header=None):\n        self.log_iters = set(np.linspace(0, max_iters-1, print_freq, dtype=int).tolist())\n        self.log_iters.add(start_it)\n        if not header:\n            header = ''\n        start_time = time.time()\n        self.iter_end_t = time.time()\n        self.iter_time = SmoothedValue(fmt='{avg:.4f}')\n        self.data_time = SmoothedValue(fmt='{avg:.4f}')\n        space_fmt = ':' + str(len(str(max_iters))) + 'd'\n        log_msg = [\n            header,\n            '[{0' + space_fmt + '}/{1}]',\n            'eta: {eta}',\n            '{meters}',\n            'time: {time}',\n            'data: {data}'\n        ]\n        log_msg = self.delimiter.join(log_msg)\n        \n        if isinstance(itrt, Iterator) and not hasattr(itrt, 'preload') and not hasattr(itrt, 'set_epoch'):\n            for i in range(start_it, max_iters):\n                obj = next(itrt)\n                self.data_time.update(time.time() - self.iter_end_t)\n                yield i, obj\n                self.iter_time.update(time.time() - self.iter_end_t)\n                if i in self.log_iters:\n                    eta_seconds = self.iter_time.global_avg * (max_iters - i)\n                    eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))\n                    print(log_msg.format(\n                        i, max_iters, eta=eta_string,\n                        meters=str(self),\n                        time=str(self.iter_time), data=str(self.data_time)), flush=True)\n                self.iter_end_t = time.time()\n        else:\n            if isinstance(itrt, int): itrt = range(itrt)\n            for i, obj in enumerate(itrt):\n                self.data_time.update(time.time() - self.iter_end_t)\n                yield i, obj\n                self.iter_time.update(time.time() - self.iter_end_t)\n                if i in self.log_iters:\n                    eta_seconds = self.iter_time.global_avg * (max_iters - i)\n                    eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))\n                    print(log_msg.format(\n                        i, max_iters, eta=eta_string,\n                        meters=str(self),\n                        time=str(self.iter_time), data=str(self.data_time)), flush=True)\n                self.iter_end_t = time.time()\n        \n        total_time = time.time() - start_time\n        total_time_str = str(datetime.timedelta(seconds=int(total_time)))\n        print('{}   Total time:      {}   ({:.3f} s / it)'.format(\n            header, total_time_str, total_time / max_iters), flush=True)\n\n\ndef glob_with_latest_modified_first(pattern, recursive=False):\n    return sorted(glob.glob(pattern, recursive=recursive), key=os.path.getmtime, reverse=True)\n\n\ndef auto_resume(args: arg_util.Args, pattern='ckpt*.pth') -> Tuple[List[str], int, int, dict, dict]:\n    info = []\n    file = os.path.join(args.local_out_dir_path, pattern)\n    all_ckpt = glob_with_latest_modified_first(file)\n    if len(all_ckpt) == 0:\n        info.append(f'[auto_resume] no ckpt found @ {file}')\n        info.append(f'[auto_resume quit]')\n        return info, 0, 0, {}, {}\n    else:\n        info.append(f'[auto_resume] load ckpt from @ {all_ckpt[0]} ...')\n        ckpt = torch.load(all_ckpt[0], map_location='cpu')\n        ep, it = ckpt['epoch'], ckpt['iter']\n        info.append(f'[auto_resume success] resume from ep{ep}, it{it}')\n        return info, ep, it, ckpt['trainer'], ckpt['args']\n\n\ndef create_npz_from_sample_folder(sample_folder: str):\n    \"\"\"\n    Builds a single .npz file from a folder of .png samples. Refer to DiT.\n    \"\"\"\n    import os, glob\n    import numpy as np\n    from tqdm import tqdm\n    from PIL import Image\n    \n    samples = []\n    pngs = glob.glob(os.path.join(sample_folder, '*.png')) + glob.glob(os.path.join(sample_folder, '*.PNG'))\n    assert len(pngs) == 50_000, f'{len(pngs)} png files found in {sample_folder}, but expected 50,000'\n    for png in tqdm(pngs, desc='Building .npz file from samples (png only)'):\n        with Image.open(png) as sample_pil:\n            sample_np = np.asarray(sample_pil).astype(np.uint8)\n        samples.append(sample_np)\n    samples = np.stack(samples)\n    assert samples.shape == (50_000, samples.shape[1], samples.shape[2], 3)\n    npz_path = f'{sample_folder}.npz'\n    np.savez(npz_path, arr_0=samples)\n    print(f'Saved .npz file to {npz_path} [shape={samples.shape}].')\n    return npz_path\n"
  }
]