[
  {
    "path": ".github/workflows/docker_build.yml",
    "content": "name: docker-build\n\non:\n  workflow_dispatch:\n  push:\n    branches:\n      - \"main\"\n    paths:\n      - \"docker/Dockerfile\"\n\njobs:\n  build-Open-Sora:\n    runs-on: ubuntu-latest\n    steps:\n      -\n        name: Checkout\n        uses: actions/checkout@v4\n      -\n        name: Set up QEMU\n        uses: docker/setup-qemu-action@v3\n      -\n        name: Set up Docker Buildx\n        uses: docker/setup-buildx-action@v3\n      -\n        name: Login to Docker Hub\n        uses: docker/login-action@v3\n        with:\n          username: ${{ secrets.DOCKERHUB_USERNAME }}\n          password: ${{ secrets.DOCKERHUB_TOKEN }}\n      -\n        name: Build and push Open-Sora image\n        uses: docker/build-push-action@v5\n        with:\n          context: .\n          file: ./docker/Dockerfile\n          push: true\n          platforms: linux/amd64, linux/arm64, linux/s390x, linux/ppc64le\n          tags: ${{ secrets.DOCKERHUB_USERNAME }}/open-sora"
  },
  {
    "path": ".gitignore",
    "content": "ucf101_stride4x4x4\n__pycache__\n*.mp4\n.ipynb_checkpoints\n*.pth\nUCF-101/\nresults/\nbuild/\nopensora.egg-info/\nwandb/\n.idea\n*.ipynb\n*.jpg\n*.mp3\n*.safetensors\n*.mp4\n*.png\n*.gif\n*.pth\n*.pt\ncache_dir/\nwandb/\ntest*\nsample_video*/\n512*\n720*\n1024*\n*debug*\nprivate*\n.deepspeed_env\n256*\nsample_image*/\ntaming*\n*test*\nsft*\nflash*\n65x256*\nalpha_vae\n*node*\ncache/\nOpen-Sora-Plan_models/\nsample_image*cfg*\n*tmp*\n*pymp*\ncheck.py\nbucket.py\nwhileinf.py\nvalidation_dir/\nruns/\nsamples/\ninpaint*/\nbs32x8x1*\n*tmp*\n*pymp*\ncheck.py\nbucket.py\nwhileinf.py\nbs4x8x16_*\n*.zip\n*validation/\nbs1x8x32*\nbs16x8x1*\nbs8x8x2*\nbs8x8x1*\nbs8x8x8*\nbs1x8x16*\nchecklora.py\ndim4todim8.py\n*vae8_any*320x320*\nsamples/\nruns/\n*validation/\ntraining_log*txt\nfilter_motion*\njson2*.py\nmotionfun*\nres_dist*\nfilter_json_aes_m*\nstage2*.json\nkernel_meta\nge_check_op.json\nWFVAE_DISTILL_FORMAL\nread_video*\nbs32x8x2*\nfilter_json_aes_m*\njson2json*\nmakenpu_json*\n*make_small_json*\n*schedule_noise*\ntest*\ngpu_profiling*\ngyy_dense*\ntorchelasti*\n*VEnhancer*\n*spdemo*\ni2v.txt\n*run_i2v*\n*curope*\nany*\n*nomotion*\nlog*\n*svg\n*k8s*\n*rf*\n*lzj*\nfinal*\nopensora/train/*debug.py\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) Rabbitpre Intelligence Ltd\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "\n\n<h1 align=\"left\"> <a href=\"\">Open-Sora Plan</a></h1>\n\nThis project aims to create a simple and scalable repo, to reproduce [Sora](https://openai.com/sora) (OpenAI, but we prefer to call it \"ClosedAI\" ). \n\n本项目希望通过开源社区的力量复现Sora，由北大-兔展AIGC联合实验室共同发起，来自兔展、华为、鹏城实验室和开源社区伙伴均有深度贡献力量。\n\n当前V1.5版本**完全基于华为昇腾训练（昇腾纯血版）**，欢迎Pull Request和使用！\n\n我们正在快速迭代新版本，欢迎更多合作者或算法工程师加入，[算法工程师招聘-兔展智能.pdf](https://github.com/user-attachments/files/19107972/-.pdf)\n\n<h5 align=\"left\">\n\n[![arXiv](https://img.shields.io/badge/Arxiv-Open--Sora%20Plan-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2412.00131)\n[![arXiv](https://img.shields.io/badge/Arxiv-Helios-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2603.04379)\n[![arXiv](https://img.shields.io/badge/Arxiv-WF--VAE-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2411.17459)\n[![License](https://img.shields.io/badge/License-Apache-yellow)](https://github.com/PKU-YuanGroup/Open-Sora-Plan/blob/main/LICENSE)  <br>\n[![slack badge](https://img.shields.io/badge/Discord-join-blueviolet?logo=discord&amp)](https://discord.gg/DFZg5678)\n[![WeChat badge](https://img.shields.io/badge/微信-加入-green?logo=wechat&amp)](https://github.com/PKU-YuanGroup/Open-Sora-Plan/issues/53#issuecomment-1987226516)\n[![Twitter](https://img.shields.io/badge/-Twitter@LinBin46984-black?logo=twitter&logoColor=1D9BF0)](https://x.com/LinBin46984/status/1795018003345510687) \n[![Modelers](https://img.shields.io/badge/%E9%AD%94%E4%B9%90-%E6%A8%A1%E5%9E%8B%E4%BD%93%E9%AA%8C-blue)](https://modelers.cn/spaces/MindSpore-Lab/Open_Sora_Plan) <br>\n[![GitHub repo stars](https://img.shields.io/github/stars/PKU-YuanGroup/Open-Sora-Plan?style=flat&logo=github&logoColor=whitesmoke&label=Stars)](https://github.com/PKU-YuanGroup/Open-Sora-Plan/stargazers)&#160;\n[![GitHub repo forks](https://img.shields.io/github/forks/PKU-YuanGroup/Open-Sora-Plan?style=flat&logo=github&logoColor=whitesmoke&label=Forks)](https://github.com/PKU-YuanGroup/Open-Sora-Plan/network)&#160;\n[![GitHub repo watchers](https://img.shields.io/github/watchers/PKU-YuanGroup/Open-Sora-Plan?style=flat&logo=github&logoColor=whitesmoke&label=Watchers)](https://github.com/PKU-YuanGroup/Open-Sora-Plan/watchers)&#160;\n[![GitHub repo size](https://img.shields.io/github/repo-size/PKU-YuanGroup/Open-Sora-Plan?style=flat&logo=github&logoColor=whitesmoke&label=Repo%20Size)](https://github.com/PKU-YuanGroup/Open-Sora-Plan/archive/refs/heads/main.zip) <br>\n[![GitHub repo contributors](https://img.shields.io/github/contributors-anon/PKU-YuanGroup/Open-Sora-Plan?style=flat&label=Contributors)](https://github.com/PKU-YuanGroup/Open-Sora-Plan/graphs/contributors) \n[![GitHub Commit](https://img.shields.io/github/commit-activity/m/PKU-YuanGroup/Open-Sora-Plan?label=Commit)](https://github.com/PKU-YuanGroup/Open-Sora-Plan/commits/main/)\n[![Pr](https://img.shields.io/github/issues-pr-closed-raw/PKU-YuanGroup/Open-Sora-Plan.svg?label=Merged+PRs&color=green)](https://github.com/PKU-YuanGroup/Open-Sora-Plan/pulls)\n[![GitHub issues](https://img.shields.io/github/issues/PKU-YuanGroup/Open-Sora-Plan?color=critical&label=Issues)](https://github.com/PKU-YuanGroup/Video-LLaVA/issues?q=is%3Aopen+is%3Aissue)\n[![GitHub closed issues](https://img.shields.io/github/issues-closed/PKU-YuanGroup/Open-Sora-Plan?color=success&label=Issues)](https://github.com/PKU-YuanGroup/Video-LLaVA/issues?q=is%3Aissue+is%3Aclosed)\n</h5>\n<a href=\"https://trendshift.io/repositories/8280\" target=\"_blank\"><img src=\"https://trendshift.io/api/badge/repositories/8280\" alt=\"PKU-YuanGroup%2FOpen-Sora-Plan | Trendshift\" style=\"width: 250px; height: 55px;\" width=\"250\" height=\"55\"/></a>\n<h5 align=\"left\"> If you like our project, please give us a star ⭐ on GitHub for latest update.  </h2>\n\n\n# 📣 News\n\n* **[2026.03.08]** 👋👋👋 We introduce [Helios](https://github.com/PKU-YuanGroup/Helios), a breakthrough video generation model that achieves minute-scale, high-quality video synthesis at **19.5 FPS on a single H100** GPU — without relying on conventional long video anti-drifting strategies or standard video acceleration techniques. Welcome to check [Technical Report](https://huggingface.co/papers/2603.04379)!\n* **[2025.06.05]** 🔥🔥🔥 We release version 1.5.0, our most powerful model! By introducing a **higher-compression WFVAE** and an improved sparse DiT architecture, **SUV**, we achieve performance **comparable to HunyuanVideo (Open-Source)** using an 8B-scale model and 40 million video samples. Version 1.5.0 is **fully trained and inferred on Ascend 910-series accelerators**; Please check the [mindspeed_mmdit](https://github.com/PKU-YuanGroup/Open-Sora-Plan/tree/mindspeed_mmdit) branch for our new code and [Report-v1.5.0.md](docs/Report-v1.5.0.md) for our report. The GPU version is coming soon. \n* **[2024.12.03]** ⚡️ We released our [arxiv paper](https://arxiv.org/abs/2412.00131) and WF-VAE [paper](https://arxiv.org/abs/2411.17459) for v1.3. The next more powerful version is coming soon.\n* **[2024.10.16]** 🎉 We released version 1.3.0, featuring: **WFVAE**, **prompt refiner**, **data filtering strategy**, **sparse attention**, and **bucket training strategy**. We also support 93x480p within **24G VRAM**. More details can be found at our latest [report](docs/Report-v1.3.0.md).\n* **[2024.08.13]** 🎉 We are launching Open-Sora Plan v1.2.0 **I2V** model, which is based on Open-Sora Plan v1.2.0. The current version supports image-to-video generation and transition generation (the starting and ending frames conditions for video generation). Check out the Image-to-Video section in this [report](https://github.com/PKU-YuanGroup/Open-Sora-Plan/blob/main/docs/Report-v1.2.0.md#training-image-to-video-diffusion-model).\n* **[2024.07.24]** 🔥🔥🔥 v1.2.0 is here! Utilizing a 3D full attention architecture instead of 2+1D. We released a true 3D video diffusion model trained on 4s 720p. Check out our latest [report](docs/Report-v1.2.0.md).\n* **[2024.05.27]** 🎉 We are launching Open-Sora Plan v1.1.0, which significantly improves video quality and length, and is fully open source! Please check out our latest [report](docs/Report-v1.1.0.md). Thanks to [ShareGPT4Video's](https://sharegpt4video.github.io/) capability to annotate long videos.\n* **[2024.04.09]** 🤝 Excited to share our latest exploration on metamorphic time-lapse video generation: [MagicTime](https://github.com/PKU-YuanGroup/MagicTime), which learns real-world physics knowledge from time-lapse videos.\n* **[2024.04.07]** 🎉🎉🎉 Today, we are thrilled to present Open-Sora-Plan v1.0.0, which significantly enhances video generation quality and text control capabilities. See our [report](docs/Report-v1.0.0.md). Thanks to HUAWEI NPU for supporting us.\n* **[2024.03.27]** 🚀🚀🚀 We release the report of [VideoCausalVAE](docs/CausalVideoVAE.md), which supports both images and videos. We present our reconstructed video in this demonstration as follows. The text-to-video model is on the way.\n* **[2024.03.01]** 🤗 We launched a plan to reproduce Sora, called Open-Sora Plan! Welcome to **watch** 👀 this repository for the latest updates.\n\n# 😍 Gallery\n\nText-to-Video Generation of Open-Sora Plan v1.5.0.\n### Youtube:\n[![Demo Video of Open-Sora Plan V1.5.0](https://github.com/user-attachments/assets/130bbba2-3ded-4092-92ef-b65b673cb1a6)](https://youtu.be/IiWTdx2EHCY)\n### Bilibili:\n[![Demo Video of Open-Sora Plan V1.5.0](https://github.com/user-attachments/assets/130bbba2-3ded-4092-92ef-b65b673cb1a6)](https://www.bilibili.com/video/BV1X77tzxE3b/)\n\n# 😮 Highlights\n\nOpen-Sora Plan shows excellent performance in video generation.\n\n### 🔥 WFVAE with higher performance and compression\n- With an 8×8×8 downsampling rate, but achieves higher PSNR than the VAE used in Wan2.1. Lowers the training cost for the DiT built upon it.\n\n### 🚀 More powerful sparse dit\n- The more powerful sparse attention architecture, SUV, achieves performance close to dense DiT while providing over a 35% speedup.\n\n<p align=\"center\">\n    <img src=\"https://s21.ax1x.com/2024/07/22/pk7cob8.png\" width=\"650\" style=\"margin-bottom: 0.2;\"/>\n<p>\n\n# 🐳 Resource\n\n| Version | Architecture |  Diffusion Model | CausalVideoVAE | Data | Prompt Refiner |\n|:---|:---|:---|:---|:---|:---|\n| v1.5.0 | SUV (Skiparse 3D) | [121x576x1024](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.5.0/blob/main/MindSpeed/model_ema.pt)[5] | [Anysize_8x8x8_32dim](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.5.0/blob/main/MindSpeed/wfvae_888_dim32.ckpt) | - | - |\n| v1.3.0 [4] | Skiparse 3D | [Anysize in 93x640x640](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.3.0/tree/main/any93x640x640)[3], [Anysize in 93x640x640_i2v](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.3.0/tree/main/any93x640x640_i2v)[3] | [Anysize](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.3.0/tree/main/vae)| [prompt_refiner](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.3.0/tree/main/prompt_refiner) | [checkpoint](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.3.0/tree/main/prompt_refiner)| |\n| v1.2.0 | Dense 3D | [93x720p](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.2.0/tree/main/93x720p), [29x720p](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.2.0/tree/main/29x720p)[1], [93x480p](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.2.0/tree/main/93x480p)[1,2], [29x480p](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.2.0/tree/main/29x480p), [1x480p](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.2.0/tree/main/1x480p), [93x480p_i2v](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.2.0/tree/main/93x480p_i2v) | [Anysize](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.2.0/tree/main/vae)| [Annotations](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.2.0) | - |\n| v1.1.0 | 2+1D | [221x512x512](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.1.0/tree/main/221x512x512), [65x512x512](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.1.0/tree/main/65x512x512) |[Anysize](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.1.0/tree/main/vae) |[Data and Annotations](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.1.0)| - |\n| v1.0.0 | 2+1D | [65x512x512](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.0.0/tree/main/65x512x512), [65x256x256](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.0.0/tree/main/65x256x256), [17x256x256](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.0.0/tree/main/17x256x256) | [Anysize](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.0.0/tree/main/vae) | [Data and Annotations](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.0.0)| - |\n\n> [1] Please note that the weights for v1.2.0 29×720p and 93×480p were trained on Panda70M and have not undergone final high-quality data fine-tuning, so they may produce watermarks.\n\n> [2] We fine-tuned 3.5k steps from 93×720p to get 93×480p for community research use.\n\n> [3] The model is trained arbitrarily on stride=32. So keep the resolution of the inference a multiple of 32. Frames need to be 4n+1, e.g. 93, 77, 61, 45, 29, 1 (image).\n\n> [4] Model weights are also available at [OpenMind](https://modelers.cn/models/linbin/Open-Sora-Plan-v1.3.0) and [WiseModel](https://wisemodel.cn/models/PKU-YUAN/Open-Sora-Plan-v1.3.0).\n\n> [5] The current model weights are only compatible with the NPU + MindSpeed-MM framework. Model weights are also available at and [modelers](https://modelers.cn/models/PKU-YUAN-Group/Open-Sora-Plan-v1.5.0/tree/main/MindSpeed).\n\n> [!Warning]\n>\n> <div align=\"left\">\n> <b>\n> 🚨 For version 1.2.0, we no longer support 2+1D models.\n> </b>\n> </div>\n\n# ⚙️ How to start\n\n### GPU\ncoming soon...\n### NPU\nPlease check out the **[mindspeed_mmdit](https://github.com/PKU-YuanGroup/Open-Sora-Plan/tree/mindspeed_mmdit)** branch and follow the README.md for configuration.\n\n# 📖 Technical report\nPlease check [Report-v1.5.0.md](docs/Report-v1.5.0.md).\n\n# 💡 How to Contribute\nWe greatly appreciate your contributions to the Open-Sora Plan open-source community and helping us make it even better than it is now!\n\nFor more details, please refer to the [Contribution Guidelines](docs/Contribution_Guidelines.md)\n\n# 👍 Acknowledgement and Related Work\n* [Allegro](https://github.com/rhymes-ai/Allegro): Allegro is a powerful text-to-video model that generates high-quality videos up to 6 seconds at 15 FPS and 720p resolution from simple text input based on our Open-Sora Plan. The significance of open-source is becoming increasingly tangible.\n* [Latte](https://github.com/Vchitect/Latte): It is a wonderful 2+1D video generation model.\n* [PixArt-alpha](https://github.com/PixArt-alpha/PixArt-alpha): Fast Training of Diffusion Transformer for Photorealistic Text-to-Image Synthesis.\n* [ShareGPT4Video](https://github.com/InternLM/InternLM-XComposer/tree/main/projects/ShareGPT4Video): Improving Video Understanding and Generation with Better Captions.\n* [VideoGPT](https://github.com/wilson1yan/VideoGPT): Video Generation using VQ-VAE and Transformers.\n* [DiT](https://github.com/facebookresearch/DiT): Scalable Diffusion Models with Transformers.\n* [FiT](https://github.com/whlzy/FiT): Flexible Vision Transformer for Diffusion Model.\n* [Positional Interpolation](https://arxiv.org/abs/2306.15595): Extending Context Window of Large Language Models via Positional Interpolation.\n\n\n# 🔒 License\n* See [LICENSE](LICENSE) for details.\n\n## ✨ Star History\n\n[![Star History](https://api.star-history.com/svg?repos=PKU-YuanGroup/Open-Sora-Plan)](https://star-history.com/#PKU-YuanGroup/Open-Sora-Plan&Date)\n\n\n# ✏️ Citing\n\n\n```bibtex\n@article{lin2024open,\n  title={Open-Sora Plan: Open-Source Large Video Generation Model},\n  author={Lin, Bin and Ge, Yunyang and Cheng, Xinhua and Li, Zongjian and Zhu, Bin and Wang, Shaodong and He, Xianyi and Ye, Yang and Yuan, Shenghai and Chen, Liuhan and others},\n  journal={arXiv preprint arXiv:2412.00131},\n  year={2024}\n}\n```\n```bibtex\n@article{helios,\n  title={Helios: Real Real-Time Long Video Generation Model},\n  author={Yuan, Shenghai and Yin, Yuanyang and Li, Zongjian and Huang, Xinwei and Yang, Xiao and Yuan, Li},\n  journal={arXiv preprint arXiv:2603.04379},\n  year={2026}\n}\n```\n```bibtex\n@article{li2024wf,\n  title={WF-VAE: Enhancing Video VAE by Wavelet-Driven Energy Flow for Latent Video Diffusion Model},\n  author={Li, Zongjian and Lin, Bin and Ye, Yang and Chen, Liuhan and Cheng, Xinhua and Yuan, Shenghai and Yuan, Li},\n  journal={arXiv preprint arXiv:2411.17459},\n  year={2024}\n}\n```\n\n# 🤝 Community contributors\n\n<a href=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/graphs/contributors\">\n  <img src=\"https://contrib.rocks/image?repo=PKU-YuanGroup/Open-Sora-Plan\" />\n</a>\n\n"
  },
  {
    "path": "docs/Contribution_Guidelines.md",
    "content": "# Contributing to the Open-Sora Plan Community\n\nThe Open-Sora Plan open-source community is a collaborative initiative driven by the community, emphasizing a commitment to being free and void of exploitation. Organized spontaneously by community members, we invite you to contribute to the Open-Sora Plan open-source community and help elevate it to new heights!\n\n## Submitting a Pull Request (PR)\n\nAs a contributor, before submitting your request, kindly follow these guidelines:\n\n1. Start by checking the [Open-Sora Plan GitHub](https://github.com/PKU-YuanGroup/Open-Sora-Plan/pulls) to see if there are any open or closed pull requests related to your intended submission. Avoid duplicating existing work.\n\n2. [Fork](https://github.com/PKU-YuanGroup/Open-Sora-Plan/fork) the [open-sora plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan) repository and download your forked repository to your local machine.\n\n   ```bash\n   git clone [your-forked-repository-url]\n   ```\n\n3. Add the original Open-Sora Plan repository as a remote to sync with the latest updates:\n\n   ```bash\n   git remote add upstream https://github.com/PKU-YuanGroup/Open-Sora-Plan\n   ```\n\n4. Sync the code from the main repository to your local machine, and then push it back to your forked remote repository.\n\n   ```\n   # Pull the latest code from the upstream branch\n   git fetch upstream\n   \n   # Switch to the main branch\n   git checkout main\n   \n   # Merge the updates from the upstream branch into main, synchronizing the local main branch with the upstream\n   git merge upstream/main\n   \n   # Additionally, sync the local main branch to the remote branch of your forked repository\n   git push origin main\n   ```\n\n\n   > Note: Sync the code from the main repository before each submission.\n\n5. Create a branch in your forked repository for your changes, ensuring the branch name is meaningful.\n\n   ```bash\n   git checkout -b my-docs-branch main\n   ```\n\n6. While making modifications and committing changes, adhere to our [Commit Message Format](#Commit-Message-Format).\n\n   ```bash\n   git commit -m \"[docs]: xxxx\"\n   ```\n\n7. Push your changes to your GitHub repository.\n\n   ```bash\n   git push origin my-docs-branch\n   ```\n\n8. Submit a pull request to `Open-Sora-Plan:main` on the GitHub repository page.\n\n## Commit Message Format\n\nCommit messages must include both `<type>` and `<summary>` sections.\n\n```bash\n[<type>]: <summary>\n  │        │\n  │        └─⫸ Briefly describe your changes, without ending with a period.\n  │\n  └─⫸ Commit Type: |docs|feat|fix|refactor|\n```\n\n### Type \n\n* **docs**: Modify or add documents.\n* **feat**: Introduce a new feature.\n* **fix**: Fix a bug.\n* **refactor**: Restructure code, excluding new features or bug fixes.\n\n### Summary\n\nDescribe modifications in English, without ending with a period.\n\n> e.g., git commit -m \"[docs]: add a contributing.md file\"\n\nThis guideline is borrowed by [minisora](https://github.com/mini-sora/minisora). We sincerely appreciate MiniSora authors for their awesome templates. \n"
  },
  {
    "path": "docs/Prompt_Refiner.md",
    "content": "## Data\n\nWe have open-sourced our dataset of 32,555 pairs, which includes Chinese data. The dataset is available [here](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.3.0/tree/main/prompt_refiner). The details can be found [here](https://github.com/PKU-YuanGroup/Open-Sora-Plan/blob/main/docs/Report-v1.3.0.md#prompt-refiner).\n\nIn fact, it is a JSON file with the following structure.\n\n```\n[\n  {\n    \"instruction\": \"Refine the sentence: \\\"A newly married couple sharing a piece of there wedding cake.\\\" to contain subject description, action, scene description. (Optional: camera language, light and shadow, atmosphere) and conceive some additional actions to make the sentence more dynamic. Make sure it is a fluent sentence, not nonsense.\",\n    \"input\": \"\",\n    \"output\": \"The newlywed couple, dressed in elegant attire...\"\n  },\n  ...\n]\n```\n\n## Train\n\n`--data_path` is the path to the prepared JSON file.\n`--model_path` is the directory containing the LLaMA 3.1 weights, including `config.json` and some weight files.\n`--lora_out_path` is the path where the LoRA model will be saved.\n\n```\ncd opensora/models/prompt_refiner\nCUDA_VISIBLE_DEVICES=0 python train.py \\\n    --data_path path/to/data.json \\\n    --model_path path/to/llama_model \\ \n    --lora_out_path path/to/save/lora_model\n```\n\n## Merge\n\n`--model_path` is the directory containing the LLaMA 3.1 weights, including `config.json` and some weight files.\n`--lora_in_path` is the directory containing the pre-trained LoRA model.\n`--lora_out_path` is the path for the merged model.\n\n```\ncd opensora/models/prompt_refiner\nCUDA_VISIBLE_DEVICES=0 python merge.py \\\n    --base_path path/to/llama_model \\\n    --lora_in_path path/to/save/lora_model \\\n    --lora_out_path path/to/save/merge_model\n```\n\n## Inference\n\n`--model_path` is the directory containing the weights (LLaMA 3.1 or merged Lora weight), including `config.json` and some weight files.\n`--prompt` is the text you want to input, which will be refined.\n\n```\ncd opensora/models/prompt_refiner\nCUDA_VISIBLE_DEVICES=0 python merge.py \\\n    --mode_path path/to/data.json \\\n    --prompt path/to/save/lora_model\n```"
  },
  {
    "path": "docs/Report-v1.0.0-cn.md",
    "content": "# 技术报告 v1.0.0\n\n在2024年3月，我们推出了Open-Sora-Plan，一个旨在复现OpenAI [Sora](https://openai.com/sora)的开源计划。它作为一个基础的开源框架，能够训练视频生成模型包括无条件视频生成，类别引导视频生成，文生视频。\n\n**今天，我们兴奋地展示Open-Sora-Plan v1.0.0，极大地改进视频生成质量、文本控制能力。**\n\n相比于之前的视频生成模型，Open-Sora-Plan v1.0.0 有以下的改进：\n\n1. **CausalVideoVAE高效的训练与推理**。 我们用4×8×8的对视频进行时间和空间的压缩。\n2. **图片视频联合训练提升视觉质量**。 CasualVideoVAE 将首帧看作图片，天然支持同时编码图片和视频。这允许扩散模型提取更多时空细节来改善质量。\n\n\n### Open-Source Release\n我们开源了Open-Sora-Plan去促进视频生成社区的进一步发展。公开代码、数据、模型。\n- 在线演示：Hugging Face [![hf_space](https://img.shields.io/badge/🤗-Open%20In%20Spaces-blue.svg)](https://huggingface.co/spaces/LanguageBind/Open-Sora-Plan-v1.0.0)， [![Replicate demo and cloud API](https://replicate.com/camenduru/open-sora-plan-512x512/badge)](https://replicate.com/camenduru/open-sora-plan-512x512) 和 [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/Open-Sora-Plan-jupyter/blob/main/Open_Sora_Plan_jupyter.ipynb), 感谢[@camenduru](https://github.com/camenduru)大力支持我们的工作！🤝\n- 代码：所有训练脚本和采样代码。\n- 模型：包括扩散模型和CausalVideoVAE [这里](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.0.0)。\n- 数据：所有原视频和对应描述 [这里](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.0.0)。\n  \n## 效果\n\nOpen-Sora-Plan v1.0.0支持图片视频联合训练。我们在此展示视频和图片的重建以及生成：\n\n720×1280**视频重建**。 因为github的限制，原视频放在: [1](https://streamable.com/gqojal), [2](https://streamable.com/6nu3j8). \n\nhttps://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/c100bb02-2420-48a3-9d7b-4608a41f14aa\n\nhttps://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/8aa8f587-d9f1-4e8b-8a82-d3bf9ba91d68\n\n1536×1024**图片重建**\n\n<img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/1684c3ec-245d-4a60-865c-b8946d788eb9\" width=\"45%\"/> <img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/46ef714e-3e5b-492c-aec4-3793cb2260b5\" width=\"45%\"/>\n\n65×1024×1024**文生视频**\n\nhttps://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/2641a8aa-66ac-4cda-8279-86b2e6a6e011\n\n65×512×512**文生视频** \n\nhttps://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/37e3107e-56b3-4b09-8920-fa1d8d144b9e\n\n\n512×512**文生视频** \n\n![download](https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/491d72bc-e762-48ff-bdcc-cc69350f56d6)\n\n## 详细技术报告\n\n### CausalVideoVAE\n\n#### 模型结构\n\n![image](https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/e3c8b35d-a217-4d96-b2e9-5c248a2859c8)\n\n因果VAE架构继承了[Stable-Diffusion Image VAE](https://github.com/CompVis/stable-diffusion/tree/main)。 为了保证图片VAE的预训练权重可以无缝运用到视频VAE中，模型结构采取如下设计:\n\n1. **CausalConv3D**: 将Conv2D 转变成CausalConv3D可以实现图片和视频的联合训练. CausalConv3D 对第一帧进行特殊处理，因为它无法访问后续帧。对于更多细节，请参考https://github.com/PKU-YuanGroup/Open-Sora-Plan/pull/145\n\n2. **初始化**：将Conv2D扩展到Conv3D常用的[方法](https://github.com/hassony2/inflated_convnets_pytorch/blob/master/src/inflate.py#L5)有两种：平均初始化和中心初始化。 但我们采用了特定的初始化方法（尾部初始化）。 这种初始化方法确保模型无需任何训练就能够直接重建图像，甚至视频。\n   \n#### 训练细节\n\n<img width=\"833\" alt=\"image\" src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/9ffb6dc4-23f6-4274-a066-bbebc7522a14\">\n\n我们展示了 17×256×256 下两种不同初始化方法的损失曲线。黄色曲线代表使用尾部初始化的损失，而蓝色曲线对应中心初始化的损失。 如图所示，尾部初始化在损失曲线上表现出更好的性能。 此外，我们发现中心初始化会导致错误累积，导致长时间内崩溃。\n\n#### 推理技巧\n尽管训练Diffusion中VAE始终是冻住的，我们仍然无法负担CasualVideoVAE的花销。在我们的实验中, 80G的显存只能够在半精度下推理一个256×512×512或32×1024×1024的视频 ，这限制了我们扩展到更长更高清的视频。因此我们采用tile convolution，能够以几乎恒定的内存推理任意时长或任意分辨率的视频。\n\n### 数据构建\n我们定义高质量的视频数据集包括两个核心法则：(1) 没有与内容无关的水印。(2) 高质量的文本注释。\n\n**对于法则1**，我们从开源网站（CC0协议）爬取了大约40k videos：1234个来自[mixkit](https://mixkit.co/)，7408个来自[pexels](https://www.pexels.com/)，31616个来自[pixabay](https://pixabay.com/)。我们根据[Panda70M](https://github.com/snap-research/Panda-70M/blob/main/splitting/README.md)提供的场景变换剪切script将这些视频切成大约434k video clips。事实上，根据我们的剪切结果，从这些网上上爬取的99%的视频都是单一的场景。另外，我们发现爬取的数据中超过60%为风景相关视频。更多细节可以在[这](https://github.com/PKU-YuanGroup/Open-Sora-Dataset)找到。\n\n**对于法则2**，很难有大量的高质量的文本注释能够从网上直接爬取。因此我们用成熟的图片标注模型来获取高质量的稠密描述。我们对2个多模态大模型进行消融实验：[ShareGPT4V-Captioner-7B](https://github.com/InternLM/InternLM-XComposer/blob/main/projects/ShareGPT4V/README.md) 和 [LLaVA-1.6-34B](https://github.com/haotian-liu/LLaVA)。前者是专门用来制作文本注释的模型，而后者是一个通用的多模态大模型。经过我们的消融实验，他们在caption的表现差不多。然而他们的推理速度在A800上差距很大：40s/it of batch size of 12 for ShareGPT4V-Captioner-7B，15s/it of batch size of 1 for ShareGPT4V-Captioner-7B。我们开源所有的[文本注释和原视频](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.0.0)。\n\n| 模型名字 | 平均长度 | 最大值 | 标准差 |\n|---|---|---|---|\n| ShareGPT4V-Captioner-7B | 170.0827524529121 |  467 | 53.689967539537776 | \n| LLaVA-1.6-34B | 141.75851073472666 |  472 | 48.52492072346965 | \n\n### 训练扩散模型\n与之前的工作类似，我们采用多阶段的级联的训练方法，总共消耗了2048个A800 GPU 小时。我们发现联合图片训练能够显著加速模型的收敛并且增强视觉观感，这与[Latte](https://github.com/Vchitect/Latte)一致。以下是我们的训练花销。\n\n| 名字 | Stage 1 | Stage 2 | Stage 3 | Stage 4 |\n|---|---|---|---|---|\n| 训练视频尺寸 | 17×256×256 |  65×256×256 | 65×512×512 |  65×1024×1024 | \n| 计算资源 (#A800 GPU x #小时) | 32 × 40 |  32 × 18 |  32 × 6 |  训练中 | \n| 权重 | [HF](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.0.0/tree/main/17x256x256) | [HF](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.0.0/tree/main/65x256x256) |  [HF](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.0.0/tree/main/65x512x512) |  训练中 | \n| 日志 | [wandb](https://api.wandb.ai/links/linbin/p6n3evym) |  [wandb](https://api.wandb.ai/links/linbin/t2g53sew) |  [wandb](https://api.wandb.ai/links/linbin/uomr0xzb) | 训练中 | \n| 训练数据 | ~40k videos |  ~40k videos |  ~40k videos |  ~40k videos | \n\n## 下版本预览\n### CausalVideoVAE\n目前我们发布的CausalVideoVAE v1.0.0版本存在2个主要的缺陷：**运动模糊**以及**网格效应**。我们对CasualVideoVAE做了一系列的改进使它推理成本更低且性能更强大，我们暂时叫它为预览版本，将在下个版本发布。\n\n**1分钟720×1280视频重建**。 受限于GitHub，我们将原视频放在这：[原视频](https://streamable.com/u4onbb)，[重建视频](https://streamable.com/qt8ncc)。\n\nhttps://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/cdcfa9a3-4de0-42d4-94c0-0669710e407b\n\n我们从kinetic 400的验证集中随机选取100个样本进行评估，结果表如下所示：\n\n|  | SSIM↑ | LPIPS↓ | PSNR↑ | FLOLPIPS↓ |\n|---|---|---|---|---|\n| v1.0.0 | 0.829 |  0.106 |  27.171 |  0.119 | \n| Preview | 0.877 |  0.064 |  29.695 |  0.070 | \n\n#### 运动模糊\n\n| **v1.0.0** | **预览版本** |\n| --- | --- |\n| ![6862cae0-b1b6-48d1-bd11-84348cf42b42](https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/f815636f-fb38-4891-918b-50b1f9aa086d)  | ![9189da06-ef2c-42e6-ad34-bd702a6f538e](https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/1e413f50-a785-485a-9851-a1449f952f1c)  |\n\n#### 网格效应\n\n| **v1.0.0** | **预览版本** |\n| --- | --- |\n| ![img](https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/7fec5bed-3c83-4ee9-baef-4a3dacafc658)  | ![img](https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/4f41b432-a3ef-484e-a492-8afd8a691bf7)  |\n\n### 数据构建\n\n**数据源**：正如上文提到，我们的数据集中超过60%为风景视频。这意味着我们的开域视频生成能力有限。然而当前的大规模开源数据集大多从YouTube爬取，尽管视频的数量多，但我们担忧视频本身的质量是否达标。因此，我们将继续收集高质量的数据集，同时也欢迎开源社区的推荐。 \n\n**Caption生成流程**：当我们训练时长增加时，我们不得不考虑更有效的视频caption生成方法，而不是多模态图片大模型。我们正在开发一个新的视频注释生成管线，它能够很好的支持长视频，敬请期待。\n\n### 训练扩散模型\n尽管目前v1.0.0展现了可喜的结果，但我们仍然离Sora有一段距离。在接下来的工作中，我们主要围绕这三个方面:\n\n1. **动态分辨率与时长的训练**: 我们的目标是开发出能够以不同分辨率和持续时间训练模型的技术，使训练过程更加灵活、适应性更强。\n\n2. **更长的视频生成**: 我们将探索扩展模型生成能力的方法，使其能够制作更长的视频，超越目前的限制。\n\n3. **更多条件控制**: 我们力求增强模型的条件控制能力，为用户提供更多的选项和对生成视频的控制能力。\n\n另外，通过仔细观察生成的视频，我们发现存在一些不符合常理的斑点或异常的流动，这是由于CasualVideoVAE的性能不足导致的 如上面提到。在未来的实验中，我们将使用更强的VAE，重新训练一个扩散模型。\n"
  },
  {
    "path": "docs/Report-v1.0.0.md",
    "content": "# Report v1.0.0\n\nIn March 2024, we launched a plan called Open-Sora-Plan, which aims to reproduce the OpenAI [Sora](https://openai.com/sora) through an open-source framework. As a foundational open-source framework, it enables training of video generation models, including Unconditioned Video Generation, Class Video Generation, and Text-to-Video Generation.\n\n**Today, we are thrilled to present Open-Sora-Plan v1.0.0, which significantly enhances video generation quality and text control capabilities.**\n\nCompared with previous video generation model, Open-Sora-Plan v1.0.0 has several improvements:\n\n1. **Efficient training and inference with CausalVideoVAE**. We apply a spatial-temporal compression to the videos by 4×8×8.\n2. **Joint image-video training for better quality**. Our CausalVideoVAE considers the first frame as an image, allowing for the simultaneous encoding of both images and videos in a natural manner. This allows the diffusion model to grasp more spatial-visual details to improve visual quality.\n\n### Open-Source Release\nWe open-source the Open-Sora-Plan to facilitate future development of Video Generation in the community. Code, data, model are made publicly available.\n- Demo: Hugging Face demo [![hf_space](https://img.shields.io/badge/🤗-Open%20In%20Spaces-blue.svg)](https://huggingface.co/spaces/LanguageBind/Open-Sora-Plan-v1.0.0). 🤝 Enjoying the [![Replicate demo and cloud API](https://replicate.com/camenduru/open-sora-plan-512x512/badge)](https://replicate.com/camenduru/open-sora-plan-512x512) and [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/Open-Sora-Plan-jupyter/blob/main/Open_Sora_Plan_jupyter.ipynb), created by [@camenduru](https://github.com/camenduru), who generously supports our research!\n- Code: All training scripts and sample scripts.\n- Model: Both Diffusion Model and CausalVideoVAE [here](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.0.0).\n- Data: Both raw videos and captions [here](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.0.0).\n\n## Gallery\n\nOpen-Sora-Plan v1.0.0 supports joint training of images and videos. Here, we present the capabilities of Video/Image Reconstruction and Generation:\n\n### CausalVideoVAE Reconstruction\n\n**Video Reconstruction** with 720×1280. Since github can't upload large video, we put it here: [1](https://streamable.com/gqojal), [2](https://streamable.com/6nu3j8). \n\nhttps://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/c100bb02-2420-48a3-9d7b-4608a41f14aa\n\nhttps://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/8aa8f587-d9f1-4e8b-8a82-d3bf9ba91d68\n\n**Image Reconstruction** in 1536×1024.\n\n<img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/1684c3ec-245d-4a60-865c-b8946d788eb9\" width=\"45%\"/> <img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/46ef714e-3e5b-492c-aec4-3793cb2260b5\" width=\"45%\"/>\n\n**Text-to-Video Generation** with 65×1024×1024\n\nhttps://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/2641a8aa-66ac-4cda-8279-86b2e6a6e011\n\n**Text-to-Video Generation** with 65×512×512\n\nhttps://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/37e3107e-56b3-4b09-8920-fa1d8d144b9e\n\n\n**Text-to-Image Generation** with 512×512\n\n![download](https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/491d72bc-e762-48ff-bdcc-cc69350f56d6)\n\n## Detailed Technical Report\n\n### CausalVideoVAE\n\n#### Model Structure\n\n![image](https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/e3c8b35d-a217-4d96-b2e9-5c248a2859c8)\n\nThe CausalVideoVAE architecture inherits from the [Stable-Diffusion Image VAE](https://github.com/CompVis/stable-diffusion/tree/main). To ensure that the pretrained weights of the Image VAE can be seamlessly applied to the Video VAE, the model structure has been designed as follows:\n\n1. **CausalConv3D**: Converting Conv2D to CausalConv3D enables joint training of image and video data. CausalConv3D applies a special treatment to the first frame, as it does not have access to subsequent frames. For more specific details, please refer to https://github.com/PKU-YuanGroup/Open-Sora-Plan/pull/145\n\n2. **Initialization**: There are two common [methods](https://github.com/hassony2/inflated_convnets_pytorch/blob/master/src/inflate.py#L5) to expand Conv2D to Conv3D: average initialization and center initialization. But we employ a specific initialization method (tail initialization). This initialization method ensures that without any training, the model is capable of directly reconstructing images, and even videos.\n\n#### Training Details\n\n<img width=\"833\" alt=\"image\" src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/9ffb6dc4-23f6-4274-a066-bbebc7522a14\">\n\nWe present the loss curves for two distinct initialization methods under 17×256×256. The yellow curve represents the loss using tail init, while the blue curve corresponds to the loss from center initialization. As shown in the graph, tail initialization demonstrates better performance on the loss curve. Additionally, we found that center initialization leads to error accumulation, causing the collapse over extended durations.\n\n#### Inference Tricks\nDespite the VAE in Diffusion training being frozen, we still find it challenging to afford the cost of the CausalVideoVAE. In our case, with 80GB of GPU memory, we can only infer a video of either 256×512×512 or 32×1024×1024 resolution using half-precision, which limits our ability to scale up to longer and higher-resolution videos. Therefore, we adopt tile convolution, which allows us to infer videos of arbitrary duration or resolution with nearly constant memory usage.\n\n### Data Construction\nWe define a high-quality video dataset based on two core principles: (1) No content-unrelated watermarks. (2) High-quality and dense captions.\n\n**For principles 1**, we crawled approximately 40,000 videos from open-source websites under the CC0 license. Specifically, we obtained 1,234 videos from [mixkit](https://mixkit.co/), 7,408 videos from [pexels](https://www.pexels.com/), and 31,616 videos from [pixabay](https://pixabay.com/). These videos adhere to the principle of having no content-unrelated watermarks. According to the scene transformation and clipping script provided by [Panda70M](https://github.com/snap-research/Panda-70M/blob/main/splitting/README.md), we have divided these videos into approximately 434,000 video clips. In fact, based on our clipping results, 99% of the videos obtained from these online sources are found to contain single scenes. Additionally, we have observed that over 60% of the crawled data comprises landscape videos. More details can be found [here](https://github.com/PKU-YuanGroup/Open-Sora-Dataset).\n\n**For principles 2**, it is challenging to directly crawl a large quantity of high-quality dense captions from the internet. Therefore, we utilize a mature Image-captioner model to obtain high-quality dense captions. We conducted ablation experiments on two multimodal large models: [ShareGPT4V-Captioner-7B](https://github.com/InternLM/InternLM-XComposer/blob/main/projects/ShareGPT4V/README.md) and [LLaVA-1.6-34B](https://github.com/haotian-liu/LLaVA). The former is specifically designed for caption generation, while the latter is a general-purpose multimodal large model. After conducting our ablation experiments, we found that they are comparable in performance. However, there is a significant difference in their inference speed on the A800 GPU: 40s/it of batch size of 12 for ShareGPT4V-Captioner-7B, 15s/it of batch size of 1 for LLaVA-1.6-34B. We open-source all annotations [here](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.0.0). We show some statistics here, and we set the maximum length of the model to 300, which covers almost 99% of the samples.\n\n| Name | Avg length | Max | Std |\n|---|---|---|---|\n| ShareGPT4V-Captioner-7B | 170.0827524529121 |  467 | 53.689967539537776 | \n| LLaVA-1.6-34B | 141.75851073472666 |  472 | 48.52492072346965 | \n\n### Training Diffusion Model\nSimilar to previous work, we employ a multi-stage cascaded training approach, which consumes a total of 2,528 A800 GPU hours. We found that joint training with images significantly accelerates model convergence and enhances visual perception, aligning with the findings of [Latte](https://github.com/Vchitect/Latte). Below is our training card:\n\n| Name | Stage 1 | Stage 2 | Stage 3 | Stage 4 |\n|---|---|---|---|---|\n| Training Video Size | 17×256×256 |  65×256×256 | 65×512×512 |  65×1024×1024 | \n| Compute (#A800 GPU x #Hours) | 32 × 40 |  32 × 22 |  32 × 17 |  Under training | \n| Checkpoint | [HF](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.0.0/tree/main/17x256x256) | [HF](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.0.0/tree/main/65x256x256) |  [HF](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.0.0/tree/main/65x512x512) |  Under training | \n| Log | [wandb](https://api.wandb.ai/links/linbin/p6n3evym) |  [wandb](https://api.wandb.ai/links/linbin/t2g53sew) |  [wandb](https://api.wandb.ai/links/linbin/uomr0xzb) | Under training | \n| Training Data | ~40k videos |  ~40k videos |  ~40k videos |  ~40k videos | \n\n## Next Release Preview\n### CausalVideoVAE\nCurrently, the released version of CausalVideoVAE (v1.0.0) has two main drawbacks: **motion blurring** and **gridding effect**. We have made a series of improvements to CausalVideoVAE to reduce its inference cost and enhance its performance. We are currently referring to this enhanced version as the \"preview version,\" which will be released in the next update. Preview reconstruction is as follows:\n\n**1 min Video Reconstruction with 720×1280**. Since github can't put too big video, we put it here: [origin video](https://streamable.com/u4onbb), [reconstruction video](https://streamable.com/qt8ncc). \n\nhttps://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/cdcfa9a3-4de0-42d4-94c0-0669710e407b\n\nWe randomly selected 100 samples from the validation set of Kinetics-400 for evaluation, and the results are presented in the following table:\n\n|  | SSIM↑ | LPIPS↓ | PSNR↑ | FLOLPIPS↓ |\n|---|---|---|---|---|\n| v1.0.0 | 0.829 |  0.106 |  27.171 |  0.119 | \n| Preview | 0.877 |  0.064 |  29.695 |  0.070 | \n\n#### Motion Blurring\n\n| **v1.0.0** | **Preview** |\n| --- | --- |\n| ![6862cae0-b1b6-48d1-bd11-84348cf42b42](https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/f815636f-fb38-4891-918b-50b1f9aa086d)  | ![9189da06-ef2c-42e6-ad34-bd702a6f538e](https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/1e413f50-a785-485a-9851-a1449f952f1c)  |\n\n#### Gridding effect\n\n| **v1.0.0** | **Preview** |\n| --- | --- |\n| ![img](https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/7fec5bed-3c83-4ee9-baef-4a3dacafc658)  | ![img](https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/4f41b432-a3ef-484e-a492-8afd8a691bf7)  |\n\n### Data Construction\n\n**Data source**. As mentioned earlier, over 60% of our dataset consists of landscape videos. This implies that our ability to generate videos in other domains is limited. However, most of the current large-scale open-source datasets are primarily obtained through web scraping from platforms like YouTube. While these datasets provide a vast quantity of videos, we have concerns about the quality of the videos themselves. Therefore, we will continue to collect high-quality datasets and also welcome recommendations from the open-source community. We are launching an Open-Sora-Dataset project, check out the details at [Open-Sora-Dataset](https://github.com/PKU-YuanGroup/Open-Sora-Dataset)\n\n**Caption Generation Pipeline**. As the video duration increases, we need to consider more efficient methods for video caption generation instead of relying solely on large multimodal image models. We are currently developing a new video caption generation pipeline that provides robust support for long videos. We are excited to share more details with you in the near future. Stay tuned!\n\n### Training Diffusion Model\nAlthough v1.0.0 has shown promising results, we acknowledge that we still have a ways to go to reach the level of Sora. In our upcoming work, we will primarily focus on three aspects:\n\n1. **Training support for dynamic resolution and duration**: We aim to develop techniques that enable training models with varying resolutions and durations, allowing for more flexible and adaptable training processes.\n\n2. **Support for longer video generation**: We will explore methods to extend the generation capabilities of our models, enabling them to produce longer videos beyond the current limitations.\n\n3. **Enhanced conditional control**: We seek to enhance the conditional control capabilities of our models, providing users with more options and control over the generated videos.\n\nFurthermore, through careful observation of the generated videos, we have noticed the presence of some non-physiological speckles or abnormal flow. This can be attributed to the limited performance of CausalVideoVAE, as mentioned earlier. In future experiments, we plan to retrain a diffusion model using a more powerful version of CausalVideoVAE to address these issues.\n"
  },
  {
    "path": "docs/Report-v1.1.0.md",
    "content": "# Report v1.1.0\n\nIn April 2024, we launched Open-Sora-Plan v1.0.0, featuring a simple and efficient design along with remarkable performance in text-to-video generation. It has already been adopted as a foundational model in numerous research projects, including its data and model.\n\n**Today, we are excited to present Open-Sora-Plan v1.1.0, which significantly improves video generation quality and duration.**\n\nCompared to the previous version, Open-Sora-Plan v1.1.0, the improvements include:\n\n1. **Better compressed visual representations**. We optimized the CausalVideoVAE architecture, which now has stronger performance and higher inference efficiency.\n2. **Generate higher quality, longer videos**. We used higher quality visual data and captions by [ShareGPT4Video](https://sharegpt4video.github.io/), enabling the model to better understand the workings of the world.\n\nAlong with performance improvements, Open-Sora-Plan v1.1.0 maintains the minimalist design and data efficiency of v1.0.0. Remarkably, we found that v1.1.0 exhibits similar performance to the Sora base model, indicating that our version's evolution aligns with the scaling law demonstrated by Sora.\n\n### Open-Source Release\nWe open-source the Open-Sora-Plan to facilitate future development of Video Generation in the community. Code, data, model will be made publicly available.\n- Demo: Hugging Face demo [here](https://huggingface.co/spaces/LanguageBind/Open-Sora-Plan-v1.1.0).\n- Code: All training scripts and sample scripts.\n- Model: Both Diffusion Model and CasualVideoVAE [here](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.1.0).\n- Data: Both raw videos and captions [here](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.1.0).\n\n## Gallery\n\n\n### 221×512×512 Text-to-Video Generation\n\n| 221×512×512 (9.2s) | 221×512×512 (9.2s) | 221×512×512 (9.2s) | 221×512×512 (9.2s) |\n| --- | --- | --- | --- |\n| <img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/6d18f344-f7da-44eb-9e07-77813f6b5e90\" width=224> | <img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/71f75e72-e9ee-4ce7-b8ea-d2d45a6f367e\" width=224>  | <img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/80430eae-a3b4-4f24-b448-0db2919327d6\" width=224> | <img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/dc217894-a8c8-4174-a42c-acd2811f61f5\" width=224> |\n| This close-up shot of a Victoria crowned pigeon showcases its striking blue plumage ... |  a cat wearing sunglasses and working as a lifeguard at pool. |  Photorealistic closeup video of two pirate ships battling each other as they sail ... | A movie trailer featuring the adventures ofthe 30 year old spacemanwearing a redwool ...  |\n| <img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/f9c5823f-aa03-40ee-8335-684684f5c842\" width=224> | <img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/f207d798-8988-45b0-b836-347e499ee000\" width=224>  | <img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/d40c38dc-9f26-4591-8163-c7089e1553e3\" width=224> | <img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/67b40ce1-135d-4ae2-a1ee-5a0866461eb2\" width=224> |\n| A snowy forest landscape with a dirt road running through it. The road is flanked by ... |  Drone shot along the Hawaii jungle coastline, sunny day. Kayaks in the water. | Alpacas wearing knit wool sweaters, graffiti background, sunglasses.  | The camera rotates around a large stack of vintage televisions all showing different ...  |\n| <img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/42542b4e-b1b8-49b8-ada4-7bcfde8e1453\" width=224> | <img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/4eee2619-a5ca-4a32-b350-e65d6220c8f7\" width=224>  | <img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/e6152c93-4edf-4569-8d17-1effd87a7780\" width=224> | <img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/ec78d855-07c7-4421-896a-3c46a83ec129\" width=224> |\n| A drone camera circles around a beautiful historic church built on a rocky outcropping ... | Aerial view of Santorini during the blue hour, showcasing the stunning architecture ...  |  A robot dog explores the surface of Mars, kicking up red dust as it investigates  ... | An aerial shot of a lighthouse standing tall on a rocky cliff, its beacon cutting ...  |\n| <img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/98897215-eae9-49f3-8fdb-df1d4b74d435\" width=224> | <img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/10e0f93d-925f-4b38-8205-7d89f49195f1\" width=224>  | <img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/5453cf37-29ac-423d-9fb2-05f23416ca3e\" width=224> | <img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/77928824-6705-4d83-a7b4-43fdb74790bf\" width=224> |\n| 3D animation of a small, round, fluffy creature with big, expressive eyes explores ... |  A corgi vlogging itself in tropical Maui. |  A single drop of liquid metal falls from a floating orb, landing on a mirror-like ... | The video presents an abstract composition centered around a hexagonal shape adorned ...  |\n\n### 65×512×512 Text-to-Video Generation\n\n| 65×512×512 (2.7s) | 65×512×512 (2.7s) | 65×512×512 (2.7s) | 65×512×512 (2.7s) |\n| --- | --- | --- | --- |\n| <img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/a0601f20-579c-4e2e-832c-5763546718cc\" width=224> | <img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/55229eca-de3a-476b-930b-13a35eb5db30\" width=224>  | <img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/74b9e7a1-0fa4-4f0d-8faf-0c84d97b11b5\" width=224> | <img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/66b06822-2652-453e-80fb-dd2988b730ce\" width=224> |\n| Extreme close-up of chicken and green pepper kebabs grilling on a barbeque with flames. | 3D animation of a small, round, fluffy creature with big, expressive eyes explores a ...   |  A corgi vlogging itself in tropical Maui. |  In a studio, there is a painting depicting a ship sailing through the rough sea. |\n| <img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/4d27ff13-e725-4602-bf17-90df2c0d8005\" width=224> | <img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/049fd4db-f2fe-4633-ab62-1dda8268e090\" width=224>  | <img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/aba01259-60f2-49ef-aa33-e738dc8c9a49\" width=224> | <img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/98178397-61b0-4200-9ab9-9bef2728ed98\" width=224> |\n| A robot dog trots down a deserted alley at night, its metallic paws clinking softly ... | A solitary spider weaves its web in a quiet corner. The web shimmers and glows with ...  |  A lone surfer rides a massive wave, skillfully maneuvering through the surf. The water ... |  A solitary cheetah sprints across the savannah, its powerful muscles propelling it ... |\n| <img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/3964dfcd-d1b4-406b-916c-d4a702184a27\" width=224> | <img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/2a5fd1c0-9304-46e4-af35-5ffcc718bf08\" width=224>  | <img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/73e1304a-6241-4e19-9dde-f2b1032edefd\" width=224> | <img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/c82b822b-cadd-4f11-bcc3-29e3651c02c0\" width=224> |\n| A solitary astronaut plants a flag on an alien planet covered in crystal formations ... |  At dawn's first light, a spaceship slowly exits the edge of the galaxy against a ...|  A dapper puppy in a miniature suit, basking in the afternoon sun, adjusting his tie ... |  A wise old elephant painting abstract art with its trunk, each stroke a burst of color ... |\n| <img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/8edfa249-272c-4773-8728-12686527771e\" width=224> | <img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/33338d8e-5ea7-4b57-9e94-97f3ee404033\" width=224>  | <img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/227fa6b6-801b-438b-9b30-cd3a4e0a7f2f\" width=224> | <img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/b5d5ebee-ab13-4e32-8b79-e6cb8f08d4a0\" width=224> |\n| In an ornate, historical hall, a massive tidal wave peaks and begins to crash. Two ... | A Shiba Inu dog wearing a beret and black turtleneck.  | A painting of a boat on water comes to life, with waves crashing and the boat becoming ...  | Many spotted jellyfish pulsating under water. Their bodies are transparent and glowing ...  |\n| <img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/fe8e9450-5a80-435d-b050-6b2fe11cdf53\" width=224> | <img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/7d9335df-817d-479d-9ea7-615cb50c66b8\" width=224>  | <img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/34c1dda6-e420-4edd-bbcd-38ff8e542ec6\" width=224> |  |\n| An animated hedgehog with distinctive spiky hair and large eyes is seen exploring a ... | An animated rabbit in a playful pink snowboarding outfit is carving its way down a ...  | A person clad in a space suit with a helmet and equipped with a chest light and arm ...  |   |\n\n### 65×512×512 Video Editing\n\n| generated 65×512×512 (2.7s) | edited 65×512×512 (2.7s) |\n| --- | --- |\n| <img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/edb8e8c2-5eef-4c90-85fb-6adb035067c3\" width=224> | <img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/32d93845-8904-4f8f-832e-37eba2ceb542\" width=224>  |\n| <img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/a70d0d91-0d61-4aa4-9520-6e4a6c477f12\" width=224> | <img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/7913e06f-1e0b-4d06-8233-72c882c6abfe\" width=224>  |\n| <img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/0e614ec0-fba0-4f42-a343-d4607966dd40\" width=224> | <img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/c531d930-410c-4614-890d-bae8013f33c2\" width=224>  |\n\n### 512×512 Text-to-Image Generation\n\n <img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/e44b7f8a-5da2-49c2-87c4-52ea680ad43b\" width=512> \n\n## Detailed Technical Report\n\n### CasualVideoVAE\n\n#### Model Structure\n\nAs the number of frames increases, the encoder overhead of CausalVideoVAE gradually rises. When training with 257 frames, 80GB of VRAM is insufficient for the VAE to encode the video. Therefore, we reduced the number of CausalConv3D layers, retaining only the last two stages of CausalConv3D in the encoder. This change significantly lowers the overhead while maintaining nearly the same performance. Note that we only modified the encoder; the decoder still retains all CausalConv3D layers, as training the Diffusion Model does not require the decoder.\n\n<img width=\"722\" alt=\"vaemodel\" src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/b6f8a638-cfb2-40f3-94be-45af8dbad18e\">\n\nWe compare the computational overhead of the two versions by testing the forward inference of the encoder on the H100.\n\n| Version | 129×256×256 |   | 257×256×256 | | 513×256×256 | |\n|---|---|---|---|---|---|---|\n|  |  Peak Mem. |  Speed  | Peak Mem. |  Speed  |Peak Mem. |  Speed  |\n| v1.0.0 |  22G |   2.9 it/s  | OOM |  -   | OOM |   -  |\n| v1.1.0 |  18G |  4.9 it/s   | 34G  |  2.5 it/s   | 61G |   1.2 it/s   |\n\n\n#### Temporal Module\n\n<img width=\"480\" alt=\"vaemodel\" src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/7c8a4263-b7d1-4edc-a60c-801ae9b4344f\">\n\nIn v1.0.0, our temporal module had only TemporalAvgPool. TemporalAvgPool leads to the loss of high-frequency information in the video, such as details and edges. To address this issue, we improved this module in v1.1.0. As shown in the figure below, we introduced convolution and added learnable weights, allowing different branches to decouple different features. When we omit CausalConv3D, the video is reconstructed very blurry. Similarly, when we omit TemporalAvgPool, the video becomes very sharp.\n\n|  | SSIM↑ | LPIPS↓ | PSNR↑ |\n|---|---|---|---|\n| Base | 0.850 |  0.091 |  28.047 |\n| + Frames | 0.868 |  0.070 |  28.829 | \n| + Reset mixed factor | 0.873 |  0.070 |  29.140 | \n\n\n\n\n\n#### Training Details\n\nSimilar to v1.0.0, we initialized from the Latent Diffusion's VAE and used tail initialization. For CausalVideoVAE, we trained for 100k steps in the first stage with a video shape of 9×256×256. Subsequently, we increased the frame count from 9 to 25 and found that this significantly improved the model's performance. It is important to clarify that we enabled the mixed factor during both the first and second stages, with a value of a (sigmoid(mixed factor)) reaching 0.88 at the end of training, indicating the model's tendency to retain low-frequency information. In the third stage, we reinitialized the mixed factor to 0.5 (sigmoid(0.5) = 0.6225), which further enhanced the model's capabilities.\n\n#### Loss Function\n\nWe found that using GAN loss helps retain high-frequency information and alleviates grid artifacts. Additionally, we observed that switching from 2D GAN to 3D GAN provides further improvements.\n\n| GAN Loss/Step | SSIM↑ | LPIPS↓ | PSNR↑ |\n|---|---|---|---|\n| 2D/80k | 0.879 |  0.068 |  29.480 |\n| 3D/80k | 0.882 |  0.067 |  29.890 | \n\n#### Inference Tricks\nTherefore, we introduced a method called **temporal rollback tiled convolution**, a tiling approach specifically designed for CausalVideoVAE. Specifically, all windows except the first one discard the first frame because the first frame in a window is treated as an image, while the remaining frames should be treated as video frames.\n\n<img width=\"633\" alt=\"tiled_temp\" src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/0a06011e-1d6c-410a-9f1c-82c4122a018a\">\n\nWe tested the speed on the H100 with a window size of 65×256×256.\n\n| Version | 129×256×256 |   | 257×256×256 | | 513×256×256 | |\n|---|---|---|---|---|---|---|\n|  |  Peak Mem. |  Speed  | Peak Mem. |  Speed  |Peak Mem. |  Speed  |\n| 4×8×8 |  10G |   1.3 s/it  | 10G | 2.6 s/it   | 10G |   5.3 s/it  |\n\n### Data Construction\nSince Open-Sora-Plan supports joint training of images and videos, our data collection is divided into two parts: images and videos. Images do not need to originate from videos; they are independent datasets. We spent approximately 32×240 H100 hours generating image and video captions, and all of this is **open source**!\n\n#### Image-Text Collection Pipeline\nWe obtained 11 million image-text pairs from [Pixart-Alpha](https://huggingface.co/datasets/PixArt-alpha/SAM-LLaVA-Captions10M), with captions generated by [LLaVA](https://github.com/haotian-liu/LLaVA). Additionally, we utilized the high-quality OCR dataset [Anytext-3M](https://github.com/tyxsspa/AnyText), which pairs each image with corresponding OCR characters. However, these captions were insufficient to describe the entire image, so we used [InternVL-1.5](https://github.com/OpenGVLab/InternVL) for supplementary descriptions. Since T5 only supports English, we filtered for English data, which constitutes about half of the complete dataset. Furthermore, we selected high-quality images from [Laion-5B](https://laion.ai/blog/laion-5b/) to enhance human-like generation quality. The selection criteria included high resolution, high aesthetic scores, and watermark-free images containing people.\n\nHere, we are open-sourcing the prompt used for InternVL-1.5:\n```\n# for anytext-3m\nCombine this rough caption: \"{}\", analyze the image in a comprehensive and detailed manner. \"{}\" can be recognized in the image.\n# for human-160k\nAnalyze the image in a comprehensive and detailed manner.\n```\n\n| Name | Image Source | Text Captioner | Num pair |\n|---|---|---|---|\n| SAM-11M | [SAM](https://ai.meta.com/datasets/segment-anything/) |  [LLaVA](https://github.com/haotian-liu/LLaVA) |  11,185,255 |\n| Anytext-3M-en | [Anytext](https://github.com/tyxsspa/AnyText) |  [InternVL-1.5](https://github.com/OpenGVLab/InternVL) |  1,886,137 | \n| Human-160k | [Laion](https://laion.ai/blog/laion-5b/) |  [InternVL-1.5](https://github.com/OpenGVLab/InternVL) |  162,094 | \n\n\n#### Video-Text Collection Pipeline\nIn v1.0.0, we sampled one frame from each video to generate captions. However, as video length increased, a single frame could not adequately describe the entire video's content or temporal movements. Therefore, we used a video captioner to generate captions for the entire video clip. Specifically, we used [ShareGPT4Video](https://sharegpt4video.github.io/), which effectively covers temporal information and describes the entire video content. The v1.1.0 video dataset comprises approximately 3k hours, compared to only 300 hours in v1.0.0. As before, we have open-sourced all text annotations and videos (both under the CC0 license), which can be found [here](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.1.0/tree/main).\n\n\n\n| Name | Hours | Num frames | Num pair |\n|---|---|---|---|\n| [Mixkit](https://mixkit.co/) | 42.0h |  65 |  54,735 |\n|   |  |  513 |  1,997 | \n| [Pixabay](https://pixabay.com/) | 353.3h |  65 | 601,513 |\n|   |  |  513 |  51,483 | \n| [Pexel](https://www.pexels.com/) | 2561.9h |  65 |  3,832,666 |\n|   |  |  513 |  271,782 | \n\n### Training Diffusion Model\nSimilar to our previous work, we employed a multi-stage cascaded training method. Below is our training card:\n\n#### Stage 1\n\nSurprisingly, we initially believed that the performance of the diffusion model would improve with longer training. However, by observing the [logs](https://api.wandb.ai/links/linbin/o76j03j4), we found that videos generated at 50k steps were of higher quality than those at 70-100k steps. In fact, extensive sampling revealed that checkpoints at 40-60k steps outperformed those at 80-100k steps. Quantitatively, 50k steps correspond to approximately 2 epochs of training. It is currently unclear whether this is due to overfitting from a small dataset or the limited capacity of the 2+1D model.\n\n#### Stage 2\n\nIn the second stage, we used Huawei Ascend computing power for training. This stage's training and inference were fully supported by Huawei. We conducted sequence parallel training and inference on a large-scale cluster, distributing one sample across eight ranks. Models trained on Huawei Ascend can also be loaded into GPUs and generate videos of the same quality.\n\n\n#### Stage 3\n\nIn the third stage, we further increased the frame count to 513 frames, approximately 21 seconds at 24 FPS. However, this stage presents several challenges, such as ensuring temporal consistency in the 2+1D model over long durations and whether the current amount of data is sufficient. We are still training the model for this stage and continuously monitoring its progress.\n\n| Name | Stage 1 | Stage 2 | Stage 3 |\n|---|---|---|---|\n| Training Video Size | 65×512×512 |  221×512×512 | 513×512×512 |\n| Compute (#Num x #Hours) | 80 H100 × 72 | 512 Ascend × 72 |  Under Training |\n| Checkpoint | [HF](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.1.0/tree/main/65x512x512) | [HF](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.1.0/tree/main/221x512x512) | Under Training |\n| Log | [wandb](https://api.wandb.ai/links/linbin/o76j03j4) | - |  - |\n| Training Data | ~3k hours videos + 13M images |  |  |\n\n### Video Editing\n\nThe recently proposed [ReVideo](https://mc-e.github.io/project/ReVideo/) achieves accurate video editing by modifying the first frame and applying motion control within the edited area. Although it achieves excellent video editing performance, the editing length is limited by the base model [SVD](https://github.com/Stability-AI/generative-models). Open-Sora, as a fundamental model for long-video generation, can compensate for this issue. Currently, we are collaborating with the ReVideo team to use Open-Sora as the base model for long video editing. Some preliminary results are shown [here](). \n\nThe initial version still needs improvement in several aspects. In the future, we will continue to explore integration with ReVideo to develop improved long-video editing models.\n\n## Failed Case and Discussion\n\nDespite the promising results of v1.1.0, there remains a gap between our model and Sora. Here, we present some failure cases and discuss them.\n\n### CasualVideoVAE\n\nDespite the significant performance improvement of VAE in v1.1.0 over v1.0.0, we still encounter failures in challenging cases, such as sand dunes and leaves. The video on the left shows the reconstructed video downsampled by a factor of 4 in time, while the video on the right is downsampled by a factor of 2. Both exhibit jitter when reconstructing fine-grained features. This indicates that reducing temporal downsampling alone cannot fully resolve the jitter issue.\n\nhttps://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/1a87d6d8-4bf1-4b4e-83bb-84870c5c3a11\n\nhttps://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/1a87d6d8-4bf1-4b4e-83bb-84870c5c3a11\n\n### Diffusion Model\n\n#### Semantic distortion\n\nOn the left is a video generated by v1.1.0 showing a puppy in the snow. In this video, the puppy's head exhibits semantic distortion, indicating that the model struggles to correctly identify which head belongs to which dog. On the right is a video generated by Sora's [base model](https://openai.com/index/video-generation-models-as-world-simulators/). We observe that Sora's early base model also experienced semantic distortion issues. This suggests that we may achieve better results by scaling up the model and increasing the amount of training data.\n\nPrompt：A litter of golden retriever puppies playing in the snow.Their heads pop out of the snow, covered in.\n\n| Our | Sora Base×1 | Sora Base×4 | Sora Base×32 |\n|---|---|---|---|\n| <img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/1d456168-afad-4e22-ae3b-fc28eca935e8\" width=224>  |<img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/c4ca99d9-9492-45c8-a75e-6efe21c330aa\" width=224>  |<img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/b7b52894-f58b-4e64-858b-015247108b8b\" width=224>  | <img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/dcaf793b-1da4-4cc1-9c17-8a46d55e80e6\" width=224> |\n\n#### Limited dynamics\n\nThe primary difference between videos and images lies in their dynamic nature, where objects undergo a series of changes across consecutive frames. However, the videos generated by v1.1.0 still contain many instances of limited dynamics. Upon reviewing a large number of training videos, we found that while web-crawled videos have high visual quality, they are often filled with meaningless close-up shots. These close-ups typically show minimal movement or are even static. On the left, we present a generated video of a bird, while on the right is a training video we found, which is almost static. There are many similar videos in the dataset from stock footage sites.\n\nPrompt：This close-up shot of a Victoria crowned pigeon showcases its striking blue plumage and red chest. Its crest is made of delicate, lacy feathers, while its eye is a striking red color. The bird's head is tilted slightly to the side,giving the impression of it looking regal and majestic. The background is blurred,drawing attention to the bird's striking appearance.\n\n\n| Our | Raw video |\n|---|---|\n|<img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/7ffb6bc6-b52c-488e-9f29-d7d90bda44d6\" width=224>  |<img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/95fbbca4-e206-42f6-8c6b-af1063c442c6\" width=224>  | \n\n#### Negative prompt\n\nWe found that using negative prompts can significantly improve video quality, even though we did not explicitly tag the training data with different labels. On the left is a video sampled using a negative prompt, while on the right is a video generated without a negative prompt. This suggests that we may need to incorporate more prior knowledge into the training data. For example, when a video has a watermark, we should note \"watermark\" in the corresponding caption. When a video's bitrate is too low, we should add more tags to distinguish it from high-quality videos, such as \"low quality\" or \"blurry.\" We believe that explicitly injecting these priors can help the model differentiate between the vast amounts of pretraining data (low quality) and the smaller amounts of fine-tuning data (high quality), thereby generating higher quality videos.\n\nPrompt：A litter of golden retriever puppies playing in the snow.Their heads pop out of the snow, covered in.\nNegative Prompt：distorted, discontinuous, ugly, blurry, low resolution, motionless, static, low quality\n\n\n| With Negative Prompt | Without Negative Prompt |\n|---|---|\n|<img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/1d456168-afad-4e22-ae3b-fc28eca935e8\" width=224>  |<img src=\"https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/7ad17f96-bfab-455a-830d-0daebccaf6fb\" width=224>  | \n\n## Future Work\n\nIn our future work, we will focus on two main areas: (1) data scaling and (2) model design. Once we have a robust baseline model, we will extend it to handle variable durations and conditional control models.\n\n### Data Scaling\n\n#### Data source\n\nAs mentioned earlier, our dataset is entirely sourced from stock footage websites. Although these videos are of high quality, many consist of close-up shots of specific areas, resulting in slow motion in the videos. We believe this is one of the main reasons for the limited dynamics observed. Therefore, we will continue to collect datasets from diverse sources to address this issue.\n\n#### Data volume\n\nIn v1.1.0, our dataset comprises only ~3k hours of video. We are actively collecting more data and anticipate that the video dataset for the next version will reach ~100k hours. We welcome recommendations from the open-source community for additional datasets.\n\n### Model Design\n\n#### CasualVideoVAE\nIn our internal testing, even without downsampling in time, we found that it is not possible to completely resolve the jitter issue in reconstructing fine-grained features. Therefore, we need to reconsider how to mitigate video jitter to the greatest extent possible while simultaneously supporting both images and videos. We will introduce a more powerful CasualVideoVAE in the next version.\n\n#### Diffusion Model\nIn v1.1.0, we found that 2+1D models can generate higher-quality videos in short durations. However, for long videos, they tend to exhibit discontinuities and inconsistencies. Therefore, we will explore more possibilities in model architecture to address this issue.\n"
  },
  {
    "path": "docs/Report-v1.2.0.md",
    "content": "# Report v1.2.0\n\nIn May 2024, we launched Open-Sora-Plan v1.1.0, featuring a 2+1D model architecture that could be quickly utilized for exploratory training in text-to-video generation tasks. However, when handling dense visual tokens, the 2+1D architecture could not simultaneously process spatial and temporal dimensions. Therefore, we transitioned to **a 3D full attention architecture**, which better captures the joint spatial-temporal features. Although this version is experimental, it advances video generation architecture to a new realm, leading us to release it as v1.2.0.\n\nCompared to previous video generation models, Open-Sora-Plan v1.2.0 offers the following improvements:\n\n1. **Better compressed visual representations**. We optimized the structure of CausalVideoVAE, which now delivers enhanced performance and higher inference efficiency.\n2. **Better video generation architecture**. Instead of 2+1D, we use a diffusion model with a 3D full attention architecture, which provides a better understanding of the world.\n\n\n### Open-Source Release\nWe open-source the Open-Sora-Plan to facilitate future development of Video Generation in the community. Code, data, model are made publicly available.\n- Code: All training scripts and sample scripts.\n- Model: Both Diffusion Model and CausalVideoVAE [here](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.2.0).\n- Data: Filtered data [here](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.2.0).\n\n\n## Gallery\n\n93×1280×720 Text-to-Video Generation. The video quality has been compressed for playback on GitHub.\n\n<table class=\"center\">\n<tr>\n  <td><video src=\"https://github.com/user-attachments/assets/1c84bc92-d585-46c9-ae7c-e5f79cefea88\" autoplay></td>\n</tr>\n</table>\n\n## Detailed Technical Report\n\n### CausalVideoVAE\n\n#### Model Structure\n\nThe VAE in version 1.2.0 maintains the overall architecture of the previous version but merges the temporal and spatial downsampling layers. In version 1.1.0, we performed spatial downsampling (stride=1,2,2) followed by temporal downsampling (stride=2,1,1). In version 1.2.0, we conduct both spatial and temporal downsampling simultaneously (stride=2,2,2) and perform spatial-temporal upsampling in the decoder (interpolate_factor=2,2,2).\n\nDue to the absence of additional convolutions during downsampling and upsampling, this method more seamlessly inherits the weights from the SD2.1 VAE, leading to improved initialization of our VAE.\n\n<img src=\"https://s21.ax1x.com/2024/07/24/pkHrHx0.png\" width=768>\n\n\n#### Training Details\n\n\nAs with v1.1.0, we initialize from the [SD2.1 VAE](https://huggingface.co/stabilityai/sd-vae-ft-mse) using tail initialization. We perform the first phase of training on the Kinetic400 video dataset, then use the EMA weights from this phase to initialize the second phase, which is fine-tuned on high-quality data (collected in v1.1.0). All training is conducted on 25-frame 256×256 videos using **one A100 node**.\n\n\n| Training stage | Dataset | Training steps  |\n|---|---|---|\n| 1 |  K400 |  200,000 |\n| 2 |  collected in v1.1.0 |  450,000   |\n\n\n#### Evaluation\n\nWe evaluated our VAE on the validation sets of two video datasets: [Webvid](https://github.com/m-bain/webvid) and [Panda70m](https://github.com/snap-research/Panda-70M/), and compared it with our [v1.1.0](https://github.com/PKU-YuanGroup/Open-Sora-Plan/blob/main/docs/Report-v1.1.0.md), [SD2.1 VAE](https://huggingface.co/stabilityai/sd-vae-ft-mse), [CV-VAE](https://github.com/AILab-CVC/CV-VAE), and [Open-Sora's VAE](https://github.com/hpcaitech/Open-Sora). The Webvid validation set contains 5k videos, while the Panda70m validation set has 6k videos. The videos were resized to 256 pixels on the short side, center-cropped to 256x256, and then 33 consecutive frames were extracted. We used PSNR, SSIM, and LPIPS metrics, and measured the encoding speed on an A100 GPU. The specific results are as follows:\n\n\n**WebVid**\n\n| Model | Compress Ratio |PNSR↑ | SSIM↑ |LPIPS↓ |\n|---|---|---|---|---|\n| SD2-1 VAE | 1x8x8 | 30.19 | 0.8379 | 0.0568 |\n| SVD VAE | 1x8x8 |<ins>31.15</ins> |<ins>0.8686</ins> | **0.0547** | \n| CV-VAE | 4x8x8 | 30.76 | 0.8566 | 0.0803 |\n| Open-Sora VAE | 4x8x8 | 31.12 | 0.8569 | 0.1003 |\n|  Open-Sora Plan v1.1 | 4x8x8 | 30.26 | 0.8597 |<ins>0.0551</ins> |\n|  Open-Sora Plan v1.2  | 4x8x8| **31.16** | **0.8694** | 0.0586 |\n\n**Panda70M**\n\n| Model | Compress Ratio| PNSR↑ | SSIM↑ |LPIPS↓ |\n|---|---|---|---|---|\n| SD2-1 VAE | 1x8x8 |30.40 | 0.8894 | 0.0396 |\n| SVD VAE | 1x8x8 |<ins>31.00</ins> | **0.9058** | **0.0379** | \n| CV-VAE  | 4x8x8| 29.57 | 0.8795 | 0.0673 |\n| Open-Sora VAE | 4x8x8 | **31.06** | 0.8969 | 0.0666 |\n| Open-Sora Plan v1.1 | 4x8x8 | 29.16 | 0.8844 | 0.0481 |\n|  Open-Sora Plan v1.2  | 4x8x8| 30.49 |<ins>0.8970</ins> |<ins>0.0454</ins>|\n\n**Encode Time on A100**\n\n|Input Size| CV-VAE | Open-Sora | Open-Sora Plan v1.1 | Open-Sora Plan v1.2 |\n|---|---|---|---|---|\n| 33x256x256 | 0.186 | 0.147 |<ins>0.104</ins> | **0.102** |\n| 81x256x256 | 0.465 | 0.357 |<ins>0.243</ins> | **0.242** |\n\n### Training Text-to-Video Diffusion Model\n\n#### Model Structure\n\nThe most significant change is that we **replaced all 2+1D Transformer blocks with 3D full attention blocks**. Each video is first processed by a patch embedding layer, which downsamples the spatial dimensions by a factor of 2. The video is then flattened into a one-dimensional sequence across the frame, width, and height dimensions. We replaced [T5-XXL](https://huggingface.co/DeepFloyd/t5-v1_1-xxl) with [mT5-XXL](https://huggingface.co/google/mt5-xxl) to enhance multilingual adaptation. Additionally, we incorporated RoPE.\n\n\n### Sequence Parallelism\n\nDue to the high computational complexity of 3D full attention, we must allocate a video across 2 GPUs for parallel processing when training with long-duration and high-resolution videos. We can control the number of GPUs used for a video sample by adjusting the batch size on a node. For example, with `sp_size=8` and `train_sp_batch_size=4`, 2 GPUs are used for a single sample. **We support sequence parallelism for both training and inference**.\n\n**Training on 93×720p**, we report speed on H100.\n\n| GPU （sp_size） | batch size | Enable sp | Train_sp_batch_size | Speed | Step per day |\n|---|---|---|---|---|---|\n|8|8|×|-|100s/step|~850|\n|8|-|√|4|53s/step|~1600|\n|8|-|√|2|27s/step|~3200|\n\n**Inference on 93×720p**, we report speed on H100.\n\n| Size | 1 GPU | 8 GPUs | \n|---|---|---|\n|29×720p|420s/100step|80s/100step|\n|93×720p|3400s/100step|450s/100step|\n\n#### Dynamic training\n\nDeep neural networks are typically trained using batched inputs. For efficient hardware processing, batch shapes are fixed, leading to a fixed data size. This requires either cropping or padding images to a uniform size, both of which have drawbacks: cropping degrades performance, while padding is inefficient and results in significant information loss. Generally, there are three methods for training with arbitrary token counts: Patch n' Pack, bucket, and pad-mask.\n\n\n\n**Patch n' Pack** ([NaViT](https://arxiv.org/abs/2307.06304)): bypasses the fixed sequence length limitation by combining tokens from multiple samples into a new sample. This approach allows variable-resolution images while maintaining aspect ratios by packaging multiple samples together, thereby reducing training time and enhancing performance and flexibility. However, this method involves significant code modifications and requires re-adaptation when exploring different model architectures in fields with unstable model designs.\n\n\n**Bucket** ([Pixart-alpha](https://arxiv.org/abs/2310.00426), [Open-Sora](https://github.com/hpcaitech/Open-Sora)): This method packages data of different resolutions into buckets, sampling batches from each bucket to ensure same resolution within each batch. It requires minimal code modifications to the model, mainly adjusting the data sampling strategy.\n\n**Pad-mask** ([FiT](https://arxiv.org/abs/2402.12376), our v1.0/v1.1): This method sets a maximum resolution and pads all data to this resolution, generating a corresponding mask. Although the approach is straightforward, it is computationally inefficient.\n\nWe believe that current video generation models are still in an exploratory phase. Extensive modifications to model code during this period can incur unnecessary development costs. The pad-mask method, while straightforward, is computationally inefficient and can waste resources in video, which involves dense computations. Ultimately, we chose the bucket strategy, which requires no modifications to the model code. Next, we will explain how our bucket strategy supports arbitrary lengths and resolutions. For simplicity, we will use video duration as an example:\n\n<img src=\"https://s21.ax1x.com/2024/07/24/pkHr4aQ.png\" width=768>\n\n\nWe define a megabatch as the total data processed in a single step across all GPUs. A megabatch can be divided into multiple batches, with each batch corresponding to the data processed by a single GPU.\n\n**Sort by frame**: The first step is to count the number of frames in all video data and sort them. This step aims to group similar data together, with sorting being one method to achieve this.\n\n**Group megabatch**: Next, all data is divided into groups, each forming a megabatch. Since all data is pre-sorted, most videos within a megabatch have the same number of frames. However, there will always be boundary cases, such as having both 61-frame and 1-frame videos in a single megabatch.\n\n**Re-organize megabatch**: We re-organize these special megabatches, which actually constitute a small proportion. We randomly replace the minority data in the megabatch with the majority data, thus re-organizing it into a megabatch with same frame counts.\n\n**Shuffle megabatch**: To ensure data randomness, we shuffle both within each megabatch and between different megabatches.\n\nWhen supporting dynamic resolutions, we simply replace each sample's frame sequence with (frame × height × width). This method ensures that the data dimension processed by each GPU in every step is the same, preventing situations where GPU1 waits for GPU0 to finish processing a longer video. Moreover, it is entirely decoupled from the model code, serving as a plug-and-play video sampling strategy.\n\n\n#### Training stage\n\nSimilar to previous work, we use a multi-stage training approach. With the 3D DiT architecture, all parameters can be transferred from images to videos without loss. To explore training costs, all parameters of the diffusion model are trained from scratch. Therefore, we first train an text-to-image model, using the training strategy from [Pixart-alpha](https://arxiv.org/abs/2310.00426).\n\nThe video model is initialized with weights from a 480p image model. We first train 480p videos with 29 frames. Next, we adapt the weights to 720p resolution, training on approximately 6 million higher-quality (HQ) samples from Panda70M, filtered for aesthetic quality and motion. Finally, we refine the model with a more higher-quality (HQ) subset of 1 million samples. After that, we use a filtered data (collected in v1.1.0) for fine-tuning 93-frame 720p videos. Below is our training card. We release the annotation file [here](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.2.0/tree/main/anno_json).\n\n| Name | Stage 1 | Stage 2 | Stage 3 | Stage 4 |Stage 5 |\n|---|---|---|---|---|---|\n| Training Video Size | 1×320×240 |  1×640×480 | 29×640×480 |  29×1280×720 | 93×1280×720 |\n| Training Step| 146k |  200k | 30k | 21k | 3k |\n| Compute (#Num x #Hours) | 32 Ascend × 81 | 32 Ascend × 142 |  128 Ascend × 38 | 256 H100 × 64 | 256 H100 × 84 |\n| Checkpoint | - | [HF](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.2.0/tree/main/1x480p) | [HF](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.2.0/tree/main/29x480p) | [HF](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.2.0/tree/main/29x720p) | [HF](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.2.0/tree/main/93x720p) |\n| Log | - | - | [wandb](https://api.wandb.ai/links/1471742727-Huawei/trdu2kba) | [wandb](https://api.wandb.ai/links/linbin/vvxvcd7s) | [wandb](https://api.wandb.ai/links/linbin/easg3qkl)\n| Training Data | [10M SAM](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.2.0/blob/main/anno_json/sam_image_11185255_resolution.json) | 5M internal image data | [6M HQ Panda70M](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.2.0/blob/main/anno_json/Panda70M_HQ6M.json) | [6M HQ Panda70M](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.2.0/blob/main/anno_json/Panda70M_HQ6M.json) | [1M HQ Panda70M](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.2.0/blob/main/anno_json/Panda70M_HQ1M.json) and [100k HQ data](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.2.0/tree/main/anno_json) (collected in v1.1.0) |\n\nAdditionally, we fine-tuned 3.5k steps from the final 93×720p to get [93×480p](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.2.0/tree/main/93x480p) for community research use.\n\n### Training Image-to-Video Diffusion Model\n\n#### Model Structure\n\n<img src=\"https://s21.ax1x.com/2024/08/12/pApZZJf.png\">\n\nTo reuse the weights of the Text-to-Video model, our Image-to-Video model is inspired by the Stable Diffusion Inpainting Model and adopts a strategy based on frame-level inpainting. By incorporating three types of information—original noise, masked video, and mask—under different control frame conditions, our model can generate coherent videos while ensuring flexibility in its usage.\n\nCompared to the denoiser structure of the Text-to-Video model, the Inpainting Model's denoiser has only changed the number of channels in the `conv in` layer. To ensure the model has a good prior knowledge, we introduce the masked video and mask information through zero initialization. We believe this is due to the 2+1D structure's lack of ability to establish long-range information dependencies, and relying solely on attention in the temporal dimension makes it difficult to capture information changes under frame control. In Text-to-Video tasks, this phenomenon is not as evident because all frames share the same text prompt embedding. However, in Image-to-Video tasks, simply concatenating images in the channel dimension does not ensure the model can accurately capture changes between frames. This is because the model cannot directly replicate image information from the channels to reduce the loss, and the 2+1D structure's interaction solely on the temporal axis fails to allow the model to discern which information from the control frames can be utilized, especially there are significant differences between frames. Therefore, without a shared image-semantic information, the control frame information might not be effectively conveyed to each frame.\n\n##### About Semantic Adapter\n\nIn previous models based on the Unet 2+1D architecture, it is necessary to input the control frames into the CLIP model to obtain semantic embeddings. These semantic embeddings are then injected into the denoiser through cross-attention. The structure that extracts CLIP embeddings and injects them into the denoiser is commonly referred to as a semantic adapter.\n\nIn the 2+1D architecture, the semantic adapter is commonly present. Additionally, papers like [DynamiCrafter](https://arxiv.org/abs/2310.12190) have pointed out that incorporating the semantic adapter helps maintain stability in the generated videos. We believe this is because the 2+1D structure lacks the ability to establish long-range information dependencies, and relying solely on attention in the temporal dimension makes it difficult to capture information changes under frame control. In the Text-to-Video task, this phenomenon is not as evident because all frames share the same text prompt embedding. However, in the Image-to-Video task, without shared semantic information, it may lead to the inability to effectively transfer control frame information to each individual frame.\n\n<center>\n<figure>\n    <img src=\"https://github.com/user-attachments/assets/06df193a-fe89-42c1-8c01-fb3b7c2be0e3\" height=400 />\n\t<img src=\"https://github.com/user-attachments/assets/09906df4-ab9f-443d-8d38-512e16075b0c\" height=400 />\n</figure>\n</center>\n\nWe conducted a simple comparison of the performance of using the Inpainting Model under the 2+1D structure (Open-Sora Plan v1.1, left in the figure) versus the 3D structure (Open-Sora Plan v1.2, right in the figure). With the same number of optimization steps, the probability of unstable visual performance in the 2+1D structure was significantly higher than in the 3D structure. Even at convergence, the 2+1D structure's visual stability was still inferior to that of the 3D structure, and it was even worse than the early training stages of the 3D structure.\n\n## Future Work and Discussion\n\n#### CausalVideoVAE\nWe observed that high-frequency motion information in videos tends to exhibit jitter, and increasing training duration and data volume does not significantly alleviate this issue. In videos, compressing the duration while maintaining the original latent dimension can lead to significant information loss. A more robust VAE will be released in the next version.\n\n#### Diffusion Model\nWe replaced T5 with mT5 to enhance multilingual capabilities, but this capability is limited as our training data is currently only in English. The multilingual ability primarily comes from the mT5 mapping space. We will explore additional text encoders and expand the data in the next steps.\n\nOur model performs well in generating character consistency, likely due to panda70m being a character-centric dataset. However, it still shows poor performance in text consistency and object generalization. We suspect this may be due to the limited amount of data the model has seen, as evidenced by the non-convergence of the loss in the final stage. **We hope to collaborate with the open-source community to optimize the 3D DiT architecture.**"
  },
  {
    "path": "docs/Report-v1.3.0.md",
    "content": "# Report v1.3.0\n\nIn August 2024, we released Open-Sora-Plan v1.2.0, transitioning to a 3D full attention architecture, which enhanced the capture of joint spatial-temporal features. However, the substantial computational cost made it unsustainable, and the lack of a clear training strategy hindered continuous progress along a focused path.\n\nIn version 1.3.0, Open-Sora-Plan introduced the following five key features:\n\n**1. A more powerful and cost-efficient WFVAE.** We decompose video into several sub-bands using wavelet transforms, naturally capturing information across different frequency domains, leading to more efficient and robust VAE learning.\n\n**2. Prompt Refiner.** A large language model designed to refine short text inputs.\n\n**3. High-quality data cleaning strategy.** The cleaned panda70m dataset retains only 27% of the original data.\n\n**4. DiT with new sparse attention.** A more cost-effective and efficient learning approach.\n\n**5. Dynamic resolution and dynamic duration.** This enables more efficient utilization of videos with varying lengths (treating a single frame as an image).\n\n### Open-Source Release\nWe open-source the Open-Sora-Plan to facilitate future development of Video Generation in the community. Code, data, model will be made publicly available.\n- Code: All training scripts and sample scripts.\n- Model: Both Diffusion Model and CasualVideoVAE [here](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.3.0).\n- Data: The data of prompt refiner is [here](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.3.0/tree/main/prompt_refiner).\n\n## Gallery\n\nText & Image to Video Generation. \n\n[![Demo Video of Open-Sora Plan V1.3](https://github.com/user-attachments/assets/4ff1d873-3dde-4905-a907-dbff51174c20)](https://www.bilibili.com/video/BV1KR2fYPEF5/?spm_id_from=333.999.0.0&vd_source=cfda99203e659100629b465161f1d87d)\n\n## Detailed Technical Report\n\n### WF-VAE\n\nAs video generation models move toward higher resolutions and longer durations, the computational cost of video VAEs grows exponentially, becoming unsustainable. Most related work addresses this by using tiling to reduce inference memory consumption. However, in high-resolution, long-duration scenarios, tiling significantly increases inference time. Additionally, since tiling is lossy for latents, it can lead to visual artifacts such as shadows or flickering in the generated videos. Then, we introduce WFVAE, which provide a new model to handle these problems.\n\n#### Model Structure\n\n<center>\n<figure>\n\t<img width=\"899\" alt=\"SCR-20241023-tzct\" src=\"https://github.com/user-attachments/assets/03615e1d-2633-4247-af0b-d93e2a935e3e\">\n</figure>\n</center>\n\nThe compression rate fundamentally determines the quality of VAE-reconstructed videos. We analyzed the energy and entropy of different subbands obtained through wavelet transform and found that most of the energy in videos is concentrated in the low-frequency bands. Moreover, by replacing the `LLL` subband of the VAE-reconstructed video with the original video's `LLL` subband, we observed a significant improvement in the spatiotemporal quality of the videos.\n\n<center>\n<figure>\n\t<img src=\"https://github.com/user-attachments/assets/533666a6-05be-4584-8b14-86f01d0471dd\" height=250 />\n</figure>\n</center>\n\nIn previous VAE architectures, the lack of a \"highway\" for transmitting the dominant energy during video compression meant that this pathway had to be gradually established during model training, leading to redundancy in model parameters and structure. Therefore, in our model design, we created a more efficient transmission path for the LLL subband energy, significantly simplifying the model architecture, reducing inference time, and lowering memory consumption.\n\n#### Training Details\n\nMore details will be provided in the forthcoming paper.\n\n#### Ablation Study\n\nIn our experiments, we used the K400 training and validation sets, conducted on 8xH100 GPUs. The latent dimension was fixed at 4. We observed that as model parameters increased, there was still room for improvement in reconstruction metrics. GroupNorm showed instability during training, performing worse than LayerNorm on PSNR but better on LPIPS.\n\n<center>\n<figure>\n\t<img src=\"https://github.com/user-attachments/assets/ed880143-72d1-4316-a1d4-5fdfc5ed155a\" height=200 />\n\t<img src=\"https://github.com/user-attachments/assets/303954c3-73ee-44f3-9897-d3d14b37b27e\" height=200 />\n</figure>\n</center>\n\n#### Performance\n\nThe following metrics were tested on H100 with float32 precision. For fairness, tiling was disabled for all models, and direct inference was performed.\n\n<center>\n<figure>\n\t<img width=\"765\" alt=\"SCR-20241023-tzwz\" src=\"https://github.com/user-attachments/assets/f7d4f225-5d22-4152-90ad-32716884ae6c\">\n</figure>\n</center>\n\n\n#### Evaluation\n\nWe evaluated PSNR and LPIPS on the Panda70M test set at 256 pixels and 33 frames. In the open-source WF-VAE-S (8-dim), our encoder was distilled from the 8-dim OD-VAE, resulting in some metric degradation compared to direct training.\n\n\n| Latent Dim | Model | Params |  PSNR |  LPIPS | \n|---|---|---|---|---|\n| 4 | OD-VAE（Our VAE in v1.2.0） | 94M + 144M | 30.311| 0.043|\n| 4 | WFVAE-S | 38M + 108M | 30.579 | 0.044 |\n| 8 | WFVAE-S（Distillion） |38M + 108M | 31.764|0.050 |\n\n#### Causal Cache\n\n<center>\n<figure>\n\t<img src=\"https://github.com/user-attachments/assets/59cb0543-225b-45a3-a4a6-429e5e753167\" height=200 />\n</figure>\n</center>\n\n\nTo address the issue of tiling, we replaced GroupNorm with LayerNorm and introduced a novel method called **Causal Cache**, enabling lossless temporal block-wise inference.\n\nFirst, we replaced GroupNorm with LayerNorm and utilized the properties of CausalConv3D to achieve lossless inference through temporal dimension chunking. In each layer of CausalConv3D, we cache the information from the previous few frames to maintain continuity during the convolution sliding operation for the next temporal chunk, thereby enabling lossless processing. As illustrated, we use a kernel size of 3 and a stride of 1 as an example:\n\n**Initial Chunk (chunk idx=0):** For the first time chunk, we perform standard causal padding to support joint processing of graphs and videos. After the convolution operation, we cache the last two frames of this chunk into the causal cache in preparation for the next chunk's inference.\n\n**Subsequent Chunks (chunk idx=1 and beyond):** Starting from the second time chunk, we no longer use causal padding. Instead, we concatenate the cached causal information from the previous chunk to the front of the current chunk. We continue to cache the last two frames of the current input into the causal cache for use in subsequent chunks.\n\n## Prompt Refiner\n\nUser-provided captions are typically fewer than 10 words, whereas the text annotations in the current training data are often dense. This inconsistency between training and inference may result in poor visual quality and weak text alignment. We categorize captions into four types:\n\n(1) Short captions from real user input; we collected 11k from [COCO](https://cocodataset.org/#home).\n\n(2) Captions composed of multiple tags; we collected 5k from [DiffusionDB](https://github.com/poloclub/diffusiondb).\n\n(3) Medium-length captions generated by large language models; 3k sourced from [JourneyDB](https://github.com/JourneyDB/JourneyDB).\n\n(4) Ultra-long, surrealist captions, sourced from Sora/Vidu/Pika/Veo and approximately 0.5k generated by GPT.\n\nWe used ChatGPT to rewrite the above captions, with the following instructions provided to ChatGPT:\n\n```\nrewrite the sentence to contain subject description action, scene description. \nOptional: camera language, light and shadow, atmosphere and\nconceive some additional actions to make the sentence more dynamic,\nmake sure it is a fluent sentence, not nonsense.\n```\n\nFinally, we performed LoRA fine-tuning using [LLaMa 3.1](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct), completing the training in just 30 minutes with a single H100. We fine-tuned for only 1 epoch, using a batch size of 32 and a LoRA rank of 64. The log can be found [here](https://api.wandb.ai/links/1471742727-Huawei/p5xmkft5). We open-sourced the data [here](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.3.0/tree/main/prompt_refiner).\n\n### Data Construction\n\nWe randomly sampled from the original Panda70m dataset and found many videos to be static, contain multiple subtitles, or suffer from motion blur. Additionally, the captions in Panda70m did not always accurately describe the video content. To address this, we designed a video filtering pipeline, which retained approximately 27% of the videos after processing.\n\n<center>\n<figure>\n\t<img src=\"https://github.com/user-attachments/assets/90f9d386-ff2e-465a-b013-a9e7151afaf8\" height=400 />\n</figure>\n</center>\n\n\n#### Jump Cut and Detect Motion\n\nWe used [LPIPS](https://github.com/richzhang/PerceptualSimilarity) frame-skipping to compute inter-frame semantic similarity, identifying anomalies as cut points and taking the mean as the motion score. We found that videos with motion scores below 0.001 were nearly static, while those above 0.3 exhibited significant jitter and flicker. After applying this method, we manually reviewed 2k videos and concluded that the cut detection accuracy was sufficient for pre-training requirements.\n\n#### OCR\n\nWe estimated the average position of subtitles on common video platforms to be around 18%. Consequently, we set the maximum crop threshold to 20% of the video's original dimensions and used [EasyOCR](https://github.com/JaidedAI/EasyOCR) to detect subtitles (sampling one frame per second). However, not all videos have subtitles or printed text located at the edges; this method may miss text appearing in central areas, such as in advertisement videos or speeches. Nonetheless, we cannot assume that the presence of text in a video necessitates filtering it out, as certain texts in specific contexts can be meaningful. We leave such judgments to aesthetic considerations.\n\n#### Aesthetic\n\nAs before, we used the [Laion aesthetic predictor](https://github.com/christophschuhmann/improved-aesthetic-predictor) for evaluation. Based on the visualization [website](http://captions.christoph-schuhmann.de/aesthetic_viz_laion_sac+logos+ava1-l14-linearMSE-en-2.37B.html), we determined that a score of 4.75 serves as a suitable threshold, effectively filtering out excessive text while retaining high-quality aesthetics. We will add an additional aesthetic prompt, such as `A high-aesthetic scene, ` for data with a score above 6.25.\n\n#### Video Quality\n\nSome old photos or videos have very low bit rates, resulting in blurry visual effects even at 480P resolution, often resembling a mosaic appearance. Aesthetic filtering struggles to exclude these videos, as it resizes images to 224 resolution. We aim to establish a metric for assessing absolute video quality, independent of the visual content itself, focusing solely on compression artifacts, low bit rates, and jitter. We employed the technical prediction score from [DOVER](https://github.com/VQAssessment/DOVER) and excluded videos with scores below 0.\n\n#### Recheck Motion\n\nSince some videos contain subtitles, variations in the subtitles may lead to inaccurate motion values. Therefore, we re-evaluated the motion values and discarded static videos.\n\n#### Captioning\n\nWe used [QWen2-VL-7B](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) for video annotation.\n\n```\nPlease describe the content of this video in as much detail as possible, \nincluding the objects, scenery, animals, characters, and camera movements within the video. \nDo not include '\\n' in your response. \nPlease start the description with the video content directly. \nPlease describe the content of the video and the changes that occur, in chronological order.\n```\n\nHowever, the 7B model tends to generate certain prefixes, such as \"This video\" or \"The video.\" We compiled a list of all irrelevant opening strings and removed them.\n\n```\n    'The video depicts ', \n    'The video captures ', \n    'In the video, ', \n    'The video showcases ', \n    'The video features ', \n    'The video is ', \n    'The video appears to be ', \n    'The video shows ', \n    'The video begins with ', \n    'The video displays ', \n    'The video begins in ', \n    'The video consists of ', \n    'The video opens with ', \n    'The video opens on ', \n    'The video appears to capture ', \n    'The video appears to show ', \n    \"The video appears to depict \", \n    \"The video opens in \", \n    \"The video appears to focus closely on \", \n    \"The video starts with \", \n    \"The video begins inside \", \n    \"The video presents \", \n    \"The video takes place in \", \n    \"The video appears to showcase \", \n    \"The video appears to display \", \n    \"The video appears to focus on \", \n    \"The video appears to feature \"\n```\n\n### Training Text-to-Video Diffusion Model\n\n#### Framework\n\n##### Skiparse (Skip-Sparse) Attention\n\nIn video generation models, alternating 2+1D spatial-temporal blocks is a commonly used approach, yet these models lack long-range modeling, limiting their performance ceiling. Consequently, models like  [CogVideoX](https://arxiv.org/abs/2408.06072), [Meta Movie Gen](https://ai.meta.com/research/movie-gen/), and Open-Sora Plan v1.2 employ **Full 3D Attention** as a denoiser, achieving substantially improved visual fidelity and motion quality compared to 2+1D models. This approach, however, requires calculating attention across all tokens in each clip encoding, which significantly raises training costs. For instance, Open-Sora Plan v1.2, training a 2.7-billion-parameter model, takes **100 seconds per step at 93x720p and over 15 seconds per step at 93x480p**, severely constraining scalability under limited computational resources.\n\nTo accelerate training while ensuring adequate performance, we propose the **Skiparse (Skip-Sparse) Attention** method. Specifically, under a fixed sparse ratio $$k$$ , we organize candidate tokens for attention through two alternating skip-gather methods. This approach preserves the attention operation is global while effectively reducing FLOPS, enabling faster training of 3D Attention models. In our experiments, applying Skiparse with sparse ratio $$k=4$$ to a 2.7B model reduced training time to **42 seconds per step at 93x720p and 8 seconds per step at 93x480p**.\n\n<center>\n<figure>\n\t<img src=\"https://github.com/user-attachments/assets/186377ca-26b2-4f0f-af42-ae6c846eebcb\" />\n</figure>\n</center>\n\n**Skiparse DiT modifies only the Attention component** within the Transformer Block, using two alternating Skip Sparse Transformer Blocks. With sparse ratio $$k$$, the sequence length in the attention operation reduces to $$\\frac{1}{k}$$ of the original, and batch size increases by $$k$$-fold, lowering the theoretical complexity of self-attention to $$\\frac{1}{k}$$ of the original, while cross-attention complexity remains unchanged. Due to GPU/NPU parallel processing, increasing the batch size by $$k$$-fold does not linearly decrease speed to $$\\frac{1}{k}$$, resulting in a performance boost that exceeds theoretical expectations.\n\n<center>\n<figure>\n\t<img src=\"https://github.com/user-attachments/assets/80f9470a-8afe-4588-a22c-e8c576fea9b6\" />\n</figure>\n</center>\n\nIn Single Skip mode, the elements located at positions $$[0, k, 2k, 3k, ...]$$ ,  $$[1, k+1, 2k+1, 3k+1, ...]$$ , ..., $$[k-1, 2k-1, 3k-1, ...]$$ are grouped into the same scope (with each list forming one scope of elements). The figure above, using $$k=2$$ as an example, illustrates this organizational structure. This concept is straightforward, as each token performs attention with tokens spaced $$k-1$$ apart.\n\n<center>\n<figure>\n\t<img src=\"https://github.com/user-attachments/assets/5880f667-7a06-4e7f-8e44-2e1cfb9209b8\" />\n</figure>\n</center>\n\nIn Group Skip mode, elements at positions $$[(0, 1, ..., k-1), (k^2, k^2+1, ..., k^2+k-1), (2k^2, 2k^2+1, ..., 2k^2+k-1), ...]$$ , $$[(k, k+1, ..., 2k-1), (k^2+k, k^2+k+1, ..., k^2+2k-1), (2k^2+k, 2k^2+k+1, ..., 2k^2+2k-1), ...]$$ , ..., $$[(k^2-k, k^2-k-1, ..., k^2-1), (2k^2-k, 2k^2-k-1, ..., 2k^2-1), (3k^2-k, 3k^2 -k-1, ..., 3k^2-1), ...]$$ are grouped together as a scope (with each list forming a scope). This arrangement may seem complex numerically, so it can be helpful to understand with the above figure.\n\nIn this pattern, we first **group adjacent tokens** in segments of length $$k$$ , then **bundle these groups** with other groups that are spaced $$k-1$$ groups apart into a single scope.For example, in $$[(0, 1, ..., k-1), (k^2, k^2+1, ..., k^2+k-1), (2k^2, 2k^2+1, ..., 2k^2+k-1), ...]$$ , each set of indices in parentheses represents a group. Each group is then connected with another group that is offset by $$k-1$$ groups, forming one scope.\n\nSince the last index of the first group is $$k-1$$ , the first token in the next group to be linked will be at index $$k-1+k(k-1)+1=k^2$$ . Following this pattern, you can determine the indices for each scope in this configuration.\n\n##### Why \"Skiparse\"?\n\nThe 2+1D DiT models temporal understanding only along the time axis of a single spatial location, theoretically and practically limiting performance. In real-world scenarios, changes at a specific spatial location are typically influenced not by prior content at that same location but by content across all spatial locations at preceding times. This constraint makes it challenging for 2+1D DiT to model complex physical dynamics accurately.\n\nFull 3D Attention represents global attention, allowing any spatial position at any time to access information from any other position across all times, aligning well with real-world physical modeling. However, this approach is time-consuming and inefficient, as visual information often contains considerable redundancy, making it unnecessary to establish attention across all spatiotemporal tokens.\n\n**A ideal spatiotemporal modeling approach should employ attention that minimizes the overhead from redundant visual information while capturing the complexities of the dynamic physical world**. Reducing redundancy requires avoiding connections among all tokens, yet global spatiotemporal attention remains essential for modeling complex physical interactions.\n\nTo achieve a balance between 2+1D efficiency and Full 3D’s strong spatiotemporal modeling, we developed Skiparse Attention.  This approach provides global spatiotemporal attention within each block, with each block having the same “receptive field”. The use of \"group\" operations also introduces a degree of locality, aligning well with visual tasks.\n\nInterestingly, once you understand the Skiparse Attention mechanism, you’ll notice that **the attention in 2+1D DiT corresponds to a sparse ratio of $$k=HW$$  (since $$T \\ll HW$$ , making the \"skip\" in Group Skip negligible), while Full 3D DiT corresponds to a sparse ratio of $$k=1$$.** In Skiparse Attention, $$k$$ is typically chosen to be close to 1, yet far smaller than $$HW$$ , making it a 3D Attention that approaches the effectiveness of Full 3D Attention.\n\nIn Skiparse Attention, Single Skip is a straightforward operation, easily understood by most. Within Group Skip, the Group operation is also intuitive, serving as a means to model local information. However, **Group Skip involves not only grouping but also skipping**—particularly between groups—which is often overlooked. This oversight frequently leads researchers to confuse Skiparse Attention with a Skip + Window Attention approach. The key difference lies in even-numbered blocks: Window Attention only groups tokens without skipping between groups. The distinctions among these attention methods are illustrated in the figure below, which shows the attention scopes for self-attention only, with dark tokens representing the tokens involved in each attention calculation.\n\n<center>\n<figure>\n\t<img src=\"https://github.com/user-attachments/assets/62d0c75a-7e1d-458e-9faf-ae394e8ddd34\" />\n</figure>\n</center>\n\nTo deeply understand why nearly global attention is necessary and why Skiparse Attention theoretically approximates Full 3D Attention more closely than other common methods, we introduce the concept of **Average Attention Distance**. This concept is defined as follows: for any two tokens, if it takes $$m$$ attention operations to establish a connection between them, the attention distance is  $$m$$ . The average attention distance for a tensor is then the mean of the attention distances across all token pairs, representing the corresponding attention method’s overall connectivity efficiency. The average attention distance of all tokens within a tensor is defined as the average attention distance for that particular attention method. \n\nFor example, in Full 3D Attention, any token can connect with any other token in just one attention operation, resulting in an average attention distance of 1.\n\nIn 2+1D Attention, the process is somewhat more complex, though still straightforward to understand. In all configurations above, any two different tokens can connect with an attention distance between 1 and 2 (Note that we define the attention distance between a token and itself as zero). Thus, for the other three attention methods, we can first identify which tokens have an attention distance of 1. Subsequently, tokens with an attention distance of 2 can be determined, allowing us to calculate the average attention distance.\n\nIn the $$2N$$ Block, attention operates over the $$(H, W)$$ dimensions, where tokens within this region have an attention distance of 1. In the $$2N+1$$ Block, attention operates along the $$(T)$$ dimension, also assigning an attention distance of 1 for these tokens. The total number of tokens with an attention distance of 1 in this case is $$HW + T - 2$$ (excluding the token itself, hence $$(HW + T - 1) - 1 = HW + T - 2$$).\n\nTherefore, in 2+1D Attention, the average attention distance (AVG Attention Distance) is:\n\n$$\n\\begin{aligned}\n\td&=\\frac{1}{THW}\\left[ 1\\times 0+\\left( HW+T-2 \\right) \\times 1+\\left[ THW-\\left( HW+T-1 \\right) \\right] \\times 2 \\right]\\\\\n\t&=2-\\left( \\frac{1}{T}+\\frac{1}{HW} \\right)\\\\\n\\end{aligned}\n$$\n\nIn Skip+Window Attention, aside from the token itself, there are $$\\frac{THW}{k} - 1$$ tokens with an attention distance of 1 in the $$2N$$ Block, and $$k - 1$$ tokens with an attention distance of 1 in the $$2N+1$$ Block. Thus, the total number of tokens with an attention distance of 1 is $$\\frac{THW}{k} + k - 2$$.\n\nTherefore, in Skip+Window Attention, the average attention distance (AVG Attention Distance) is:\n\n$$\n\\begin{aligned}\n\td&=\\frac{1}{THW}\\left[ 1\\times 0+\\left( \\frac{THW}{k}+k-2 \\right) \\times 1+\\left[ THW-\\left( \\frac{THW}{k}+k-1 \\right) \\right] \\times 2 \\right]\\\\\n\t&=2-\\left( \\frac{1}{k}+\\frac{k}{THW} \\right)\\\\\n\\end{aligned}\n$$\n\nIn Skiparse Attention, aside from the token itself, $$\\frac{THW}{k} - 1$$ tokens have an attention distance of 1 in the $$2N$$ Block, and $$\\frac{THW}{k} - 1$$ tokens have an attention distance of 1 in the $$2N+1$$ Block. Notably, $$\\frac{THW}{k^2} - 1$$ tokens can establish an attention distance of 1 in both blocks and should not be counted twice.\n\nTherefore, in Skiparse Attention, the average attention distance (AVG Attention Distance) is:\n\n$$\n\\begin{aligned}\n\td&=\\frac{1}{THW}\\left[ 1\\times 0+\\left[ \\frac{2THW}{k}-2-\\left( \\frac{THW}{k^2}-1 \\right) \\right] \\times 1+\\left[ THW-\\left( \\frac{2THW}{k}-\\frac{THW}{k^2} \\right) \\right] \\times 2 \\right]\\\\\n\t&=2-\\frac{2}{k}+\\frac{1}{k^2}-\\frac{1}{THW}\\\\\n\t&=2-\\frac{2}{k}+\\frac{1}{k^2}\\left( 1\\ll THW \\right)\\\\\n\\end{aligned}\n$$\n\nIn fact, in the Group Skip of the $$2N+1$$ Block, the actual sequence length is $$k\\lceil \\frac{THW}{k^2} \\rceil$$ rather than $$\\frac{THW}{k}$$. The prior calculation assumes the ideal case where $$k \\ll THW$$ and $$k$$ divides $$THW$$ exactly, yielding $$k\\lceil \\frac{THW}{k^2} \\rceil = k \\cdot \\frac{THW}{k^2} = \\frac{THW}{k}$$. In practical applications, excessively large $$k$$ values are typically avoided, making this derivation a reasonably accurate approximation for general use.\n\nSpecifically, when $$k = HW$$ and padding is disregarded, since $$T \\ll HW$$, group skip attention reduces to window attention with a window size of $$HW$$. Given that padding does not affect the final computation, Skiparse Attention is equivalent to 2+1D Attention when $$k = HW$$.\n\nFor the commonly used resolution of 93x512x512, using a causal VAE with a 4x8x8 compression rate and a DiT with a 1x2x2 patch embedding, we obtain a latent shape of 24x32x32 before applying attention. The AVG Attention Distance for different calculation methods would then be as follows:\n\n|                        | Full 3D Attention | 2+1D  Attention |\n| ---------------------- | ----------------- | --------------- |\n| AVG Attention Distance | 1                 | 1.957           |\n\n|                        | Skip + Window Attention(k=2) | Skip + Window Attention(k=4) | Skip + Window Attention(k=6) | Skip + Window Attention(k=8) |\n| ---------------------- | ---------------------------- | ---------------------------- | ---------------------------- | ---------------------------- |\n| AVG Attention Distance | 1.500                        | 1.750                        | 1.833                        | 1.875                        |\n\n|                        | Skiparse Attention(k=2) | Skiparse Attention(k=4) | Skiparse Attention(k=6) | Skiparse Attention(k=8) |\n| ---------------------- | ----------------------- | ----------------------- | ----------------------- | ----------------------- |\n| AVG Attention Distance | 1.250                   | 1.563                   | 1.694                   | 1.766                   |\n\nIn 2+1D Attention, the average attention distance is 1.957, larger than that of Skip + Window Attention and Skiparse Attention at commonly used sparse ratios. While Skip + Window Attention achieves a shorter average attention distance, its modeling capacity remains limited due to the locality of attention in its 2N+1 blocks. Skiparse Attention, with the shortest average attention distance, applies global attention in both 2N and 2N+1 blocks, making its spatiotemporal modeling capabilities closer to Full 3D Attention than the other two non-Full 3D methods.\n\n<center>\n<figure>\n\t<img src=\"https://github.com/user-attachments/assets/80ca6d70-5033-454b-883f-11d12d140360\" width=600/>\n</figure>\n</center>\n\nThe figure above shows how Skiparse Attention’s AVG Attention Distance changes with sparse ratio $$k$$.\n\nWe can summarize the characteristics of these attention types as follows:\n\n|                                    | Full 3D Attention | 2+1D  Attention                  | Skip + Window Attention                 | Skiparse Attention                                           |\n| ---------------------------------- | ----------------- | -------------------------------- | --------------------------------------- | ------------------------------------------------------------ |\n| Speed                              | Slow              | Fast                             | Depending on $$k$$                    | Depending on $$k$$                                         |\n| Spatiotemporal modeling capability | Strong            | Weak                             | Weak                                    | Approaches Full 3D                                           |\n| Is attention global?               | Yes               | No                               | Half of the attention blocks are global | Yes                                                          |\n| Computation load per block         | Equal             | Not Equal                        | Not Equal                               | Equal                                                        |\n| AVG Attention Distance             | 1                 | $$2-(\\frac{1}{T}+\\frac{1}{HW})$$ | $$2-(\\frac{1}{k}+\\frac{k}{THW})$$       | $$2-\\frac{2}{k}+\\frac{1}{k^2},1<k\\ll THW$$                   |\n\nConsidering both computational load and AVG Attention Distance, we select Skiparse with $$k = 4$$, replacing the first and last two blocks with Full 3D Attention to enhance performance.\n\nOverall, we retained the architecture from version 1.2 but incorporated Skiparse Attention module.\n\n\n<center>\n<figure>\n\t<img src=\"https://github.com/user-attachments/assets/af21a577-e5a8-46ac-8be4-0cd08cddb6c6\" height=350 />\n</figure>\n</center>\n\n#### Dynamic training\n\nOverall, we maintained the bucket strategy from v1.2.0, pre-defining the shape of each video during training and aggregating data of the same shape through a sampler. Finally, the dataloader retrieves data based on our aggregated indices.\n\nIn our early implementation, we specified `--max_width`, `--max_height`, `--min_width`, and `--min_height`. While this allows for specifying arbitrary resolutions within a certain range, this approach can easily lead to OOM issues during video training. For instance, for a 720P (720×1280) video, if the maximum dimensions are set to 720, the video would be scaled to 405×720. However, if there are square videos with resolutions greater than 720, they would be scaled to 720×720. Most videos are non-square, and to prevent OOM, we need to reserve GPU memory, which leads to significant computational waste. Therefore, we recommend using `--max_token` and `--min_token` to limit any range, as this aligns better with the Transformer architecture.\n\n#### Training scheduler\n\nWe replaced the eps-pred loss with v-pred loss and enable ZeroSNR. For videos, we resample to 16 FPS for training.\n\n**Stage 1**: We initially initialized from the image weights of version 1.2.0 and trained images at a resolution of 1x320x320. The objective of this phase was to fine-tune the 3D dense attention model to a sparse attention model. The entire fine-tuning process involved approximately 100k steps, with a batch size of 1024 and a learning rate of 2e-5. The image data was primarily sourced from SAM in version 1.2.0.\n\n\n**Stage 2**: We trained the model jointly on images and videos, with a maximum resolution of 93x320x320. The entire fine-tuning process involved approximately 300k steps, with a batch size of 1024 and a learning rate of 2e-5. The image data was primarily sourced from SAM in version 1.2.0, while the video data consisted of the unfiltered Panda70m. In fact, the model had nearly converged around 100k steps, and by 300k steps, there were no significant gains. Subsequently, we performed data cleaning and caption rewriting, with further data analysis discussed at the end.\n\n**Stage 3**: We fine-tuned the model using our filtered Panda70m dataset, with a fixed resolution of 93x352x640. The entire fine-tuning process involved approximately 30k steps, with a batch size of 1024 and a learning rate of 1e-5.\n\n### Training Image-to-Video Diffusion Model\n\n#### Framework\n\n<center>\n<figure>\n\t<img src=\"https://github.com/user-attachments/assets/41e22292-8d8b-469e-940a-6e5ae00bf620\" />\n</figure>\n</center>\n\nIn terms of framework, Open-Sora Plan v1.3 continues to use the Inpainting model architecture from Open-Sora Plan v1.2.\n\n#### Data processing\n\nFor data processing, Open-Sora Plan v1.3 introduces two new mask types: an all-1 mask and an all-0 mask. This brings the total number of mask types in the Inpainting Model to six.\n\n<center>\n<figure>\n\t<img src=\"https://github.com/user-attachments/assets/f31b222e-811c-49b9-839c-f72fb85c4ee4\" />\n</figure>\n</center>\n\nIn the figure above, black indicates retained frames, while white denotes discarded frames. The corresponding frame strategies are as follows:\n\n- **Clear**: Retain all frames.\n- **T2V**: Discard all frames.\n- **I2V**: Retain only the first frame; discard the rest.\n- **Transition**: Retain only the first and last frames; discard the rest.\n- **Continuation**: Retain the first $$n$$ frames; discard the rest.\n- **Random**: Retain $$n$$ randomly selected frames; discard the rest.\n\n#### Progressive training\n\nThe Open-Sora Plan v1.3 uses more data for training and employs a progressive training approach to help the model understand frame-based inpainting tasks.\n\nSince the Inpainting Model supports various mask inputs, different mask inputs correspond to tasks of varying difficulty levels. Therefore, we can first teach the model simple tasks, such as random masks, allowing it to develop a basic capability for frame-based inpainting before gradually increasing the proportion of more challenging tasks. It is important to note that at different training stages, we ensure that at least 5% of the data the model sees pertains to T2V tasks, which is aimed at enhancing the model's understanding of prompts.\n\nThe model weights are initialized from the T2V model with zero initialization. The batch size is fixed at 256, and the learning rate is set to 1e-5, using a two-stage training approach.\n\n**Stage 1**: Any resolution and duration within 93x102400 (320x320), using unfiltered motion and aesthetic low-quality data:\n\n(1) Step 1: t2v 10%, continuation 40%, random mask 40%, clear 10%. Ensure that at least 50% of the frames are retained during continuation and random mask, training with 4 million samples.\n\n(2) Step 2: t2v 10%, continuation 40%, random mask 40%, clear 10%. Ensure that at least 25% of the frames are retained during continuation and random mask, training with 4 million samples.\n\n(3) Step 3: t2v 10%, continuation 40%, random mask 40%, clear 10%. Ensure that at least 12.5% of the frames are retained during continuation and random mask, training with 4 million samples.\n\n(4) Step 4: t2v 10%, continuation 25%, random mask 60%, clear 5%. Ensure that at least 12.5% of the frames are retained during continuation and random mask, training with 4 million samples.\n\n(5) Step 5: t2v 10%, continuation 25%, random mask 60%, clear 5%, training with 8 million samples.\n\n(6) Step 6: t2v 10%, continuation 10%, random mask 20%, i2v 40%, transition 20%, training with 16 million samples.\n\n(7) Step 7: t2v 5%, continuation 5%, random mask 10%, i2v 40%, transition 40%, training with 10 million samples.\n\n**Stage 2:** Any resolution and duration within 93x236544 (e.g., 480x480, 640x352, 352x640), using filtered motion and aesthetic high-quality data:\n\nt2v 5%, continuation 5%, random mask 10%, i2v 40%, transition 40%, training with 15 million samples.\n\n#### About the Semantic Adapter\n\nWe conducted further experiments on the Semantic Adapter module and compared the video quality of Image-to-Video under various Image Encoders, including [Clip](https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K) and [Dino v2](https://huggingface.co/timm/vit_large_patch14_dinov2.lvd142m). We also attempted strategies such as directly injecting image embeddings into cross-attention or extracting features from the Image Encoder using Qformer before injecting them into cross-attention. \n\nUnder various strategies, we did not observe significant performance improvements; the impact on video quality was much smaller than that of the dataset. Therefore, we decided not to include the Semantic Adapter in Open-Sora Plan v1.3.\n\n#### Noise Injection Strategy for Conditional Images\n\nResearchs like [CogVideoX](https://arxiv.org/abs/2408.06072) and [Stable Video Diffusion](https://stability.ai/stable-video) have indicated that adding a certain amount of noise to Conditional Images can enhance the generalization capability of I2V models and achieve a greater range of motion. Therefore, we will implement this strategy in Open-Sora Plan v1.3, just the same as in [CogVideoX](https://arxiv.org/abs/2408.06072).\n\n### The implementation of Skiparse Attention\n\nSkiparse is theoretically easy to understand and straightforward to implement. Its implementation mainly relies on the rearrange operation, which reduces the sequence length of latents before entering `F.scaled_dot_product_attention()`. Aside from this adjustment, no other modifications are made. For simplicity, the following discussion focuses solely on the self-attention part, excluding the attention mask.\n\nThe pseudocode implementation of Single Skip is as follows:\n\n```python\n# x.shape: (B,N,C)\ndef single_skip_rearrange(x, sparse_k):\n\treturn rearrange(x, 'b (g k) d -> (k b) g d', k=sparse_k)\ndef reverse_sparse(x, sparse_k):\n\treturn rearrange(x, '(k b) g d -> b (g k) d', k=sparse_k)\nq, k, v = Q(x), K(x), V(x)\nq = add_rope(q)\nk = add_rope(k)\nq = single_skip_rearrange(q)\nk = single_skip_rearrange(k)\nv = single_skip_rearrange(v)\nhidden_states = F.scaled_dot_product_attention(q=q,k=k,v=v)\noutput = reverse_sparse(hidden_states)\n```\n\nThe core of the Skiparse operation lies in \"rearranging the sequence\", which corresponds to the Single Skip operation in the pseudocode:\n\n```python\nrearrange(x, '(g k) b d -> g (k b) d', k=sparse_k)\n```\n\nThis operation can be understood as a combination of a reshape and a transpose operation:\n\n<center>\n<figure>\n\t<img src=\"https://github.com/user-attachments/assets/e42c4dd5-ee95-42a8-b8c6-8cb803cd7e12\" height=300/>\n</figure>\n</center>\n\nIn this way, $$k$$ sub-sequences can be created, and $$k$$ can be moved to the batch dimension, allowing the Attention mechanism to compute the sub-sequences in parallel.\n\nUnderstanding Single Skip makes Group Skip easy to comprehend as well; it simply adds a grouping operation before the Skip. Its pseudocode is as follows:\n\n```python\n# x.shape: (B,N,C)\ndef group_skip_rearrange(x, sparse_k):\n\treturn rearrange(x, ' b (n m k) d -> (m b) (n k) d', m=sparse_k, k=sparse_k)\ndef reverse_sparse(x, sparse_k):\n\treturn rearrange(x, '(m b) (n k) d -> b (n m k) d', m=sparse_k, k=sparse_k)\nq, k, v = Q(x), K(x), V(x)\nq = add_rope(q)\nk = add_rope(k)\nq = group_skip_rearrange(q)\nk = group_skip_rearrange(k)\nv = group_skip_rearrange(v)\nhidden_states = F.scaled_dot_product_attention(q=q,k=k,v=v)\noutput = reverse_sparse(hidden_states)\n```\n\nEvery $$k^2$$ tokens form a repetition, and every $$k$$ tokens form a group. To help everyone better understand this operation, the following figure illustrates the situation when $$k=3$$:\n\n<center>\n<figure>\n\t<img src=\"https://github.com/user-attachments/assets/5e777862-d03c-4c7e-8ffc-1e1e9234b84e\"/>\n</figure>\n</center>\n\nIt is important to note that the rope is added before the Skiparse operation and cannot be placed after it, as the sequence after Skiparse will lose its original spatial positions.\n\n\n## Future Work and Discussion\n\n### CasualVideoVAE\nFor videos, increasing the compression ratio while maintaining the original latent dimension leads to significant information loss. Therefore, it is a trend to increase the latent dimension to achieve higher compression ratios. A more advanced VAE will be released in the next version.\n\n### Diffusion Model\nThe current 2B model in version 1.3.0 shows performance saturation during the later stages of training. However, it does not perform well in understanding physical laws (e.g., a cup overflowing with milk, a car moving forward, or a person walking). We have 4 hypotheses regarding this issue:\n\n#### The current data domain is too narrow.\n\nWe randomly sampled 2,000 videos from Panda70m and conducted manual verification, finding that less than 1% featured cars in motion, and there were even fewer than 10 videos of people walking. Approximately 80% of the videos consist of half-body conversations with multiple people in front of the camera. Therefore, we speculate that the narrow data domain of Panda70m restricts the model's ability to generate many scenarios. We plan to collect more data in the next version.\n\n#### Joint training of images and videos\n\nModels such as [Open-Sora v1.2](https://github.com/hpcaitech/Open-Sora), [EasyAnimate v4](https://github.com/aigc-apps/EasyAnimate), and [Vchitect-2.0](https://github.com/Vchitect/Vchitect-2.0) can easily generate high-visual-quality videos, possibly due to their direct inheritance of image weights ([Pixart-Sigma](https://pixart-alpha.github.io/PixArt-sigma-project/), [HunyuanDiT](https://github.com/Tencent/HunyuanDiT), [SD3](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers)). They train the model with a small amount of video data to learn how to flow along the time axis based on 2D images. However, we trained images from scratch with only 10M-level data, which is far from sufficient. We have two hypotheses regarding the training strategy: (1) the first is to start joint training from scratch, with images significantly outnumbering videos; (2) The second is to first train a high-quality image model and then use joint training, with a higher proportion of videos at that stage. Considering the learning path and training costs, the second approach may offer more decoupling, while the first aligns better with scaling laws.\n\n#### The model still needs to scale\n\nBy observing the differences between [CogVideoX-2B](https://github.com/THUDM/CogVideo) and its 5B variant, we can clearly see that the 5B model understands more physical laws than the 2B model. We speculate that instead of spending excessive effort designing for smaller models, it may be more effective to leverage scaling laws to solve these issues. In the next version, we will scale up the model to explore the boundaries of video generation.\n\nWe currently have two plans: one is to continue using the Deepspeed/FSDP approach, sharding the EMA and text encoder across ranks with Zero3, which is sufficient for training 10-15B models. The other is to adopt [MindSpeed](https://gitee.com/ascend/MindSpeed) for various parallel strategies, enabling us to scale the model up to 30B.\n\n#### Supervised loss in training\n\nWhether flow-based models are more suitable than v-pred models remains uncertain and requires further ablation studies to determine.\n\n### How else can \"Skiparse\" skip?\n\nThe sparse method we use is theoretically and practically straightforward; however, its implementation treats the original video data purely as a one-dimensional sequence, neglecting the 2D spatial priors. Thus, we extended Skiparse to create Skiparse-2D, which is better suited for 2D Visuals.\n\n<center>\n<figure>\n\t<img src=\"https://github.com/user-attachments/assets/44bd5284-b4c0-4a9d-9f2e-5acbb2e3450f\" height=500/>\n</figure>\n</center>\n\nIn Skiparse-2D, a sparse ratio of $$k$$ represents the sparsity along the $$h$$ or $$w$$ direction. In terms of the number of tokens involved in attention computation, it is equivalent to the square of the sparse ratio in Skiparse-1D.\n\nWe conducted basic experiments comparing Skiparse-1D and Skiparse-2D. Under identical experimental settings, Skiparse-2D showed no improvement over Skiparse-1D in terms of loss or sampling results. Additionally, Skiparse-2D is less flexible to implement than Skiparse-1D. Therefore, we opted to use the Skiparse-1D approach for training in Open-Sora Plan v1.3.\n\nNevertheless, given our limited experimentation, the feasibility of Skiparse-2D remains worth exploring. Intuitively, Skiparse-2D better aligns with the spatial characteristics of visuals, and as sparse ratio $$k$$ increases, its approach intuitively approximates that of 2+1D. We therefore encourage interested researchers in the community to pursue further exploration in this area.\n"
  },
  {
    "path": "docs/Report-v1.5.0.md",
    "content": "## Report v1.5.0\r\n\r\nIn October 2024, we released Open-Sora Plan v1.3.0, introducing the sparse attention structure, Skiparse Attention, to the field of video generation for the first time. Additionally, we adopted the efficient WFVAE, significantly reducing encoding time and memory usage during training.\r\n\r\nIn Open-Sora Plan v1.5.0, We introduce several key updates to enhance the framework:\r\n\r\n1、Improved Sparse DiT, SUV. Building on Skiparse Attention, we extend sparse DiT into a U-shaped sparse structure. This design preserves speed advantages while enabling sparse DiT to achieve performance comparable to dense DiT.\r\n\r\n2、Higher-compression WFVAE. In Open-Sora Plan v1.5.0, we explore a WFVAE with an 8×8×8 downsampling rate. It outperforms the performance of the widely adopted 4×8×8 VAE in the community, while reducing the latent shape by half and shortening the attention sequence length.\r\n\r\n3、Data and model scaling. In Open-Sora Plan v1.5.0, we collect 1.1 billion high-quality images and 40 million high-quality videos. The model is scaled up to 8.5 billion parameters, resulting in strong overall performance.\r\n\r\n4、Simplified Adaptive Gradient Clipping strategy. Compared to the more complex batch-dropping method in version 1.3.0, version 1.5.0 maintains a simple adaptive gradient norm threshold for clipping, making it more compatible with various parallel training strategies.\r\n\r\nOpen-Sora Plan v1.5.0 is fully trained and inferred on Ascend 910-series accelerators, using the mindspeed-mm framework to support parallel training strategies.\r\n\r\n### Open-Source Release\r\n\r\nOpen-Sora Plan v1.5.0 is open-sourced with the following components:\r\n\r\n1、All training and inference code. You can also find the implementation of Open-Sora Plan v1.5.0 in the official [MindSpeed-MM](https://gitee.com/ascend/MindSpeed-MM) repository.\r\n\r\n2、The WFVAE weights with 8×8×8 compression, along with the 8.5B SUV denoiser weights.\r\n\r\n## Detailed Technical Report\r\n\r\n### Data collection and processing\r\n\r\nOur dataset includes 1.1B images from [Recap-DataComp-1B](https://huggingface.co/datasets/UCSC-VLAA/Recap-DataComp-1B)、[COYO-700M](https://github.com/kakaobrain/coyo-dataset)、[LAION-Aesthetics](https://laion.ai/blog/laion-aesthetics/), with no filtering applied aside from resolution checks. The video data are drawn from [Panda-70M](https://github.com/snap-research/Panda-70M) and internal sources, and filtered using the same protocol as in Open-Sora Plan v1.3.0, yielding 40M high-quality videos.\r\n\r\n### Adaptive Grad Clipping\r\n\r\nIn Open-Sora Plan v1.3.0, we introduce an Adaptive Grad Clipping strategy based on discarding gradient-abnormal batches. While highly stable, this method involve overly complex execution logic. In Open-Sora Plan v1.5.0, we optimize the strategy by maintaining the gradient norm threshold via an exponential moving average (EMA). Gradients exceeding the threshold are clipped accordingly. This approach effectively extends the fixed threshold of 1.0, which is commonly used in large-scale models, into a dynamic, training-dependent threshold.\r\n\r\n```python\r\n'''\r\n\tmoving_avg_max_grad_norm: the maximum gradient norm maintained via EMA\r\n\tmoving_avg_max_grad_norm_var: the variance of the maximum gradient norm maintained via EMA\r\n\tclip_threshold: the gradient clipping threshold computed using the 3-sigma rule\r\n\tema_decay: the EMA decay coefficient, typically set to 0.99.\r\n\tgrad_norm: grad norm at the current step \r\n'''\r\nclip_threshold = moving_avg_max_grad_norm + 3.0 * (moving_avg_max_grad_norm_var ** 0.5)\r\nif grad_norm <= clip_threshold:\r\n    # If the gradient norm is below the clipping threshold, the parameters are updated normally at this step, and both the moving_avg_max_grad_norm and moving_avg_max_grad_norm_var are updated accordingly.\r\n    moving_avg_max_grad_norm = ema_decay * moving_avg_max_grad_norm + (1 - ema_decay) * grad_norm\r\n    max_grad_norm_var = (moving_avg_max_grad_norm - grad_norm) ** 2\r\n    moving_avg_max_grad_norm_var = ema_decay * moving_avg_max_grad_norm_var + (1 - ema_decay) * max_grad_norm_var\r\n    # update weights...\r\nelse:\r\n    # If the gradient norm exceeds the clipping threshold, the gradients are first clipped to reduce the norm to the threshold value before updating the parameters.\r\n    clip_coef = grad_norm / clip_threshold\r\n    grads = clip(grads, clip_coef) # clipping grads\r\n    # update weights...\r\n```\r\n\r\nCompared to the strategy in v1.3.0, this approach is simpler to implement and effectively addresses the issue of loss spikes that occur in the later stages of diffusion training when the gradient norm is significantly below 1.0.\r\n\r\n### WFVAE with 8x8x8 compression\r\n\r\nIn version 1.5.0, we increase the temporal compression rate of the VAE from 4× to 8×, reducing the latent shape to half that of the previous version. This enables the generation of videos with higher frame counts.\r\n\r\n| Model             | THW(C)        | PSNR         | LPIPS         | rFVD         |\r\n| ----------------- | ------------- | ------------ | ------------- | ------------ |\r\n| CogVideoX         | 4x8x8 (16)    | <u>36.38</u> | 0.0243        | <u>50.33</u> |\r\n| StepVideo         | 8x16x16 (16)  | 33.61        | 0.0337        | 113.68       |\r\n| LTXVideo          | 8x32x32 (128) | 33.84        | 0.0380        | 150.87       |\r\n| Wan2.1            | 4x8x8 (16)    | 35.77        | **0.0197**    | **46.05**    |\r\n| Ours （WF-VAE-M） | 8x8x8 (32)    | **36.91**    | <u>0.0205</u> | 52.53        |\r\n\r\n**Test on an open-domain dataset with 1K samples.**\r\n\r\nFor more details on WFVAE, please refer to [WF-VAE: Enhancing Video VAE by Wavelet-Driven Energy Flow for Latent Video Diffusion Model](https://arxiv.org/abs/2411.17459)\r\n\r\n### Training Text-to-Video Diffusion Model\r\n\r\n#### Framework —— SUV: A Sparse U-shaped Diffusion Transformer For Fast Video Generation\r\n\r\nIn Open-Sora Plan v1.3.0, we discuss the strengths and weaknesses of Full 3D Attention and 2+1D Attention. Based on their characteristics, we propose Skiparse Attention, a novel global sparse attention mechanism.\r\n\r\nUnder a predefined sparsity $k$, Skiparse Attention selects a subsequence of length $\\frac{1}{k}$ of the original sequence in an alternating Single-Skip and Group-Skip pattern for attention interaction. This design approximates the effect of Full 3D Attention. As the sparsity increases, the selected positions become more widely spaced; as it decreases, the positions become more concentrated. Regardless of the sparsity, Skiparse Attention remains global.\r\n\r\nIn Open-Sora Plan v1.5.0, we interpret this sparse interaction pattern as a form of token-level information downsampling. Sparser Skiparse Attention performs more semantic-level interactions, while denser Skiparse Attention captures fine-grained information. Following the multi-scale design principle in neural networks, we introduce Skiparse Attention with U-shaped sparsity variation: low-sparsity Skiparse Attention is used in shallow layers, with Full 3D Attention applied at the shallowest layer, and high-sparsity Skiparse Attention in deeper layers. Inspired by the UNet architecture, we further incorporate long skip connections between stages with identical sparsity. This U-shaped DiT architecture based on Skiparse Attention is referred to as **SUV**.\r\n\r\n![SUV](https://github.com/user-attachments/assets/6eb54e37-7077-4746-a4c6-9b7165dd48fe)\r\n\r\nIn Open-Sora Plan v1.5.0, we adopt an SUV architecture based on MMDiT. Skiparse Attention is applied to the video latents, while the text embeddings are only repeated to align with the skiparse-processed latent shape, without any sparsification.\r\n\r\nThe SUV architecture offers the following advantages:\r\n\r\n1、SUV is the first sparsification method proven effective for video generation. Our ablation studies show that it achieves performance comparable to dense DiT within the approximate training steps. Moreover, it can be applied during both pretraining and inference. Testing on the Ascend 910B platform at 121×576×1024 shape shows SUV runs over 35% faster than Dense DiT, with the attention operation alone gaining a speed boost of over 45%.\r\n\r\n2、Unlike UNet structures that explicitly downsample feature maps and cause information loss, the U-shaped structure of SUV operates on attention. The shape of the feature map remains unchanged, preserving information while altering only the granularity of token-level interactions.\r\n\r\n3、Skiparse Attention and SUV only change the attention computation during the forward pass instead of modifying model weights. This allows dynamic adjustment of sparsity throughout training: lower sparsity for image or low-resolution video training, and higher sparsity for high-resolution video training. As a result, FLOPS grow approximately linearly with increasing of sequence length.\r\n\r\n\r\nA more detailed analysis of the SUV architecture will be released in a future arXiv update.\r\n\r\n#### Training Stage\r\n\r\nOur training consists of two stages: Text-to-Image and Text-to-Video.\r\n\r\n#### Text-to-Image\r\n\r\nPrevious studies have shown that image weights trained on synthetic data may negatively impact video training. Therefore, in the v1.5.0 update, we choose to train image weights using a much larger corpus of real-world data, totaling 1.1B images. Since image data come in various resolutions, whereas videos are primarily in a 9:16 aspect ratio, we adopt multi-resolution training for images using five common aspect ratios—(1,1), (3,4), (4,3), (9,16), and (16,9)—along with the Min-Max Token Strategy. In contrast, video training is conducted using a fixed 9:16 resolution.\r\n\r\nThe difference between Skiparse Attention and Full Attention lies in the token sequences involved in the forward computation; the required weights remain identical. Therefore, we can first train the model using Dense MMDiT with Full 3D Attention, and then fine-tune it to the Sparse MMDiT mode after sufficient training.\r\n\r\n**Image-Stage-1:** Training is conducted using 512 Ascend 910B accelerators. We train a randomly initialized Dense MMDiT on 256²-pixel images with multi-resolution enabled. The learning rate is set to 1e-4, with a batch size of 8096. This stage runs for a total of 225k steps.\r\n\r\n**Image-Stage-2:** Training is conducted using 384 Ascend 910B accelerators. We train on 384²-pixel images with multi-resolution still enabled. The learning rate remains 1e-4, the batch size is 6144, and training lasts for 150k steps.\r\n\r\n**Image-Stage-3:** Training is conducted using 256 Ascend 910B accelerators. We train on 288x512 images with force resolution. The learning rate is 1e-4, the batch size is 4096, and training lasts for 110k steps. This stage completes the Dense MMDiT training.\r\n\r\n**Image-Stage-4:** Training is conducted using 256 Ascend 910B accelerators. We initialize the SUV model using the pretrained weights from Dense MMDiT, with skip connections zero-initialized to ensure that the model could produce non-noise outputs at the start. In practice, zero-shot inference reveals that the generated images contained meaningful low-frequency structures. Our experiments confirm that fine-tuning from Dense DiT to SUV converges quickly. This stage uses a fixed resolution of 288×512, a learning rate of 1e-4, a batch size of 4096, and is trained for approximately 160k steps.\r\n\r\n#### Text-to-Video\r\n\r\nFor video training, we fix the aspect ratio at 9:16 and training solely on video data instead of joint training with image data. All training in this stage is performed using 512 Ascend 910B accelerators.\r\n\r\n**Video-Stage-1:** Starting from the SUV weights pretrained during the Text-to-Image phase, we train on videos with a shape of 57×288×512 for about 40k steps. The setup includes a learning rate of 6e-5, TP/SP parallelism of 2, gradient accumulation set to 2, a micro batch size of 2, and a global batch size of 1024. Videos are trained at 24 fps, representing approximately 2.4 seconds (57/24 ≈ 2.4s) of content per sample. This stage marks the initial adaptation from image-based to video-based weights, for which shorter video clips are intentionally selected to ensure stable initialization.\r\n\r\n**Video-Stage-2:** We further train on videos with a shape of 57×288×512 for 45k steps, keeping the learning rate, TP/SP parallelism, and gradient accumulation settings unchanged. However, the training frame rate is reduced to 12 fps, corresponding to ~4.8 seconds of video content per sample (57/12 ≈ 4.8s). This stage aims to enhance temporal learning without increasing sequence length, serving as preparation for later high-frame-counts training.\r\n\r\n**Video-Stage-3:** We train on videos with a shape of 121×288×512 for approximately 25k steps. The learning rate is adjusted to 4e-5, with TP/SP parallelism set to 4, gradient accumulation steps set to 2, a micro batch size of 4, and a global batch size of 1024. In this stage, we revert to a training frame rate of 24 fps.\r\n\r\n**Video-Stage-4:** We conduct training on videos with a shape of 121×576×1024 for a total of 16k + 9k steps. The learning rates are set to 2e-5 and 1e-5 for the two phases, respectively. TP/SP parallelism is configured as 4, with gradient accumulation steps set to 4, a micro batch size of 1, and a global batch size of 512.\r\n\r\n**Video-Stage-5:** We train on a high-quality subset of the dataset for 5k steps, using a learning rate of 1e-5. TP/SP parallelism is set to 4, with gradient accumulation steps of 4, a micro batch size of 1, and a global batch size of 512.\r\n\r\n #### Performance on Vbench\r\n\r\n| Model                      | Parameters | Total Score   | Quality Score | Semantic Score | **aesthetic quality** |\r\n| -------------------------- | ---------- | ------------- | ------------- | -------------- | --------------------- |\r\n| Mochi-1                    | 10B        | 80.13%        | 82.64%        | 70.08%         | 56.94%                |\r\n| CogvideoX-2B               | 2B         | 80.91%        | 82.18%        | 75.83%         | 60.82%                |\r\n| CogvideoX-5B               | 5B         | 81.61%        | 82.75%        | 77.04%         | 61.98%                |\r\n| Step-Video-T2V             | 30B        | 81.83%        | <u>84.46%</u> | 71.28%         | 61.23%                |\r\n| CogvideoX1.5-5B            | 5B         | 82.17%        | 82.78%        | **79.76%**     | 62.79%                |\r\n| Gen-3                      | -          | 82.32%        | 84.11%        | 75.17%         | <u>63.34%</u>         |\r\n| HunyuanVideo (Open-Source) | 13B        | **83.24%**    | **85.09%**    | 75.82%         | 60.36%                |\r\n| Open-Sora Plan v1.5.0      | 8B         | <u>83.02%</u> | 84.24%        | <u>78.18%</u>  | **66.89%**            |\r\n\r\n\r\n### Training Image-to-Video Diffusion Model\r\n\r\nComing Soon...\r\n\r\n### Future Work\r\n\r\nCurrently, open-source models such as Wan2.1 have achieved performance comparable to closed-source commercial counterparts. Given the gap in computing resources and data availability compared to industry-scale efforts, the future development of the Open-Sora Plan will focus on the following directions:\r\n\r\n1、Latents Cache。\r\n\r\nIn the training process of Text-to-Video models, the data must be processed through two key modules—the Variational Autoencoder (VAE) and the Text Encoder—to extract features from both video/images and their corresponding prompts. These encoded features serve as inputs to the training model. However, in existing industry practices, feature encoding is redundantly performed on the multimodal training dataset during every training epoch. This leads to additional computational overhead and significantly prolongs the total training time.\r\n\r\nSpecifically, in conventional training pipelines, the VAE and Text Encoder modules are typically kept resident in GPU memory to perform feature encoding in real time during each epoch. While this ensures on-the-fly encoding, it also results in persistently high GPU memory usage, becoming a major bottleneck for training efficiency. This issue is exacerbated when handling large-scale datasets or complex models, where memory constraints further limit model capacity and training speed.\r\n\r\nTo address the above issue, we propose an optimization strategy that replaces repeated feature computation with feature lookup. The core idea is to decouple feature encoding from model training. Specifically, during pretraining or the first training epoch, we compute and store the most computationally expensive text prompt features in external high-performance storage. During subsequent training, the model directly loads these precomputed features from storage, avoiding redundant encoding operations. This design significantly reduces computational overhead and GPU memory usage, allowing more memory to be allocated to model training.\r\n\r\nBased on the following configuration environment, we compare the training time per epoch and per step before and after applying the feature caching strategy. Experimental results show that storing precomputed features reduces multi-epoch training time by approximately 30% and frees up around 20% of GPU memory resources.\r\n\r\n| **Configuration** |                 **Details**                 |\r\n| :---------------: | :-----------------------------------------: |\r\n|       Model       | Open-Sora Plan v1.5.0 (2B-level parameters) |\r\n|      Dataset      |         100K images and 10K videos          |\r\n|   Accelerators    |             8× Nvidia A800 GPUs             |\r\n|  Feature Storage  |         Huawei OceanStor AI Storage         |\r\n\r\nTest cases:\r\n\r\n| **Training Stage** | **Test Type**          | **Batch Size** | **Time per Step** | **Time per Epoch** | **Memory Usage** |\r\n| ------------------ | ---------------------- | -------------- | ----------------- | ------------------ | ---------------- |\r\n| Low-Res Images     | General Method         | 64             | 6.53s             | 21 min 12s         | 56 GB            |\r\n|                    | Feature Caching Method | 64             | 4.10s             | 13 min 19s         | 40 GB            |\r\n|                    | General Method         | 128            | 12.78s            | 20 min 39s         | 74 GB            |\r\n|                    | Feature Caching Method | 128            | 7.81s             | 12 min 38s         | 50 GB            |\r\n| Low-Res Videos     | General Method         | 8              | 8.90s             | 26 min 23s         | 68 GB            |\r\n|                    | Feature Caching Method | 8              | 7.78s             | 23 min 05s         | 51 GB            |\r\n| High-Res Videos    | General Method         | 4              | 17.00s            | 101 min            | 71 GB            |\r\n|                    | Feature Caching Method | 4              | 16.00s            | 97 min             | 57 GB            |\r\n\r\n2、Improved DiT pretraining with sparse or linear attention. In v1.3.0, we introduce the first DiT pretrained with sparse attention in the community. This is extended in v1.5.0 into the SUV architecture, enabling sparse DiT to achieve performance comparable to its dense counterpart. While sparse and linear attention have demonstrated significant success in the LLM domain, their application in video generation remains underexplored. In future versions, we plan to further investigate the integration of sparse and linear attention into video generation models.\r\n\r\n3、MoE-based DiT. Since the release of [Mixtral-8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1), the MoE (Mixture-of-Experts) paradigm has become a common approach for scaling LLMs to larger parameter sizes. Currently, open-source video generation models are capped at around 14B parameters, which is still relatively small compared to the 100B+ scales in the LLM field. Incorporating MoE into the DiT architecture, and exploring its combination with sparse and linear attention, is a future direction under consideration by the Open-Sora Plan team.\r\n\r\n4、Unified video generation models for both generation and understanding. The March release of GPT-4o demonstrates that unified architectures combining generation and understanding can offer fundamentally different capabilities compared to purely generative models. In the video domain, we should similarly anticipate the potential breakthroughs that such unified generative models might bring.\r\n\r\n5、Enhancing Image-to-Video generation models. Current approaches in this field still largely follow either the SVD paradigm or the inpainting-based paradigm adopted since Open-Sora Plan v1.2.0. Both approaches require extensive fine-tuning of pretrained Text-to-Video models. From a practical standpoint, Text-to-Video is more aligned with academic exploration, while Image-to-Video is more relevant to real-world production scenarios. As a result, developing a new paradigm for Image-to-Video will be a key focus for the Open-Sora Plan team moving forward.\r\n\r\n"
  },
  {
    "path": "docs/Report-v1.5.0_cn.md",
    "content": "## Report v1.5.0\r\n\r\n在2024年的10月，我们发布了Open-Sora Plan v1.3.0，第一次将一种稀疏化的attention结构——skiparse attention引入video generation领域。同时，我们采用了高效的WFVAE，使得训练时的编码时间和显存占用大大降低。\r\n\r\n在Open-Sora Plan v1.5.0中，Open-Sora Plan引入了几个关键的更新：\r\n\r\n1、更好的sparse dit——SUV。在skiparse attention的基础上，我们将sparse dit扩展至U形变化的稀疏结构，使得在保持速度优势的基础上sparse dit可以取得和dense dit相近的性能。\r\n\r\n2、更高压缩率的WFVAE。在Open-Sora Plan v1.5.0中，我们尝试了8x8x8下采样率的WFVAE，它在性能上媲美社区中广泛存在的4x8x8下采样率的VAE的同时latent shape减半，降低attention序列长度。\r\n\r\n3、data和model scaling。在Open-Sora Plan v1.5.0中，我们收集了1.1B的高质量图片数据和40m的高质量视频数据，并将模型大小scale到8.5B，使最终得到的模型呈现出不俗的性能。\r\n\r\n4、更简易的Adaptive Grad Clipping。相比于version 1.3.0中较复杂的丢弃污点batch的策略，在version 1.5.0中我们简单地维护一个adaptive的grad norm threshold并clipping，以此更适应各种并行策略的需要。\r\n\r\nOpen-Sora Plan v.1.5.0全程在昇腾910系列加速卡上完成训练和推理，并采用mindspeed-mm训练框架适配并行策略。\r\n\r\n### Open-Source Release\r\n\r\nOpen-Sora Plan v1.5.0的开源包括：\r\n\r\n1、所有训练和推理代码。你也可以在[MindSpeed-MM](https://gitee.com/ascend/MindSpeed-MM)官方仓库找到open-sora plan v1.5.0版本的实现。\r\n\r\n2、8x8x8下采样的WFVAE权重以及8.5B的SUV去噪器权重。\r\n\r\n## Detailed Technical Report\r\n\r\n### Data collection and processing\r\n\r\n我们共收集了来自Recap-DataComp-1B、Coyo700M、Laion-aesthetic的共1.1B图片数据。对于图片数据，我们不进行除了分辨率之外的筛选。我们的视频数据来自于Panda70M以及其他自有数据。对于视频数据，我们采用与Open-Sora Plan v1.3.0相同的处理策略进行筛选，最终数据量为40m的高质量视频数据。\r\n\r\n### Adaptive Grad Clipping\r\n\r\n在Open-Sora Plan v1.3.0中，我们介绍了一种基于丢弃梯度异常batch的Adaptive Grad Clipping策略，这种策略具有很高的稳定性，但是执行逻辑过于复杂。因此，在Open-Sora Plan v1.5.0中，我们选择将该策略进行优化，采用EMA方式维护grad norm的threshold，并在grad norm超过该threshold时裁剪到threshold以下。该策略本质上是将大模型领域常用的1.0常数grad norm threshold扩展为一个随着训练进程动态变化的threshold。\r\n\r\n```python\r\n'''\r\n\tmoving_avg_max_grad_norm: EMA方式维护的最大grad norm\r\n\tmoving_avg_max_grad_norm_var: EMA方式维护的最大grad norm的方差\r\n\tclip_threshold: 根据3 sigma策略计算得到的梯度裁剪阈值\r\n\tema_decay: EMA衰减系数，一般为0.99\r\n\tgrad_norm: 当前step的grad norm\r\n'''\r\nclip_threshold = moving_avg_max_grad_norm + 3.0 * (moving_avg_max_grad_norm_var ** 0.5)\r\nif grad_norm <= clip_threshold:\r\n    # grad norm小于裁剪阈值，则该step参数正常更新，同时更新维护的moving_avg_max_grad_norm 和 moving_avg_max_grad_norm_var\r\n    moving_avg_max_grad_norm = ema_decay * moving_avg_max_grad_norm + (1 - ema_decay) * grad_norm\r\n    max_grad_norm_var = (moving_avg_max_grad_norm - grad_norm) ** 2\r\n    moving_avg_max_grad_norm_var = ema_decay * moving_avg_max_grad_norm_var + (1 - ema_decay) * max_grad_norm_var\r\n    参数更新...\r\nelse:\r\n    # grad norm大于裁剪阈值，则先裁剪grad使grad norm减少至clip_threshold，再进行参数更新。\r\n    clip_coef = grad_norm / clip_threshold\r\n    grads = clip(grads, clip_coef) # 裁剪grads\r\n    参数更新...\r\n```\r\n\r\n该策略相较于v1.3.0中策略实现更简单，且能够很好应对diffusion训练后期grad norm远小于1.0时仍存在loss spike的问题。\r\n\r\n### WFVAE with 8x8x8 downsampling\r\n\r\n在V1.5.0版本中，我们将VAE的时间压缩率从4倍压缩提高至8倍压缩，使得对于同样原始尺寸的视频，latent shape减少为先前版本的一半，这使得我们可以实现更高帧数的视频生成。\r\n\r\n| Model             | THW(C)        | PSNR         | LPIPS         | rFVD         |\r\n| ----------------- | ------------- | ------------ | ------------- | ------------ |\r\n| CogVideoX         | 4x8x8 (16)    | <u>36.38</u> | 0.0243        | <u>50.33</u> |\r\n| StepVideo         | 8x16x16 (16)  | 33.61        | 0.0337        | 113.68       |\r\n| LTXVideo          | 8x32x32 (128) | 33.84        | 0.0380        | 150.87       |\r\n| Wan2.1            | 4x8x8 (16)    | 35.77        | **0.0197**    | **46.05**    |\r\n| Ours （WF-VAE-M） | 8x8x8 (32)    | **36.91**    | <u>0.0205</u> | 52.53        |\r\n\r\n**Test on an open-domain dataset with 1K samples.**\r\n\r\nWFVAE详情请见[WF-VAE: Enhancing Video VAE by Wavelet-Driven Energy Flow for Latent Video Diffusion Model](https://arxiv.org/abs/2411.17459)\r\n\r\n### Training Text-to-Video Diffusion Model\r\n\r\n#### Framework —— SUV: A Sparse U-shaped Diffusion Transformer For Fast Video Generation\r\n\r\n在Open-Sora Plan v1.3.0中，我们讨论了Full 3D Attention以及2+1D Attention的优劣，并综合他们的特点提出了Skiparse Attention——一种新型的global sparse attention。\r\n\r\n在一个事先指定的sparse ratio $k$ 下，Skiparse Attention按照Single Skip - Group Skip交替的方式选定原序列长度 $\\frac{1}{k}$ 的子序列进行attention交互，以此达到近似Full 3D Attention的效果。在Skiparse Attention中，sparse ratio越大，子序列在原序列中的位置越稀疏；sparse ratio越小，子序列在原序列中的位置越密集。但无论sparse ratio为多少，Skiparse Attention总是global的。\r\n\r\n在Open-Sora Plan v1.5.0中，我们将这种稀疏交互方式看作一种token上的信息下采样，越稀疏的Skiparse Attention是一种更偏语义级的信息交互，越密集的Skiparse Attention是一种更偏细粒度的信息交互。遵循神经网络中多尺度设计的准则，我们在网络中引入U形变化稀疏度的Skiparse Attention，即浅层采用稀疏度低的Skiparse Attention，并在最浅层使用Full 3D Attention，深层采用稀疏度高的Skiparse Attention。特别的，类比UNet的设计，我们在相同稀疏度的Stage之间引入了Long Skip Connection。我们将这种U形变化的基于Skiparse Attention的DiT称之为SUV。\r\n\r\n![SUV](https://github.com/user-attachments/assets/6eb54e37-7077-4746-a4c6-9b7165dd48fe)\r\n\r\n在Open-Sora Plan v1.5.0中我们采用了基于MMDiT的SUV架构。对于video latents，我们对其进行skiparse attention操作，对于text embedding，我们仅对其进行repeat以对齐skiparse后的latent shape而不进行任何稀疏化操作。\r\n\r\nSUV架构存在以下优点：\r\n\r\n1、SUV是首个在视频生成模型上验证有效的稀疏化方法，在我们的消融实验中表明其在同样训练步数下可以达到接近dense dit的性能，且可以同时应用于预训练和推理中。在910B测试平台下，在121x576x1024的视频shape下，SUV的推理速度相比Dense DiT提升35%以上，其中Attn部分速度提升45%以上。\r\n\r\n2、相较于UNet结构对feature map进行显式的下采样造成了信息损失，SUV的U形结构作用在Attention上，feature map的shape并没有发生变化，即信息并未发生损失，改变的只是token间信息交互的粒度。\r\n\r\n3、Skiparse Attention及SUV不改变权重大小，只改变forward时attention的计算方式。这使得我们可以随着训练进程动态调整稀疏度，在图片训练或低分辨率视频训练时采用较低的稀疏度，在高分辨率视频训练时提高稀疏度，获得随序列长度近似线性增长的FLOPS。\r\n\r\n对SUV架构更细致的分析，将会在后续更新至arxiv。\r\n\r\n#### Training Stage\r\n\r\n我们的训练包括Text-to-Image和Text-to-Video两个阶段。\r\n\r\n#### Text-to-Image\r\n\r\n先前的工作表明从合成数据训练得到的图像权重可能会影响视频训练时的效果。因此，在v1.5.0更新中，我们选择在更大的真实数据域内训练图像权重。我们收集了共1.1B的图片数据进行训练。由于图片存在多种不同的分辨率，而视频主要为9：16分辨率，因此我们选择在训练图片权重时开启多分辨率（5个常见宽高比：(1,1), (3,4), (4,3), (9,16), (16,9) ）及Min-Max token Strategy训练，而在训练视频时采用固定9：16的宽高比固定分辨率训练。\r\n\r\nSkiparse Attention与Full Attention的区别在于前向过程中参与计算的token序列不同，所需要的权重变量则完全相同。因此，我们可以先用Full 3D Attention的Dense MMDiT做训练，并在训练充分后Fine-tune至Sparse MMDiT模式。\r\n\r\n**Image-Stage-1:** 采用512张Ascend 910B进行训练。 我们采用随机初始化的Dense MMDiT在256^2px级别分辨率的图片上训练，开启多分辨率。学习率为1e-4，batch size为8096。在这个阶段我们总共训练了225k steps。\r\n\r\n**Image-Stage-2:** 采用384张Ascend 910B进行训练。在384^px级别的图片上训练，开启多分辨率训练。学习率为1e-4，batch size为6144，共训练150k step。\r\n\r\n**Image-Stage-3:** 采用256张Ascend 910B进行训练。固定288x512分辨率训练。学习率为1e-4，batch size为4096，共训练110k step。Dense MMDiT阶段训练完成。\r\n\r\n**Image-Stage-4:** 采用256张Ascend 910B进行训练。采用Dense MMDiT的权重初始化SUV，其中skip connection采用零初始化，保证初始SUV权重能够推出非噪声图片。事实上，zero shot推理得到的图片具备一定的低频信息，我们验证了Dense DiT到SUV的finetune可以很快达成。该阶段固定分辨率为288x512，学习率为1e-4，batch size为4096，共训练约160k step。\r\n\r\n#### Text-to-Video\r\n\r\n在训练视频时，我们采用的宽高比固定为9：16，且并未采用视频图像联合训练，而是仅用视频数据做训练。以下训练均在512张Ascend 910B上完成。\r\n\r\n**Video-Stage-1:** 继承Text-to-Image阶段得到的SUV权重，我们在57x288x512的视频上训练了大约40k step，学习率为6e-5，TP/SP并行度为2，学习率为6e-5，梯度累积次数为2， micro batch size为2，global batch size为1024。在这个阶段，我们采用的train fps为24，即大约57/24≈2.4s的视频内容。该阶段作为图片权重到视频权重迁移的第一个阶段，我们选择了较短的视频训练作为良好的初始化。\r\n\r\n**Video-Stage-2:** 我们同样在57x288x512的视频上训练45k step，学习率、TP/SP并行度和梯度累积设置保持不变，但是train fps更改为12，即对应的原视频长度为57/12≈4.8s的内容。该阶段旨在不增加序列长度的同时提高对时序的学习，为后续高帧数训练阶段做准备。\r\n\r\n**Video-Stage-3:** 我们在121x288x512的视频上训练约25k step，学习率调整为4e-5、TP/SP并行度设置为4，梯度累积次数设置为2，micro batch size为4，global batch size为1024。在这个阶段我们重新采用train fps为24。\r\n\r\n**Video-Stage-4:** 在121x576x1024的视频上共训练16k + 9k step，学习率分别为2e-5和1e-5，TP/SP并行度设置为4，梯度累积次数设置为4，micro batch size为1，global batch size为512。\r\n\r\n**Video-Stage-5:** 我们选择数据中的高质量子集训练了5k step，学习率为1e-5，TP/SP并行度设置为4，梯度累积次数设置为4，micro batch size为1，global batch size为512。\r\n\r\n #### Performance on Vbench\r\n\r\n| Model                      | Parameters | Total Score   | Quality Score | Semantic Score | **aesthetic quality** |\r\n| -------------------------- | ---------- | ------------- | ------------- | -------------- | --------------------- |\r\n| Mochi-1                    | 10B        | 80.13%        | 82.64%        | 70.08%         | 56.94%                |\r\n| CogvideoX-2B               | 2B         | 80.91%        | 82.18%        | 75.83%         | 60.82%                |\r\n| CogvideoX-5B               | 5B         | 81.61%        | 82.75%        | 77.04%         | 61.98%                |\r\n| Step-Video-T2V             | 30B        | 81.83%        | <u>84.46%</u> | 71.28%         | 61.23%                |\r\n| CogvideoX1.5-5B            | 5B         | 82.17%        | 82.78%        | **79.76%**     | 62.79%                |\r\n| Gen-3                      | -          | 82.32%        | 84.11%        | 75.17%         | <u>63.34%</u>         |\r\n| HunyuanVideo (Open-Source) | 13B        | **83.24%**    | **85.09%**    | 75.82%         | 60.36%                |\r\n| Open-Sora Plan v1.5.0      | 8B         | <u>83.02%</u> | 84.24%        | <u>78.18%</u>  | **66.89%**            |\r\n\r\n### Training Image-to-Video Diffusion Model\r\n\r\nComming Soon...\r\n\r\n### Future Work\r\n\r\n目前，开源社区已经有与闭源商业版本相当性能的模型，如Wan2.1。鉴于算力和数据相比企业来说仍存在不足，后续Open-Sora Plan团队的改进方向为：\r\n\r\n1、Latents Cache。\r\n\r\n在Text2Video模型的训练过程中，训练数据需要经过变分自编码器（VAE）和文本编码器（Text Encoder）两个关键模块的处理，以实现对视频/图片和对应引导词的特征编码。这些编码后的特征数据作为模型训练的输入，参与后续训练流程。然而业界训练方案中，每个训练周期（Epoch）都需要对多模态训练数据集进行重复的特征编码计算，这不仅增加了额外的计算开销，还显著延长了整体训练时间。\r\n\r\n具体而言，在传统的训练流程中，VAE和Text Encoder模型通常需要常驻于GPU显存中，以便在每个Epoch中实时执行特征编码任务。这种设计虽然确保了特征编码的实时性，但也导致了GPU显存占用率居高不下，成为制约训练效率的主要瓶颈之一。尤其是在处理大规模数据集或复杂模型时，显存资源的紧张会进一步加剧这一问题，限制了模型的参数量和训练速度。\r\n\r\n为了解决上述问题，我们提出了一种特征值以查代算的优化方案。该方案的核心思想是将特征编码的计算过程与模型训练过程进行解耦。具体实现方式为：在训练前或首轮训练时计算耗时最高的引导词特征值，将其保存至外置高性能文件存储中。后续的训练过程中，模型可以直接从文件存储中读取这些预计算的特征数据，避免了重复的特征编码计算。这种设计不仅显著减少了计算资源的浪费，还大幅降低了GPU显存的占用率，使更多的显存资源可用于模型训练。\r\n\r\n基于以下配置环境，统计使用特征数据存储前后的单个epoch及单个step的训练数据。实验表明，特征值存储方案**可缩短约30%多轮迭代训练时间，同时释放约20%显存资源。**\r\n\r\n|  配置环境  |             详细信息              |\r\n| :--------: | :-------------------------------: |\r\n|    模型    | Open-Sora Plan v1.5.0 with 2B量级 |\r\n|   数据集   |         100K图片及10K视频         |\r\n| GPU服务器  |          8张Nvidia A800           |\r\n| 特征值存储 |       华为OceanStor AI存储        |\r\n\r\n测试数据：\r\n\r\n| 训练阶段     | 测试类型         | Batch Size | 单Step耗时 | 单Epoch耗时 | 显存占用 |\r\n| ------------ | ---------------- | ---------- | ---------- | ----------- | -------- |\r\n| 低分辨率图片 | 通用方案         | 64         | 6.53s      | 21min12s    | 56GB     |\r\n|              | 特征数据存储方案 | 64         | 4.10s      | 13min19s    | 40GB     |\r\n|              | 通用方案         | 128        | 12.78s     | 20min39s    | 74GB     |\r\n|              | 特征数据存储方案 | 128        | 7.81s      | 12min38s    | 50GB     |\r\n| 低分辨率视频 | 通用方案         | 8          | 8.90s      | 26min23s    | 68GB     |\r\n|              | 特征数据存储方案 | 8          | 7.78s      | 23min05s    | 51GB     |\r\n| 高分辨率视频 | 通用方案         | 4          | 17s        | 101min      | 71GB     |\r\n|              | 特征数据存储方案 | 4          | 16s        | 97min       | 57GB     |\r\n\r\n2、更好的基于稀疏化attention or 线性attention预训练的DiT。在V1.3.0中，我们推出了社区中第一个基于稀疏attention预训练的DiT，并在V1.5.0版本中将其扩展为SUV架构，使稀疏DiT获得了与Dense DiT相当的模型性能。稀疏attention和线性attention在LLM领域已经获得了很大的成功，但在视频生成领域中的应用仍不够明显。在后续版本中，我们将进一步探索稀疏attention和线性attention在video generation领域的应用。\r\n\r\n3、基于MoE的DiT。自Mixtral 8x7B发布以来，LLM领域通常会采用MoE的方式将模型scale至更大的参数量。目前开源视频模型的最大大小仅限于14B，相比于LLM领域上百B的参数量来说仍属于小模型。在DiT架构中引入MoE，以及MoE与稀疏attention和线性attention的结合，是Open-Sora Plan团队未来考虑的方向。\r\n\r\n4、生成和理解统一的视频生成模型。3月份gpt-4o的更新让大家认识到了生成理解统一架构的生成模型能够获得与纯生成模型完全不同的能力。在视频领域，我们同样应该期待一个统一的生成模型能够为我们带来哪些惊喜。\r\n\r\n5、更好的Image-to-Video模型。目前Image-to-Video领域仍基本遵循SVD范式和Open-Sora Plan v1.2.0起采用的Inpainting范式。这两种范式都需要在Text-to-Video模型权重的基础上进行长时间的finetune。从应用意义上看，Text-to-Video更接近于学术上的探索，而Image-to-Video则更贴近现实的生产环境。因此，Image-to-Video的更新范式也会是Open-Sora Plan团队未来的重点探索方向。\r\n\r\n"
  },
  {
    "path": "docs/VAE.md",
    "content": "\n### Data prepare\nThe organization of the training data is easy. We only need to put all the videos recursively in a directory. This makes the training more convenient when using multiple datasets.\n``` shell\nTraining Dataset\n|——sub_dataset1\n    |——sub_sub_dataset1\n        |——video1.mp4\n        |——video2.mp4\n        ......\n    |——sub_sub_dataset2\n        |——video3.mp4\n        |——video4.mp4\n        ......\n|——sub_dataset2\n    |——video5.mp4\n    |——video6.mp4\n    ......\n|——video7.mp4\n|——video8.mp4\n```\n\n### Training\n``` shell\nbash scripts/causalvae/train.sh\n```\nWe introduce the important args for training.\n\n| Argparse | Usage |\n|:---|:---|\n|_Training size_||\n|`--num_frames`|The number of using frames for training videos|\n|`--resolution`|The resolution of the input to the VAE|\n|`--batch_size`|The local batch size in each GPU|\n|`--sample_rate`|The frame interval of when loading training videos|\n|_Data processing_||\n|`--video_path`|/path/to/dataset|\n|_Load weights_||\n|`--model_name`| `CausalVAE` or `WFVAE`|\n|`--model_config`|/path/to/config.json The model config of VAE. If you want to train from scratch use this parameter.|\n|`--pretrained_model_name_or_path`|A directory containing a model checkpoint and its config. Using this parameter will only load its weight but not load the state of the optimizer|\n|`--resume_from_checkpoint`|/path/to/checkpoint It will resume the training process from the checkpoint including the weight and the optimizer.|\n\n### Inference\n\n``` shell\nbash scripts/causalvae/rec_video.sh\n```\nWe introduce the important args for inference.\n| Argparse | Usage |\n|:---|:---|\n|_Ouoput video size_||\n|`--num_frames`|The number of frames of generated videos|\n|`--height`|The resolution of generated videos|\n|`--width`|The resolution of generated videos|\n|_Data processing_||\n|`--video_path`|The path to the original video|\n|`--rec_path`|The path to the generated video|\n|_Load weights_||\n|`--ae_path`|/path/to/model_dir. A directory containing the checkpoint of VAE is used for inference and its model config.json|\n|_Other_||\n|`--enable_tilintg`|Use tiling to deal with videos of high resolution and long duration|\n|`--save_memory`|Save memory to inference but lightly influence quality|\n\n\n### Evaluation\n\nThe evaluation process consists of two steps:\n\nReconstruct videos in batches: `bash scripts/causalvae/prepare_eval.sh`\nEvaluate video metrics: `bash scripts/causalvae/eval.sh`\n\nTo simplify the evaluation, environment variables are used for control. For step 1 (`bash scripts/causalvae/prepare_eval.sh`):\n\n```bash\n# Experiment name\nEXP_NAME=wfvae\n# Video parameters\nSAMPLE_RATE=1\nNUM_FRAMES=33\nRESOLUTION=256\n# Model weights\nCKPT=ckpt\n# Select subset size (0 for full set)\nSUBSET_SIZE=0\n# Dataset directory\nDATASET_DIR=test_video\n```\n\nFor step 2 (`scripts/causalvae/eval.sh`):\n\n```bash\n# Experiment name\nEXP_NAME=wfvae-4dim\n# Video parameters\nSAMPLE_RATE=1\nNUM_FRAMES=33\nRESOLUTION=256\n# Evaluation metric\nMETRIC=lpips\n# Select subset size (0 for full set)\nSUBSET_SIZE=0\n# Path to the ground truth videos, which can be saved during video reconstruction by setting `--output_origin`\nORIGIN_DIR=video_gen/${EXP_NAME}_sr${SAMPLE_RATE}_nf${NUM_FRAMES}_res${RESOLUTION}_subset${SUBSET_SIZE}/origin\n# Path to the reconstructed videos\nRECON_DIR=video_gen/${EXP_NAME}_sr${SAMPLE_RATE}_nf${NUM_FRAMES}_res${RESOLUTION}_subset${SUBSET_SIZE}\n```"
  },
  {
    "path": "examples/cond_pix_path.txt",
    "content": "examples/test_img1.png\nexamples/test_img2.png\nexamples/test_img3.png"
  },
  {
    "path": "examples/cond_prompt.txt",
    "content": "A rocket ascends slowly into the sky.\nAlong the coast, variously sized boats float on the lake.\nThe landscape at sunset is profound and expansive."
  },
  {
    "path": "examples/rec_image.py",
    "content": "import sys\nsys.path.append(\".\")\nfrom PIL import Image\nimport torch\nfrom torchvision.transforms import ToTensor, Compose, Resize, Normalize, Lambda\nfrom torch.nn import functional as F\nimport argparse\nimport numpy as np\nfrom opensora.models.causalvideovae import ae_wrapper\n\ndef preprocess(video_data: torch.Tensor, short_size: int = 128) -> torch.Tensor:\n    transform = Compose(\n        [\n            ToTensor(),\n            Lambda(lambda x: 2. * x - 1.), \n            Resize(size=short_size),\n        ]\n    )\n    outputs = transform(video_data)\n    outputs = outputs.unsqueeze(0).unsqueeze(2)\n    return outputs\n\ndef main(args: argparse.Namespace):\n    image_path = args.image_path\n    short_size = args.short_size\n    device = args.device\n    kwarg = {}\n    \n    # vae = getae_wrapper(args.ae)(args.model_path, subfolder=\"vae\", cache_dir='cache_dir', **kwarg).to(device)\n    vae = ae_wrapper[args.ae](args.ae_path, **kwarg).eval().to(device)\n    if args.enable_tiling:\n        vae.vae.enable_tiling()\n        vae.vae.tile_overlap_factor = args.tile_overlap_factor\n    vae.eval()\n    vae = vae.to(device)\n    vae = vae.half()\n    \n    with torch.no_grad():\n        x_vae = preprocess(Image.open(image_path), short_size)\n        x_vae = x_vae.to(device, dtype=torch.float16)  # b c t h w\n        latents = vae.encode(x_vae)\n        latents = latents.to(torch.float16)\n        image_recon = vae.decode(latents)  # b t c h w\n    x = image_recon[0, 0, :, :, :]\n    x = x.squeeze()\n    x = x.detach().cpu().numpy()\n    x = np.clip(x, -1, 1)\n    x = (x + 1) / 2\n    x = (255*x).astype(np.uint8)\n    x = x.transpose(1,2,0)\n    image = Image.fromarray(x)\n    image.save(args.rec_path)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--image_path', type=str, default='')\n    parser.add_argument('--rec_path', type=str, default='')\n    parser.add_argument('--ae', type=str, default='')\n    parser.add_argument('--ae_path', type=str, default='')\n    parser.add_argument('--model_path', type=str, default='results/pretrained')\n    parser.add_argument('--short_size', type=int, default=336)\n    parser.add_argument('--device', type=str, default='cuda')\n    parser.add_argument('--tile_overlap_factor', type=float, default=0.25)\n    parser.add_argument('--enable_tiling', action='store_true')\n    \n    args = parser.parse_args()\n    main(args)"
  },
  {
    "path": "examples/rec_video.py",
    "content": "import math\nimport random\nimport argparse\nfrom typing import Optional\n\nimport cv2\nimport numpy as np\nimport numpy.typing as npt\nimport torch\nfrom PIL import Image\nfrom decord import VideoReader, cpu\nfrom torch.nn import functional as F\nfrom torchvision.transforms import Lambda, Compose\nimport sys\nsys.path.append(\".\")\nfrom opensora.models.causalvideovae import ae_wrapper\nfrom opensora.dataset.transform import ToTensorVideo, CenterCropResizeVideo\n\n\ndef array_to_video(image_array: npt.NDArray, fps: float = 30.0, output_file: str = 'output_video.mp4') -> None:\n    height, width, channels = image_array[0].shape\n    fourcc = cv2.VideoWriter_fourcc(*'mp4v')\n    video_writer = cv2.VideoWriter(output_file, fourcc, float(fps), (width, height))\n\n    for image in image_array:\n        image_rgb = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)\n        video_writer.write(image_rgb)\n\n    video_writer.release()\n\n\ndef custom_to_video(x: torch.Tensor, fps: float = 2.0, output_file: str = 'output_video.mp4') -> None:\n    x = x.detach().cpu()\n    x = torch.clamp(x, -1, 1)\n    x = (x + 1) / 2\n    x = x.permute(0, 2, 3, 1).float().numpy()\n    x = (255 * x).astype(np.uint8)\n    array_to_video(x, fps=fps, output_file=output_file)\n    return\n\n\ndef read_video(video_path: str, num_frames: int, sample_rate: int) -> torch.Tensor:\n    decord_vr = VideoReader(video_path, ctx=cpu(0))\n    total_frames = len(decord_vr)\n    sample_frames_len = sample_rate * num_frames\n\n    # if total_frames > sample_frames_len:\n    #     s = random.randint(0, total_frames - sample_frames_len - 1)\n    #     s = 0\n    #     e = s + sample_frames_len\n    #     num_frames = num_frames\n    # else:\n    # s = 0\n    # e = total_frames\n    # num_frames = int(total_frames / sample_frames_len * num_frames)\n    s = 0\n    e = sample_frames_len\n    print(f'sample_frames_len {sample_frames_len}, only can sample {num_frames * sample_rate}', video_path,\n            total_frames)\n\n    frame_id_list = np.linspace(s, e - 1, num_frames, dtype=int)\n    video_data = decord_vr.get_batch(frame_id_list).asnumpy()\n    video_data = torch.from_numpy(video_data)\n    video_data = video_data.permute(3, 0, 1, 2)  # (T, H, W, C) -> (C, T, H, W)\n    return video_data\n\n\ndef preprocess(video_data: torch.Tensor, height: int = 128, width: int = 128) -> torch.Tensor:\n    transform = Compose(\n        [\n            ToTensorVideo(),\n            CenterCropResizeVideo((height, width)),\n            Lambda(lambda x: 2. * x - 1.)\n        ]\n    )\n\n    video_outputs = transform(video_data)\n    video_outputs = torch.unsqueeze(video_outputs, 0)\n\n    return video_outputs\n\n\ndef main(args: argparse.Namespace):\n    device = args.device\n    kwarg = {}\n    # vae = getae_wrapper(args.ae)(args.model_path, subfolder=\"vae\", cache_dir='cache_dir', **kwarg).to(device)\n    # vae = CausalVAEModelWrapper(args.ae_path, **kwarg).to(device)\n    vae = ae_wrapper[args.ae](args.ae_path, **kwarg).eval().to(device)\n    if args.enable_tiling:\n        vae.vae.enable_tiling()\n        vae.vae.tile_overlap_factor = args.tile_overlap_factor\n        # vae.vae.tile_sample_min_size = 512\n        # vae.vae.tile_latent_min_size = 64\n        # vae.vae.tile_sample_min_size_t = 29\n        # vae.vae.tile_latent_min_size_t = 8\n        # if args.save_memory:\n        #     vae.vae.tile_sample_min_size = 256\n        #     vae.vae.tile_latent_min_size = 32\n        #     vae.vae.tile_sample_min_size_t = 9\n        #     vae.vae.tile_latent_min_size_t = 3\n    dtype = torch.bfloat16\n    vae.eval()\n    vae = vae.to(device, dtype=dtype)\n    \n    with torch.no_grad():\n        x_vae = preprocess(read_video(args.video_path, args.num_frames, args.sample_rate), args.height,\n                           args.width)\n        print(\"input shape\", x_vae.shape)\n        x_vae = x_vae.to(device, dtype=dtype)  # b c t h w\n        # for i in range(10000):\n        latents = vae.encode(x_vae)\n        latents = latents.to(dtype)\n        video_recon = vae.decode(latents)  # b t c h w\n        print(\"recon shape\", video_recon.shape)\n\n\n    \n    # vae = vae.half()\n    # from tqdm import tqdm\n    # with torch.no_grad():\n    #     x_vae = torch.rand(1, 3, 93, 720, 1280)\n    #     print(x_vae.shape)\n    #     x_vae = x_vae.to(device, dtype=torch.float16)  # b c t h w\n    #     # x_vae = x_vae.to(device)  # b c t h w\n    #     for i in tqdm(range(100000)):\n    #         latents = vae.encode(x_vae)\n    #     print(latents.shape)\n    #     latents = latents.to(torch.float16)\n    #     video_recon = vae.decode(latents)  # b t c h w\n    #     print(video_recon.shape)\n\n\n    custom_to_video(video_recon[0], fps=args.fps, output_file=args.rec_path)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--video_path', type=str, default='')\n    parser.add_argument('--rec_path', type=str, default='')\n    parser.add_argument('--ae', type=str, default='')\n    parser.add_argument('--ae_path', type=str, default='')\n    parser.add_argument('--model_path', type=str, default='results/pretrained')\n    parser.add_argument('--fps', type=int, default=30)\n    parser.add_argument('--height', type=int, default=336)\n    parser.add_argument('--width', type=int, default=336)\n    parser.add_argument('--num_frames', type=int, default=100)\n    parser.add_argument('--sample_rate', type=int, default=1)\n    parser.add_argument('--device', type=str, default=\"cuda\")\n    parser.add_argument('--tile_overlap_factor', type=float, default=0.25)\n    parser.add_argument('--tile_sample_min_size', type=int, default=512)\n    parser.add_argument('--tile_sample_min_size_t', type=int, default=33)\n    parser.add_argument('--tile_sample_min_size_dec', type=int, default=256)\n    parser.add_argument('--tile_sample_min_size_dec_t', type=int, default=33)\n    parser.add_argument('--enable_tiling', action='store_true')\n    parser.add_argument('--save_memory', action='store_true')\n\n    args = parser.parse_args()\n    main(args)"
  },
  {
    "path": "examples/sora.txt",
    "content": "A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, along red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is dampand reflective, creating a mirror effect of thecolorful lights. Many pedestrians walk about.\nSeveral giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered tree sand dramatic snow capped mountains in the distance,mid afternoon lightwith wispy cloud sand a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field\nA movie trailer featuring the adventures ofthe 30 year old spacemanwearing a redwool knitted motorcycle helmet, bluesky, saltdesert, cinematic style, shoton 35mm film, vivid colors. \nDrone view of waves crashing against the rugged cliffs along Big Sur's garay point beach.The crashing blue waters create white-tipped waves,while the golden light of the setting sun illuminates the rocky shore. A small island with a lighthouse sits in the distance, and green\nshrubbery covers the cliffs edge. The steep drop from the road down to the beach is adramatic feat, with the cliff's edges jutting out over the sea. This is a view that captures the raw beauty of the coast and the rugged landscape of the Pacific Coast Highway.\nAnimated scene features a close-up of a short fluffy monster kneeling beside a melting red candle.The art style is 3D and realistic,with a focus on lighting and texture.The mood of the painting is one of wonder and curiosity, as the monster gazes at the flame with wide eyes and\nopen mouth. lts pose and expression convey a sense of innocence and playfulness, as if it is exploring the world around it for the first time.The use of warm colors and dramatic lighting further enhances the cozy atmosphere of the image.\nA gorgeously rendered papercraft world of a coral reef,rife with colorful fish and sea creatures.\nThis close-up shot of a Victoria crowned pigeon showcases its striking blue plumage and red chest. Its crest is made of delicate, lacy feathers, while its eye is a striking red color. The bird's head is tilted slightly to the side,giving the impression of it looking regal and majestic. The background is blurred,drawing attention to the bird's striking appearance.\nPhotorealistic closeup video of two pirate ships battling each other as they sail inside a cup of coffee.\nA young man at his 20s is sitting on a piece of cloud in the sky, reading a book.\nA petri dish with a bamboo forest growing within it that has tiny red pandas running around.\nThe camera rotates around a large stack of vintage televisions all showing different programs-1950s sci-fi movies, horror movies, news, static, a 1970s sitcom, etc, set inside a large New York museum gallery.\n3D animation of a small, round, fluffy creature with big, expressive eyes explores a vibrant, enchanted forest. The creature, a whimsical blend of a rabbit and a squirrel, has soft blue fur and a bushy, striped tail. It hops along a sparkling stream,its eyes wide with wonder. The forest is alive with magical elements: flowers that glow and change colors, trees with leaves in shades of purple and silver, and small floating lights that resemble fireflies. The creature stops to interact playfully with a group of tiny, fairy-like beings dancing around a mushroom ring. The creature looks up in awe at a large, glowing tree that seems to be the heart of the forest.\nHistorical footage of California during the gold rush.\nA close up view of a glass sphere that has a zen garden within it. There is a small dwarf in the sphere who is raking the zen garden and creating patterns in the sand.\nExtreme close up of a 24 year old woman's eye blinking, standing in Marrakech during magic hour, cinematic film shot in 70mm, depth of field,vivid colors, cinematic.\nA cartoon kangaroo disco dances.\nA beautiful homemade video showing the people of Lagos, Nigeria in the year 2056. Shot with a mobile phone camera.\nA cat waking up its sleeping owner demanding breakfast.The owner tries to ignore the cat, but the cat tries new tactics and finally the owner pulls out a secret stash of treats from under the pillow to hold the cat off a little longer.\nBorneo wildlife on the Kinabatangan River\nA Chinese Lunar New Year celebration video with Chinese Dragon.\nThe camera follows behind a white vintage SUv with a black roof rack as it speeds up a steep dirt road surrounded by pine trees on a steep mountain slope, dust kicks up from it's tires, the sunlight shines on the Suv as it speeds along the dirt road,casting a warm glow over the scene. The dirt road curves gently into the distance, with no other cars orvehicles in sight. The trees on either side of the road are redwoods, with patches of greenery scattered throughout. The car is seen from the rear following the curve with ease, making it seem as if it is on a rugged drive through the rugged terrain. The dirt road itself is surrounded by steep hills and mountains with a clear blue sky above with wispy clouds.\nReflections in the window of a train traveling through the Tokyo suburbs.\nA drone camera circles around a beautiful historic church built on a rocky outcropping along the Amalfi Coast, the view showcases historic and magnificent architectural details and tiered pathways and patios, waves are seen crashing against the rocks below as the view overlooks the horizon of the coastal waters and hilly landscapes of the Amalfi Coast ltaly, several distant people are seen walking and enjoying vistas on patios of the dramatic ocean views, the warm glow of the afternoon sun creates a magical and romantic feeling to the scene, the view is stunning captured with beautiful photography\nA large orange octopus is seen resting on the bottom of the ocean floor, blending in with the sandy and rocky terrain. lts tentacles are spread out around its body, and its eyes are closed. The octopus is unaware of a king crab that is crawling towards it from behind a rock,its claws raised and ready to attack. The crab is brown and spiny,with long legs and antennae. The scene is captured from a wide angle,showing the vastness and depth of the ocean. The wateris clear and blue, with rays of sunlight filtering through. The shot is sharp and crisp, with a high dynamic range. The octopus and the crab are in focus, while the background is slightly blurred,creating a depth of field effect.\nA flock of paper airplanes flutters through a dense jungle,weaving around trees as if they were migrating birds.\nA beautiful silhouette animation shows a wolf howling at the moon,feeling lonely, untilit finds its pack.\nNew York City submerged like Atlantis.Fish,whales,sea turtles and sharks swim through the streets of New York.\nA litter of golden retriever puppies playing in the snow.Their heads pop out of the snow, covered in.\nTour of an art gallery with many beautiful works of art in different styles.\nBeautiful, snowy Tokyo city is bustling. The camera moves through the bustling city street, following several people enjoying the beautiful snowy weather and shopping at nearby stalls. Gorgeous sakura petals are flying through the wind along with snowflakes.\nA stop motion animation of a flower growing out of the windowsill of a suburban house.\nThe story of a robot's life in a cyberpunk setting.\nAn extreme close-up of an gray-haired man with a beard in his 60s, he is deep in thought pondering the history of the universe as he sits at a cafe in Paris, his eyes focus on people offscreen as they walk as he sits mostly motionless, he is dressed in a wool coat suit coat with a button-down shirt, he wears a brown beret and glasses and has a very professorial appearance, and the end he offers a subtle closed-mouth smile as if he found the answer to the mystery of life, the lighting is very cinematic with the golden light and the Parisian streets and city in the background, depth of field, cinematic 35mm film.\nBasketball through hoop then explodes\nArcheologists discovera generic plastic chairin the desert,excavating and dusting it with great care\nA grandmother with neatly combed grey hair stands behind a colorful birthday cake with numerous candles at a wood dining room table,expression is one of pure joy and happines with a happy glow in her eye. She leans forward and blows out the candles with a gentle puff, the cake has pink frosting and sprinkles and the candles cease to flicker,the grandmotherwears a light blue blouse adorned with floral patterns,several happy friends and family sitting at the table can be seen celebrating,out of focus.The scene is beautifully captured, cinematic, showing a 3/4 view of the grandmother and the dining room. Warm color tones and soft lighting enhance the mood\nStep-printing scene of a person running, cinematic film shot in 35mm\nFive gray wolf pups frolicking and chasing each other around a remote gravel road, surrounded by grass. The pups run and leap, chasing each other, and nipping at each other, playing.\nTiltshift of a construction site filled with workers, equipment, and heavy machinery.\nA giant, towering cloud in the shape of a man looms overthe earth. The cloud man shoots lighting bolts down to the earth.\nA Samoyed and a Golden Retriever dog are playfully romping through a futuristic neon city at night. The neon lights emitted from the nearby buildings glistens off of their fur.\nThe Glenfinnan Viaduct is a historic railway bridge in Scotland, UK, that crosses over the west highland line between the towns of Mallaig and Fort Wiliam. It is a stunning sight as a steam train leaves the bridge, traveling over the arch-covered viaduct. The landscape is dotted with lush greenery and rocky mountains, creating a picturesque backdrop forthe train journey. The sky is blue and the sun is shining,making for a beautiful day to explore this majestic spot.\nThe camera directly faces colorful buildings in Burano ltaly. An adorable dalmation looks through a window on a building on the ground floor. Many people are walking and cycling along the canal streets in front of the buildings.\nAn adorable happy otter confidently stands on a surfboard wearing a yellow lifejacket, riding along turquoise tropical waters near lush tropical islands,3D digital render art style.\nThis close-up shot of a chameleon showcases its striking color changing capabilities.The background is blurred, drawing attention to the animals striking appearance.\nA corgi vlogging itself in tropical Maui.\nA white and orange tabby cat is seen happily darting through a dense garden, as if chasing something.Its eyes are wide and happy as it jogs forward, scanning the branches, flowers, and leaves as it walks. The path is narrow as it makes its way between all the plants. the scene is captured from a ground-level angle, following the cat closely, giving a low and intimate perspective. The image is cinematic with warm tones and a grainy texture. The scattered daylight between the leaves and plants above creates awarm contrast, accentuating the cat's orange fur. The shot is clear and sharp, with a shallow depth of field.\nAerial view of Santorini during the blue hour, showcasing the stunning architecture of white Cycladic buildings with blue domes. The caldera views are breathtaking,and the lighting creates a beautiful, serene atmosphere.\nTiltshift of a construction site filled with workers, equipment, and heavy machinery."
  },
  {
    "path": "opensora/__init__.py",
    "content": "# "
  },
  {
    "path": "opensora/acceleration/__init__.py",
    "content": ""
  },
  {
    "path": "opensora/acceleration/communications.py",
    "content": "import torch\nimport torch.distributed as dist\nfrom einops import rearrange\nfrom opensora.acceleration.parallel_states import hccl_info, lccl_info, enable_LCCL\ntry:\n    from lcalib.functional import lcal_all2allvc\nexcept:\n    lcal_all2allvc = None\n\ndef broadcast(input_: torch.Tensor):\n    sp_size = hccl_info.world_size\n    src = hccl_info.rank // sp_size * sp_size\n    dist.broadcast(input_, src=src, group=hccl_info.group)\n\n_COUNT = 0\ndef _all_to_all(\n    input_: torch.Tensor,\n    scatter_dim: int,\n    gather_dim: int,\n):\n    group = hccl_info.group\n    sp_size = hccl_info.world_size\n    input_list = [t.contiguous() for t in torch.tensor_split(input_, sp_size, scatter_dim)]\n    output_list = [torch.empty_like(input_list[0]) for _ in range(sp_size)]\n    dist.all_to_all(output_list, input_list, group=group)\n    return torch.cat(output_list, dim=gather_dim).contiguous()\n\ndef _single_all_to_all(\n    input_: torch.Tensor,\n    scatter_dim: int,\n    gather_dim: int,\n    enable_HCCL=False,\n):\n    if enable_LCCL:\n        sp_size = lccl_info.world_size\n    else:\n        sp_size = hccl_info.world_size\n    inp_shape = list(input_.shape)\n    inp_shape[scatter_dim] = inp_shape[scatter_dim] // sp_size\n    if scatter_dim < 1:\n        input_t = input_.reshape(\n            [sp_size, inp_shape[scatter_dim]] + \\\n            inp_shape[scatter_dim + 1:]\n        )\n    else:\n        # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them!\n        input_t = input_.reshape(\n            [-1, sp_size, inp_shape[scatter_dim]] + \\\n            inp_shape[scatter_dim + 1:]\n        ).transpose(0, 1).contiguous()\n\n    output = torch.empty_like(input_t)\n    if enable_LCCL and not enable_HCCL:\n        matrix_count = torch.ones([sp_size, sp_size], dtype=torch.int64, device=input_t.device) * (\n                    input_t.numel() // sp_size)\n        lcal_all2allvc(input_t, output, matrix_count, lccl_info.group)\n    else:\n        dist.all_to_all_single(output, input_t, group=hccl_info.group)\n    # if scattering the seq-dim, transpose the heads back to the original dimension\n    if scatter_dim < 1:\n        output = output.transpose(0, 1).contiguous()\n\n    return output.reshape(\n        inp_shape[: gather_dim] + [inp_shape[gather_dim] * sp_size, ] + inp_shape[gather_dim + 1:])\n\n\nclass _AllToAll(torch.autograd.Function):\n    \"\"\"All-to-all communication.\n\n    Args:\n        input_: input matrix\n        process_group: communication group\n        scatter_dim: scatter dimension\n        gather_dim: gather dimension\n    \"\"\"\n\n    @staticmethod\n    def forward(ctx, input_, scatter_dim, gather_dim, all_to_all_func):\n        ctx.scatter_dim = scatter_dim\n        ctx.gather_dim = gather_dim\n        ctx.all_to_all = all_to_all_func\n        output = ctx.all_to_all(input_, scatter_dim, gather_dim)\n        return output\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        grad_output = ctx.all_to_all(\n            grad_output,\n            ctx.gather_dim,\n            ctx.scatter_dim,\n        )\n        return (\n            grad_output,\n            None,\n            None,\n            None,\n        )\n\ndef all_to_all_SBH(\n    input_: torch.Tensor,\n    scatter_dim: int = 1,\n    gather_dim: int = 0,\n):\n    return _AllToAll.apply(input_, scatter_dim, gather_dim, _single_all_to_all)\n\ndef all_to_all_BSND(\n    input_: torch.Tensor,\n    scatter_dim: int = 2,\n    gather_dim: int = 1,\n):\n    return _AllToAll.apply(input_, scatter_dim, gather_dim, _all_to_all)\n\n\ndef prepare_parallel_data(\n        hidden_states, \n        encoder_hidden_states, \n        attention_mask, \n        encoder_attention_mask, \n        pooled_projections=None, \n        ):\n    def all_to_all(\n            hidden_states, \n            encoder_hidden_states, \n            attention_mask, \n            encoder_attention_mask, \n            pooled_projections, \n            ):\n        # hidden_states          (b c t h w)   -gather0-> (sp*b c t h w)   -scatter2-> (sp*b c t//sp h w)\n        # encoder_hidden_states  (b sp l/sp d) -gather0-> (sp*b sp l/sp d) -scatter1-> (sp*b 1 l/sp d)\n        # attention_mask         (b t*sp h w)  -gather0-> (sp*b t*sp h w)  -scatter1-> (sp*b t h w)\n        # encoder_attention_mask (b sp l)      -gather0-> (sp*b sp l)      -scatter1-> (sp*b 1 l)\n        # pooled_projections     (b sp d)      -gather0-> (sp*b sp d)      -scatter1-> (sp*b 1 d)\n        hidden_states = _single_all_to_all(hidden_states, scatter_dim=2, gather_dim=0, enable_HCCL=True)\n        encoder_hidden_states = _single_all_to_all(encoder_hidden_states, scatter_dim=1, gather_dim=0, enable_HCCL=True)\n        attention_mask = _single_all_to_all(attention_mask, scatter_dim=1, gather_dim=0, enable_HCCL=True)\n        encoder_attention_mask = _single_all_to_all(encoder_attention_mask, scatter_dim=1, gather_dim=0, enable_HCCL=True)\n        if pooled_projections is not None:\n            pooled_projections = _single_all_to_all(pooled_projections, scatter_dim=1, gather_dim=0, enable_HCCL=True)\n\n        return hidden_states, encoder_hidden_states, attention_mask, encoder_attention_mask, pooled_projections\n\n    sp_size = hccl_info.world_size\n    frame = hidden_states.shape[2]\n    assert frame % sp_size == 0, \"frame should be a multiple of sp_size\"\n\n    encoder_hidden_states = rearrange(\n        encoder_hidden_states, 'b 1 (n x) h -> b n x h',\n        n=sp_size, x=encoder_hidden_states.shape[2]//sp_size\n        ).contiguous()\n    hidden_states, encoder_hidden_states, attention_mask, encoder_attention_mask, pooled_projections = all_to_all(\n        hidden_states, \n        encoder_hidden_states, \n        attention_mask.repeat(1, sp_size, 1, 1), \n        encoder_attention_mask.repeat(1, sp_size, 1), \n        pooled_projections.repeat(1, sp_size, 1)\n        )\n\n    return hidden_states, encoder_hidden_states, attention_mask, encoder_attention_mask, pooled_projections"
  },
  {
    "path": "opensora/acceleration/parallel_states.py",
    "content": "import torch\nimport torch_npu\nimport torch.distributed as dist\nimport os\ntry:\n    from lcalib.functional import lcal_initialize\n    enable_LCCL = True\nexcept:\n    lcal_initialize = None\n    enable_LCCL = False\nclass COMM_INFO:\n    def __init__(self):\n        self.group = None\n        self.world_size = 0\n        self.rank = -1\n\nlccl_info = COMM_INFO()\nhccl_info = COMM_INFO()\n_SEQUENCE_PARALLEL_STATE = False\ndef initialize_sequence_parallel_state(sequence_parallel_size):\n    global _SEQUENCE_PARALLEL_STATE\n    if sequence_parallel_size > 1:\n        _SEQUENCE_PARALLEL_STATE = True\n        initialize_sequence_parallel_group(sequence_parallel_size)\n\ndef set_sequence_parallel_state(state):\n    global _SEQUENCE_PARALLEL_STATE\n    _SEQUENCE_PARALLEL_STATE = state\n\ndef get_sequence_parallel_state():\n    return _SEQUENCE_PARALLEL_STATE\n\ndef initialize_sequence_parallel_group(sequence_parallel_size):\n    \"\"\"Initialize the sequence parallel group.\"\"\"\n    rank = int(os.getenv('RANK', '0'))\n    world_size = int(os.getenv(\"WORLD_SIZE\", '1'))\n    assert world_size % sequence_parallel_size == 0, \"world_size must be divisible by sequence_parallel_size\"\n    # hccl\n    hccl_info.world_size = sequence_parallel_size\n    hccl_info.rank = rank\n    num_sequence_parallel_groups: int = world_size // sequence_parallel_size\n    for i in range(num_sequence_parallel_groups):\n        ranks = range(i * sequence_parallel_size, (i + 1) * sequence_parallel_size)\n        group = dist.new_group(ranks)\n        if rank in ranks:\n            hccl_info.group = group\n\n    if enable_LCCL:\n        assert sequence_parallel_size == 8, \"sequence_parallel_size should be 8 when enable_LCCL is True\"\n        rank %= sequence_parallel_size\n        lccl_info.world_size = sequence_parallel_size\n        lccl_info.group = lcal_initialize(rank, sequence_parallel_size)\n        lccl_info.rank = rank\n\ndef destroy_sequence_parallel_group():\n    \"\"\"Destroy the sequence parallel group.\"\"\"\n    dist.destroy_process_group()\n"
  },
  {
    "path": "opensora/adaptor/__init__.py",
    "content": ""
  },
  {
    "path": "opensora/adaptor/bf16_optimizer.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# SPDX-License-Identifier: Apache-2.0\n\n# DeepSpeed Team\n\nfrom collections import OrderedDict\nimport torch\nimport sys\nimport os\nfrom torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors\nfrom deepspeed import comm as dist\nfrom deepspeed.runtime.constants import PIPE_REPLICATED\nfrom deepspeed.runtime import ZeROOptimizer\nfrom packaging import version as pkg_version\n\nfrom deepspeed.git_version_info import version\nfrom deepspeed.runtime.utils import (get_global_norm_of_tensors, clip_tensors_by_global_norm, DummyOptim,\n                                     align_dense_tensors, all_gather_dp_groups, bwc_tensor_model_parallel_rank,\n                                     is_model_parallel_parameter, see_memory_usage, graph_process)\n\nfrom deepspeed.utils import link_hp_params, fragment_address\nfrom deepspeed.checkpoint import enable_universal_checkpoint\nfrom deepspeed.checkpoint.constants import (DS_VERSION, PARTITION_COUNT, BASE_OPTIMIZER_STATE,\n                                            SINGLE_PARTITION_OF_FP32_GROUPS, CLIP_GRAD, GROUP_PADDINGS,\n                                            PARAM_SLICE_MAPPINGS)\n\nsetattr(sys.modules[__name__], 'fragment_address', fragment_address)\n\n\ndef contigous_flatten(tensors):\n    return _flatten_dense_tensors([tensor.contiguous() for tensor in tensors])\n\n\nclass BF16_Optimizer(ZeROOptimizer):\n\n    def __init__(self,\n                 init_optimizer,\n                 param_names,\n                 mpu=None,\n                 clip_grad=0.0,\n                 norm_type=2,\n                 allgather_bucket_size=5000000000,\n                 dp_process_group=None,\n                 timers=None,\n                 grad_acc_dtype=None,\n                 graph_harvesting=False):\n        # super().__init__()\n        # base_class = ZeROOptimizer.__bases__[0]\n        # # 直接调用基类的 __init__ 方法\n        # base_class.__init__()\n        see_memory_usage('begin bf16_optimizer', force=True)\n        self.timers = timers\n        self.optimizer = init_optimizer\n        self.param_names = param_names\n        self.using_real_optimizer = not isinstance(self.optimizer, DummyOptim)\n\n        assert grad_acc_dtype in [torch.float32, torch.bfloat16\n                                  ], f\"BF16Optimizer: Unsupported gradient accumulation data type: {grad_acc_dtype}\"\n        self.grad_acc_dtype = grad_acc_dtype\n\n        self.clip_grad = clip_grad\n        self.norm_type = norm_type\n        self.mpu = mpu\n        self.allgather_bucket_size = int(allgather_bucket_size)\n        self.dp_process_group = dp_process_group\n        self.dp_rank = dist.get_rank(group=self.dp_process_group)\n        self.real_dp_process_group = [dp_process_group for i in range(len(self.optimizer.param_groups))]\n\n        # Use torch (un)flatten ops\n        self.flatten = contigous_flatten\n        self.unflatten = _unflatten_dense_tensors\n\n        #align nccl all-gather send buffers to 4-bye boundary\n        self.nccl_start_alignment_factor = 16\n\n        # Build BF16/FP32 groups\n        self.bf16_groups = []\n        self.bf16_groups_flat = []\n        self.bf16_partitioned_groups = []\n\n        self.fp32_groups_flat_partition = []\n\n        # Maintain different fp32 gradients views for convenience\n        self.fp32_groups_gradients = []\n        self.fp32_groups_gradient_dict = {}\n        self.fp32_groups_gradients_flat = []\n        self.fp32_groups_actual_gradients_flat = []\n        self.fp32_groups_gradient_flat_partition = []\n        self.fp32_groups_has_gradients = []\n\n        self.group_paddings = []\n        self.graph_harvesting = graph_harvesting\n        if self.using_real_optimizer:\n            self._setup_for_real_optimizer()\n\n        see_memory_usage('end bf16_optimizer', force=True)\n\n    def _setup_for_real_optimizer(self):\n        dp_world_size = dist.get_world_size(group=self.dp_process_group)\n        self.partition_count = [dp_world_size for i in range(len(self.optimizer.param_groups))]\n\n        for i, param_group in enumerate(self.optimizer.param_groups):\n            see_memory_usage(f'before initializing group {i}', force=True)\n\n            partition_id = dist.get_rank(group=self.real_dp_process_group[i])\n\n            # grab the original list\n            trainable_parameters = [param for param in param_group['params'] if param.requires_grad]\n            self.bf16_groups.append(trainable_parameters)\n\n            # create flat bf16 params\n            self.bf16_groups_flat.append(\n                self._flatten_dense_tensors_aligned(self.bf16_groups[i],\n                                                    self.nccl_start_alignment_factor * dp_world_size))\n\n            # Make bf16 params point to flat tensor storage\n            self._update_storage_to_flattened_tensor(tensor_list=self.bf16_groups[i],\n                                                     flat_tensor=self.bf16_groups_flat[i])\n\n            # divide flat weights into equal sized partitions\n            partition_size = self.bf16_groups_flat[i].numel() // dp_world_size\n            bf16_dp_partitions = [\n                self.bf16_groups_flat[i].narrow(0, dp_index * partition_size, partition_size)\n                for dp_index in range(dp_world_size)\n            ]\n            self.bf16_partitioned_groups.append(bf16_dp_partitions)\n\n            # create fp32 params partition\n            self.fp32_groups_flat_partition.append(bf16_dp_partitions[partition_id].clone().float().detach())\n            self.fp32_groups_flat_partition[i].requires_grad = True\n\n            num_elem_list = [t.numel() for t in self.bf16_groups[i]]\n\n            # create fp32 gradients\n            self.fp32_groups_gradients_flat.append(\n                torch.zeros_like(self.bf16_groups_flat[i], dtype=self.grad_acc_dtype))\n\n            # track individual fp32 gradients for entire model\n            fp32_gradients = self._split_flat_tensor(flat_tensor=self.fp32_groups_gradients_flat[i],\n                                                     num_elem_list=num_elem_list)\n            self.fp32_groups_gradients.append(fp32_gradients)\n            self.fp32_groups_gradient_dict[i] = fp32_gradients\n\n            # flat tensor corresponding to actual fp32 gradients (i.e., minus alignment padding)\n            length_without_padding = sum(num_elem_list)\n            self.fp32_groups_actual_gradients_flat.append(\n                torch.narrow(self.fp32_groups_gradients_flat[i], 0, 0, length_without_padding))\n\n            # flat tensor corresponding to gradient partition\n            self.fp32_groups_gradient_flat_partition.append(\n                torch.narrow(self.fp32_groups_gradients_flat[i], 0, partition_id * partition_size, partition_size))\n\n            # track fp32 gradient updates\n            self.fp32_groups_has_gradients.append([False] * len(self.bf16_groups[i]))\n\n            # Record padding required for alignment\n            if partition_id == dist.get_world_size(group=self.real_dp_process_group[i]) - 1:\n                padding = self.bf16_groups_flat[i].numel() - length_without_padding\n            else:\n                padding = 0\n\n            self.group_paddings.append(padding)\n\n            # update optimizer param groups to reference fp32 params partition\n            param_group['params'] = [self.fp32_groups_flat_partition[i]]\n\n            see_memory_usage(f'after initializing group {i}', force=True)\n\n        see_memory_usage('before initialize_optimizer', force=True)\n        self.initialize_optimizer_states()\n        see_memory_usage('end initialize_optimizer', force=True)\n\n        # Need optimizer states initialized before linking lp to optimizer state\n        self._link_all_hp_params()\n        self._enable_universal_checkpoint()\n        self._param_slice_mappings = self._create_param_mapping()\n\n    def _enable_universal_checkpoint(self):\n        for lp_param_group in self.bf16_groups:\n            enable_universal_checkpoint(param_list=lp_param_group)\n\n    def _create_param_mapping(self):\n        param_mapping = []\n        for i, _ in enumerate(self.optimizer.param_groups):\n            param_mapping_per_group = OrderedDict()\n            for lp in self.bf16_groups[i]:\n                if lp._hp_mapping is not None:\n                    lp_name = self.param_names[lp]\n                    param_mapping_per_group[lp_name] = lp._hp_mapping.get_hp_fragment_address()\n            param_mapping.append(param_mapping_per_group)\n\n        return param_mapping\n\n    def _link_all_hp_params(self):\n        dp_world_size = dist.get_world_size(group=self.dp_process_group)\n        for i, _ in enumerate(self.optimizer.param_groups):\n            # Link bf16 and fp32 params in partition\n            partition_id = dist.get_rank(group=self.real_dp_process_group[i])\n            partition_size = self.bf16_groups_flat[i].numel() // dp_world_size\n            flat_hp_partition = self.fp32_groups_flat_partition[i]\n            link_hp_params(lp_param_list=self.bf16_groups[i],\n                           flat_hp_partition=flat_hp_partition,\n                           gradient_dict=self.fp32_groups_gradient_dict,\n                           offload_gradient_dict=None,\n                           use_offload=False,\n                           param_group_index=i,\n                           partition_start=partition_id * partition_size,\n                           partition_size=partition_size,\n                           partition_optimizer_state=self.optimizer.state[flat_hp_partition],\n                           dp_group=self.real_dp_process_group[i])\n\n    def initialize_optimizer_states(self):\n        \"\"\"Take an optimizer step with zero-valued gradients to allocate internal\n        optimizer state.\n\n        This helps prevent memory fragmentation by allocating optimizer state at the\n        beginning of training instead of after activations have been allocated.\n        \"\"\"\n        for param_partition, grad_partition in zip(self.fp32_groups_flat_partition,\n                                                   self.fp32_groups_gradient_flat_partition):\n            # In case of grad acc dtype different than FP32, need to cast to high precision.\n            param_partition.grad = grad_partition.to(\n                param_partition.dtype) if grad_partition.dtype != param_partition.dtype else grad_partition\n\n        self.optimizer.step()\n\n        if self.grad_acc_dtype is not torch.float32:\n            for param_partition in self.fp32_groups_flat_partition:\n                param_partition.grad = None\n\n        self.clear_hp_grads()\n\n    def _split_flat_tensor(self, flat_tensor, num_elem_list):\n        assert sum(num_elem_list) <= flat_tensor.numel()\n        tensor_list = []\n        offset = 0\n        for num_elem in num_elem_list:\n            dense_tensor = torch.narrow(flat_tensor, 0, offset, num_elem)\n            tensor_list.append(dense_tensor)\n            offset += num_elem\n\n        return tensor_list\n\n    def _update_storage_to_flattened_tensor(self, tensor_list, flat_tensor):\n        updated_params = self.unflatten(flat_tensor, tensor_list)\n        for p, q in zip(tensor_list, updated_params):\n            p.data = q.data\n\n    def _flatten_dense_tensors_aligned(self, tensor_list, alignment):\n        return self.flatten(align_dense_tensors(tensor_list, alignment))\n\n    @torch.no_grad()\n    def step(self, closure=None):\n        if closure is not None:\n            raise NotImplementedError(f'{self.__class__} does not support closure.')\n\n        all_groups_norm = get_global_norm_of_tensors(input_tensors=self.get_grads_for_norm(),\n                                                     mpu=self.mpu,\n                                                     norm_type=self.norm_type,\n                                                     use_graph=self.graph_harvesting)\n        self._global_grad_norm = all_groups_norm\n\n        assert all_groups_norm > 0.\n        if self.clip_grad > 0.:\n            clip_tensors_by_global_norm(input_tensors=self.get_grads_for_norm(for_clipping=True),\n                                        max_norm=self.clip_grad,\n                                        global_norm=all_groups_norm,\n                                        mpu=self.mpu,\n                                        use_graph=self.graph_harvesting)\n\n        self.optimizer.step()\n\n        self.update_lp_params()\n\n        self.clear_hp_grads()\n\n    def backward(self, loss, update_hp_grads=True, clear_lp_grads=False, **bwd_kwargs):\n        \"\"\"Perform a backward pass and copy the low-precision gradients to the\n        high-precision copy.\n\n        We copy/accumulate to the high-precision grads now to prevent accumulating in the\n        bf16 grads after successive backward() calls (i.e., grad accumulation steps > 1)\n\n        The low-precision grads are deallocated during this procedure.\n        \"\"\"\n        self.clear_lp_grads()\n        loss.backward(**bwd_kwargs)\n\n        if update_hp_grads:\n            self.update_hp_grads(clear_lp_grads=clear_lp_grads)\n\n    @torch.no_grad()\n    def update_hp_grads(self, clear_lp_grads=False):\n\n        def _update_hp_grads_func(clear_lp_grads=False):\n            for i, group in enumerate(self.bf16_groups):\n                for j, lp in enumerate(group):\n                    if lp.grad is None:\n                        continue\n                    hp_grad = self.fp32_groups_gradients[i][j]\n                    assert hp_grad is not None, \\\n                        f'high precision param has no gradient, lp param_id = {id(lp)} group_info = [{i}][{j}]'\n                    hp_grad.data.add_(lp.grad.data.to(hp_grad.dtype).view(hp_grad.shape))\n                    lp._hp_grad = hp_grad\n                    self.fp32_groups_has_gradients[i][j] = True\n                    # clear gradients\n                    if clear_lp_grads:\n                        lp.grad._zero()\n\n        if self.graph_harvesting:\n            graph_process(False, _update_hp_grads_func, clear_lp_grads)\n        else:\n            _update_hp_grads_func(clear_lp_grads)\n        #cpu op\n        for i, group in enumerate(self.bf16_groups):\n            for j, lp in enumerate(group):\n                if lp.grad is None:\n                    continue\n                self.fp32_groups_has_gradients[i][j] = True\n\n    @torch.no_grad()\n    def get_grads_for_reduction(self):\n        return self.fp32_groups_gradients_flat\n\n    @torch.no_grad()\n    def get_grads_for_norm(self, for_clipping=False):\n        grads = []\n        tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu)\n        for i, group in enumerate(self.bf16_groups):\n            for j, lp in enumerate(group):\n                if not for_clipping:\n                    if hasattr(lp, PIPE_REPLICATED) and lp.ds_pipe_replicated:\n                        continue\n\n                    if not (tensor_mp_rank == 0 or is_model_parallel_parameter(lp)):\n                        continue\n\n                if not self.fp32_groups_has_gradients[i][j]:\n                    continue\n\n                grads.append(self.fp32_groups_gradients[i][j])\n\n        return grads\n\n    @torch.no_grad()\n    def update_lp_params(self):\n        for i, (bf16_partitions,\n                fp32_partition) in enumerate(zip(self.bf16_partitioned_groups, self.fp32_groups_flat_partition)):\n            partition_id = dist.get_rank(group=self.real_dp_process_group[i])\n            bf16_partitions[partition_id].data.copy_(fp32_partition.data)\n            # print_rank_0(f'update_lp_params {i=} {partition_id=}', force=True)\n            # if i == 0:\n            #     print_rank_0(f'{fp32_partition[:10]=}', force=True)\n\n        all_gather_dp_groups(groups_flat=self.bf16_groups_flat,\n                             partitioned_param_groups=self.bf16_partitioned_groups,\n                             dp_process_group=self.real_dp_process_group,\n                             start_alignment_factor=self.nccl_start_alignment_factor,\n                             allgather_bucket_size=self.allgather_bucket_size)\n\n    def clear_hp_grads(self):\n        for flat_gradients in self.fp32_groups_gradients_flat:\n            flat_gradients.zero_()\n\n        for i, group in enumerate(self.fp32_groups_gradients):\n            self.fp32_groups_has_gradients[i] = [False] * len(group)\n\n    def clear_lp_grads(self):\n        for group in self.bf16_groups:\n            for param in group:\n                if param.grad is not None:\n                    # Using zero_() fixed memory address for graph replay\n                    param.grad.zero_()\n\n    def state_dict(self):\n        state_dict = {}\n        state_dict[CLIP_GRAD] = self.clip_grad\n        state_dict[BASE_OPTIMIZER_STATE] = self.optimizer.state_dict()\n        state_dict[SINGLE_PARTITION_OF_FP32_GROUPS] = self.fp32_groups_flat_partition\n        state_dict[GROUP_PADDINGS] = self.group_paddings\n        state_dict[PARTITION_COUNT] = self.partition_count\n        state_dict[DS_VERSION] = version\n        state_dict[PARAM_SLICE_MAPPINGS] = self._param_slice_mappings\n\n        return state_dict\n\n    # Restore base optimizer fp32 weights bfloat16 weights\n    def _restore_from_bit16_weights(self):\n        for i, group in enumerate(self.bf16_groups):\n            partition_id = dist.get_rank(group=self.real_dp_process_group[i])\n            for bf16_partitions, fp32_partition in zip(self.bf16_partitioned_groups, self.fp32_groups_flat_partition):\n                fp32_partition.data.copy_(bf16_partitions[partition_id].data)\n\n    def refresh_fp32_params(self):\n        self._restore_from_bit16_weights()\n\n    def load_state_dict(self,\n                        state_dict_list,\n                        checkpoint_folder,\n                        load_optimizer_states=True,\n                        load_from_fp32_weights=False,\n                        load_serial=None):\n        if checkpoint_folder:\n            self._load_universal_checkpoint(checkpoint_folder, load_optimizer_states, load_from_fp32_weights)\n        else:\n            self._load_legacy_checkpoint(state_dict_list, load_optimizer_states, load_from_fp32_weights)\n\n    def _load_legacy_checkpoint(self, state_dict_list, load_optimizer_states=True, load_from_fp32_weights=False):\n\n        dp_rank = dist.get_rank(group=self.dp_process_group)\n        current_rank_sd = state_dict_list[dp_rank]\n\n        ckpt_version = current_rank_sd.get(DS_VERSION, False)\n        assert ckpt_version, f\"Empty ds_version in checkpoint, not clear how to proceed\"\n        ckpt_version = pkg_version.parse(ckpt_version)\n\n        self.clip_grad = current_rank_sd.get(CLIP_GRAD, self.clip_grad)\n\n        if load_optimizer_states:\n            self.optimizer.load_state_dict(current_rank_sd[BASE_OPTIMIZER_STATE])\n\n        if load_from_fp32_weights:\n            for current, saved in zip(self.fp32_groups_flat_partition,\n                                      current_rank_sd[SINGLE_PARTITION_OF_FP32_GROUPS]):\n                src_tensor = _get_padded_tensor(saved, current.numel())\n                current.data.copy_(src_tensor.data)\n\n        if load_optimizer_states:\n            self._link_all_hp_params()\n\n    def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights):\n        self._load_hp_checkpoint_state(checkpoint_folder)\n\n    @property\n    def param_groups(self):\n        \"\"\"Forward the wrapped optimizer's parameters.\"\"\"\n        return self.optimizer.param_groups\n\n    def _load_hp_checkpoint_state(self, checkpoint_dir):\n        checkpoint_dir = os.path.join(checkpoint_dir, \"zero\")\n        tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu)\n        tp_world_size = self.mpu.get_slice_parallel_world_size()\n\n        for i, _ in enumerate(self.optimizer.param_groups):\n            for lp in self.bf16_groups[i]:\n                if lp._hp_mapping is not None:\n                    #print(f\"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}\")\n                    lp.load_hp_checkpoint_state(os.path.join(checkpoint_dir, self.param_names[lp]), tp_rank,\n                                                tp_world_size)\n\n\ndef _get_padded_tensor(src_tensor, size):\n    if src_tensor.numel() >= size:\n        return src_tensor\n    padded_tensor = torch.zeros(size, dtype=src_tensor.dtype, device=src_tensor.device)\n    slice_tensor = torch.narrow(padded_tensor, 0, 0, src_tensor.numel())\n    slice_tensor.data.copy_(src_tensor.data)\n    return padded_tensor\n"
  },
  {
    "path": "opensora/adaptor/engine.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# SPDX-License-Identifier: Apache-2.0\n\n# DeepSpeed Team\n\nimport os\nimport re\nimport stat\nimport torch\nimport hashlib\nfrom collections import defaultdict, OrderedDict, deque\nfrom shutil import copyfile\nimport gc\n\nfrom torch.nn.modules import Module\nfrom torch.nn.parameter import Parameter\nfrom torch.optim import Optimizer\nfrom torch.optim.lr_scheduler import _LRScheduler\nfrom torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors\n\nfrom typing import Callable, Dict, Union, Iterable\n\nimport deepspeed\n\nfrom deepspeed import comm as dist\nfrom deepspeed.runtime.utils import see_memory_usage, DummyOptim\nfrom deepspeed.runtime.zero.offload_config import OffloadDeviceEnum\nfrom deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer\nfrom deepspeed.runtime.zero.partition_parameters import ZeroParamStatus\nfrom deepspeed.runtime.zero.utils import is_zero_supported_optimizer, ZeRORuntimeException\nfrom deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload\nfrom deepspeed.runtime.zero.config import ZERO_OPTIMIZATION\n\nfrom deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer\nfrom deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer\nfrom deepspeed.runtime.bf16_optimizer import BF16_Optimizer\n\nfrom deepspeed.runtime.config import DEEPSPEED_OPTIMIZERS, \\\n    ADAGRAD_OPTIMIZER, ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, ONEBIT_LAMB_OPTIMIZER, \\\n    TORCH_ADAM_PARAM, ADAM_W_MODE, ADAM_W_MODE_DEFAULT, ZERO_ONE_ADAM_OPTIMIZER, MUADAM_OPTIMIZER, MUADAMW_OPTIMIZER, \\\n    MUSGD_OPTIMIZER, LION_OPTIMIZER\n\nfrom deepspeed.runtime.dataloader import DeepSpeedDataLoader\nfrom deepspeed.runtime.constants import \\\n    ROUTE_TRAIN, ROUTE_PREDICT, ROUTE_EVAL, \\\n    PLD_THETA, PLD_GAMMA, BFLOAT16, FP16, AMP, GRADIENT_ACCUMULATION_STEPS, \\\n    DATA_PARALLEL_GROUP, GLOBAL_RANK\nfrom deepspeed.runtime.zero.config import ZeroStageEnum\nfrom deepspeed.compression import compression_scheduler\nfrom deepspeed.compression.constants import \\\n    WEIGHT_QUANTIZE_IN_FORWARD_ENABLED, \\\n    WEIGHT_QUANTIZATION, SHARED_PARAMETERS, \\\n    WEIGHT_QUANTIZE_ENABLED, \\\n    WEIGHT_QUANTIZE_GROUPS, \\\n    WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE, \\\n    WEIGHT_QUANTIZE_CHANGE_RATIO, \\\n    WEIGHT_QUANTIZE_TYPE, \\\n    WEIGHT_QUANTIZE_ROUNDING, \\\n    WEIGHT_QUANTIZE_VERBOSE, \\\n    WEIGHT_QUANTIZE_KERNEL\nfrom deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, FROZEN_PARAM_FRAGMENTS\nfrom deepspeed.runtime.sparse_tensor import SparseTensor\n\nfrom deepspeed.runtime import lr_schedules\nfrom deepspeed.utils import groups\nfrom deepspeed.utils import logger, log_dist, instrument_w_nvtx\nfrom deepspeed.utils.timer import NoopTimer, ThroughputTimer, SynchronizedWallClockTimer, \\\n    FORWARD_MICRO_TIMER, BACKWARD_MICRO_TIMER, BACKWARD_INNER_MICRO_TIMER, BACKWARD_REDUCE_MICRO_TIMER, \\\n    STEP_MICRO_TIMER, \\\n    FORWARD_GLOBAL_TIMER, BACKWARD_GLOBAL_TIMER, BACKWARD_INNER_GLOBAL_TIMER, BACKWARD_REDUCE_GLOBAL_TIMER, \\\n    STEP_GLOBAL_TIMER\nfrom deepspeed.utils.debug import debug_extract_module_and_param_names\nfrom deepspeed.monitor.monitor import MonitorMaster\nfrom deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop\nfrom deepspeed.runtime.utils import clip_grad_norm_\nfrom deepspeed.runtime.eigenvalue import Eigenvalue\nfrom deepspeed.runtime.data_pipeline.constants import DATA_SAMPLING, \\\n    DATA_ROUTING, DATA_SAMPLING_ENABLED, CURRICULUM_LEARNING, \\\n    CURRICULUM_LEARNING_ENABLED, DATA_SAMPLING_NUM_WORKERS, RANDOM_LTD, \\\n    RANDOM_LTD_ENABLED, RANDOM_LTD_LAYER_ID, RANDOM_LTD_LAYER_NUM, \\\n    RANDOM_LTD_LAYER_TOKEN_LR_SCHEDULE, RANDOM_LTD_LAYER_TOKEN_LR_ENABLED, \\\n    RANDOM_LTD_GLOBAL_BATCH_SIZE, RANDOM_LTD_MICRO_BATCH_SIZE, DATA_EFFICIENCY\nfrom deepspeed.runtime.data_pipeline.curriculum_scheduler import CurriculumScheduler\nfrom deepspeed.runtime.data_pipeline.data_routing.scheduler import RandomLTDScheduler\nfrom deepspeed.runtime.data_pipeline.data_routing.helper import remove_random_ltd_state_dict\nfrom deepspeed.runtime.data_pipeline.data_routing.basic_layer import RandomLayerTokenDrop\n\nfrom deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine\nfrom deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint\n\nfrom deepspeed.runtime.pipe.module import PipelineModule\nfrom deepspeed.runtime.utils import get_ma_status\nfrom deepspeed.ops.adam import FusedAdam\nfrom deepspeed.moe.sharded_moe import TopKGate, MOELayer\nfrom deepspeed.moe.layer import MoE\nfrom deepspeed.moe.utils import is_moe_param\nfrom deepspeed.git_version_info import version\n\nfrom deepspeed.profiling.flops_profiler.profiler import FlopsProfiler\nfrom deepspeed.utils.logging import print_json_dist, print_configuration\n\nfrom deepspeed.accelerator import get_accelerator\n\nfrom deepspeed.runtime.config import DtypeEnum\n\nfrom opensora.adaptor.zp_manager import zp_manager\n\nMEMORY_OPT_ALLREDUCE_SIZE = 500000000\n\nDeepSpeedOptimizerCallable = \\\n    Callable[[Union[Iterable[Parameter], Dict[str, Iterable]]], Optimizer]\nDeepSpeedSchedulerCallable = Callable[[Optimizer], _LRScheduler]\n\ntry:\n    import apex\n    from apex import amp\n    APEX_INSTALLED = True\nexcept ImportError:\n    # Fail silently so we don't spam logs unnecessarily if user isn't using amp\n    APEX_INSTALLED = False\n\n\ndef split_half_float_double_sparse(tensors):\n    device_type = get_accelerator().device_name()\n    supported_types = [\n        \"torch.{}.HalfTensor\".format(device_type), \"torch.{}.FloatTensor\".format(device_type),\n        \"torch.{}.DoubleTensor\".format(device_type), \"torch.{}.BFloat16Tensor\".format(device_type),\n        SparseTensor.type()\n    ]\n\n    for t in tensors:\n        assert t.type() in supported_types, f\"attempting to reduce an unsupported grad type: {t.type()}\"\n\n    buckets = []\n    for i, dtype in enumerate(supported_types):\n        bucket = [t for t in tensors if t.type() == dtype]\n        if bucket:\n            buckets.append((dtype, bucket))\n    return buckets\n\n\nclass EngineTimers(object):\n    r\"\"\"Wallclock timers for DeepSpeedEngine\"\"\"\n\n    def __init__(self, enable_micro_timers, enable_global_timers):\n        self.forward_timers = []\n        self.backward_timers = []\n        self.backward_inner_timers = []\n        self.backward_reduce_timers = []\n        self.step_timers = []\n        self.global_timers = []\n        self.micro_timers = []\n\n        if enable_micro_timers:\n            self.forward_timers += [FORWARD_MICRO_TIMER]\n            self.backward_timers += [BACKWARD_MICRO_TIMER]\n            self.backward_inner_timers += [BACKWARD_INNER_MICRO_TIMER]\n            self.backward_reduce_timers += [BACKWARD_REDUCE_MICRO_TIMER]\n            self.step_timers += [STEP_MICRO_TIMER]\n            self.micro_timers += [\n                FORWARD_MICRO_TIMER, BACKWARD_MICRO_TIMER, BACKWARD_INNER_MICRO_TIMER, BACKWARD_REDUCE_MICRO_TIMER,\n                STEP_MICRO_TIMER\n            ]\n\n        if enable_global_timers:\n            self.forward_timers += [FORWARD_GLOBAL_TIMER]\n            self.backward_timers += [BACKWARD_GLOBAL_TIMER]\n            self.backward_inner_timers += [BACKWARD_INNER_GLOBAL_TIMER]\n            self.backward_reduce_timers += [BACKWARD_REDUCE_GLOBAL_TIMER]\n            self.step_timers += [STEP_GLOBAL_TIMER]\n            self.global_timers += [\n                FORWARD_GLOBAL_TIMER, BACKWARD_GLOBAL_TIMER, BACKWARD_INNER_GLOBAL_TIMER, BACKWARD_REDUCE_GLOBAL_TIMER,\n                STEP_GLOBAL_TIMER\n            ]\n\n\nclass DeepSpeedEngine(Module):\n    r\"\"\"DeepSpeed engine for training.\"\"\"\n\n    def __init__(\n        self,\n        args,\n        model,\n        optimizer=None,\n        model_parameters=None,\n        training_data=None,\n        lr_scheduler=None,\n        mpu=None,\n        dist_init_required=None,\n        collate_fn=None,\n        config=None,\n        config_class=None,\n        dont_change_device=False,\n    ):\n        super(DeepSpeedEngine, self).__init__()\n        self.dont_change_device = dont_change_device\n        self.client_optimizer = optimizer\n        self.client_lr_scheduler = lr_scheduler\n        self.training_data = training_data\n        self.collate_fn = collate_fn\n        self.mpu = mpu\n        self.all_to_all_group = None\n        self.data_parallel_group = None\n        self.global_steps = 0\n        self.global_samples = 0\n        self.micro_steps = 0\n        self.skipped_steps = 0\n        self.gradient_average = True\n        self.warn_unscaled_loss = True\n        self.config = config\n        self._config = config_class\n        self.loaded_checkpoint_mp_world_size = None\n        self.loaded_checkpoint_dp_world_size = None\n        self.loaded_checkpoint_zp_world_size = None\n        self.enable_backward_allreduce = True\n        self.progressive_layer_drop = None\n        self.eigenvalue = None\n        self.block_eigenvalue = None\n        self.gas_boundary_ctr = 0\n        self.dist_backend = get_accelerator().communication_backend_name()\n        self.has_moe_layers = False\n        self.num_experts = []\n        self.gate_modules = []\n        self.moe_layers = []\n        self._step_applied = False\n        self._global_grad_norm = None\n        self.use_ds_comm = False  # False --> Use torch.dist, True --> Use ds.comm backend.\n\n        self.checkpoint_engine = None\n\n        self._is_gradient_accumulation_boundary = None\n        self.scale_wrt_gas = None\n        self.losses = 0.0\n\n        # for debug purposes - can then debug print: debug_get_module_name(module)\n        debug_extract_module_and_param_names(model)\n\n        self._do_args_sanity_check(args)\n        self._configure_with_arguments(args, mpu)\n        self._do_sanity_check()\n        see_memory_usage(f\"DeepSpeed Engine: After args sanity test\", force=self.memory_breakdown())\n        if mpu is not None:\n            if self.elasticity_enabled():\n                if not self.is_elastic_model_parallel_supported():\n                    assert not self.elasticity_enabled(), (\"Elasticity is not currently supported\"\n                                                           \" with model parallelism.\")\n\n        self._set_distributed_vars(args)\n\n        dist.configure(self._config)\n\n        self.monitor = MonitorMaster(self._config.monitor_config)\n\n        see_memory_usage(\n            f\"DeepSpeed Engine: Before configure distributed model\",\n            force=self.memory_breakdown(),\n        )\n\n        self.pipeline_parallelism = isinstance(model, PipelineModule)\n\n        # Configure distributed model\n        self._configure_distributed_model(model)\n\n        # needed for zero_to_fp32 weights reconstruction to remap nameless data to state_dict\n        self.param_names = {param: name for name, param in model.named_parameters()}\n\n        self._get_model_parameters()\n\n        see_memory_usage(f\"DeepSpeed Engine: After configure distributed model\")\n\n        # Configure wall clock timers\n        self.timers = SynchronizedWallClockTimer()\n        # Throughput timer\n        self.tput_timer = ThroughputTimer(\n            batch_size=self.train_batch_size(),\n            steps_per_output=self.steps_per_print(),\n            monitor_memory=False,\n        )\n\n        log_dist(f\"DeepSpeed Flops Profiler Enabled: {self.flops_profiler_enabled()}\", ranks=[0])\n\n        if self.flops_profiler_enabled():\n            self.flops_profiler = FlopsProfiler(self.module, self, self.flops_profiler_recompute_fwd_factor())\n\n        if training_data:\n            self.training_dataloader = self.deepspeed_io(training_data)\n        else:\n            self.training_dataloader = None\n\n        # Configure optimizer and scheduler\n        self.optimizer = None\n        self.basic_optimizer = None\n        self.lr_scheduler = None\n        has_optimizer = False\n\n        if optimizer or self.optimizer_name():\n            has_optimizer = True\n        # If no parameters given by init default to module parameters\n        if model_parameters is None:\n            model_parameters = self.module.parameters()\n\n        # Convert model parameters from generator to list\n        if not isinstance(model_parameters, list):\n            model_parameters = list(model_parameters)\n\n        if has_optimizer:\n            self._configure_optimizer(optimizer, model_parameters)\n            self._configure_lr_scheduler(lr_scheduler)\n            self._report_progress(0)\n        elif self.zero_optimization():\n            # no optim selected but zero is enabled\n            self.optimizer = self._configure_zero_optimizer(optimizer=None)\n        elif self.bfloat16_enabled():\n            self.optimizer = self._configure_bf16_optimizer(optimizer=None)\n\n        # Hook optimizer for snip_momentum pruning\n        if hasattr(model, 'pruners'):\n            from deepspeed.compression.helper import rewrite_optimizer_step\n            self.optimizer.pruners = model.pruners\n            rewrite_optimizer_step(self.optimizer)\n\n        # Bookkeeping for sparse support\n        self.sparse_tensor_module_names = set()\n        # if self.sparse_gradients_enabled():\n        for name, module in self.module.named_modules():\n            if isinstance(module, (torch.nn.Embedding, torch.nn.EmbeddingBag)) and self.sparse_gradients_enabled():\n                self.sparse_tensor_module_names.add(name + \".weight\")\n                logger.info(\"Will convert {} to sparse tensor during training\".format(name))\n\n        self.save_non_zero_checkpoint = False\n        self.save_zero_checkpoint = False\n        if not isinstance(self.optimizer, DeepSpeedZeRoOffload):\n            self._configure_checkpointing(dist_init_required)\n\n        if self.eigenvalue_enabled():\n            self.eigenvalue = self._configure_eigenvalue()\n\n        if self.pld_enabled():\n            self.progressive_layer_drop = self._configure_progressive_layer_drop()\n\n        if self.curriculum_enabled_legacy():\n            self.curriculum_scheduler_legacy = self._configure_curriculum_scheduler_legacy()\n\n        if self.random_ltd_enabled():\n            random_ltd_config = self.random_ltd_config()\n            random_ltd_config[RANDOM_LTD_GLOBAL_BATCH_SIZE] = self.train_batch_size()\n            random_ltd_config[RANDOM_LTD_MICRO_BATCH_SIZE] = self.train_micro_batch_size_per_gpu()\n            self.random_ltd_scheduler = self._configure_random_ltd_scheduler(random_ltd_config)\n\n        # Engine timers\n\n        self.engine_timers = EngineTimers(enable_micro_timers=self.wall_clock_breakdown(),\n                                          enable_global_timers=self.wall_clock_breakdown()\n                                          or self.flops_profiler_enabled())\n\n        if self.global_rank == 0:\n            self._config.print(\"DeepSpeedEngine configuration\")\n            if self.dump_state():\n                print_configuration(self, \"DeepSpeedEngine\")\n\n        # Use torch (un)flatten ops\n        self.flatten = _flatten_dense_tensors\n        self.unflatten = _unflatten_dense_tensors\n\n    def destroy(self):\n        if self.optimizer is not None and hasattr(self.optimizer, 'destroy'):\n            self.optimizer.destroy()\n\n    def _get_model_parameters(self):\n        if self.autotuning_profile_model_info():\n            self.autotuning_model_info = {}\n            num_params = 0\n            trainable_num_params = 0\n\n            for p in self.module.parameters():\n                # since user code might call deepspeed.zero.Init() before deepspeed.initialize(), need to check the attribute to check if the parameter is partitioned in zero 3 already or not\n                n = 0\n                if hasattr(p, \"ds_tensor\"):  # if the parameter is partitioned in zero 3\n                    n += p.ds_numel\n                else:  # if the parameter is not partitioned in zero 3 yet\n                    n += p.numel()\n                num_params += n\n                if p.requires_grad:\n                    trainable_num_params += n\n            if self.global_rank == 0:\n                self.autotuning_model_info[\"num_params\"] = num_params * self.mp_world_size\n                self.autotuning_model_info[\"trainable_num_params\"] = trainable_num_params * self.mp_world_size\n\n            logger.info(f\"model parameter = {num_params}\")\n\n    def get_batch_info(self):\n        \"\"\"Get all training batch related settings.\n        Returns:\n            train_batch_size (int): The effective training batch size. This is the amount of data\n                samples that leads to one step of model update.\n            train_micro_batch_size_per_gpu (int): Batch size to be processed by one GPU in one\n                step (without gradient accumulation).\n            gradient_accumulation_steps (int): Number of training steps to accumulate gradients\n                before averaging and applying them.\n        \"\"\"\n        return (\n            self.train_batch_size,\n            self.train_micro_batch_size_per_gpu,\n            self.gradient_accumulation_steps,\n        )\n\n    def set_train_batch_size(self, train_batch_size):\n        \"\"\"Adjust the global batch size by increasing or decreasing the number of\n        micro-batches (i.e., gradient accumulation steps). The size of each micro-batch\n        (i.e., ``train_micro_batch_size_per_gpu``) is not changed.\n        Args:\n            train_batch_size (int): The new global batch size for training.\n        Raises:\n            ValueError: if ``train_batch_size`` is not divisible by the\n                configured micro-batch size and data parallelism.\n        \"\"\"\n        if train_batch_size % (self.train_micro_batch_size_per_gpu() * self.dp_world_size) != 0:\n            #print(f'{train_batch_size=} {self.train_micro_batch_size_per_gpu()=} {self.dp_world_size=}')\n            raise ValueError(f'Train batch size must be divisible by micro-batch data parallelism')\n        new_gas = train_batch_size // (self.train_micro_batch_size_per_gpu() * self.dp_world_size)\n        # overwrite config\n        self._config.train_batch_size = train_batch_size\n        self._config.gradient_accumulation_steps = new_gas\n\n    def set_train_micro_batch_size(self, micro_batch_size):\n        \"\"\"Adjust the micro batch size(i.e., the micro batch size in every data parallel group),\n        while keep the gradient accumulation steps the same.\n        Args:\n            micro_batch_size (int): The new micro batch size for training.\n        \"\"\"\n        # overwrite config\n        new_global_batch_size = micro_batch_size * self._config.gradient_accumulation_steps * self.dp_world_size\n        self._config.train_batch_size = new_global_batch_size\n        self._config.train_micro_batch_size_per_gpu = micro_batch_size\n\n    def set_data_post_process_func(self, post_process_func):\n        if self.training_dataloader is not None:\n            self.training_dataloader.post_process_func = post_process_func\n\n    def set_custom_curriculum_learning_schedule(self, schedule_func_dict):\n        if self.training_dataloader is not None and self.curriculum_learning_enabled():\n            self.training_dataloader.data_sampler.set_custom_curriculum_learning_schedule(schedule_func_dict)\n\n    def get_global_grad_norm(self) -> float:\n        \"\"\"Return the 2-norm of all gradients. If there is model parallelism,\n        the norm will be global.\n        The computed norm will be cached and reused until the next step() pass.\n        .. note::\n            In the presence of model parallelism, this is a collective call\n            and acts as a barrier among ``mpu.get_model_parallel_group()``.\n        Returns:\n            float: norm\n        \"\"\"\n        return self._global_grad_norm\n\n    def __getattr__(self, name):\n        \"\"\"\n        Pass through attributes defined in the model if they are not overridden by ds-engine.\n        \"\"\"\n\n        _module = {}\n        if \"module\" in self.__dict__:\n            _module = self.__dict__['module']\n        if name in dir(self):\n            return getattr(self, name)\n        elif name in dir(_module):\n            return getattr(_module, name)\n        else:\n            raise AttributeError(f\"'{type(self).__name__}' object has no attribute '{name}'\")\n\n    def checkpoint_tag_validation_enabled(self):\n        return self._config.checkpoint_tag_validation_enabled\n\n    def checkpoint_tag_validation_fail(self):\n        return self._config.checkpoint_tag_validation_fail\n\n    def elasticity_enabled(self):\n        return self._config.elasticity_enabled\n\n    def is_elastic_model_parallel_supported(self):\n        if self.elasticity_enabled():\n            # Add code for finding number of GPUs per node automatically\n            if self._config.num_gpus_per_node % self._config.elastic_model_parallel_size == 0:\n                return True\n            else:\n                return False\n\n    def pld_enabled(self):\n        return self._config.pld_enabled\n\n    def pld_params(self):\n        return self._config.pld_params\n\n    def pld_theta(self):\n        return self.pld_params()[PLD_THETA]\n\n    def pld_gamma(self):\n        return self.pld_params()[PLD_GAMMA]\n\n    def eigenvalue_enabled(self):\n        return self._config.eigenvalue_enabled\n\n    def eigenvalue_verbose(self):\n        return self._config.eigenvalue_verbose\n\n    def eigenvalue_max_iter(self):\n        return self._config.eigenvalue_max_iter\n\n    def eigenvalue_tol(self):\n        return self._config.eigenvalue_tol\n\n    def eigenvalue_stability(self):\n        return self._config.eigenvalue_stability\n\n    def eigenvalue_gas_boundary_resolution(self):\n        return self._config.eigenvalue_gas_boundary_resolution\n\n    def eigenvalue_layer_name(self):\n        return self._config.eigenvalue_layer_name\n\n    def eigenvalue_layer_num(self):\n        return self._config.eigenvalue_layer_num\n\n    def curriculum_enabled_legacy(self):\n        return self._config.curriculum_enabled_legacy\n\n    def curriculum_params_legacy(self):\n        return self._config.curriculum_params_legacy\n\n    def data_efficiency_enabled(self):\n        return self._config.data_efficiency_enabled\n\n    def data_efficiency_config(self):\n        return self._config.data_efficiency_config\n\n    def data_sampling_enabled(self):\n        return self._config.data_efficiency_config[DATA_SAMPLING][DATA_SAMPLING_ENABLED]\n\n    def data_sampling_config(self):\n        return self._config.data_efficiency_config[DATA_SAMPLING]\n\n    def curriculum_learning_enabled(self):\n        return self._config.data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][CURRICULUM_LEARNING_ENABLED]\n\n    def curriculum_learning_config(self):\n        return self._config.data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING]\n\n    def random_ltd_enabled(self):\n        return self._config.data_efficiency_config[DATA_ROUTING][RANDOM_LTD][RANDOM_LTD_ENABLED]\n\n    def random_ltd_config(self):\n        return self._config.data_efficiency_config[DATA_ROUTING][RANDOM_LTD]\n\n    def random_ltd_initialize(self):\n        assert self.random_ltd_enabled()\n        random_ltd_config = self.random_ltd_config()\n        random_ltd_queue = deque([x for x in sorted(random_ltd_config[RANDOM_LTD_LAYER_ID])])\n        count = 0\n        for name, layer in self.module.named_modules():\n            if isinstance(layer, RandomLayerTokenDrop):\n                if len(random_ltd_queue) != 0 and str(random_ltd_queue[0]) in name:  ###[1,2,3]\n                    layer.init_config(random_ltd_config, self.random_ltd_scheduler, count)\n                    random_ltd_queue.popleft()\n                    count += 1\n\n        if random_ltd_config[RANDOM_LTD_LAYER_NUM] != count:\n            raise ValueError(f'random_ltd_layer_num {random_ltd_config[RANDOM_LTD_LAYER_NUM]} must be \\\n                equivalent to the len of random_ltd_layer_id {count}')\n\n        if random_ltd_config[RANDOM_LTD_LAYER_TOKEN_LR_SCHEDULE][RANDOM_LTD_LAYER_TOKEN_LR_ENABLED]:\n            assert self.client_lr_scheduler is None\n            raise ValueError(f'not yet support')\n            #self.lr_scheduler = lr_schedules.WarmupLayerTokenDecayLR(self.optimizer, self.random_ltd_scheduler)\n\n    def wall_clock_breakdown(self):\n        return self._config.wall_clock_breakdown\n\n    def flops_profiler_enabled(self):\n        return self._config.flops_profiler_config.enabled or self.autotuning_enabled()\n\n    def flops_profiler_recompute_fwd_factor(self):\n        return self._config.flops_profiler_config.recompute_fwd_factor\n\n    def flops_profiler_profile_step(self):\n        step = self._config.flops_profiler_config.profile_step\n        if self._config.autotuning_config.enabled:\n            step = self.autotuning_start_profile_step()\n        return step\n\n    def flops_profiler_module_depth(self):\n        return self._config.flops_profiler_config.module_depth\n\n    def flops_profiler_top_modules(self):\n        return self._config.flops_profiler_config.top_modules\n\n    def flops_profiler_detailed(self):\n        if self._config.autotuning_config.enabled:\n            return False\n        return self._config.flops_profiler_config.detailed\n\n    def flops_profiler_output_file(self):\n        return self._config.flops_profiler_config.output_file\n\n    def memory_breakdown(self):\n        return self._config.memory_breakdown\n\n    def autotuning_enabled(self):\n        return self._config.autotuning_config.enabled\n\n    def autotuning_start_profile_step(self):\n        return self._config.autotuning_config.start_profile_step\n\n    def autotuning_end_profile_step(self):\n        return self._config.autotuning_config.end_profile_step\n\n    def autotuning_metric_path(self):\n        path = self._config.autotuning_config.metric_path\n        if not path:\n            path = os.path.join(os.getcwd(), \"autotuning_metric.json\")\n        return path\n\n    def autotuning_model_info_path(self):\n        path = self._config.autotuning_config.model_info_path\n        if not path:\n            path = os.path.join(os.getcwd(), \"autotuning_model_info.json\")\n        return path\n\n    def autotuning_metric(self):\n        return self._config.autotuning_config.metric\n\n    def autotuning_profile_model_info(self):\n        return self.autotuning_enabled(\n        ) and self._config.autotuning_config.model_info and self._config.autotuning_config.model_info.get(\n            \"profile\", False)\n\n    def sparse_gradients_enabled(self):\n        return self._config.sparse_gradients_enabled\n\n    def train_batch_size(self):\n        return self._config.train_batch_size\n\n    def train_micro_batch_size_per_gpu(self):\n        return self._config.train_micro_batch_size_per_gpu\n\n    def optimizer_name(self):\n        return (self.client_optimizer.__class__.__name__ if self.client_optimizer else self._config.optimizer_name)\n\n    def optimizer_params(self):\n        return self._config.optimizer_params\n\n    def optimizer_legacy_fusion(self):\n        return self._config.optimizer_legacy_fusion\n\n    def scheduler_name(self):\n        return self._config.scheduler_name\n\n    def scheduler_params(self):\n        return self._config.scheduler_params\n\n    def quantize_training(self):\n        return (\n            self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS]\n            [WEIGHT_QUANTIZE_IN_FORWARD_ENABLED],\n            self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_ENABLED],\n            self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_GROUPS],\n            self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS]\n            [WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE],\n            self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_CHANGE_RATIO],\n            self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_TYPE],\n            self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_ROUNDING],\n            self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_VERBOSE],\n            self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_KERNEL],\n        )\n\n    def zero_optimization(self):\n        return self._config.zero_enabled\n\n    def zero_allow_untested_optimizer(self):\n        return self._config.zero_allow_untested_optimizer\n\n    def zero_force_ds_cpu_optimizer(self):\n        return self._config.zero_force_ds_cpu_optimizer\n\n    def zero_reduce_scatter(self):\n        return self._config.zero_config.reduce_scatter\n\n    def zero_overlap_comm(self):\n        return self._config.zero_config.overlap_comm\n\n    def zero_offload_optimizer(self):\n        return self._config.zero_config.offload_optimizer\n\n    def zero_offload_param(self):\n        return self._config.zero_config.offload_param\n\n    def zero_use_cpu_optimizer(self):\n        if self._config.zero_config.offload_optimizer is not None:\n            return self._config.zero_config.offload_optimizer.device in [OffloadDeviceEnum.cpu, OffloadDeviceEnum.nvme]\n        return False\n\n    def zero_cpu_offload(self):\n        if self._config.zero_config.offload_optimizer is not None:\n            return self._config.zero_config.offload_optimizer.device == OffloadDeviceEnum.cpu\n        return False\n\n    def zero_partial_offload(self):\n        return getattr(self._config.zero_config.offload_optimizer, \"ratio\", 1.0)\n\n    def zero_sub_group_size(self):\n        return self._config.zero_config.sub_group_size\n\n    def zero_optimization_stage(self):\n        return self._config.zero_optimization_stage\n\n    def mics_shard_size(self):\n        return self._config.mics_shard_size\n\n    def zero_reduce_bucket_size(self):\n        return self._config.zero_config.reduce_bucket_size\n\n    def zero_multi_rank_bucket_allreduce(self):\n        return self._config.zero_config.use_multi_rank_bucket_allreduce\n\n    def zero_allgather_bucket_size(self):\n        return self._config.zero_config.allgather_bucket_size\n\n    def zero_optimization_partition_gradients(self):\n        return self.zero_optimization_stage() >= ZeroStageEnum.gradients\n\n    def zero_optimization_partition_weights(self):\n        return self.zero_optimization_stage() >= ZeroStageEnum.weights\n\n    def is_first_weights_partition_group(self):\n        ret = True if self.mics_shard_size() < 0 \\\n            and self.zero_optimization_partition_weights() else False\n        if self.mics_shard_size() > 0 and self.global_rank < self.mics_shard_size():\n            ret = True\n        return ret\n\n    def zero_contiguous_gradients(self):\n        return self._config.zero_config.contiguous_gradients\n\n    def zero_load_from_fp32_weights(self):\n        return self._config.zero_config.load_from_fp32_weights\n\n    def zero_elastic_checkpoint(self):\n        return self._config.zero_config.elastic_checkpoint\n\n    def zero_max_live_parameters(self):\n        return self._config.zero_config.max_live_parameters\n\n    def zero_max_reuse_distance(self):\n        return self._config.zero_config.max_reuse_distance\n\n    def zero_prefetch_bucket_size(self):\n        return self._config.zero_config.prefetch_bucket_size\n\n    def zero_param_persistence_threshold(self):\n        return self._config.zero_config.param_persistence_threshold\n\n    def zero_model_persistence_threshold(self):\n        return self._config.zero_config.model_persistence_threshold\n\n    def zero_gather_16bit_weights_on_model_save(self):\n        return self._config.zero_config.gather_16bit_weights_on_model_save\n\n    def zero_grad_hooks(self):\n        return self._config.zero_config.grad_hooks\n\n    def zero_legacy_stage1(self):\n        return self._config.zero_config.legacy_stage1\n\n    def zero_ignore_unused_parameters(self):\n        return self._config.zero_config.ignore_unused_parameters\n\n    def graph_harvesting(self):\n        return self._config.graph_harvesting\n\n    def fp16_enabled(self):\n        return self._config.fp16_enabled\n\n    def bfloat16_enabled(self):\n        return self._config.bfloat16_enabled\n\n    def fp16_master_weights_and_gradients(self):\n        return self._config.fp16_master_weights_and_gradients\n\n    def amp_enabled(self):\n        return self._config.amp_enabled\n\n    def amp_params(self):\n        return self._config.amp_params\n\n    def fp16_auto_cast(self):\n        return self._config.fp16_auto_cast\n\n    def loss_scale(self):\n        return self._config.loss_scale\n\n    def gradient_accumulation_steps(self):\n        return self._config.gradient_accumulation_steps\n\n    def use_node_local_storage(self):\n        return self._config.use_node_local_storage\n\n    def load_universal_checkpoint(self):\n        return self._config.load_universal_checkpoint\n\n    @property\n    def communication_data_type(self):\n        res = self._config.communication_data_type\n        if res is not None:\n            return res\n\n        if self.fp16_enabled():\n            return torch.float16\n\n        if self.bfloat16_enabled():\n            return torch.bfloat16\n\n        return torch.float32\n\n    @communication_data_type.setter\n    def communication_data_type(self, value):\n        self._config.communication_data_type = value\n\n    def postscale_gradients(self):\n        return not self._config.prescale_gradients\n\n    def gradient_predivide_factor(self):\n        return self._config.gradient_predivide_factor\n\n    def steps_per_print(self):\n        return self._config.steps_per_print\n\n    def zero_allgather_partitions(self):\n        return self._config.zero_config.allgather_partitions\n\n    def zero_round_robin_gradients(self):\n        return self._config.zero_config.round_robin_gradients\n\n    def zero_hpz_partition_size(self):\n        return self._config.zero_config.zero_hpz_partition_size\n\n    def zero_quantized_weights(self):\n        return self._config.zero_config.zero_quantized_weights\n\n    def zero_quantized_nontrainable_weights(self):\n        return self._config.zero_config.zero_quantized_nontrainable_weights\n\n    def zero_quantized_gradients(self):\n        return self._config.zero_config.zero_quantized_gradients\n\n    def dump_state(self):\n        return self._config.dump_state\n\n    def gradient_clipping(self):\n        return self._config.gradient_clipping\n\n    def dynamic_loss_scale(self):\n        return self._config.loss_scale == 0\n\n    def initial_dynamic_scale(self):\n        return self._config.initial_dynamic_scale\n\n    def dynamic_loss_scale_args(self):\n        return self._config.dynamic_loss_scale_args\n\n    def swap_tensor_config(self):\n        return self._config.swap_tensor_config\n\n    def aio_config(self):\n        return self._config.aio_config\n\n    def get_data_types(self):\n        model_dtype = torch.float32\n        if self.fp16_enabled():\n            model_dtype = torch.float16\n        elif self.bfloat16_enabled():\n            model_dtype = torch.bfloat16\n\n        if self._config.grad_accum_dtype is None:\n            if model_dtype == torch.bfloat16 and not self.zero_optimization():\n                grad_accum_dtype = torch.float32\n            else:\n                grad_accum_dtype = model_dtype\n        else:\n            grad_accum_dtype = DtypeEnum(self._config.grad_accum_dtype).value\n\n        return (model_dtype, grad_accum_dtype)\n\n    def _optimizer_has_ckpt_event_prologue(self):\n        return self.optimizer is not None and hasattr(self.optimizer, 'checkpoint_event_prologue')\n\n    def _optimizer_has_ckpt_event_epilogue(self):\n        return self.optimizer is not None and hasattr(self.optimizer, 'checkpoint_event_epilogue')\n\n    def _configure_lr_scheduler(self, client_lr_scheduler):\n        # First check for scheduler in json configuration\n        lr_scheduler = self._scheduler_from_config(self.optimizer)\n        if lr_scheduler:\n            log_dist(f\"DeepSpeed using configured LR scheduler = {self.scheduler_name()}\", ranks=[0])\n            self.lr_scheduler = lr_scheduler\n        else:\n            if isinstance(client_lr_scheduler, Callable):\n                log_dist('DeepSpeed using client callable to create LR scheduler', ranks=[0])\n                self.lr_scheduler = client_lr_scheduler(self.basic_optimizer)\n            else:\n                log_dist('DeepSpeed using client LR scheduler', ranks=[0])\n                self.lr_scheduler = client_lr_scheduler\n\n        log_dist(f'DeepSpeed LR Scheduler = {self.lr_scheduler}', ranks=[0])\n\n    def _configure_checkpointing(self, dist_init_required):\n        self.checkpoint_engine = TorchCheckpointEngine()\n\n        if self._config is not None and self._config.nebula_config.enabled:\n            try:\n                from deepspeed.runtime.checkpoint_engine.nebula_checkpoint_engine import \\\n                    NebulaCheckpointEngine\n                self.checkpoint_engine = NebulaCheckpointEngine(config_params=self._config.nebula_config)\n            except ImportError as err:\n                logger.error(f\"No torch_nebula was found! Will fall back to torch.save. Details: {err}\")\n                self.checkpoint_engine = TorchCheckpointEngine()\n\n        dp_rank = groups._get_sequence_data_parallel_rank()\n\n        rank = self.local_rank if self.use_node_local_storage() else dp_rank\n\n        # only the first data parallel process needs to store the model checkpoint\n        # if you want to use node local storage this must be done by rank 0 on each\n        # node\n        self.save_non_zero_checkpoint = (rank == 0) or (self.zero_optimization_partition_weights()\n                                                        and self.is_first_weights_partition_group())\n\n        if self.zero_optimization() or self.bfloat16_enabled():\n            param_rank = dist.get_rank(group=self.optimizer.zp_process_group)\n\n            # Only the first parameter parallel process needs to store the\n            # optimizer state checkpoints for zero\n            self.save_zero_checkpoint = param_rank == dp_rank\n\n    def _scheduler_from_config(self, optimizer):\n        scheduler_name = self.scheduler_name()\n        if scheduler_name is not None:\n            if hasattr(lr_schedules, scheduler_name):\n                scheduler = getattr(lr_schedules, scheduler_name)\n            else:\n                assert hasattr(torch.optim.lr_scheduler,\n                               scheduler_name), f\"DeepSpeed does not recognize LR scheduler {scheduler_name}\"\n\n                scheduler = getattr(torch.optim.lr_scheduler, scheduler_name)\n\n            scheduler_params = self.scheduler_params()\n            instantiated_scheduler = scheduler(optimizer, **scheduler_params)\n            return instantiated_scheduler\n        else:\n            return None\n\n    def _set_distributed_vars(self, args):\n        device_rank = args.device_rank if args is not None and hasattr(args, 'device_rank') else self.local_rank\n        if device_rank >= 0:\n            get_accelerator().set_device(device_rank)\n            self.device = torch.device(get_accelerator().device_name(), device_rank)\n            self.world_size = dist.get_world_size()\n            self.global_rank = dist.get_rank()\n        else:\n            self.world_size = 1\n            self.global_rank = 0\n            self.device = torch.device(get_accelerator().device_name())\n\n    # Configure based on command line arguments\n    def _configure_with_arguments(self, args, mpu):\n        # After the distributed backend is initialized we are guaranteed the LOCAL_RANK\n        # environment variable is set. We must align args.local_rank to this value for\n        # backwards compatibility with scripts relying on [args|self].local_rank containing\n        # the correct local rank info. _do_args_sanity_check will ensure this is the case.\n\n        if \"OMPI_COMM_WORLD_LOCAL_RANK\" in os.environ:\n            ompi_local_rank = os.environ.get(\"OMPI_COMM_WORLD_LOCAL_RANK\")\n            local_rank = os.environ.get('LOCAL_RANK', ompi_local_rank)\n            assert ompi_local_rank == local_rank, f\"LOCAL_RANK ({local_rank}) != OMPI_COMM_WORLD_LOCAL_RANK ({ompi_local_rank}), \" \\\n                \"not sure how to proceed as we're seeing conflicting local rank info.\"\n            os.environ['LOCAL_RANK'] = local_rank\n\n        self.local_rank = int(os.environ['LOCAL_RANK'])\n        if hasattr(args, 'local_rank'):\n            args.local_rank = self.local_rank\n\n    # Validate command line arguments\n    def _do_args_sanity_check(self, args):\n        assert \"LOCAL_RANK\" in os.environ or \"OMPI_COMM_WORLD_LOCAL_RANK\" in os.environ, \"DeepSpeed requires the LOCAL_RANK environment \" \\\n            \"variable, it is set by the deepspeed launcher, deepspeed.init_distributed, or the torch's launcher. If using a \" \\\n            \"different launcher please ensure LOCAL_RANK is set prior to initializing deepspeed.\"\n\n        if hasattr(args, 'local_rank') and args.local_rank is not None:\n            assert isinstance(args.local_rank,\n                              int), f\"args.local_rank of {args.local_rank} is an unknown type {type(args.local_rank)}\"\n            if args.local_rank >= 0:\n                env_local_rank = int(os.environ.get(\"LOCAL_RANK\"))\n                assert (\n                    env_local_rank == args.local_rank\n                ), f\"Mismatch in local rank setting, args.local_rank={args.local_rank} but env['LOCAL_RANK']={env_local_rank}.\"\n\n    def _is_supported_optimizer(self, optimizer_name):\n        return (optimizer_name in DEEPSPEED_OPTIMIZERS or getattr(torch.optim, optimizer_name, None) is not None)\n\n    def _supported_optims(self):\n        FairseqOptimizer = None\n        try:\n            from fairseq.optim.fairseq_optimizer import FairseqOptimizer\n        except ImportError:\n            pass\n\n        expected_optim_types = [Optimizer]\n        if FairseqOptimizer:\n            # fairseq optims are not torch.optim objects\n            expected_optim_types.append(FairseqOptimizer)\n        return expected_optim_types\n\n    # Validate configuration based on command line arguments\n    def _do_sanity_check(self):\n        expected_optim_types = self._supported_optims()\n        expected_optim_types += [type(None), Callable]\n        assert isinstance(self.client_optimizer, tuple(expected_optim_types)), \\\n            f'Client Optimizer is of unexpected type {type(self.client_optimizer)}'\n\n        if not self.client_optimizer:\n            if self.optimizer_name() is not None:\n                assert self._is_supported_optimizer(\n                    self.optimizer_name()), \"{} is not a supported DeepSpeed Optimizer\".format(self.optimizer_name())\n\n        if (self.optimizer_name() == LAMB_OPTIMIZER or self.optimizer_name() == ONEBIT_LAMB_OPTIMIZER):\n            assert (self.dynamic_loss_scale()), \"DeepSpeed {} optimizer requires dynamic loss scaling\".format(\n                self.optimizer_name())\n\n        # Detect invalid combinations of client optimizer and client scheduler\n        if isinstance(self.client_lr_scheduler, _LRScheduler):\n            assert isinstance(self.client_optimizer, Optimizer), \\\n                f'Client Optimizer (type = {type(self.client_optimizer)} is not instantiated but Client LR Scheduler is instantiated'\n\n    def _broadcast_model(self):\n\n        def is_replicated(p):\n            if hasattr(p, \"ds_status\") and p.ds_status is not ZeroParamStatus.AVAILABLE:\n                return False\n            return True\n\n        for p in self.module.parameters():\n            # Broadcast the model for different parameters\n            if is_moe_param(p):\n                if torch.is_tensor(p) and is_replicated(p):\n                    dist.broadcast(p,\n                                   groups._get_expert_broadcast_src_rank(p.group_name),\n                                   group=self.expert_data_parallel_group[p.group_name])\n            else:\n                if torch.is_tensor(p) and is_replicated(p):\n                    dist.broadcast(p, groups._get_broadcast_src_rank(), group=self.seq_data_parallel_group)\n\n    @staticmethod\n    def __check_params(model: Module, dtype: torch.dtype) -> None:\n        return\n        if not all(param.dtype == dtype for param in model.parameters()) and dist.get_rank() == 0:\n            raise ValueError(f\"{dtype} is enabled but the following parameters have dtype that is \"\n                             f\"not {dtype}: \"\n                             f\"{[(n, p.dtype) for n, p in model.named_parameters() if p.dtype != dtype]}\")\n\n    def _set_client_model(self, model):\n        # register client model in _modules so that nn.module methods work correctly\n        modules = self.__dict__.get('_modules')\n        modules['module'] = model\n        # register module attribute in engine but avoid getattr\n        self.__dict__['module'] = model\n\n    def _configure_distributed_model(self, model):\n        self._set_client_model(model)\n        is_zero_init_model = self.zero_optimization_partition_weights() and any(\n            [hasattr(param, \"ds_id\") for param in self.module.parameters()])\n\n        if self.fp16_enabled():\n            if is_zero_init_model:\n                self.__check_params(self.module, torch.half)\n            self.module.half()\n        elif self.bfloat16_enabled():\n            if is_zero_init_model:\n                self.__check_params(self.module, torch.bfloat16)\n            self.module.bfloat16()\n        else:\n            self.__check_params(self.module, torch.float)\n\n        # zero.Init() handles device placement of model\n        if not (self.dont_change_device or is_zero_init_model):\n            self.module.to(self.device)\n\n        # MoE related initialization\n        for _, module in self.module.named_modules():\n            if isinstance(module, MoE):\n                self.has_moe_layers = True\n                self.num_experts.append(module.num_experts)\n\n        if self.has_moe_layers:\n            for _, module in self.module.named_modules():\n                if isinstance(module, TopKGate):\n                    self.gate_modules.append(module)\n                    if self.wall_clock_breakdown():\n                        module.wall_clock_breakdown = True\n                if isinstance(module, MOELayer):\n                    self.moe_layers.append(module)\n                    if self.wall_clock_breakdown():\n                        module.wall_clock_breakdown = True\n\n        # Pass the mpu from here to groups. For subsequent use, just query groups\n        if self.mpu is not None:\n            groups.mpu = self.mpu\n\n        # Set deepspeed parallelism spec. for the model including expert parallelism\n        for _, module in self.module.named_modules():\n            if hasattr(module, 'set_deepspeed_parallelism'):\n                module.set_deepspeed_parallelism(self._config.use_data_before_expert_parallel_)\n\n        # Query the groups module to get information about various parallel groups\n        self.local_all_to_all_group = None\n        if self.zero_quantized_gradients():\n            log_dist(\"Using quantized gradients\", ranks=[0])\n            self.local_all_to_all_group = groups._get_local_all_to_all_group()\n        self.data_parallel_group = groups._get_data_parallel_group()\n        self.dp_world_size = groups._get_data_parallel_world_size()\n        self.zp_world_size = zp_manager.zp_size\n        self.seq_data_parallel_group = groups._get_sequence_data_parallel_group()\n        self.seq_dp_world_size = groups._get_sequence_data_parallel_world_size()\n        self.mp_world_size = groups._get_model_parallel_world_size()\n        self.expert_parallel_group = groups._get_expert_parallel_group_dict()\n        self.expert_data_parallel_group = groups._get_expert_data_parallel_group_dict()\n        self.sequence_parallel_size = groups._get_sequence_parallel_world_size()\n        if self.sequence_parallel_size > 1:\n            self.communication_data_type = self._config.seq_parallel_communication_data_type\n\n        if not (self.amp_enabled() or is_zero_init_model):\n            self._broadcast_model()\n\n    # check if parameters are duplicated in optimizer param_groups\n    def _check_for_duplicates(self, optimizer):\n        for name, param in self.module.named_parameters():\n            param_id = id(param)\n\n            def ids_list(group):\n                return [id(param) for param in group]\n\n            occurrence = sum([\n                ids_list(group['params']).count(param_id) if param_id in ids_list(group['params']) else 0\n                for group in optimizer.param_groups\n            ])\n            assert occurrence <= 1, f\"Parameter with name: {name} occurs multiple times in optimizer.param_groups. Make sure it only appears once to prevent undefined behavior.\"\n\n    def _do_optimizer_sanity_check(self, basic_optimizer):\n        model_dtype, grad_accum_dtype = self.get_data_types()\n        zero_enabled = self.zero_optimization()\n        amp_enabled = self.amp_enabled()\n        # config based assertions\n        assert (\n            not (amp_enabled and zero_enabled)\n        ), \"Amp and ZeRO are not currently compatible, please use (legacy) fp16 mode which performs similar to amp opt_mode=O2\"\n        if zero_enabled:\n            if not is_zero_supported_optimizer(basic_optimizer):\n                assert (\n                    self.zero_allow_untested_optimizer()\n                ), 'You are using an untested ZeRO Optimizer. Please add <\"zero_allow_untested_optimizer\": true> in the configuration file to use it.'\n\n                if self.global_rank == 0:\n                    logger.warning(\"**** You are using ZeRO with an untested optimizer, proceed with caution *****\")\n            if model_dtype == torch.bfloat16 and grad_accum_dtype == torch.float32 and self.zero_optimization_stage(\n            ) == 1 and not self.zero_cpu_offload():\n                return BFLOAT16\n            return ZERO_OPTIMIZATION\n        elif amp_enabled:\n            if model_dtype != grad_accum_dtype:\n                raise NotImplementedError(\n                    \"Model data type and gradient accumulation data type must be equal to use Amp\")\n            if model_dtype == torch.bfloat16 or model_dtype == torch.float16:\n                raise NotImplementedError(\"Cannot enable both amp with (legacy) fp16 or bfloat16 mode\")\n            try:\n                logger.info(\"Initializing Apex amp from: {}\".format(amp.__path__))\n            except NameError:\n                # If apex/amp is available it will be imported above\n                raise RuntimeError(\"Unable to import apex/amp, please make sure it is installed\")\n            return AMP\n        # data type checks\n        elif model_dtype == grad_accum_dtype:\n            if model_dtype == torch.bfloat16:\n                if self.pipeline_parallelism:\n                    logger.warning(\n                        \"**** BF16 gradient accumulation is not safe numerically with large number of accumulation steps, proceed with caution *****\"\n                    )\n                    return BFLOAT16\n                else:\n                    raise NotImplementedError(\n                        \"Bfloat16 wrapper must use a gradient accumulation type of fp32, enable ZeRO to use Bfloat16 gradient accumulation\"\n                    )\n            if model_dtype == torch.float16:\n                return FP16\n            # else optimizer_wrapper = None\n        elif model_dtype == torch.bfloat16 and grad_accum_dtype == torch.float32:\n            return BFLOAT16\n        else:\n            raise NotImplementedError(\"unsupported mix of model dtype and gradient accumulation type\")\n\n        return None\n\n    # Configure optimizer\n    def _configure_optimizer(self, client_optimizer, model_parameters):\n        if client_optimizer is None:\n            basic_optimizer = self._configure_basic_optimizer(model_parameters)\n            log_dist(f\"Using DeepSpeed Optimizer param name {self.optimizer_name()} as basic optimizer\", ranks=[0])\n        else:\n            if isinstance(client_optimizer, tuple(self._supported_optims())):\n                basic_optimizer = client_optimizer\n                log_dist('Using client Optimizer as basic optimizer', ranks=[0])\n            else:\n                basic_optimizer = client_optimizer(model_parameters)\n                log_dist('Using client callable to create basic optimizer', ranks=[0])\n\n            if self.zero_use_cpu_optimizer() and not isinstance(basic_optimizer, deepspeed.ops.adam.DeepSpeedCPUAdam):\n                if self.zero_force_ds_cpu_optimizer():\n                    msg = f'You are using ZeRO-Offload with a client provided optimizer ({type(basic_optimizer)}) which in most cases will yield poor performance. Please either use deepspeed.ops.adam.DeepSpeedCPUAdam or set an optimizer in your ds-config (https://www.deepspeed.ai/docs/config-json/#optimizer-parameters). If you really want to use a custom optimizer w. ZeRO-Offload and understand the performance impacts you can also set <\"zero_force_ds_cpu_optimizer\": false> in your configuration file.'\n                    raise ZeRORuntimeException(msg)\n\n        basic_optimizer.param_groups[:] = [pg for pg in basic_optimizer.param_groups if len(pg[\"params\"]) != 0]\n        log_dist(\"Removing param_group that has no 'params' in the basic Optimizer\", ranks=[0])\n\n        self._check_for_duplicates(basic_optimizer)\n\n        self.basic_optimizer = basic_optimizer\n        log_dist(\"DeepSpeed Basic Optimizer = {}\".format(basic_optimizer.__class__.__name__), ranks=[0])\n\n        optimizer_wrapper = self._do_optimizer_sanity_check(basic_optimizer)\n\n        if optimizer_wrapper == ZERO_OPTIMIZATION:\n            self.optimizer = self._configure_zero_optimizer(basic_optimizer)\n        elif optimizer_wrapper == AMP:\n            amp_params = self.amp_params()\n            log_dist(f\"Initializing AMP with these params: {amp_params}\", ranks=[0])\n            model, self.optimizer = amp.initialize(self.module, basic_optimizer, **amp_params)\n            self._set_client_model(model)\n            self._broadcast_model()\n            # TODO: maybe need to broadcast experts differently?\n        elif optimizer_wrapper == FP16:\n            self.optimizer = self._configure_fp16_optimizer(basic_optimizer)\n        elif optimizer_wrapper == BFLOAT16:\n            self.optimizer = self._configure_bf16_optimizer(basic_optimizer)\n        else:\n            self.optimizer = basic_optimizer\n\n        log_dist(\"DeepSpeed Final Optimizer = {}\".format(self.optimizer_name()), ranks=[0])\n\n        self.compression_scheduler = self._configure_compression_scheduler()\n        self.quantizer = self._configure_quantization()\n\n    def _configure_basic_optimizer(self, model_parameters):\n        optimizer_parameters = self.optimizer_params()\n        if optimizer_parameters is None:\n            optimizer_parameters = {}\n        # print(optimizer_parameters.keys())\n        if \"max_grad_norm\" in optimizer_parameters.keys():\n            raise ValueError(\n                \"'max_grad_norm' is not supported as an optimizer parameter, please switch to using the deepspeed parameter 'gradient_clipping' see: https://www.deepspeed.ai/docs/config-json/#gradient-clipping for more details\"\n            )\n\n        if self.optimizer_name() in [ADAM_OPTIMIZER, ADAMW_OPTIMIZER]:\n            torch_adam = optimizer_parameters.pop(TORCH_ADAM_PARAM, False)\n            adam_w_mode = optimizer_parameters.pop(ADAM_W_MODE, ADAM_W_MODE_DEFAULT)\n\n            # Optimizer name of Adam forces AdamW logic unless adam_w_mode is explicitly set\n            effective_adam_w_mode = self.optimizer_name() == ADAMW_OPTIMIZER or adam_w_mode\n\n            if torch_adam:\n                if not effective_adam_w_mode:\n                    optimizer = torch.optim.Adam(model_parameters, **optimizer_parameters)\n                else:\n                    optimizer = torch.optim.AdamW(model_parameters, **optimizer_parameters)\n            else:\n                if self.zero_use_cpu_optimizer():\n                    from deepspeed.ops.adam import DeepSpeedCPUAdam\n                    optimizer = DeepSpeedCPUAdam(model_parameters,\n                                                 **optimizer_parameters,\n                                                 adamw_mode=effective_adam_w_mode)\n                else:\n                    from deepspeed.ops.adam import FusedAdam\n\n                    optimizer = FusedAdam(\n                        model_parameters,\n                        **optimizer_parameters,\n                        adam_w_mode=effective_adam_w_mode,\n                    )\n\n        elif self.optimizer_name() == ADAGRAD_OPTIMIZER:\n            if self.zero_use_cpu_optimizer():\n                from deepspeed.ops.adagrad import DeepSpeedCPUAdagrad\n                optimizer = DeepSpeedCPUAdagrad(model_parameters, **optimizer_parameters)\n            else:\n                optimizer = torch.optim.Adagrad(model_parameters, **optimizer_parameters)\n        elif self.optimizer_name() == LAMB_OPTIMIZER:\n            from deepspeed.ops.lamb import FusedLamb\n\n            optimizer = FusedLamb(model_parameters, **optimizer_parameters)\n        elif self.optimizer_name() == ONEBIT_ADAM_OPTIMIZER:\n            assert not self.zero_optimization(), \"1bit-Adam is not compatible with ZeRO\"\n            from deepspeed.runtime.fp16.onebit.adam import OnebitAdam\n\n            optimizer = OnebitAdam(model_parameters, self, **optimizer_parameters)\n            if not self.fp16_enabled():\n                logger.warning(f\"Currently the convergence of 1-bit Adam is only verified under FP16\")\n        elif self.optimizer_name() == ZERO_ONE_ADAM_OPTIMIZER:\n            assert not self.zero_optimization(), \"0/1 Adam is not compatible with ZeRO\"\n            from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam\n\n            optimizer = ZeroOneAdam(model_parameters, self, **optimizer_parameters)\n            if not self.fp16_enabled():\n                logger.warning(f'Currently the convergence of 0/1 Adam is only verified under FP16')\n        elif self.optimizer_name() == ONEBIT_LAMB_OPTIMIZER:\n            assert not self.zero_optimization(), \"1bit-Lamb is not compatible with ZeRO\"\n            from deepspeed.runtime.fp16.onebit.lamb import OnebitLamb\n\n            optimizer = OnebitLamb(model_parameters, self, **optimizer_parameters)\n            if not self.fp16_enabled():\n                logger.warning(f\"Currently the convergence of 1-bit Lamb is only verified under FP16\")\n        elif self.optimizer_name() == LION_OPTIMIZER:\n            if self.zero_use_cpu_optimizer():\n                from deepspeed.ops.lion import DeepSpeedCPULion\n                optimizer = DeepSpeedCPULion(model_parameters, **optimizer_parameters)\n            else:\n                from deepspeed.ops.lion import FusedLion\n                optimizer = FusedLion(model_parameters, **optimizer_parameters)\n        elif self.optimizer_name() == MUADAM_OPTIMIZER:\n            try:\n                from mup import MuAdam\n            except ImportError:\n                logger.error(f\"Install mup to use MuAdam optimizer\")\n            optimizer = MuAdam(model_parameters, **optimizer_parameters)\n        elif self.optimizer_name() == MUADAMW_OPTIMIZER:\n            try:\n                from mup import MuAdamW\n            except ImportError:\n                logger.error(f\"Install mup to use MuAdamW optimizer\")\n            optimizer = MuAdamW(model_parameters, **optimizer_parameters)\n        elif self.optimizer_name() == MUSGD_OPTIMIZER:\n            try:\n                from mup import MuSGD\n            except ImportError:\n                logger.error(f\"Install mup to use MuSGD optimizer\")\n            optimizer = MuSGD(model_parameters, **optimizer_parameters)\n        else:\n            torch_optimizer = getattr(torch.optim, self.optimizer_name())\n            optimizer = torch_optimizer(model_parameters, **optimizer_parameters)\n        return optimizer\n\n    def _configure_compression_scheduler(self):\n        return compression_scheduler(self.module, self._config.compression_config)\n\n    def _configure_random_ltd_scheduler(self, configs):\n        return RandomLTDScheduler(configs)\n\n    def _configure_quantization(self):\n        (\n            quantize_weight_in_forward,\n            quantize_enabled,\n            q_groups,\n            q_mixed_fp16,\n            q_change_ratio,\n            q_type,\n            q_rounding,\n            q_verbose,\n            use_quantizer_kernel,\n        ) = self.quantize_training()\n        if quantize_enabled and not quantize_weight_in_forward:\n            assert self.fp16_enabled(\n            ), \"MoQ (quantize in optimization step) weight quantization is only supported for FP16\"\n        quantizer = None\n        if quantize_enabled and not quantize_weight_in_forward:\n            from deepspeed.runtime.quantize import Quantizer\n\n            quantizer = Quantizer(\n                q_groups,\n                q_mixed_fp16,\n                q_change_ratio,\n                q_type,\n                q_rounding,\n                q_verbose,\n                self.eigenvalue_enabled(),\n                use_quantizer_kernel,\n                self.eigenvalue_layer_num() if self.eigenvalue_enabled() else 0,\n            )\n        return quantizer\n\n    def _configure_fp16_optimizer(self, optimizer):\n        initial_dynamic_scale = self.initial_dynamic_scale()\n        dynamic_loss_args = self.dynamic_loss_scale_args()\n        clip_grad = self.gradient_clipping()\n        if APEX_INSTALLED:\n            fused_opts = (apex.optimizers.FusedAdam, FusedAdam)\n        else:\n            fused_opts = FusedAdam\n        if isinstance(optimizer, fused_opts) \\\n                or self.optimizer_name() in [ONEBIT_ADAM_OPTIMIZER, ZERO_ONE_ADAM_OPTIMIZER]:\n            if self.dynamic_loss_scale():\n                log_dist(f'Creating fp16 optimizer with dynamic loss scale', ranks=[0])\n                timers = self.timers if self.wall_clock_breakdown() else NoopTimer()\n                optimizer = FP16_Optimizer(\n                    optimizer,\n                    deepspeed=self,\n                    dynamic_loss_scale=True,\n                    initial_dynamic_scale=initial_dynamic_scale,\n                    dynamic_loss_args=dynamic_loss_args,\n                    mpu=self.mpu,\n                    clip_grad=clip_grad,\n                    fused_adam_legacy=self.optimizer_legacy_fusion(),\n                    timers=timers,\n                    has_moe_layers=self.has_moe_layers,\n                )\n            else:\n                log_dist(f'Creating fp16 optimizer with static loss scale: {self.loss_scale()}', ranks=[0])\n                optimizer = FP16_Optimizer(\n                    optimizer,\n                    deepspeed=self,\n                    static_loss_scale=self.loss_scale(),\n                    mpu=self.mpu,\n                    clip_grad=clip_grad,\n                    fused_adam_legacy=self.optimizer_legacy_fusion(),\n                    has_moe_layers=self.has_moe_layers,\n                )\n        else:\n            log_dist(f'Creating fp16 unfused optimizer with dynamic loss scale', ranks=[0])\n            optimizer = FP16_UnfusedOptimizer(\n                optimizer,\n                deepspeed=self,\n                static_loss_scale=self.loss_scale(),\n                dynamic_loss_scale=self.dynamic_loss_scale(),\n                dynamic_loss_args=dynamic_loss_args,\n                mpu=self.mpu,\n                clip_grad=clip_grad,\n                fused_lamb_legacy=self.optimizer_name() == LAMB_OPTIMIZER,\n            )\n\n        return optimizer\n\n    def _configure_bf16_optimizer(self, optimizer):\n        clip_grad = self.gradient_clipping()\n\n        if optimizer is None:\n            optimizer = DummyOptim(list(self.module.parameters()))\n\n        log_dist('Creating BF16 optimizer', ranks=[0])\n\n        timers = self.timers if self.wall_clock_breakdown() else NoopTimer()\n        optimizer = BF16_Optimizer(optimizer,\n                                   self.param_names,\n                                   mpu=self.mpu,\n                                   clip_grad=clip_grad,\n                                   allgather_bucket_size=self.zero_allgather_bucket_size(),\n                                   dp_process_group=self.seq_data_parallel_group,\n                                   timers=timers,\n                                   grad_acc_dtype=self.get_data_types()[1],\n                                   graph_harvesting=self.graph_harvesting())\n\n        return optimizer\n\n    def _configure_zero_optimizer(self, optimizer):\n        zero_stage = self.zero_optimization_stage()\n\n        mics_shard_size = self.mics_shard_size()\n        model_dtype, gradient_accumulation_dtype = self.get_data_types()\n\n        timers = self.timers if self.wall_clock_breakdown() else NoopTimer()\n\n        if optimizer is None:\n            optimizer = DummyOptim(list(self.module.parameters()))\n\n        if self.zero_legacy_stage1():\n            raise Exception(\n                \"The deprecated version of ZeRO Stage 1 is not supported in deepspeed >= 0.5.9. Please downgrade to a version less than 0.5.9 if you need to use this deprecated version of ZeRO.\"\n            )\n\n        if zero_stage <= ZeroStageEnum.gradients:\n            overlap_comm = self.zero_overlap_comm()\n            contiguous_gradients = self.zero_contiguous_gradients()\n            round_robin_gradients = self.zero_round_robin_gradients()\n            assert not isinstance(optimizer, DummyOptim), \"zero stage {} requires an optimizer\".format(zero_stage)\n\n            log_dist(f'Creating {model_dtype} ZeRO stage {zero_stage} optimizer', ranks=[0])\n            # Overlap and contiguous grads are meaningless in stage 1 and are ignored\n            if zero_stage == ZeroStageEnum.optimizer_states:\n                overlap_comm = False\n                round_robin_gradients = False\n                # Non-MoE requires contiguous grads to be disabled w. stage 1\n                if not self.has_moe_layers:\n                    contiguous_gradients = False\n\n            if isinstance(self.module, PipelineModule):\n                if overlap_comm:\n                    logger.warning(\"Pipeline parallelism does not support overlapped communication, will be disabled.\")\n                    overlap_comm = False\n            optimizer = DeepSpeedZeroOptimizer(\n                optimizer,\n                self.param_names,\n                timers=timers,\n                static_loss_scale=self.loss_scale(),\n                dynamic_loss_scale=self.dynamic_loss_scale(),\n                dynamic_loss_args=self.dynamic_loss_scale_args(),\n                clip_grad=self.gradient_clipping(),\n                contiguous_gradients=contiguous_gradients,\n                reduce_bucket_size=self.zero_reduce_bucket_size(),\n                use_multi_rank_bucket_allreduce=self.zero_multi_rank_bucket_allreduce(),\n                allgather_bucket_size=self.zero_allgather_bucket_size(),\n                dp_process_group=self.seq_data_parallel_group,\n                expert_parallel_group=self.expert_parallel_group if self.has_moe_layers else None,\n                expert_data_parallel_group=self.expert_data_parallel_group if self.has_moe_layers else None,\n                reduce_scatter=self.zero_reduce_scatter(),\n                overlap_comm=overlap_comm,\n                offload_optimizer_config=self.zero_offload_optimizer(),\n                mpu=self.mpu,\n                postscale_gradients=self.postscale_gradients(),\n                gradient_predivide_factor=self.gradient_predivide_factor(),\n                gradient_accumulation_steps=self.gradient_accumulation_steps(),\n                ignore_unused_parameters=self.zero_ignore_unused_parameters(),\n                partition_grads=zero_stage == ZeroStageEnum.gradients,\n                round_robin_gradients=round_robin_gradients,\n                has_moe_layers=self.has_moe_layers,\n                fp16_master_weights_and_gradients=self.fp16_master_weights_and_gradients(),\n                gradient_accumulation_dtype=gradient_accumulation_dtype,\n                communication_data_type=self.communication_data_type,\n                elastic_checkpoint=self.zero_elastic_checkpoint())\n\n        elif zero_stage == ZeroStageEnum.weights:\n            assert not self.has_moe_layers, \"MoE not supported with Stage 3\"\n            if isinstance(optimizer, DummyOptim):\n                log_dist(\"Creating ZeRO Offload\", ranks=[0])\n                zero_param_parallel_group = groups._get_zero_param_intra_parallel_group()\n                if self.zero_hpz_partition_size() > 1 and zero_param_parallel_group is None:\n                    self._set_zero_group_parallelism()\n                    zero_param_parallel_group = groups._get_zero_param_intra_parallel_group()\n                optimizer = DeepSpeedZeRoOffload(\n                    self.module,\n                    timers=timers,\n                    ds_config=self.config,\n                    overlap_comm=self.zero_overlap_comm(),\n                    prefetch_bucket_size=self.zero_prefetch_bucket_size(),\n                    max_reuse_distance=self.zero_max_reuse_distance(),\n                    max_live_parameters=self.zero_max_live_parameters(),\n                    param_persistence_threshold=self.zero_param_persistence_threshold(),\n                    model_persistence_threshold=self.zero_model_persistence_threshold(),\n                    offload_param_config=self.zero_offload_param(),\n                    mpu=self.mpu,\n                    zero_param_parallel_group=zero_param_parallel_group,\n                    zero_quantized_weights=self.zero_quantized_weights(),\n                    zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights(),\n                )\n            else:\n                log_dist(\n                    f'Creating fp16 ZeRO stage {zero_stage} optimizer,'\n                    f' MiCS is enabled {mics_shard_size>0},'\n                    f' Hierarchical params gather {self._config.mics_hierarchial_params_gather}',\n                    ranks=[0])\n                if mics_shard_size > 0:\n                    return self._return_mics_optimizer(optimizer, timers)\n\n                log_dist(f'Creating {model_dtype} ZeRO stage {zero_stage} optimizer', ranks=[0])\n                from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3\n                optimizer = DeepSpeedZeroOptimizer_Stage3(\n                    self.module,\n                    optimizer,\n                    timers=timers,\n                    ds_config=self.config,\n                    static_loss_scale=self.loss_scale(),\n                    dynamic_loss_scale=self.dynamic_loss_scale(),\n                    dynamic_loss_args=self.dynamic_loss_scale_args(),\n                    clip_grad=self.gradient_clipping(),\n                    contiguous_gradients=self.zero_contiguous_gradients(),\n                    reduce_bucket_size=self.zero_reduce_bucket_size(),\n                    prefetch_bucket_size=self.zero_prefetch_bucket_size(),\n                    max_reuse_distance=self.zero_max_reuse_distance(),\n                    max_live_parameters=self.zero_max_live_parameters(),\n                    param_persistence_threshold=self.zero_param_persistence_threshold(),\n                    model_persistence_threshold=self.zero_model_persistence_threshold(),\n                    dp_process_group=self.seq_data_parallel_group,\n                    all2all_process_group=self.local_all_to_all_group,\n                    reduce_scatter=self.zero_reduce_scatter(),\n                    overlap_comm=self.zero_overlap_comm(),\n                    offload_optimizer_config=self.zero_offload_optimizer(),\n                    offload_param_config=self.zero_offload_param(),\n                    sub_group_size=self.zero_sub_group_size(),\n                    offload_ratio=self.zero_partial_offload(),\n                    mpu=self.mpu,\n                    postscale_gradients=self.postscale_gradients(),\n                    gradient_predivide_factor=self.gradient_predivide_factor(),\n                    gradient_accumulation_steps=self.gradient_accumulation_steps(),\n                    aio_config=self.aio_config(),\n                    gradient_accumulation_dtype=gradient_accumulation_dtype,\n                    communication_data_type=self.communication_data_type,\n                    zero_hpz_partition_size=self.zero_hpz_partition_size(),\n                    zero_quantized_weights=self.zero_quantized_weights(),\n                    zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights(),\n                )\n\n        else:\n            raise NotImplementedError(\"ZeRO stage {} not implemented\".format(zero_stage))\n\n        return optimizer\n\n    def _return_mics_optimizer(self, basic_optimizer, timers):\n        from deepspeed.runtime.zero.mics import MiCS_Optimizer\n        model_dtype, gradient_accumulation_dtype = self.get_data_types()\n        optimizer = MiCS_Optimizer(self.module,\n                                   basic_optimizer,\n                                   timers=timers,\n                                   ds_config=self.config,\n                                   static_loss_scale=self.loss_scale(),\n                                   dynamic_loss_scale=self.dynamic_loss_scale(),\n                                   dynamic_loss_args=self.dynamic_loss_scale_args(),\n                                   clip_grad=self.gradient_clipping(),\n                                   contiguous_gradients=self.zero_contiguous_gradients(),\n                                   reduce_bucket_size=self.zero_reduce_bucket_size(),\n                                   prefetch_bucket_size=self.zero_prefetch_bucket_size(),\n                                   max_reuse_distance=self.zero_max_reuse_distance(),\n                                   max_live_parameters=self.zero_max_live_parameters(),\n                                   param_persistence_threshold=self.zero_param_persistence_threshold(),\n                                   model_persistence_threshold=self.zero_model_persistence_threshold(),\n                                   dp_process_group=self.seq_data_parallel_group,\n                                   reduce_scatter=self.zero_reduce_scatter(),\n                                   overlap_comm=self.zero_overlap_comm(),\n                                   offload_optimizer_config=self.zero_offload_optimizer(),\n                                   offload_param_config=self.zero_offload_param(),\n                                   sub_group_size=self.zero_sub_group_size(),\n                                   mpu=self.mpu,\n                                   postscale_gradients=self.postscale_gradients(),\n                                   gradient_predivide_factor=self.gradient_predivide_factor(),\n                                   gradient_accumulation_steps=self.gradient_accumulation_steps(),\n                                   aio_config=self.aio_config(),\n                                   gradient_accumulation_dtype=gradient_accumulation_dtype,\n                                   communication_data_type=self.communication_data_type)\n        return optimizer\n\n    def _configure_eigenvalue(self):\n        eigenvalue = Eigenvalue(\n            verbose=self.eigenvalue_verbose(),\n            max_iter=self.eigenvalue_max_iter(),\n            tol=self.eigenvalue_tol(),\n            stability=self.eigenvalue_stability(),\n            gas_boundary_resolution=self.eigenvalue_gas_boundary_resolution(),\n            layer_name=self.eigenvalue_layer_name(),\n            layer_num=self.eigenvalue_layer_num(),\n        )\n\n        return eigenvalue\n\n    def _configure_progressive_layer_drop(self):\n        pld = ProgressiveLayerDrop(theta=self.pld_theta(), gamma=self.pld_gamma())\n\n        return pld\n\n    def _configure_curriculum_scheduler_legacy(self):\n        scheduler = CurriculumScheduler(self.curriculum_params_legacy())\n        return scheduler\n\n    @staticmethod\n    def is_map_style_dataset(obj):\n        return hasattr(obj, \"__getitem__\") and hasattr(obj, \"__len__\")\n\n    @staticmethod\n    def is_iterable_style_dataset(obj):\n        return isinstance(obj, torch.utils.data.IterableDataset)  # hasattr(obj, \"__iter__\") should work as well\n\n    def dataloader_drop_last(self):\n        return self._config.dataloader_drop_last\n\n    def was_step_applied(self) -> bool:\n        \"\"\"Returns True if the latest ``step()`` produced in parameter updates.\n        Note that a ``False`` return is not an error condition. Steps are frequently\n        no-ops, such as between gradient accumulation boundaries or when overflows\n        occur.\n        Returns:\n            bool: Whether the latest ``step()`` modified model parameters.\n        \"\"\"\n        return self._step_applied\n\n    def deepspeed_io(self,\n                     dataset,\n                     batch_size=None,\n                     route=ROUTE_TRAIN,\n                     pin_memory=True,\n                     data_sampler=None,\n                     collate_fn=None,\n                     num_local_io_workers=None):\n        if not (self.is_map_style_dataset(dataset) or self.is_iterable_style_dataset(dataset)):\n            raise ValueError(\"Training data must be a torch Dataset\")\n\n        if batch_size is None:\n            batch_size = self.train_micro_batch_size_per_gpu()\n\n        if collate_fn is None:\n            collate_fn = self.collate_fn\n\n        # Currently we only use timer in train route\n        deepspeed_io_timer = None\n        if route == ROUTE_TRAIN:\n            deepspeed_io_timer = self.tput_timer\n\n        # If mpu is provided, forward world size and parallel rank to sampler.\n        data_parallel_world_size = self.dp_world_size\n        data_parallel_rank = self.global_rank\n        if self.mpu is not None:\n            data_parallel_world_size = self.mpu.get_data_parallel_world_size()\n            data_parallel_rank = self.mpu.get_data_parallel_rank()\n\n        if data_sampler is None and (route == ROUTE_PREDICT or route == ROUTE_EVAL):\n            data_sampler = torch.utils.data.DistributedSampler(\n                dataset,\n                num_replicas=data_parallel_world_size,\n                rank=data_parallel_rank,\n                shuffle=False,\n            )\n\n        deepspeed_dataloader_config = {}\n        if self.curriculum_learning_enabled():\n            deepspeed_dataloader_config = {\n                CURRICULUM_LEARNING: self.curriculum_learning_enabled(),\n                DATA_EFFICIENCY: self.data_efficiency_config(),\n                DATA_PARALLEL_GROUP: self.data_parallel_group,\n                GRADIENT_ACCUMULATION_STEPS: self.gradient_accumulation_steps(),\n                GLOBAL_RANK: self.global_rank,\n                DATA_SAMPLING_NUM_WORKERS: self.data_sampling_config()[DATA_SAMPLING_NUM_WORKERS]\n            }\n\n        return DeepSpeedDataLoader(dataset=dataset,\n                                   batch_size=batch_size,\n                                   pin_memory=pin_memory,\n                                   collate_fn=collate_fn,\n                                   local_rank=self.local_rank,\n                                   tput_timer=deepspeed_io_timer,\n                                   num_local_io_workers=num_local_io_workers,\n                                   data_sampler=data_sampler,\n                                   data_parallel_world_size=data_parallel_world_size,\n                                   data_parallel_rank=data_parallel_rank,\n                                   dataloader_drop_last=self.dataloader_drop_last(),\n                                   deepspeed_dataloader_config=deepspeed_dataloader_config)\n\n    def train(self, mode=True):\n        r\"\"\"\"\"\"\n\n        self.warn_unscaled_loss = True\n        self.module.train(mode)\n\n    def eval(self):\n        r\"\"\"\"\"\"\n\n        self.warn_unscaled_loss = True\n        self.module.train(False)\n\n    def _scale_loss_by_gas(self, prescaled_loss):\n        if isinstance(prescaled_loss, torch.Tensor):\n            scaled_loss = prescaled_loss / self.gradient_accumulation_steps()\n        elif isinstance(prescaled_loss, tuple) or isinstance(prescaled_loss, list):\n            scaled_loss = []\n            for l in prescaled_loss:\n                if isinstance(l, torch.Tensor):\n                    scaled_loss.append(l / self.gradient_accumulation_steps())\n                else:\n                    scaled_loss.append(l)\n        else:\n            scaled_loss = prescaled_loss\n            if self.warn_unscaled_loss:\n                logger.warning(f\"DeepSpeed unable to scale loss because of type: {type(prescaled_loss)}\")\n                self.warn_unscaled_loss = False\n\n        return scaled_loss\n\n    @instrument_w_nvtx\n    def forward(self, *inputs, **kwargs):\n        r\"\"\"Execute forward propagation\n        Arguments:\n            *inputs: Variable length input list\n            **kwargs: variable length keyword arguments\n        \"\"\"\n\n        if self.autotuning_profile_model_info():\n            ma = get_ma_status()\n        else:\n            see_memory_usage(\"Engine before forward\", force=self.memory_breakdown())\n\n        flops_profiler_active = (self.flops_profiler_enabled()\n                                 and self.global_steps == self.flops_profiler_profile_step() and self.global_rank == 0)\n\n        # used to check quantization happens at step 0!\n        if self.global_steps == 0 and hasattr(self, \"compression_scheduler\"):\n            self.compression_scheduler.step(step_zero_check=True)\n            if self.quantizer:\n                tensor_to_quantize = self.optimizer.bit16_groups if self.zero_optimization_stage(\n                ) == 2 else self.optimizer.fp16_groups\n                if self.compression_scheduler.weight_quantization_enabled:\n                    self.quantizer.quantize(\n                        tensor_to_quantize,\n                        (self.optimizer.overflow if self.fp16_enabled() else False),\n                        self.eigenvalue_enabled(),\n                        None,\n                    )\n\n        if flops_profiler_active:\n            self.flops_profiler.start_profile(ignore_list=None)\n\n        if self.module.training:\n            if self.progressive_layer_drop:\n                kwargs.update(self.progressive_layer_drop.get_state())\n\n        if self.__class__.__name__ != \"PipelineEngine\":\n            # TODO: The above if condition is a HACK since for PipelineEngine\n            # it's difficult to inject argument in forward pass.\n            if self.module.training and self.curriculum_enabled_legacy():\n                self.curriculum_scheduler_legacy.update_difficulty(self.global_steps + 1)\n                if self.curriculum_params_legacy()[\"curriculum_type\"] == \"seqlen\":\n                    kwargs.update({\"curriculum_seqlen\": self.curriculum_scheduler_legacy.get_current_difficulty()})\n\n        if self.module.training and self.random_ltd_enabled():\n            self.random_ltd_scheduler.update_seq(self.global_steps)\n\n        if self.zero_optimization_partition_weights():\n            # Enable automated discovery of external parameters by indicating that\n            # we are in a forward pass.\n            for module in self.module.modules():\n                module._parameters._in_forward = True\n\n        self._start_timers(self.engine_timers.forward_timers)\n\n        if self.training_dataloader is None:\n            self.tput_timer.start()\n\n        if self.fp16_auto_cast():\n            inputs = self._cast_inputs_half(inputs)\n        # print(f\"RANK[{self.global_rank}] self.fp16_auto_cast() is {self.fp16_auto_cast()}\")\n\n        loss = self.module(*inputs, **kwargs)\n\n        # print(f\"RANK[{self.global_rank}]'s loss is {loss}\")\n\n        if self.zero_optimization_partition_weights():\n            # Disable automated discovery of external parameters\n            for module in self.module.modules():\n                module._parameters._in_forward = False\n\n        self._stop_timers(self.engine_timers.forward_timers)\n\n        if flops_profiler_active:\n            self.flops_profiler.stop_profile()\n\n        if self.autotuning_profile_model_info():\n            activation_mem = get_ma_status() - ma\n            self.autotuning_model_info[\"activation_mem_per_gpu\"] = activation_mem\n            print_json_dist(self.autotuning_model_info, [0], path=self.autotuning_model_info_path())\n            exit()\n        else:\n            see_memory_usage(\"Engine after forward\", force=self.memory_breakdown())\n        return loss\n\n    def _cast_inputs_half(self, inputs):\n        if isinstance(inputs, (list, tuple)):\n            new_inputs = []\n            for v in inputs:\n                new_inputs.append(self._cast_inputs_half(v))\n            return inputs.__class__(new_inputs)\n        elif isinstance(inputs, dict):\n            new_inputs = {}\n            for k, v in inputs.items():\n                new_inputs[k] = self._cast_inputs_half(v)\n            return new_inputs\n        elif hasattr(inputs, 'half'):\n            return inputs.half()\n        else:\n            return inputs\n\n    def print_forward_breakdown(self, fwd_time):\n        gate_time = 0.0\n        moe_time = 0.0\n        falltoall = 0.0\n        salltoall = 0.0\n\n        for gate in self.gate_modules:\n            #logger.info(f\"Individual TopK gate time: {gate.gate_time:.2f} ms\")\n            gate_time += gate.gate_time\n\n        for l in self.moe_layers:\n            #logger.info(f\"MoE layer; total: {l.time_moe:.2f} ms, first alltoall: {l.time_falltoall:.2f}, second alltoall: {l.time_salltoall:.2f}\")\n            moe_time += l.time_moe\n            falltoall += l.time_falltoall\n            salltoall += l.time_salltoall\n\n        # TODO: Allreduce/average them across ranks for more accurate timing.\n\n        # if deepspeed.comm.get_rank() == 0:\n        log_dist(\n            f\"time (ms) | fwd: {fwd_time:.2f} (fwd_moe: {moe_time:.2f}, 1st_a2a: {falltoall:.2f}, 2nd_a2a: {salltoall:.2f}, top_k: {gate_time:.2f})\",\n            ranks=[0])\n\n    @instrument_w_nvtx\n    def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):\n        assert not (self.bfloat16_enabled() and self.pipeline_parallelism), \\\n            f'allreduce_gradients() is not valid when bfloat+pipeline_parallelism is enabled'\n\n        # Pass (PP) gas boundary flag to optimizer (required for zero)\n        self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary()\n        # ZeRO stage >= 2 communicates during non gradient accumulation boundaries as well\n        if self.zero_optimization_partition_gradients():\n            self.optimizer.overlapping_partition_gradients_reduce_epilogue()\n\n        # Communicate only at gradient accumulation boundaries\n        elif self.is_gradient_accumulation_boundary():\n            if self.zero_optimization_stage() == ZeroStageEnum.optimizer_states and hasattr(\n                    self.optimizer, 'reduce_gradients'):\n                self.optimizer.reduce_gradients(pipeline_parallel=self.pipeline_parallelism)\n            else:\n                self.buffered_allreduce_fallback(elements_per_buffer=bucket_size)\n\n    @instrument_w_nvtx\n    def backward(self, loss, allreduce_gradients=True, release_loss=False, retain_graph=False, scale_wrt_gas=True):\n        r\"\"\"Execute backward pass on the loss\n        Arguments:\n            loss: Torch tensor on which to execute backward propagation\n            allreduce_gradients: is deprecated, ignored, and will soon be removed'\n            retain_graph: bool, default: false\n                forward on user defined choice of retain_graph\n        \"\"\"\n\n        see_memory_usage(\"Engine before backward\", force=self.memory_breakdown())\n\n        if self.scale_wrt_gas is not None:\n            scale_wrt_gas = self.scale_wrt_gas\n\n        if not allreduce_gradients:\n            logger.warning(f\"Argument `allreduce_gradients` is deprecated, ignored, and will soon be removed\")\n\n        # scale loss w.r.t. gradient accumulation if needed\n        if self.gradient_accumulation_steps() > 1 and scale_wrt_gas:\n            loss = self._scale_loss_by_gas(loss.float())\n\n        # Log training loss\n        self.losses += loss.mean().item()\n        if self.monitor.enabled:\n            if self.is_gradient_accumulation_boundary():\n                if self.global_rank == 0:\n                    self.summary_events = [(\n                        f\"Train/Samples/train_loss\",\n                        self.losses,\n                        self.global_samples,\n                    )]\n                    self.monitor.write_events(self.summary_events)\n\n        self._start_timers(self.engine_timers.backward_timers)\n\n        assert self.optimizer is not None and not isinstance(self.optimizer, DummyOptim), \\\n            \"must provide optimizer during init in order to use backward\"\n\n        self._start_timers(self.engine_timers.backward_inner_timers)\n\n        if self.zero_optimization():\n            self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary()\n            self.optimizer.backward(loss, retain_graph=retain_graph)\n        elif self.amp_enabled():\n            # AMP requires delaying unscale when inside gradient accumulation boundaries\n            # https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations\n            delay_unscale = not self.is_gradient_accumulation_boundary()\n            with amp.scale_loss(loss, self.optimizer, delay_unscale=delay_unscale) as scaled_loss:\n                scaled_loss.backward(retain_graph=retain_graph)\n        elif self.fp16_enabled():\n            if self.eigenvalue_enabled():\n                self.optimizer.backward(loss, create_graph=True, retain_graph=True)\n            else:\n                self.optimizer.backward(loss, retain_graph=retain_graph)\n        elif self.bfloat16_enabled():\n            self.optimizer.backward(loss)\n        else:\n            if self.eigenvalue_enabled():\n                loss.backward(create_graph=True, retain_graph=True)\n            else:\n                loss.backward(retain_graph=retain_graph)\n\n        self._stop_timers(self.engine_timers.backward_inner_timers)\n\n        self._start_timers(self.engine_timers.backward_reduce_timers)\n\n        if allreduce_gradients and self.enable_backward_allreduce:\n            # Traditional code path that allreduces the module parameter grads\n            self.allreduce_gradients()\n\n        self._stop_timers(self.engine_timers.backward_reduce_timers)\n\n        self._stop_timers(self.engine_timers.backward_timers)\n\n        if release_loss:\n            # loss.data = None\n            pass\n\n        see_memory_usage(\"Engine after backward\", force=self.memory_breakdown())\n\n        return loss\n\n    def is_gradient_accumulation_boundary(self):\n        \"\"\"\n        Query whether the current micro-batch is at the boundary of\n        gradient accumulation, and thus will trigger gradient reductions and\n        an optimizer step.\n\n        Returns:\n            bool: if the current step is a gradient accumulation boundary.\n\n        \"\"\"\n        if self._is_gradient_accumulation_boundary is None:\n            return (self.micro_steps + 1) % \\\n                self.gradient_accumulation_steps() == 0\n        else:\n            return self._is_gradient_accumulation_boundary\n\n    def set_gradient_accumulation_boundary(self, is_boundary):\n        \"\"\"\n        Manually overrides the DeepSpeed engine's gradient accumulation boundary state, this is an optional\n        feature and should be used with care. The state should be set before to the intended\n        value before each forward/backward. The final forward/backward should have the\n        boundary state set to True. This style allows client code to only call engine.step() once after all\n        the gradient accumulation passes are complete. See example below:\n        .. code-block:: python\n        engine.set_gradient_accumulation_boundary(False)\n        for _ in range(gradient_accumulation_steps - 1):\n            micro_batch = next(data_loader)\n            loss = engine(micro_batch)\n            engine.backward(loss)\n        engine.set_gradient_accumulation_boundary(True)\n        micro_batch = next(data_loader)\n        loss = engine(micro_batch)\n        engine.backward(loss)\n        engine.step()\n        Arguments:\n            is_boundary (bool): are we at a gradient accumulation boundary or not?\n        \"\"\"\n        self._is_gradient_accumulation_boundary = is_boundary\n        self.optimizer.is_gradient_accumulation_boundary = is_boundary\n\n    def zero_grad(self):\n        \"\"\"\n        Zero parameter grads.\n        \"\"\"\n        for param_name, param in self.module.named_parameters():\n            param.grad = None\n\n    def clip_fp32_gradients(self):\n        clip_grad_norm_(parameters=self.module.parameters(), max_norm=self.gradient_clipping(), mpu=self.mpu)\n\n    def _take_model_step(self, lr_kwargs, block_eigenvalue={}):\n        if self.gradient_clipping() > 0.0:\n            if not (self.fp16_enabled() or self.bfloat16_enabled() or self.amp_enabled() or self.zero_optimization()):\n                self.clip_fp32_gradients()\n            elif self.amp_enabled():\n                # AMP's recommended way of doing clipping\n                # https://nvidia.github.io/apex/advanced.html#gradient-clipping\n                master_params = amp.master_params(self.optimizer)\n                clip_grad_norm_(parameters=master_params, max_norm=self.gradient_clipping(), mpu=self.mpu)\n        self.optimizer.step()\n\n        if hasattr(self.optimizer, '_global_grad_norm'):\n            self._global_grad_norm = self.optimizer._global_grad_norm\n\n        # Quantize the updated parameter if there is no overflow\n        if self.quantizer:\n            tensor_to_quantize = self.optimizer.bit16_groups if self.zero_optimization_stage(\n            ) == 2 else self.optimizer.fp16_groups\n            if self.compression_scheduler.weight_quantization_enabled:\n                self.quantizer.quantize(\n                    tensor_to_quantize,\n                    (self.optimizer.overflow if self.fp16_enabled() else False),\n                    self.eigenvalue_enabled(),\n                    block_eigenvalue,\n                )\n        # zero grad in basic optimizer could be unreliable and may not exhibit\n        # the behavior that we want\n        if self.bfloat16_enabled():\n            # TODO: Temporary until bf16_optimizer and zero_optimizer are integrated\n            if self.zero_optimization() and hasattr(self.optimizer, \"zero_grad\"):\n                self.optimizer.zero_grad()\n            else:\n                pass\n        elif self.zero_optimization() or self.fp16_enabled() or self.amp_enabled():\n            self.optimizer.zero_grad()\n        else:\n            self.zero_grad()\n\n        report_progress = self.global_rank == 0 if self.global_rank else True\n\n        # Check overflow here since in DS fp16 optimizer, the overflow is updated in above step() function.\n        overflow = False\n        if hasattr(self.optimizer, \"overflow\"):\n            overflow = self.optimizer.overflow\n        self._step_applied = not overflow\n\n        if overflow:\n            self.skipped_steps += 1\n        else:\n            self.compression_scheduler.step()\n            if self.lr_scheduler is not None:\n                try:\n                    self.lr_scheduler.step(**(lr_kwargs or {}))\n                except TypeError:\n                    # XXX Hack to work with Megatron 2.0 and DeepSpeed pipelines.\n                    # We don't currently have a way to specify lr_kwargs from\n                    # pipe_engine.train_batch()\n                    self.lr_scheduler.step(self.train_batch_size())\n\n        if report_progress and (self.global_steps + 1) % self.steps_per_print() == 0:\n            self._report_progress(self.global_steps + 1)\n\n        self.losses = 0.0\n        self.global_steps += 1\n        self.global_samples += self.train_batch_size()\n\n    def step(self, lr_kwargs=None):\n        r\"\"\"Execute the weight update step after forward and backward propagation\n        on effective_train_batch.\n        \"\"\"\n        see_memory_usage(\"Engine before step\", force=self.memory_breakdown())\n\n        # Check early because self.global_steps is incremented at some point here.\n        # TODO: Delay self.global_steps increment until very end of this function.\n        flops_profiler_active = self.flops_profiler_enabled(\n        ) and self.global_steps == self.flops_profiler_profile_step() and self.global_rank == 0\n\n        self._start_timers(self.engine_timers.step_timers)\n\n        assert self.optimizer is not None and not isinstance(self.optimizer, DummyOptim), \\\n            \"must provide optimizer during init in order to use step\"\n\n        report_progress = False\n\n        self._step_applied = False  # assume False, will flip to True\n\n        # Update the model when we reach gradient accumulation boundaries\n        if self.is_gradient_accumulation_boundary():\n            self.gas_boundary_ctr += 1\n\n            if (self.eigenvalue_enabled() and (self.gas_boundary_ctr % self.eigenvalue_gas_boundary_resolution() == 0)\n                    and self.quantizer.any_precision_switch()):\n                log_dist(f\"computing eigenvalue...\", ranks=[0])\n                self.block_eigenvalue = self.eigenvalue.compute_eigenvalue(self.module, self.device,\n                                                                           self.optimizer.cur_scale)\n\n            if self.progressive_layer_drop:\n                self.progressive_layer_drop.update_state(self.global_steps)\n\n            if (self.eigenvalue_enabled() and not self.gas_boundary_ctr % self.eigenvalue_gas_boundary_resolution()\n                    and self.quantizer.any_precision_switch()):\n                self._take_model_step(lr_kwargs, self.block_eigenvalue)\n            else:\n                self._take_model_step(lr_kwargs)\n\n            report_progress = self.global_rank == 0 if self.global_rank else True\n\n        self.tput_timer.stop(global_step=self.is_gradient_accumulation_boundary(), report_speed=report_progress)\n\n        self._stop_timers(self.engine_timers.step_timers)\n\n        # Log learning rate\n        if self.monitor.enabled:\n            if self.is_gradient_accumulation_boundary():\n                if self.global_rank == 0:\n                    self.summary_events = [(f\"Train/Samples/lr\", self.get_lr()[0], self.global_samples)]\n\n                    if self.fp16_enabled() and hasattr(self.optimizer, \"cur_scale\"):\n                        self.summary_events.append((\n                            f\"Train/Samples/loss_scale\",\n                            self.optimizer.cur_scale,\n                            self.global_samples,\n                        ))\n\n                    if (self.eigenvalue_enabled()\n                            and not self.gas_boundary_ctr % self.eigenvalue_gas_boundary_resolution()):\n                        ev_values = self.block_eigenvalue.values()\n                        for i in range(len(ev_values)):\n                            self.summary_events.append((\n                                f\"Train/Eigenvalues/ModelBlockParam_{i}\",\n                                self.ev_values[i][0],\n                                self.global_samples,\n                            ))\n                    self.monitor.write_events(self.summary_events)\n\n        # Check flops profiling\n        if flops_profiler_active:\n            if self.autotuning_enabled():\n                self.flops = self.flops_profiler.get_total_flops() * 3\n                self.fwd_duration = self.flops_profiler.get_total_duration()\n            else:\n                self.flops_profiler.print_model_profile(\n                    profile_step=self.global_steps,\n                    module_depth=self.flops_profiler_module_depth(),\n                    top_modules=self.flops_profiler_top_modules(),\n                    detailed=self.flops_profiler_detailed(),\n                    output_file=self.flops_profiler_output_file(),\n                )\n            self.flops_profiler.end_profile()\n\n        if self.autotuning_enabled() and self.global_steps == (self.autotuning_end_profile_step() + 1):\n            self._autotuning_exit()\n\n        if self.wall_clock_breakdown():\n            # Log micro timing and reset\n            self.timers.log(names=self.engine_timers.micro_timers, memory_breakdown=self.memory_breakdown())\n\n        if self.wall_clock_breakdown() or self.flops_profiler_enabled():\n            # Log global timing and reset\n            if self.is_gradient_accumulation_boundary():\n                if self.monitor.enabled:\n                    self._write_monitor()\n\n                if self.has_moe_layers:\n                    fwd_time = self.timers(FORWARD_GLOBAL_TIMER).elapsed(reset=False)\n                    self.print_forward_breakdown(fwd_time=fwd_time)\n\n                self.timers.log(self.engine_timers.global_timers)\n\n        self.micro_steps += 1\n        see_memory_usage(\"Engine after step\", force=self.memory_breakdown())\n\n    def _start_timers(self, timer_names):\n        for name in timer_names:\n            self.timers(name).start()\n\n    def _stop_timers(self, timer_names):\n        record = self.is_gradient_accumulation_boundary() and \\\n            self.flops_profiler_enabled() and \\\n                (self.global_steps >= self.flops_profiler_profile_step())\n        for name in timer_names:\n            self.timers(name).stop(record=record)\n\n    def _autotuning_exit(self):\n        if self.global_rank == 0:\n            msg = self.timers.get_mean([\n                FORWARD_GLOBAL_TIMER,\n                BACKWARD_GLOBAL_TIMER,\n                STEP_GLOBAL_TIMER,\n            ], reset=False)\n            titer = 0.0\n            titer += msg[FORWARD_GLOBAL_TIMER] if FORWARD_GLOBAL_TIMER in msg else 0\n            titer += msg[BACKWARD_GLOBAL_TIMER] if BACKWARD_GLOBAL_TIMER in msg else 0\n            titer += msg[STEP_GLOBAL_TIMER] if STEP_GLOBAL_TIMER in msg else 0\n            titer *= self.gradient_accumulation_steps()\n            msg[\"latency\"] = titer\n            msg[\"FLOPS_per_gpu\"] = self.flops * 1_000_000 * self.gradient_accumulation_steps() / titer\n            msg[\"throughput\"] = self.train_batch_size() * 1_000_000 / \\\n                msg[\"latency\"]\n            print_json_dist(msg, [0], path=self.autotuning_metric_path())\n            log_dist(\n                f\"Wrote metrics to {self.autotuning_metric_path()}, {os.path.abspath(self.autotuning_metric_path())}\",\n                ranks=[0])\n            import atexit\n            atexit.register(print, \"Autotuning: done with running current ds config.\")\n        exit()\n\n    def _write_monitor(self):\n        if self.global_rank == 0:\n            self.summary_events = [\n                (\n                    f\"Train/Samples/elapsed_time_ms_forward\",\n                    self.timers(FORWARD_GLOBAL_TIMER).elapsed(reset=False),\n                    self.global_samples,\n                ),\n                (\n                    f\"Train/Samples/elapsed_time_ms_backward\",\n                    self.timers(BACKWARD_GLOBAL_TIMER).elapsed(reset=False),\n                    self.global_samples,\n                ),\n                (\n                    f\"Train/Samples/elapsed_time_ms_backward_inner\",\n                    self.timers(BACKWARD_INNER_GLOBAL_TIMER).elapsed(reset=False),\n                    self.global_samples,\n                ),\n                (\n                    f\"Train/Samples/elapsed_time_ms_backward_allreduce\",\n                    self.timers(BACKWARD_REDUCE_GLOBAL_TIMER).elapsed(reset=False),\n                    self.global_samples,\n                ),\n                (\n                    f\"Train/Samples/elapsed_time_ms_step\",\n                    self.timers(STEP_GLOBAL_TIMER).elapsed(reset=False),\n                    self.global_samples,\n                ),\n            ]\n            self.monitor.write_events(self.summary_events)\n\n    def _get_optimizer_param(self, param_name):\n        result = []\n        if not self.optimizer:\n            return result\n        for group in self.optimizer.param_groups:\n            if param_name in group:\n                result.append(group[param_name])\n            else:\n                result.append(0.0)\n        return result\n\n    def get_lr(self):\n        return self._get_optimizer_param(\"lr\")\n\n    def get_type(self):\n        return self._get_optimizer_param(\"type\")\n\n    def get_mom(self):\n        if self.optimizer_name() in [\"SGD\", \"RMSprop\"]:\n            return self._get_optimizer_param(\"momentum\")\n        else:\n            return self._get_optimizer_param(\"betas\")\n\n    def get_pld_theta(self):\n        if self.progressive_layer_drop:\n            return self.progressive_layer_drop.get_theta()\n        else:\n            return None\n\n    def _report_progress(self, step):\n        lr = self.get_lr()\n        mom = self.get_mom()\n        log_dist(f\"step={step}, skipped={self.skipped_steps}, lr={lr}, mom={mom}\", ranks=[0])\n\n    def allreduce_bucket(self, bucket, dp_group):\n        tensor = self.flatten(bucket)\n\n        tensor_to_allreduce = tensor\n\n        if self.communication_data_type != tensor.dtype:\n            tensor_to_allreduce = tensor.to(self.communication_data_type)\n\n        if self.postscale_gradients():\n            if self.gradient_predivide_factor() != 1.0:\n                tensor_to_allreduce.mul_(1.0 / self.gradient_predivide_factor())\n\n            dist.all_reduce(tensor_to_allreduce, group=dp_group)\n            if self.gradient_average:\n                if self.gradient_predivide_factor() != dist.get_world_size(group=dp_group):\n                    tensor_to_allreduce.mul_(self.gradient_predivide_factor() / dist.get_world_size(group=dp_group))\n        else:\n            tensor_to_allreduce.mul_(1. / dist.get_world_size(group=dp_group))\n            dist.all_reduce(tensor_to_allreduce, group=dp_group)\n\n        if self.communication_data_type != tensor.dtype and tensor is not tensor_to_allreduce:\n            tensor.copy_(tensor_to_allreduce)\n\n        return tensor\n\n    def allreduce_and_copy(self, small_bucket, dp_group):\n        allreduced = self.allreduce_bucket(small_bucket, dp_group)\n        for buf, synced in zip(small_bucket, self.unflatten(allreduced, small_bucket)):\n            buf.copy_(synced)\n\n    def allreduce_no_retain(self, bucket, dp_group, numel_per_bucket=500000000):\n        small_bucket = []\n        numel = 0\n        for tensor in bucket:\n            small_bucket.append(tensor)\n            numel = numel + tensor.numel()\n            if numel > numel_per_bucket:\n                self.allreduce_and_copy(small_bucket, dp_group)\n                small_bucket = []\n                numel = 0\n        if len(small_bucket) > 0:\n            self.allreduce_and_copy(small_bucket, dp_group)\n\n    def _get_gradients_for_reduction(self):\n        non_expert_grads = []\n        expert_grads = {}\n        if self.has_moe_layers:\n            for key in self.expert_data_parallel_group.keys():\n                expert_grads[key] = []\n\n        for param_name, param in self.module.named_parameters():\n            if not param.requires_grad:\n                continue\n\n            if param.grad is None:\n                # In cases where there is an imbalance of empty grads across\n                # ranks we must create empty grads, this will ensure that every\n                # rank is reducing the same size. In some cases it may make\n                # sense in the future to support the ability to average not\n                # w.r.t. world size but with a different value.\n                param.grad = torch.zeros(param.size(), dtype=param.dtype, device=param.device)\n\n            grad_data = param.grad.data\n            if param_name in self.sparse_tensor_module_names or grad_data.is_sparse:\n                # Call param.grad without data to avoid problem with setting of updated grads\n                grad_data = SparseTensor(param.grad)\n\n            if is_moe_param(param):\n                expert_grads[param.group_name].append(grad_data)\n            else:\n                non_expert_grads.append(grad_data)\n\n        return non_expert_grads, expert_grads\n\n    def _reduce_non_expert_gradients(self, grads, elements_per_buffer):\n        split_buckets = split_half_float_double_sparse(grads)\n        for _, bucket_tuple in enumerate(split_buckets):\n            bucket_type, bucket = bucket_tuple\n\n            if self.pipeline_parallelism:\n                dp_group = self.mpu.get_data_parallel_group()\n            else:\n                dp_group = groups._get_sequence_data_parallel_group()\n\n            if bucket_type == SparseTensor.type():\n                self.sparse_allreduce_no_retain(bucket, dp_group=dp_group)\n            else:\n                self.allreduce_no_retain(bucket, dp_group=dp_group, numel_per_bucket=elements_per_buffer)\n\n    def _reduce_expert_gradients(self, expert_grads, elements_per_buffer):\n        for ep_name, expert_grads_group in expert_grads.items():\n            expert_split_buckets = split_half_float_double_sparse(expert_grads_group)\n            for i, bucket_tuple in enumerate(expert_split_buckets):\n                bucket_type, bucket = bucket_tuple\n                if bucket_type == SparseTensor.type():\n                    self.sparse_allreduce_no_retain(bucket, groups._get_expert_data_parallel_group(ep_name))\n                else:\n                    # Separate between diff groups\n                    self.allreduce_no_retain(bucket,\n                                             dp_group=groups._get_expert_data_parallel_group(ep_name),\n                                             numel_per_bucket=elements_per_buffer)\n\n    def buffered_allreduce_fallback(self, grads=None, elements_per_buffer=500000000):\n        if grads is None:\n            non_expert_grads, expert_grads = self._get_gradients_for_reduction()\n        else:\n            assert not self.has_moe_layers, \"attempting to reduce grads in unsupported way w.r.t. MoE\"\n            non_expert_grads = grads\n\n        self._reduce_non_expert_gradients(non_expert_grads, elements_per_buffer)\n\n        if self.has_moe_layers:\n            self._reduce_expert_gradients(expert_grads, elements_per_buffer)\n\n    def sparse_allreduce_no_retain(self, bucket, dp_group):\n        allreduced_sparses = self.sparse_allreduce_bucket(bucket, dp_group)\n        # Densify sparse tensor and copy back to original location\n        for tensor in allreduced_sparses:\n            if tensor.is_sparse:\n                tensor.orig_dense_tensor.data = tensor.to_coo_tensor()\n            else:\n                tensor.orig_dense_tensor.copy_(tensor.to_dense())\n\n    def sparse_allreduce_bucket(self, bucket, dp_group):\n        sparse_list = []\n        for sparse in bucket:\n            sparse_list.append(self.sparse_allreduce(sparse, dp_group))\n        return sparse_list\n\n    def sparse_allreduce(self, sparse, dp_group):\n        original_data_type = sparse.values.dtype\n        if self.communication_data_type != sparse.values.dtype:\n            if self.communication_data_type in (torch.float16, torch.bfloat16):\n                indices = sparse.indices.to(torch.int32)\n            else:\n                indices = sparse.indices\n            values = sparse.values.to(self.communication_data_type)\n        else:\n            indices = sparse.indices\n            values = sparse.values\n\n        if self.postscale_gradients():\n            if self.gradient_average:\n                values.mul_(self.gradient_predivide_factor() /\n                            (dist.get_world_size(group=dp_group) / float(self.sequence_parallel_size)))\n        else:\n            values.mul_(1. / (dist.get_world_size(group=dp_group) / float(self.sequence_parallel_size)))\n\n        indices_device_list = self.sparse_all_gather(indices, dp_group)\n        values_device_list = self.sparse_all_gather(values, dp_group)\n\n        sparse.indices = torch.cat(indices_device_list).to(torch.long)\n        sparse.values = torch.cat(values_device_list).to(original_data_type)\n        return sparse\n\n    def sparse_all_gather(self, value, dp_group):\n        my_size = torch.LongTensor([value.size()[0]]).to(self.device)\n        all_sizes = self.all_gather_scalar(my_size, dp_group)\n        max_size = torch.cat(all_sizes).max()\n        fill_size = max_size - my_size\n\n        assert value.dim() in [1, 2]\n        if value.dim() == 1:\n            if fill_size > 0:\n                value = torch.cat([value, value.new_empty(fill_size)])\n            tensor_list = [value.new_empty(max_size) for _ in range(dist.get_world_size(group=dp_group))]\n        else:\n            if fill_size > 0:\n                value = torch.cat([value, value.new_empty(fill_size, value.size()[1])])\n            tensor_list = [\n                value.new_empty(max_size,\n                                value.size()[1]) for _ in range(dist.get_world_size(group=dp_group))\n            ]\n\n        dist.all_gather(tensor_list, value, group=dp_group)\n        tensors = []\n        for dev_idx, t in enumerate(tensor_list):\n            size = all_sizes[dev_idx][0]\n            tensors.append(t.index_select(0, torch.arange(size, dtype=torch.long, device=self.device)))\n\n        return tensors\n\n    def all_gather_scalar(self, value, dp_group):\n        tensor_list = [value.new_zeros(value.size()) for _ in range(dist.get_world_size(group=dp_group))]\n        dist.all_gather(tensor_list, value, group=dp_group)\n        return tensor_list\n\n    def module_state_dict(self, destination=None, prefix=\"\", keep_vars=False, exclude_frozen_parameters=False):\n        sd = self.module.state_dict(destination, prefix, keep_vars)\n\n        # Remove frozen parameter weights from state_dict if specified\n        if exclude_frozen_parameters:\n            for n, p in self.module.named_parameters():\n                if not p.requires_grad and n in sd:\n                    del sd[n]\n\n        if self.random_ltd_enabled():\n            sd = remove_random_ltd_state_dict(sd)\n        return sd\n\n    @staticmethod\n    def load_moe_state_dict(checkpoint_path,\n                            tag,\n                            state_dict,\n                            old_moe_load,\n                            model=None,\n                            mpu=None,\n                            num_experts=1,\n                            checkpoint_engine=TorchCheckpointEngine()):\n        if old_moe_load:\n            expp_rank = groups._get_expert_data_parallel_rank(groups._get_max_expert_size_name())\n\n            num_local_experts = max(num_experts) // groups._get_expert_parallel_world_size(\n                groups._get_max_expert_size_name())\n            for local_expert_id in range(num_local_experts):\n                global_expert_id = expp_rank * num_local_experts + local_expert_id\n                expert_state_dict = checkpoint_engine.load(\n                    DeepSpeedEngine._get_expert_ckpt_name(\n                        checkpoint_path,\n                        -1,  # -1 means ignore layer_id\n                        global_expert_id,\n                        tag,\n                        mpu),\n                    map_location=torch.device('cpu'))\n\n                # Updating global -> local expert ids\n                moe_str_prefix = '.deepspeed_moe.experts.deepspeed_experts.'\n                for key in list(expert_state_dict.keys()):\n                    local_key = key.replace(f'{moe_str_prefix}{global_expert_id}',\n                                            f'{moe_str_prefix}{local_expert_id}')\n                    expert_state_dict[local_key] = expert_state_dict.pop(key)\n                state_dict.update(expert_state_dict)\n\n        else:\n            moe_layer_id = 0\n            for n_module, module in model.named_modules():\n                if isinstance(module, MoE):  # and deepspeed.comm.get_rank() == 0:\n                    group_name = module.expert_group_name\n                    num_local_experts = module.num_local_experts\n                    expp_rank = groups._get_expert_parallel_rank(group_name)\n                    # loop all local_experts\n                    for local_expert_id in range(num_local_experts):\n                        global_expert_id = expp_rank * num_local_experts + local_expert_id\n                        expert_state_dict = checkpoint_engine.load(DeepSpeedEngine._get_expert_ckpt_name(\n                            checkpoint_path, moe_layer_id, global_expert_id, tag, mpu),\n                                                                   map_location=torch.device('cpu'))\n                        # print(expert_state_dict.keys())\n                        # Updating global -> local expert ids\n                        moe_str_prefix = '.deepspeed_moe.experts.deepspeed_experts.'\n                        for key in list(expert_state_dict.keys()):\n                            local_key = key.replace(f'{moe_str_prefix}{global_expert_id}',\n                                                    f'{moe_str_prefix}{local_expert_id}')\n                            expert_state_dict[local_key] = expert_state_dict.pop(key)\n                        state_dict.update(expert_state_dict)\n                    moe_layer_id += 1\n\n    def load_module_state_dict(self, checkpoint, strict=True, custom_load_fn=None, fetch_z3_params=False):\n        if fetch_z3_params:\n            params_to_fetch = [\n                p for p in self.module.parameters()\n                if hasattr(p, 'ds_id') and p.ds_status == ZeroParamStatus.NOT_AVAILABLE\n            ]\n        else:\n            params_to_fetch = []\n\n        with deepspeed.zero.GatheredParameters(params_to_fetch, modifier_rank=0):\n            module_state_dict = checkpoint['module']\n            if custom_load_fn:\n                custom_load_fn(src=module_state_dict, dst=self.module)\n            else:\n                self.module.load_state_dict(\n                    module_state_dict,  # TODO\n                    strict=strict)\n\n        if checkpoint.get(FROZEN_PARAM_FRAGMENTS, None) is not None:\n            saved_frozen_params = checkpoint[FROZEN_PARAM_FRAGMENTS]\n            for param in self.module.parameters():\n                if param.requires_grad:\n                    continue\n                if param not in self.param_names:\n                    raise ValueError(f\"failed to find frozen {param} in named params\")\n                name = self.param_names[param]\n                if hasattr(param, 'ds_id'):\n                    param.ds_tensor.data.copy_(saved_frozen_params[name].data)\n                else:\n                    param.data.copy_(saved_frozen_params[name].data)\n\n    def _get_zero_ckpt_prefix(self, dp_rank, bf16_mode):\n        return f'{\"bf16_\" if bf16_mode else \"\"}zero_pp_rank_{dp_rank}'\n\n    def _get_rank_zero_ckpt_name(self, checkpoints_path, tag, mp_rank, dp_rank, bf16_mode):\n        file_prefix = self._get_zero_ckpt_prefix(dp_rank, bf16_mode=bf16_mode)\n        zero_ckpt_name = os.path.join(\n            checkpoints_path,\n            str(tag),\n            f\"{file_prefix}_mp_rank_{mp_rank:02d}_optim_states.pt\",\n        )\n        return zero_ckpt_name\n\n    def _get_zero_ckpt_name(self, checkpoints_path, tag):\n        mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()\n        pp_rank = dist.get_rank(group=self.optimizer.zp_process_group)\n        bf16_mode = self.bfloat16_enabled()\n        return self._get_rank_zero_ckpt_name(checkpoints_path, tag, mp_rank, pp_rank, bf16_mode)\n\n    def _get_ckpt_name(self, checkpoints_path, tag, mp_placeholder=None):\n        if mp_placeholder is not None:\n            mp_rank_str = mp_placeholder\n        else:\n            mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()\n            mp_rank_str = f\"{mp_rank:02d}\"\n\n        if self.zero_optimization_partition_weights():\n            filename = \"zero_pp_rank_{}\".format(dist.get_rank(group=self.optimizer.zp_process_group))\n            ckpt_name = os.path.join(\n                checkpoints_path,\n                str(tag),\n                f\"{filename}_mp_rank_{mp_rank_str}_model_states.pt\",\n            )\n        else:\n            ckpt_name = os.path.join(\n                checkpoints_path,\n                str(tag),\n                \"mp_rank_\" + mp_rank_str + \"_model_states.pt\",\n            )\n        return ckpt_name\n\n    def _get_optimizer_ckpt_name(self, checkpoints_path, tag, expp_rank):\n        mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()\n        ckpt_name = os.path.join(checkpoints_path, str(tag),\n                                 f'expp_rank_{expp_rank}_mp_rank_{mp_rank:02d}_optim_states.pt')\n        return ckpt_name\n\n    @staticmethod\n    def _get_expert_ckpt_name(checkpoints_path, layer_id, expert_id, tag, mpu=None):\n        mp_rank = 0 if mpu is None else mpu.get_model_parallel_rank()\n        if layer_id <= -1:\n            # Used to support old checkpoint loading\n            ckpt_name = os.path.join(checkpoints_path, '' if tag is None else str(tag),\n                                     f'expert_{expert_id}_mp_rank_{mp_rank:02d}_model_states.pt')\n        else:\n            # Used to support new checkpoint loading\n            ckpt_name = os.path.join(checkpoints_path, '' if tag is None else str(tag),\n                                     f'layer_{layer_id}_expert_{expert_id}_mp_rank_{mp_rank:02d}_model_states.pt')\n        return ckpt_name\n\n    def _get_all_ckpt_names(self, checkpoints_path, tag):\n        # It is required that (checkpoints_path, tag) are consistent among all ranks.\n        ckpt_file_pattern = self._get_ckpt_name(checkpoints_path, tag, mp_placeholder=\"*\")\n        import glob\n\n        ckpt_files = glob.glob(ckpt_file_pattern)\n        ckpt_files.sort()\n        return ckpt_files\n\n    def load_checkpoint(self,\n                        load_dir,\n                        tag=None,\n                        load_module_strict=True,\n                        load_optimizer_states=True,\n                        load_lr_scheduler_states=True,\n                        load_module_only=False,\n                        custom_load_fn=None):\n        \"\"\"\n        Load training checkpoint\n\n        Arguments:\n            load_dir: Required. Directory to load the checkpoint from\n            tag: Checkpoint tag used as a unique identifier for checkpoint, if not provided will attempt to load tag in 'latest' file\n            load_module_strict: Optional. Boolean to strictly enforce that the keys in state_dict of module and checkpoint match.\n            load_optimizer_states: Optional. Boolean to load the training optimizer states from Checkpoint. Ex. ADAM's momentum and variance\n            load_lr_scheduler_states: Optional. Boolean to add the learning rate scheduler states from Checkpoint.\n            load_module_only: Optional. Boolean to load only the model weights from the checkpoint. Ex. warmstarting.\n            custom_load_fn: Optional. Custom model load function.\n\n        Returns:\n            A tuple of ``load_path`` and ``client_state``.\n            *``load_path``: Path of the loaded checkpoint. ``None`` if loading the checkpoint failed.\n            *``client_state``: State dictionary used for loading required training states in the client code.\n\n        Important: under ZeRO3, one cannot load checkpoint with ``engine.load_checkpoint()`` right\n        after ``engine.save_checkpoint()``. It is because ``engine.module`` is partitioned, and\n        ``load_checkpoint()`` wants a pristine model. If insisting to do so, please reinitialize engine\n        before ``load_checkpoint()``.\n\n        \"\"\"\n\n        if tag is None:\n            latest_tag = \"latest_universal\" if self.load_universal_checkpoint() else \"latest\"\n            latest_path = os.path.join(load_dir, latest_tag)\n            if os.path.isfile(latest_path):\n                with open(latest_path, \"r\") as fd:\n                    tag = fd.read().strip()\n            else:\n                if self.load_universal_checkpoint():\n                    raise ValueError(f'Invalid for universal checkpoint: {latest_path} does not exist')\n                else:\n                    logger.warning(\n                        f\"Unable to find latest file at {latest_path}, if trying to load latest \"\n                        \"checkpoint please ensure this file exists or pass an explicit checkpoint tag when loading a checkpoint.\"\n                    )\n                    return None, None\n\n        if self._optimizer_has_ckpt_event_prologue():\n            # Prepare for checkpoint load by ensuring all parameters are partitioned\n            self.optimizer.checkpoint_event_prologue()\n\n        load_path, client_states = self._load_checkpoint(load_dir,\n                                                         tag,\n                                                         load_module_strict=load_module_strict,\n                                                         load_optimizer_states=load_optimizer_states,\n                                                         load_lr_scheduler_states=load_lr_scheduler_states,\n                                                         load_module_only=load_module_only,\n                                                         custom_load_fn=custom_load_fn)\n\n        load_zero_checkpoint = load_optimizer_states and load_path is not None and (self.zero_optimization()\n                                                                                    or self.bfloat16_enabled())\n        if load_zero_checkpoint:\n            success = self._load_zero_checkpoint(load_dir, tag, load_optimizer_states=load_optimizer_states)\n            if not success:\n                self.optimizer._restore_from_bit16_weights()\n\n        if self._optimizer_has_ckpt_event_epilogue():\n            self.optimizer.checkpoint_event_epilogue()\n\n        if self.load_universal_checkpoint():\n            self.optimizer.update_lp_params()\n            if load_zero_checkpoint:\n                self.update_optimizer_step(step=client_states['iteration'] + 1)\n\n        return load_path, client_states\n\n    def _load_checkpoint(self,\n                         load_dir,\n                         tag,\n                         load_module_strict=True,\n                         load_optimizer_states=True,\n                         load_lr_scheduler_states=True,\n                         load_module_only=False,\n                         custom_load_fn=None):\n\n        from deepspeed.runtime.state_dict_factory import SDLoaderFactory\n\n        ckpt_list = self._get_all_ckpt_names(load_dir, tag)\n        sd_loader = SDLoaderFactory.get_sd_loader(ckpt_list, checkpoint_engine=self.checkpoint_engine)\n\n        is_pipe_parallel = isinstance(self.module, PipelineModule)\n\n        mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()\n        load_path, checkpoint, _ = sd_loader.load(self.mp_world_size, mp_rank, is_pipe_parallel=is_pipe_parallel)\n\n        if checkpoint is None:\n            return None, None\n\n        fetch_z3_params = False\n        if self.zero_optimization_partition_weights() and not load_optimizer_states:\n            checkpoint['module'] = get_fp32_state_dict_from_zero_checkpoint(load_dir)\n            fetch_z3_params = True\n\n        if is_pipe_parallel:\n            # Pipeline parallelism uses this to load its own checkpoint files.\n            self._curr_ckpt_path = os.path.join(load_dir, tag)\n\n        if self.has_moe_layers:\n            # print(checkpoint.keys())\n            old_moe_load = False\n            if not isinstance(checkpoint['num_experts'], list):\n                old_moe_load = True\n            DeepSpeedEngine.load_moe_state_dict(load_dir,\n                                                tag,\n                                                state_dict=checkpoint['module'],\n                                                old_moe_load=old_moe_load,\n                                                model=self.module,\n                                                mpu=self.mpu,\n                                                num_experts=self.num_experts,\n                                                checkpoint_engine=self.checkpoint_engine)\n        if not self.load_universal_checkpoint():\n            self.load_module_state_dict(checkpoint=checkpoint,\n                                        strict=load_module_strict,\n                                        custom_load_fn=custom_load_fn,\n                                        fetch_z3_params=fetch_z3_params)\n\n        self.loaded_checkpoint_dp_world_size = checkpoint['dp_world_size']\n        if 'zp_world_size' not in checkpoint:\n            checkpoint['zp_world_size'] = self.zp_world_size\n        self.loaded_checkpoint_zp_world_size = checkpoint['zp_world_size']\n\n        optim_checkpoint = None\n        if load_module_only:\n            deepspeed_states = ['module']\n            if self.optimizer is not None and self.fp16_enabled():\n                self.optimizer.refresh_fp32_params()\n        else:\n            has_zero_optimizer_state = self.zero_optimization() or self.bfloat16_enabled()\n            if load_optimizer_states and self.optimizer is not None and not has_zero_optimizer_state:\n                if self.has_moe_layers:\n                    largest_group_name = groups._get_max_expert_size_name()\n                    expp_rank = groups._get_expert_parallel_rank(largest_group_name)\n                    optim_load_path = self._get_optimizer_ckpt_name(load_dir, tag, expp_rank)\n                    optim_checkpoint = self.checkpoint_engine.load(optim_load_path, map_location=torch.device('cpu'))\n                else:\n                    optim_checkpoint = checkpoint\n\n                if self.fp16_enabled() or self.bfloat16_enabled():\n                    self.optimizer.load_state_dict(optim_checkpoint['optimizer'],\n                                                   load_optimizer_states=load_optimizer_states)\n                else:\n                    optim_checkpoint = checkpoint\n\n                self.optimizer.load_state_dict(optim_checkpoint['optimizer'])\n\n            if load_lr_scheduler_states and self.lr_scheduler is not None:\n                self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])\n\n            if self.random_ltd_enabled() and self.random_ltd_scheduler is not None and 'random_ltd' in checkpoint:\n                self.random_ltd_scheduler.load_state_dict(checkpoint['random_ltd'])\n\n            if self.training_dataloader is not None and self.curriculum_learning_enabled(\n            ) and 'data_sampler' in checkpoint:\n                self.training_dataloader.data_sampler.load_state_dict(checkpoint['data_sampler'])\n\n            def get_sparse_tensor_module_names(original_set, loaded_set, original_parameters, loaded_parameters):\n                result = set()\n\n                for name in original_set:\n                    if name in loaded_parameters and name not in loaded_set:\n                        continue  # parameter existed in previous model and was not sparse\n                    result.add(name)\n\n                for name in loaded_set:\n                    if name in original_parameters:\n                        result.add(name)  # parameter exists in both configs and it was sparse\n\n                return result\n\n            if 'sparse_tensor_module_names' in checkpoint:\n                sparse_tensor_module_names = checkpoint['sparse_tensor_module_names']\n            elif 'csr_tensor_module_names' in checkpoint:\n                sparse_tensor_module_names = checkpoint['csr_tensor_module_names']\n            else:\n                sparse_tensor_module_names = None\n            if sparse_tensor_module_names is not None:\n                if load_module_strict:\n                    self.sparse_tensor_module_names = sparse_tensor_module_names\n                else:\n                    self.sparse_tensor_module_names = get_sparse_tensor_module_names(\n                        self.sparse_tensor_module_names, sparse_tensor_module_names,\n                        dict(self.module.named_parameters()), checkpoint[\"module\"])\n\n            self.global_steps = checkpoint['global_steps']\n            self.global_samples = checkpoint.get('global_samples', self.global_steps * self.train_batch_size())\n            self.skipped_steps = checkpoint['skipped_steps']\n            self.loaded_checkpoint_mp_world_size = checkpoint['mp_world_size']\n            deepspeed_states = [\n                'module', 'sparse_tensor_module_names', 'skipped_steps', 'global_steps', 'zp_world_size',\n                'mp_world_size', 'data_sampler', 'random_ltd', 'dp_world_size',\n            ]\n        client_state = {}\n\n        if load_lr_scheduler_states:\n            deepspeed_states.append('lr_scheduler')\n        if load_optimizer_states:\n            deepspeed_states.append('optimizer')\n\n        client_state = {key: value for key, value in checkpoint.items() if not key in deepspeed_states}\n\n        if optim_checkpoint is not None:\n            client_state['optimizer'] = optim_checkpoint['optimizer']\n\n        return load_path, client_state\n\n    def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True):\n\n        load_serial = None\n        # When use loading checkpoint serial, checkpoint loading start from local rank 0,\n        # all other local rank would be paused, waiting for its rank-1 peer ready and its notification.\n        if self._config.zero_config.pipeline_loading_checkpoint:\n            assert self.zero_optimization_stage(\n            ) == ZeroStageEnum.weights, \"Only stage3 support for pipeline checkpoint loading\"\n            load_serial = torch.zeros(1).to(self.device)\n            if dist.get_local_rank() != 0:\n                dist.recv(tensor=load_serial, src=dist.get_rank() - 1)\n        if self.load_universal_checkpoint():\n            zero_sd_list = None\n            checkpoint_folder = f'{os.path.join(load_dir, tag)}'\n        else:\n            if load_optimizer_states and self.zp_world_size != self.loaded_checkpoint_zp_world_size:\n                raise ZeRORuntimeException(\"The checkpoint being loaded used a DP \" \\\n                    f\"world size of {self.loaded_checkpoint_zp_world_size} but the \" \\\n                    f\"current world size is {self.zp_world_size}. Automatic adjustment \" \\\n                    \"of ZeRO's optimizer state partitioning with a new world size is not \" \\\n                    \"currently supported.\")\n            checkpoint_folder = None\n            zero_sd_list = self._get_all_zero_checkpoints(load_dir, tag)\n            if zero_sd_list is None:\n                return False\n\n        self.optimizer.load_state_dict(state_dict_list=zero_sd_list,\n                                       load_optimizer_states=load_optimizer_states,\n                                       load_from_fp32_weights=self.zero_load_from_fp32_weights(),\n                                       checkpoint_folder=checkpoint_folder,\n                                       load_serial=load_serial)\n\n        if self.load_universal_checkpoint():\n            logger.info(f'loaded universal zero checkpoints from {checkpoint_folder} for rank {self.global_rank}')\n        else:\n            logger.info(f\"loading {len(zero_sd_list)} zero partition checkpoints for rank {self.global_rank}\")\n        return True\n\n    def update_optimizer_step(self, step):\n\n        def set_step(d):\n            if isinstance(d['step'], torch.Tensor):\n                d['step'] = torch.tensor(step, dtype=d['step'].dtype, device=d['step'].device)\n            else:\n                d['step'] = step\n\n        optimizer = self.optimizer\n        base_optimizer = optimizer.optimizer\n        state = base_optimizer.state\n        for group in optimizer.param_groups:\n            if 'step' in group:\n                set_step(group)\n            for p in group['params']:\n                if p in state and len(state[p]) > 0 and 'step' in state[p]:\n                    set_step(state[p])\n\n    def _get_mp_rank_zero_checkpoint_names(self, load_dir, tag, mp_rank, dp_world_size, bf16_mode):\n        zero_ckpt_names = []\n        for dp_rank in range(dp_world_size):\n            ckpt_name = self._get_rank_zero_ckpt_name(checkpoints_path=load_dir,\n                                                      tag=tag,\n                                                      mp_rank=mp_rank,\n                                                      dp_rank=dp_rank,\n                                                      bf16_mode=bf16_mode)\n            zero_ckpt_names.append(ckpt_name)\n\n        return zero_ckpt_names\n\n    def _get_all_zero_checkpoint_names(self, load_dir, tag, bf16_mode):\n        mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()\n        zero_ckpt_names = self._get_mp_rank_zero_checkpoint_names(load_dir=load_dir,\n                                                                  tag=tag,\n                                                                  mp_rank=mp_rank,\n                                                                  dp_world_size=self.loaded_checkpoint_dp_world_size,\n                                                                  bf16_mode=bf16_mode)\n        for i, ckpt_name in enumerate(zero_ckpt_names):\n            if not os.path.exists(ckpt_name):\n                # transparently handle the old file pattern for optim_states\n                if \"optim_states.pt\" in ckpt_name:\n                    ckpt_name_try = ckpt_name.replace(\"_optim_states.pt\", \"optim_states.pt\")\n                    if os.path.exists(ckpt_name_try):\n                        zero_ckpt_names[i] = ckpt_name_try\n                        continue\n\n        return zero_ckpt_names\n\n    def _get_all_zero_checkpoint_state_dicts(self, zero_ckpt_names):\n        zero_sd_list = []\n        for i, ckpt_name in enumerate(zero_ckpt_names):\n            _state = None\n            if ckpt_name is None:\n                _state = {OPTIMIZER_STATE_DICT: None}\n            # Fully load state for current rank\n            elif self.zero_elastic_checkpoint() or dist.get_rank(group=self.optimizer.zp_process_group) == i:\n                _state = self.checkpoint_engine.load(\n                    ckpt_name,\n                    map_location='cpu',\n                )\n            else:\n                _state = {OPTIMIZER_STATE_DICT: None}\n            zero_sd_list.append(_state)\n\n        zero_optimizer_sd = [sd[OPTIMIZER_STATE_DICT] for sd in zero_sd_list]\n        logger.info(f\"successfully read {len(zero_optimizer_sd)} ZeRO state_dicts for rank {self.global_rank}\")\n        return zero_optimizer_sd\n\n    def _get_all_zero_checkpoints(self, load_dir, tag):\n        for bf16_mode in [self.bfloat16_enabled(), not self.bfloat16_enabled()]:\n            zero_ckpt_names = self._get_all_zero_checkpoint_names(load_dir, tag, bf16_mode)\n            if zero_ckpt_names is not None:\n                # Warn if loading checkpoint of different bit16 type\n                if bf16_mode is not self.bfloat16_enabled():\n                    checkpoint_bit16 = BFLOAT16 if bf16_mode else FP16\n                    engine_bit16 = BFLOAT16 if self.bfloat16_enabled() else FP16\n                    logger.warn(f'Loading {checkpoint_bit16} zero checkpoints into {engine_bit16} training engine')\n                return self._get_all_zero_checkpoint_state_dicts(zero_ckpt_names)\n\n        return None\n\n    def _checkpoint_tag_validation(self, tag):\n        if self.checkpoint_tag_validation_enabled():\n            s_hash = hashlib.sha1(tag.encode())\n            bhash = torch.ByteTensor([s_hash.digest()]).flatten().to(self.device)\n            max_bhash = bhash.clone()\n            min_bhash = bhash.clone()\n            dist.all_reduce(max_bhash, op=dist.ReduceOp.MAX)\n            dist.all_reduce(min_bhash, op=dist.ReduceOp.MIN)\n            valid = all(min_bhash == bhash) and all(max_bhash == bhash)\n            msg = (f\"[rank={dist.get_rank()}] The checkpoint tag name '{tag}' is not consistent across \"\n                   \"all ranks. Including rank unique information in checkpoint tag could cause issues when \"\n                   \"restoring with different world sizes.\")\n            if self.checkpoint_tag_validation_fail():\n                assert valid, msg\n            elif not valid:\n                logger.warning(msg)\n\n    def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True, exclude_frozen_parameters=False):\n        \"\"\"Save training checkpoint\n\n        Arguments:\n            save_dir: Required. Directory for saving the checkpoint\n            tag: Optional. Checkpoint tag used as a unique identifier for the checkpoint, global step is\n                used if not provided. Tag name must be the same across all ranks.\n            client_state: Optional. State dictionary used for saving required training states in the client code.\n            save_latest: Optional. Save a file 'latest' pointing to the latest saved checkpoint.\n            exclude_frozen_parameters: Optional. Exclude frozen parameters from checkpointed state.\n        Important: all processes must call this method and not just the process with rank 0. It is\n        because each process needs to save its master weights and scheduler+optimizer states. This\n        method will hang waiting to synchronize with other processes if it's called just for the\n        process with rank 0.\n\n        \"\"\"\n        if self._optimizer_has_ckpt_event_prologue():\n            # Custom preparation for checkpoint save, if applicable\n            self.optimizer.checkpoint_event_prologue()\n\n        rank = self.local_rank if self.use_node_local_storage() else self.global_rank\n\n        # This is to make sure the checkpoint names are created without collision\n        # There seems to be issue creating them in parallel\n\n        # Ensure save_dir directory exists\n        if rank == 0:\n            self.checkpoint_engine.makedirs(save_dir, exist_ok=True)\n        dist.barrier()\n\n        if tag is None:\n            tag = f\"global_step{self.global_steps}\"\n\n        # Ensure tag is a string\n        tag = str(tag)\n        self.checkpoint_engine.create(tag)\n\n        # Ensure checkpoint tag is consistent across ranks\n        self._checkpoint_tag_validation(tag)\n\n        if self.has_moe_layers:\n            self.save_non_zero_checkpoint = False\n            self._create_checkpoint_file(save_dir, tag, False)\n            self._save_moe_checkpoint(save_dir,\n                                      tag,\n                                      client_state=client_state,\n                                      exclude_frozen_parameters=exclude_frozen_parameters)\n\n        # We distribute the task of saving layer checkpoint files among\n        # data parallel instances, so all procs should call _save_checkpoint.\n        # All procs then call module_state_dict(), but only procs of data\n        # parallel rank 0 save the general model params.\n        if not self.has_moe_layers:\n            self._create_checkpoint_file(save_dir, tag, False)\n            self._save_checkpoint(save_dir,\n                                  tag,\n                                  client_state=client_state,\n                                  exclude_frozen_parameters=exclude_frozen_parameters)\n\n        if self.save_zero_checkpoint:\n            self._create_zero_checkpoint_files(save_dir, tag)\n            self._save_zero_checkpoint(save_dir, tag)\n\n        if self._optimizer_has_ckpt_event_epilogue():\n            self.optimizer.checkpoint_event_epilogue()\n\n        # Save latest checkpoint tag\n        self.checkpoint_engine.commit(tag)\n        if save_latest and rank == 0:\n            with open(os.path.join(save_dir, 'latest'), 'w') as fd:\n                fd.write(tag)\n\n        dist.barrier()\n\n        return True\n\n    def _get_non_moe_state_dict(self, full_state_dict):\n        \"\"\"\n            Get the state dict of the non-moe layers\n        \"\"\"\n        for key in list(full_state_dict.keys()):\n            if 'expert' in key and 'moe.gate.wg.weight' not in key:\n                full_state_dict.pop(key)\n\n        return full_state_dict\n\n    def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_parameters=False):\n        save_path = self._get_ckpt_name(save_dir, tag)\n        # A hack to save the checkpointing directory. Pipeline parallelism overrides\n        # module_state_dict() and uses this path to save the model. module_state_dict()\n        # then instead just returns None.\n\n        # Using layer_#_export_# to save the model's expert state_dict\n        moe_layer_id = 0\n        for n_module, module in self.module.named_modules():\n            if isinstance(module, MoE):  # and deepspeed.comm.get_rank() == 0:\n                group_name = module.expert_group_name\n                num_local_experts = module.num_local_experts\n                expp_rank = groups._get_expert_parallel_rank(group_name)\n                exp_dp_rank = groups._get_expert_data_parallel_rank(group_name)\n                # print(expp_rank, exp_dp_rank)\n                if exp_dp_rank != 0:\n                    moe_layer_id += 1\n                    continue\n\n                # get all moe parameters\n                moe_state_dict = {}\n                for n, p in module.state_dict().items():\n                    if 'expert' in n and 'moe.gate.wg.weight' not in n:\n                        moe_state_dict[n_module + '.' + n] = p\n                moe_str_prefix = '.deepspeed_moe.experts.deepspeed_experts.'\n                # print(moe_state_dict.keys()) # until now, everything is fine. So the bug happens at next few lines\n                # Reorder the moe name rank, so that each checkpoint only has one expert\n                experts_state_dict = defaultdict(dict)\n                for key in list(moe_state_dict.keys()):\n                    m = re.match(f\".*{moe_str_prefix}([0-9]+).*\", key)\n\n                    local_expert_id = None\n                    if not m:\n                        logger.warn(f'No expert found in key {key}.')\n                    else:\n                        local_expert_id = m.group(1)\n\n                    global_expert_id = expp_rank * \\\n                        num_local_experts + int(local_expert_id)\n                    expert_key = key.replace(f'{moe_str_prefix}{local_expert_id}',\n                                             f'{moe_str_prefix}{global_expert_id}')\n                    # truncating extra tensor (shared) storage\n                    truncated = moe_state_dict.pop(key).clone().detach()\n                    experts_state_dict[str(global_expert_id)][expert_key] = truncated\n\n                # let save the moe parameters\n                for global_expert_id, expert_state_dict in experts_state_dict.items():\n                    # save the moe parameters\n                    moe_save_path = self._get_expert_ckpt_name(save_dir, moe_layer_id, global_expert_id, tag, self.mpu)\n                    if self.random_ltd_enabled():\n                        expert_state_dict = remove_random_ltd_state_dict(expert_state_dict)\n                    self.checkpoint_engine.save(expert_state_dict, moe_save_path)\n                moe_layer_id += 1\n\n        self._curr_ckpt_path = os.path.join(save_dir, tag)\n\n        largest_group_name = groups._get_max_expert_size_name()\n        expp_rank = groups._get_expert_parallel_rank(largest_group_name)\n        exp_dp_rank = groups._get_expert_data_parallel_rank(largest_group_name)\n\n        # In the case of E + D parallelism, only the\n        # first expert parallel group should save the expert weights\n        # since each expert parallel group is a copy of the model's experts\n        if exp_dp_rank != 0:\n            return\n\n        # Save optimizer states. They are different across each exp parallel rank.\n        optimizer_state = {\n            'optimizer': self.optimizer.state_dict() if self.optimizer and not self.zero_optimization() else None\n        }\n        # TODO: why use BufferedWriter not the path\n        file_path = self._get_optimizer_ckpt_name(save_dir, tag, expp_rank)\n        self.checkpoint_engine.save(optimizer_state, file_path)\n\n        # get non-moe parameters\n        model_state_dict = self._get_non_moe_state_dict(\n            self.module_state_dict(exclude_frozen_parameters=exclude_frozen_parameters))\n\n        if expp_rank == 0:\n            # TODO: update num experts info,.. in checkpoint\n            state = {\n                'module':\n                model_state_dict,\n                'lr_scheduler':\n                self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,\n                'data_sampler':\n                self.training_dataloader.data_sampler.state_dict() if\n                (self.training_dataloader is not None and self.curriculum_learning_enabled()) else None,\n                'random_ltd':\n                self.random_ltd_scheduler.state_dict() if self.random_ltd_enabled() else None,\n                'sparse_tensor_module_names':\n                self.sparse_tensor_module_names,\n                'skipped_steps':\n                self.skipped_steps,\n                'global_steps':\n                self.global_steps,\n                'global_samples':\n                self.global_samples,\n                'zp_world_size':\n                self.zp_world_size,\n                'dp_world_size':\n                self.dp_world_size,\n                'mp_world_size':\n                self.mp_world_size,\n                'num_experts':\n                self.num_experts\n            }\n            state.update(client_state)\n            logger.info(f'Saving model checkpoint: {save_path}')\n            self.checkpoint_engine.save(state, save_path)\n\n    def _create_checkpoint_file(self, save_dir, tag, zero_checkpoint):\n        name_function = (self._get_zero_ckpt_name if zero_checkpoint else self._get_ckpt_name)\n        try:\n            checkpoint_name = name_function(save_dir, tag)\n            path = os.path.dirname(checkpoint_name)\n            self.checkpoint_engine.makedirs(path, exist_ok=True)\n        except:\n            logger.error(f\"Failed saving model checkpoint to {save_dir} with tag {tag}\")\n            return False\n\n        return True\n\n    def _create_zero_checkpoint_files(self, save_dir, tag):\n        success = True\n        # zero checkpoint files are created sequentially\n        for rank in range(dist.get_world_size(self.optimizer.zp_process_group)):\n            if rank == self.global_rank:\n                success = self._create_checkpoint_file(save_dir, tag, True)\n\n            dist.barrier(group=self.optimizer.zp_process_group)\n\n        return success\n\n    def _save_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_parameters=False):\n\n        save_path = self._get_ckpt_name(save_dir, tag)\n\n        zero_optimizer_state = self.zero_optimization() or self.bfloat16_enabled()\n\n        save_frozen_param = self.zero_optimization_partition_gradients() and not exclude_frozen_parameters\n\n        # A hack to save the checkpointing directory. Pipeline parallelism overrides\n        # module_state_dict() and uses this path to save the model. module_state_dict()\n        # then instead just returns None.  The module_state_dict() implementation in\n        # PipelineEngine expects the save path to be set in self._curr_ckpt_path.\n        self._curr_ckpt_path = os.path.join(save_dir, tag)\n        module = self.module_state_dict(exclude_frozen_parameters=exclude_frozen_parameters)\n        self._curr_ckpt_path = None\n\n        state = dict(module=module,\n                     buffer_names=self._get_buffer_names(),\n                     optimizer=self.optimizer.state_dict() if self.optimizer and not zero_optimizer_state else None,\n                     param_shapes=self._get_zero_param_shapes() if self.optimizer and zero_optimizer_state else None,\n                     frozen_param_shapes=self._get_zero_frozen_param_attributes(self._get_param_shape_func)\n                     if save_frozen_param else None,\n                     shared_params=self._get_shared_params() if self.optimizer and zero_optimizer_state else None,\n                     frozen_param_fragments=self._get_zero_frozen_param_attributes(self._get_param_fragment_func)\n                     if save_frozen_param else None,\n                     lr_scheduler=self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,\n                     data_sampler=self.training_dataloader.data_sampler.state_dict() if\n                     (self.training_dataloader is not None and self.curriculum_learning_enabled()) else None,\n                     random_ltd=self.random_ltd_scheduler.state_dict() if self.random_ltd_enabled() else None,\n                     sparse_tensor_module_names=self.sparse_tensor_module_names,\n                     skipped_steps=self.skipped_steps,\n                     global_steps=self.global_steps,\n                     global_samples=self.global_samples,\n                     dp_world_size=self.seq_dp_world_size,\n                     mp_world_size=self.mp_world_size,\n                     ds_config=self.config,\n                     ds_version=version)\n        state.update(client_state)\n\n        if self.save_non_zero_checkpoint:\n            log_dist(message=f'Saving model checkpoint: {save_path}', ranks=[0, 1])\n            self.checkpoint_engine.save(state, save_path)\n\n    def _get_buffer_names(self):\n        buffer_names = []\n\n        # we save buffer names so that we could extract later the real buffers from the saved\n        # state_dict[\"module\"] in the non-zero checkpoint - the buffers are already there but they\n        # are intermixed with param placeholders\n\n        # have to traverse the tree to be able to skip non-persistent buffers\n        def get_layer_named_buffers(module, prefix=\"\"):\n            for name, buf in module.named_buffers(recurse=False):\n                if buf is not None and name not in module._non_persistent_buffers_set:\n                    buffer_names.append(prefix + name)\n\n            for name, child in module.named_children():\n                if child is not None:\n                    get_layer_named_buffers(child, prefix + name + \".\")\n\n        get_layer_named_buffers(self.module, prefix=\"\")\n\n        return buffer_names\n\n    def _get_param_shape_func(self, param):\n        return param.ds_shape if hasattr(param, 'ds_id') else param.shape\n\n    def _get_param_fragment_func(self, param):\n        return param.ds_tensor.detach().cpu() if hasattr(param, 'ds_id') else param.detach().cpu()\n\n    def _get_zero_frozen_param_attributes(self, attr_func):\n        frozen_param_fragments = OrderedDict()\n\n        for param in self.module.parameters():\n            if param.requires_grad:\n                continue\n            if param not in self.param_names:\n                raise ValueError(f\"failed to find frozen {param} in named params\")\n            name = self.param_names[param]\n            frozen_param_fragments[name] = attr_func(param)\n\n        return frozen_param_fragments\n\n    def _get_zero_param_shapes(self):\n        \"\"\"Returns a dict of name to shape mapping, only for the flattened fp32 weights saved by the\n        optimizer. the names are exactly as in state_dict. The order is absolutely important, since\n        the saved data is just flattened data with no identifiers and requires reconstruction in the\n        same order it was saved.\n        We can't rely on self.module.named_parameters() to get the saved tensors, as some params\n        will be missing and others unsaved and then it'd be impossible to reconstruct state_dict\n        from the flattened weights.\n        optimizer.bit16_groups seems to be the easiest to use as it's in all zeroX versions.\n        \"\"\"\n        param_group_shapes = []\n        cnt = 0\n        numel = 0\n\n        # zero2 started using a round_robin_bit16_groups which is a shuffled version of bit16_groups -\n        # if we don't use it, we get parameters ordered incorrectly\n        if hasattr(self.optimizer, \"round_robin_bit16_groups\"):\n            bit16_groups = self.optimizer.round_robin_bit16_groups\n        elif self.bfloat16_enabled() and hasattr(self.optimizer, \"bf16_groups\"):\n            bit16_groups = self.optimizer.bf16_groups\n        else:\n            bit16_groups = self.optimizer.bit16_groups if self.zero_optimization_stage(\n            ) == 2 else self.optimizer.fp16_groups\n\n        for bit16_group in bit16_groups:\n            param_shapes = OrderedDict()\n            for param in bit16_group:\n                cnt += 1\n                numel += param.ds_numel if hasattr(param, \"ds_numel\") else param.numel()\n                shape = param.ds_shape if hasattr(param, \"ds_shape\") else param.shape\n                if param not in self.param_names:\n                    raise ValueError(f\"failed to find optimizer param in named params\")\n                name = self.param_names[param]\n                param_shapes[name] = shape\n\n                # uncomment to debug zero_to_fp32.py problems\n                # if self.global_rank == 0: print(f\"saving param {name} {shape} (numel={shape.numel()})\")\n            param_group_shapes.append(param_shapes)\n        # if self.global_rank == 0: print(f\"Total saved {numel} numels in {cnt} params\")\n\n        return param_group_shapes\n\n    def _get_shared_params(self):\n        \"\"\"\n        Returns a dict of shared params, which can later be used to reconstruct the original state dict,\n        e.g. in `zero_to_fp32`. Each dict entry is a pair of param names, where the key is the name\n        of the variable that isn't stored and the value is the actual param holding data.\n        \"\"\"\n        shared_index = {}\n        shared_params_by_full_name = {}\n\n        is_zero3_model = (self.zero_optimization_partition_weights()\n                          and any(hasattr(param, \"ds_id\") for param in self.module.parameters()))\n\n        def get_layer_state_dict(module, prefix=\"\"):\n            # handle params\n            for name, param in module.named_parameters(recurse=False):\n                if param is None or (is_zero3_model and not hasattr(param, \"ds_id\")):\n                    continue\n                key = prefix + name\n\n                # When weights are manged by stage 3, we can't rely on param.data_ptr() as it will be reused\n                # as weights get gathered and reduced, but param.ds_id is unique across all zero weights\n                # (and shared params will have the same param.ds_id)\n                param_id = param.ds_id if is_zero3_model else param.data_ptr()\n\n                if param_id in shared_index:\n                    # shared weights\n                    #print(f\"`{key}` is shared with `{shared_index[param_id]}`\")\n                    shared_params_by_full_name[key] = shared_index[param_id]\n                else:\n                    shared_index[param_id] = key\n\n            for name, child in module.named_children():\n                if child is not None:\n                    get_layer_state_dict(child, prefix + name + \".\")\n\n        if dist.get_rank() == 0:\n            get_layer_state_dict(self.module, prefix=\"\")\n\n        return shared_params_by_full_name\n\n    def _copy_recovery_script(self, save_path):\n        base_dir = os.path.dirname(os.path.dirname(__file__))\n        script = \"zero_to_fp32.py\"\n        src = os.path.join(base_dir, \"utils\", script)\n        dst = os.path.join(save_path, script)\n        #logger.info(f\"creating recovery script {dst}\")\n        copyfile(src, dst)\n        self._change_recovery_script_permissions(dst)\n\n    def _change_recovery_script_permissions(self, dst):\n        # make executable (safeguard for file shares - Azure as example)\n        try:\n            os.chmod(dst, os.stat(dst).st_mode | stat.S_IEXEC)\n        except (FileNotFoundError, PermissionError) as e:\n            #this message is used in unit test TestZeRONonDistributed\n            logger.info(\n                f'Warning: Could not change permissions for {dst} due to error: {e}. Continuing without changing permissions.'\n            )\n\n    def _save_zero_checkpoint(self, save_path, tag):\n        zero_checkpoint_name = self._get_zero_ckpt_name(save_path, tag)\n        zero_sd = dict(optimizer_state_dict=self.optimizer.state_dict(), ds_config=self.config, ds_version=version)\n        self.checkpoint_engine.save(zero_sd, zero_checkpoint_name)\n\n        if self.global_rank == 0:\n            self._copy_recovery_script(save_path)\n        ckpt_type = 'zero' if self.zero_optimization() else 'bf16_zero'\n        logger.info(f'{ckpt_type} checkpoint saved {zero_checkpoint_name}')\n\n    def _zero3_consolidated_16bit_state_dict(self):\n        \"\"\"\n        Get a full non-partitioned state_dict with fp16 weights on cpu.\n        Important: this function must be called on all ranks and not just rank 0.\n        This is similar to nn.Module.state_dict (modelled after _save_to_state_dict), but:\n        1. consolidates the weights from different partitions on gpu0\n        2. works on one layer at a time to require as little gpu0 memory as possible, by\n        moving the already consolidated weights to cpu\n        3. takes care to keep the shared params shared when gradually copying the params to cpu\n        Returns:\n            a consolidated fp16 ``state_dict`` on cpu on rank 0, ``None`` on other ranks\n        \"\"\"\n        if not self.zero_optimization_partition_weights():\n            raise ValueError(\"this function requires ZeRO-3 mode\")\n\n        state_dict = OrderedDict() if dist.get_rank() == 0 else None\n        shared_params = {}\n\n        def get_layer_state_dict(module, prefix=\"\"):\n            # gather one layer at a time to be memory-efficient\n            # must use modifier_rank=0 to release GPU memory after each layer gathered\n            #see_memory_usage(\"before GatheredParameters\", force=True)\n            with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0):\n                if dist.get_rank() == 0:\n                    # handle params\n                    for name, param in module.named_parameters(recurse=False):\n                        if param is None:\n                            continue\n                        key = prefix + name\n                        # can't rely on param.data_ptr() as it will be reused as weights gets\n                        # gathered and reduced, but param.ds_id is unique across all zero weights\n                        # (and shared params will have the same param.ds_id)\n                        if param.ds_id in shared_params:\n                            # shared weights\n                            #print(f\"`{key}` is shared with `{shared_params[param.ds_id]}`\")\n                            state_dict[key] = state_dict[shared_params[param.ds_id]]\n                        else:\n                            state_dict[key] = param.detach().cpu()\n                            shared_params[param.ds_id] = key\n                        #print(f\"param {param.ds_id} {param.shape} {key} \")\n\n                    # now buffers - not sure if need to take care of potentially shared weights here\n                    for name, buf in module.named_buffers(recurse=False):\n                        if (buf is not None and name not in module._non_persistent_buffers_set):\n                            state_dict[prefix + name] = buf.detach().cpu()\n            #see_memory_usage(\"after GatheredParameters\", force=True)\n\n            for name, child in module.named_children():\n                if child is not None:\n                    get_layer_state_dict(child, prefix + name + \".\")\n\n        # Prepare for checkpoint save by ensuring all parameters are partitioned\n        if self._optimizer_has_ckpt_event_prologue():\n            self.optimizer.checkpoint_event_prologue()\n\n        see_memory_usage(\"before get_layer_state_dict\", force=False)\n        get_layer_state_dict(self.module, prefix=\"\")\n        see_memory_usage(\"after get_layer_state_dict\", force=False)\n\n        if self._optimizer_has_ckpt_event_epilogue():\n            self.optimizer.checkpoint_event_epilogue()\n\n        return state_dict\n\n    def save_fp16_model(self, save_dir, save_filename=\"pytorch_model.bin\"):\n        \"\"\"has been renamed to save_16bit_model, keeping this around for backwards\n        compatibility\"\"\"\n        return self.save_16bit_model(save_dir, save_filename)\n\n    def save_16bit_model(self, save_dir, save_filename=\"pytorch_model.bin\"):\n        \"\"\"\n        Save 16bit model weights\n\n        This method saves the 16bit model weights at the desired destination.\n\n        Arguments:\n            save_dir: Required. Directory for saving the model\n            save_filename: Optional. Filename to save to. Defaults to ``pytorch_model.bin``\n\n        Returns:\n            ``True`` when a model has been saved, ``False`` otherwise. It will not be saved if\n            stage3_gather_16bit_weights_on_model_save is ``False``.\n\n        Important: all processes must call this method and not just the process with rank 0. It is\n        because the processes need to work in sync to gather the weights. This method will hang\n        waiting to synchronize with other processes if it's called just for the process with rank 0.\n\n        \"\"\"\n\n        path = os.path.join(save_dir, save_filename)\n\n        if self.zero_optimization_partition_weights():\n            if self.zero_gather_16bit_weights_on_model_save():\n                # consolidation is expensive in time and memory and therefore isn't a default\n                state_dict = self._zero3_consolidated_16bit_state_dict()\n            else:\n                # the model will be bogus if not consolidated so don't confuse the user by saving it\n                logger.info(\n                    f\"Did not save the model {path} because `stage3_gather_16bit_weights_on_model_save` is False\")\n                return False\n        else:\n            state_dict = self.module.state_dict()\n\n        tag = f\"global_step{self.global_steps}\"\n        tag = str(tag)\n        self.checkpoint_engine.create(tag)\n\n        if dist.get_rank() == 0:\n            self.checkpoint_engine.makedirs(save_dir, exist_ok=True)\n            logger.info(f\"Saving model weights to {path}, tag: {tag}\")\n            self.checkpoint_engine.save(state_dict, path)\n\n        self.checkpoint_engine.commit(tag)\n\n        return True\n\n    def empty_partition_cache(self):\n        \"\"\"\n        Release GPU memory consumed by offloaded model parameters.\n        \"\"\"\n        if hasattr(self.optimizer, 'empty_partition_cache'):\n            self.optimizer.empty_partition_cache()\n            gc.collect()\n            get_accelerator().empty_cache()\n"
  },
  {
    "path": "opensora/adaptor/modules.py",
    "content": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\n\ndef fp32_layer_norm_forward(self, inputs: torch.Tensor) -> torch.Tensor:\n    origin_dtype = inputs.dtype\n    return F.layer_norm(inputs.float(), self.normalized_shape, self.weight.float() if self.weight is not None else None,\n                        self.bias.float() if self.bias is not None else None, self.eps).to(origin_dtype)\n\n\ndef fp32_silu_forward(self, inputs: torch.Tensor) -> torch.Tensor:\n    return torch.nn.functional.silu(inputs.float(), inplace=self.inplace).to(inputs.dtype)\n\n\ndef fp32_gelu_forward(self, inputs: torch.Tensor) -> torch.Tensor:\n    return torch.nn.functional.gelu(inputs.float(), approximate=self.approximate).to(inputs.dtype)\n\n\ndef replace_with_fp32_forwards():\n    nn.GELU.forward = fp32_gelu_forward\n    nn.SiLU.forward = fp32_silu_forward\n    nn.LayerNorm.forward = fp32_layer_norm_forward\n"
  },
  {
    "path": "opensora/adaptor/stage_1_and_2.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# SPDX-License-Identifier: Apache-2.0\n\n# DeepSpeed Team\n\nimport torch\nimport os\nimport pdb\nfrom deepspeed import comm as dist\nfrom packaging import version as pkg_version\nfrom collections import OrderedDict\nfrom torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors\n\nfrom deepspeed.runtime import ZeROOptimizer\nfrom deepspeed.runtime.fp16.loss_scaler import CreateLossScaler\nfrom deepspeed.runtime.utils import (bwc_tensor_model_parallel_rank, get_global_norm, empty_cache, see_memory_usage,\n                                     inf, is_model_parallel_parameter, align_dense_tensors, all_gather_dp_groups)\n\nfrom deepspeed.runtime.zero.config import ZeroStageEnum\nfrom deepspeed.runtime.zero.offload_config import OffloadDeviceEnum\nfrom deepspeed.ops.adam import DeepSpeedCPUAdam\nfrom deepspeed.utils import logger\nfrom deepspeed.moe.utils import is_moe_param\nfrom deepspeed.git_version_info import version\n\nfrom deepspeed.runtime.constants import PIPE_REPLICATED\nfrom deepspeed.accelerator import get_accelerator\n\nfrom deepspeed.checkpoint.constants import (DS_VERSION, GROUP_PADDINGS, PARTITION_COUNT, LOSS_SCALER,\n                                            SINGLE_PARTITION_OF_FP32_GROUPS, BASE_OPTIMIZER_STATE,\n                                            BASE_OPTIMIZER_STATE_STEP, CLIP_GRAD, ZERO_STAGE, PARAM_SLICE_MAPPINGS)\nfrom deepspeed.utils import link_hp_params\nfrom deepspeed.checkpoint import enable_universal_checkpoint\n\nfrom deepspeed.utils import groups\n\nfrom opensora.adaptor.zp_manager import zp_manager\n\n# Toggle this to true to enable correctness test\n# with gradient partitioning and without\npg_correctness_test = False\n\nOPTIMIZER_ALLGATHER_TIMER = 'optimizer_allgather'\nOPTIMIZER_GRADIENTS_TIMER = 'optimizer_gradients'\nOPTIMIZER_STEP_TIMER = 'optimizer_step'\nOPTIMIZER_TIMERS = [OPTIMIZER_ALLGATHER_TIMER, OPTIMIZER_GRADIENTS_TIMER, OPTIMIZER_STEP_TIMER]\n\n\ndef input(msg):\n    return\n\n\ndef split_half_float_double(tensors):\n    device_type = get_accelerator().device_name()\n    dtypes = [\n        \"torch.{}.HalfTensor\".format(device_type), \"torch.{}.FloatTensor\".format(device_type),\n        \"torch.{}.DoubleTensor\".format(device_type), \"torch.{}.BFloat16Tensor\".format(device_type)\n    ]\n    buckets = []\n    for i, dtype in enumerate(dtypes):\n        bucket = [t for t in tensors if t.type() == dtype]\n        if bucket:\n            buckets.append(bucket)\n    return buckets\n\n\ndef isclose(a, b, rtol=1e-09, atol=0.0):\n    return abs(a - b) <= max(rtol * max(abs(a), abs(b)), atol)\n\n\ndef lcm(x, y):\n    from fractions import gcd  # or can import gcd from `math` in Python 3\n    return x * y // gcd(x, y)\n\n\ndef get_alignment_padding(tensor_list, alignment):\n    num_elements = sum([tensor.numel() for tensor in tensor_list])\n    remainder = num_elements % alignment\n    return (alignment - remainder) if remainder else remainder\n\n\ndef move_to_cpu(tensor_list):\n    for tensor in tensor_list:\n        tensor.data = tensor.data.cpu()\n\n\ndef print_rank_msg(msg):\n    print(f\"rank {dist.get_rank()} - {msg}\")\n\n\ndef _get_padded_tensor(src_tensor, size):\n    if src_tensor.numel() >= size:\n        return src_tensor\n    padded_tensor = torch.zeros(size, dtype=src_tensor.dtype, device=src_tensor.device)\n    slice_tensor = torch.narrow(padded_tensor, 0, 0, src_tensor.numel())\n    slice_tensor.data.copy_(src_tensor.data)\n    return padded_tensor\n\n\ndef contigous_flatten(tensors):\n    return _flatten_dense_tensors([tensor.contiguous() for tensor in tensors])\n\n\ndef all_gather_into_tensor_dp_groups(groups_flat, partitioned_param_groups, zp_process_group):\n    for group_id, (group_flat, partitioned_params) in enumerate(zip(groups_flat, partitioned_param_groups)):\n        partition_id = dist.get_rank(group=zp_process_group[group_id])\n        dp_world_size = dist.get_world_size(group=zp_process_group[group_id])\n        if dp_world_size == 1:\n            # no groups share optimizer states\n            # pipeline parallel with bf16 will default call this even if dp size = 1.\n            continue\n        input_tensor = partitioned_params[partition_id].contiguous()\n        # print(f\"call all_gather_into_tensor_dp_groups, input size is {input_tensor.size()}, \"\n        #       f\"output size is {group_flat.size()}\")\n        #\n        # print(f\"groups_flat.size = {groups_flat.numel()}\")\n        # print(f\"partitioned_param_groups = {sum([v.numel() for v in partitioned_param_groups])}\")\n        dist.all_gather_into_tensor(group_flat, input_tensor, zp_process_group[group_id])\n\n\nclass DeepSpeedZeroOptimizer(ZeROOptimizer):\n    \"\"\"\n    DeepSpeedZeroOptimizer designed to reduce the memory footprint\n    required for training large deep learning models.\n\n    For more details please see ZeRO: Memory Optimization Towards Training A Trillion Parameter Models\n    https://arxiv.org/abs/1910.02054\n\n    For usage examples, refer to TODO: DeepSpeed Tutorial\n\n    \"\"\"\n\n    def __init__(self,\n                 init_optimizer,\n                 param_names,\n                 timers,\n                 static_loss_scale=1.0,\n                 dynamic_loss_scale=False,\n                 dynamic_loss_args=None,\n                 verbose=True,\n                 contiguous_gradients=True,\n                 reduce_bucket_size=500000000,\n                 use_multi_rank_bucket_allreduce=True,\n                 allgather_bucket_size=5000000000,\n                 dp_process_group=None,\n                 expert_parallel_group=None,\n                 expert_data_parallel_group=None,\n                 reduce_scatter=True,\n                 overlap_comm=False,\n                 offload_optimizer_config=None,\n                 mpu=None,\n                 clip_grad=0.0,\n                 gradient_accumulation_dtype=torch.float32,\n                 communication_data_type=torch.float16,\n                 postscale_gradients=True,\n                 gradient_predivide_factor=1.0,\n                 gradient_accumulation_steps=1,\n                 ignore_unused_parameters=True,\n                 partition_grads=True,\n                 round_robin_gradients=False,\n                 has_moe_layers=False,\n                 fp16_master_weights_and_gradients=False,\n                 elastic_checkpoint=False):\n\n        if offload_optimizer_config is not None and offload_optimizer_config.device != OffloadDeviceEnum.none:\n            self.cpu_offload = True\n            self.cpu_offload_pin_memory = offload_optimizer_config.pin_memory\n        else:\n            self.cpu_offload = False\n            self.cpu_offload_pin_memory = False\n\n        if dist.get_rank() == 0:\n            logger.info(f\"Reduce bucket size {reduce_bucket_size}\")\n            logger.info(f\"Allgather bucket size {allgather_bucket_size}\")\n            logger.info(f\"CPU Offload: {self.cpu_offload}\")\n            logger.info(f'Round robin gradient partitioning: {round_robin_gradients}')\n        # The fused optimizer does all the work. We need this layer for two reason:\n        # 1. maintain same user API from apex.fp16_utils\n        # 2. keep common stuff here in case we need to add ne552w fused optimizer later\n\n        self.elastic_checkpoint = elastic_checkpoint\n        self.param_names = param_names\n        self.mpu = mpu\n        # differences from apex.fp16_utils:\n        # - assume all model params in fp16\n        # - assume all params requires grad\n        # - flat by groups, not keeping state. TODO: remove state explicitly?\n        # - master grad and unflat master weight never exist. TODO: a way to save out unflat master?\n        if not get_accelerator().is_available():\n            raise SystemError(\"Accelerator is not detected, cannot perform low precision training (e.g., fp16, bf16).\")\n        self.optimizer = init_optimizer\n\n        # Use torch (un)flatten ops\n        self.flatten = contigous_flatten\n        self.unflatten = _unflatten_dense_tensors\n\n        # ZeRO stage 1 (False) or 2 (True)\n        self.partition_gradients = partition_grads\n        self.zero_stage_string = \"ZeRO-2\" if partition_grads else \"ZeRO-1\"\n\n        self.timers = timers\n\n        self.reduce_scatter = reduce_scatter\n\n        self.overlap_comm = overlap_comm\n\n        self.deepspeed_adam_offload = self.cpu_offload\n\n        self.device = get_accelerator().current_device_name() if not self.cpu_offload else 'cpu'\n\n        zp_manager.init_group()\n        self.zp_process_group = zp_manager.zp_group\n        zp_rank = dist.get_rank(group=self.zp_process_group)\n        zp_size = dist.get_world_size(group=self.zp_process_group)\n        print(f\"zp rank is {zp_rank}, zp_size={zp_size}\")\n\n        self.dp_process_group = dp_process_group\n\n        self.sequence_parallel_size = groups._get_sequence_parallel_world_size()\n        # expert parallel group\n        self.ep_process_group = expert_parallel_group\n\n        # data parallel group for experts\n        self.expert_dp_process_group = expert_data_parallel_group\n\n        # data parallel size for non-experts\n        dp_size = dist.get_world_size(group=self.dp_process_group)\n\n        # For MoE models this maybe different for different param group\n        # It will be modified during MoE setup later in the init\n        self.real_zp_process_group = [self.zp_process_group for i in range(len(self.optimizer.param_groups))]\n        self.real_dp_process_group = [self.dp_process_group for i in range(len(self.optimizer.param_groups))]\n        self.partition_count = [zp_manager.zp_size for i in range(len(self.optimizer.param_groups))]\n\n        self.is_gradient_accumulation_boundary = True\n\n        # CPU-Offload requires contiguous gradients\n        self.contiguous_gradients = contiguous_gradients or self.cpu_offload\n\n        self.has_moe_layers = has_moe_layers\n        if self.has_moe_layers:\n            self._configure_moe_settings()\n        self._global_grad_norm = 0.\n\n        if mpu is None:\n            self.model_parallel_group = None\n            self.model_parallel_world_size = 1\n            self.model_parallel_rank = 0\n        else:\n            self.model_parallel_group = mpu.get_model_parallel_group()\n            self.model_parallel_world_size = mpu.get_model_parallel_world_size()\n            self.model_parallel_rank = bwc_tensor_model_parallel_rank(mpu)\n\n        self.overflow = False\n        self.clip_grad = clip_grad\n        self.communication_data_type = communication_data_type\n        self.gradient_predivide_factor = gradient_predivide_factor\n        self.postscale_gradients = postscale_gradients\n        self.gradient_accumulation_steps = gradient_accumulation_steps\n        self.micro_step_id = 0\n        self.ignore_unused_parameters = ignore_unused_parameters\n        self.round_robin_gradients = round_robin_gradients\n\n        self.extra_large_param_to_reduce = None\n        self.fp16_master_weights_and_gradients = fp16_master_weights_and_gradients\n\n        if self.fp16_master_weights_and_gradients:\n            assert self.cpu_offload and type(self.optimizer) in [DeepSpeedCPUAdam], \\\n                f\"fp16_master_and_gradients requires optimizer to support keeping fp16 master and gradients while keeping the optimizer states in fp32.\" \\\n                f\"Currently only supported using ZeRO-Offload with DeepSpeedCPUAdam. But current setting is ZeRO-Offload:{self.cpu_offload} and optimizer type {type(self.optimizer)}.\" \\\n                f\"Either disable fp16_master_weights_and_gradients or enable {self.zero_stage_string} Offload with DeepSpeedCPUAdam.\"\n\n        if self.reduce_scatter:\n            valid_reduce_scatter_dtypes = (torch.float16, torch.bfloat16, torch.float32)\n            assert self.communication_data_type in valid_reduce_scatter_dtypes, f\"{self.zero_stage_string} supports {valid_reduce_scatter_dtypes} communication_data_type with reduce scatter enabled. Got: '{self.communication_data_type}'\"\n            assert self.gradient_predivide_factor == 1.0, \"gradient_predivide_factor != 1.0 is not yet supported with {self.zero_stage_string} with reduce scatter enabled\"\n            assert self.postscale_gradients, \"pre-scale gradients is not yet supported with {self.zero_stage_string} with reduce scatter enabled\"\n\n        # param flattened by groups\n        self.bit16_groups = []\n        self.bit16_groups_flat = []\n\n        # param partitioned by data parallel degree\n        # this will contain a list of equal sized tensors\n        # each of which will be updated by a different process\n        self.parallel_partitioned_bit16_groups = []\n\n        # a single 32-bit partition of the parallel partitioned parameters\n        # that this process will update\n        self.single_partition_of_fp32_groups = []\n\n        # param partition info\n\n        # These are the parameters in each group that will not be updated by this process directly\n        self.params_not_in_partition = []\n\n        # These are the parameters that will be updated by this process directly\n        self.params_in_partition = []\n\n        # Offset from the first parameter in the self.params_in_partition\n        # the parameter boundaries may not align with partition boundaries\n        # so we need to keep track of the offset\n        self.first_offset = []\n\n        # number of elements per partition in each group\n        self.partition_size = []\n\n        # align nccl all-gather send buffers to 4-byte boundary\n        self.nccl_start_alignment_factor = 16  # 4-byte alignment/sizeof(fp16) = 2\n\n        assert (\n                allgather_bucket_size % self.nccl_start_alignment_factor == 0\n        ), f\"allgather_bucket_size must be a multiple of nccl_start_alignment_factor, {self.nccl_start_alignment_factor} \"\n\n        self.all_reduce_print = False\n        self.dtype = self.optimizer.param_groups[0]['params'][0].dtype\n        self.gradient_accumulation_dtype = gradient_accumulation_dtype\n\n        if self.dtype != self.gradient_accumulation_dtype:\n            self.use_separate_grad_accum = True\n        else:\n            self.use_separate_grad_accum = False\n        if self.use_separate_grad_accum and not self.partition_gradients:\n            self.use_grad_accum_attribute = True\n        else:\n            self.use_grad_accum_attribute = False\n\n        self.round_robin_bit16_groups = []\n        self.round_robin_bit16_indices = []\n\n        # Use different parallel to do all_to_all_reduce related things\n        # padding on each partition for alignment purposes\n        self.groups_padding = []\n        # loop to deal with groups\n        for i, param_group in enumerate(self.optimizer.param_groups):\n            partition_id = dist.get_rank(group=self.real_zp_process_group[i])\n\n            # push this group to list before modify\n            # TODO: Explore simplification that avoids the extra book-keeping by pushing the reordered group\n            trainable_parameters = []\n            for param in param_group['params']:\n                if param.requires_grad:\n                    param.grad_accum = None\n                    trainable_parameters.append(param)\n            self.bit16_groups.append(trainable_parameters)\n\n            # not sure why apex was cloning the weights before flattening\n            # removing cloning here\n\n            see_memory_usage(f\"Before moving param group {i} to CPU\")\n            # move all the parameters to cpu to free up GPU space for creating flat buffer\n            move_to_cpu(self.bit16_groups[i])\n            empty_cache()\n            see_memory_usage(f\"After moving param group {i} to CPU\", force=False)\n\n            # Reorder group parameters for load balancing of gradient partitioning during backward among ranks.\n            # This ensures that gradients are reduced in a fashion such that ownership round robins among the ranks.\n            # For example, rather than 3 gradients (g_n+2, g_n+1, g_n) that are reduced consecutively belonging\n            # to the same rank, instead they will belong to 3 ranks (r_m+2, r_m+1, r_m).\n            if self.round_robin_gradients:\n                round_robin_tensors, round_robin_indices = self._round_robin_reorder(\n                    self.bit16_groups[i], dist.get_world_size(group=self.real_zp_process_group[i]))\n            else:\n                round_robin_tensors = self.bit16_groups[i]\n                round_robin_indices = list(range(len(self.bit16_groups[i])))\n\n            self.round_robin_bit16_groups.append(round_robin_tensors)\n            self.round_robin_bit16_indices.append(round_robin_indices)\n\n            # create flat buffer in CPU and move to GPU\n            self.bit16_groups_flat.append(\n                self.flatten_dense_tensors_aligned(\n                    self.round_robin_bit16_groups[i],\n                    self.nccl_start_alignment_factor * dist.get_world_size(group=self.real_zp_process_group[i])).to(\n                    get_accelerator().current_device_name()))\n            see_memory_usage(f\"After flattening and moving param group {i} to GPU\", force=False)\n\n            # Record padding required for alignment\n            if partition_id == dist.get_world_size(group=self.real_zp_process_group[i]) - 1:\n                padding = self.bit16_groups_flat[i].numel() - sum(\n                    [t.numel() for t in self.round_robin_bit16_groups[i]])\n            else:\n                padding = 0\n            self.groups_padding.append(padding)\n\n            if dist.get_rank(group=self.real_zp_process_group[i]) == 0:\n                see_memory_usage(f\"After Flattening and after emptying param group {i} cache\", force=False)\n\n            # set model bit16 weight to slices of flattened buffer\n            self._update_model_bit16_weights(i)\n\n            # divide the flat weights into near equal partition equal to the data parallel degree\n            # each process will compute on a different part of the partition\n            data_parallel_partitions = self.get_data_parallel_partitions(self.bit16_groups_flat[i], i)\n            self.parallel_partitioned_bit16_groups.append(data_parallel_partitions)\n\n            # print(f\"self.bit16_groups_flat[i].size = {self.bit16_groups_flat[i].numel()}\")\n            # print(f\"data_parallel_partitions = {sum([v.numel() for v in data_parallel_partitions])}\")\n\n            # verify that data partition start locations are 4-byte aligned\n            for partitioned_data in data_parallel_partitions:\n                assert (partitioned_data.data_ptr() % (2 * self.nccl_start_alignment_factor) == 0)\n\n            # A partition of the fp32 master weights that will be updated by this process.\n            # Note that the params in single_partition_of_fp32_groups is cloned and detached\n            # from the origin params of the model.\n            if not fp16_master_weights_and_gradients:\n                self.single_partition_of_fp32_groups.append(self.parallel_partitioned_bit16_groups[i][partition_id].to(\n                    self.device).clone().float().detach())\n            else:\n                self.single_partition_of_fp32_groups.append(self.parallel_partitioned_bit16_groups[i][partition_id].to(\n                    self.device).clone().half().detach())\n\n            # Set local optimizer to have flat params of its own partition.\n            # After this, the local optimizer will only contain its own partition of params.\n            # In that case, the local optimizer only saves the states(momentum, variance, etc.) related to its partition's params(zero stage1).\n            self.single_partition_of_fp32_groups[\n                i].requires_grad = True  # keep this in case internal optimizer uses it\n            param_group['params'] = [self.single_partition_of_fp32_groups[i]]\n\n            partition_size = len(self.bit16_groups_flat[i]) / dist.get_world_size(group=self.real_zp_process_group[i])\n            params_in_partition, params_not_in_partition, first_offset = self.get_partition_info(\n                self.round_robin_bit16_groups[i], partition_size, partition_id)\n\n            self.partition_size.append(partition_size)\n            self.params_in_partition.append(params_in_partition)\n            self.params_not_in_partition.append(params_not_in_partition)\n            self.first_offset.append(first_offset)\n\n        self.reduce_bucket_size = int(reduce_bucket_size)\n        self.use_multi_rank_bucket_allreduce = use_multi_rank_bucket_allreduce\n        self.allgather_bucket_size = int(allgather_bucket_size)\n\n        self.reduction_stream = None if get_accelerator().is_synchronized_device() else get_accelerator().Stream()\n        # self.copy_grad_stream = get_accelerator().Stream()\n        self.callback_queued = False\n\n        self.param_dict = {}\n\n        # map between param_id and bool to specify if a param is in this partition\n        self.is_param_in_current_partition = {}\n\n        self.grads_in_ipg_bucket = []\n        self.params_in_ipg_bucket = []\n        self.elements_in_ipg_bucket = 0\n        self.params_already_reduced = []\n        self._release_ipg_buffers()\n        self.previous_reduced_grads = None\n        self.ipg_bucket_has_moe_params = False\n\n        # simplified param id\n        self.param_id = {}\n\n        # interesting code: unique ids being assigned to individual parameters\n        largest_param_numel = 0\n        count = 0\n        for i, params_group in enumerate(self.bit16_groups):\n            for param in params_group:\n                unique_id = id(param)\n                self.param_id[unique_id] = count\n                self.param_dict[count] = param\n                self.params_already_reduced.append(False)\n                if param.numel() > largest_param_numel:\n                    largest_param_numel = param.numel()\n                count = count + 1\n\n        for param_group in self.params_in_partition:\n            for param in param_group:\n                self.is_param_in_current_partition[self.get_param_id(param)] = True\n\n        for param_group in self.params_not_in_partition:\n            for param in param_group:\n                self.is_param_in_current_partition[self.get_param_id(param)] = False\n\n        if self.cpu_offload:\n            self.accumulated_grads_in_cpu = {}\n            self.norm_for_param_grads = {}\n            self.local_overflow = False\n            self.grad_position = {}\n            self.temp_grad_buffer_for_cpu_offload = torch.zeros(largest_param_numel,\n                                                                device=self.device,\n                                                                dtype=self.dtype)\n            if self.cpu_offload_pin_memory:\n                self.temp_grad_buffer_for_cpu_offload = get_accelerator().pin_memory(\n                    self.temp_grad_buffer_for_cpu_offload)\n            self.temp_grad_buffer_for_gpu_offload = torch.zeros(largest_param_numel,\n                                                                device=get_accelerator().current_device_name(),\n                                                                dtype=self.dtype)\n            for i, params_group in enumerate(self.bit16_groups):\n                self.get_grad_position(i, self.params_in_partition[i], self.first_offset[i], self.partition_size[i])\n\n        # mapping from parameter to partition that it belongs to\n        self.param_to_partition_ids = {}\n\n        # stores if a partition has been reduced in this step\n        self.is_partition_reduced = {}\n\n        # number of grads in partition that still need to be computed\n        self.remaining_grads_in_partition = {}\n\n        # total number of grads in partition\n        self.total_grads_in_partition = {}\n\n        # stores if a grad in a partition has been computed or not\n        self.is_grad_computed = {}\n\n        # stores the offset at which a parameter gradient needs to be inserted in a partition\n        self.grad_partition_insertion_offset = {}\n\n        # the offset in the gradient at which it must be inserted at the beginning of the partition\n        self.grad_start_offset = {}\n\n        # will store the averaged gradients required by this partition\n        self.averaged_gradients = {}\n\n        # For cpu_offload, will store the averaged gradients required by this partition\n        self.offload_gradient_dict = {}\n\n        # store index of first parameter in each partition\n        self.first_param_index_in_partition = {}\n\n        # initializes all data structures for implementing gradient partitioning\n        self.initialize_gradient_partitioning_data_structures()\n\n        # resets the data structure value for the next backward propagation\n        self.reset_partition_gradient_structures()\n\n        # creates backward hooks for gradient partitioning\n        if self.partition_gradients or self.overlap_comm:\n            self.create_reduce_and_remove_grad_hooks()\n\n        self.custom_loss_scaler = False\n        self.external_loss_scale = None\n\n        # we may have a way of fusing dynamic scale. Do not support for now\n        self.loss_scaler = CreateLossScaler(dtype=self.dtype,\n                                            static_loss_scale=static_loss_scale,\n                                            dynamic_scaling=dynamic_loss_scale,\n                                            dynamic_loss_args=dynamic_loss_args)\n        self.dynamic_loss_scale = self.loss_scaler.dynamic\n\n        if self.dtype != torch.float16:\n            # Only fp16 should use dynamic loss scaling\n            assert self.loss_scaler.cur_scale == 1.0\n            assert not self.dynamic_loss_scale\n\n        see_memory_usage(\"Before initializing optimizer states\", force=True)\n        self.initialize_optimizer_states()\n        see_memory_usage(\"After initializing optimizer states\", force=True)\n\n        if dist.get_rank() == 0:\n            logger.info(f\"optimizer state initialized\")\n\n        if dist.get_rank(group=self.zp_process_group) == 0:\n            see_memory_usage(f\"After initializing ZeRO optimizer\", force=True)\n\n        self._link_all_hp_params()\n        self._enable_universal_checkpoint()\n        self._param_slice_mappings = self._create_param_mapping()\n\n    def _enable_universal_checkpoint(self):\n        for lp_param_group in self.bit16_groups:\n            enable_universal_checkpoint(param_list=lp_param_group)\n\n    def _create_param_mapping(self):\n        param_mapping = []\n        for i, _ in enumerate(self.optimizer.param_groups):\n            param_mapping_per_group = OrderedDict()\n            for lp in self.bit16_groups[i]:\n                if lp._hp_mapping is not None:\n                    lp_name = self.param_names[lp]\n                    param_mapping_per_group[lp_name] = lp._hp_mapping.get_hp_fragment_address()\n            param_mapping.append(param_mapping_per_group)\n\n        return param_mapping\n\n    def _link_all_hp_params(self):\n        dp_world_size = dist.get_world_size(group=self.zp_process_group)\n        if self.cpu_offload:\n            self._get_offload_gradient_dict()\n\n        for i, _ in enumerate(self.optimizer.param_groups):\n            # Link bit16 and fp32 params in partition\n            partition_id = dist.get_rank(group=self.real_zp_process_group[i])\n            partition_size = self.bit16_groups_flat[i].numel() // dp_world_size\n            flat_hp_partition = self.single_partition_of_fp32_groups[i]\n            link_hp_params(lp_param_list=self.bit16_groups[i],\n                           flat_hp_partition=flat_hp_partition,\n                           gradient_dict=self.averaged_gradients,\n                           offload_gradient_dict=self.offload_gradient_dict,\n                           use_offload=self.cpu_offload,\n                           param_group_index=i,\n                           partition_start=partition_id * partition_size,\n                           partition_size=partition_size,\n                           partition_optimizer_state=self.optimizer.state[flat_hp_partition],\n                           dp_group=self.real_zp_process_group[i])\n\n    def is_moe_group(self, group):\n        return 'moe' in group and group['moe']\n\n    def _configure_moe_settings(self):\n        # if we're using ZeRO stage 2, ensure contiguous gradients are used\n        if self.partition_gradients:\n            assert self.contiguous_gradients, \"Contiguous Gradients in ZeRO Stage 2 must be set to True for MoE. Other code paths are not tested with MoE\"\n        # NOTE: To run ZeRO stage 1 with MoE, we need to set self.contiguous_gradients to True or ignore the assertion\n        if not self.partition_gradients and not self.contiguous_gradients:\n            logger.warn(\n                \"ZeRO Stage 1 has not been thoroughly tested with MoE. This configuration is still experimental.\")\n        assert self.reduce_scatter, \"Reduce Scatter in ZeRO Stage 2 must be set to True for MoE. Other code paths are not tested with MoE\"\n\n        assert any(\n            [self.is_moe_group(group) for group in self.optimizer.param_groups]\n        ), \"The model has moe layers, but None of the param groups are marked as MoE. Create a param group with 'moe' key set to True before creating optimizer\"\n        self.is_moe_param_group = []\n        for i, group in enumerate(self.optimizer.param_groups):\n            if self.is_moe_group(group):\n                assert all([is_moe_param(param)\n                            for param in group['params']]), \"All params in MoE group must be MoE params\"\n                self.real_zp_process_group[i] = self.expert_dp_process_group[group['name']]\n                self.partition_count[i] = dist.get_world_size(group=self.expert_dp_process_group[group['name']])\n                self.is_moe_param_group.append(True)\n            else:\n                self.is_moe_param_group.append(False)\n\n        assert self.expert_dp_process_group is not None, \"Expert data parallel group should be configured with MoE\"\n        assert self.ep_process_group is not None, \"Expert parallel group should be configured with MoE\"\n\n    def _update_model_bit16_weights(self, group_index):\n        updated_params = self.unflatten(self.bit16_groups_flat[group_index],\n                                        self.round_robin_bit16_groups[group_index])\n        for p, q in zip(self.round_robin_bit16_groups[group_index], updated_params):\n            p.data = q.data\n\n        # set model fp16 weight to slices of reordered flattened buffer\n        for param_index, param in enumerate(self.bit16_groups[group_index]):\n            new_index = self.round_robin_bit16_indices[group_index][param_index]\n            param.data = self.round_robin_bit16_groups[group_index][new_index].data\n\n    def _round_robin_reorder(self, tensor_list, num_partitions):\n\n        # disable round robin if need to debug something\n        # return tensor_list, list(range(len(tensor_list)))\n\n        partition_tensors = {}\n\n        for i, tensor in enumerate(tensor_list):\n            j = i % num_partitions\n            if not j in partition_tensors:\n                partition_tensors[j] = []\n            partition_tensors[j].append((i, tensor))\n\n        reordered_tensors = []\n        reordered_indices = {}\n\n        for partition_index in partition_tensors.keys():\n            for i, (original_index, tensor) in enumerate(partition_tensors[partition_index]):\n                reordered_indices[original_index] = len(reordered_tensors)\n                reordered_tensors.append(tensor)\n\n        return reordered_tensors, reordered_indices\n\n    def _release_ipg_buffers(self):\n        if self.contiguous_gradients:\n            self.ipg_buffer = None\n            self.grads_in_partition = None\n            self.grads_in_partition_offset = 0\n\n    def initialize_optimizer_states(self):\n\n        for i, group in enumerate(self.bit16_groups):\n            single_grad_partition = torch.zeros(int(self.partition_size[i]),\n                                                dtype=self.single_partition_of_fp32_groups[i].dtype,\n                                                device=self.device)\n            self.single_partition_of_fp32_groups[i].grad = get_accelerator().pin_memory(\n                single_grad_partition) if self.cpu_offload_pin_memory else single_grad_partition\n\n        # Initialize the optimizer states with the flattened fp32 partition.\n        # State initialization for the Adagrad optimizer occurs at construction as opposed to other optimizers\n        # which do lazy initialization of the state at the first call to step.\n        if isinstance(self.optimizer, torch.optim.Adagrad):\n            self.optimizer = torch.optim.Adagrad(self.single_partition_of_fp32_groups, **self.optimizer.defaults)\n        else:\n            self.optimizer.step()\n\n        if not self.cpu_offload:\n            for group in self.single_partition_of_fp32_groups:\n                group.grad = None  # class init\n\n        return\n\n    #########################################################################\n    #################### ZeRO Stage 1 - reduce gradients ####################\n    #########################################################################\n    def reduce_gradients(self, pipeline_parallel=False):\n        world_size = dist.get_world_size(self.zp_process_group)\n        my_rank = dist.get_rank(self.zp_process_group)\n\n        # with PP we must create ipg buffer, since backward is handled outside zero\n        if pipeline_parallel and self.contiguous_gradients:\n            self.ipg_buffer = []\n            buf_0 = torch.empty(int(self.reduce_bucket_size),\n                                dtype=self.dtype,\n                                device=get_accelerator().current_device_name())\n            self.ipg_buffer.append(buf_0)\n            self.ipg_index = 0\n\n        if not self.overlap_comm:\n            for i, group in enumerate(self.bit16_groups):\n                for param in group:\n                    grad_reduc = self.get_gradient_for_reduction(param)\n                    if grad_reduc is not None:\n                        self.reduce_ready_partitions_and_remove_grads(param, i)\n        # reduce any pending grads in either hook/non-hook case\n        self.overlapping_partition_gradients_reduce_epilogue()\n\n    #########################################################################\n    #########################ZeRO Partition Gradients########################\n    #########################################################################\n\n    def get_first_param_index(self, group_id, param_group, partition_id):\n        for index, param in enumerate(param_group):\n            param_id = self.get_param_id(param)\n            if partition_id in self.param_to_partition_ids[group_id][param_id]:\n                return index\n        return None\n\n    def initialize_gradient_partitioning_data_structures(self):\n\n        for i, param_group in enumerate(self.round_robin_bit16_groups):\n            total_partitions = dist.get_world_size(group=self.real_zp_process_group[i])\n\n            self.param_to_partition_ids[i] = {}\n            self.is_partition_reduced[i] = {}\n            self.total_grads_in_partition[i] = {}\n            self.remaining_grads_in_partition[i] = {}\n            self.is_grad_computed[i] = {}\n            self.grad_partition_insertion_offset[i] = {}\n            self.grad_start_offset[i] = {}\n            self.first_param_index_in_partition[i] = {}\n\n            for partition_id in range(total_partitions):\n                self.is_grad_computed[i][partition_id] = {}\n                self.grad_partition_insertion_offset[i][partition_id] = {}\n                self.grad_start_offset[i][partition_id] = {}\n                self.total_grads_in_partition[i][partition_id] = 0\n                self.initialize_gradient_partition(i, param_group, partition_id)\n                self.is_partition_reduced[i][partition_id] = False\n                self.first_param_index_in_partition[i][partition_id] = self.get_first_param_index(\n                    i, param_group, partition_id)\n\n    def independent_gradient_partition_epilogue(self):\n        self.report_ipg_memory_usage(f\"In ipg_epilogue before reduce_ipg_grads\", 0)\n        self.reduce_ipg_grads()\n        self.report_ipg_memory_usage(f\"In ipg_epilogue after reduce_ipg_grads\", 0)\n\n        # if dist.get_rank() == 0:\n        #    logger.info(\"Params already reduced %s\", self.params_already_reduced)\n        for i in range(len(self.params_already_reduced)):\n            self.params_already_reduced[i] = False\n\n        if self.overlap_comm:\n            get_accelerator().synchronize()\n            # It is safe to clear previously reduced grads of other partitions\n            self._clear_previous_reduced_grads()\n\n        if self.cpu_offload is False:\n            for i, _ in enumerate(self.bit16_groups):\n\n                if not i in self.averaged_gradients or self.averaged_gradients[i] is None:\n                    self.averaged_gradients[i] = self.get_flat_partition(\n                        self.params_in_partition[i],\n                        self.first_offset[i],\n                        self.partition_size[i],\n                        dtype=self.gradient_accumulation_dtype,\n                        device=get_accelerator().current_device_name(),\n                        return_tensor_list=True)\n                else:\n                    avg_new = self.get_flat_partition(self.params_in_partition[i],\n                                                      self.first_offset[i],\n                                                      self.partition_size[i],\n                                                      dtype=self.gradient_accumulation_dtype,\n                                                      device=get_accelerator().current_device_name(),\n                                                      return_tensor_list=True)\n\n                    for accumulated_grad, new_avg_grad in zip(self.averaged_gradients[i], avg_new):\n                        accumulated_grad.add_(new_avg_grad)\n\n        self._release_ipg_buffers()\n\n        # No need to keep the gradients anymore.\n        # All gradients required by the step\n        # are in self.averaged_gradients\n        self.zero_grad(set_to_none=True)\n        see_memory_usage(f\"End ipg_epilogue\")\n\n    # resets all partition to no reduced\n    # sets remaining grads to the total number of grads in each partition\n    # set is grad computed to false for all grads in partition\n    def reset_partition_gradient_structures(self):\n        for i, _ in enumerate(self.bit16_groups):\n            total_partitions = dist.get_world_size(group=self.real_zp_process_group[i])\n            for partition_id in range(total_partitions):\n                self.is_partition_reduced[i][partition_id] = False\n                self.remaining_grads_in_partition[i][partition_id] = self.total_grads_in_partition[i][partition_id]\n\n                for param_id in self.is_grad_computed[i][partition_id]:\n                    self.is_grad_computed[i][partition_id][param_id] = False\n\n    def initialize_gradient_partition(self, i, param_group, partition_id):\n\n        def set_key_value_list(dictionary, key, value):\n            if key in dictionary:\n                dictionary[key].append(value)\n            else:\n                dictionary[key] = [value]\n\n        def increment_value(dictionary, key):\n            if key in dictionary:\n                dictionary[key] += 1\n            else:\n                dictionary[key] = 1\n\n        partition_size = self.partition_size[i]\n\n        start_index = partition_size * partition_id\n        end_index = partition_size * (partition_id + 1)\n\n        current_index = 0\n        first_offset = 0\n\n        for param in param_group:\n\n            param_size = param.numel()\n            param_id = self.get_param_id(param)\n\n            if start_index <= current_index < end_index:\n                set_key_value_list(self.param_to_partition_ids[i], param_id, partition_id)\n                increment_value(self.total_grads_in_partition[i], partition_id)\n\n                self.is_grad_computed[i][partition_id][param_id] = False\n\n                self.grad_partition_insertion_offset[i][partition_id][param_id] = current_index - start_index\n                self.grad_start_offset[i][partition_id][param_id] = 0\n\n            elif current_index < start_index < (current_index + param_size):\n                assert (first_offset == 0\n                        ), \"This can happen either zero or only once as this must be the first tensor in the partition\"\n                first_offset = start_index - current_index\n\n                set_key_value_list(self.param_to_partition_ids[i], param_id, partition_id)\n                increment_value(self.total_grads_in_partition[i], partition_id)\n\n                self.is_grad_computed[i][partition_id][param_id] = False\n\n                self.grad_partition_insertion_offset[i][partition_id][param_id] = 0\n                self.grad_start_offset[i][partition_id][param_id] = first_offset\n\n            current_index = current_index + param_size\n\n    def overlapping_partition_gradients_reduce_epilogue(self):\n        self.independent_gradient_partition_epilogue()\n\n    def fill_grad_accum_attribute(self):\n        for group in self.bit16_groups:\n            for param in group:\n                if param.grad is not None:\n                    if param.grad_accum is None:\n                        param.grad_accum = param.grad.to(self.gradient_accumulation_dtype)\n                    else:\n                        param.grad_accum.add_(\n                            param.grad.to(self.gradient_accumulation_dtype).view(param.grad_accum.shape))\n                    param.grad = None\n\n    def get_gradient_for_reduction(self, param):\n        if self.use_grad_accum_attribute:\n            return param.grad_accum.to(self.dtype) if param.grad_accum is not None else None\n        else:\n            return param.grad\n\n    def get_param_gradient_attribute(self, param):\n        return param.grad_accum if self.use_grad_accum_attribute else param.grad\n\n    # Clear the tensor the reduction gradient attribute is pointing to\n    def clear_grad_attribute(self, param):\n        if self.use_grad_accum_attribute:\n            param.grad_accum = None\n        else:\n            param.grad = None\n\n    def create_reduce_and_remove_grad_hooks(self):\n        self.grad_accs = []\n        for i, param_group in enumerate(self.bit16_groups):\n            for param in param_group:\n                if param.requires_grad:\n                    def wrapper(param, i):\n                        param_tmp = param.expand_as(param)\n                        grad_acc = param_tmp.grad_fn.next_functions[0][0]\n\n                        def reduce_partition_and_remove_grads(*notneeded):\n                            self.reduce_ready_partitions_and_remove_grads(param, i)\n\n                        grad_acc.register_hook(reduce_partition_and_remove_grads)\n                        self.grad_accs.append(grad_acc)\n\n                    wrapper(param, i)\n\n    def get_param_id(self, param):\n        unique_id = id(param)\n        return self.param_id[unique_id]\n\n    def report_ipg_memory_usage(self, tag, param_elems):\n        elem_count = self.elements_in_ipg_bucket + param_elems\n        percent_of_bucket_size = (100.0 * elem_count) // self.reduce_bucket_size\n        see_memory_usage(\n            f\"{tag}: elems in_bucket {self.elements_in_ipg_bucket} param {param_elems} max_percent {percent_of_bucket_size}\"\n        )\n\n    # create a flat tensor aligned at the alignment boundary\n    def flatten_dense_tensors_aligned(self, tensor_list, alignment):\n        return self.flatten(align_dense_tensors(tensor_list, alignment))\n\n    ############### Independent Partition Gradient ########################\n    def reduce_independent_p_g_buckets_and_remove_grads(self, param, i):\n\n        grad_reduc = self.get_gradient_for_reduction(param)\n        if self.elements_in_ipg_bucket + param.numel() > self.reduce_bucket_size:\n            self.report_ipg_memory_usage(\"In ipg_remove_grads before reduce_ipg_grads\", param.numel())\n            self.reduce_ipg_grads()\n            if self.contiguous_gradients and self.overlap_comm:\n                # Swap ipg_index between 0 and 1\n                self.ipg_index = 1 - self.ipg_index\n            self.report_ipg_memory_usage(\"In ipg_remove_grads after reduce_ipg_grads\", param.numel())\n\n        param_id = self.get_param_id(param)\n        assert self.params_already_reduced[param_id] == False, \\\n            f\"The parameter {param_id} has already been reduced. \\\n            Gradient computed twice for this partition. \\\n            Multiple gradient reduction is currently not supported\"\n\n        if self.contiguous_gradients:\n            if param.numel() > self.reduce_bucket_size:\n                self.extra_large_param_to_reduce = param\n            else:\n                # keeping the gradients contiguous to prevent memory fragmentation, and avoid flattening\n                new_grad_tensor = self.ipg_buffer[self.ipg_index].narrow(0, self.elements_in_ipg_bucket, param.numel())\n                new_grad_tensor.copy_(grad_reduc.view(-1))\n                grad_reduc.data = new_grad_tensor.data.view_as(grad_reduc)\n\n        self.elements_in_ipg_bucket += param.numel()\n\n        assert grad_reduc is not None, f\"rank {dist.get_rank()} - Invalid to reduce Param {param_id} with None gradient\"\n\n        self.grads_in_ipg_bucket.append(grad_reduc)\n        self.params_in_ipg_bucket.append((i, param, param_id))\n\n        # make sure the average tensor function knows how to average the gradients\n        if is_moe_param(param):\n            self.ipg_bucket_has_moe_params = True\n\n        self.report_ipg_memory_usage(\"End ipg_remove_grads\", 0)\n\n    def print_rank_0(self, message):\n        if dist.get_rank() == 0:\n            logger.info(message)\n\n    def gradient_reduction_w_predivide(self, tensor):\n\n        dp_world_size = dist.get_world_size(group=self.dp_process_group)\n\n        tensor_to_allreduce = tensor\n\n        if self.communication_data_type != tensor.dtype:\n            tensor_to_allreduce = tensor.to(self.communication_data_type)\n\n        if self.postscale_gradients:\n            if self.gradient_predivide_factor != 1.0:\n                tensor_to_allreduce.mul_(1. / self.gradient_predivide_factor)\n\n            dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)\n\n            if self.gradient_predivide_factor != dp_world_size:\n                tensor_to_allreduce.mul_(self.gradient_predivide_factor /\n                                         (dp_world_size / float(self.sequence_parallel_size)))\n        else:\n            tensor_to_allreduce.div_(dp_world_size / float(self.sequence_parallel_size))\n            dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)\n\n        if self.communication_data_type != tensor.dtype and tensor is not tensor_to_allreduce:\n            tensor.copy_(tensor_to_allreduce)\n\n        return tensor\n\n    def allreduce_and_copy_with_multiple_ranks(self,\n                                               small_bucket,\n                                               log=None,\n                                               divide=True,\n                                               process_group=None,\n                                               bucket_ranks=None):\n        process_group = self.zp_process_group if process_group is None else process_group\n        allreduced = self.allreduce_bucket(small_bucket, log=log, divide=divide, process_group=process_group)\n        for buf, synced, bucket_rank in zip(small_bucket, self.unflatten(allreduced, small_bucket), bucket_ranks):\n            if dist.get_rank(group=process_group) == bucket_rank:\n                buf.copy_(synced)\n\n    def allreduce_and_scatter(self, bucket, numel_per_bucket=500000000, log=None, divide=True, process_group=None):\n        small_bucket = []\n        small_bucket_ranks = []\n        numel = 0\n        allreduce_sizes = []\n\n        for i, bucket_elem in enumerate(bucket):\n            rank, tensor = bucket_elem\n            small_bucket.append(tensor)\n            small_bucket_ranks.append(rank)\n            numel = numel + tensor.numel()\n            if numel > numel_per_bucket:\n                self.allreduce_and_copy_with_multiple_ranks(small_bucket,\n                                                            log=None,\n                                                            divide=divide,\n                                                            process_group=process_group,\n                                                            bucket_ranks=small_bucket_ranks)\n                small_bucket = []\n                small_bucket_ranks = []\n                numel = 0\n\n        if len(small_bucket) > 0:\n            self.allreduce_and_copy_with_multiple_ranks(small_bucket,\n                                                        log=None,\n                                                        divide=divide,\n                                                        process_group=process_group,\n                                                        bucket_ranks=small_bucket_ranks)\n\n    def average_tensor(self, tensor):\n        if self.overlap_comm:\n            stream = self.reduction_stream\n            if not get_accelerator().is_synchronized_device():\n                stream.wait_stream(get_accelerator().current_stream())\n        else:\n            stream = get_accelerator().current_stream()\n\n        with get_accelerator().stream(stream):\n            if not self.reduce_scatter:\n                self.gradient_reduction_w_predivide(tensor)\n                return\n\n            # Accumulate destination ranks and bucket offsets for each gradient slice.\n            # Note: potential future optimization, record access pattern of parameters\n            # in backward pass and partition gradients w.r.t. access pattern so that our\n            # bucket is guaranteed to be contiguous w.r.t. ranks\n            rank_and_offsets = []\n            real_dp_process_group = []\n            curr_size = 0\n            prev_id, prev_process_group = -1, None\n\n            process_group = self.zp_process_group\n            # count = 0\n            for i, param, param_id in self.params_in_ipg_bucket:\n\n                process_group = self.zp_process_group\n                grad_reduc = self.get_gradient_for_reduction(param)\n                # Averages gradients at parameter level if ipg has a moe param\n                # Otherwise averaging is done at the entire buffer level at the end of the loop\n                # MoE param have different groups\n                if self.ipg_bucket_has_moe_params:\n                    process_group = self.expert_dp_process_group[param.group_name] if is_moe_param(\n                        param) else self.zp_process_group\n                    grad_reduc.data.div_(dist.get_world_size(group=process_group) / float(self.sequence_parallel_size))\n\n                partition_ids = self.param_to_partition_ids[i][param_id]\n                assert all([p_id < dist.get_world_size(group=process_group) for p_id in partition_ids\n                            ]), f\"world size {dist.get_world_size(group=process_group)} and p_ids: {partition_ids}\"\n                partition_size = self.partition_size[i]\n                # Get all partition ids + their offsets\n                partition_ids_w_offsets = []\n                for partition_id in partition_ids:\n                    offset = self.grad_start_offset[i][partition_id][param_id]\n                    partition_ids_w_offsets.append((partition_id, offset))\n                partition_ids_w_offsets.sort(key=lambda t: t[1])\n\n                # Calculate rank and offsets for grad slices\n                for idx in range(len(partition_ids_w_offsets)):\n                    partition_id, offset = partition_ids_w_offsets[idx]\n\n                    # if dist.get_rank() == 0 and count < 100:\n                    #     print(f\"Rank {dist.get_rank()} rank offset id {idx} calculated dp size {dist.get_world_size(group=process_group)} real dp size {dist.get_world_size(self.real_dp_process_group[i])} and dst: {partition_id}\")\n                    # count += 1\n\n                    # Calculate numel for grad slice depending on partition location\n                    if idx == len(partition_ids_w_offsets) - 1:\n                        # Last partition_id uses its own offset\n                        numel = param.numel() - offset\n                    else:\n                        # Set numel to next partition's offset\n                        numel = partition_ids_w_offsets[idx + 1][1] - offset\n\n                    # Merge bucket ranges if they belong to the same rank\n                    if partition_id == prev_id and process_group == prev_process_group:\n                        prev_pid, prev_size, prev_numel = rank_and_offsets[-1]\n                        rank_and_offsets[-1] = (prev_pid, prev_size, prev_numel + numel)\n                    else:\n                        rank_and_offsets.append((partition_id, curr_size, numel))\n                        real_dp_process_group.append(process_group)\n                    curr_size += numel\n                    prev_id, prev_process_group = partition_id, process_group\n\n            if not self.ipg_bucket_has_moe_params:\n                tensor.div_(dist.get_world_size(group=self.dp_process_group) / float(self.sequence_parallel_size))\n\n            buckets = {}\n            for i, (dst, bucket_offset, numel) in enumerate(rank_and_offsets):\n                grad_slice = tensor.narrow(0, int(bucket_offset), int(numel))\n                bucket_key = real_dp_process_group[i] if self.use_multi_rank_bucket_allreduce else (\n                    dst, real_dp_process_group[i])\n                if bucket_key not in buckets:\n                    buckets[bucket_key] = []\n                if self.use_multi_rank_bucket_allreduce:\n                    buckets[bucket_key].append((dst, grad_slice))\n                else:\n                    buckets[bucket_key].append(grad_slice)\n\n            for bucket_key in buckets:\n                if self.use_multi_rank_bucket_allreduce:\n                    self.allreduce_and_scatter(buckets[bucket_key],\n                                               numel_per_bucket=self.reduce_bucket_size,\n                                               divide=self.ipg_bucket_has_moe_params,\n                                               process_group=bucket_key)\n                else:\n                    dst, process_group = bucket_key\n                    self.allreduce_no_retain(buckets[bucket_key],\n                                             numel_per_bucket=self.reduce_bucket_size,\n                                             rank=dst,\n                                             divide=self.ipg_bucket_has_moe_params,\n                                             process_group=process_group)\n\n    ##############################################################################\n    ############################# CPU Offload Methods#############################\n    ##############################################################################\n    def get_grad_position(self, group_id, tensor_list, first_offset, partition_size):\n        current_offset = 0\n\n        for i, tensor in enumerate(tensor_list):\n            param_id = self.get_param_id(tensor)\n            param_start_offset = 0\n\n            num_elements = tensor.numel()\n\n            # we need to offset to get to the right element\n            if i == 0 and first_offset > 0:\n                tensor_offset = first_offset\n                num_elements = num_elements - tensor_offset\n                param_start_offset = first_offset\n\n            # we dont need all elements of the tensor\n            if num_elements > (partition_size - current_offset):\n                num_elements = partition_size - current_offset\n\n            self.grad_position[param_id] = [\n                int(group_id), int(param_start_offset),\n                int(current_offset), int(num_elements)\n            ]\n            current_offset += num_elements\n\n    def update_overflow_tracker_for_param_grad(self, param):\n        grad_accum = self.get_param_gradient_attribute(param)\n        if grad_accum is not None and self._has_inf_or_nan(grad_accum.data):\n            self.local_overflow = True\n\n    def _get_offload_gradient_dict(self):\n        for param_group_index, _ in enumerate(self.optimizer.param_groups):\n            self.offload_gradient_dict[param_group_index] = []\n            for lp_param in self.params_in_partition[param_group_index]:\n                param_id = self.get_param_id(lp_param)\n                [_, _, dest_offset, num_elements] = self.grad_position[param_id]\n                dest_tensor = self.single_partition_of_fp32_groups[param_group_index].grad.view(-1).narrow(\n                    0, dest_offset, num_elements)\n                self.offload_gradient_dict[param_group_index].append(dest_tensor)\n\n    def async_accumulate_grad_in_cpu_via_gpu(self, param):\n        param_id = self.get_param_id(param)\n\n        [i, source_offset, dest_offset, num_elements] = self.grad_position[param_id]\n\n        # copy to a preexisiting buffer to avoid memory allocation penalty\n        dest_buffer = self.temp_grad_buffer_for_gpu_offload.view(-1).narrow(0, 0, param.numel())\n\n        # buffer for storing gradients for this parameter in CPU\n        def buffer_to_accumulate_to_in_cpu():\n            if not self.fp16_master_weights_and_gradients:\n                buffer = torch.zeros(param.numel(), dtype=param.dtype, device=self.device)\n                return get_accelerator().pin_memory(buffer) if self.cpu_offload_pin_memory else buffer\n            else:\n                return self.single_partition_of_fp32_groups[i].grad.view(-1).narrow(0, dest_offset, num_elements)\n\n        # accumulate gradients into param.grad_accum or parts of it that belongs to this partition\n        def accumulate_gradients():\n            grad_accum = self.get_param_gradient_attribute(param)\n            if not self.fp16_master_weights_and_gradients:\n                dest_buffer.copy_(self.accumulated_grads_in_cpu[param_id].view(-1), non_blocking=True)\n                grad_accum.data.view(-1).add_(dest_buffer)\n            else:\n                dest_buffer.narrow(0, source_offset,\n                                   num_elements).copy_(self.accumulated_grads_in_cpu[param_id].view(-1),\n                                                       non_blocking=True)\n                grad_accum.data.view(-1).narrow(0, source_offset,\n                                                num_elements).add_(dest_buffer.narrow(0, source_offset, num_elements))\n\n        # move accumulated gradients back to CPU\n        def copy_gradients_to_cpu():\n            grad_accum = self.get_param_gradient_attribute(param)\n            if not self.fp16_master_weights_and_gradients:\n                self.accumulated_grads_in_cpu[param_id].data.copy_(grad_accum.data.view(-1), non_blocking=True)\n            else:\n                self.accumulated_grads_in_cpu[param_id].data.copy_(grad_accum.data.view(-1).narrow(\n                    0, source_offset, num_elements),\n                    non_blocking=True)\n\n        if param_id not in self.accumulated_grads_in_cpu:\n            self.accumulated_grads_in_cpu[param_id] = buffer_to_accumulate_to_in_cpu()\n\n        if self.micro_step_id > 0:\n            accumulate_gradients()\n\n        # at the boundary we will send 32bit directly\n        if not self.is_gradient_accumulation_boundary:\n            copy_gradients_to_cpu()\n\n    def set_norm_for_param_grad(self, param):\n        param_id = self.get_param_id(param)\n        grad_accum = self.get_param_gradient_attribute(param)\n        accumulated_grad = self.accumulated_grads_in_cpu[\n            param_id] if self.gradient_accumulation_steps > 1 else grad_accum\n\n        [i, source_offset, dest_offset, num_elements] = self.grad_position[param_id]\n\n        start = source_offset\n        accumulated_grad = accumulated_grad.view(-1).narrow(0, start, num_elements)\n\n        self.norm_for_param_grads[param_id] = accumulated_grad.data.double().norm(2)\n\n    def set_norm_for_param_grad_in_gpu(self, param):\n        param_id = self.get_param_id(param)\n        grad_accum = self.get_param_gradient_attribute(param)\n        if grad_accum is None:\n            accumulated_grad = param.grad\n        else:\n            accumulated_grad = grad_accum\n\n        [i, source_offset, dest_offset, num_elements] = self.grad_position[param_id]\n\n        start = source_offset\n        accumulated_grad = accumulated_grad.view(-1).narrow(0, start, num_elements)\n\n        self.norm_for_param_grads[param_id] = accumulated_grad.data.double().norm(2)\n\n    def async_inplace_copy_grad_to_fp32_buffer_from_gpu(self, param):\n        param_id = self.get_param_id(param)\n\n        [i, source_offset, dest_offset, num_elements] = self.grad_position[param_id]\n\n        dest_tensor = self.single_partition_of_fp32_groups[i].grad.view(-1).narrow(0, dest_offset, num_elements)\n\n        grad_accum = self.get_param_gradient_attribute(param)\n        if grad_accum is None:\n            src_tensor = grad_accum.view(-1).narrow(0, source_offset, num_elements)\n        else:\n            src_tensor = grad_accum.view(-1).narrow(0, source_offset, num_elements)\n        if not self.fp16_master_weights_and_gradients:\n            src_tensor = src_tensor.float()\n\n        dest_tensor.copy_(src_tensor, non_blocking=True)\n        param.grad = None  # offload only\n\n    def complete_grad_norm_calculation_for_cpu_offload(self, params):\n        total_norm = 0.0\n        norm_type = 2.0\n        for p in params:\n            # Pipeline parallelism may replicate parameters. Avoid multi-counting.\n            if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated:\n                continue\n\n            if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0):\n                param_id = self.get_param_id(p)\n                # as some model have trainable parameters but skipped in training,\n                # their backward hooks in self.create_reduce_and_remove_grad_hooks() will not run,\n                # so they have no norm_for_param_grads\n                if param_id in self.norm_for_param_grads:\n                    param_norm = self.norm_for_param_grads[param_id]\n                    total_norm += param_norm.item() ** 2\n                else:\n                    # As unused parameters in modules may not be expected sometimes,\n                    # add an explicit error msg when it occurred and an option to\n                    # avoid the error\n                    assert self.ignore_unused_parameters, \"\"\"\n                        This assert indicates that your module has parameters that\n                        were not used in producing loss.\n                        You can avoid this assert by\n                        (1) enable ignore_unused_parameters option in zero_optimization config;\n                        (2) making sure all trainable parameters and `forward` function\n                            outputs participate in calculating loss.\n                    \"\"\"\n\n        # Sum across all model parallel GPUs.\n        total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])\n        dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=self.dp_process_group)\n\n        self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.SUM)\n\n        total_norm = total_norm_cuda[0].item() ** (1. / norm_type)\n\n        if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm:\n            total_norm = -1\n\n        return total_norm\n\n    ############################################################################################\n    def copy_grads_in_partition(self, param):\n        if self.cpu_offload:\n\n            if self.gradient_accumulation_steps > 1:\n                self.async_accumulate_grad_in_cpu_via_gpu(param)\n\n            if self.is_gradient_accumulation_boundary:\n                self.set_norm_for_param_grad_in_gpu(param)\n\n                self.update_overflow_tracker_for_param_grad(param)\n\n                self.async_inplace_copy_grad_to_fp32_buffer_from_gpu(param)\n\n            return\n        # print(f\"ID {self.get_param_id(param)} grad norm {param.grad.norm()}\")\n        if self.grads_in_partition is None:\n            self.grads_in_partition_offset = 0\n            total_size = 0\n            for group in self.params_in_partition:\n                for param_in_partition in group:\n                    total_size += param_in_partition.numel()\n\n            see_memory_usage(f\"before copying {total_size} gradients into partition\")\n            self.grads_in_partition = torch.empty(int(total_size),\n                                                  dtype=self.dtype,\n                                                  device=get_accelerator().current_device_name())\n            see_memory_usage(f\"after copying {total_size} gradients into partition\")\n\n        grad_reduc = self.get_gradient_for_reduction(param)\n        # The allreduce buffer will be rewritten. Copy the gradients in partition to a new buffer\n        new_grad_tensor = self.grads_in_partition.view(-1).narrow(0, self.grads_in_partition_offset, param.numel())\n        new_grad_tensor.copy_(grad_reduc.view(-1))\n        grad_reduc.data = new_grad_tensor.data.view_as(grad_reduc)\n        # print(f\"Grad norm after copy to contiguous_buffer {param.grad.data.norm()}\")\n        self.grads_in_partition_offset += param.numel()\n\n    def reduce_ipg_grads(self):\n        if self.contiguous_gradients:\n            if self.extra_large_param_to_reduce is not None:\n                assert len(self.params_in_ipg_bucket) == 1, \"more than 1 param in ipg bucket, this shouldn't happen\"\n                _, _, param_id = self.params_in_ipg_bucket[0]\n                assert self.get_param_id(self.extra_large_param_to_reduce\n                                         ) == param_id, \"param in ipg bucket does not match extra-large param\"\n                extra_large_grad_reduc = self.get_gradient_for_reduction(self.extra_large_param_to_reduce)\n                self.average_tensor(extra_large_grad_reduc.view(-1))\n                self.extra_large_param_to_reduce = None\n            else:\n                self.average_tensor(self.ipg_buffer[self.ipg_index])\n        else:\n            self.buffered_reduce_fallback(None,\n                                          self.grads_in_ipg_bucket,\n                                          elements_per_buffer=self.elements_in_ipg_bucket)\n\n        if self.overlap_comm:\n            stream = self.reduction_stream\n        elif self.cpu_offload:\n            # TODO: copy_grad_stream is disabled because of race with reduce. This hurts perf and should be fixed.\n            #            get_accelerator().synchronize()\n            #            stream = self.copy_grad_stream\n            stream = get_accelerator().current_stream()\n        else:\n            stream = get_accelerator().current_stream()\n\n        with get_accelerator().stream(stream):\n            for _, param, param_id in self.params_in_ipg_bucket:\n\n                assert self.params_already_reduced[param_id] == False, \\\n                    f\"The parameter {param_id} has already been reduced. \\\n                    Gradient computed twice for this partition. \\\n                    Multiple gradient reduction is currently not supported\"\n\n                self.params_already_reduced[param_id] = True\n                if self.partition_gradients:\n                    if not self.is_param_in_current_partition[param_id]:\n                        if self.overlap_comm and self.contiguous_gradients is False:\n                            # Clear grads of other partitions during the next reduction\n                            # to avoid clearing them before the reduction is complete.\n                            if self.previous_reduced_grads is None:\n                                self.previous_reduced_grads = []\n                            self.previous_reduced_grads.append(param)\n                        else:\n                            self.clear_grad_attribute(param)\n                    elif self.contiguous_gradients:\n                        self.copy_grads_in_partition(param)\n                else:  # zero stage 1 - partition only optimizer state\n                    if self.contiguous_gradients and self.is_param_in_current_partition[param_id]:\n                        self.copy_grads_in_partition(param)\n\n        self.grads_in_ipg_bucket = []\n        self.params_in_ipg_bucket = []\n        self.ipg_bucket_has_moe_params = False\n        self.elements_in_ipg_bucket = 0\n        #####################################################################\n\n    def reduce_ready_partitions_and_remove_grads(self, param, i):\n        if self.partition_gradients or self.is_gradient_accumulation_boundary:\n            self.reduce_independent_p_g_buckets_and_remove_grads(param, i)\n\n    def zero_reduced_gradients(self, partition_id, i):\n\n        def are_all_related_partitions_reduced(params_id):\n            for partition_id in self.param_to_partition_ids[i][params_id]:\n                if not self.is_partition_reduced[i][partition_id]:\n                    return False\n            return True\n\n        for params_id in self.is_grad_computed[i][partition_id]:\n            if are_all_related_partitions_reduced(params_id):\n                self.param_dict[params_id].grad = None  # dead code\n\n    def flatten_and_print(self, message, tensors, start=0, n=5):\n        flatten_tensor = self.flatten(tensors)\n\n        def print_func():\n            logger.info(flatten_tensor.contiguous().view(-1).narrow(0, start, n))\n\n        self.sequential_execution(print_func, message)\n\n    def get_grads_to_reduce(self, i, partition_id):\n\n        def get_reducible_portion(key):\n            grad = self.param_dict[key].grad\n            total_elements = grad.numel()\n            start = self.grad_start_offset[i][partition_id][key]\n            num_elements = min(total_elements - start,\n                               self.partition_size[i] - self.grad_partition_insertion_offset[i][partition_id][key])\n            if not pg_correctness_test:\n                if num_elements == total_elements:\n                    return grad\n                else:\n                    return grad.contiguous().view(-1).narrow(0, int(start), int(num_elements))\n            else:\n                if num_elements == total_elements:\n                    return grad.clone()\n                else:\n                    return grad.clone().contiguous().view(-1).narrow(0, int(start), int(num_elements))\n\n        grads_to_reduce = []\n        for key in self.is_grad_computed[i][partition_id]:\n            grad = get_reducible_portion(key)\n            grads_to_reduce.append(grad)\n        return grads_to_reduce\n\n    def sequential_execution(self, function, message, group=None):\n        if group is None:\n            group = self.zp_process_group\n        if dist.get_rank(group=group) == 0:\n            logger.info(message)\n        for id in range(dist.get_world_size(group=group)):\n            if id == dist.get_rank(group=group):\n                function()\n            dist.barrier(group=group)\n\n    def set_none_gradients_to_zero(self, i, partition_id):\n        for param_id in self.is_grad_computed[i][partition_id]:\n            param = self.param_dict[param_id]\n            if param.grad is None:\n                param.grad = torch.zero_like(param)\n\n    ######################Reduction Related Methods##############################\n    def allreduce_bucket(self, bucket, rank=None, log=None, divide=True, process_group=None):\n        rank = None\n        tensor = self.flatten(bucket)\n\n        process_group = self.zp_process_group if process_group is None else process_group\n\n        tensor_to_allreduce = tensor\n\n        if pg_correctness_test or self.sequence_parallel_size > 1:\n            communication_data_type = torch.float32\n        else:\n            communication_data_type = self.communication_data_type\n\n        if communication_data_type != tensor.dtype:\n            tensor_to_allreduce = tensor.to(communication_data_type)\n\n        if divide:\n            tensor_to_allreduce.div_(\n                dist.get_world_size(group=self.dp_process_group) / float(self.sequence_parallel_size))\n\n        tensor_to_allreduce = tensor_to_allreduce.contiguous()\n        if rank is None:\n            #    \"All Reducing\"\n            dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)\n        else:\n            global_rank = dist.get_global_rank(self.dp_process_group, rank)\n            dist.reduce(tensor_to_allreduce, global_rank, group=self.dp_process_group)\n\n        if communication_data_type != tensor.dtype and tensor is not tensor_to_allreduce:\n            if rank is None or rank == dist.get_rank(group=process_group):\n                tensor.copy_(tensor_to_allreduce)\n\n        return tensor\n\n    def _clear_previous_reduced_grads(self):\n        if self.previous_reduced_grads is not None:\n            for param in self.previous_reduced_grads:\n                self.clear_grad_attribute(param)\n            self.previous_reduced_grads = None\n\n    # if rank is specified do a reduction instead of an allreduce\n    def allreduce_and_copy(self, small_bucket, rank=None, log=None, divide=True, process_group=None):\n        process_group = self.zp_process_group if process_group is None else process_group\n        if self.overlap_comm:\n            get_accelerator().synchronize()\n            # It is safe to clear the previously reduced grads of other partitions\n            self._clear_previous_reduced_grads()\n            stream = self.reduction_stream\n        else:\n            stream = get_accelerator().current_stream()\n\n        with get_accelerator().stream(stream):\n            allreduced = self.allreduce_bucket(\n                small_bucket,\n                rank=rank,\n                log=log,\n                divide=divide,\n                process_group=process_group,\n            )\n            if rank is None or rank == dist.get_rank(group=self.zp_process_group):\n                for buf, synced in zip(small_bucket, self.unflatten(allreduced, small_bucket)):\n                    buf.copy_(synced)\n\n    def allreduce_no_retain(\n            self,\n            bucket,\n            numel_per_bucket=500000000,\n            rank=None,\n            log=None,\n            divide=True,\n            process_group=None,\n    ):\n        small_bucket = []\n        numel = 0\n        for tensor in bucket:\n            small_bucket.append(tensor)\n            numel = numel + tensor.numel()\n            if numel > numel_per_bucket:\n                self.allreduce_and_copy(small_bucket, rank=rank, log=None, divide=divide, process_group=process_group)\n                small_bucket = []\n                numel = 0\n\n        if len(small_bucket) > 0:\n            self.allreduce_and_copy(small_bucket, rank=rank, log=log, divide=divide, process_group=process_group)\n\n    # allows using reduction of gradients instead of using all_reduce\n\n    def buffered_reduce_fallback(self, rank, grads, elements_per_buffer=500000000, log=None):\n        split_buckets = split_half_float_double(grads)\n\n        for i, bucket in enumerate(split_buckets):\n            self.allreduce_no_retain(bucket, numel_per_bucket=elements_per_buffer, rank=rank, log=log)\n\n    #############################################################################\n    #############################################################################\n    #############################################################################\n\n    # views the tensor as multiple partitions and returns\n    # those partitions\n    def get_data_parallel_partitions(self, tensor, group_id):\n        partitions = []\n\n        dp = dist.get_world_size(group=self.real_zp_process_group[group_id])\n        # dp_id = dist.get_rank(group=self.real_dp_process_group[group_id])\n\n        total_num_elements = tensor.numel()\n\n        base_size = total_num_elements // dp\n        remaining = total_num_elements % dp\n\n        start = 0\n        for id in range(dp):\n            partition_size = base_size\n            if id < remaining:\n                partition_size = partition_size + 1\n            partitions.append(tensor.narrow(0, start, partition_size))\n            start = start + partition_size\n        return partitions\n\n    def get_partition_info(self, tensor_list, partition_size, partition_id):\n        params_in_partition = []\n        params_not_in_partition = []\n\n        start_index = partition_size * partition_id\n        end_index = partition_size * (partition_id + 1)\n\n        current_index = 0\n        first_offset = 0\n\n        for tensor in tensor_list:\n\n            tensor_size = tensor.numel()\n\n            if start_index <= current_index < end_index:\n                params_in_partition.append(tensor)\n\n            elif current_index < start_index < (current_index + tensor_size):\n                params_in_partition.append(tensor)\n\n                assert (first_offset == 0\n                        ), \"This can happen either zero or only once as this must be the first tensor in the partition\"\n                first_offset = start_index - current_index\n\n            else:\n                params_not_in_partition.append(tensor)\n\n            current_index = current_index + tensor_size\n\n        return params_in_partition, params_not_in_partition, first_offset\n\n    def zero_grad(self, set_to_none=True):\n        \"\"\"\n        Zero FP16 parameter grads.\n        \"\"\"\n        # FP32 grad should never exist.\n        # For speed, set model fp16 grad to None by default\n        # zero all pointers to grad tensors\n        for group in self.bit16_groups:\n            for p in group:\n                if set_to_none:\n                    p.grad = None  # epilogue and in step\n                    p.grad_accum = None\n                else:\n                    if p.grad is not None:\n                        p.grad.detach_()\n                        p.grad.zero_()\n\n    def _model_parallel_all_reduce(self, tensor, op):\n        \"\"\" Perform all reduce within model parallel group, if any.\n        \"\"\"\n        if self.model_parallel_group is None or self.model_parallel_world_size == 1:\n            pass\n        else:\n            dist.all_reduce(tensor=tensor, op=op, group=self.model_parallel_group)\n\n    def get_grad_norm_direct(self, gradients, params, norm_type=2):\n        \"\"\"Clips gradient norm of an iterable of parameters.\n\n        This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and\n        added functionality to handle model parallel parameters. Note that\n        the gradients are modified in place.\n\n        Arguments:\n            parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a\n                single Tensor that will have gradients normalized\n            max_norm (float or int): max norm of the gradients\n            norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for\n                infinity norm.\n\n        Returns:\n            Total norm of the parameters (viewed as a single vector).\n        \"\"\"\n        norm_type = float(norm_type)\n        if norm_type == inf:\n            total_norm = max(g.data.abs().max() for g in gradients)\n            total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])\n            dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=self.dp_process_group)\n\n            # Take max across all GPUs.\n            self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX)\n            total_norm = total_norm_cuda[0].item()\n        else:\n            total_norm = 0.0\n            # if dist.get_rank() == 0:\n            #    logger.info(f\"Total Norm beginning {total_norm}\")\n            for g, p in zip(gradients, params):\n                # Pipeline parallelism may replicate parameters. Avoid multi-counting.\n                if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated:\n                    continue\n                if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0):\n                    param_norm = g.data.double().norm(2)\n                    total_norm += param_norm.item() ** 2\n            # Sum across all model parallel GPUs.\n            total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])\n            dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=self.dp_process_group)\n\n            self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.SUM)\n\n            total_norm = total_norm_cuda[0].item() ** (1. / norm_type)\n\n        if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm:\n            total_norm = -1\n\n        return total_norm\n\n    # creates a flat fused tensor from the tensor list starting at the first_offset\n    # in the first tensor of the list. If there are not enough elements in the tensor\n    # list then the flat tensor will be padded with zeros\n    def get_flat_partition(self, tensor_list, first_offset, partition_size, dtype, device, return_tensor_list=False):\n        flat_tensor_list = []\n        current_size = 0\n\n        for i, tensor in enumerate(tensor_list):\n            grad_accum = self.get_param_gradient_attribute(tensor)\n            if grad_accum is None:\n                grad_accum = torch.zeros_like(tensor, dtype=dtype)\n\n            tensor = grad_accum\n            num_elements = tensor.numel()\n            tensor_offset = 0\n\n            # we need to offset to get to the right element\n            if i == 0 and first_offset > 0:\n                tensor_offset = first_offset\n                num_elements = num_elements - tensor_offset\n\n            # we dont need all elements of the tensor\n            if num_elements > (partition_size - current_size):\n                num_elements = partition_size - current_size\n\n            # we need a narrow view of the tensor based on the tensor offset and number of elements that\n            # we need from this tensor\n            if tensor_offset > 0 or num_elements < tensor.numel():\n                flat_tensor_list.append(tensor.contiguous().view(-1).narrow(0, int(tensor_offset), int(num_elements)))\n            else:\n                flat_tensor_list.append(tensor)\n\n            current_size = current_size + num_elements\n\n        # this means its the last partition and does not align with the dp boundary. We need to pad before flattening\n        if current_size < partition_size:\n            flat_tensor_list.append(torch.zeros(int(partition_size - current_size), dtype=dtype, device=device))\n\n        if return_tensor_list:\n            return flat_tensor_list\n\n        return self.flatten(flat_tensor_list)\n\n    def free_grad_in_param_list(self, param_list):\n        for p in param_list:\n            p.grad = None  # in step\n            p.grad_accum = None\n\n    def reset_cpu_buffers(self):\n        self.norm_for_param_grads = {}\n        self.local_overflow = False\n\n    def set_lr(self, lr):\n        \"\"\"Set the learning rate.\"\"\"\n        for param_group in self.optimizer.param_groups:\n            param_group[\"lr\"] = lr\n\n    def get_lr(self):\n        \"\"\"Return the current learning rate.\"\"\"\n        return self.optimizer.param_groups[0][\"lr\"]\n\n    def override_loss_scale(self, loss_scale):\n        if loss_scale != self.external_loss_scale:\n            logger.info(f'[deepspeed] setting loss scale from {self.external_loss_scale} -> {loss_scale}')\n        self.custom_loss_scaler = True\n        self.external_loss_scale = loss_scale\n\n    def scaled_global_norm(self, norm_type=2):\n        assert norm_type == 2, \"only L2 norm supported\"\n        norm_groups = []\n        for i, group in enumerate(self.bit16_groups):\n            partition_id = dist.get_rank(group=self.real_zp_process_group[i])\n            if self.cpu_offload:\n                norm_groups.append(self.complete_grad_norm_calculation_for_cpu_offload(self.params_in_partition[i]))\n                single_grad_partition = self.single_partition_of_fp32_groups[i].grad\n            else:\n                norm_groups.append(self.get_grad_norm_direct(self.averaged_gradients[i], self.params_in_partition[i]))\n\n        if self.has_moe_layers:\n            self._average_expert_grad_norms(norm_groups)\n\n        # note that the get_global_norm function only supports l2 norm\n        return get_global_norm(norm_list=norm_groups)\n\n    def get_bit16_param_group(self, group_no):\n        bit16_partitions = self.parallel_partitioned_bit16_groups[group_no]\n        partition_id = dist.get_rank(group=self.real_zp_process_group[group_no])\n        return [bit16_partitions[dist.get_rank(group=self.real_zp_process_group[group_no])]]\n\n    def _optimizer_step(self, group_no):\n        original_param_groups = self.optimizer.param_groups\n        self.optimizer.param_groups = [original_param_groups[group_no]]\n        # Disabling this as the C++ side copy & synchronize is not working correctly\n        # from deepspeed.ops.adam import DeepSpeedCPUAdam\n        # if type(self.optimizer) == DeepSpeedCPUAdam and self.dtype == torch.half:\n        #    self.optimizer.step(fp16_param_groups=[self.get_bit16_param_group(group_no)])\n        # else:\n        #    self.optimizer.step()\n        self.optimizer.step()\n        self.optimizer.param_groups = original_param_groups\n\n    def step(self, closure=None):\n        \"\"\"\n        Not supporting closure.\n        \"\"\"\n        self.micro_step_id = -1\n\n        see_memory_usage(f\"In step before checking overflow\")\n\n        # First compute norm for all group so we know if there is overflow\n        if self.dtype == torch.float16:\n            self.check_overflow()\n\n        prev_scale = self.loss_scale\n        self._update_scale(self.overflow)\n        if self.overflow:\n            see_memory_usage('After overflow before clearing gradients')\n            self.zero_grad(set_to_none=True)\n            if self.cpu_offload:\n                self.reset_cpu_buffers()\n            else:\n                self.averaged_gradients = {}\n\n            see_memory_usage('After overflow after clearing gradients')\n\n            for timer in OPTIMIZER_TIMERS:\n                self.timers(timer).start()\n                self.timers(timer).stop()\n            return\n\n        # Step 1:- Calculate gradient norm using bit-16 grads\n        see_memory_usage('Before norm calculation')\n        scaled_global_grad_norm = self.scaled_global_norm()\n        self._global_grad_norm = scaled_global_grad_norm / prev_scale\n        see_memory_usage('After norm before optimizer')\n\n        # Step 2:- run optimizer and upscaling simultaneously\n        for i, group in enumerate(self.bit16_groups):\n            self.timers(OPTIMIZER_GRADIENTS_TIMER).start()\n            partition_id = dist.get_rank(group=self.real_zp_process_group[i])\n            if self.cpu_offload:\n                single_grad_partition = self.single_partition_of_fp32_groups[i].grad\n                self.unscale_and_clip_grads([single_grad_partition], scaled_global_grad_norm)\n\n                self.timers(OPTIMIZER_GRADIENTS_TIMER).stop()\n                self.timers(OPTIMIZER_STEP_TIMER).start()\n                self._optimizer_step(i)\n\n                # Disabled, this is not currently working\n                # from deepspeed.ops.adam import DeepSpeedCPUAdam\n                # if not (type(self.optimizer) == DeepSpeedCPUAdam and self.dtype == torch.half):\n                #    bit16_partitions = self.parallel_partitioned_bit16_groups[i]\n                #    fp32_partition = self.single_partition_of_fp32_groups[i]\n                #    bit16_partitions[partition_id].data.copy_(fp32_partition.data)\n                bit16_partitions = self.parallel_partitioned_bit16_groups[i]\n                fp32_partition = self.single_partition_of_fp32_groups[i]\n                bit16_partitions[partition_id].data.copy_(fp32_partition.data)\n\n                self.timers(OPTIMIZER_STEP_TIMER).stop()\n            else:\n                # free gradients for all the parameters that are not updated by this process(ZeRO stage2)\n                self.free_grad_in_param_list(self.params_not_in_partition[i])\n\n                # create a flat gradients for parameters updated by this process\n                # If we are last partition, ensure we have same size grads and partition size, if not pad with zero tensors\n                if partition_id == dist.get_world_size(group=self.real_zp_process_group[i]) - 1:\n                    single_grad_partition = self.flatten_dense_tensors_aligned(\n                        self.averaged_gradients[i],\n                        int(self.partition_size[i])).to(self.single_partition_of_fp32_groups[i].dtype)\n                else:\n                    single_grad_partition = self.flatten(self.averaged_gradients[i]).to(\n                        self.single_partition_of_fp32_groups[i].dtype)\n                assert single_grad_partition.numel() == self.partition_size[i], \\\n                    \"averaged gradients have different number of elements that partition size {} {} {} {}\".format(\n                        single_grad_partition.numel(), self.partition_size[i], i, partition_id)\n\n                self.single_partition_of_fp32_groups[i].grad = single_grad_partition\n                # release all the gradient since we have already created a necessary copy in dp_grad_partition(ZeRO stage2)\n                self.free_grad_in_param_list(self.params_in_partition[i])\n\n                self.averaged_gradients[i] = None\n\n                self.unscale_and_clip_grads([single_grad_partition], scaled_global_grad_norm)\n\n                self.timers(OPTIMIZER_GRADIENTS_TIMER).stop()\n\n                # Step 3:- run the optimizer if no offloading\n                self.timers(OPTIMIZER_STEP_TIMER).start()\n                self._optimizer_step(i)\n                # Step 4:- get rid of the fp32 gradients. Not needed anymore\n                self.single_partition_of_fp32_groups[i].grad = None\n                del single_grad_partition\n                bit16_partitions = self.parallel_partitioned_bit16_groups[i]\n                fp32_partition = self.single_partition_of_fp32_groups[i]\n                bit16_partitions[partition_id].data.copy_(fp32_partition.data)\n                self.timers(OPTIMIZER_STEP_TIMER).stop()\n\n        see_memory_usage('After optimizer before all-gather')\n        if self.cpu_offload:\n            self.reset_cpu_buffers()\n\n        self.timers(OPTIMIZER_ALLGATHER_TIMER).start()\n\n        # if dist.get_rank(group=self.dp_process_group) == 0:\n        #     pdb.set_trace()  # 或者使用其他调试工具\n\n        # Gather the updated weights from everyone.\n        # Then all partitions of the model parameters are updated and ready for next round forward.\n        all_gather_into_tensor_dp_groups(groups_flat=self.bit16_groups_flat,\n                                         partitioned_param_groups=self.parallel_partitioned_bit16_groups,\n                                         zp_process_group=self.real_zp_process_group)\n        self.timers(OPTIMIZER_ALLGATHER_TIMER).stop()\n\n        # TODO: we probably don't need this? just to be safe\n        for i in range(len(self.bit16_groups)):\n            self._update_model_bit16_weights(i)\n\n        self.timers.log(OPTIMIZER_TIMERS)\n        see_memory_usage('After zero_optimizer step')\n\n        return\n\n    @torch.no_grad()\n    def update_lp_params(self):\n        for i, (bit16_partitions, fp32_partition) in enumerate(\n                zip(self.parallel_partitioned_bit16_groups, self.single_partition_of_fp32_groups)):\n            partition_id = dist.get_rank(group=self.real_zp_process_group[i])\n            bit16_partitions[partition_id].data.copy_(fp32_partition.data)\n            # print_rank_0(f'update_lp_params {i=} {partition_id=}', force=True)\n            # if i == 0:\n            #     print_rank_0(f'{fp32_partition[:10]=}', force=True)\n        all_gather_into_tensor_dp_groups(groups_flat=self.bit16_groups_flat,\n                                         partitioned_param_groups=self.parallel_partitioned_bit16_groups,\n                                         zp_process_group=self.real_zp_process_group)\n\n    def _average_expert_grad_norms(self, norm_groups):\n        for i, norm in enumerate(norm_groups):\n            if self.is_moe_param_group[i]:\n                scaled_norm = norm * 1.0 / float(dist.get_world_size(group=self.dp_process_group))\n                scaled_norm_tensor = torch.tensor(scaled_norm,\n                                                  device=get_accelerator().device_name(),\n                                                  dtype=torch.float)\n                dist.all_reduce(scaled_norm_tensor, group=self.dp_process_group)\n                norm_groups[i] = scaled_norm_tensor.item()\n\n    def unscale_and_clip_grads(self, grad_groups_flat, total_norm):\n        # compute combined scale factor for this group\n        combined_scale = self.loss_scale\n        if self.clip_grad > 0.:\n            # norm is in fact norm*scale\n            clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad\n            if clip > 1:\n                combined_scale = clip * self.loss_scale\n\n        for grad in grad_groups_flat:\n            if isinstance(grad, list):\n                sub_partitions = grad\n                for g in sub_partitions:\n                    g.data.mul_(1. / combined_scale)\n            else:\n                grad.data.mul_(1. / combined_scale)\n\n    def _check_overflow(self, partition_gradients=True):\n        self.overflow = self.has_overflow(partition_gradients)\n\n    # `params` is a list / generator of torch.Variable\n    def has_overflow_serial(self, params, is_grad_list=False):\n        for p in params:\n            if p.grad is not None and self._has_inf_or_nan(p.grad.data):\n                return True\n\n        return False\n\n    def has_overflow_partitioned_grads_serial(self):\n        for i in range(len(self.bit16_groups)):\n            for j, grad in enumerate(self.averaged_gradients[i]):\n                if grad is not None and self._has_inf_or_nan(grad.data, j):\n                    return True\n        return False\n\n    def has_overflow(self, partition_gradients=True):\n        if partition_gradients:\n            overflow = self.local_overflow if self.cpu_offload else self.has_overflow_partitioned_grads_serial()\n            overflow_gpu = get_accelerator().ByteTensor([overflow])\n            '''This will capture overflow across all data parallel and expert parallel process\n            Since expert parallel process are a subset of data parallel process'''\n            dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.dp_process_group)\n\n        else:\n            params = []\n            for group in self.bit16_groups:\n                for param in group:\n                    params.append(param)\n\n            overflow = self.has_overflow_serial(params, is_grad_list=partition_gradients)\n            overflow_gpu = get_accelerator().ByteTensor([overflow])\n\n        # Since each model parallel GPU carries only part of the model,\n        # make sure overflow flag is synced across all the model parallel GPUs\n        self._model_parallel_all_reduce(tensor=overflow_gpu, op=dist.ReduceOp.MAX)\n\n        overflow = overflow_gpu[0].item()\n        return bool(overflow)\n\n    # `x` is a torch.Tensor\n    @staticmethod\n    def _has_inf_or_nan(x, j=None):\n        try:\n            # if x is half, the .float() incurs an additional deep copy, but it's necessary if\n            # Pytorch's .sum() creates a one-element tensor of the same type as x\n            # (which is true for some recent version of pytorch).\n            cpu_sum = float(x.float().sum())\n            # More efficient version that can be used if .sum() returns a Python scalar\n            # cpu_sum = float(x.sum())\n        except RuntimeError as instance:\n            # We want to check if inst is actually an overflow exception.\n            # RuntimeError could come from a different error.\n            # If so, we still want the exception to propagate.\n            if \"value cannot be converted\" not in instance.args[0]:\n                raise\n            return True\n        else:\n            if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:\n                return True\n            return False\n\n    def backward(self, loss, retain_graph=False):\n        \"\"\"\n        :attr:`backward` performs the following steps:\n\n        1. fp32_loss = loss.float()\n        2. scaled_loss = fp32_loss*loss_scale\n        3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves\n        \"\"\"\n        self.micro_step_id += 1\n\n        if self.contiguous_gradients:\n            self.ipg_buffer = []\n            buf_0 = torch.empty(int(self.reduce_bucket_size),\n                                dtype=self.dtype,\n                                device=get_accelerator().current_device_name())\n            self.ipg_buffer.append(buf_0)\n\n            # Use double buffers to avoid data access conflict when overlap_comm is enabled.\n            if self.overlap_comm:\n                buf_1 = torch.empty(int(self.reduce_bucket_size),\n                                    dtype=self.dtype,\n                                    device=get_accelerator().current_device_name())\n                self.ipg_buffer.append(buf_1)\n            self.ipg_index = 0\n\n        if self.custom_loss_scaler:\n            scaled_loss = self.external_loss_scale * loss\n            scaled_loss.backward()\n        else:\n            self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)\n\n        # Only for Stage 1, Mode 2\n        if self.use_grad_accum_attribute:\n            self.fill_grad_accum_attribute()\n\n    def check_overflow(self, partition_gradients=True):\n        self._check_overflow(partition_gradients)\n\n    def _update_scale(self, has_overflow=False):\n        self.loss_scaler.update_scale(has_overflow)\n\n    # Promote state so it can be retrieved or set via \"fp16_optimizer_instance.state\"\n    def _get_state(self):\n        return self.optimizer.state\n\n    def _set_state(self, value):\n        self.optimizer.state = value\n\n    state = property(_get_state, _set_state)\n\n    # Promote param_groups so it can be retrieved or set via \"fp16_optimizer_instance.param_groups\"\n    # (for example, to adjust the learning rate)\n    def _get_param_groups(self):\n        return self.optimizer.param_groups\n\n    def _set_param_groups(self, value):\n        self.optimizer.param_groups = value\n\n    param_groups = property(_get_param_groups, _set_param_groups)\n\n    # Promote loss scale so it can be retrieved or set via \"fp16_optimizer_instance.loss_scale\"\n    def _get_loss_scale(self):\n        if self.custom_loss_scaler:\n            return self.external_loss_scale\n        else:\n            return self.loss_scaler.cur_scale\n\n    def _set_loss_scale(self, value):\n        self.loss_scaler.cur_scale = value\n\n    loss_scale = property(_get_loss_scale, _set_loss_scale)\n    cur_scale = property(_get_loss_scale, _set_loss_scale)\n\n    # Return group tensor after removing paddings that are added for alignment to DP world size.\n    # This method works on the assumption that each group contains a single flattened tensor.\n    def _get_groups_without_padding(self, groups_with_padding):\n        groups_without_padding = []\n        for i, group in enumerate(groups_with_padding):\n            lean_length = group.numel() - self.groups_padding[i]\n            groups_without_padding.append(group[:lean_length])\n\n        return groups_without_padding\n\n    # Return optimizer state after removing paddings that are added for alignment.\n    def _get_state_without_padding(self, state_with_padding, padding):\n        lean_state = {}\n        for key, value in state_with_padding.items():\n            if torch.is_tensor(value):\n                lean_length = value.numel() - padding\n                lean_state[key] = value[:lean_length]\n            else:\n                lean_state[key] = value\n\n        return lean_state\n\n    # Return base optimizer states.\n    # This method assumes that each param group contains a single flattened tensor.\n    def _get_base_optimizer_state(self):\n        optimizer_groups_state = []\n        for i, group in enumerate(self.optimizer.param_groups):\n            p = group['params'][0]\n            lean_optimizer_state = self._get_state_without_padding(self.optimizer.state[p], self.groups_padding[i])\n            optimizer_groups_state.append(lean_optimizer_state)\n\n        return optimizer_groups_state\n\n    def state_dict(self):\n        \"\"\"\n        Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.\n        This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict\n        of the contained Pytorch optimizer.\n        Example::\n            checkpoint = {}\n            checkpoint['model'] = model.state_dict()\n            checkpoint['optimizer'] = optimizer.state_dict()\n            torch.save(checkpoint, \"saved.pth\")\n        \"\"\"\n        state_dict = {}\n        state_dict[LOSS_SCALER] = self.loss_scaler\n        state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale\n        state_dict['overflow'] = self.overflow\n        state_dict[CLIP_GRAD] = self.clip_grad\n\n        if self.elastic_checkpoint:\n            state_dict[BASE_OPTIMIZER_STATE] = self._get_base_optimizer_state()\n\n            if \"step\" in self.optimizer.param_groups[0]:\n                # Assuming \"step\" is the only item that changes through training iterations\n                assert all(group[\"step\"] == self.optimizer.param_groups[0][\"step\"]\n                           for group in self.optimizer.param_groups), \"All param groups must have the same step value\"\n                state_dict[BASE_OPTIMIZER_STATE_STEP] = self.optimizer.param_groups[0][\"step\"]\n        else:\n            state_dict[BASE_OPTIMIZER_STATE] = self.optimizer.state_dict()\n\n        # Remove paddings for DP alignment to enable loading for other alignment values\n        fp32_groups_without_padding = self._get_groups_without_padding(self.single_partition_of_fp32_groups)\n        state_dict[SINGLE_PARTITION_OF_FP32_GROUPS] = fp32_groups_without_padding\n\n        state_dict[\n            ZERO_STAGE] = ZeroStageEnum.gradients if self.partition_gradients else ZeroStageEnum.optimizer_states\n        state_dict[GROUP_PADDINGS] = self.groups_padding\n        state_dict[PARTITION_COUNT] = self.partition_count\n\n        state_dict[DS_VERSION] = version\n        state_dict[PARAM_SLICE_MAPPINGS] = self._param_slice_mappings\n\n        return state_dict\n\n    # Restore base optimizer fp32 weights from elastic checkpoint by:\n    # 1) Merging fp32 weights from checkpoints of all partitions\n    # 2) Extracting fp32 weights for current partition from merged weights\n    # 3) Using extracted weights to update base optimizer weights directly.\n    def _restore_from_elastic_fp32_weights(self, all_state_dict):\n        merged_single_partition_of_fp32_groups = []\n\n        for i in range(len(self.single_partition_of_fp32_groups)):\n            partition_id = dist.get_rank(group=self.real_zp_process_group[i])\n            merged_partitions = [sd[SINGLE_PARTITION_OF_FP32_GROUPS][i] for sd in all_state_dict]\n            if self.is_moe_group(self.optimizer.param_groups[i]):\n                ranks = self.get_ep_ranks(group_name=self.optimizer.param_groups[i]['name'])\n                merged_partitions = [merged_partitions[i] for i in ranks]\n            flat_merged_partitions = self.flatten_dense_tensors_aligned(\n                merged_partitions,\n                self.nccl_start_alignment_factor * dist.get_world_size(group=self.real_zp_process_group[i]))\n            dp_partitions = self.get_data_parallel_partitions(flat_merged_partitions, i)\n            merged_single_partition_of_fp32_groups.append(dp_partitions[partition_id])\n\n        for current, saved in zip(self.single_partition_of_fp32_groups, merged_single_partition_of_fp32_groups):\n            current.data.copy_(saved.data)\n\n    # Restore base optimizer fp32 weights from ZeRO fp16 or bfloat16 weights\n    def _restore_from_bit16_weights(self):\n        for group_id, (bit16_partitions, fp32_partition) in enumerate(\n                zip(self.parallel_partitioned_bit16_groups, self.single_partition_of_fp32_groups)):\n            partition_id = dist.get_rank(group=self.real_zp_process_group[group_id])\n            fp32_partition.data.copy_(bit16_partitions[partition_id].data)\n\n    # Refresh the fp32 master params from the fp16 or bfloat16 copies.\n    def refresh_fp32_params(self):\n        self._restore_from_bit16_weights()\n\n    # Extract optimizer state for current partition from merged states of all partitions\n    def _partition_base_optimizer_state(self, state_key, all_partition_states, group_id):\n        partition_id = dist.get_rank(group=self.real_zp_process_group[group_id])\n        alignment = dist.get_world_size(group=self.real_zp_process_group[group_id])\n        if torch.is_tensor(all_partition_states[0]):\n            flat_merged_partitions = self.flatten_dense_tensors_aligned(all_partition_states, alignment)\n            dp_partitions = self.get_data_parallel_partitions(flat_merged_partitions, group_id)\n            return dp_partitions[partition_id]\n        else:\n            # Assume non-tensor states are not partitioned and equal across ranks, so return first one\n            return all_partition_states[0]\n\n    def _restore_base_optimizer_state(self, base_optimizer_group_states):\n        if type(base_optimizer_group_states) == dict:\n            base_optimizer_group_states = base_optimizer_group_states['state']\n        for i, group in enumerate(self.optimizer.param_groups):\n            p = group['params'][0]\n            for key, saved in base_optimizer_group_states[i].items():\n                if torch.is_tensor(self.optimizer.state[p][key]):\n                    dst_tensor = self.optimizer.state[p][key]\n                    src_tensor = _get_padded_tensor(saved, dst_tensor.numel())\n                    self.optimizer.state[p][key].data.copy_(src_tensor.data)\n                else:\n                    self.optimizer.state[p][key] = saved\n\n    def get_ep_ranks(self, rank=0, group_name=None):\n        from deepspeed.utils import groups\n        expert_parallel_size_ = groups._get_expert_parallel_world_size(group_name)\n        world_size = groups._get_data_parallel_world_size()\n        rank = groups._get_expert_parallel_rank(group_name)\n        ranks = range(rank, world_size, expert_parallel_size_)\n        return list(ranks)\n\n    # Restore base optimizer state from elastic checkpoint by\n    # 1) Merging optimizer state from checkpoints of all partitions\n    # 2) Extracting optimizer state for current partition from the merged state\n    # 3) Using the extracted value to directly update the base optimizer.\n    def _restore_elastic_base_optimizer_state(self, all_state_dict):\n        base_optimizer_group_states = []\n        for i in range(len(self.optimizer.param_groups)):\n            partition_states = {}\n            all_partition_group_states = [sd[BASE_OPTIMIZER_STATE][i] for sd in all_state_dict]\n\n            if self.is_moe_group(self.optimizer.param_groups[i]):\n                ranks = self.get_ep_ranks(group_name=self.optimizer.param_groups[i]['name'])\n                all_partition_group_states = [all_partition_group_states[i] for i in ranks]\n\n            for key in all_partition_group_states[0].keys():\n                all_partition_states = [all_states[key] for all_states in all_partition_group_states]\n                partition_states[key] = self._partition_base_optimizer_state(key, all_partition_states, i)\n            base_optimizer_group_states.append(partition_states)\n\n        self._restore_base_optimizer_state(base_optimizer_group_states)\n\n        # Restore step\n        if BASE_OPTIMIZER_STATE_STEP in all_state_dict[0]:\n            assert all(sd[BASE_OPTIMIZER_STATE_STEP] == all_state_dict[0][BASE_OPTIMIZER_STATE_STEP]\n                       for sd in all_state_dict), \"State dicts of all partitions must have the same step value\"\n            loaded_param_groups_step = all_state_dict[0][BASE_OPTIMIZER_STATE_STEP]\n            for param_group in self.optimizer.param_groups:\n                param_group['step'] = loaded_param_groups_step\n\n    def load_state_dict(self,\n                        state_dict_list,\n                        load_optimizer_states=True,\n                        load_from_fp32_weights=False,\n                        checkpoint_folder=None,\n                        load_serial=None):\n        if checkpoint_folder:\n            self._load_universal_checkpoint(checkpoint_folder, load_optimizer_states, load_from_fp32_weights)\n        else:\n            self._load_legacy_checkpoint(state_dict_list, load_optimizer_states, load_from_fp32_weights)\n\n    def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights):\n        self._load_hp_checkpoint_state(checkpoint_folder)\n\n    @property\n    def param_groups(self):\n        \"\"\"Forward the wrapped optimizer's parameters.\"\"\"\n        return self.optimizer.param_groups\n\n    def _load_hp_checkpoint_state(self, checkpoint_dir):\n        checkpoint_dir = os.path.join(checkpoint_dir, \"zero\")\n        optim_state_path = os.path.join(checkpoint_dir, \"optimizer_state.pt\")\n        assert os.path.isfile(\n            optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.'\n        optim_sd = torch.load(optim_state_path)\n        self._load_global_state(optim_sd)\n\n        tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu)\n        tp_world_size = self.mpu.get_slice_parallel_world_size() if hasattr(self.mpu, \"get_slice_parallel_world_size\") \\\n            else self.mpu.get_tensor_model_parallel_world_size()\n\n        for i, _ in enumerate(self.optimizer.param_groups):\n            for lp in self.bit16_groups[i]:\n                if lp._hp_mapping is not None:\n                    # print(f\"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}\")\n                    lp.load_hp_checkpoint_state(os.path.join(checkpoint_dir, self.param_names[lp]), tp_rank,\n                                                tp_world_size)\n\n    def _load_global_state(self, sd):\n        self.loss_scaler = sd.get(LOSS_SCALER, self.loss_scaler)\n        self.dynamic_loss_scale = sd.get('dynamic_loss_scale', self.dynamic_loss_scale)\n        self.overflow = sd.get('overflow', self.overflow)\n        self.clip_grad = sd.get(CLIP_GRAD, self.clip_grad)\n\n        ckpt_version = sd.get(DS_VERSION, False)\n        assert ckpt_version, f\"Empty ds_version in checkpoint, not clear how to proceed\"\n        ckpt_version = pkg_version.parse(ckpt_version)\n\n        # zero stage 1 mode\n        if not self.partition_gradients:\n            required_version = pkg_version.parse(\"0.3.17\")\n            error_str = f\"ZeRO stage 1 changed in {required_version} and is not backwards compatible \" \\\n                        \"with older stage 1 checkpoints. If you'd like to load an old ZeRO-1 checkpoint \" \\\n                        \"please use an older version of DeepSpeed (<= 0.5.8) and set 'legacy_stage1': true in your zero config json.\"\n            assert required_version <= ckpt_version, f\"Old version: {ckpt_version} {error_str}\"\n\n    def _load_legacy_checkpoint(self, state_dict_list, load_optimizer_states=True, load_from_fp32_weights=False):\n        r\"\"\"Loading ZeRO checkpoint\n\n        Arguments:\n            state_dict_list: List of all saved ZeRO checkpoints, one for each saved partition.\n                Note that the number of saved partitions may differ from number of loading partitions to support\n                changing GPU count, specifically DP world size, between saving and loading checkpoints.\n            load_optimizer_states: Boolean indicating whether or not to load base optimizer states\n            load_from_fp32_weights: Boolean indicating whether to initialize fp32 master weights from fp32\n            copies in checkpoints (no precision loss) or from model's fp16 copies (with precision loss).\n        \"\"\"\n        \"\"\"\n        Loads a state_dict created by an earlier call to state_dict().\n        If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,\n        whose parameters in turn came from ``model``, it is expected that the user\n        will call ``model.load_state_dict()`` before\n        ``fp16_optimizer_instance.load_state_dict()`` is called.\n        Example::\n            model = torch.nn.Linear(D_in, D_out).to(get_accelerator().device_name()).half()\n            optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)\n            optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)\n            ...\n            checkpoint = torch.load(\"saved.pth\")\n            model.load_state_dict(checkpoint['model'])\n            optimizer.load_state_dict(checkpoint['optimizer'])\n        \"\"\"\n\n        # I think it should actually be ok to reload the optimizer before the model.\n        dp_rank = dist.get_rank(group=self.zp_process_group)\n        current_rank_sd = state_dict_list[dp_rank]\n        self._load_global_state(current_rank_sd)\n\n        ckpt_is_rigid = isinstance(current_rank_sd[BASE_OPTIMIZER_STATE], dict)\n\n        # padding is always at the last rank/partition\n        # if DP=1024 and param-group elems=16 -> padding will be 1024-16 across all but one rank\n        # scenario-1 (shrink): saving w. 4 gpus -> loading w. 2 gpus\n        # scenario-2 (expand): saving w. 2 gpus -> loading w. 4 gpus\n        # if load_optimizer_states:\n        #     if new_dp_size:\n        #         self.strip_padding()\n        #         self.add_padding_w_new_dp_size()\n        #     self.optimizer.load_state_dict(current_rank_sd[BASE_OPTIMIZER_STATE])\n\n        if load_optimizer_states:\n            if ckpt_is_rigid:\n                # loading rigid ckpt into either rigid or elastic exec\n                self.optimizer.load_state_dict(current_rank_sd[BASE_OPTIMIZER_STATE])\n            else:\n                if self.elastic_checkpoint:\n                    # loading elastic into elastic exec\n                    self._restore_elastic_base_optimizer_state(state_dict_list)\n                else:\n                    # loading an elastic checkpoint into rigid exec\n                    self._restore_base_optimizer_state(current_rank_sd[BASE_OPTIMIZER_STATE])\n\n        # At this point, the optimizer's references to the model's fp32 parameters are up to date.\n        # The optimizer's hyperparameters and internal buffers are also up to date.\n        # However, the fp32 master copies of the model's fp16 params stored by the optimizer are still\n        # out of date.  There are two options.\n        # 1:  Refresh the master params from the model's fp16 params.\n        # This requires less storage but incurs precision loss.\n        # 2:  Save and restore the fp32 master copies separately.\n        # We choose option 1 if changing DP degree and option 2 otherwise.\n        #\n        # Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device\n        # of their associated parameters, because it's possible those buffers might not exist yet in\n        # the current optimizer instance.  In our case, as long as the current FP16_Optimizer has been\n        # constructed in the same way as the one whose state_dict we are loading, the same master params\n        # are guaranteed to exist, so we can just copy_() from the saved master params.\n\n        if load_from_fp32_weights:\n            # option 2 from above\n            if self.elastic_checkpoint and not ckpt_is_rigid:\n                self._restore_from_elastic_fp32_weights(state_dict_list)\n            else:\n                # For non-elastic checkpoint, simply copying from saved weights of current rank is sufficient.\n                for current, saved in zip(self.single_partition_of_fp32_groups,\n                                          current_rank_sd[SINGLE_PARTITION_OF_FP32_GROUPS]):\n                    src_tensor = _get_padded_tensor(saved, current.numel())\n                    current.data.copy_(src_tensor.data)\n        else:\n            # option 1 from above\n            self._restore_from_bit16_weights()\n\n        if load_optimizer_states:\n            self._link_all_hp_params()\n\n\ndef _handle_overflow(cpu_sum, x, i):\n    import math\n    rank = dist.get_rank()\n    if rank == 0:\n        t_i = -1\n        for v_i, v in enumerate(x.data.contiguous().view(-1)):\n            if not math.isfinite(float(v)):\n                t_i = v_i\n                break\n        logger.info(f\"rank {rank} detected overflow {cpu_sum} in tensor {i}:{t_i} shape {x.shape}\")\n\n\ndef estimate_zero2_model_states_mem_needs(total_params,\n                                          num_gpus_per_node=1,\n                                          num_nodes=1,\n                                          cpu_offload=True,\n                                          additional_buffer_factor=1.5):\n    total_gpus = num_nodes * num_gpus_per_node\n\n    if cpu_offload:\n        gpu_mem = 2 * total_params\n        cpu_mem = total_params * max(4 * total_gpus, 16) * additional_buffer_factor\n    else:\n        gpu_mem = 4 * total_params + int(16 * total_params / total_gpus)\n        cpu_mem = total_params * 4 * num_gpus_per_node * additional_buffer_factor\n\n    return int(cpu_mem), int(gpu_mem)\n\n\ndef model_to_params(model):\n    # shared params calculated only once\n    total_params = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values())\n    return total_params\n\n\ndef estimate_zero2_model_states_mem_needs_all_live(model,\n                                                   num_gpus_per_node=1,\n                                                   num_nodes=1,\n                                                   additional_buffer_factor=1.5):\n    \"\"\"\n    Print out estimates on memory usage requirements for ZeRO 2 params, optim states and gradients\n    for a given ``model`` and hardware setup.\n\n    If you have an actual model object, use this function and everything will be derived\n    automatically.\n\n    If it's a hypothetical model, use ``estimate_zero2_model_states_mem_needs_all_cold`` where you have to pass\n    the ``total_params`` explicitly.\n\n    Args:\n        - ``model``: ``nn.Module`` object\n        - ``num_gpus_per_node``: how many gpus per node (defaults to 1)\n        - ``num_nodes``: how many nodes (defaults to 1),\n        - ``additional_buffer_factor``: estimation factor (defaults to 1.5):\n\n    \"\"\"\n\n    total_params = model_to_params(model)\n\n    estimate_zero2_model_states_mem_needs_all_cold(total_params=total_params,\n                                                   num_gpus_per_node=num_gpus_per_node,\n                                                   num_nodes=num_nodes,\n                                                   additional_buffer_factor=additional_buffer_factor)\n\n\ndef estimate_zero2_model_states_mem_needs_all_cold(total_params,\n                                                   num_gpus_per_node=1,\n                                                   num_nodes=1,\n                                                   additional_buffer_factor=1.5):\n    \"\"\"\n    Print out estimates on memory usage requirements for ZeRO 2 params, optim states and gradients\n    for a given ``model`` and hardware setup.\n\n    If it's a hypothetical model, use this function where you have to pass\n    the ``total_params`` and ``largest_layer_params`` explicitly.\n\n    If you have an actual model object, use ``estimate_zero2_model_states_mem_needs_all_live`` and everything\n    will be derived automatically.\n\n    Args:\n        - ``total_params``: total  model params\n        - ``num_gpus_per_node``: how many gpus per node (defaults to 1)\n        - ``num_nodes``: how many nodes (defaults to 1),\n        - ``additional_buffer_factor``: estimation factor (defaults to 1.5):\n\n    \"\"\"\n\n    def format_options(cpu_offload):\n        enabled = []\n        device = f'{OffloadDeviceEnum.cpu:4}' if cpu_offload else \"none\"\n        enabled.append(f\"offload_optimizer={device}\")\n        return \", \".join(enabled)\n\n    nodes_str = \"nodes\" if num_nodes > 1 else \"node\"\n    gpus_str = \"GPUs\" if num_gpus_per_node > 1 else \"GPU\"\n    print(\"Estimated memory needed for params, optim states and gradients for a:\\n\"\n          f\"HW: Setup with {num_nodes} {nodes_str}, {num_gpus_per_node} {gpus_str} per node.\\n\"\n          f\"SW: Model with {int(total_params / 1e6)}M total params.\")\n    print(\"  per CPU  |  per GPU |   Options\")\n    for cpu_offload in [True, False]:\n        cpu_mem, gpu_mem = estimate_zero2_model_states_mem_needs(total_params=total_params,\n                                                                 num_gpus_per_node=num_gpus_per_node,\n                                                                 num_nodes=num_nodes,\n                                                                 cpu_offload=cpu_offload,\n                                                                 additional_buffer_factor=additional_buffer_factor)\n\n        options_str = format_options(cpu_offload=cpu_offload)\n        print(f\" {cpu_mem / 2 ** 30:7.2f}GB | {gpu_mem / 2 ** 30:6.2f}GB | {options_str}\")\n"
  },
  {
    "path": "opensora/adaptor/utils.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# SPDX-License-Identifier: Apache-2.0\n\n# DeepSpeed Team\n\"\"\"\nCopyright NVIDIA/Megatron\n\nHelper functions and classes from multiple sources.\n\"\"\"\n\nfrom collections.abc import Iterable\nfrom deepspeed.moe.utils import is_moe_param\nimport os\nimport psutil\nimport gc\nfrom math import sqrt\nfrom packaging import version as pkg_version\n\nimport torch\nfrom deepspeed import comm as dist\n\ntry:\n    from torch._six import inf\nexcept ModuleNotFoundError:\n    from torch import inf\n\nfrom deepspeed.utils import groups, logger\nfrom deepspeed.runtime.constants import PIPE_REPLICATED\nfrom numpy import prod\nfrom deepspeed.accelerator import get_accelerator\n\nfrom deepspeed.module_inject.policy import transpose\nfrom torch.nn import functional as F\n\ntorch_memory_reserved = get_accelerator().memory_reserved\ntorch_max_memory_reserved = get_accelerator().max_memory_reserved\n\n\nclass DummyOptim():\n    \"\"\"\n    Dummy optimizer presents model parameters as a param group, this is\n    primarily used to allow ZeRO-3 without an optimizer\n    \"\"\"\n\n    def __init__(self, params):\n        self.param_groups = []\n        self.param_groups.append({'params': params})\n\n\ngraph_cache = {}\n\n\ndef graph_process(replay_first_step, func, *args, **kwargs):\n    # `func` should only contain operations on the GPU\n    # Please ensure that the memory address of the data required by 'func' remains constant\n    if func.__name__ not in graph_cache:\n        cuda_stream = get_accelerator().Stream()\n        cuda_stream.wait_stream(get_accelerator().current_stream())\n        with get_accelerator().stream(cuda_stream):\n            func(*args, **kwargs)\n        get_accelerator().current_stream().wait_stream(cuda_stream)\n        graph_cache[func.__name__] = get_accelerator().create_graph()\n        with get_accelerator().capture_to_graph(graph_cache[func.__name__]):\n            func(*args, **kwargs)\n        if replay_first_step:\n            get_accelerator().replay_graph(graph_cache[func.__name__])\n    else:\n        get_accelerator().replay_graph(graph_cache[func.__name__])\n\n\ndef noop_decorator(func):\n    return func\n\n\nclass noop_context(object):\n\n    def __init__(self):\n        pass\n\n    def __enter__(self):\n        pass\n\n    def __exit__(self, exc_type, exc_val, exc_tb):\n        pass\n\n\ndef ensure_directory_exists(filename):\n    \"\"\"Create the directory path to ``filename`` if it does not already exist.\n\n    Args:\n        filename (str): A file path.\n    \"\"\"\n    dirname = os.path.dirname(filename)\n    os.makedirs(dirname, exist_ok=True)\n\n\ndef set_random_seed(seed):\n    \"\"\"Set the random seed for common PRNGs used during training: random, numpy, and torch.\n\n    Args:\n        seed (int): the seed to use\n    \"\"\"\n    import numpy\n    import random\n    random.seed(seed)\n    numpy.random.seed(seed)\n    torch.manual_seed(seed)\n\n\ndef is_model_parallel_parameter(p) -> bool:\n    if hasattr(p, 'model_parallel') and p.model_parallel:\n        return True\n\n    if hasattr(p, 'tensor_model_parallel') and p.tensor_model_parallel:\n        return True\n\n    return False\n\n\ndef bwc_tensor_model_parallel_rank(mpu=None):\n    \"\"\"Backwards-compatible way of querying the tensor model parallel rank from\n    an ``mpu`` object.\n\n    *Tensor* model parallelism means that tensors are physically split across\n    processes. This contrasts with *pipeline* model parallelism, in which the\n    layers are partitioned but tensors left intact.\n\n    The API for tensor model parallelism has changed across versions and this\n    helper provides a best-effort implementation across versions of ``mpu``\n    objects.  The preferred mechanism is\n    ``mpu.get_tensor_model_parallel_rank()``.\n\n    This should \"just work\" with both Megatron-LM and DeepSpeed's pipeline\n    parallelism.\n\n    Args:\n        mpu (model parallel unit, optional): The tensor model parallel rank.\n            If ``mpu=None``, returns 0. Defaults to ``None``.\n\n    Returns:\n        int: the rank\n    \"\"\"\n    if mpu is None:\n        # No model parallelism in easy :)\n        return 0\n\n    if hasattr(mpu, 'get_tensor_model_parallel_rank'):\n        # New Megatron and DeepSpeed convention (post pipeline-parallelism release)\n        return mpu.get_tensor_model_parallel_rank()\n    elif hasattr(mpu, 'get_slice_parallel_rank'):\n        # Some DeepSpeed + pipeline parallelism versions\n        return mpu.get_slice_parallel_rank()\n    else:\n        # Deprecated Megatron and DeepSpeed convention\n        return mpu.get_model_parallel_rank()\n\n\ndef copy_to_device(item, device, criterion_func):\n    \"\"\"\n    Return a copy of tensor on specified device.\n    Works on individual tensors, and tensors contained/nested in lists, tuples, and dicts.\n    Parameters:\n        item: tensor to copy or (possibly nested) container of tensors to copy.\n        device: target device\n        criterion_func: Function to restrict copy operation to items meet criterion\n\n    Returns:\n        None\n    \"\"\"\n    if criterion_func(item):\n        return item.to(device)\n    elif isinstance(item, list):\n        return [copy_to_device(v, device, criterion_func) for v in item]\n    elif isinstance(item, tuple):\n        return tuple([copy_to_device(v, device, criterion_func) for v in item])\n    elif isinstance(item, dict):\n        return {k: copy_to_device(v, device, criterion_func) for k, v in item.items()}\n    else:\n        return item\n\n\ndef move_to_device(item, device, criterion_func):\n    \"\"\"\n    Move tensor on to specified device by changing the storage.\n    Works on individual tensors, and tensors contained/nested in lists, tuples, and dicts.\n    Parameters:\n        item: tensor to move or (possibly nested) container of tensors to move.\n        device: target device\n        criterion_func: Function to restrict move operation to items meet criterion\n\n    Returns:\n        None\n    \"\"\"\n    if criterion_func(item):\n        device_copy = item.to(device)\n        item.data = device_copy.data\n        return item\n    elif isinstance(item, list):\n        return [move_to_device(v, device, criterion_func) for v in item]\n    elif isinstance(item, tuple):\n        return tuple([move_to_device(v, device, criterion_func) for v in item])\n    elif isinstance(item, dict):\n        return {k: move_to_device(v, device, criterion_func) for k, v in item.items()}\n    else:\n        return item\n\n\nclass CheckOverflow(object):\n    '''Checks for overflow in gradient across parallel process'''\n\n    def __init__(self, param_groups=None, mpu=None, zero_reduce_scatter=False, deepspeed=None):\n        self.mpu = mpu\n        self.params = [] if param_groups else None\n        self.zero_reduce_scatter = zero_reduce_scatter\n        self.deepspeed = deepspeed\n        self.has_moe_params = False\n        if param_groups:\n            for group in param_groups:\n                for param in group:\n                    self.params.append(param)\n                    if is_moe_param(param):\n                        self.has_moe_params = True\n\n    def check_using_norm(self, norm_group, reduce_overflow=True):\n        # TODO: I don't think reduce_overflow is needed if mpu is None\n        overflow = -1 in norm_group\n        overflow_gpu = get_accelerator().FloatTensor([overflow])\n        if self.has_moe_params:\n            # In this case, we need to do an all_reduce across\n            # the expert_parallel_group, so that if there was\n            # an overflow due to expert weights, we detect it\n\n            # Only need to check groups.get_largest_expert_parallel_group()\n            dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=groups._get_max_expert_parallel_group())\n        if self.mpu is not None:\n            dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.mpu.get_model_parallel_group())\n        elif reduce_overflow:\n            dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX)\n            dist.barrier()\n        overflow = overflow_gpu[0].item()\n        return bool(overflow)\n\n    def check(self, param_groups=None):\n        params = []\n        has_moe_params = False\n        if param_groups is None:\n            params = self.params\n            has_moe_params = self.has_moe_params\n        else:\n            assert param_groups is not None, \\\n                \"self.params and param_groups both cannot be none\"\n\n            for group in param_groups:\n                for param in group:\n                    params.append(param)\n                    if is_moe_param(param):\n                        has_moe_params = True\n\n        return self.has_overflow(params, has_moe_params=has_moe_params)\n\n    # `params` is a list / generator of torch.Variable\n    def has_overflow_serial(self, params):\n        for i, p in enumerate(params):\n            if p.grad is not None and self._has_inf_or_nan(p.grad.data, i):\n                return True\n        return False\n\n    def has_overflow(self, params, has_moe_params=None):\n        if has_moe_params is None:\n            has_moe_params = self.has_moe_params\n        overflow = self.has_overflow_serial(params)\n        # Since each model parallel GPU carries only part of the model,\n        # make sure overflow flag is synced across all the model parallel GPUs\n        overflow_gpu = get_accelerator().ByteTensor([overflow])\n        # deepspeed.comm.all_reduce(overflow_gpu,\n        #                             op=deepspeed.comm.ReduceOp.MAX,\n        #                             group=mpu.get_model_parallel_group())\n        if has_moe_params:\n            # All reduce this across expert_parallel_group, so that if an expert\n            # overflows, we detect it here\n            dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=groups._get_max_expert_parallel_group())\n        if self.zero_reduce_scatter:\n            dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=dist.get_world_group())\n        elif self.mpu is not None:\n            if self.deepspeed is not None:\n                using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce')\n                if (using_pipeline and self.deepspeed.pipeline_enable_backward_allreduce is False) or (\n                        not using_pipeline and self.deepspeed.enable_backward_allreduce is False):\n                    dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.mpu.get_data_parallel_group())\n            dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.mpu.get_model_parallel_group())\n        elif self.deepspeed is not None and self.deepspeed.enable_backward_allreduce is False:\n            dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=dist.get_world_group())\n\n        overflow = overflow_gpu[0].item()\n        return bool(overflow)\n\n    # `x` is a torch.Tensor\n    @staticmethod\n    def _has_inf_or_nan(x, i):\n        try:\n            # if x is half, the .float() incurs an additional deep copy, but it's necessary if\n            # Pytorch's .sum() creates a one-element tensor of the same type as x\n            # (which is true for some recent version of pytorch).\n            cpu_sum = float(x.float().sum())\n            # More efficient version that can be used if .sum() returns a Python scalar\n            # cpu_sum = float(x.sum())\n        except RuntimeError as instance:\n            # We want to check if inst is actually an overflow exception.\n            # RuntimeError could come from a different error.\n            # If so, we still want the exception to propagate.\n            if \"value cannot be converted\" not in instance.args[0]:\n                raise\n            return True\n        else:\n            if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:\n                return True\n            return False\n\n\ndef _handle_overflow(cpu_sum, x, i):\n    import math\n    rank = dist.get_rank()\n    if rank == 0:\n        t_i = -1\n        for v_i, v in enumerate(x.data.contiguous().view(-1)):\n            if not math.isfinite(float(v)):\n                t_i = v_i\n                break\n        logger.info(f\"rank {rank} detected overflow {cpu_sum} in tensor {i}:{t_i} shape {x.shape}\")\n\n\ndef get_global_norm(norm_list):\n    \"\"\" Compute total from a list of norms\n    \"\"\"\n    total_norm = 0.0\n    for norm in norm_list:\n        total_norm += norm**2.0\n    # logger.info(f'norm_list = {norm_list} global = {sqrt(total_norm)}')\n    return sqrt(total_norm)\n\n\ndef clip_grad_norm_(parameters, max_norm, norm_type=2, mpu=None):\n    \"\"\"Clips gradient norm of an iterable of parameters.\n\n    This has been adapted from Nvidia megatron. We add norm averaging\n    to consider MoE params when calculating norm as they will result\n    in different norms across different ranks.\n\n    This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and\n    added functionality to handle model parallel parameters. Note that\n    the gradients are modified in place.\n\n    Arguments:\n        parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a\n            single Tensor that will have gradients normalized\n        max_norm (float or int): max norm of the gradients\n        norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for\n            infinity norm.\n\n    Returns:\n        Total norm of the parameters (viewed as a single vector).\n    \"\"\"\n    if isinstance(parameters, torch.Tensor):\n        parameters = [parameters]\n    parameters = list(filter(lambda p: p.grad is not None, parameters))\n    max_norm = float(max_norm)\n    norm_type = float(norm_type)\n    if norm_type == inf:\n        total_norm = max(p.grad.data.abs().max() for p in parameters)\n        total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])\n        # Take max across all GPUs.\n        if mpu is not None:\n            dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group())\n        total_norm = total_norm_cuda[0].item()\n    else:\n        total_norm = 0\n        for p in parameters:\n            if mpu is not None:\n                if (mpu.get_model_parallel_rank() == 0) or is_model_parallel_parameter(p):\n                    param_norm = p.grad.data.norm(norm_type)\n                    total_norm += param_norm.item()**norm_type\n            else:\n                param_norm = p.grad.data.float().norm(norm_type)\n                total_norm += param_norm.item()**norm_type\n\n        # Sum across all model parallel GPUs.\n        total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])\n        if mpu is not None:\n            dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group())\n        total_norm = total_norm_cuda[0].item()**(1. / norm_type)\n\n    # Need to average total_norm across different GPUs due to the presence of moe params\n    pg = groups._get_data_parallel_group()\n    scaled_norm = total_norm * 1.0 / float(dist.get_world_size(group=pg))\n\n    scaled_norm_tensor = get_accelerator().FloatTensor([float(scaled_norm)])\n    dist.all_reduce(scaled_norm_tensor, group=pg)\n    total_norm = scaled_norm_tensor.item()\n\n    clip_coef = max_norm / (total_norm + 1e-6)\n    if clip_coef < 1:\n        for p in parameters:\n            p.grad.data.mul_(clip_coef)\n    return total_norm\n\n\ndef get_grad_norm(parameters, norm_type=2, mpu=None):\n    \"\"\"Get grad norm of an iterable of parameters.\n\n    This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and\n    added functionality to handle model parallel parameters. Note that\n    the gradients are modified in place. Taken from Nvidia Megatron.\n\n    Arguments:\n        parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a\n            single Tensor that will have gradients normalized\n        norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for\n            infinity norm.\n\n    Returns:\n        Total norm of the parameters (viewed as a single vector).\n    \"\"\"\n    if isinstance(parameters, torch.Tensor):\n        parameters = [parameters]\n    parameters = list(filter(lambda p: p.grad is not None, parameters))\n\n    norm_type = float(norm_type)\n    if norm_type == inf:\n        total_norm = max(p.grad.data.abs().max() for p in parameters)\n        total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])\n        # Take max across all GPUs.\n        if mpu is not None:\n            dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group())\n        total_norm = total_norm_cuda[0].item()\n    else:\n        total_norm = 0.\n        tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=mpu)\n        for p in parameters:\n            # Pipeline parallelism may replicate parameters. Avoid multi-counting.\n            if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated:\n                continue\n\n            # Filter to avoid over-counting replicated tensors from tensor\n            # model parallelism\n            if (tensor_mp_rank > 0) and not is_model_parallel_parameter(p):\n                continue\n\n            param_norm = p.grad.data.float().norm(norm_type)\n            total_norm += param_norm.item()**norm_type\n\n        # Sum across all model parallel GPUs.\n        total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])\n        if mpu is not None:\n            dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group())\n        total_norm = total_norm_cuda[0].item()**(1. / norm_type)\n\n    if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm:\n        total_norm = -1\n\n    return total_norm\n\n\ndef get_grad_zeros(parameters, mpu=None):\n    \"\"\"Compute the number of grads with zero values.\n\n    This is adapted from get_grad_norm\n\n    Arguments:\n        parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a\n            single Tensor that will have gradients normalized\n\n    Returns:\n        Total number of params with zero values (viewed as a single vector).\n    \"\"\"\n    if isinstance(parameters, torch.Tensor):\n        parameters = [parameters]\n    parameters = list(filter(lambda p: p.grad is not None, parameters))\n\n    total_zeros = 0.\n    tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=mpu)\n    for p in parameters:\n        # Pipeline parallelism may replicate parameters. Avoid multi-counting.\n        if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated:\n            continue\n\n        # Filter to avoid over-counting replicated tensors from tensor\n        # model parallelism\n        if (tensor_mp_rank > 0) and not is_model_parallel_parameter(p):\n            continue\n\n        count_zeros = p.grad.numel() - torch.count_nonzero(p.grad)\n        total_zeros += count_zeros.item()\n\n    # Sum across all model parallel GPUs.\n    total_zeros_cuda = get_accelerator().FloatTensor([float(total_zeros)])\n    if mpu is not None:\n        dist.all_reduce(total_zeros_cuda, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group())\n    total_zeros = total_zeros_cuda[0].item()\n\n    return total_zeros\n\n\ndef get_weight_norm(parameters, norm_type=2, mpu=None):\n    \"\"\"Get norm of an iterable of parameters.\n\n    This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and\n    added functionality to handle model parallel parameters. Note that\n    the gradients are modified in place. Taken from Nvidia Megatron.\n\n    Arguments:\n        parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a\n            single Tensor that will have gradients normalized\n        norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for\n            infinity norm.\n\n    Returns:\n        Total norm of the parameters (viewed as a single vector).\n        -1 if the norm value is NaN or Inf.\n    \"\"\"\n    if isinstance(parameters, torch.Tensor):\n        parameters = [parameters]\n\n    norm_type = float(norm_type)\n    if norm_type == inf:\n        total_norm = max(p.data.abs().max() for p in parameters)\n        total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])\n        # Take max across all GPUs.\n        if mpu is not None:\n            dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group())\n        total_norm = total_norm_cuda[0].item()\n    else:\n        total_norm = 0.\n        tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=mpu)\n        for p in parameters:\n            # Pipeline parallelism may replicate parameters. Avoid multi-counting.\n            if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated:\n                continue\n\n            # Filter to avoid over-counting replicated tensors from tensor\n            # model parallelism\n            if (tensor_mp_rank > 0) and not is_model_parallel_parameter(p):\n                continue\n\n            param_norm = p.data.float().norm(norm_type)\n            total_norm += param_norm**norm_type\n\n        # Sum across all model parallel GPUs.\n        total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])\n        if mpu is not None:\n            dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group())\n        total_norm = total_norm_cuda[0].item()**(1. / norm_type)\n\n    if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm:\n        total_norm = -1\n\n    return total_norm\n\n\ndef prefix_sum_inc(weights):\n    \"\"\" Compute an inclusive prefix sum.\n\n    Example:\n        >>> prefix_sum_inc([3,4,5])\n        [3, 7, 12]\n    \"\"\"\n    weights_ = [w for w in weights]\n    for x in range(1, len(weights_)):\n        weights_[x] += weights_[x - 1]\n    return weights_\n\n\ndef partition_uniform(num_items, num_parts):\n    import numpy\n    parts = [0] * (num_parts + 1)\n    # First check for the trivial edge case\n    if num_items <= num_parts:\n        for p in range(num_parts + 1):\n            parts[p] = min(p, num_items)\n        return parts\n\n    chunksize = num_items // num_parts\n    residual = num_items - (chunksize * num_parts)\n\n    parts = numpy.arange(0, (num_parts + 1) * chunksize, chunksize)\n\n    for i in range(residual):\n        parts[i + 1:] += 1\n    parts = parts.tolist()\n\n    return parts\n\n\ndef partition_balanced(weights, num_parts):\n    \"\"\"\n    use dynamic programming solve `The Linear Partition Problem`.\n    see https://www8.cs.umu.se/kurser/TDBAfl/VT06/algorithms/BOOK/BOOK2/NODE45.HTM\n    \"\"\"\n    import numpy as np\n    n = len(weights)\n    m = num_parts\n\n    if n <= m:\n        return partition_uniform(n, m)\n\n    dp_max = np.full((n + 1, m + 1), np.inf)\n    dp_min = np.full((n + 1, m + 1), np.inf)\n    dp_cost = np.full((n + 1, m + 1), np.inf)\n    position = np.zeros((n + 1, m + 1), dtype=int)\n    prefix_sum = np.zeros((n + 1))\n    prefix_sum[1:] = np.cumsum(weights)\n\n    dp_max[0, 0] = 0\n    dp_cost[0, 0] = 0\n    for i in range(1, n + 1):\n        for j in range(1, min(i, m) + 1):\n            for k in range(i):\n                max_sum = max(dp_max[k, j - 1], prefix_sum[i] - prefix_sum[k])\n                min_sum = min(dp_min[k, j - 1], prefix_sum[i] - prefix_sum[k])\n                cost = max_sum - min_sum\n                if dp_cost[i, j] >= cost:\n                    dp_cost[i, j] = cost\n                    dp_max[i, j] = max_sum\n                    dp_min[i, j] = min_sum\n                    position[i, j] = k\n\n    parts = [n]\n    for i in reversed(range(1, m + 1)):\n        parts.append(position[parts[-1], i])\n    parts.reverse()\n\n    return parts\n\n\nclass PartitionedTensor:\n\n    def __init__(self, tensor, group, partition_meta=None):\n        super().__init__()\n\n        self.group = group\n        self.num_parts = dist.get_world_size(group=self.group)\n        self.rank = dist.get_rank(group=self.group)\n        self.orig_size = list(tensor.size())\n        self.orig_device = tensor.device\n        self.local_data, self.partition = self._partition_tensor(tensor)\n        self.even_split = tensor.numel() % self.num_parts == 0\n\n    @classmethod\n    def from_meta(cls, meta, local_part, group, device=get_accelerator().device_name()):\n        assert meta.dtype == torch.long\n        dummy = torch.ones(dist.get_world_size(group=group))\n        part_obj = cls(tensor=dummy, group=group)\n\n        meta = meta.tolist()\n\n        # [N, list0, ..., listN-1]\n        part_obj.orig_size = meta[1:(1 + meta[0])]\n        meta = meta[1 + meta[0]:]\n\n        part_obj.orig_device = device\n        part_obj.local_data = local_part.detach()\n\n        part_obj.group = group\n\n        # Partition is encoded like the rowptr of a CSR matrix:\n        # [num_parts, rank, 0, part_1, ..., part_num_parts]\n        # TODO: support shuffle between different partition granularities\n        assert part_obj.num_parts == meta[0]\n        assert part_obj.rank == meta[1]\n        part_obj.partition = meta[2:]  # length num_parts+1\n\n        return part_obj\n\n    def _partition_tensor(self, tensor):\n        partition = partition_uniform(num_items=tensor.numel(), num_parts=self.num_parts)\n        start = partition[self.rank]\n        length = partition[self.rank + 1] - start\n        tensor_part = tensor.detach().contiguous().view(-1).narrow(0, start=start, length=length).clone()\n\n        return tensor_part, partition\n\n    def full(self, device=None):\n        if device is None:\n            device = self.orig_device\n\n        # Allocate the full tensor as a flat buffer.\n        full_numel = prod(self.full_size())\n        flat_tensor = torch.zeros([full_numel], dtype=self.local_data.dtype, device=device)\n        if self.even_split:\n            # Collect the full tensor\n            dist.all_gather_into_tensor(flat_tensor, self.local_data, group=self.group)\n        else:\n            for part_id in range(self.num_parts):\n                part_size = self.partition[part_id + 1] - self.partition[part_id]\n                buf = flat_tensor.narrow(0, start=self.partition[part_id], length=part_size)\n                if part_id == self.rank:\n                    buf.copy_(self.local_data)\n                dist.broadcast(buf, part_id, self.group)\n        return flat_tensor.view(self.full_size()).clone().detach()\n\n    def to_meta(self):\n        \"\"\"Returns a torch.LongTensor that encodes partitioning information.\n\n        Can be used along with ``data()`` to serialize a ``PartitionedTensor`` for\n        communication.\n\n        Returns:\n            torch.LongTensor: a tensor encoding the meta-information for the partitioning\n        \"\"\"\n        meta = []\n        meta.append(len(self.orig_size))\n        meta += list(self.orig_size)\n        meta.append(self.num_parts)\n        meta.append(self.rank)\n        meta += self.partition\n        return torch.LongTensor(data=meta).to(self.orig_device)\n\n    def data(self):\n        return self.local_data\n\n    def local_size(self):\n        return self.local_data.size()\n\n    def full_size(self):\n        return self.orig_size\n\n\nmem_alloced = 0\nmem_cached = 0\n\n\ndef memory_status(msg, print_rank=-1, reset_max=False):\n    global mem_alloced, mem_cached\n\n    rank = dist.get_rank()\n    if print_rank != -1 and rank != print_rank:\n        return\n\n    get_accelerator().synchronize()\n\n    if reset_max:\n        get_accelerator().reset_max_memory_cached()\n        get_accelerator().reset_max_memory_allocated()\n\n    new_alloced = get_accelerator().memory_allocated()\n    new_cached = get_accelerator().memory_cached()\n\n    delta_alloced = new_alloced - mem_alloced\n    delta_cached = new_cached - mem_cached\n\n    mem_cached = new_cached\n    mem_alloced = new_alloced\n\n    max_alloced = get_accelerator().max_memory_allocated()\n    max_cached = get_accelerator().max_memory_cached()\n\n    # convert to GB for printing\n    new_alloced /= 1024**3\n    new_cached /= 1024**3\n    delta_alloced /= 1024**3\n    delta_cached /= 1024**3\n    max_alloced /= 1024**3\n    max_cached /= 1024**3\n\n    print(\n        f'RANK={rank} MEMSTATS', msg, f'device={get_accelerator().current_device_name()} '\n        f'current alloc={new_alloced:0.4f}GB (delta={delta_alloced:0.4f}GB max={max_alloced:0.4f}GB) '\n        f'current cache={new_cached:0.4f}GB (delta={delta_cached:0.4f}GB max={max_cached:0.4f}GB)')\n\n\ndef get_ma_status():\n    if dist.is_initialized() and not dist.get_rank() == 0:\n        return 0\n    return get_accelerator().memory_allocated()\n\n\ndef empty_cache():\n    get_accelerator().empty_cache()\n    get_accelerator().reset_peak_memory_stats()\n\n\ndef see_memory_usage(message, force=False):\n    if not force:\n        return\n    if dist.is_initialized() and not dist.get_rank() == 0:\n        return\n\n    # python doesn't do real-time garbage collection so do it explicitly to get the correct RAM reports\n    gc.collect()\n\n    # Print message except when distributed but not rank 0\n    logger.info(message)\n    logger.info(f\"MA {round(get_accelerator().memory_allocated() / (1024 * 1024 * 1024),2 )} GB \\\n        Max_MA {round(get_accelerator().max_memory_allocated() / (1024 * 1024 * 1024),2)} GB \\\n        CA {round(torch_memory_reserved() / (1024 * 1024 * 1024),2)} GB \\\n        Max_CA {round(torch_max_memory_reserved() / (1024 * 1024 * 1024))} GB \")\n\n    vm_stats = psutil.virtual_memory()\n    used_GB = round(((vm_stats.total - vm_stats.available) / (1024**3)), 2)\n    logger.info(f'CPU Virtual Memory:  used = {used_GB} GB, percent = {vm_stats.percent}%')\n\n    # get the peak memory to report correct data, so reset the counter for the next call\n    get_accelerator().reset_peak_memory_stats()\n\n\ndef call_to_str(base, *args, **kwargs):\n    \"\"\"Construct a string representation of a call.\n\n    Args:\n        base (str): name of the call\n        args (tuple, optional): args to ``base``\n        kwargs (dict, optional): kwargs supplied to ``base``\n\n    Returns:\n        str: A string representation of base(*args, **kwargs)\n    \"\"\"\n    name = f'{base}('\n    if args:\n        name += ', '.join(repr(arg) for arg in args)\n        if kwargs:\n            name += ', '\n    if kwargs:\n        name += ', '.join(f'{key}={repr(arg)}' for key, arg in kwargs.items())\n    name += ')'\n    return name\n\n\ndef get_only_unique_item(items):\n    item_set = set(items)\n    if len(item_set) != 1:\n        raise RuntimeError(f\"expected there to be only one unique element in {items}\")\n    unique_item, = item_set\n\n    return unique_item\n\n\ndef clip_gradients(parameters, max_norm=1.0, global_grad_norm=None, mpu=None, eps=1e-6):\n    \"\"\"Clip the gradient of a list of parameters.\n    Args:\n        parameters: List of parameters whose .grad will be clipped.\n        global_grad_norm (float, optional): Precomputed gradient norm. Defaults to None.\n        mpu (optional): model parallelism unit. Defaults to None.\n        eps (float, optional): epsilon value added to grad norm. Defaults to 1e-6\n    Returns:\n        float: the global gradient norm\n    \"\"\"\n    if global_grad_norm is None:\n        global_grad_norm = get_grad_norm(parameters, mpu=mpu)\n    clip_coef = max_norm / (global_grad_norm + eps)\n    if clip_coef < 1:\n        for p in parameters:\n            p.grad.detach().mul_(clip_coef)\n    return global_grad_norm\n\n\ndef get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None, use_graph=False):\n    \"\"\"Get norm of an iterable of tensors.\n\n    This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and\n    added functionality to handle model parallel parameters. Taken from Nvidia Megatron.\n\n    Arguments:\n        input_tensors (Iterable[Tensor]): an iterable of Tensors will have norm computed\n        norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for\n            infinity norm.\n\n    Returns:\n        Total norm of the tensors (viewed as a single vector).\n    \"\"\"\n    assert isinstance(input_tensors, Iterable), f'expected Iterable type not {type(input_tensors)}'\n    assert all([torch.is_tensor(t) for t in input_tensors]), f'expected list of only tensors'\n\n    norm_type = float(norm_type)\n    if norm_type == inf:\n        total_norm = max(t.data.abs().max() for t in input_tensors)\n        total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])\n        if mpu is not None:\n            dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group())\n            total_norm = total_norm_cuda[0].item()\n    else:\n        if use_graph:\n            if 'norm_tensors_compute_buffer' not in graph_cache:\n                graph_cache['norm_tensors_compute_buffer'] = [t.data.float().norm(norm_type) for t in input_tensors]\n            compute_buffer = graph_cache['norm_tensors_compute_buffer']\n\n            def _norm_tensors(tensor_list, _compute_buffer, _norm_type):\n                for i, t in enumerate(tensor_list):\n                    _compute_buffer[i].data.copy_(t.data.float().norm(_norm_type)**_norm_type)\n                    if i != 0:\n                        _compute_buffer[0].data.add_(_compute_buffer[i].data)\n\n            graph_process(False, _norm_tensors, input_tensors, compute_buffer, norm_type)\n\n            total_norm = compute_buffer[0]\n        else:\n            total_norm = sum([t.data.float().norm(norm_type).item()**norm_type for t in input_tensors])\n\n        total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]).detach()\n        if mpu is not None:\n            dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group())\n        total_norm = total_norm_cuda[0].item()**(1. / norm_type)\n\n    if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm:\n        total_norm = -1\n\n    return total_norm\n\n\ndef clip_tensors_by_global_norm(input_tensors, max_norm=1.0, global_norm=None, mpu=None, eps=1e-6, use_graph=False):\n    \"\"\"Clip list of tensors by global norm.\n    Args:\n        input_tensors: List of tensors to be clipped\n        global_norm (float, optional): Precomputed norm. Defaults to None.\n        mpu (optional): model parallelism unit. Defaults to None.\n        eps (float, optional): epsilon value added to grad norm. Defaults to 1e-6\n    Returns:\n        float: the global norm\n    \"\"\"\n    if global_norm is None:\n        global_norm = get_global_norm_of_tensors(input_tensors, mpu=mpu, use_graph=use_graph)\n    clip_coef = max_norm / (global_norm + eps)\n    if clip_coef < 1:\n        if use_graph:\n\n            def clip_tensors(_tensor_list, _clip_coef_tensor):\n                for t in _tensor_list:\n                    t.detach().mul_(_clip_coef_tensor)\n\n            if 'clip_coef_tensor' not in graph_cache:\n                # Alloc memory\n                graph_cache['clip_coef_tensor'] = torch.tensor(clip_coef,\n                                                               dtype=torch.float32).to(get_accelerator().device_name())\n            clip_coef_tensor = graph_cache['clip_coef_tensor']\n            clip_coef_tensor.copy_(torch.tensor(clip_coef, dtype=torch.float32))\n            graph_process(False, clip_tensors, input_tensors, clip_coef_tensor)\n\n        else:\n            for t in input_tensors:\n                t.detach().mul_(clip_coef)\n    return global_norm\n\n\ndef align_dense_tensors(tensor_list, alignment):\n    num_elements = sum(t.numel() for t in tensor_list)\n    remaining = num_elements % alignment\n\n    if remaining:\n        elements_to_add = alignment - remaining\n        pad_tensor = torch.zeros(elements_to_add, device=tensor_list[0].device, dtype=tensor_list[0].dtype)\n        padded_tensor_list = tensor_list + [pad_tensor]\n    else:\n        padded_tensor_list = tensor_list\n\n    return padded_tensor_list\n\n\ndef all_gather_into_tensor_dp_groups(groups_flat, partitioned_param_groups, zp_process_group, dp_process_group=None):\n\n    for group_id, (group_flat, partitioned_params) in enumerate(zip(groups_flat, partitioned_param_groups)):\n        partition_id = dist.get_rank(group=zp_process_group[group_id])\n        dp_world_size = dist.get_world_size(group=dp_process_group)\n        if dp_world_size == 1:\n            # no groups share optimizer states\n            # pipeline parallel with bf16 will default call this even if dp size = 1.\n            continue\n        # print(\"call contiguous for all_gather_into_tensor_dp_groups\")\n        dist.all_gather_into_tensor(group_flat, partitioned_params[partition_id].contiguous(), dp_process_group)\n\n\ndef all_gather_dp_groups(groups_flat, partitioned_param_groups, zp_process_group, start_alignment_factor,\n                         allgather_bucket_size, dp_process_group=None):\n    # if dist.has_all_gather_into_tensor():\n    return all_gather_into_tensor_dp_groups(groups_flat, partitioned_param_groups, zp_process_group, dp_process_group)\n\n    # for group_id, partitioned_params in enumerate(partitioned_param_groups):\n    #     # Sequential AllGather Best of both worlds\n    #     partition_id = dist.get_rank(group=dp_process_group[group_id])\n    #     dp_world_size = dist.get_world_size(group=dp_process_group[group_id])\n    #\n    #     if dp_world_size == 1:\n    #         # no groups share optimizer states\n    #         # pipeline parallel with bf16 will default call this even if dp size = 1.\n    #         continue\n    #     num_shards = max(1, partitioned_params[partition_id].numel() * dp_world_size // allgather_bucket_size)\n    #\n    #     shard_size = partitioned_params[partition_id].numel() // num_shards\n    #\n    #     # Enforce nccl/rccl alignment of start location of each shard\n    #     shard_size = shard_size - (shard_size % start_alignment_factor)\n    #\n    #     num_elements = shard_size\n    #\n    #     assert shard_size * num_shards <= partitioned_params[partition_id].numel()\n    #\n    #     for shard_id in range(num_shards):\n    #\n    #         if shard_id == (num_shards - 1):\n    #             num_elements = partitioned_params[partition_id].numel() - shard_id * shard_size\n    #\n    #         shard_list = []\n    #         for dp_id in range(dp_world_size):\n    #             curr_shard = partitioned_params[dp_id].narrow(0, shard_id * shard_size, num_elements).detach()\n    #             shard_list.append(curr_shard)\n    #         dist.all_gather(shard_list, shard_list[partition_id].contiguous(), dp_process_group[group_id])\n\n\nclass TLinear(torch.nn.Linear):\n\n    def __init__(self, orig_layer, name=\"\"):\n        self.name = name\n        super().__init__(orig_layer.weight.shape[1], orig_layer.weight.shape[0], bias=(orig_layer.bias is not None))\n        self.weight.data = transpose(orig_layer.weight.data)\n        self.bias = orig_layer.bias\n        self._fwd_func = self._fwd_bias_add if self.bias is not None else self._fwd\n\n    def _fwd(self, input):\n        return F.linear(input, self.weight)\n\n    def _fwd_bias_add(self, input):\n        return F.linear(input, self.weight, bias=self.bias)\n\n    def forward(self, input):\n        return self._fwd_func(input)\n\n\ndef get_inactive_params(param_list):\n    from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus\n    return [param for param in param_list if (hasattr(param, 'ds_id') and \\\n                            param.ds_status == ZeroParamStatus.NOT_AVAILABLE)]\n\n\ndef required_torch_version(min_version=None, max_version=None):\n    assert min_version or max_version, \"Must provide a min_version or max_version argument\"\n\n    torch_version = pkg_version.parse(torch.__version__)\n\n    if min_version and pkg_version.parse(str(min_version)) > torch_version:\n        return False\n\n    if max_version and pkg_version.parse(str(max_version)) < torch_version:\n        return False\n\n    return True\n"
  },
  {
    "path": "opensora/adaptor/zp_manager.py",
    "content": "import torch\nimport os\nimport torch.distributed as dist\n\n\nclass ZPManager(object):\n    def __init__(self, zp_size=8):\n        self.rank = int(os.getenv('RANK', '0'))\n        self.world_size = int(os.getenv(\"WORLD_SIZE\", '1'))\n        self.zp_size = zp_size\n        self.zp_group = None\n        self.zp_rank = None\n        self.is_initialized = False\n\n    def init_group(self):\n        if self.is_initialized:\n            return\n\n        self.is_initialized = True\n\n        \"\"\"Initialize the sequence parallel group.\"\"\"\n        num_zp_groups: int = self.world_size // self.zp_size\n        for i in range(num_zp_groups):\n            ranks = range(i * self.zp_size, (i + 1) * self.zp_size)\n            group = dist.new_group(ranks)\n            if self.rank in ranks:\n                self.zp_group = group\n                self.zp_rank = self.rank % self.zp_size\n\n\nzp_manager = ZPManager()\n"
  },
  {
    "path": "opensora/dataset/__init__.py",
    "content": "from torchvision.transforms import Compose\nfrom transformers import AutoTokenizer, AutoImageProcessor\n\nfrom torchvision import transforms\nfrom torchvision.transforms import Lambda\n\ntry:\n    import torch_npu\nexcept:\n    torch_npu = None\n\nfrom opensora.dataset.t2v_datasets import T2V_dataset\nfrom opensora.dataset.inpaint_dataset import Inpaint_dataset\nfrom opensora.models.causalvideovae import ae_norm, ae_denorm\nfrom opensora.dataset.transform import ToTensorVideo, TemporalRandomCrop, MaxHWResizeVideo, CenterCropResizeVideo, LongSideResizeVideo, SpatialStrideCropVideo, NormalizeVideo, ToTensorAfterResize\n\n\n\ndef getdataset(args):\n    temporal_sample = TemporalRandomCrop(args.num_frames)  # 16 x\n    norm_fun = ae_norm[args.ae]\n    if args.force_resolution:\n        resize = [CenterCropResizeVideo((args.max_height, args.max_width)), ]\n    else:\n        resize = [\n            MaxHWResizeVideo(args.max_hxw), \n            SpatialStrideCropVideo(stride=args.hw_stride), \n        ]\n\n    tokenizer_1 = AutoTokenizer.from_pretrained(args.text_encoder_name_1, cache_dir=args.cache_dir)\n    tokenizer_2 = None\n    if args.text_encoder_name_2 is not None:\n        tokenizer_2 = AutoTokenizer.from_pretrained(args.text_encoder_name_2, cache_dir=args.cache_dir)\n    if args.dataset == 't2v':\n        transform = transforms.Compose([\n            ToTensorVideo(),\n            *resize, \n            norm_fun\n        ])  # also work for img, because img is video when frame=1\n        return T2V_dataset(\n            args, transform=transform, temporal_sample=temporal_sample, \n            tokenizer_1=tokenizer_1, tokenizer_2=tokenizer_2\n            )\n    elif args.dataset == 'i2v' or args.dataset == 'inpaint':\n        resize_transform = Compose(resize)\n        transform = Compose([\n            ToTensorAfterResize(),\n            norm_fun,\n        ])\n        return Inpaint_dataset(\n            args, resize_transform=resize_transform, transform=transform, \n            temporal_sample=temporal_sample, tokenizer_1=tokenizer_1, tokenizer_2=tokenizer_2\n        )\n    raise NotImplementedError(args.dataset)\n\n\nif __name__ == \"__main__\":\n    '''\n    python opensora/dataset/__init__.py\n    '''\n    from accelerate import Accelerator\n    from opensora.dataset.t2v_datasets import dataset_prog\n    from opensora.utils.dataset_utils import LengthGroupedSampler, Collate\n    from torch.utils.data import DataLoader\n    import random\n    from torch import distributed as dist\n    from tqdm import tqdm\n    args = type('args', (), \n    {\n        'ae': 'WFVAEModel_D32_4x8x8', \n        'dataset': 't2v', \n        'model_max_length': 512, \n        'max_height': 640,\n        'max_width': 640,\n        'hw_stride': 16, \n        'num_frames': 93,\n        'compress_kv_factor': 1, \n        'interpolation_scale_t': 1,\n        'interpolation_scale_h': 1,\n        'interpolation_scale_w': 1,\n        'cache_dir': '../cache_dir', \n        'data': '/home/image_data/gyy/mmdit/Open-Sora-Plan/scripts/train_data/current_hq_on_npu.txt', \n        'train_fps': 18, \n        'drop_short_ratio': 0.0, \n        'speed_factor': 1.0, \n        'cfg': 0.1, \n        'text_encoder_name_1': 'google/mt5-xxl', \n        'text_encoder_name_2': None,\n        'dataloader_num_workers': 8,\n        'force_resolution': False, \n        'use_decord': True, \n        'group_data': True, \n        'train_batch_size': 1, \n        'gradient_accumulation_steps': 1, \n        'ae_stride': 8, \n        'ae_stride_t': 4,  \n        'patch_size': 2, \n        'patch_size_t': 1, \n        'total_batch_size': 256, \n        'sp_size': 1, \n        'max_hxw': 384*384, \n        'min_hxw': 384*288, \n        # 'max_hxw': 236544, \n        # 'min_hxw': 102400, \n    }\n    )\n    # accelerator = Accelerator()\n    dataset = getdataset(args)\n    # data = next(iter(dataset))\n    # import ipdb;ipdb.set_trace()\n    # print()\n    sampler = LengthGroupedSampler(\n                args.train_batch_size,\n                world_size=1, \n                gradient_accumulation_size=args.gradient_accumulation_steps, \n                initial_global_step=0, \n                lengths=dataset.lengths, \n                group_data=args.group_data, \n            )\n    train_dataloader = DataLoader(\n        dataset,\n        shuffle=False,\n        # pin_memory=True,\n        collate_fn=Collate(args),\n        batch_size=args.train_batch_size,\n        num_workers=args.dataloader_num_workers,\n        sampler=sampler, \n        drop_last=False, \n        prefetch_factor=4\n    )\n    import ipdb;ipdb.set_trace()\n    import imageio\n    import numpy as np\n    from einops import rearrange\n    while True:\n        for idx, i in enumerate(tqdm(train_dataloader)):\n            pixel_values = i[0][0]\n            pixel_values_ = (pixel_values+1)/2\n            pixel_values_ = rearrange(pixel_values_, 'c t h w -> t h w c') * 255.0\n            pixel_values_ = pixel_values_.numpy().astype(np.uint8)\n            imageio.mimwrite(f'output{idx}.mp4', pixel_values_, fps=args.train_fps)\n            dist.barrier()\n            pass"
  },
  {
    "path": "opensora/dataset/inpaint_dataset.py",
    "content": "import time\nimport traceback\n\ntry:\n    import torch_npu\n    from opensora.npu_config import npu_config\nexcept:\n    torch_npu = None\n    npu_config = None\nimport glob\nimport json\nimport pickle\nimport os, io, csv, math, random\nimport numpy as np\nimport torchvision\nfrom einops import rearrange\nfrom os.path import join as opj\nfrom collections import Counter\n\nimport cv2\nimport pandas as pd\nimport time\nimport torch\nimport torchvision.transforms as transforms\nfrom torch.utils.data.dataset import Dataset\nfrom torch.utils.data import DataLoader, Dataset, get_worker_info\nfrom tqdm import tqdm\nfrom PIL import Image\nfrom accelerate.logging import get_logger\nimport gc\nimport decord\n\nfrom opensora.utils.dataset_utils import DecordInit\nfrom opensora.utils.utils import text_preprocessing\nfrom opensora.dataset.transform import get_params, maxhwresize, add_masking_notice, calculate_statistics, \\\n    add_aesthetic_notice_image, add_aesthetic_notice_video\nfrom opensora.utils.mask_utils import MaskProcessor, STR_TO_TYPE\nfrom opensora.dataset.t2v_datasets import T2V_dataset, DataSetProg\n\nlogger = get_logger(__name__)\n\ndataset_prog = DataSetProg()\n\ndef type_ratio_normalize(mask_type_ratio_dict):\n    for k, v in mask_type_ratio_dict.items():\n        assert v >= 0, f\"mask_type_ratio_dict[{k}] should be non-negative, but got {v}\"\n    total = sum(mask_type_ratio_dict.values())\n    length = len(mask_type_ratio_dict)\n    if total == 0:\n        return {k: 1.0 / length for k in mask_type_ratio_dict.keys()}\n    return {k: v / total for k, v in mask_type_ratio_dict.items()}\n\nclass Inpaint_dataset(T2V_dataset):\n    def __init__(self, args, resize_transform, transform, temporal_sample, tokenizer_1, tokenizer_2):\n        super().__init__(\n            args=args, \n            transform=transform,  \n            temporal_sample=temporal_sample, \n            tokenizer_1=tokenizer_1, \n            tokenizer_2=tokenizer_2\n        )\n\n        self.resize_transform = resize_transform\n\n        if self.num_frames != 1:\n            self.mask_type_ratio_dict_video = args.mask_type_ratio_dict_video if args.mask_type_ratio_dict_video is not None else {'random_temporal': 1.0}\n            self.mask_type_ratio_dict_video = {STR_TO_TYPE[k]: v for k, v in self.mask_type_ratio_dict_video.items()}\n            self.mask_type_ratio_dict_video = type_ratio_normalize(self.mask_type_ratio_dict_video)\n                \n        self.mask_type_ratio_dict_image = args.mask_type_ratio_dict_image if args.mask_type_ratio_dict_image is not None else {'random_spatial': 1.0}\n        self.mask_type_ratio_dict_image = {STR_TO_TYPE[k]: v for k, v in self.mask_type_ratio_dict_image.items()}\n        self.mask_type_ratio_dict_image = type_ratio_normalize(self.mask_type_ratio_dict_image)\n\n        print(f\"mask_type_ratio_dict_video: {self.mask_type_ratio_dict_video}\")\n        print(f\"mask_type_ratio_dict_image: {self.mask_type_ratio_dict_image}\")\n\n        self.mask_processor = MaskProcessor(\n            max_height=args.max_height,\n            max_width=args.max_width,\n            min_clear_ratio=args.min_clear_ratio,\n            max_clear_ratio=args.max_clear_ratio,\n        )\n\n        self.default_text_ratio = args.default_text_ratio\n\n    def __getitem__(self, idx):\n        try:\n            # future = self.executor.submit(self.get_data, idx)\n            # data = future.result(timeout=self.timeout) \n            # return data\n            return self.get_data(idx)\n        except Exception as e:\n            # if len(str(e)) < 2:\n            #     e = f\"TimeoutError, {self.timeout}s timeout occur with {dataset_prog.cap_list[idx]['path']}\"\n            print(f'Error with {e}')\n            index_cand = self.shape_idx_dict[self.sample_size[idx]]  # pick same shape\n            return self.__getitem__(random.choice(index_cand))\n            # return self.__getitem__(idx)\n    \n    def get_data(self, idx):\n        path = dataset_prog.cap_list[idx]['path']\n        if not os.path.exists(path):\n            print(f\"file {path} do not exist, random choice a new one with same shape!\")\n            index_cand = self.shape_idx_dict[self.sample_size[idx]]\n            return self.__getitem__(random.choice(index_cand))\n        if path.endswith('.mp4'):\n            return self.get_video(idx)\n        else:\n            return self.get_image(idx)\n\n    def drop(self, text, is_video=True):\n        rand_num = random.random()\n        rand_num_text = random.random()\n\n        if rand_num < self.cfg:\n            if rand_num_text < self.default_text_ratio:\n                if not is_video:\n                    text = \"The image showcases a scene with coherent and clear visuals.\" \n                else:\n                    text = \"The video showcases a scene with coherent and clear visuals.\" \n            else:\n                text = ''\n\n        return dict(text=text)\n    \n    def get_video(self, idx):\n        # npu_config.print_msg(f\"current idx is {idx}\")\n        # video = random.choice([random_video_noise(65, 3, 336, 448), random_video_noise(65, 3, 1024, 1024), random_video_noise(65, 3, 360, 480)])\n        # # print('random shape', video.shape)\n        # input_ids = torch.ones(1, 120).to(torch.long).squeeze(0)\n        # cond_mask = torch.cat([torch.ones(1, 60).to(torch.long), torch.ones(1, 60).to(torch.long)], dim=1).squeeze(0)\n        # logger.info(f'Now we use t2v dataset {idx}')\n        video_data = dataset_prog.cap_list[idx]\n        video_path = video_data['path']\n        # assert os.path.exists(video_path), f\"file {video_path} do not exist!\"\n        sample_h = video_data['resolution']['sample_height']\n        sample_w = video_data['resolution']['sample_width']\n        \n        if self.video_reader == 'decord':\n            video = self.decord_read(video_data)\n        elif self.video_reader == 'opencv':\n            video = self.opencv_read(video_data)\n        else:\n            NotImplementedError(f'Found {self.video_reader}, but support decord or opencv')\n        # import ipdb;ipdb.set_trace()\n\n        video = self.resize_transform(video)  # T C H W -> T C H W\n        assert video.shape[2] == sample_h and video.shape[3] == sample_w, f'sample_h ({sample_h}), sample_w ({sample_w}), video ({video.shape}), video_path ({video_path})'\n\n        inpaint_cond_data = self.mask_processor(video, mask_type_ratio_dict=self.mask_type_ratio_dict_video)\n        mask, masked_video = inpaint_cond_data['mask'], inpaint_cond_data['masked_pixel_values']\n\n        video = self.transform(video)  # T C H W -> T C H W\n        masked_video = self.transform(masked_video)  # T C H W -> T C H W\n\n        video = torch.cat([video, masked_video, mask], dim=1)  # T 2C+1 H W\n\n        video = video.transpose(0, 1)  # T C H W -> C T H W\n        text = video_data['cap']\n        if not isinstance(text, list):\n            text = [text]\n        text = [random.choice(text)]\n        if video_data.get('aesthetic', None) is not None or video_data.get('aes', None) is not None:\n            aes = video_data.get('aesthetic', None) or video_data.get('aes', None)\n            text = [add_aesthetic_notice_video(text[0], aes)]\n\n        text = self.drop(text, is_video=True)['text']\n\n        text_tokens_and_mask_1 = self.tokenizer_1(\n            text,\n            max_length=self.model_max_length,\n            padding='max_length',\n            truncation=True,\n            return_attention_mask=True,\n            add_special_tokens=True,\n            return_tensors='pt'\n        )\n        input_ids_1 = text_tokens_and_mask_1['input_ids']\n        cond_mask_1 = text_tokens_and_mask_1['attention_mask']\n        \n        input_ids_2, cond_mask_2 = None, None\n        if self.tokenizer_2 is not None:\n            text_tokens_and_mask_2 = self.tokenizer_2(\n                text,\n                max_length=self.tokenizer_2.model_max_length,\n                padding='max_length',\n                truncation=True,\n                return_attention_mask=True,\n                add_special_tokens=True,\n                return_tensors='pt'\n            )\n            input_ids_2 = text_tokens_and_mask_2['input_ids']\n            cond_mask_2 = text_tokens_and_mask_2['attention_mask']\n\n        return dict(\n            pixel_values=video, input_ids_1=input_ids_1, cond_mask_1=cond_mask_1, \n            input_ids_2=input_ids_2, cond_mask_2=cond_mask_2,\n            )\n\n    def get_image(self, idx):\n        image_data = dataset_prog.cap_list[idx]  # [{'path': path, 'cap': cap}, ...]\n        sample_h = image_data['resolution']['sample_height']\n        sample_w = image_data['resolution']['sample_width']\n\n        image = Image.open(image_data['path']).convert('RGB')  # [h, w, c]\n        image = torch.from_numpy(np.array(image))  # [h, w, c]\n        image = rearrange(image, 'h w c -> c h w').unsqueeze(0)  #  [1 c h w]\n\n        image = self.resize_transform(image)  # [1 c h w]\n        assert image.shape[2] == sample_h, image.shape[3] == sample_w\n\n        inpaint_cond_data = self.mask_processor(image, mask_type_ratio_dict=self.mask_type_ratio_dict_image)\n        mask, masked_image = inpaint_cond_data['mask'], inpaint_cond_data['masked_pixel_values']   \n\n        image = self.transform(image)\n        masked_image = self.transform(masked_image)\n\n        image = torch.cat([image, masked_image, mask], dim=1)  #  [1 2C+1 H W]\n        # image = [torch.rand(1, 3, 480, 640) for i in image_data]\n        image = image.transpose(0, 1)  # [1 C H W] -> [C 1 H W]\n\n        caps = image_data['cap'] if isinstance(image_data['cap'], list) else [image_data['cap']]\n        caps = [random.choice(caps)]\n        # caps = [caps[0]]\n        if '/sam/' in image_data['path']:\n            caps = [add_masking_notice(caps[0])]\n        if image_data.get('aesthetic', None) is not None or image_data.get('aes', None) is not None:\n            aes = image_data.get('aesthetic', None) or image_data.get('aes', None)\n            caps = [add_aesthetic_notice_image(caps[0], aes)]\n        text = text_preprocessing(caps, support_Chinese=self.support_Chinese)\n        text = self.drop(text, is_video=False)['text']\n\n        text_tokens_and_mask_1 = self.tokenizer_1(\n            text,\n            max_length=self.model_max_length,\n            padding='max_length',\n            truncation=True,\n            return_attention_mask=True,\n            add_special_tokens=True,\n            return_tensors='pt'\n        )\n        input_ids_1 = text_tokens_and_mask_1['input_ids']  # 1, l\n        cond_mask_1 = text_tokens_and_mask_1['attention_mask']  # 1, l\n        \n        input_ids_2, cond_mask_2 = None, None\n        if self.tokenizer_2 is not None:\n            text_tokens_and_mask_2 = self.tokenizer_2(\n                text,\n                max_length=self.tokenizer_2.model_max_length,\n                padding='max_length',\n                truncation=True,\n                return_attention_mask=True,\n                add_special_tokens=True,\n                return_tensors='pt'\n            )\n            input_ids_2 = text_tokens_and_mask_2['input_ids']  # 1, l\n            cond_mask_2 = text_tokens_and_mask_2['attention_mask']  # 1, l\n\n        return dict(\n            pixel_values=image, input_ids_1=input_ids_1, cond_mask_1=cond_mask_1, motion_score=None, \n            input_ids_2=input_ids_2, cond_mask_2=cond_mask_2\n            )"
  },
  {
    "path": "opensora/dataset/t2v_datasets.py",
    "content": "import time\nimport traceback\n\ntry:\n    import torch_npu\n    from opensora.npu_config import npu_config\nexcept:\n    torch_npu = None\n    npu_config = None\nimport glob\nimport json\nimport pickle\nimport os, io, csv, math, random\nimport numpy as np\nimport torchvision\nfrom einops import rearrange\nfrom os.path import join as opj\nfrom collections import Counter\n\nimport cv2\nimport pandas as pd\nimport time\nimport torch\nimport torchvision.transforms as transforms\nfrom torch.utils.data.dataset import Dataset\nfrom torch.utils.data import DataLoader, Dataset, get_worker_info\nfrom tqdm import tqdm\nfrom PIL import Image\nfrom accelerate.logging import get_logger\nimport gc\n\nfrom opensora.utils.dataset_utils import DecordInit\nfrom opensora.utils.utils import text_preprocessing\nfrom opensora.dataset.transform import get_params, maxhwresize, add_masking_notice, calculate_statistics, \\\n    add_aesthetic_notice_image, add_aesthetic_notice_video\n\nimport decord\nfrom concurrent.futures import ThreadPoolExecutor, TimeoutError\n\nlogger = get_logger(__name__)\n\ndef filter_json_by_existed_files(directory, data, postfix=\".mp4\"):\n    # 构建搜索模式，以匹配指定后缀的文件\n    pattern = os.path.join(directory, '**', f'*{postfix}')\n    mp4_files = glob.glob(pattern, recursive=True)  # 使用glob查找所有匹配的文件\n\n    # 使用文件的绝对路径构建集合\n    mp4_files_set = set(os.path.abspath(path) for path in mp4_files)\n\n    # 过滤数据条目，只保留路径在mp4文件集合中的条目\n    filtered_items = [item for item in data if item['path'] in mp4_files_set]\n\n    return filtered_items\n\n\ndef random_video_noise(t, c, h, w):\n    vid = torch.rand(t, c, h, w) * 255.0\n    vid = vid.to(torch.uint8)\n    return vid\n\n\nclass SingletonMeta(type):\n    \"\"\"\n    这是一个元类，用于创建单例类。\n    \"\"\"\n    _instances = {}\n\n    def __call__(cls, *args, **kwargs):\n        if cls not in cls._instances:\n            instance = super().__call__(*args, **kwargs)\n            cls._instances[cls] = instance\n        return cls._instances[cls]\n\n\nclass DataSetProg(metaclass=SingletonMeta):\n    def __init__(self):\n        self.cap_list = []\n        self.elements = []\n        self.num_workers = 1\n        self.n_elements = 0\n        self.worker_elements = dict()\n        self.n_used_elements = dict()\n\n    def set_cap_list(self, num_workers, cap_list, n_elements):\n        self.num_workers = num_workers\n        self.cap_list = cap_list\n        self.n_elements = n_elements\n        self.elements = list(range(n_elements))\n        \n        print(f\"n_elements: {len(self.elements)}\", flush=True)\n        # if torch_npu is not None:\n        #     random.shuffle(self.elements)\n        #     for i in range(self.num_workers):\n        #         self.n_used_elements[i] = 0\n        #         per_worker = int(math.ceil(len(self.elements) / float(self.num_workers)))\n        #         start = i * per_worker\n        #         end = min(start + per_worker, len(self.elements))\n        #         self.worker_elements[i] = self.elements[start: end]\n\n    def get_item(self, work_info):\n        if work_info is None:\n            worker_id = 0\n        else:\n            worker_id = work_info.id\n\n        idx = self.worker_elements[worker_id][self.n_used_elements[worker_id] % len(self.worker_elements[worker_id])]\n        self.n_used_elements[worker_id] += 1\n        return idx\n\n\ndataset_prog = DataSetProg()\n\ndef find_closest_y(x, vae_stride_t=4, model_ds_t=1):\n    min_num_frames = 29\n    if x < min_num_frames:\n        return -1  \n    for y in range(x, min_num_frames - 1, -1):\n        if (y - 1) % vae_stride_t == 0 and ((y - 1) // vae_stride_t + 1) % model_ds_t == 0:\n            # 4, 8: y in [29, 61, 93, 125, 157, 189, 221, 253, 285, 317, 349, 381, 413, 445, 477, 509, ...]\n            # 4, 4: y in [29, 45, 61, 77, 93, 109, 125, 141, 157, 173, 189, 205, 221, 237, 253, 269, 285, 301, 317, 333, 349, 365, 381, 397, 413, 429, 445, 461, 477, 493, 509, ...]\n            # 8, 1: y in [33, 41, 49, 57, 65, 73, 81, 89, 97, 105]\n            # 8, 2: y in [41, 57, 73, 89, 105]\n            # 8, 4: y in [57, 89]\n            # 8, 8: y in [57]\n            return y\n    return -1 \n\ndef filter_resolution(h, w, max_h_div_w_ratio=17/16, min_h_div_w_ratio=8 / 16):\n    if h / w <= max_h_div_w_ratio and h / w >= min_h_div_w_ratio:\n        return True\n    return False\n        \ndef read_parquet(path):\n    df = pd.read_parquet(path)\n    data = df.to_dict(orient='records')\n    return data\n\n\n\nclass DecordDecoder(object):\n    def __init__(self, url, num_threads=1):\n\n        self.num_threads = num_threads\n        self.ctx = decord.cpu(0)\n        self.reader = decord.VideoReader(url,\n                                    ctx=self.ctx,\n                                    num_threads=self.num_threads)\n\n    def get_avg_fps(self):\n        return self.reader.get_avg_fps() if self.reader.get_avg_fps() > 0 else 30.0\n\n    def get_num_frames(self):\n        return len(self.reader)\n\n    def get_height(self):\n        return self.reader[0].shape[0] if self.get_num_frames() > 0 else 0\n\n    def get_width(self):\n        return self.reader[0].shape[1] if self.get_num_frames() > 0 else 0\n\n    # output shape [T, H, W, C]\n    def get_batch(self, frame_indices):\n        try:\n            #frame_indices[0] = 1000\n            video_data = self.reader.get_batch(frame_indices).asnumpy()\n            video_data = torch.from_numpy(video_data)\n            return video_data\n        except Exception as e:\n            print('get_batch execption:', e)\n            return None\n        \nclass T2V_dataset(Dataset):\n    def __init__(self, args, transform, temporal_sample, tokenizer_1, tokenizer_2):\n        self.data = args.data\n        self.num_frames = args.num_frames\n        self.train_fps = args.train_fps\n        self.transform = transform\n        self.temporal_sample = temporal_sample\n        self.tokenizer_1 = tokenizer_1\n        self.tokenizer_2 = tokenizer_2\n        self.model_max_length = args.model_max_length\n        self.cfg = args.cfg\n        self.speed_factor = args.speed_factor\n        self.max_height = args.max_height\n        self.max_width = args.max_width\n        self.drop_short_ratio = args.drop_short_ratio\n        self.hw_stride = args.hw_stride\n        self.force_resolution = args.force_resolution\n        self.max_hxw = args.max_hxw\n        self.min_hxw = args.min_hxw\n        self.sp_size = args.sp_size\n        assert self.speed_factor >= 1\n        self.video_reader = 'decord' if args.use_decord else 'opencv'\n        self.ae_stride_t = args.ae_stride_t\n        self.total_batch_size = args.total_batch_size\n        self.seed = 42\n        self.generator = torch.Generator().manual_seed(self.seed) \n        self.hw_aspect_thr = 2.0  # just a threshold\n        self.too_long_factor = 5.0\n\n        self.support_Chinese = False\n        if 'mt5' in args.text_encoder_name_1:\n            self.support_Chinese = True\n        if args.text_encoder_name_2 is not None and 'mt5' in args.text_encoder_name_2:\n            self.support_Chinese = True\n\n        s = time.time()\n        cap_list, self.sample_size, self.shape_idx_dict = self.define_frame_index(self.data)\n        e = time.time()\n        print(f'Build data time: {e-s}')\n        self.lengths = self.sample_size\n\n        n_elements = len(cap_list)\n        dataset_prog.set_cap_list(args.dataloader_num_workers, cap_list, n_elements)\n        print(f\"Data length: {len(dataset_prog.cap_list)}\")\n        self.executor = ThreadPoolExecutor(max_workers=1)\n        self.timeout = 60\n\n    def set_checkpoint(self, n_used_elements):\n        for i in range(len(dataset_prog.n_used_elements)):\n            dataset_prog.n_used_elements[i] = n_used_elements\n\n    def __len__(self):\n        return dataset_prog.n_elements\n\n    def __getitem__(self, idx):\n        try:\n            future = self.executor.submit(self.get_data, idx)\n            data = future.result(timeout=self.timeout) \n            # data = self.get_data(idx)\n            return data\n        except Exception as e:\n            if len(str(e)) < 2:\n                e = f\"TimeoutError, {self.timeout}s timeout occur with {dataset_prog.cap_list[idx]['path']}\"\n            print(f'Error with {e}')\n            index_cand = self.shape_idx_dict[self.sample_size[idx]]  # pick same shape\n            return self.__getitem__(random.choice(index_cand))\n\n    def get_data(self, idx):\n        path = dataset_prog.cap_list[idx]['path']\n        if path.endswith('.mp4'):\n            return self.get_video(idx)\n        else:\n            return self.get_image(idx)\n    \n    def get_video(self, idx):\n        video_data = dataset_prog.cap_list[idx]\n        video_path = video_data['path']\n        assert os.path.exists(video_path), f\"file {video_path} do not exist!\"\n        sample_h = video_data['resolution']['sample_height']\n        sample_w = video_data['resolution']['sample_width']\n        if self.video_reader == 'decord':\n            video = self.decord_read(video_data)\n        elif self.video_reader == 'opencv':\n            video = self.opencv_read(video_data)\n        else:\n            NotImplementedError(f'Found {self.video_reader}, but support decord or opencv')\n        # import ipdb;ipdb.set_trace()\n        video = self.transform(video)  # T C H W -> T C H W\n        assert video.shape[2] == sample_h and video.shape[3] == sample_w, f'sample_h ({sample_h}), sample_w ({sample_w}), video ({video.shape})'\n\n        # video = torch.rand(105, 3, 640, 640)\n\n        video = video.transpose(0, 1)  # T C H W -> C T H W\n        text = video_data['cap']\n        if not isinstance(text, list):\n            text = [text]\n        text = [random.choice(text)]\n        if video_data.get('aesthetic', None) is not None or video_data.get('aes', None) is not None:\n            aes = video_data.get('aesthetic', None) or video_data.get('aes', None)\n            text = [add_aesthetic_notice_video(text[0], aes)]\n        text = text_preprocessing(text, support_Chinese=self.support_Chinese)\n\n        text = text if random.random() > self.cfg else \"\"\n\n        text_tokens_and_mask_1 = self.tokenizer_1(\n            text,\n            max_length=self.model_max_length,\n            padding='max_length',\n            truncation=True,\n            return_attention_mask=True,\n            add_special_tokens=True,\n            return_tensors='pt'\n        )\n        input_ids_1 = text_tokens_and_mask_1['input_ids']\n        cond_mask_1 = text_tokens_and_mask_1['attention_mask']\n        \n        input_ids_2, cond_mask_2 = None, None\n        if self.tokenizer_2 is not None:\n            text_tokens_and_mask_2 = self.tokenizer_2(\n                text,\n                max_length=self.tokenizer_2.model_max_length,\n                padding='max_length',\n                truncation=True,\n                return_attention_mask=True,\n                add_special_tokens=True,\n                return_tensors='pt'\n            )\n            input_ids_2 = text_tokens_and_mask_2['input_ids']\n            cond_mask_2 = text_tokens_and_mask_2['attention_mask']\n\n        return dict(\n            pixel_values=video, input_ids_1=input_ids_1, cond_mask_1=cond_mask_1, \n            input_ids_2=input_ids_2, cond_mask_2=cond_mask_2,\n            )\n\n    def get_image(self, idx):\n        image_data = dataset_prog.cap_list[idx]  # [{'path': path, 'cap': cap}, ...]\n        sample_h = image_data['resolution']['sample_height']\n        sample_w = image_data['resolution']['sample_width']\n\n        image = Image.open(image_data['path']).convert('RGB')  # [h, w, c]\n        image = torch.from_numpy(np.array(image))  # [h, w, c]\n        image = rearrange(image, 'h w c -> c h w').unsqueeze(0)  #  [1 c h w]\n\n        image = self.transform(image) #  [1 C H W] -> num_img [1 C H W]\n        assert image.shape[2] == sample_h and image.shape[3] == sample_w, f\"image_data: {image_data}, but found image {image.shape}\"\n        # image = torch.rand(1, 3, sample_h, sample_w)\n        image = image.transpose(0, 1)  # [1 C H W] -> [C 1 H W]\n\n        caps = image_data['cap'] if isinstance(image_data['cap'], list) else [image_data['cap']]\n        caps = [random.choice(caps)]\n        if image_data.get('aesthetic', None) is not None or image_data.get('aes', None) is not None:\n            aes = image_data.get('aesthetic', None) or image_data.get('aes', None)\n            caps = [add_aesthetic_notice_image(caps[0], aes)]\n        text = text_preprocessing(caps, support_Chinese=self.support_Chinese)\n        text = text if random.random() > self.cfg else \"\"\n\n        text_tokens_and_mask_1 = self.tokenizer_1(\n            text,\n            max_length=self.model_max_length,\n            padding='max_length',\n            truncation=True,\n            return_attention_mask=True,\n            add_special_tokens=True,\n            return_tensors='pt'\n        )\n        input_ids_1 = text_tokens_and_mask_1['input_ids']  # 1, l\n        cond_mask_1 = text_tokens_and_mask_1['attention_mask']  # 1, l\n        \n        input_ids_2, cond_mask_2 = None, None\n        if self.tokenizer_2 is not None:\n            text_tokens_and_mask_2 = self.tokenizer_2(\n                text,\n                max_length=self.tokenizer_2.model_max_length,\n                padding='max_length',\n                truncation=True,\n                return_attention_mask=True,\n                add_special_tokens=True,\n                return_tensors='pt'\n            )\n            input_ids_2 = text_tokens_and_mask_2['input_ids']  # 1, l\n            cond_mask_2 = text_tokens_and_mask_2['attention_mask']  # 1, l\n\n        return dict(\n            pixel_values=image, input_ids_1=input_ids_1, cond_mask_1=cond_mask_1, \n            input_ids_2=input_ids_2, cond_mask_2=cond_mask_2\n            )\n\n    def define_frame_index(self, data):\n        \n        shape_idx_dict = {}\n        new_cap_list = []\n        sample_size = []\n        aesthetic_score = []\n        cnt_vid = 0\n        cnt_img = 0\n        cnt_too_long = 0\n        cnt_too_short = 0\n        cnt_no_cap = 0\n        cnt_no_resolution = 0\n        cnt_no_aesthetic = 0\n        cnt_img_res_mismatch_stride = 0\n        cnt_vid_res_mismatch_stride = 0\n        cnt_img_aspect_mismatch = 0\n        cnt_vid_aspect_mismatch = 0\n        cnt_img_res_too_small = 0\n        cnt_vid_res_too_small = 0\n        cnt_vid_after_filter = 0\n        cnt_img_after_filter = 0\n        cnt = 0\n        \n\n        with open(data, 'r') as f:\n            folder_anno = [i.strip().split(',') for i in f.readlines() if len(i.strip()) > 0]\n        for sub_root, anno in tqdm(folder_anno):\n            print(f'Building {anno}...')\n            if anno.endswith('.json'):\n                with open(anno, 'r') as f:\n                    sub_list = json.load(f)\n            elif anno.endswith('.pkl'):\n                with open(anno, \"rb\") as f: \n                    sub_list = pickle.load(f)\n            for index, i in enumerate(tqdm(sub_list)):\n                cnt += 1\n                path = os.path.join(sub_root, i['path'])\n                i['path'] = path\n                if path.endswith('.mp4'):\n                    cnt_vid += 1\n                elif path.endswith('.jpg'):\n                    cnt_img += 1\n\n                # ======no aesthetic=====\n                if i.get('aesthetic', None) is None or i.get('aes', None) is None:\n                    cnt_no_aesthetic += 1\n                else:\n                    aesthetic_score.append(i.get('aesthetic', None) or i.get('aes', None))\n\n                # ======no caption=====\n                cap = i.get('cap', None)\n                if cap is None:\n                    cnt_no_cap += 1\n                    continue\n\n                # ======resolution mismatch=====\n                if i.get('resolution', None) is None:\n                    cnt_no_resolution += 1\n                    continue\n                else:\n                    if i['resolution'].get('height', None) is None or i['resolution'].get('width', None) is None:\n                        cnt_no_resolution += 1\n                        continue\n                    else:\n                        height, width = i['resolution']['height'], i['resolution']['width']\n                        if not self.force_resolution:\n                            if height <= 0 or width <= 0:\n                                cnt_no_resolution += 1\n                                continue\n                            \n                            tr_h, tr_w = maxhwresize(height, width, self.max_hxw)\n                            _, _, sample_h, sample_w = get_params(tr_h, tr_w, self.hw_stride)\n\n                            if sample_h <= 0 or sample_w <= 0:\n                                if path.endswith('.mp4'):\n                                    cnt_vid_res_mismatch_stride += 1\n                                elif path.endswith('.jpg'):\n                                    cnt_img_res_mismatch_stride += 1\n                                continue\n                            \n                            # filter min_hxw\n                            if sample_h * sample_w < self.min_hxw:\n                                if path.endswith('.mp4'):\n                                    cnt_vid_res_too_small += 1\n                                elif path.endswith('.jpg'):\n                                    cnt_img_res_too_small += 1\n                                continue\n\n                            # filter aspect\n                            is_pick = filter_resolution(\n                                sample_h, sample_w, max_h_div_w_ratio=self.hw_aspect_thr, min_h_div_w_ratio=1/self.hw_aspect_thr\n                                )\n                            if not is_pick:\n                                if path.endswith('.mp4'):\n                                    cnt_vid_aspect_mismatch += 1\n                                elif path.endswith('.jpg'):\n                                    cnt_img_aspect_mismatch += 1\n                                continue\n\n                            i['resolution'].update(dict(sample_height=sample_h, sample_width=sample_w))\n                            \n                        else:\n                            aspect = self.max_height / self.max_width\n                            is_pick = filter_resolution(\n                                height, width, max_h_div_w_ratio=self.hw_aspect_thr*aspect, min_h_div_w_ratio=1/self.hw_aspect_thr*aspect\n                                )\n                            if not is_pick:\n                                if path.endswith('.mp4'):\n                                    cnt_vid_aspect_mismatch += 1\n                                elif path.endswith('.jpg'):\n                                    cnt_img_aspect_mismatch += 1\n                                continue\n                            sample_h, sample_w = self.max_height, self.max_width\n                            \n                            i['resolution'].update(dict(sample_height=sample_h, sample_width=sample_w))\n\n\n                if path.endswith('.mp4'):\n                    fps = i.get('fps', 24)\n                    # max 5.0 and min 1.0 are just thresholds to filter some videos which have suitable duration. \n                    if i['num_frames'] > self.too_long_factor * (self.num_frames * fps / self.train_fps * self.speed_factor):  # too long video is not suitable for this training stage (self.num_frames)\n                        cnt_too_long += 1\n                        continue\n\n                    # resample in case high fps, such as 50/60/90/144 -> train_fps(e.g, 24)\n                    frame_interval = 1.0 if abs(fps - self.train_fps) < 0.1 else fps / self.train_fps\n                    start_frame_idx = i.get('cut', [0])[0]\n                    i['start_frame_idx'] = start_frame_idx\n                    frame_indices = np.arange(start_frame_idx, start_frame_idx+i['num_frames'], frame_interval).astype(int)\n                    frame_indices = frame_indices[frame_indices < start_frame_idx+i['num_frames']]\n\n                    # comment out it to enable dynamic frames training\n                    if len(frame_indices) < self.num_frames and torch.rand(1, generator=self.generator).item() < self.drop_short_ratio:\n                        cnt_too_short += 1\n                        continue\n\n                    #  too long video will be temporal-crop randomly\n                    if len(frame_indices) > self.num_frames:\n                        begin_index, end_index = self.temporal_sample(len(frame_indices))\n                        frame_indices = frame_indices[begin_index: end_index]\n                        # frame_indices = frame_indices[:self.num_frames]  # head crop\n                    # to find a suitable end_frame_idx, to ensure we do not need pad video\n                    end_frame_idx = find_closest_y(\n                        len(frame_indices), vae_stride_t=self.ae_stride_t, model_ds_t=self.sp_size\n                        )\n                    if end_frame_idx == -1:  # too short that can not be encoded exactly by videovae\n                        cnt_too_short += 1\n                        continue\n                    frame_indices = frame_indices[:end_frame_idx]\n\n                    i['sample_frame_index'] = frame_indices.tolist()\n\n                    new_cap_list.append(i)\n                    cnt_vid_after_filter += 1\n\n                elif path.endswith('.jpg'):  # image\n                    cnt_img_after_filter += 1\n                    i['sample_frame_index'] = [0]\n                    new_cap_list.append(i)\n                \n                else:\n                    raise NameError(f\"Unknown file extention {path.split('.')[-1]}, only support .mp4 for video and .jpg for image\")\n\n                pre_define_shape = f\"{len(i['sample_frame_index'])}x{sample_h}x{sample_w}\"\n                sample_size.append(pre_define_shape)\n                # if shape_idx_dict.get(pre_define_shape, None) is None:\n                #     shape_idx_dict[pre_define_shape] = [index]\n                # else:\n                #     shape_idx_dict[pre_define_shape].append(index)\n        counter = Counter(sample_size)\n        counter_cp = counter\n        if not self.force_resolution and self.max_hxw is not None and self.min_hxw is not None:\n            assert all([np.prod(np.array(k.split('x')[1:]).astype(np.int32)) <= self.max_hxw for k in counter_cp.keys()])\n            assert all([np.prod(np.array(k.split('x')[1:]).astype(np.int32)) >= self.min_hxw for k in counter_cp.keys()])\n\n        len_before_filter_major = len(sample_size)\n        filter_major_num = 4 * self.total_batch_size\n        new_cap_list, sample_size = zip(*[[i, j] for i, j in zip(new_cap_list, sample_size) if counter[j] >= filter_major_num])\n        for idx, shape in enumerate(sample_size):\n            if shape_idx_dict.get(shape, None) is None:\n                shape_idx_dict[shape] = [idx]\n            else:\n                shape_idx_dict[shape].append(idx)\n        cnt_filter_minority = len_before_filter_major - len(sample_size) \n        counter = Counter(sample_size)\n        \n        print(f'no_cap: {cnt_no_cap}, no_resolution: {cnt_no_resolution}\\n'\n                f'too_long: {cnt_too_long}, too_short: {cnt_too_short}\\n'\n                f'cnt_img_res_mismatch_stride: {cnt_img_res_mismatch_stride}, cnt_vid_res_mismatch_stride: {cnt_vid_res_mismatch_stride}\\n'\n                f'cnt_img_res_too_small: {cnt_img_res_too_small}, cnt_vid_res_too_small: {cnt_vid_res_too_small}\\n'\n                f'cnt_img_aspect_mismatch: {cnt_img_aspect_mismatch}, cnt_vid_aspect_mismatch: {cnt_vid_aspect_mismatch}\\n'\n                f'cnt_filter_minority: {cnt_filter_minority}\\n'\n                f'Counter(sample_size): {counter}\\n'\n                f'cnt_vid: {cnt_vid}, cnt_vid_after_filter: {cnt_vid_after_filter}, use_ratio: {round(cnt_vid_after_filter/(cnt_vid+1e-6), 5)*100}%\\n'\n                f'cnt_img: {cnt_img}, cnt_img_after_filter: {cnt_img_after_filter}, use_ratio: {round(cnt_img_after_filter/(cnt_img+1e-6), 5)*100}%\\n'\n                f'before filter: {cnt}, after filter: {len(new_cap_list)}, use_ratio: {round(len(new_cap_list)/cnt, 5)*100}%')\n        # import ipdb;ipdb.set_trace()\n        \n        if len(aesthetic_score) > 0:\n            stats_aesthetic = calculate_statistics(aesthetic_score)\n            print(f\"before filter: {cnt}, after filter: {len(new_cap_list)}\\n\"\n                f\"aesthetic_score: {len(aesthetic_score)}, cnt_no_aesthetic: {cnt_no_aesthetic}\\n\"\n                f\"{len([i for i in aesthetic_score if i>=5.75])} > 5.75, 4.5 > {len([i for i in aesthetic_score if i<=4.5])}\\n\"\n                f\"Mean: {stats_aesthetic['mean']}, Var: {stats_aesthetic['variance']}, Std: {stats_aesthetic['std_dev']}\\n\"\n                f\"Min: {stats_aesthetic['min']}, Max: {stats_aesthetic['max']}\")\n\n        return new_cap_list, sample_size, shape_idx_dict\n    \n    def decord_read(self, video_data):\n        path = video_data['path']\n        predefine_frame_indice = video_data['sample_frame_index']\n        start_frame_idx = video_data['start_frame_idx']\n        clip_total_frames = video_data['num_frames']\n        fps = video_data['fps']\n        s_x, e_x, s_y, e_y = video_data.get('crop', [None, None, None, None])\n\n        predefine_num_frames = len(predefine_frame_indice)\n        # decord_vr = decord.VideoReader(path, ctx=decord.cpu(0), num_threads=1)\n        decord_vr = DecordDecoder(path)\n\n        frame_indices = self.get_actual_frame(\n            fps, start_frame_idx, clip_total_frames, path, predefine_num_frames, predefine_frame_indice\n            )\n        \n        # video_data = decord_vr.get_batch(frame_indices).asnumpy()\n        # video_data = torch.from_numpy(video_data)\n        video_data = decord_vr.get_batch(frame_indices)\n        if video_data is not None:\n            video_data = video_data.permute(0, 3, 1, 2)  # (T, H, W, C) -> (T C H W)\n            if s_y is not None:\n                video_data = video_data[:, :, s_y: e_y, s_x: e_x]\n        else:\n            raise ValueError(f'Get video_data {video_data}')\n        # del decord_vr\n        # gc.collect()\n        return video_data\n    \n    def opencv_read(self, video_data):\n        path = video_data['path']\n        predefine_frame_indice = video_data['sample_frame_index']\n        start_frame_idx = video_data['start_frame_idx']\n        clip_total_frames = video_data['num_frames']\n        fps = video_data['fps']\n        s_x, e_x, s_y, e_y = video_data.get('crop', [None, None, None, None])\n\n        predefine_num_frames = len(predefine_frame_indice)\n        cv2_vr = cv2.VideoCapture(path)\n        if not cv2_vr.isOpened():\n            raise ValueError(f'can not open {path}')\n        frame_indices = self.get_actual_frame(\n            fps, start_frame_idx, clip_total_frames, path, predefine_num_frames, predefine_frame_indice\n            )\n\n        video_data = []\n        for frame_idx in frame_indices:\n            cv2_vr.set(1, frame_idx)\n            _, frame = cv2_vr.read()\n            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n            video_data.append(torch.from_numpy(frame).permute(2, 0, 1))\n        cv2_vr.release()\n        video_data = torch.stack(video_data)  # (T C H W)\n        if s_y is not None:\n            video_data = video_data[:, :, s_y: e_y, s_x: e_x]\n        return video_data\n\n    def get_actual_frame(self, fps, start_frame_idx, clip_total_frames, path, predefine_num_frames, predefine_frame_indice):\n        # resample in case high fps, such as 50/60/90/144 -> train_fps(e.g, 24)\n        frame_interval = 1.0 if abs(fps - self.train_fps) < 0.1 else fps / self.train_fps\n        frame_indices = np.arange(start_frame_idx, start_frame_idx+clip_total_frames, frame_interval).astype(int)\n        frame_indices = frame_indices[frame_indices < start_frame_idx+clip_total_frames]\n        \n        # speed up\n        max_speed_factor = len(frame_indices) / self.num_frames\n        if self.speed_factor > 1 and max_speed_factor > 1:\n            # speed_factor = random.uniform(1.0, min(self.speed_factor, max_speed_factor))\n            speed_factor = min(self.speed_factor, max_speed_factor)\n            target_frame_count = int(len(frame_indices) / speed_factor)\n            speed_frame_idx = np.linspace(0, len(frame_indices) - 1, target_frame_count, dtype=int)\n            frame_indices = frame_indices[speed_frame_idx]\n\n        #  too long video will be temporal-crop randomly\n        if len(frame_indices) > self.num_frames:\n            begin_index, end_index = self.temporal_sample(len(frame_indices))\n            frame_indices = frame_indices[begin_index: end_index]\n            # frame_indices = frame_indices[:self.num_frames]  # head crop\n\n        # to find a suitable end_frame_idx, to ensure we do not need pad video\n        end_frame_idx = find_closest_y(\n            len(frame_indices), vae_stride_t=self.ae_stride_t, model_ds_t=self.sp_size\n            )\n        if end_frame_idx == -1:  # too short that can not be encoded exactly by videovae\n            raise IndexError(f'video ({path}) has {clip_total_frames} frames, but need to sample {len(frame_indices)} frames ({frame_indices})')\n        frame_indices = frame_indices[:end_frame_idx]\n        if predefine_num_frames != len(frame_indices):\n            raise ValueError(f'video ({path}) predefine_num_frames ({predefine_num_frames}) ({predefine_frame_indice}) is not equal with frame_indices ({len(frame_indices)}) ({frame_indices})')\n        if len(frame_indices) < self.num_frames and self.drop_short_ratio >= 1:\n            raise IndexError(f'video ({path}) has {clip_total_frames} frames, but need to sample {len(frame_indices)} frames ({frame_indices})')\n        return frame_indices\n"
  },
  {
    "path": "opensora/dataset/transform.py",
    "content": "import torch\nimport random\nimport numbers\nfrom torchvision.transforms import RandomCrop, RandomResizedCrop\nimport statistics\nimport numpy as np\nimport ftfy\nimport regex as re\nimport html\n\n\ndef _is_tensor_video_clip(clip):\n    if not torch.is_tensor(clip):\n        raise TypeError(\"clip should be Tensor. Got %s\" % type(clip))\n\n    if not clip.ndimension() == 4:\n        raise ValueError(\"clip should be 4D. Got %dD\" % clip.dim())\n\n    return True\n\n\ndef center_crop_arr(pil_image, image_size):\n    \"\"\"\n    Center cropping implementation from ADM.\n    https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126\n    \"\"\"\n    while min(*pil_image.size) >= 2 * image_size:\n        pil_image = pil_image.resize(\n            tuple(x // 2 for x in pil_image.size), resample=Image.BOX\n        )\n\n    scale = image_size / min(*pil_image.size)\n    pil_image = pil_image.resize(\n        tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC\n    )\n\n    arr = np.array(pil_image)\n    crop_y = (arr.shape[0] - image_size) // 2\n    crop_x = (arr.shape[1] - image_size) // 2\n    return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])\n\n\ndef crop(clip, i, j, h, w):\n    \"\"\"\n    Args:\n        clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)\n    \"\"\"\n    if len(clip.size()) != 4:\n        raise ValueError(\"clip should be a 4D tensor\")\n    return clip[..., i: i + h, j: j + w]\n\n\ndef resize(clip, target_size, interpolation_mode):\n    if len(target_size) != 2:\n        raise ValueError(f\"target size should be tuple (height, width), instead got {target_size}\")\n    return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=True, antialias=True)\n\n\ndef resize_scale(clip, target_size, interpolation_mode):\n    if len(target_size) != 2:\n        raise ValueError(f\"target size should be tuple (height, width), instead got {target_size}\")\n    H, W = clip.size(-2), clip.size(-1)\n    scale_ = target_size[0] / min(H, W)\n    return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=True, antialias=True)\n\n\ndef resized_crop(clip, i, j, h, w, size, interpolation_mode=\"bilinear\"):\n    \"\"\"\n    Do spatial cropping and resizing to the video clip\n    Args:\n        clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)\n        i (int): i in (i,j) i.e coordinates of the upper left corner.\n        j (int): j in (i,j) i.e coordinates of the upper left corner.\n        h (int): Height of the cropped region.\n        w (int): Width of the cropped region.\n        size (tuple(int, int)): height and width of resized clip\n    Returns:\n        clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W)\n    \"\"\"\n    if not _is_tensor_video_clip(clip):\n        raise ValueError(\"clip should be a 4D torch.tensor\")\n    clip = crop(clip, i, j, h, w)\n    clip = resize(clip, size, interpolation_mode)\n    return clip\n\n\ndef center_crop(clip, crop_size):\n    if not _is_tensor_video_clip(clip):\n        raise ValueError(\"clip should be a 4D torch.tensor\")\n    h, w = clip.size(-2), clip.size(-1)\n    th, tw = crop_size\n    if h < th or w < tw:\n        raise ValueError(\"height and width must be no smaller than crop_size\")\n\n    i = int(round((h - th) / 2.0))\n    j = int(round((w - tw) / 2.0))\n    return crop(clip, i, j, th, tw)\n\n\ndef center_crop_using_short_edge(clip):\n    if not _is_tensor_video_clip(clip):\n        raise ValueError(\"clip should be a 4D torch.tensor\")\n    h, w = clip.size(-2), clip.size(-1)\n    if h < w:\n        th, tw = h, h\n        i = 0\n        j = int(round((w - tw) / 2.0))\n    else:\n        th, tw = w, w\n        i = int(round((h - th) / 2.0))\n        j = 0\n    return crop(clip, i, j, th, tw)\n\n\n\ndef center_crop_th_tw(clip, th, tw, top_crop):\n    if not _is_tensor_video_clip(clip):\n        raise ValueError(\"clip should be a 4D torch.tensor\")\n    \n    # import ipdb;ipdb.set_trace()\n    h, w = clip.size(-2), clip.size(-1)\n    tr = th / tw\n    if h / w > tr:\n        # hxw 720x1280  thxtw 320x640  hw_raito 9/16 > tr_ratio 8/16  newh=1280*320/640=640  neww=1280 \n        new_h = int(w * tr)\n        new_w = w\n    else:\n        # hxw 720x1280  thxtw 480x640  hw_raito 9/16 < tr_ratio 12/16   newh=720 neww=720/(12/16)=960  \n        # hxw 1080x1920  thxtw 720x1280  hw_raito 9/16 = tr_ratio 9/16   newh=1080 neww=1080/(9/16)=1920  \n        new_h = h\n        new_w = int(h / tr)\n    \n    i = 0 if top_crop else int(round((h - new_h) / 2.0))\n    j = int(round((w - new_w) / 2.0))\n    return crop(clip, i, j, new_h, new_w)\n\ndef random_shift_crop(clip):\n    '''\n    Slide along the long edge, with the short edge as crop size\n    '''\n    if not _is_tensor_video_clip(clip):\n        raise ValueError(\"clip should be a 4D torch.tensor\")\n    h, w = clip.size(-2), clip.size(-1)\n\n    if h <= w:\n        long_edge = w\n        short_edge = h\n    else:\n        long_edge = h\n        short_edge = w\n\n    th, tw = short_edge, short_edge\n\n    i = torch.randint(0, h - th + 1, size=(1,)).item()\n    j = torch.randint(0, w - tw + 1, size=(1,)).item()\n    return crop(clip, i, j, th, tw)\n\n\ndef to_tensor(clip):\n    \"\"\"\n    Convert tensor data type from uint8 to float, divide value by 255.0 and\n    permute the dimensions of clip tensor\n    Args:\n        clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)\n    Return:\n        clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)\n    \"\"\"\n    _is_tensor_video_clip(clip)\n    if not clip.dtype == torch.uint8:\n        raise TypeError(\"clip tensor should have data type uint8. Got %s\" % str(clip.dtype))\n    # return clip.float().permute(3, 0, 1, 2) / 255.0\n    return clip.float() / 255.0\n\n\ndef to_tensor_after_resize(clip):\n    \"\"\"\n    Convert resized tensor to [0, 1]\n    Args:\n        clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)\n    Return:\n        clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W), but in [0, 1]\n    \"\"\"\n    _is_tensor_video_clip(clip)\n    # return clip.float().permute(3, 0, 1, 2) / 255.0\n    return clip.float() / 255.0\n\ndef normalize(clip, mean, std, inplace=False):\n    \"\"\"\n    Args:\n        clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)\n        mean (tuple): pixel RGB mean. Size is (3)\n        std (tuple): pixel standard deviation. Size is (3)\n    Returns:\n        normalized clip (torch.tensor): Size is (T, C, H, W)\n    \"\"\"\n    if not _is_tensor_video_clip(clip):\n        raise ValueError(\"clip should be a 4D torch.tensor\")\n    if not inplace:\n        clip = clip.clone()\n    mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)\n    # print(mean)\n    std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)\n    clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])\n    return clip\n\n\ndef hflip(clip):\n    \"\"\"\n    Args:\n        clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)\n    Returns:\n        flipped clip (torch.tensor): Size is (T, C, H, W)\n    \"\"\"\n    if not _is_tensor_video_clip(clip):\n        raise ValueError(\"clip should be a 4D torch.tensor\")\n    return clip.flip(-1)\n\n\nclass RandomCropVideo:\n    def __init__(self, size):\n        if isinstance(size, numbers.Number):\n            self.size = (int(size), int(size))\n        else:\n            self.size = size\n\n    def __call__(self, clip):\n        \"\"\"\n        Args:\n            clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)\n        Returns:\n            torch.tensor: randomly cropped video clip.\n                size is (T, C, OH, OW)\n        \"\"\"\n        i, j, h, w = self.get_params(clip)\n        return crop(clip, i, j, h, w)\n\n    def get_params(self, clip):\n        h, w = clip.shape[-2:]\n        th, tw = self.size\n\n        if h < th or w < tw:\n            raise ValueError(f\"Required crop size {(th, tw)} is larger than input image size {(h, w)}\")\n\n        if w == tw and h == th:\n            return 0, 0, h, w\n\n        i = torch.randint(0, h - th + 1, size=(1,)).item()\n        j = torch.randint(0, w - tw + 1, size=(1,)).item()\n\n        return i, j, th, tw\n\n    def __repr__(self) -> str:\n        return f\"{self.__class__.__name__}(size={self.size})\"\n\n\ndef get_params(h, w, stride):\n    \n    th, tw = h // stride * stride, w // stride * stride\n\n    i = (h - th) // 2\n    j = (w - tw) // 2\n\n    return i, j, th, tw \n    \nclass SpatialStrideCropVideo:\n    def __init__(self, stride):\n        self.stride = stride\n\n    def __call__(self, clip):\n        \"\"\"\n        Args:\n            clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)\n        Returns:\n            torch.tensor: cropped video clip by stride.\n                size is (T, C, OH, OW)\n        \"\"\"\n        h, w = clip.shape[-2:] \n        i, j, h, w = get_params(h, w, self.stride)\n        return crop(clip, i, j, h, w)\n\n\n    def __repr__(self) -> str:\n        return f\"{self.__class__.__name__}(stride={self.stride})\"  \n\ndef longsideresize(h, w, size, skip_low_resolution):\n    if h <= size[0] and w <= size[1] and skip_low_resolution:\n        return h, w\n    \n    if h / w > size[0] / size[1]:\n        # hxw 720x1280  size 320x640  hw_raito 9/16 > size_ratio 8/16  neww=320/720*1280=568  newh=320  \n        w = int(size[0] / h * w)\n        h = size[0]\n    else:\n        # hxw 720x1280  size 480x640  hw_raito 9/16 < size_ratio 12/16   newh=640/1280*720=360 neww=640  \n        # hxw 1080x1920  size 720x1280  hw_raito 9/16 = size_ratio 9/16   newh=1280/1920*1080=720 neww=1280  \n        h = int(size[1] / w * h)\n        w = size[1]\n    return h, w\n\ndef maxhwresize(ori_height, ori_width, max_hxw):\n    if ori_height * ori_width > max_hxw:\n        scale_factor = np.sqrt(max_hxw / (ori_height * ori_width))\n        new_height = int(ori_height * scale_factor)\n        new_width = int(ori_width * scale_factor)\n    else:\n        new_height = ori_height\n        new_width = ori_width\n    return new_height, new_width\n\nclass LongSideResizeVideo:\n    '''\n    First use the long side,\n    then resize to the specified size\n    '''\n\n    def __init__(\n            self,\n            size,\n            skip_low_resolution=False, \n            interpolation_mode=\"bilinear\",\n    ):\n        self.size = size\n        self.skip_low_resolution = skip_low_resolution\n        self.interpolation_mode = interpolation_mode\n\n    def __call__(self, clip):\n        \"\"\"\n        Args:\n            clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)\n        Returns:\n            torch.tensor: scale resized video clip.\n        \"\"\"\n        _, _, h, w = clip.shape\n        tr_h, tr_w = longsideresize(h, w, self.size, self.skip_low_resolution)\n        if h == tr_h and w == tr_w:\n            return clip\n        resize_clip = resize(clip, target_size=(tr_h, tr_w),\n                                         interpolation_mode=self.interpolation_mode)\n        return resize_clip\n\n    def __repr__(self) -> str:\n        return f\"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}\"\n\n\nclass MaxHWResizeVideo:\n    '''\n    First use the h*w,\n    then resize to the specified size\n    '''\n\n    def __init__(\n            self,\n            max_hxw,\n            interpolation_mode=\"bilinear\",\n    ):\n        self.max_hxw = max_hxw\n        self.interpolation_mode = interpolation_mode\n\n    def __call__(self, clip):\n        \"\"\"\n        Args:\n            clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)\n        Returns:\n            torch.tensor: scale resized video clip.\n        \"\"\"\n        _, _, h, w = clip.shape\n        tr_h, tr_w = maxhwresize(h, w, self.max_hxw)\n        if h == tr_h and w == tr_w:\n            return clip\n        resize_clip = resize(clip, target_size=(tr_h, tr_w),\n                                         interpolation_mode=self.interpolation_mode)\n        return resize_clip\n\n    def __repr__(self) -> str:\n        return f\"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}\"\n\n\nclass CenterCropResizeVideo:\n    '''\n    First use the short side for cropping length,\n    center crop video, then resize to the specified size\n    '''\n\n    def __init__(\n            self,\n            size,\n            top_crop=False, \n            interpolation_mode=\"bilinear\",\n    ):\n        if len(size) != 2:\n            raise ValueError(f\"size should be tuple (height, width), instead got {size}\")\n        self.size = size\n        self.top_crop = top_crop\n        self.interpolation_mode = interpolation_mode\n\n    def __call__(self, clip):\n        \"\"\"\n        Args:\n            clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)\n        Returns:\n            torch.tensor: scale resized / center cropped video clip.\n                size is (T, C, crop_size, crop_size)\n        \"\"\"\n        clip_center_crop = center_crop_th_tw(clip, self.size[0], self.size[1], top_crop=self.top_crop)\n        clip_center_crop_resize = resize(clip_center_crop, target_size=self.size,\n                                         interpolation_mode=self.interpolation_mode)\n        return clip_center_crop_resize\n\n    def __repr__(self) -> str:\n        return f\"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}\"\n\n\nclass UCFCenterCropVideo:\n    '''\n    First scale to the specified size in equal proportion to the short edge,\n    then center cropping\n    '''\n\n    def __init__(\n            self,\n            size,\n            interpolation_mode=\"bilinear\",\n    ):\n        if isinstance(size, tuple):\n            if len(size) != 2:\n                raise ValueError(f\"size should be tuple (height, width), instead got {size}\")\n            self.size = size\n        else:\n            self.size = (size, size)\n\n        self.interpolation_mode = interpolation_mode\n\n    def __call__(self, clip):\n        \"\"\"\n        Args:\n            clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)\n        Returns:\n            torch.tensor: scale resized / center cropped video clip.\n                size is (T, C, crop_size, crop_size)\n        \"\"\"\n        clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode)\n        clip_center_crop = center_crop(clip_resize, self.size)\n        return clip_center_crop\n\n    def __repr__(self) -> str:\n        return f\"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}\"\n\n\nclass KineticsRandomCropResizeVideo:\n    '''\n    Slide along the long edge, with the short edge as crop size. And resie to the desired size.\n    '''\n\n    def __init__(\n            self,\n            size,\n            interpolation_mode=\"bilinear\",\n    ):\n        if isinstance(size, tuple):\n            if len(size) != 2:\n                raise ValueError(f\"size should be tuple (height, width), instead got {size}\")\n            self.size = size\n        else:\n            self.size = (size, size)\n\n        self.interpolation_mode = interpolation_mode\n\n    def __call__(self, clip):\n        clip_random_crop = random_shift_crop(clip)\n        clip_resize = resize(clip_random_crop, self.size, self.interpolation_mode)\n        return clip_resize\n\n\nclass CenterCropVideo:\n    def __init__(\n            self,\n            size,\n            interpolation_mode=\"bilinear\",\n    ):\n        if isinstance(size, tuple):\n            if len(size) != 2:\n                raise ValueError(f\"size should be tuple (height, width), instead got {size}\")\n            self.size = size\n        else:\n            self.size = (size, size)\n\n        self.interpolation_mode = interpolation_mode\n\n    def __call__(self, clip):\n        \"\"\"\n        Args:\n            clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)\n        Returns:\n            torch.tensor: center cropped video clip.\n                size is (T, C, crop_size, crop_size)\n        \"\"\"\n        clip_center_crop = center_crop(clip, self.size)\n        return clip_center_crop\n\n    def __repr__(self) -> str:\n        return f\"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}\"\n\n\nclass NormalizeVideo:\n    \"\"\"\n    Normalize the video clip by mean subtraction and division by standard deviation\n    Args:\n        mean (3-tuple): pixel RGB mean\n        std (3-tuple): pixel RGB standard deviation\n        inplace (boolean): whether do in-place normalization\n    \"\"\"\n\n    def __init__(self, mean, std, inplace=False):\n        self.mean = mean\n        self.std = std\n        self.inplace = inplace\n\n    def __call__(self, clip):\n        \"\"\"\n        Args:\n            clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W)\n        \"\"\"\n        return normalize(clip, self.mean, self.std, self.inplace)\n\n    def __repr__(self) -> str:\n        return f\"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})\"\n\n\nclass ToTensorVideo:\n    \"\"\"\n    Convert tensor data type from uint8 to float, divide value by 255.0 and\n    permute the dimensions of clip tensor\n    \"\"\"\n\n    def __init__(self):\n        pass\n\n    def __call__(self, clip):\n        \"\"\"\n        Args:\n            clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)\n        Return:\n            clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)\n        \"\"\"\n        return to_tensor(clip)\n\n    def __repr__(self) -> str:\n        return self.__class__.__name__\n    \n\nclass ToTensorAfterResize:\n    \"\"\"\n    Convert tensor data type from uint8 to float, divide value by 255.0 and\n    permute the dimensions of clip tensor\n    \"\"\"\n\n    def __init__(self):\n        pass\n\n    def __call__(self, clip):\n        \"\"\"\n        Args:\n            clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)\n        Return:\n            clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W), but in [0, 1]\n        \"\"\"\n        return to_tensor_after_resize(clip)\n\n    def __repr__(self) -> str:\n        return self.__class__.__name__\n\n\n\nclass RandomHorizontalFlipVideo:\n    \"\"\"\n    Flip the video clip along the horizontal direction with a given probability\n    Args:\n        p (float): probability of the clip being flipped. Default value is 0.5\n    \"\"\"\n\n    def __init__(self, p=0.5):\n        self.p = p\n\n    def __call__(self, clip):\n        \"\"\"\n        Args:\n            clip (torch.tensor): Size is (T, C, H, W)\n        Return:\n            clip (torch.tensor): Size is (T, C, H, W)\n        \"\"\"\n        if random.random() < self.p:\n            clip = hflip(clip)\n        return clip\n\n    def __repr__(self) -> str:\n        return f\"{self.__class__.__name__}(p={self.p})\"\n\n\n#  ------------------------------------------------------------\n#  ---------------------  Sampling  ---------------------------\n#  ------------------------------------------------------------\nclass TemporalRandomCrop(object):\n    \"\"\"Temporally crop the given frame indices at a random location.\n\n    Args:\n        size (int): Desired length of frames will be seen in the model.\n    \"\"\"\n\n    def __init__(self, size):\n        self.size = size\n\n    def __call__(self, total_frames):\n        rand_end = max(0, total_frames - self.size - 1)\n        begin_index = random.randint(0, rand_end)\n        end_index = min(begin_index + self.size, total_frames)\n        return begin_index, end_index\n\nclass DynamicSampleDuration(object):\n    \"\"\"Temporally crop the given frame indices at a random location.\n\n    Args:\n        size (int): Desired length of frames will be seen in the model.\n    \"\"\"\n\n    def __init__(self, t_stride, extra_1):\n        self.t_stride = t_stride\n        self.extra_1 = extra_1\n\n    def __call__(self, t, h, w):\n        if self.extra_1:\n            t = t - 1\n        truncate_t_list = list(range(t+1))[t//2:][::self.t_stride]  # need half at least\n        truncate_t = random.choice(truncate_t_list)\n        if self.extra_1:\n            truncate_t = truncate_t + 1\n        return 0, truncate_t\n\nkeywords = [\n        ' man ', ' woman ', ' person ', ' people ', 'human',\n        ' individual ', ' child ', ' kid ', ' girl ', ' boy ',\n    ]\nkeywords += [i[:-1] + 's ' for i in keywords]\n\nmasking_notices = [\n    \"Note: The faces in this image are blurred.\",\n    \"This image contains faces that have been pixelated.\",\n    \"Notice: Faces in this image are masked.\",\n    \"Please be aware that the faces in this image are obscured.\",\n    \"The faces in this image are hidden.\",\n    \"This is an image with blurred faces.\",\n    \"The faces in this image have been processed.\",\n    \"Attention: Faces in this image are not visible.\",\n    \"The faces in this image are partially blurred.\",\n    \"This image has masked faces.\",\n    \"Notice: The faces in this picture have been altered.\",\n    \"This is a picture with obscured faces.\",\n    \"The faces in this image are pixelated.\",\n    \"Please note, the faces in this image have been blurred.\",\n    \"The faces in this photo are hidden.\",\n    \"The faces in this picture have been masked.\",\n    \"Note: The faces in this picture are altered.\",\n    \"This is an image where faces are not clear.\",\n    \"Faces in this image have been obscured.\",\n    \"This picture contains masked faces.\",\n    \"The faces in this image are processed.\",\n    \"The faces in this picture are not visible.\",\n    \"Please be aware, the faces in this photo are pixelated.\",\n    \"The faces in this picture have been blurred.\", \n]\n\nwebvid_watermark_notices = [\n    \"This video has a faint Shutterstock watermark in the center.\", \n    \"There is a slight Shutterstock watermark in the middle of this video.\", \n    \"The video contains a subtle Shutterstock watermark in the center.\", \n    \"This video features a light Shutterstock watermark at its center.\", \n    \"A faint Shutterstock watermark is present in the middle of this video.\", \n    \"There is a mild Shutterstock watermark at the center of this video.\", \n    \"This video has a slight Shutterstock watermark in the middle.\", \n    \"You can see a faint Shutterstock watermark in the center of this video.\", \n    \"A subtle Shutterstock watermark appears in the middle of this video.\", \n    \"This video includes a light Shutterstock watermark at its center.\", \n]\n\n\nhigh_aesthetic_score_notices_video = [\n    \"This video has a high aesthetic quality.\", \n    \"The beauty of this video is exceptional.\", \n    \"This video scores high in aesthetic value.\", \n    \"With its harmonious colors and balanced composition.\", \n    \"This video ranks highly for aesthetic quality\", \n    \"The artistic quality of this video is excellent.\", \n    \"This video is rated high for beauty.\", \n    \"The aesthetic quality of this video is impressive.\", \n    \"This video has a top aesthetic score.\", \n    \"The visual appeal of this video is outstanding.\", \n]\n\nlow_aesthetic_score_notices_video = [\n    \"This video has a low aesthetic quality.\", \n    \"The beauty of this video is minimal.\", \n    \"This video scores low in aesthetic appeal.\", \n    \"The aesthetic quality of this video is below average.\", \n    \"This video ranks low for beauty.\", \n    \"The artistic quality of this video is lacking.\", \n    \"This video has a low score for aesthetic value.\", \n    \"The visual appeal of this video is low.\", \n    \"This video is rated low for beauty.\", \n    \"The aesthetic quality of this video is poor.\", \n]\n\n\nhigh_aesthetic_score_notices_image = [\n    \"This image has a high aesthetic quality.\", \n    \"The beauty of this image is exceptional\", \n    \"This photo scores high in aesthetic value.\", \n    \"With its harmonious colors and balanced composition.\", \n    \"This image ranks highly for aesthetic quality.\", \n    \"The artistic quality of this photo is excellent.\", \n    \"This image is rated high for beauty.\", \n    \"The aesthetic quality of this image is impressive.\", \n    \"This photo has a top aesthetic score.\", \n    \"The visual appeal of this image is outstanding.\", \n]\n\nlow_aesthetic_score_notices_image = [\n    \"This image has a low aesthetic quality.\", \n    \"The beauty of this image is minimal.\", \n    \"This image scores low in aesthetic appeal.\", \n    \"The aesthetic quality of this image is below average.\", \n    \"This image ranks low for beauty.\", \n    \"The artistic quality of this image is lacking.\", \n    \"This image has a low score for aesthetic value.\", \n    \"The visual appeal of this image is low.\", \n    \"This image is rated low for beauty.\", \n    \"The aesthetic quality of this image is poor.\", \n]\n\nhigh_aesthetic_score_notices_image_human = [\n    \"High-quality image with visible human features and high aesthetic score.\", \n    \"Clear depiction of an individual in a high-quality image with top aesthetics.\", \n    \"High-resolution photo showcasing visible human details and high beauty rating.\", \n    \"Detailed, high-quality image with well-defined human subject and strong aesthetic appeal.\", \n    \"Sharp, high-quality portrait with clear human features and high aesthetic value.\", \n    \"High-quality image featuring a well-defined human presence and exceptional aesthetics.\", \n    \"Visible human details in a high-resolution photo with a high aesthetic score.\", \n    \"Clear, high-quality image with prominent human subject and superior aesthetic rating.\", \n    \"High-quality photo capturing a visible human with excellent aesthetics.\", \n    \"Detailed, high-quality image of a human with high visual appeal and aesthetic value.\", \n]\n\n\ndef add_masking_notice(caption):\n    if any(keyword in caption for keyword in keywords):\n        notice = random.choice(masking_notices)\n        return random.choice([caption + ' ' + notice, notice + ' ' + caption])\n    return caption\n\ndef add_webvid_watermark_notice(caption):\n    notice = random.choice(webvid_watermark_notices)\n    return random.choice([caption + ' ' + notice, notice + ' ' + caption])\n\ndef add_aesthetic_notice_video(caption, aesthetic_score):\n    if aesthetic_score <= 4.25:\n        notice = random.choice(low_aesthetic_score_notices_video)\n        return random.choice([caption + ' ' + notice, notice + ' ' + caption])\n    if aesthetic_score >= 5.75:\n        notice = random.choice(high_aesthetic_score_notices_video)\n        return random.choice([caption + ' ' + notice, notice + ' ' + caption])\n    return caption\n\n\n\ndef add_aesthetic_notice_image(caption, aesthetic_score):\n    if aesthetic_score <= 4.25:\n        notice = random.choice(low_aesthetic_score_notices_image)\n        return random.choice([caption + ' ' + notice, notice + ' ' + caption])\n    if aesthetic_score >= 5.75:\n        notice = random.choice(high_aesthetic_score_notices_image)\n        return random.choice([caption + ' ' + notice, notice + ' ' + caption])\n    return caption\n\ndef add_high_aesthetic_notice_image(caption):\n    notice = random.choice(high_aesthetic_score_notices_image)\n    return random.choice([caption + ' ' + notice, notice + ' ' + caption])\n\ndef add_high_aesthetic_notice_image_human(caption):\n    notice = random.choice(high_aesthetic_score_notices_image_human)\n    return random.choice([caption + ' ' + notice, notice + ' ' + caption])\n\ndef basic_clean(text):\n    text = ftfy.fix_text(text)\n    text = html.unescape(html.unescape(text))\n    return text.strip()\n\n\ndef whitespace_clean(text):\n    text = re.sub(r\"\\s+\", \" \", text)\n    text = text.strip()\n    return text\n\n\ndef clean_youtube(text, is_tags=False):\n    text = text.lower() + ' '\n    text = re.sub(\n        r'#video|video|#shorts|shorts| shorts|#short| short|#youtubeshorts|youtubeshorts|#youtube| youtube|#shortsyoutube|#ytshorts|ytshorts|#ytshort|#shortvideo|shortvideo|#shortsfeed|#tiktok|tiktok|#tiktokchallenge|#myfirstshorts|#myfirstshort|#viral|viralvideo|viral|#viralshorts|#virlshort|#ytviralshorts|#instagram',\n        ' ', text)\n    text = re.sub(r' s |short|youtube|virlshort|#', ' ', text)\n    pattern = r'[^a-zA-Z0-9\\s\\.,;:?!\\'\\\"|]'\n    if is_tags:\n        pattern = r'[^a-zA-Z0-9\\s]'\n    text = re.sub(pattern, '', text)\n    text = whitespace_clean(basic_clean(text))\n    return text\n\ndef clean_vidal(text):\n    title_hashtags = text.split('#')\n    title, hashtags = title_hashtags[0], '#' + '#'.join(title_hashtags[1:])\n    title = clean_youtube(title)\n    hashtags = clean_youtube(hashtags, is_tags=True)\n    text = title + ', ' + hashtags\n    if text == '' or text.isspace():\n        raise ValueError('text is empty')\n    return text\n\ndef calculate_statistics(data):\n    if len(data) == 0:\n        return None\n    data = np.array(data)\n    mean = np.mean(data)\n    variance = np.var(data)\n    std_dev = np.std(data)\n    minimum = np.min(data)\n    maximum = np.max(data)\n\n    return {\n        'mean': mean,\n        'variance': variance,\n        'std_dev': std_dev,\n        'min': minimum,\n        'max': maximum\n    }\n\nif __name__ == '__main__':\n    from torchvision import transforms\n    import torchvision.io as io\n    import numpy as np\n    from torchvision.utils import save_image\n    import os\n\n    vframes, aframes, info = io.read_video(\n        filename='./v_Archery_g01_c03.avi',\n        pts_unit='sec',\n        output_format='TCHW'\n    )\n\n    trans = transforms.Compose([\n        ToTensorVideo(),\n        RandomHorizontalFlipVideo(),\n        UCFCenterCropVideo(512),\n        # NormalizeVideo(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),\n        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)\n    ])\n\n    target_video_len = 32\n    frame_interval = 1\n    total_frames = len(vframes)\n    print(total_frames)\n\n    temporal_sample = TemporalRandomCrop(target_video_len * frame_interval)\n\n    # Sampling video frames\n    start_frame_ind, end_frame_ind = temporal_sample(total_frames)\n    # print(start_frame_ind)\n    # print(end_frame_ind)\n    assert end_frame_ind - start_frame_ind >= target_video_len\n    frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, target_video_len, dtype=int)\n    print(frame_indice)\n\n    select_vframes = vframes[frame_indice]\n    print(select_vframes.shape)\n    print(select_vframes.dtype)\n\n    select_vframes_trans = trans(select_vframes)\n    print(select_vframes_trans.shape)\n    print(select_vframes_trans.dtype)\n\n    select_vframes_trans_int = ((select_vframes_trans * 0.5 + 0.5) * 255).to(dtype=torch.uint8)\n    print(select_vframes_trans_int.dtype)\n    print(select_vframes_trans_int.permute(0, 2, 3, 1).shape)\n\n    io.write_video('./test.avi', select_vframes_trans_int.permute(0, 2, 3, 1), fps=8)\n\n    for i in range(target_video_len):\n        save_image(select_vframes_trans[i], os.path.join('./test000', '%04d.png' % i), normalize=True,\n                   value_range=(-1, 1))\n"
  },
  {
    "path": "opensora/dataset/virtual_disk.py",
    "content": "import subprocess\nimport json\nimport pickle\nfrom collections import OrderedDict\nfrom opensora.npu_config import npu_config\n\nimport sys\nimport os\n\nclass SuppressStdout:\n    _instance = None\n\n    def __new__(cls, *args, **kwargs):\n        if cls._instance is None:\n            cls._instance = super(SuppressStdout, cls).__new__(cls, *args, **kwargs)\n        return cls._instance\n\n    def __enter__(self):\n        self._original_stdout = sys.stdout\n        sys.stdout = open(os.devnull, 'w')\n\n    def __exit__(self, exc_type, exc_value, traceback):\n        sys.stdout.close()\n        sys.stdout = self._original_stdout\n\n# 创建单例\n\n\nclass ObsConnection:\n    \"\"\"\n    AK, SK, STS_TOKEN临时密钥有效时效云计算网站最长为24h\n    buckets & object: https://uconsole.ccaicc.com/#/mgt/modelarts -> 对象控制台\n    keys & tokens: https://uconsole.ccaicc.com/#/mgt/modelarts -> 对象控制台 -> 获取访问密匙(AK 和 SK)\n    \"\"\"\n    def __init__(self):\n        with open(f\"{npu_config.work_path}/scripts/train_data/key.json\", \"r\") as f:\n            key = json.load(f)\n        self.AK = key[\"AK\"]\n        self.SK = key[\"SK\"]\n        self.endpoint = key[\"EP\"]\n        self.bucket = \"sora\"\n        self.suppress_stdout = SuppressStdout()\n    \n    def connect(self, obs):\n        config_command = [\n            obs, 'config',\n            '-i=' + self.AK,\n            '-k=' + self.SK,\n            '-e=' + self.endpoint\n        ]\n        result = subprocess.run(config_command, capture_output=True, text=True)\n        if result.returncode != 0:\n            print(f\"Failed to configure obsutil: {result.stderr}\")\n        else:\n            print(\"Successfully configured obsutil\")\n\nclass VirtualDisk:\n    \"\"\"\n    :param storage_dir: 内存虚拟磁盘的挂载点路径。\n    :param size: 内存虚拟磁盘的大小，例如 '1G'。\n    :param obs: linux 系统里面obs具体位置\n    :param connection: 抽象出obs连接管理\n    \"\"\"\n    def __init__(self, storage_dir, size=\"1G\", obs=\"/home/opensora/obsutil_linux_arm64_5.5.12/obsutil\"):\n        self.obs = obs\n        self.connection = ObsConnection()\n        self.connection.connect(obs)\n        os.makedirs(storage_dir, exist_ok=True)\n        self.storage_dir = storage_dir\n        self.size = self._convert_size_to_bytes(size)\n        if not self.is_tmpfs_mounted():\n            self.create_ramdisk()\n        else:\n            print(f\"{self.storage_dir} is already mounted as tmpfs.\")\n        self.index_file = os.path.join(self.storage_dir, 'index.pkl')\n        self.index = self.load_index()\n        self.lru = OrderedDict()\n        self.current_size = self.get_total_storage_size()  # 初始化时计算总大小\n    \n    def _convert_size_to_bytes(self, size):\n        unit = size[-1].upper()\n        size_value = int(size[:-1])\n        if unit == 'K':\n            return size_value * 1024\n        elif unit == 'M':\n            return size_value * 1024 ** 2\n        elif unit == 'G':\n            return size_value * 1024 ** 3\n        else:\n            raise ValueError(\"Invalid size unit. Use K, M, or G.\")\n\n    \"\"\"\n    创建并挂载一个 tmpfs 类型的内存虚拟磁盘。\n    \"\"\"\n    def create_ramdisk(self):\n        try:\n            # 如果挂载点目录不存在，创建它\n            if not os.path.exists(self.storage_dir):\n                os.makedirs(self.storage_dir)     \n            # 挂载 tmpfs 到挂载点\n            subprocess.run(['sudo', 'mount', '-t', 'tmpfs', '-o', f'size={self.size}', 'tmpfs', self.storage_dir], check=True)\n            print(f\"Successfully mounted tmpfs on {self.storage_dir} with size {self.size}.\")\n        \n        except subprocess.CalledProcessError as e:\n            print(f\"Failed to mount tmpfs: {e}\")\n        except Exception as e:\n            print(f\"An error occurred: {e}\")\n\n    def load_index(self):\n        \"\"\"\n        加载索引文件。\n        :return: 索引字典。\n        \"\"\"\n        if os.path.exists(self.index_file):\n            with open(self.index_file, 'rb') as f:\n                return pickle.load(f)\n        return {}\n\n    def save_index(self):\n        \"\"\"\n        保存索引文件。\n        \"\"\"\n        with open(self.index_file, 'wb') as f:\n            pickle.dump(self.index, f)\n \n    \"\"\"\n    取消挂载内存虚拟磁盘。\n\n    :param storage_dir: 内存虚拟磁盘的挂载点路径。\n    \"\"\"\n    def unmount_ramdisk(self):\n        try:\n            # 确保没有进程在使用挂载点后取消挂载\n            subprocess.run(['sudo', 'umount', self.storage_dir], check=True)\n            print(f\"Successfully unmounted tmpfs from {self.storage_dir}.\")\n        except subprocess.CalledProcessError as e:\n            print(f\"Failed to unmount tmpfs: {e}\")\n        except Exception as e:\n            print(f\"An error occurred: {e}\")\n\n    \"\"\"\n    检查挂载点是否已经被挂载为 tmpfs。\n    :param storage_dir: 挂载点路径。\n    :return: 如果已挂载为 tmpfs，返回 True；否则返回 False。\n    \"\"\"\n    def is_tmpfs_mounted(self):\n        try:\n            result = subprocess.run(['mountpoint', '-q', self.storage_dir], check=False)\n            if result.returncode == 0:\n                return True\n            return False\n        except Exception as e:\n            print(f\"An error occurred while checking if tmpfs is mounted: {e}\")\n            return False\n\n    def get_data(self, key):\n        \"\"\"\n        获取存储在本地磁盘上的数据。如果数据不存在，通过 obsutil 从远端获取并存储。\n        :param key: 数据的唯一键。\n        :return: 数据。\n        \"\"\"\n        # if key in self.index:\n        #     data_file = self.index[key]\n        #     if os.path.exists(data_file):\n        #         self.lru.move_to_end(key)\n        #         with open(data_file, 'rb') as f:\n        #             # print(f\"Successfully get {key} from local\")\n        #             return pickle.load(f)\n\n\n        # 如果数据不存在，使用 obsutil 从远端获取\n        object_name = key  # 假设 key 对应于远端对象名称\n        local_path = os.path.join(self.storage_dir, key)\n\n        with self.connection.suppress_stdout:\n            self.download_and_convert_to_pickle(self.connection.bucket, object_name, local_path)\n\n        # 保存数据的位置\n        # self.index[key] = local_path\n        # self.save_index()\n        # self.lru[key] = local_path\n        #\n        # file_size = os.path.getsize(local_path)\n        # self.current_size += file_size\n\n        # self.ensure_storage_limit()\n\n        return local_path\n\n    def del_data(self, local_path):\n        os.remove(local_path)\n\n    def download_and_convert_to_pickle(self, bucket, object_name, local_path):\n        \"\"\"\n        使用 obsutil 从 OBS 下载文件并转换为 pickle 格式存储到本地路径。\n        :param bucket: OBS 存储桶名称。\n        :param object_name: OBS 中的对象名称。\n        :param local_path: 本地文件路径。\n        \"\"\"\n        # try:\n            # 下载文件到local_path路径\n        subprocess.run([self.obs, 'cp', f'obs://{bucket}/{object_name}', local_path], check=True)\n            # print(f\"Successfully downloaded obs://{bucket}/{object_name} to {local_path}.\")\n\n        # except subprocess.CalledProcessError as e:\n        #     print(f\"Failed to download obs://{bucket}/{object_name} to {local_path}: {e}\")\n\n    def ensure_storage_limit(self):\n        \"\"\"\n        确保存储总大小不超过虚拟磁盘大小，超出时根据LRU策略删除最旧的文件。\n        \"\"\"\n        while self.current_size > self.size:\n            oldest_key, oldest_path = self.lru.popitem(last=False)\n            file_size = os.path.getsize(oldest_path)\n            os.remove(oldest_path)\n            del self.index[oldest_key]\n            self.save_index()\n            print(f\"Removed {oldest_key} to free up {file_size} bytes.\")\n            self.current_size -= file_size\n\n    def get_total_storage_size(self):\n        \"\"\"\n        获取当前所有存储文件的总大小。\n        :return: 总大小（字节）。\n        \"\"\"\n        total_size = 0\n        for path in self.lru.values():\n            if os.path.exists(path):\n                total_size += os.path.getsize(path)\n        return total_size"
  },
  {
    "path": "opensora/models/__init__.py",
    "content": "from .causalvideovae import CausalVAEModelWrapper, WFVAEModelWrapper"
  },
  {
    "path": "opensora/models/causalvideovae/__init__.py",
    "content": "from torchvision.transforms import Lambda\nfrom .model.vae import CausalVAEModel, WFVAEModel\nfrom einops import rearrange\nimport torch\ntry:\n    import torch_npu\n    from opensora.npu_config import npu_config\nexcept:\n    torch_npu = None\n    npu_config = None\n    pass\nimport torch.nn as nn\nimport torch\n\nclass CausalVAEModelWrapper(nn.Module):\n    def __init__(self, model_path, subfolder=None, cache_dir=None, use_ema=False, **kwargs):\n        super(CausalVAEModelWrapper, self).__init__()\n        self.vae = CausalVAEModel.from_pretrained(model_path, subfolder=subfolder, cache_dir=cache_dir, **kwargs)\n        \n    def encode(self, x):\n        x = self.vae.encode(x).sample().mul_(0.18215)\n        return x\n    def decode(self, x):\n        x = self.vae.decode(x / 0.18215)\n        x = rearrange(x, 'b c t h w -> b t c h w').contiguous()\n        return x\n\n    def dtype(self):\n        return self.vae.dtype\n    \nclass WFVAEModelWrapper(nn.Module):\n    def __init__(self, model_path, subfolder=None, cache_dir=None, **kwargs):\n        super(WFVAEModelWrapper, self).__init__()\n        self.vae = WFVAEModel.from_pretrained(model_path, subfolder=subfolder, cache_dir=cache_dir, **kwargs)\n        self.register_buffer('shift', torch.tensor(self.vae.config.shift)[None, :, None, None, None])\n        self.register_buffer('scale', torch.tensor(self.vae.config.scale)[None, :, None, None, None])\n        \n    def encode(self, x):\n        x = (self.vae.encode(x).sample() - self.shift.to(x.device, dtype=x.dtype)) * self.scale.to(x.device, dtype=x.dtype)\n        return x\n    \n    def decode(self, x):\n        x = x / self.scale.to(x.device, dtype=x.dtype) + self.shift.to(x.device, dtype=x.dtype)\n        x = self.vae.decode(x)\n        x = rearrange(x, 'b c t h w -> b t c h w').contiguous()\n        return x\n\n    def dtype(self):\n        return self.vae.dtype\n\nae_wrapper = {\n    'CausalVAEModel_D4_2x8x8': CausalVAEModelWrapper,\n    'CausalVAEModel_D8_2x8x8': CausalVAEModelWrapper,\n    'CausalVAEModel_D4_4x8x8': CausalVAEModelWrapper,\n    'CausalVAEModel_D8_4x8x8': CausalVAEModelWrapper,\n    'WFVAEModel_D8_4x8x8': WFVAEModelWrapper,\n    'WFVAEModel_D16_4x8x8': WFVAEModelWrapper,\n    'WFVAEModel_D32_4x8x8': WFVAEModelWrapper,\n    'WFVAEModel_D32_8x8x8': WFVAEModelWrapper,\n}\n\nae_stride_config = {\n    'CausalVAEModel_D4_2x8x8': [2, 8, 8],\n    'CausalVAEModel_D8_2x8x8': [2, 8, 8],\n    'CausalVAEModel_D4_4x8x8': [4, 8, 8],\n    'CausalVAEModel_D8_4x8x8': [4, 8, 8],\n    'WFVAEModel_D8_4x8x8': [4, 8, 8],\n    'WFVAEModel_D16_4x8x8': [4, 8, 8],\n    'WFVAEModel_D32_4x8x8': [4, 8, 8],\n    'WFVAEModel_D32_8x8x8': [8, 8, 8],\n}\n\nae_channel_config = {\n    'CausalVAEModel_D4_2x8x8': 4,\n    'CausalVAEModel_D8_2x8x8': 8,\n    'CausalVAEModel_D4_4x8x8': 4,\n    'CausalVAEModel_D8_4x8x8': 8,\n    'WFVAEModel_D8_4x8x8': 8,\n    'WFVAEModel_D16_4x8x8': 16,\n    'WFVAEModel_D32_4x8x8': 32,\n    'WFVAEModel_D32_8x8x8': 32,\n}\n\nae_denorm = {\n    'CausalVAEModel_D4_2x8x8': lambda x: (x + 1.) / 2.,\n    'CausalVAEModel_D8_2x8x8': lambda x: (x + 1.) / 2.,\n    'CausalVAEModel_D4_4x8x8': lambda x: (x + 1.) / 2.,\n    'CausalVAEModel_D8_4x8x8': lambda x: (x + 1.) / 2.,\n    'WFVAEModel_D8_4x8x8': lambda x: (x + 1.) / 2.,\n    'WFVAEModel_D16_4x8x8': lambda x: (x + 1.) / 2.,\n    'WFVAEModel_D32_4x8x8': lambda x: (x + 1.) / 2.,\n    'WFVAEModel_D32_8x8x8': lambda x: (x + 1.) / 2.,\n}\n\nae_norm = {\n    'CausalVAEModel_D4_2x8x8': Lambda(lambda x: 2. * x - 1.),\n    'CausalVAEModel_D8_2x8x8': Lambda(lambda x: 2. * x - 1.),\n    'CausalVAEModel_D4_4x8x8': Lambda(lambda x: 2. * x - 1.),\n    'CausalVAEModel_D8_4x8x8': Lambda(lambda x: 2. * x - 1.),\n    'WFVAEModel_D8_4x8x8': Lambda(lambda x: 2. * x - 1.),\n    'WFVAEModel_D16_4x8x8': Lambda(lambda x: 2. * x - 1.),\n    'WFVAEModel_D32_4x8x8': Lambda(lambda x: 2. * x - 1.),\n    'WFVAEModel_D32_8x8x8': Lambda(lambda x: 2. * x - 1.),\n}"
  },
  {
    "path": "opensora/models/causalvideovae/dataset/__init__.py",
    "content": ""
  },
  {
    "path": "opensora/models/causalvideovae/dataset/ddp_sampler.py",
    "content": "import math\nfrom typing import TypeVar, Optional, Iterator\n\nimport torch\nfrom torch.utils.data import Sampler, Dataset\nimport torch.distributed as dist\n\nT_co = TypeVar('T_co', covariant=True)\nclass CustomDistributedSampler(Sampler[T_co]):\n    r\"\"\"Sampler that restricts data loading to a subset of the dataset.\n\n    It is especially useful in conjunction with\n    :class:`torch.nn.parallel.DistributedDataParallel`. In such a case, each\n    process can pass a :class:`~torch.utils.data.DistributedSampler` instance as a\n    :class:`~torch.utils.data.DataLoader` sampler, and load a subset of the\n    original dataset that is exclusive to it.\n\n    .. note::\n        Dataset is assumed to be of constant size and that any instance of it always\n        returns the same elements in the same order.\n\n    Args:\n        dataset: Dataset used for sampling.\n        num_replicas (int, optional): Number of processes participating in\n            distributed training. By default, :attr:`world_size` is retrieved from the\n            current distributed group.\n        rank (int, optional): Rank of the current process within :attr:`num_replicas`.\n            By default, :attr:`rank` is retrieved from the current distributed\n            group.\n        shuffle (bool, optional): If ``True`` (default), sampler will shuffle the\n            indices.\n        seed (int, optional): random seed used to shuffle the sampler if\n            :attr:`shuffle=True`. This number should be identical across all\n            processes in the distributed group. Default: ``0``.\n        drop_last (bool, optional): if ``True``, then the sampler will drop the\n            tail of the data to make it evenly divisible across the number of\n            replicas. If ``False``, the sampler will add extra indices to make\n            the data evenly divisible across the replicas. Default: ``False``.\n\n    .. warning::\n        In distributed mode, calling the :meth:`set_epoch` method at\n        the beginning of each epoch **before** creating the :class:`DataLoader` iterator\n        is necessary to make shuffling work properly across multiple epochs. Otherwise,\n        the same ordering will be always used.\n\n    Example::\n\n        >>> # xdoctest: +SKIP\n        >>> sampler = DistributedSampler(dataset) if is_distributed else None\n        >>> loader = DataLoader(dataset, shuffle=(sampler is None),\n        ...                     sampler=sampler)\n        >>> for epoch in range(start_epoch, n_epochs):\n        ...     if is_distributed:\n        ...         sampler.set_epoch(epoch)\n        ...     train(loader)\n    \"\"\"\n\n    def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None,\n                 rank: Optional[int] = None, shuffle: bool = True,\n                 seed: int = 0, drop_last: bool = False) -> None:\n        if num_replicas is None:\n            if not dist.is_available():\n                raise RuntimeError(\"Requires distributed package to be available\")\n            num_replicas = dist.get_world_size()\n        if rank is None:\n            if not dist.is_available():\n                raise RuntimeError(\"Requires distributed package to be available\")\n            rank = dist.get_rank()\n        if rank >= num_replicas or rank < 0:\n            raise ValueError(\n                f\"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]\")\n        self.dataset = dataset\n        self.num_replicas = num_replicas\n        self.rank = rank\n        self.epoch = 0\n        self.current_index = 0\n        self.drop_last = drop_last\n        # If the dataset length is evenly divisible by # of replicas, then there\n        # is no need to drop any data, since the dataset will be split equally.\n        if self.drop_last and len(self.dataset) % self.num_replicas != 0:  # type: ignore[arg-type]\n            # Split to nearest available length that is evenly divisible.\n            # This is to ensure each rank receives the same amount of data when\n            # using this Sampler.\n            self.num_samples = math.ceil(\n                (len(self.dataset) - self.num_replicas) / self.num_replicas  # type: ignore[arg-type]\n            )\n        else:\n            self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)  # type: ignore[arg-type]\n        self.total_size = self.num_samples * self.num_replicas\n        self.shuffle = shuffle\n        self.seed = seed\n\n    def __iter__(self) -> Iterator[T_co]:\n        if self.shuffle:\n            # deterministically shuffle based on epoch and seed\n            g = torch.Generator()\n            g.manual_seed(self.seed + self.epoch)\n            indices = torch.randperm(len(self.dataset), generator=g).tolist()  # type: ignore[arg-type]\n        else:\n            indices = list(range(len(self.dataset)))  # type: ignore[arg-type]\n\n        if not self.drop_last:\n            # add extra samples to make it evenly divisible\n            padding_size = self.total_size - len(indices)\n            if padding_size <= len(indices):\n                indices += indices[:padding_size]\n            else:\n                indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]\n        else:\n            # remove tail of data to make it evenly divisible.\n            indices = indices[:self.total_size]\n        assert len(indices) == self.total_size\n\n        # subsample\n        indices = indices[self.rank:self.total_size:self.num_replicas]\n        assert len(indices) == self.num_samples\n        \n        while self.current_index < len(indices):\n            yield indices[self.current_index]\n            self.current_index += 1\n        self.current_index = 0\n        \n    def __len__(self) -> int:\n        return self.num_samples\n\n    def set_epoch(self, epoch: int) -> None:\n        r\"\"\"\n        Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas\n        use a different random ordering for each epoch. Otherwise, the next iteration of this\n        sampler will yield the same ordering.\n\n        Args:\n            epoch (int): Epoch number.\n        \"\"\"\n        self.epoch = epoch\n    \n    def state_dict(self) -> dict:\n        return {\n            'epoch': self.epoch,\n            'seed': self.seed,\n            'current_index': self.current_index\n        }\n\n    def load_state_dict(self, state_dict: dict) -> None:\n        self.epoch = state_dict['epoch']\n        self.seed = state_dict['seed']\n        self.current_index = state_dict.get('current_index', 0)\n        "
  },
  {
    "path": "opensora/models/causalvideovae/dataset/transform.py",
    "content": "import torch\nimport random\nimport numbers\nfrom torchvision.transforms import RandomCrop, RandomResizedCrop\n\n\ndef _is_tensor_video_clip(clip):\n    if not torch.is_tensor(clip):\n        raise TypeError(\"clip should be Tensor. Got %s\" % type(clip))\n\n    if not clip.ndimension() == 4:\n        raise ValueError(\"clip should be 4D. Got %dD\" % clip.dim())\n\n    return True\n\n\ndef center_crop_arr(pil_image, image_size):\n    \"\"\"\n    Center cropping implementation from ADM.\n    https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126\n    \"\"\"\n    while min(*pil_image.size) >= 2 * image_size:\n        pil_image = pil_image.resize(\n            tuple(x // 2 for x in pil_image.size), resample=Image.BOX\n        )\n\n    scale = image_size / min(*pil_image.size)\n    pil_image = pil_image.resize(\n        tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC\n    )\n\n    arr = np.array(pil_image)\n    crop_y = (arr.shape[0] - image_size) // 2\n    crop_x = (arr.shape[1] - image_size) // 2\n    return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])\n\n\ndef crop(clip, i, j, h, w):\n    \"\"\"\n    Args:\n        clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)\n    \"\"\"\n    if len(clip.size()) != 4:\n        raise ValueError(\"clip should be a 4D tensor\")\n    return clip[..., i: i + h, j: j + w]\n\n\ndef resize(clip, target_size, interpolation_mode):\n    if len(target_size) != 2:\n        raise ValueError(f\"target size should be tuple (height, width), instead got {target_size}\")\n    return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=True, antialias=True)\n\n\ndef resize_scale(clip, target_size, interpolation_mode):\n    if len(target_size) != 2:\n        raise ValueError(f\"target size should be tuple (height, width), instead got {target_size}\")\n    H, W = clip.size(-2), clip.size(-1)\n    scale_ = target_size[0] / min(H, W)\n    return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=True, antialias=True)\n\n\ndef resized_crop(clip, i, j, h, w, size, interpolation_mode=\"bilinear\"):\n    \"\"\"\n    Do spatial cropping and resizing to the video clip\n    Args:\n        clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)\n        i (int): i in (i,j) i.e coordinates of the upper left corner.\n        j (int): j in (i,j) i.e coordinates of the upper left corner.\n        h (int): Height of the cropped region.\n        w (int): Width of the cropped region.\n        size (tuple(int, int)): height and width of resized clip\n    Returns:\n        clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W)\n    \"\"\"\n    if not _is_tensor_video_clip(clip):\n        raise ValueError(\"clip should be a 4D torch.tensor\")\n    clip = crop(clip, i, j, h, w)\n    clip = resize(clip, size, interpolation_mode)\n    return clip\n\n\ndef center_crop(clip, crop_size):\n    if not _is_tensor_video_clip(clip):\n        raise ValueError(\"clip should be a 4D torch.tensor\")\n    h, w = clip.size(-2), clip.size(-1)\n    th, tw = crop_size\n    if h < th or w < tw:\n        raise ValueError(\"height and width must be no smaller than crop_size\")\n\n    i = int(round((h - th) / 2.0))\n    j = int(round((w - tw) / 2.0))\n    return crop(clip, i, j, th, tw)\n\n\ndef center_crop_using_short_edge(clip):\n    if not _is_tensor_video_clip(clip):\n        raise ValueError(\"clip should be a 4D torch.tensor\")\n    h, w = clip.size(-2), clip.size(-1)\n    if h < w:\n        th, tw = h, h\n        i = 0\n        j = int(round((w - tw) / 2.0))\n    else:\n        th, tw = w, w\n        i = int(round((h - th) / 2.0))\n        j = 0\n    return crop(clip, i, j, th, tw)\n\n\ndef random_shift_crop(clip):\n    '''\n    Slide along the long edge, with the short edge as crop size\n    '''\n    if not _is_tensor_video_clip(clip):\n        raise ValueError(\"clip should be a 4D torch.tensor\")\n    h, w = clip.size(-2), clip.size(-1)\n\n    if h <= w:\n        long_edge = w\n        short_edge = h\n    else:\n        long_edge = h\n        short_edge = w\n\n    th, tw = short_edge, short_edge\n\n    i = torch.randint(0, h - th + 1, size=(1,)).item()\n    j = torch.randint(0, w - tw + 1, size=(1,)).item()\n    return crop(clip, i, j, th, tw)\n\n\ndef to_tensor(clip):\n    \"\"\"\n    Convert tensor data type from uint8 to float, divide value by 255.0 and\n    permute the dimensions of clip tensor\n    Args:\n        clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)\n    Return:\n        clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)\n    \"\"\"\n    _is_tensor_video_clip(clip)\n    if not clip.dtype == torch.uint8:\n        raise TypeError(\"clip tensor should have data type uint8. Got %s\" % str(clip.dtype))\n    # return clip.float().permute(3, 0, 1, 2) / 255.0\n    return clip.float() / 255.0\n\n\ndef normalize(clip, mean, std, inplace=False):\n    \"\"\"\n    Args:\n        clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)\n        mean (tuple): pixel RGB mean. Size is (3)\n        std (tuple): pixel standard deviation. Size is (3)\n    Returns:\n        normalized clip (torch.tensor): Size is (T, C, H, W)\n    \"\"\"\n    if not _is_tensor_video_clip(clip):\n        raise ValueError(\"clip should be a 4D torch.tensor\")\n    if not inplace:\n        clip = clip.clone()\n    mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)\n    # print(mean)\n    std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)\n    clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])\n    return clip\n\n\ndef hflip(clip):\n    \"\"\"\n    Args:\n        clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)\n    Returns:\n        flipped clip (torch.tensor): Size is (T, C, H, W)\n    \"\"\"\n    if not _is_tensor_video_clip(clip):\n        raise ValueError(\"clip should be a 4D torch.tensor\")\n    return clip.flip(-1)\n\n\nclass RandomCropVideo:\n    def __init__(self, size):\n        if isinstance(size, numbers.Number):\n            self.size = (int(size), int(size))\n        else:\n            self.size = size\n\n    def __call__(self, clip):\n        \"\"\"\n        Args:\n            clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)\n        Returns:\n            torch.tensor: randomly cropped video clip.\n                size is (T, C, OH, OW)\n        \"\"\"\n        i, j, h, w = self.get_params(clip)\n        return crop(clip, i, j, h, w)\n\n    def get_params(self, clip):\n        h, w = clip.shape[-2:]\n        th, tw = self.size\n\n        if h < th or w < tw:\n            raise ValueError(f\"Required crop size {(th, tw)} is larger than input image size {(h, w)}\")\n\n        if w == tw and h == th:\n            return 0, 0, h, w\n\n        i = torch.randint(0, h - th + 1, size=(1,)).item()\n        j = torch.randint(0, w - tw + 1, size=(1,)).item()\n\n        return i, j, th, tw\n\n    def __repr__(self) -> str:\n        return f\"{self.__class__.__name__}(size={self.size})\"\n\n\nclass SpatialStrideCropVideo:\n    def __init__(self, stride):\n            self.stride = stride\n\n    def __call__(self, clip):\n        \"\"\"\n        Args:\n            clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)\n        Returns:\n            torch.tensor: cropped video clip by stride.\n                size is (T, C, OH, OW)\n        \"\"\"\n        i, j, h, w = self.get_params(clip)\n        return crop(clip, i, j, h, w)\n\n    def get_params(self, clip):\n        h, w = clip.shape[-2:]\n\n        th, tw = h // self.stride * self.stride, w // self.stride * self.stride\n\n        return 0, 0, th, tw  # from top-left\n\n    def __repr__(self) -> str:\n        return f\"{self.__class__.__name__}(size={self.size})\"\n\nclass LongSideResizeVideo:\n    '''\n    First use the long side,\n    then resize to the specified size\n    '''\n\n    def __init__(\n            self,\n            size,\n            skip_low_resolution=False, \n            interpolation_mode=\"bilinear\",\n    ):\n        self.size = size\n        self.skip_low_resolution = skip_low_resolution\n        self.interpolation_mode = interpolation_mode\n\n    def __call__(self, clip):\n        \"\"\"\n        Args:\n            clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)\n        Returns:\n            torch.tensor: scale resized video clip.\n                size is (T, C, 512, *) or (T, C, *, 512)\n        \"\"\"\n        _, _, h, w = clip.shape\n        if self.skip_low_resolution and max(h, w) <= self.size:\n            return clip\n        if h > w:\n            w = int(w * self.size / h)\n            h = self.size\n        else:\n            h = int(h * self.size / w)\n            w = self.size\n        resize_clip = resize(clip, target_size=(h, w),\n                                         interpolation_mode=self.interpolation_mode)\n        return resize_clip\n\n    def __repr__(self) -> str:\n        return f\"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}\"\n\nclass CenterCropResizeVideo:\n    '''\n    First use the short side for cropping length,\n    center crop video, then resize to the specified size\n    '''\n\n    def __init__(\n            self,\n            size,\n            interpolation_mode=\"bilinear\",\n    ):\n        if isinstance(size, tuple):\n            if len(size) != 2:\n                raise ValueError(f\"size should be tuple (height, width), instead got {size}\")\n            self.size = size\n        else:\n            self.size = (size, size)\n\n        self.interpolation_mode = interpolation_mode\n\n    def __call__(self, clip):\n        \"\"\"\n        Args:\n            clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)\n        Returns:\n            torch.tensor: scale resized / center cropped video clip.\n                size is (T, C, crop_size, crop_size)\n        \"\"\"\n        clip_center_crop = center_crop_using_short_edge(clip)\n        clip_center_crop_resize = resize(clip_center_crop, target_size=self.size,\n                                         interpolation_mode=self.interpolation_mode)\n        return clip_center_crop_resize\n\n    def __repr__(self) -> str:\n        return f\"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}\"\n\n\nclass UCFCenterCropVideo:\n    '''\n    First scale to the specified size in equal proportion to the short edge,\n    then center cropping\n    '''\n\n    def __init__(\n            self,\n            size,\n            interpolation_mode=\"bilinear\",\n    ):\n        if isinstance(size, tuple):\n            if len(size) != 2:\n                raise ValueError(f\"size should be tuple (height, width), instead got {size}\")\n            self.size = size\n        else:\n            self.size = (size, size)\n\n        self.interpolation_mode = interpolation_mode\n\n    def __call__(self, clip):\n        \"\"\"\n        Args:\n            clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)\n        Returns:\n            torch.tensor: scale resized / center cropped video clip.\n                size is (T, C, crop_size, crop_size)\n        \"\"\"\n        clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode)\n        clip_center_crop = center_crop(clip_resize, self.size)\n        return clip_center_crop\n\n    def __repr__(self) -> str:\n        return f\"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}\"\n\n\nclass KineticsRandomCropResizeVideo:\n    '''\n    Slide along the long edge, with the short edge as crop size. And resie to the desired size.\n    '''\n\n    def __init__(\n            self,\n            size,\n            interpolation_mode=\"bilinear\",\n    ):\n        if isinstance(size, tuple):\n            if len(size) != 2:\n                raise ValueError(f\"size should be tuple (height, width), instead got {size}\")\n            self.size = size\n        else:\n            self.size = (size, size)\n\n        self.interpolation_mode = interpolation_mode\n\n    def __call__(self, clip):\n        clip_random_crop = random_shift_crop(clip)\n        clip_resize = resize(clip_random_crop, self.size, self.interpolation_mode)\n        return clip_resize\n\n\nclass CenterCropVideo:\n    def __init__(\n            self,\n            size,\n            interpolation_mode=\"bilinear\",\n    ):\n        if isinstance(size, tuple):\n            if len(size) != 2:\n                raise ValueError(f\"size should be tuple (height, width), instead got {size}\")\n            self.size = size\n        else:\n            self.size = (size, size)\n\n        self.interpolation_mode = interpolation_mode\n\n    def __call__(self, clip):\n        \"\"\"\n        Args:\n            clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)\n        Returns:\n            torch.tensor: center cropped video clip.\n                size is (T, C, crop_size, crop_size)\n        \"\"\"\n        clip_center_crop = center_crop(clip, self.size)\n        return clip_center_crop\n\n    def __repr__(self) -> str:\n        return f\"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}\"\n\n\nclass NormalizeVideo:\n    \"\"\"\n    Normalize the video clip by mean subtraction and division by standard deviation\n    Args:\n        mean (3-tuple): pixel RGB mean\n        std (3-tuple): pixel RGB standard deviation\n        inplace (boolean): whether do in-place normalization\n    \"\"\"\n\n    def __init__(self, mean, std, inplace=False):\n        self.mean = mean\n        self.std = std\n        self.inplace = inplace\n\n    def __call__(self, clip):\n        \"\"\"\n        Args:\n            clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W)\n        \"\"\"\n        return normalize(clip, self.mean, self.std, self.inplace)\n\n    def __repr__(self) -> str:\n        return f\"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})\"\n\n\nclass ToTensorVideo:\n    \"\"\"\n    Convert tensor data type from uint8 to float, divide value by 255.0 and\n    permute the dimensions of clip tensor\n    \"\"\"\n\n    def __init__(self):\n        pass\n\n    def __call__(self, clip):\n        \"\"\"\n        Args:\n            clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)\n        Return:\n            clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)\n        \"\"\"\n        return to_tensor(clip)\n\n    def __repr__(self) -> str:\n        return self.__class__.__name__\n\n\nclass RandomHorizontalFlipVideo:\n    \"\"\"\n    Flip the video clip along the horizontal direction with a given probability\n    Args:\n        p (float): probability of the clip being flipped. Default value is 0.5\n    \"\"\"\n\n    def __init__(self, p=0.5):\n        self.p = p\n\n    def __call__(self, clip):\n        \"\"\"\n        Args:\n            clip (torch.tensor): Size is (T, C, H, W)\n        Return:\n            clip (torch.tensor): Size is (T, C, H, W)\n        \"\"\"\n        if random.random() < self.p:\n            clip = hflip(clip)\n        return clip\n\n    def __repr__(self) -> str:\n        return f\"{self.__class__.__name__}(p={self.p})\"\n\n\n#  ------------------------------------------------------------\n#  ---------------------  Sampling  ---------------------------\n#  ------------------------------------------------------------\nclass TemporalRandomCrop(object):\n    \"\"\"Temporally crop the given frame indices at a random location.\n\n    Args:\n        size (int): Desired length of frames will be seen in the model.\n    \"\"\"\n\n    def __init__(self, size):\n        self.size = size\n\n    def __call__(self, total_frames):\n        rand_end = max(0, total_frames - self.size - 1)\n        begin_index = random.randint(0, rand_end)\n        end_index = min(begin_index + self.size, total_frames)\n        return begin_index, end_index\n\nclass DynamicSampleDuration(object):\n    \"\"\"Temporally crop the given frame indices at a random location.\n\n    Args:\n        size (int): Desired length of frames will be seen in the model.\n    \"\"\"\n\n    def __init__(self, t_stride, extra_1):\n        self.t_stride = t_stride\n        self.extra_1 = extra_1\n\n    def __call__(self, t, h, w):\n        if self.extra_1:\n            t = t - 1\n        truncate_t_list = list(range(t+1))[t//2:][::self.t_stride]  # need half at least\n        truncate_t = random.choice(truncate_t_list)\n        if self.extra_1:\n            truncate_t = truncate_t + 1\n        return 0, truncate_t\n\nif __name__ == '__main__':\n    from torchvision import transforms\n    import torchvision.io as io\n    import numpy as np\n    from torchvision.utils import save_image\n    import os\n\n    vframes, aframes, info = io.read_video(\n        filename='./v_Archery_g01_c03.avi',\n        pts_unit='sec',\n        output_format='TCHW'\n    )\n\n    trans = transforms.Compose([\n        ToTensorVideo(),\n        RandomHorizontalFlipVideo(),\n        UCFCenterCropVideo(512),\n        # NormalizeVideo(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),\n        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)\n    ])\n\n    target_video_len = 32\n    frame_interval = 1\n    total_frames = len(vframes)\n    print(total_frames)\n\n    temporal_sample = TemporalRandomCrop(target_video_len * frame_interval)\n\n    # Sampling video frames\n    start_frame_ind, end_frame_ind = temporal_sample(total_frames)\n    # print(start_frame_ind)\n    # print(end_frame_ind)\n    assert end_frame_ind - start_frame_ind >= target_video_len\n    frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, target_video_len, dtype=int)\n    print(frame_indice)\n\n    select_vframes = vframes[frame_indice]\n    print(select_vframes.shape)\n    print(select_vframes.dtype)\n\n    select_vframes_trans = trans(select_vframes)\n    print(select_vframes_trans.shape)\n    print(select_vframes_trans.dtype)\n\n    select_vframes_trans_int = ((select_vframes_trans * 0.5 + 0.5) * 255).to(dtype=torch.uint8)\n    print(select_vframes_trans_int.dtype)\n    print(select_vframes_trans_int.permute(0, 2, 3, 1).shape)\n\n    io.write_video('./test.avi', select_vframes_trans_int.permute(0, 2, 3, 1), fps=8)\n\n    for i in range(target_video_len):\n        save_image(select_vframes_trans[i], os.path.join('./test000', '%04d.png' % i), normalize=True,\n                   value_range=(-1, 1))\n"
  },
  {
    "path": "opensora/models/causalvideovae/dataset/video_dataset.py",
    "content": "import os.path as osp\nimport random\nfrom glob import glob\nfrom torchvision import transforms\nimport numpy as np\nimport torch\nimport torch.utils.data as data\nimport torch.nn.functional as F\nimport pickle\nimport decord\nfrom torch.nn import functional as F\nfrom .transform import ToTensorVideo, CenterCropVideo\nfrom torchvision.transforms._transforms_video import CenterCropVideo as TVCenterCropVideo\nfrom torchvision.transforms import Lambda, Compose, Resize\nimport torch\nimport os\n\n\nclass DecordInit(object):\n    def __init__(self, num_threads=1):\n        self.num_threads = num_threads\n        self.ctx = decord.cpu(0)\n\n    def __call__(self, filename):\n        reader = decord.VideoReader(\n            filename, ctx=self.ctx, num_threads=self.num_threads\n        )\n        return reader\n\n    def __repr__(self):\n        repr_str = (\n            f\"{self.__class__.__name__}(\"\n            f\"sr={self.sr},\"\n            f\"num_threads={self.num_threads})\"\n        )\n        return repr_str\n\ndef TemporalRandomCrop(total_frames, size):\n    rand_end = max(0, total_frames - size - 1)\n    begin_index = random.randint(0, rand_end)\n    end_index = min(begin_index + size, total_frames)\n    return begin_index, end_index\n\ndef _format_video_shape(video, time_compress=4, spatial_compress=8):\n    \"\"\"Prepare video for VAE\"\"\"\n    time = video.shape[1]\n    height = video.shape[2]\n    width = video.shape[3]\n    new_time = (\n        (time - (time - 1) % time_compress) if (time - 1) % time_compress != 0 else time\n    )\n    new_height = (\n        (height - (height) % spatial_compress)\n        if height % spatial_compress != 0\n        else height\n    )\n    new_width = (\n        (width - (width) % spatial_compress) if width % spatial_compress != 0 else width\n    )\n    return video[:, :new_time, :new_height, :new_width]\n\n\nclass TrainVideoDataset(data.Dataset):\n    video_exts = [\"avi\", \"mp4\", \"webm\"]\n\n    def __init__(\n        self,\n        video_folder,\n        sequence_length,\n        train=True,\n        resolution=64,\n        sample_rate=1,\n        dynamic_sample=True,\n        cache_file=None,\n        is_main_process=False,\n    ):\n\n        self.train = train\n        self.sequence_length = sequence_length\n        self.sample_rate = sample_rate\n        self.resolution = resolution\n        self.v_decoder = DecordInit()\n        self.video_folder = video_folder\n        self.dynamic_sample = dynamic_sample\n        self.cache_file = cache_file\n        self.transform = transforms.Compose(\n            [\n                ToTensorVideo(),\n                Resize(self.resolution),\n                CenterCropVideo(self.resolution),\n                Lambda(lambda x: 2.0 * x - 1.0),\n            ]\n        )\n        print(\"Building datasets...\")\n        self.is_main_process = is_main_process\n        self.samples = self._make_dataset()\n\n    def _make_dataset(self):\n        cache_file = osp.join(self.video_folder, self.cache_file)\n\n        if osp.exists(cache_file):\n            with open(cache_file, \"rb\") as f:\n                samples = pickle.load(f)\n        else:\n            samples = []\n            samples += sum(\n                [\n                    glob(osp.join(self.video_folder, \"**\", f\"*.{ext}\"), recursive=True)\n                    for ext in self.video_exts\n                ],\n                [],\n            )\n            if self.is_main_process:\n                with open(cache_file, \"wb\") as f:\n                    pickle.dump(samples, f)\n        return samples\n\n    def __len__(self):\n        return len(self.samples)\n\n    def __getitem__(self, idx):\n        video_path = self.samples[idx]\n        try:\n            video = self.decord_read(video_path)\n            video = self.transform(video)  # T C H W -> T C H W\n            video = video.transpose(0, 1)  # T C H W -> C T H W\n            return dict(video=video, label=\"\")\n        except Exception as e:\n            print(f\"Error with {e}, {video_path}\")\n            return self.__getitem__(random.randint(0, self.__len__() - 1))\n\n    def decord_read(self, path):\n        decord_vr = self.v_decoder(path)\n        total_frames = len(decord_vr)\n        # Sampling video frames\n        if self.dynamic_sample:\n            sample_rate = random.randint(1, self.sample_rate)\n        else:\n            sample_rate = self.sample_rate\n        size = self.sequence_length * sample_rate\n        start_frame_ind, end_frame_ind = TemporalRandomCrop(total_frames, size)\n        frame_indice = np.linspace(\n            start_frame_ind, end_frame_ind - 1, self.sequence_length, dtype=int\n        )\n\n        video_data = decord_vr.get_batch(frame_indice).asnumpy()\n        video_data = torch.from_numpy(video_data)\n        video_data = video_data.permute(0, 3, 1, 2)\n        return video_data\n\ndef resize(x, resolution):\n    height, width = x.shape[-2:]\n    aspect_ratio = width / height\n    if width <= height:\n        new_width = resolution\n        new_height = int(resolution / aspect_ratio)\n    else:\n        new_height = resolution\n        new_width = int(resolution * aspect_ratio)\n    resized_x = F.interpolate(x, size=(new_height, new_width), mode='bilinear', align_corners=True, antialias=True)\n    return resized_x\n\nclass ValidVideoDataset(data.Dataset):\n    video_exts = [\"avi\", \"mp4\", \"webm\"]\n    \n    def __init__(\n        self,\n        real_video_dir,\n        num_frames,\n        sample_rate=1,\n        crop_size=None,\n        resolution=128,\n        is_main_process=False\n    ) -> None:\n        super().__init__()\n        self.is_main_process = is_main_process\n        self.real_video_files = self._make_dataset(real_video_dir)\n        \n        self.num_frames = num_frames\n        self.sample_rate = sample_rate\n        self.crop_size = crop_size\n        self.short_size = resolution\n        self.v_decoder = DecordInit()\n        self.transform = Compose(\n            [\n                ToTensorVideo(),\n                Resize(resolution),\n                CenterCropVideo(resolution) if crop_size is not None else Lambda(lambda x: x),\n            ]\n        )\n        \n    def _make_dataset(self, real_video_dir):\n        cache_file = osp.join(real_video_dir, \"idx.pkl\")\n\n        if osp.exists(cache_file):\n            with open(cache_file, \"rb\") as f:\n                samples = pickle.load(f)\n        else:\n            samples = []\n            samples += sum(\n                [\n                    glob(osp.join(real_video_dir, \"**\", f\"*.{ext}\"), recursive=True)\n                    for ext in self.video_exts\n                ],\n                [],\n            )\n            if self.is_main_process:\n                with open(cache_file, \"wb\") as f:\n                    pickle.dump(samples, f)\n        return samples\n    \n    def __len__(self):\n        return len(self.real_video_files)\n\n    def __getitem__(self, index):\n        try:\n            if index >= len(self):\n                raise IndexError\n            real_video_file = self.real_video_files[index]\n            real_video_tensor = self._load_video(real_video_file)\n            real_video_tensor = self.transform(real_video_tensor)\n            video_name = os.path.basename(real_video_file)\n            return {'video': real_video_tensor, 'file_name': video_name }\n        except:\n            print(f\"Video error: {self.real_video_files[index]}\")\n            return self.__getitem__(0)\n\n    def _load_video(self, video_path, sample_rate=None):\n        num_frames = self.num_frames\n        if not sample_rate:\n            sample_rate = self.sample_rate\n        try:\n            decord_vr = self.v_decoder(video_path)\n        except:\n            raise Exception(f\"fail to load {video_path}.\")\n        total_frames = len(decord_vr)\n        sample_frames_len = sample_rate * num_frames\n\n        if total_frames >= sample_frames_len:\n            s = 0\n            e = s + sample_frames_len\n            num_frames = num_frames\n        else:\n            raise Exception(\"video too short!\")\n            \n        frame_id_list = np.linspace(s, e - 1, num_frames, dtype=int)\n        video_data = decord_vr.get_batch(frame_id_list).asnumpy()\n        video_data = torch.from_numpy(video_data)\n        video_data = video_data.permute(3, 0, 1, 2)\n        return video_data\n"
  },
  {
    "path": "opensora/models/causalvideovae/eval/cal_fvd.py",
    "content": "import numpy as np\nimport torch\nfrom tqdm import tqdm\n\ndef trans(x):\n    # if greyscale images add channel\n    if x.shape[-3] == 1:\n        x = x.repeat(1, 1, 3, 1, 1)\n\n    # permute BTCHW -> BCTHW\n    x = x.permute(0, 2, 1, 3, 4) \n\n    return x\n\ndef calculate_fvd(videos1, videos2, device, method='styleganv'):\n\n    if method == 'styleganv':\n        from fvd.styleganv.fvd import get_fvd_feats, frechet_distance, load_i3d_pretrained\n    elif method == 'videogpt':\n        from fvd.videogpt.fvd import load_i3d_pretrained\n        from fvd.videogpt.fvd import get_fvd_logits as get_fvd_feats\n        from fvd.videogpt.fvd import frechet_distance\n\n    print(\"calculate_fvd...\")\n\n    # videos [batch_size, timestamps, channel, h, w]\n    \n    assert videos1.shape == videos2.shape\n\n    i3d = load_i3d_pretrained(device=device)\n    fvd_results = []\n\n    # support grayscale input, if grayscale -> channel*3\n    # BTCHW -> BCTHW\n    # videos -> [batch_size, channel, timestamps, h, w]\n\n    videos1 = trans(videos1)\n    videos2 = trans(videos2)\n\n    fvd_results = {}\n\n    # for calculate FVD, each clip_timestamp must >= 10\n    for clip_timestamp in tqdm(range(10, videos1.shape[-3]+1)):\n       \n        # get a video clip\n        # videos_clip [batch_size, channel, timestamps[:clip], h, w]\n        videos_clip1 = videos1[:, :, : clip_timestamp]\n        videos_clip2 = videos2[:, :, : clip_timestamp]\n\n        # get FVD features\n        feats1 = get_fvd_feats(videos_clip1, i3d=i3d, device=device)\n        feats2 = get_fvd_feats(videos_clip2, i3d=i3d, device=device)\n      \n        # calculate FVD when timestamps[:clip]\n        fvd_results[clip_timestamp] = frechet_distance(feats1, feats2)\n\n    result = {\n        \"value\": fvd_results,\n        \"video_setting\": videos1.shape,\n        \"video_setting_name\": \"batch_size, channel, time, heigth, width\",\n    }\n\n    return result\n\n# test code / using example\n\ndef main():\n    NUMBER_OF_VIDEOS = 8\n    VIDEO_LENGTH = 50\n    CHANNEL = 3\n    SIZE = 64\n    videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)\n    videos2 = torch.ones(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)\n    device = torch.device(\"cuda\")\n    # device = torch.device(\"cpu\")\n\n    import json\n    result = calculate_fvd(videos1, videos2, device, method='videogpt')\n    print(json.dumps(result, indent=4))\n\n    result = calculate_fvd(videos1, videos2, device, method='styleganv')\n    print(json.dumps(result, indent=4))\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "opensora/models/causalvideovae/eval/cal_lpips.py",
    "content": "import numpy as np\nimport torch\nfrom tqdm import tqdm\nimport math\n\nimport torch\nimport lpips\n\nspatial = True         # Return a spatial map of perceptual distance.\n\n# Linearly calibrated models (LPIPS)\nloss_fn = lpips.LPIPS(net='alex', spatial=spatial) # Can also set net = 'squeeze' or 'vgg'\n# loss_fn = lpips.LPIPS(net='alex', spatial=spatial, lpips=False) # Can also set net = 'squeeze' or 'vgg'\n\ndef trans(x):\n    # if greyscale images add channel\n    if x.shape[-3] == 1:\n        x = x.repeat(1, 1, 3, 1, 1)\n\n    # value range [0, 1] -> [-1, 1]\n    x = x * 2 - 1\n\n    return x\n\ndef calculate_lpips(videos1, videos2, device):\n    # image should be RGB, IMPORTANT: normalized to [-1,1]\n    print(\"calculate_lpips...\")\n\n    assert videos1.shape == videos2.shape\n\n    # videos [batch_size, timestamps, channel, h, w]\n\n    # support grayscale input, if grayscale -> channel*3\n    # value range [0, 1] -> [-1, 1]\n    videos1 = trans(videos1)\n    videos2 = trans(videos2)\n\n    lpips_results = []\n\n    for video_num in tqdm(range(videos1.shape[0])):\n        # get a video\n        # video [timestamps, channel, h, w]\n        video1 = videos1[video_num]\n        video2 = videos2[video_num]\n\n        lpips_results_of_a_video = []\n        for clip_timestamp in range(len(video1)):\n            # get a img\n            # img [timestamps[x], channel, h, w]\n            # img [channel, h, w] tensor\n\n            img1 = video1[clip_timestamp].unsqueeze(0).to(device)\n            img2 = video2[clip_timestamp].unsqueeze(0).to(device)\n            \n            loss_fn.to(device)\n\n            # calculate lpips of a video\n            lpips_results_of_a_video.append(loss_fn.forward(img1, img2).mean().detach().cpu().tolist())\n        lpips_results.append(lpips_results_of_a_video)\n    \n    lpips_results = np.array(lpips_results)\n    \n    lpips = {}\n    lpips_std = {}\n\n    for clip_timestamp in range(len(video1)):\n        lpips[clip_timestamp] = np.mean(lpips_results[:,clip_timestamp])\n        lpips_std[clip_timestamp] = np.std(lpips_results[:,clip_timestamp])\n\n\n    result = {\n        \"value\": lpips,\n        \"value_std\": lpips_std,\n        \"video_setting\": video1.shape,\n        \"video_setting_name\": \"time, channel, heigth, width\",\n    }\n\n    return result\n\n# test code / using example\n\ndef main():\n    NUMBER_OF_VIDEOS = 8\n    VIDEO_LENGTH = 50\n    CHANNEL = 3\n    SIZE = 64\n    videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)\n    videos2 = torch.ones(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)\n    device = torch.device(\"cuda\")\n    # device = torch.device(\"cpu\")\n\n    import json\n    result = calculate_lpips(videos1, videos2, device)\n    print(json.dumps(result, indent=4))\n\nif __name__ == \"__main__\":\n    main()"
  },
  {
    "path": "opensora/models/causalvideovae/eval/cal_psnr.py",
    "content": "import numpy as np\nimport torch\nfrom tqdm import tqdm\nimport math\n\ndef img_psnr_cuda(img1, img2):\n    # [0,1]\n    # compute mse\n    # mse = np.mean((img1-img2)**2)\n    mse = torch.mean((img1 / 1.0 - img2 / 1.0) ** 2)\n    # compute psnr\n    if mse < 1e-10:\n        return 100\n    psnr = 20 * torch.log10(1 / torch.sqrt(mse))\n    return psnr\n\n\ndef img_psnr(img1, img2):\n    # [0,1]\n    # compute mse\n    # mse = np.mean((img1-img2)**2)\n    mse = np.mean((img1 / 1.0 - img2 / 1.0) ** 2)\n    # compute psnr\n    if mse < 1e-10:\n        return 100\n    psnr = 20 * math.log10(1 / math.sqrt(mse))\n    return psnr\n\n\ndef trans(x):\n    return x\n\ndef calculate_psnr(videos1, videos2):\n    print(\"calculate_psnr...\")\n\n    # videos [batch_size, timestamps, channel, h, w]\n    \n    assert videos1.shape == videos2.shape\n\n    videos1 = trans(videos1)\n    videos2 = trans(videos2)\n\n    psnr_results = []\n    \n    for video_num in tqdm(range(videos1.shape[0])):\n        # get a video\n        # video [timestamps, channel, h, w]\n        video1 = videos1[video_num]\n        video2 = videos2[video_num]\n\n        psnr_results_of_a_video = []\n        for clip_timestamp in range(len(video1)):\n            # get a img\n            # img [timestamps[x], channel, h, w]\n            # img [channel, h, w] numpy\n\n            img1 = video1[clip_timestamp].numpy()\n            img2 = video2[clip_timestamp].numpy()\n            \n            # calculate psnr of a video\n            psnr_results_of_a_video.append(img_psnr(img1, img2))\n\n        psnr_results.append(psnr_results_of_a_video)\n    \n    psnr_results = np.array(psnr_results) # [batch_size, num_frames]\n    psnr = {}\n    psnr_std = {}\n\n    for clip_timestamp in range(len(video1)):\n        psnr[clip_timestamp] = np.mean(psnr_results[:,clip_timestamp])\n        psnr_std[clip_timestamp] = np.std(psnr_results[:,clip_timestamp])\n\n    result = {\n        \"value\": psnr,\n        \"value_std\": psnr_std,\n        \"video_setting\": video1.shape,\n        \"video_setting_name\": \"time, channel, heigth, width\",\n    }\n\n    return result\n\n# test code / using example\n\ndef main():\n    NUMBER_OF_VIDEOS = 8\n    VIDEO_LENGTH = 50\n    CHANNEL = 3\n    SIZE = 64\n    videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)\n    videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)\n\n    import json\n    result = calculate_psnr(videos1, videos2)\n    print(json.dumps(result, indent=4))\n\nif __name__ == \"__main__\":\n    main()"
  },
  {
    "path": "opensora/models/causalvideovae/eval/cal_ssim.py",
    "content": "import numpy as np\nimport torch\nfrom tqdm import tqdm\nimport cv2\n \ndef ssim(img1, img2):\n    C1 = 0.01 ** 2\n    C2 = 0.03 ** 2\n    img1 = img1.astype(np.float64)\n    img2 = img2.astype(np.float64)\n    kernel = cv2.getGaussianKernel(11, 1.5)\n    window = np.outer(kernel, kernel.transpose())\n    mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]  # valid\n    mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]\n    mu1_sq = mu1 ** 2\n    mu2_sq = mu2 ** 2\n    mu1_mu2 = mu1 * mu2\n    sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq\n    sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq\n    sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2\n    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *\n                                                            (sigma1_sq + sigma2_sq + C2))\n    return ssim_map.mean()\n \n \ndef calculate_ssim_function(img1, img2):\n    # [0,1]\n    # ssim is the only metric extremely sensitive to gray being compared to b/w \n    if not img1.shape == img2.shape:\n        raise ValueError('Input images must have the same dimensions.')\n    if img1.ndim == 2:\n        return ssim(img1, img2)\n    elif img1.ndim == 3:\n        if img1.shape[0] == 3:\n            ssims = []\n            for i in range(3):\n                ssims.append(ssim(img1[i], img2[i]))\n            return np.array(ssims).mean()                   \n        elif img1.shape[0] == 1:\n            return ssim(np.squeeze(img1), np.squeeze(img2))\n    else:\n        raise ValueError('Wrong input image dimensions.')\n\ndef trans(x):\n    return x\n\ndef calculate_ssim(videos1, videos2):\n    print(\"calculate_ssim...\")\n\n    # videos [batch_size, timestamps, channel, h, w]\n    \n    assert videos1.shape == videos2.shape\n\n    videos1 = trans(videos1)\n    videos2 = trans(videos2)\n\n    ssim_results = []\n    \n    for video_num in tqdm(range(videos1.shape[0])):\n        # get a video\n        # video [timestamps, channel, h, w]\n        video1 = videos1[video_num]\n        video2 = videos2[video_num]\n\n        ssim_results_of_a_video = []\n        for clip_timestamp in range(len(video1)):\n            # get a img\n            # img [timestamps[x], channel, h, w]\n            # img [channel, h, w] numpy\n\n            img1 = video1[clip_timestamp].numpy()\n            img2 = video2[clip_timestamp].numpy()\n            \n            # calculate ssim of a video\n            ssim_results_of_a_video.append(calculate_ssim_function(img1, img2))\n\n        ssim_results.append(ssim_results_of_a_video)\n\n    ssim_results = np.array(ssim_results)\n\n    ssim = {}\n    ssim_std = {}\n\n    for clip_timestamp in range(len(video1)):\n        ssim[clip_timestamp] = np.mean(ssim_results[:,clip_timestamp])\n        ssim_std[clip_timestamp] = np.std(ssim_results[:,clip_timestamp])\n\n    result = {\n        \"value\": ssim,\n        \"value_std\": ssim_std,\n        \"video_setting\": video1.shape,\n        \"video_setting_name\": \"time, channel, heigth, width\",\n    }\n\n    return result\n\n# test code / using example\n\ndef main():\n    NUMBER_OF_VIDEOS = 8\n    VIDEO_LENGTH = 50\n    CHANNEL = 3\n    SIZE = 64\n    videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)\n    videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)\n    device = torch.device(\"cuda\")\n\n    import json\n    result = calculate_ssim(videos1, videos2)\n    print(json.dumps(result, indent=4))\n\nif __name__ == \"__main__\":\n    main()"
  },
  {
    "path": "opensora/models/causalvideovae/eval/eval.py",
    "content": "import os\nfrom argparse import ArgumentDefaultsHelpFormatter, ArgumentParser\nimport numpy as np\nimport torch\nfrom torch.utils.data import DataLoader, Subset\nfrom tqdm import tqdm\nimport sys\nfrom glob import glob\n\nsys.path.append(\".\")\n\nfrom opensora.models.causalvideovae.eval.cal_lpips import calculate_lpips\nfrom opensora.models.causalvideovae.eval.cal_fvd import calculate_fvd\nfrom opensora.models.causalvideovae.eval.cal_psnr import calculate_psnr\nfrom opensora.models.causalvideovae.eval.cal_ssim import calculate_ssim\nfrom opensora.models.causalvideovae.dataset.video_dataset import (\n    ValidVideoDataset,\n    DecordInit,\n    Compose,\n    Lambda,\n    resize,\n    CenterCropVideo,\n    ToTensorVideo\n)\n\nclass EvalDataset(ValidVideoDataset):\n    def __init__(\n        self,\n        real_video_dir,\n        generated_video_dir,\n        num_frames,\n        sample_rate=1,\n        crop_size=None,\n        resolution=128,\n    ) -> None:\n        self.is_main_process = False\n        self.v_decoder = DecordInit()\n        self.real_video_files = []\n        self.generated_video_files = self._make_dataset(generated_video_dir)\n        for video_file in self.generated_video_files:\n            filename = os.path.basename(video_file)\n            if not os.path.exists(os.path.join(real_video_dir, filename)):\n                raise Exception(os.path.join(real_video_dir, filename))\n            self.real_video_files.append(os.path.join(real_video_dir, filename))\n        self.num_frames = num_frames\n        self.sample_rate = sample_rate\n        self.crop_size = crop_size\n        self.short_size = resolution\n        self.transform = Compose(\n            [\n                ToTensorVideo(),\n                Lambda(lambda x: resize(x, self.short_size)),\n                (\n                    CenterCropVideo(crop_size)\n                    if crop_size is not None\n                    else Lambda(lambda x: x)\n                ),\n            ]\n        )\n\n    def _make_dataset(self, real_video_dir):\n        samples = []\n        samples += sum(\n            [\n                glob(os.path.join(real_video_dir, f\"*.{ext}\"), recursive=True)\n                for ext in self.video_exts\n            ],\n            [],\n        )\n        return samples\n    \n    def __len__(self):\n        return len(self.real_video_files)\n\n    def __getitem__(self, index):\n        if index >= len(self):\n            raise IndexError\n        real_video_file = self.real_video_files[index]\n        generated_video_file = self.generated_video_files[index]\n        real_video_tensor = self._load_video(real_video_file, self.sample_rate)\n        generated_video_tensor = self._load_video(generated_video_file, 1)\n        return {\"real\": self.transform(real_video_tensor), \"generated\": self.transform(generated_video_tensor)}\n\n\ndef calculate_common_metric(args, dataloader, device):\n    score_list = []\n    for batch_data in tqdm(dataloader):\n        real_videos = batch_data[\"real\"].to(device)\n        generated_videos = batch_data[\"generated\"].to(device)\n\n        assert real_videos.shape[2] == generated_videos.shape[2]\n        if args.metric == \"fvd\":\n            tmp_list = list(\n                calculate_fvd(\n                    real_videos, generated_videos, args.device, method=args.fvd_method\n                )[\"value\"].values()\n            )\n        elif args.metric == \"ssim\":\n            tmp_list = list(\n                calculate_ssim(real_videos, generated_videos)[\"value\"].values()\n            )\n        elif args.metric == \"psnr\":\n            tmp_list = [calculate_psnr(real_videos, generated_videos)]\n        else:\n            tmp_list = [calculate_lpips(real_videos, generated_videos, args.device)]\n        score_list += tmp_list\n    return np.mean(score_list)\n\n\ndef main():\n\n    if args.device is None:\n        device = torch.device(\"cuda\" if (torch.cuda.is_available()) else \"cpu\")\n    else:\n        device = torch.device(args.device)\n\n    if args.num_workers is None:\n        try:\n            num_cpus = len(os.sched_getaffinity(0))\n        except AttributeError:\n            num_cpus = os.cpu_count()\n        num_workers = min(num_cpus, 8) if num_cpus is not None else 0\n    else:\n        num_workers = args.num_workers\n\n    dataset = EvalDataset(\n        args.real_video_dir,\n        args.generated_video_dir,\n        num_frames=args.num_frames,\n        sample_rate=args.sample_rate,\n        crop_size=args.crop_size,\n        resolution=args.resolution,\n    )\n\n    if args.subset_size:\n        indices = range(args.subset_size)\n        dataset = Subset(dataset, indices=indices)\n\n    dataloader = DataLoader(\n        dataset, args.batch_size, num_workers=num_workers, pin_memory=True\n    )\n\n    metric_score = calculate_common_metric(args, dataloader, device)\n    print(metric_score)\n\n\ndef parse_args():\n    parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)\n    parser.add_argument(\"--batch_size\", type=int, default=2, help=\"Batch size to use\")\n    parser.add_argument(\"--real_video_dir\", type=str, help=(\"the path of real videos`\"))\n    parser.add_argument(\n        \"--generated_video_dir\", type=str, help=(\"the path of generated videos`\")\n    )\n    parser.add_argument(\n        \"--device\",\n        type=str,\n        default=None,\n        help=\"Device to use. Like cuda, cuda:0 or cpu\",\n    )\n    parser.add_argument(\n        \"--num_workers\",\n        type=int,\n        default=8,\n        help=(\n            \"Number of processes to use for data loading. \"\n            \"Defaults to `min(8, num_cpus)`\"\n        ),\n    )\n    parser.add_argument(\"--sample_fps\", type=int, default=30)\n    parser.add_argument(\"--resolution\", type=int, default=336)\n    parser.add_argument(\"--crop_size\", type=int, default=None)\n    parser.add_argument(\"--num_frames\", type=int, default=100)\n    parser.add_argument(\"--sample_rate\", type=int, default=1)\n    parser.add_argument(\"--subset_size\", type=int, default=None)\n    parser.add_argument(\n        \"--metric\",\n        type=str,\n        default=\"fvd\",\n        choices=[\"fvd\", \"psnr\", \"ssim\", \"lpips\", \"flolpips\"],\n    )\n    args = parser.parse_args()\n    return args\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main()\n"
  },
  {
    "path": "opensora/models/causalvideovae/eval/fvd/styleganv/fvd.py",
    "content": "import torch\nimport os\nimport math\nimport torch.nn.functional as F\n\n# https://github.com/universome/fvd-comparison\n\n\ndef load_i3d_pretrained(device=torch.device('cpu')):\n    i3D_WEIGHTS_URL = \"https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt\"\n    filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'i3d_torchscript.pt')\n    print(filepath)\n    if not os.path.exists(filepath):\n        print(f\"preparing for download {i3D_WEIGHTS_URL}, you can download it by yourself.\")\n        os.system(f\"wget {i3D_WEIGHTS_URL} -O {filepath}\")\n    i3d = torch.jit.load(filepath).eval().to(device)\n    i3d = torch.nn.DataParallel(i3d)\n    return i3d\n    \n\ndef get_feats(videos, detector, device, bs=10):\n    # videos : torch.tensor BCTHW [0, 1]\n    detector_kwargs = dict(rescale=False, resize=False, return_features=True) # Return raw features before the softmax layer.\n    feats = np.empty((0, 400))\n    with torch.no_grad():\n        for i in range((len(videos)-1)//bs + 1):\n            feats = np.vstack([feats, detector(torch.stack([preprocess_single(video) for video in videos[i*bs:(i+1)*bs]]).to(device), **detector_kwargs).detach().cpu().numpy()])\n    return feats\n\n\ndef get_fvd_feats(videos, i3d, device, bs=10):\n    # videos in [0, 1] as torch tensor BCTHW\n    # videos = [preprocess_single(video) for video in videos]\n    embeddings = get_feats(videos, i3d, device, bs)\n    return embeddings\n\n\ndef preprocess_single(video, resolution=224, sequence_length=None):\n    # video: CTHW, [0, 1]\n    c, t, h, w = video.shape\n\n    # temporal crop\n    if sequence_length is not None:\n        assert sequence_length <= t\n        video = video[:, :sequence_length]\n\n    # scale shorter side to resolution\n    scale = resolution / min(h, w)\n    if h < w:\n        target_size = (resolution, math.ceil(w * scale))\n    else:\n        target_size = (math.ceil(h * scale), resolution)\n    video = F.interpolate(video, size=target_size, mode='bilinear', align_corners=False)\n\n    # center crop\n    c, t, h, w = video.shape\n    w_start = (w - resolution) // 2\n    h_start = (h - resolution) // 2\n    video = video[:, :, h_start:h_start + resolution, w_start:w_start + resolution]\n\n    # [0, 1] -> [-1, 1]\n    video = (video - 0.5) * 2\n\n    return video.contiguous()\n\n\n\"\"\"\nCopy-pasted from https://github.com/cvpr2022-stylegan-v/stylegan-v/blob/main/src/metrics/frechet_video_distance.py\n\"\"\"\nfrom typing import Tuple\nfrom scipy.linalg import sqrtm\nimport numpy as np\n\n\ndef compute_stats(feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:\n    mu = feats.mean(axis=0) # [d]\n    sigma = np.cov(feats, rowvar=False) # [d, d]\n    return mu, sigma\n\n\ndef frechet_distance(feats_fake: np.ndarray, feats_real: np.ndarray) -> float:\n    mu_gen, sigma_gen = compute_stats(feats_fake)\n    mu_real, sigma_real = compute_stats(feats_real)\n    m = np.square(mu_gen - mu_real).sum()\n    if feats_fake.shape[0]>1:\n        s, _ = sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member\n        fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))\n    else:\n        fid = np.real(m)\n    return float(fid)"
  },
  {
    "path": "opensora/models/causalvideovae/eval/fvd/videogpt/fvd.py",
    "content": "import torch\nimport os\nimport math\nimport torch.nn.functional as F\nimport numpy as np\nimport einops\n\ndef load_i3d_pretrained(device=torch.device('cpu')):\n    i3D_WEIGHTS_URL = \"https://onedrive.live.com/download?cid=78EEF3EB6AE7DBCB&resid=78EEF3EB6AE7DBCB%21199&authkey=AApKdFHPXzWLNyI\"\n    filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'i3d_pretrained_400.pt')\n    print(filepath)\n    if not os.path.exists(filepath):\n        print(f\"preparing for download {i3D_WEIGHTS_URL}, you can download it by yourself.\")\n        os.system(f\"wget {i3D_WEIGHTS_URL} -O {filepath}\")\n    from .pytorch_i3d import InceptionI3d\n    i3d = InceptionI3d(400, in_channels=3).eval().to(device)\n    i3d.load_state_dict(torch.load(filepath, map_location=device))\n    i3d = torch.nn.DataParallel(i3d)\n    return i3d\n\ndef preprocess_single(video, resolution, sequence_length=None):\n    # video: THWC, {0, ..., 255}\n    video = video.permute(0, 3, 1, 2).float() / 255. # TCHW\n    t, c, h, w = video.shape\n\n    # temporal crop\n    if sequence_length is not None:\n        assert sequence_length <= t\n        video = video[:sequence_length]\n\n    # scale shorter side to resolution\n    scale = resolution / min(h, w)\n    if h < w:\n        target_size = (resolution, math.ceil(w * scale))\n    else:\n        target_size = (math.ceil(h * scale), resolution)\n    video = F.interpolate(video, size=target_size, mode='bilinear',\n                          align_corners=False)\n\n    # center crop\n    t, c, h, w = video.shape\n    w_start = (w - resolution) // 2\n    h_start = (h - resolution) // 2\n    video = video[:, :, h_start:h_start + resolution, w_start:w_start + resolution]\n    video = video.permute(1, 0, 2, 3).contiguous() # CTHW\n\n    video -= 0.5\n\n    return video\n\ndef preprocess(videos, target_resolution=224):\n    # we should tras videos in [0-1] [b c t h w] as th.float \n    # -> videos in {0, ..., 255} [b t h w c] as np.uint8 array\n    videos = einops.rearrange(videos, 'b c t h w -> b t h w c')\n    videos = (videos*255).numpy().astype(np.uint8)\n\n    b, t, h, w, c = videos.shape\n    videos = torch.from_numpy(videos)\n    videos = torch.stack([preprocess_single(video, target_resolution) for video in videos])\n    return videos * 2 # [-0.5, 0.5] -> [-1, 1]\n\ndef get_fvd_logits(videos, i3d, device, bs=10):\n    videos = preprocess(videos)\n    embeddings = get_logits(i3d, videos, device, bs=10)\n    return embeddings\n\n# https://github.com/tensorflow/gan/blob/de4b8da3853058ea380a6152bd3bd454013bf619/tensorflow_gan/python/eval/classifier_metrics.py#L161\ndef _symmetric_matrix_square_root(mat, eps=1e-10):\n    u, s, v = torch.svd(mat)\n    si = torch.where(s < eps, s, torch.sqrt(s))\n    return torch.matmul(torch.matmul(u, torch.diag(si)), v.t())\n\n# https://github.com/tensorflow/gan/blob/de4b8da3853058ea380a6152bd3bd454013bf619/tensorflow_gan/python/eval/classifier_metrics.py#L400\ndef trace_sqrt_product(sigma, sigma_v):\n    sqrt_sigma = _symmetric_matrix_square_root(sigma)\n    sqrt_a_sigmav_a = torch.matmul(sqrt_sigma, torch.matmul(sigma_v, sqrt_sigma))\n    return torch.trace(_symmetric_matrix_square_root(sqrt_a_sigmav_a))\n\n# https://discuss.pytorch.org/t/covariance-and-gradient-support/16217/2\ndef cov(m, rowvar=False):\n    '''Estimate a covariance matrix given data.\n\n    Covariance indicates the level to which two variables vary together.\n    If we examine N-dimensional samples, `X = [x_1, x_2, ... x_N]^T`,\n    then the covariance matrix element `C_{ij}` is the covariance of\n    `x_i` and `x_j`. The element `C_{ii}` is the variance of `x_i`.\n\n    Args:\n        m: A 1-D or 2-D array containing multiple variables and observations.\n            Each row of `m` represents a variable, and each column a single\n            observation of all those variables.\n        rowvar: If `rowvar` is True, then each row represents a\n            variable, with observations in the columns. Otherwise, the\n            relationship is transposed: each column represents a variable,\n            while the rows contain observations.\n\n    Returns:\n        The covariance matrix of the variables.\n    '''\n    if m.dim() > 2:\n        raise ValueError('m has more than 2 dimensions')\n    if m.dim() < 2:\n        m = m.view(1, -1)\n    if not rowvar and m.size(0) != 1:\n        m = m.t()\n\n    fact = 1.0 / (m.size(1) - 1) # unbiased estimate\n    m -= torch.mean(m, dim=1, keepdim=True)\n    mt = m.t()  # if complex: mt = m.t().conj()\n    return fact * m.matmul(mt).squeeze()\n\n\ndef frechet_distance(x1, x2):\n    x1 = x1.flatten(start_dim=1)\n    x2 = x2.flatten(start_dim=1)\n    m, m_w = x1.mean(dim=0), x2.mean(dim=0)\n    sigma, sigma_w = cov(x1, rowvar=False), cov(x2, rowvar=False)\n    mean = torch.sum((m - m_w) ** 2)\n    if x1.shape[0]>1:\n        sqrt_trace_component = trace_sqrt_product(sigma, sigma_w)\n        trace = torch.trace(sigma + sigma_w) - 2.0 * sqrt_trace_component\n        fd = trace + mean\n    else:\n        fd = np.real(mean)\n    return float(fd)\n\n\ndef get_logits(i3d, videos, device, bs=10):\n    # assert videos.shape[0] % 16 == 0\n    with torch.no_grad():\n        logits = []\n        for i in range(0, videos.shape[0], bs):\n            batch = videos[i:i + bs].to(device)\n            # logits.append(i3d.module.extract_features(batch)) # wrong\n            logits.append(i3d(batch)) # right\n        logits = torch.cat(logits, dim=0)\n        return logits\n"
  },
  {
    "path": "opensora/models/causalvideovae/eval/fvd/videogpt/pytorch_i3d.py",
    "content": "# Original code from https://github.com/piergiaj/pytorch-i3d\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\n\nclass MaxPool3dSamePadding(nn.MaxPool3d):\n\n    def compute_pad(self, dim, s):\n        if s % self.stride[dim] == 0:\n            return max(self.kernel_size[dim] - self.stride[dim], 0)\n        else:\n            return max(self.kernel_size[dim] - (s % self.stride[dim]), 0)\n\n    def forward(self, x):\n        # compute 'same' padding\n        (batch, channel, t, h, w) = x.size()\n        out_t = np.ceil(float(t) / float(self.stride[0]))\n        out_h = np.ceil(float(h) / float(self.stride[1]))\n        out_w = np.ceil(float(w) / float(self.stride[2]))\n        pad_t = self.compute_pad(0, t)\n        pad_h = self.compute_pad(1, h)\n        pad_w = self.compute_pad(2, w)\n\n        pad_t_f = pad_t // 2\n        pad_t_b = pad_t - pad_t_f\n        pad_h_f = pad_h // 2\n        pad_h_b = pad_h - pad_h_f\n        pad_w_f = pad_w // 2\n        pad_w_b = pad_w - pad_w_f\n\n        pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b)\n        x = F.pad(x, pad)\n        return super(MaxPool3dSamePadding, self).forward(x)\n\n\nclass Unit3D(nn.Module):\n\n    def __init__(self, in_channels,\n                 output_channels,\n                 kernel_shape=(1, 1, 1),\n                 stride=(1, 1, 1),\n                 padding=0,\n                 activation_fn=F.relu,\n                 use_batch_norm=True,\n                 use_bias=False,\n                 name='unit_3d'):\n\n        \"\"\"Initializes Unit3D module.\"\"\"\n        super(Unit3D, self).__init__()\n\n        self._output_channels = output_channels\n        self._kernel_shape = kernel_shape\n        self._stride = stride\n        self._use_batch_norm = use_batch_norm\n        self._activation_fn = activation_fn\n        self._use_bias = use_bias\n        self.name = name\n        self.padding = padding\n\n        self.conv3d = nn.Conv3d(in_channels=in_channels,\n                                out_channels=self._output_channels,\n                                kernel_size=self._kernel_shape,\n                                stride=self._stride,\n                                padding=0, # we always want padding to be 0 here. We will dynamically pad based on input size in forward function\n                                bias=self._use_bias)\n\n        if self._use_batch_norm:\n            self.bn = nn.BatchNorm3d(self._output_channels, eps=1e-5, momentum=0.001)\n\n    def compute_pad(self, dim, s):\n        if s % self._stride[dim] == 0:\n            return max(self._kernel_shape[dim] - self._stride[dim], 0)\n        else:\n            return max(self._kernel_shape[dim] - (s % self._stride[dim]), 0)\n\n\n    def forward(self, x):\n        # compute 'same' padding\n        (batch, channel, t, h, w) = x.size()\n        out_t = np.ceil(float(t) / float(self._stride[0]))\n        out_h = np.ceil(float(h) / float(self._stride[1]))\n        out_w = np.ceil(float(w) / float(self._stride[2]))\n        pad_t = self.compute_pad(0, t)\n        pad_h = self.compute_pad(1, h)\n        pad_w = self.compute_pad(2, w)\n\n        pad_t_f = pad_t // 2\n        pad_t_b = pad_t - pad_t_f\n        pad_h_f = pad_h // 2\n        pad_h_b = pad_h - pad_h_f\n        pad_w_f = pad_w // 2\n        pad_w_b = pad_w - pad_w_f\n\n        pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b)\n        x = F.pad(x, pad)\n\n        x = self.conv3d(x)\n        if self._use_batch_norm:\n            x = self.bn(x)\n        if self._activation_fn is not None:\n            x = self._activation_fn(x)\n        return x\n\n\n\nclass InceptionModule(nn.Module):\n    def __init__(self, in_channels, out_channels, name):\n        super(InceptionModule, self).__init__()\n\n        self.b0 = Unit3D(in_channels=in_channels, output_channels=out_channels[0], kernel_shape=[1, 1, 1], padding=0,\n                         name=name+'/Branch_0/Conv3d_0a_1x1')\n        self.b1a = Unit3D(in_channels=in_channels, output_channels=out_channels[1], kernel_shape=[1, 1, 1], padding=0,\n                          name=name+'/Branch_1/Conv3d_0a_1x1')\n        self.b1b = Unit3D(in_channels=out_channels[1], output_channels=out_channels[2], kernel_shape=[3, 3, 3],\n                          name=name+'/Branch_1/Conv3d_0b_3x3')\n        self.b2a = Unit3D(in_channels=in_channels, output_channels=out_channels[3], kernel_shape=[1, 1, 1], padding=0,\n                          name=name+'/Branch_2/Conv3d_0a_1x1')\n        self.b2b = Unit3D(in_channels=out_channels[3], output_channels=out_channels[4], kernel_shape=[3, 3, 3],\n                          name=name+'/Branch_2/Conv3d_0b_3x3')\n        self.b3a = MaxPool3dSamePadding(kernel_size=[3, 3, 3],\n                                stride=(1, 1, 1), padding=0)\n        self.b3b = Unit3D(in_channels=in_channels, output_channels=out_channels[5], kernel_shape=[1, 1, 1], padding=0,\n                          name=name+'/Branch_3/Conv3d_0b_1x1')\n        self.name = name\n\n    def forward(self, x):\n        b0 = self.b0(x)\n        b1 = self.b1b(self.b1a(x))\n        b2 = self.b2b(self.b2a(x))\n        b3 = self.b3b(self.b3a(x))\n        return torch.cat([b0,b1,b2,b3], dim=1)\n\n\nclass InceptionI3d(nn.Module):\n    \"\"\"Inception-v1 I3D architecture.\n    The model is introduced in:\n        Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset\n        Joao Carreira, Andrew Zisserman\n        https://arxiv.org/pdf/1705.07750v1.pdf.\n    See also the Inception architecture, introduced in:\n        Going deeper with convolutions\n        Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed,\n        Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich.\n        http://arxiv.org/pdf/1409.4842v1.pdf.\n    \"\"\"\n\n    # Endpoints of the model in order. During construction, all the endpoints up\n    # to a designated `final_endpoint` are returned in a dictionary as the\n    # second return value.\n    VALID_ENDPOINTS = (\n        'Conv3d_1a_7x7',\n        'MaxPool3d_2a_3x3',\n        'Conv3d_2b_1x1',\n        'Conv3d_2c_3x3',\n        'MaxPool3d_3a_3x3',\n        'Mixed_3b',\n        'Mixed_3c',\n        'MaxPool3d_4a_3x3',\n        'Mixed_4b',\n        'Mixed_4c',\n        'Mixed_4d',\n        'Mixed_4e',\n        'Mixed_4f',\n        'MaxPool3d_5a_2x2',\n        'Mixed_5b',\n        'Mixed_5c',\n        'Logits',\n        'Predictions',\n    )\n\n    def __init__(self, num_classes=400, spatial_squeeze=True,\n                 final_endpoint='Logits', name='inception_i3d', in_channels=3, dropout_keep_prob=0.5):\n        \"\"\"Initializes I3D model instance.\n        Args:\n          num_classes: The number of outputs in the logit layer (default 400, which\n              matches the Kinetics dataset).\n          spatial_squeeze: Whether to squeeze the spatial dimensions for the logits\n              before returning (default True).\n          final_endpoint: The model contains many possible endpoints.\n              `final_endpoint` specifies the last endpoint for the model to be built\n              up to. In addition to the output at `final_endpoint`, all the outputs\n              at endpoints up to `final_endpoint` will also be returned, in a\n              dictionary. `final_endpoint` must be one of\n              InceptionI3d.VALID_ENDPOINTS (default 'Logits').\n          name: A string (optional). The name of this module.\n        Raises:\n          ValueError: if `final_endpoint` is not recognized.\n        \"\"\"\n\n        if final_endpoint not in self.VALID_ENDPOINTS:\n            raise ValueError('Unknown final endpoint %s' % final_endpoint)\n\n        super(InceptionI3d, self).__init__()\n        self._num_classes = num_classes\n        self._spatial_squeeze = spatial_squeeze\n        self._final_endpoint = final_endpoint\n        self.logits = None\n\n        if self._final_endpoint not in self.VALID_ENDPOINTS:\n            raise ValueError('Unknown final endpoint %s' % self._final_endpoint)\n\n        self.end_points = {}\n        end_point = 'Conv3d_1a_7x7'\n        self.end_points[end_point] = Unit3D(in_channels=in_channels, output_channels=64, kernel_shape=[7, 7, 7],\n                                            stride=(2, 2, 2), padding=(3,3,3),  name=name+end_point)\n        if self._final_endpoint == end_point: return\n\n        end_point = 'MaxPool3d_2a_3x3'\n        self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2),\n                                                             padding=0)\n        if self._final_endpoint == end_point: return\n\n        end_point = 'Conv3d_2b_1x1'\n        self.end_points[end_point] = Unit3D(in_channels=64, output_channels=64, kernel_shape=[1, 1, 1], padding=0,\n                                       name=name+end_point)\n        if self._final_endpoint == end_point: return\n\n        end_point = 'Conv3d_2c_3x3'\n        self.end_points[end_point] = Unit3D(in_channels=64, output_channels=192, kernel_shape=[3, 3, 3], padding=1,\n                                       name=name+end_point)\n        if self._final_endpoint == end_point: return\n\n        end_point = 'MaxPool3d_3a_3x3'\n        self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2),\n                                                             padding=0)\n        if self._final_endpoint == end_point: return\n\n        end_point = 'Mixed_3b'\n        self.end_points[end_point] = InceptionModule(192, [64,96,128,16,32,32], name+end_point)\n        if self._final_endpoint == end_point: return\n\n        end_point = 'Mixed_3c'\n        self.end_points[end_point] = InceptionModule(256, [128,128,192,32,96,64], name+end_point)\n        if self._final_endpoint == end_point: return\n\n        end_point = 'MaxPool3d_4a_3x3'\n        self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[3, 3, 3], stride=(2, 2, 2),\n                                                             padding=0)\n        if self._final_endpoint == end_point: return\n\n        end_point = 'Mixed_4b'\n        self.end_points[end_point] = InceptionModule(128+192+96+64, [192,96,208,16,48,64], name+end_point)\n        if self._final_endpoint == end_point: return\n\n        end_point = 'Mixed_4c'\n        self.end_points[end_point] = InceptionModule(192+208+48+64, [160,112,224,24,64,64], name+end_point)\n        if self._final_endpoint == end_point: return\n\n        end_point = 'Mixed_4d'\n        self.end_points[end_point] = InceptionModule(160+224+64+64, [128,128,256,24,64,64], name+end_point)\n        if self._final_endpoint == end_point: return\n\n        end_point = 'Mixed_4e'\n        self.end_points[end_point] = InceptionModule(128+256+64+64, [112,144,288,32,64,64], name+end_point)\n        if self._final_endpoint == end_point: return\n\n        end_point = 'Mixed_4f'\n        self.end_points[end_point] = InceptionModule(112+288+64+64, [256,160,320,32,128,128], name+end_point)\n        if self._final_endpoint == end_point: return\n\n        end_point = 'MaxPool3d_5a_2x2'\n        self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[2, 2, 2], stride=(2, 2, 2),\n                                                             padding=0)\n        if self._final_endpoint == end_point: return\n\n        end_point = 'Mixed_5b'\n        self.end_points[end_point] = InceptionModule(256+320+128+128, [256,160,320,32,128,128], name+end_point)\n        if self._final_endpoint == end_point: return\n\n        end_point = 'Mixed_5c'\n        self.end_points[end_point] = InceptionModule(256+320+128+128, [384,192,384,48,128,128], name+end_point)\n        if self._final_endpoint == end_point: return\n\n        end_point = 'Logits'\n        self.avg_pool = nn.AvgPool3d(kernel_size=[2, 7, 7],\n                                     stride=(1, 1, 1))\n        self.dropout = nn.Dropout(dropout_keep_prob)\n        self.logits = Unit3D(in_channels=384+384+128+128, output_channels=self._num_classes,\n                             kernel_shape=[1, 1, 1],\n                             padding=0,\n                             activation_fn=None,\n                             use_batch_norm=False,\n                             use_bias=True,\n                             name='logits')\n\n        self.build()\n\n\n    def replace_logits(self, num_classes):\n        self._num_classes = num_classes\n        self.logits = Unit3D(in_channels=384+384+128+128, output_channels=self._num_classes,\n                             kernel_shape=[1, 1, 1],\n                             padding=0,\n                             activation_fn=None,\n                             use_batch_norm=False,\n                             use_bias=True,\n                             name='logits')\n\n\n    def build(self):\n        for k in self.end_points.keys():\n            self.add_module(k, self.end_points[k])\n\n    def forward(self, x):\n        for end_point in self.VALID_ENDPOINTS:\n            if end_point in self.end_points:\n                x = self._modules[end_point](x) # use _modules to work with dataparallel\n\n        x = self.logits(self.dropout(self.avg_pool(x)))\n        if self._spatial_squeeze:\n            logits = x.squeeze(3).squeeze(3)\n        logits = logits.mean(dim=2)\n        # logits is batch X time X classes, which is what we want to work with\n        return logits\n\n\n    def extract_features(self, x):\n        for end_point in self.VALID_ENDPOINTS:\n            if end_point in self.end_points:\n                x = self._modules[end_point](x)\n        return self.avg_pool(x)"
  },
  {
    "path": "opensora/models/causalvideovae/eval/script/cal_clip_score.sh",
    "content": "# clip_score cross modality\npython eval_clip_score.py \\\n    --real_path path/to/image \\\n    --generated_path path/to/text \\\n    --batch-size 50 \\\n    --device \"cuda\"\n\n# clip_score within the same modality\npython eval_clip_score.py \\\n    --real_path path/to/textA \\\n    --generated_path path/to/textB \\\n    --real_flag txt \\\n    --generated_flag txt \\\n    --batch-size 50 \\\n    --device \"cuda\"\n\npython eval_clip_score.py \\\n    --real_path path/to/imageA \\\n    --generated_path path/to/imageB \\\n    --real_flag img \\\n    --generated_flag img \\\n    --batch-size 50 \\\n    --device \"cuda\"\n"
  },
  {
    "path": "opensora/models/causalvideovae/eval/script/cal_fvd.sh",
    "content": "python eval_common_metric.py \\\n    --real_video_dir path/to/imageA\\\n    --generated_video_dir path/to/imageB \\\n    --batch_size 10 \\\n    --crop_size 64 \\\n    --num_frames 20 \\\n    --device 'cuda' \\\n    --metric 'fvd' \\\n    --fvd_method 'styleganv'\n"
  },
  {
    "path": "opensora/models/causalvideovae/eval/script/cal_lpips.sh",
    "content": "python eval_common_metric.py \\\n    --real_video_dir path/to/imageA\\\n    --generated_video_dir path/to/imageB \\\n    --batch_size 10 \\\n    --num_frames 20 \\\n    --crop_size 64 \\\n    --device 'cuda' \\\n    --metric 'lpips'"
  },
  {
    "path": "opensora/models/causalvideovae/eval/script/cal_psnr.sh",
    "content": "\npython eval_common_metric.py \\\n    --real_video_dir /data/xiaogeng_liu/data/video1 \\\n    --generated_video_dir /data/xiaogeng_liu/data/video2 \\\n    --batch_size 10 \\\n    --num_frames 20 \\\n    --crop_size 64 \\\n    --device 'cuda' \\\n    --metric 'psnr'"
  },
  {
    "path": "opensora/models/causalvideovae/eval/script/cal_ssim.sh",
    "content": "python eval_common_metric.py \\\n    --real_video_dir /data/xiaogeng_liu/data/video1 \\\n    --generated_video_dir /data/xiaogeng_liu/data/video2 \\\n    --batch_size 10 \\\n    --num_frames 20 \\\n    --crop_size 64 \\\n    --device 'cuda' \\\n    --metric 'ssim'"
  },
  {
    "path": "opensora/models/causalvideovae/model/__init__.py",
    "content": "from .registry import ModelRegistry\nfrom .vae import (\n    CausalVAEModel, WFVAEModel\n)\n"
  },
  {
    "path": "opensora/models/causalvideovae/model/configuration_videobase.py",
    "content": "import json\nimport yaml\nfrom typing import TypeVar, Dict, Any\nfrom diffusers import ConfigMixin\n\nT = TypeVar('T', bound='VideoBaseConfiguration')\nclass VideoBaseConfiguration(ConfigMixin):\n    config_name = \"VideoBaseConfiguration\"\n    _nested_config_fields: Dict[str, Any] = {}\n    \n    def __init__(self, **kwargs):\n        pass\n    \n    def to_dict(self) -> Dict[str, Any]:\n        d = {}\n        for key, value in vars(self).items():\n            if isinstance(value, VideoBaseConfiguration):\n                d[key] = value.to_dict()  # Serialize nested VideoBaseConfiguration instances\n            elif isinstance(value, tuple):\n                d[key] = list(value)\n            else:\n                d[key] = value\n        return d\n    \n    def to_yaml_file(self, yaml_path: str):\n        with open(yaml_path, 'w') as yaml_file:\n            yaml.dump(self.to_dict(), yaml_file, default_flow_style=False)\n    \n    @classmethod\n    def load_from_yaml(cls: T, yaml_path: str) -> T:\n        with open(yaml_path, 'r') as yaml_file:\n            config_dict = yaml.safe_load(yaml_file)\n        for field, field_type in cls._nested_config_fields.items():\n            if field in config_dict:\n                config_dict[field] = field_type.load_from_dict(config_dict[field])\n        return cls(**config_dict)\n\n    @classmethod\n    def load_from_dict(cls: T, config_dict: Dict[str, Any]) -> T:\n        # Process nested configuration objects\n        for field, field_type in cls._nested_config_fields.items():\n            if field in config_dict:\n                config_dict[field] = field_type.load_from_dict(config_dict[field])\n        return cls(**config_dict)"
  },
  {
    "path": "opensora/models/causalvideovae/model/dataset_videobase.py",
    "content": "import os.path as osp\nimport random\nfrom glob import glob\n\nfrom torchvision import transforms\nimport numpy as np\nimport torch\nimport torch.utils.data as data\nimport torch.nn.functional as F\nfrom torchvision.transforms import Lambda\n\nfrom ..dataset.transform import ToTensorVideo, CenterCropVideo\nfrom ..utils.dataset_utils import DecordInit\n\ndef TemporalRandomCrop(total_frames, size):\n    \"\"\"\n    Performs a random temporal crop on a video sequence.\n\n    This function randomly selects a continuous frame sequence of length `size` from a video sequence.\n    `total_frames` indicates the total number of frames in the video sequence, and `size` represents the length of the frame sequence to be cropped.\n\n    Parameters:\n    - total_frames (int): The total number of frames in the video sequence.\n    - size (int): The length of the frame sequence to be cropped.\n\n    Returns:\n    - (int, int): A tuple containing two integers. The first integer is the starting frame index of the cropped sequence,\n                  and the second integer is the ending frame index (inclusive) of the cropped sequence.\n    \"\"\"\n    rand_end = max(0, total_frames - size - 1)\n    begin_index = random.randint(0, rand_end)\n    end_index = min(begin_index + size, total_frames)\n    return begin_index, end_index\n\ndef resize(x, resolution):\n    height, width = x.shape[-2:]\n    resolution = min(2 * resolution, height, width)\n    aspect_ratio = width / height\n    if width <= height:\n        new_width = resolution\n        new_height = int(resolution / aspect_ratio)\n    else:\n        new_height = resolution\n        new_width = int(resolution * aspect_ratio)\n    resized_x = F.interpolate(x, size=(new_height, new_width), mode='bilinear', align_corners=True, antialias=True)\n    return resized_x\n\nclass VideoDataset(data.Dataset):\n    \"\"\" Generic dataset for videos files stored in folders\n    Returns BCTHW videos in the range [-0.5, 0.5] \"\"\"\n    video_exts = ['avi', 'mp4', 'webm']\n    def __init__(self, video_folder, sequence_length, image_folder=None, train=True, resolution=64, sample_rate=1, dynamic_sample=True):\n\n        self.train = train\n        self.sequence_length = sequence_length\n        self.sample_rate = sample_rate\n        self.resolution = resolution\n        self.v_decoder = DecordInit()\n        self.video_folder = video_folder\n        self.dynamic_sample = dynamic_sample\n\n        self.transform = transforms.Compose([\n            ToTensorVideo(),\n            # Lambda(lambda x: resize(x, self.resolution)),\n            CenterCropVideo(self.resolution),\n            Lambda(lambda x: 2.0 * x - 1.0)\n        ])\n        print('Building datasets...')\n        self.samples = self._make_dataset()\n\n    def _make_dataset(self):\n        samples = []\n        samples += sum([glob(osp.join(self.video_folder, '**', f'*.{ext}'), recursive=True)\n                            for ext in self.video_exts], [])\n        return samples\n\n    def __len__(self):\n        return len(self.samples)\n\n    def __getitem__(self, idx):\n        video_path = self.samples[idx]\n        try:\n            video = self.decord_read(video_path)\n            video = self.transform(video)  # T C H W -> T C H W\n            video = video.transpose(0, 1)  # T C H W -> C T H W\n            return dict(video=video, label=\"\")\n        except Exception as e:\n            print(f'Error with {e}, {video_path}')\n            return self.__getitem__(random.randint(0, self.__len__()-1))\n\n    def decord_read(self, path):\n        decord_vr = self.v_decoder(path)\n        total_frames = len(decord_vr)\n        # Sampling video frames\n        if self.dynamic_sample:\n            sample_rate = random.randint(1, self.sample_rate)\n        else:\n            sample_rate = self.sample_rate\n        size = self.sequence_length * sample_rate\n        start_frame_ind, end_frame_ind = TemporalRandomCrop(total_frames, size)\n        # assert end_frame_ind - start_frame_ind >= self.num_frames\n        frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.sequence_length, dtype=int)\n\n        video_data = decord_vr.get_batch(frame_indice).asnumpy()\n        video_data = torch.from_numpy(video_data)\n        video_data = video_data.permute(0, 3, 1, 2)  # (T, H, W, C) -> (T C H W)\n        return video_data"
  },
  {
    "path": "opensora/models/causalvideovae/model/ema_model.py",
    "content": "class EMA:\n    def __init__(self, model, decay):\n        self.model = model\n        self.decay = decay\n        self.shadow = {}\n        self.backup = {}\n        \n    def register(self):\n        for name, param in self.model.named_parameters():\n            if param.requires_grad:\n                self.shadow[name] = param.data.clone()\n\n    def update(self):\n        for name, param in self.model.named_parameters():\n            if name in self.shadow:\n                new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]\n                self.shadow[name] = new_average.clone()\n\n    def apply_shadow(self):\n        for name, param in self.model.named_parameters():\n            if name in self.shadow:\n                self.backup[name] = param.data\n                param.data = self.shadow[name]\n\n    def restore(self):\n        for name, param in self.model.named_parameters():\n            if name in self.shadow:\n                param.data = self.backup[name]\n        self.backup = {}\n        \n        "
  },
  {
    "path": "opensora/models/causalvideovae/model/losses/__init__.py",
    "content": "from .perceptual_loss import LPIPSWithDiscriminator3D\n"
  },
  {
    "path": "opensora/models/causalvideovae/model/losses/discriminator.py",
    "content": "import functools\nimport torch.nn as nn\nfrom ..modules.conv import CausalConv3d\nfrom einops import rearrange\n\ndef weights_init(m):\n    classname = m.__class__.__name__\n    if classname.find('Conv') != -1:\n        nn.init.normal_(m.weight.data, 0.0, 0.02)\n    elif classname.find('BatchNorm') != -1:\n        nn.init.normal_(m.weight.data, 1.0, 0.02)\n        nn.init.constant_(m.bias.data, 0)\n\ndef weights_init_conv(m):\n    if hasattr(m, 'conv'):\n        m = m.conv\n    classname = m.__class__.__name__\n    if classname.find('Conv') != -1:\n        nn.init.normal_(m.weight.data, 0.0, 0.02)\n    elif classname.find('BatchNorm') != -1:\n        nn.init.normal_(m.weight.data, 1.0, 0.02)\n        nn.init.constant_(m.bias.data, 0)\n    \nclass NLayerDiscriminator3D(nn.Module):\n    \"\"\"Defines a 3D PatchGAN discriminator as in Pix2Pix but for 3D inputs.\"\"\"\n    def __init__(self, input_nc=1, ndf=64, n_layers=3, use_actnorm=False):\n        \"\"\"\n        Construct a 3D PatchGAN discriminator\n\n        Parameters:\n            input_nc (int)  -- the number of channels in input volumes\n            ndf (int)       -- the number of filters in the last conv layer\n            n_layers (int)  -- the number of conv layers in the discriminator\n            use_actnorm (bool) -- flag to use actnorm instead of batchnorm\n        \"\"\"\n        super(NLayerDiscriminator3D, self).__init__()\n        if not use_actnorm:\n            norm_layer = nn.BatchNorm3d\n        else:\n            raise NotImplementedError(\"Not implemented.\")\n        if type(norm_layer) == functools.partial:\n            use_bias = norm_layer.func != nn.BatchNorm3d\n        else:\n            use_bias = norm_layer != nn.BatchNorm3d\n\n        kw = 3\n        padw = 1\n        sequence = [nn.Conv3d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]\n        nf_mult = 1\n        nf_mult_prev = 1\n        for n in range(1, n_layers):  # gradually increase the number of filters\n            nf_mult_prev = nf_mult\n            nf_mult = min(2 ** n, 8)\n            sequence += [\n                nn.Conv3d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=(kw, kw, kw), stride=(2 if n==1 else 1,2,2), padding=padw, bias=use_bias),\n                norm_layer(ndf * nf_mult),\n                nn.LeakyReLU(0.2, True)\n            ]\n\n        nf_mult_prev = nf_mult\n        nf_mult = min(2 ** n_layers, 8)\n        sequence += [\n            nn.Conv3d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=(kw, kw, kw), stride=1, padding=padw, bias=use_bias),\n            norm_layer(ndf * nf_mult),\n            nn.LeakyReLU(0.2, True)\n        ]\n\n        sequence += [nn.Conv3d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]  # output 1 channel prediction map\n        self.main = nn.Sequential(*sequence)\n\n    def forward(self, input):\n        \"\"\"Standard forward.\"\"\"\n        return self.main(input)\n\n\n\n\n\n# class NLayerDiscriminator3D(nn.Module):\n#     \"\"\"Defines a 3D PatchGAN discriminator as in Pix2Pix but for 3D inputs.\"\"\"\n#     def __init__(self, input_nc=1, ndf=64, n_layers=3, use_actnorm=False):\n#         \"\"\"\n#         Construct a 3D PatchGAN discriminator\n\n#         Parameters:\n#             input_nc (int)  -- the number of channels in input volumes\n#             ndf (int)       -- the number of filters in the last conv layer\n#             n_layers (int)  -- the number of conv layers in the discriminator\n#             use_actnorm (bool) -- flag to use actnorm instead of batchnorm\n#         \"\"\"\n#         super(NLayerDiscriminator3D, self).__init__()\n#         if not use_actnorm:\n#             norm_layer = nn.BatchNorm3d\n#         else:\n#             raise NotImplementedError(\"Not implemented.\")\n#         if type(norm_layer) == functools.partial:\n#             use_bias = norm_layer.func != nn.BatchNorm3d\n#         else:\n#             use_bias = norm_layer != nn.BatchNorm3d\n\n#         kw = 4\n#         padw = 1\n#         sequence = [CausalConv3d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]\n#         nf_mult = 1\n#         nf_mult_prev = 1\n#         for n in range(1, n_layers):  # gradually increase the number of filters\n#             nf_mult_prev = nf_mult\n#             nf_mult = min(2 ** n, 8)\n#             sequence += [\n#                 CausalConv3d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=(kw, kw, kw), stride=(2 if n==1 else 1,2,2), padding=padw, bias=use_bias),\n#                 norm_layer(ndf * nf_mult),\n#                 nn.LeakyReLU(0.2, True)\n#             ]\n\n#         nf_mult_prev = nf_mult\n#         nf_mult = min(2 ** n_layers, 8)\n#         sequence += [\n#             CausalConv3d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=(kw, kw, kw), stride=1, padding=padw, bias=use_bias),\n#             norm_layer(ndf * nf_mult),\n#             nn.LeakyReLU(0.2, True)\n#         ]\n\n#         sequence += [CausalConv3d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]  # output 1 channel prediction map\n#         self.main = nn.Sequential(*sequence)\n\n#     def forward(self, input):\n#         \"\"\"Standard forward.\"\"\"\n#         return self.main(input)"
  },
  {
    "path": "opensora/models/causalvideovae/model/losses/lpips.py",
    "content": "\"\"\"Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models\"\"\"\n\nimport torch\nimport torch.nn as nn\nfrom torchvision import models\nfrom collections import namedtuple\nfrom .....utils.taming_download import get_ckpt_path\n\nclass LPIPS(nn.Module):\n    # Learned perceptual metric\n    def __init__(self, use_dropout=True):\n        super().__init__()\n        self.scaling_layer = ScalingLayer()\n        self.chns = [64, 128, 256, 512, 512]  # vg16 features\n        self.net = vgg16(pretrained=True, requires_grad=False)\n        self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)\n        self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)\n        self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)\n        self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)\n        self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)\n        self.load_from_pretrained()\n        for param in self.parameters():\n            param.requires_grad = False\n\n    def load_from_pretrained(self, name=\"vgg_lpips\"):\n        ckpt = get_ckpt_path(name, \".cache/lpips\")\n        self.load_state_dict(torch.load(ckpt, map_location=torch.device(\"cpu\")), strict=False)\n        print(\"loaded pretrained LPIPS loss from {}\".format(ckpt))\n\n    @classmethod\n    def from_pretrained(cls, name=\"vgg_lpips\"):\n        if name != \"vgg_lpips\":\n            raise NotImplementedError\n        model = cls()\n        ckpt = get_ckpt_path(name)\n        model.load_state_dict(torch.load(ckpt, map_location=torch.device(\"cpu\")), strict=False)\n        return model\n\n    def forward(self, input, target):\n        in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))\n        outs0, outs1 = self.net(in0_input), self.net(in1_input)\n        feats0, feats1, diffs = {}, {}, {}\n        lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]\n        for kk in range(len(self.chns)):\n            feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])\n            diffs[kk] = (feats0[kk] - feats1[kk]) ** 2\n\n        res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]\n        val = res[0]\n        for l in range(1, len(self.chns)):\n            val += res[l]\n        return val\n\n\nclass ScalingLayer(nn.Module):\n    def __init__(self):\n        super(ScalingLayer, self).__init__()\n        self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])\n        self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])\n\n    def forward(self, inp):\n        return (inp - self.shift) / self.scale\n\n\nclass NetLinLayer(nn.Module):\n    \"\"\" A single linear layer which does a 1x1 conv \"\"\"\n    def __init__(self, chn_in, chn_out=1, use_dropout=False):\n        super(NetLinLayer, self).__init__()\n        layers = [nn.Dropout(), ] if (use_dropout) else []\n        layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]\n        self.model = nn.Sequential(*layers)\n\n\nclass vgg16(torch.nn.Module):\n    def __init__(self, requires_grad=False, pretrained=True):\n        super(vgg16, self).__init__()\n        vgg_pretrained_features = models.vgg16(pretrained=pretrained).features\n        self.slice1 = torch.nn.Sequential()\n        self.slice2 = torch.nn.Sequential()\n        self.slice3 = torch.nn.Sequential()\n        self.slice4 = torch.nn.Sequential()\n        self.slice5 = torch.nn.Sequential()\n        self.N_slices = 5\n        for x in range(4):\n            self.slice1.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(4, 9):\n            self.slice2.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(9, 16):\n            self.slice3.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(16, 23):\n            self.slice4.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(23, 30):\n            self.slice5.add_module(str(x), vgg_pretrained_features[x])\n        if not requires_grad:\n            for param in self.parameters():\n                param.requires_grad = False\n\n    def forward(self, X):\n        h = self.slice1(X)\n        h_relu1_2 = h\n        h = self.slice2(h)\n        h_relu2_2 = h\n        h = self.slice3(h)\n        h_relu3_3 = h\n        h = self.slice4(h)\n        h_relu4_3 = h\n        h = self.slice5(h)\n        h_relu5_3 = h\n        vgg_outputs = namedtuple(\"VggOutputs\", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])\n        out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)\n        return out\n\n\ndef normalize_tensor(x,eps=1e-10):\n    norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))\n    return x/(norm_factor+eps)\n\n\ndef spatial_average(x, keepdim=True):\n    return x.mean([2,3],keepdim=keepdim)\n"
  },
  {
    "path": "opensora/models/causalvideovae/model/losses/perceptual_loss.py",
    "content": "import torch\nfrom torch import nn\nimport torch.nn.functional as F\nfrom .lpips import LPIPS\nfrom einops import rearrange\nfrom .discriminator import weights_init, NLayerDiscriminator3D\n\ndef hinge_d_loss(logits_real, logits_fake):\n    loss_real = torch.mean(F.relu(1.0 - logits_real))\n    loss_fake = torch.mean(F.relu(1.0 + logits_fake))\n    d_loss = 0.5 * (loss_real + loss_fake)\n    return d_loss\n\n\ndef vanilla_d_loss(logits_real, logits_fake):\n    d_loss = 0.5 * (\n        torch.mean(torch.nn.functional.softplus(-logits_real))\n        + torch.mean(torch.nn.functional.softplus(logits_fake))\n    )\n    return d_loss\n\n\ndef hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):\n    assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]\n    loss_real = torch.mean(F.relu(1.0 - logits_real), dim=[1, 2, 3])\n    loss_fake = torch.mean(F.relu(1.0 + logits_fake), dim=[1, 2, 3])\n    loss_real = (weights * loss_real).sum() / weights.sum()\n    loss_fake = (weights * loss_fake).sum() / weights.sum()\n    d_loss = 0.5 * (loss_real + loss_fake)\n    return d_loss\n\n\ndef adopt_weight(weight, global_step, threshold=0, value=0.0):\n    if global_step < threshold:\n        weight = value\n    return weight\n\n\ndef measure_perplexity(predicted_indices, n_embed):\n    # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py\n    # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally\n    encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)\n    avg_probs = encodings.mean(0)\n    perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()\n    cluster_use = torch.sum(avg_probs > 0)\n    return perplexity, cluster_use\n\n\ndef l1(x, y):\n    return torch.abs(x - y)\n\n\ndef l2(x, y):\n    return torch.pow((x - y), 2)\n\n\nclass LPIPSWithDiscriminator3D(nn.Module):\n    def __init__(\n        self,\n        disc_start,\n        logvar_init=0.0,\n        kl_weight=1.0,\n        pixelloss_weight=1.0,\n        perceptual_weight=1.0,\n        disc_num_layers=4,\n        disc_in_channels=3,\n        disc_factor=1.0,\n        disc_weight=1.0,\n        use_actnorm=False,\n        disc_conditional=False,\n        disc_loss=\"hinge\",\n        learn_logvar: bool = False,\n        wavelet_weight=0.01,\n        loss_type: str = \"l1\",\n    ):\n\n        super().__init__()\n        assert disc_loss in [\"hinge\", \"vanilla\"]\n        self.wavelet_weight = wavelet_weight\n        self.kl_weight = kl_weight\n        self.pixel_weight = pixelloss_weight\n        self.perceptual_loss = LPIPS().eval()\n        self.perceptual_weight = perceptual_weight\n        self.logvar = nn.Parameter(\n            torch.full((), logvar_init), requires_grad=learn_logvar\n        )\n        self.discriminator = NLayerDiscriminator3D(\n            input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=use_actnorm\n        ).apply(weights_init)\n        self.discriminator_iter_start = disc_start\n        self.disc_loss = hinge_d_loss if disc_loss == \"hinge\" else vanilla_d_loss\n        self.disc_factor = disc_factor\n        self.discriminator_weight = disc_weight\n        self.disc_conditional = disc_conditional\n        self.loss_func = l1 if loss_type == \"l1\" else l2\n\n    def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):\n        layer = last_layer if last_layer is not None else self.last_layer[0]\n\n        nll_grads = torch.autograd.grad(nll_loss, layer, retain_graph=True)[0]\n        g_grads = torch.autograd.grad(g_loss, layer, retain_graph=True)[0]\n\n        d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)\n        d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()\n        d_weight = d_weight * self.discriminator_weight\n        return d_weight\n\n    def forward(\n        self,\n        inputs,\n        reconstructions,\n        posteriors,\n        optimizer_idx,\n        global_step,\n        split=\"train\",\n        weights=None,\n        last_layer=None,\n        wavelet_coeffs=None,\n        cond=None,\n    ):\n        bs = inputs.shape[0]\n        t = inputs.shape[2]\n        if optimizer_idx == 0:\n            inputs = rearrange(inputs, \"b c t h w -> (b t) c h w\").contiguous()\n            reconstructions = rearrange(\n                reconstructions, \"b c t h w -> (b t) c h w\"\n            ).contiguous()\n            rec_loss = self.loss_func(inputs, reconstructions)\n            if self.perceptual_weight > 0:\n                p_loss = self.perceptual_loss(inputs, reconstructions)\n                rec_loss = rec_loss + self.perceptual_weight * p_loss\n            nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar\n            weighted_nll_loss = nll_loss\n            if weights is not None:\n                weighted_nll_loss = weights * nll_loss\n            weighted_nll_loss = (\n                torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]\n            )\n            nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]\n            kl_loss = posteriors.kl()\n            kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]\n\n            if wavelet_coeffs:\n                wl_loss_l2 = torch.sum(l1(wavelet_coeffs[0], wavelet_coeffs[1])) / bs\n                wl_loss_l3 = torch.sum(l1(wavelet_coeffs[2], wavelet_coeffs[3])) / bs\n                wl_loss = wl_loss_l2 + wl_loss_l3\n            else:\n                wl_loss = torch.tensor(0.0)\n\n            inputs = rearrange(inputs, \"(b t) c h w -> b c t h w\", t=t).contiguous()\n            reconstructions = rearrange(\n                reconstructions, \"(b t) c h w -> b c t h w\", t=t\n            ).contiguous()\n\n            logits_fake = self.discriminator(reconstructions)\n            g_loss = -torch.mean(logits_fake)\n            if global_step >= self.discriminator_iter_start:\n                if self.disc_factor > 0.0:\n                    d_weight = self.calculate_adaptive_weight(\n                        nll_loss, g_loss, last_layer=last_layer\n                    )\n                else:\n                    d_weight = torch.tensor(1.0)\n            else:\n                d_weight = torch.tensor(0.0)\n                g_loss = torch.tensor(0.0, requires_grad=True)\n\n            disc_factor = adopt_weight(\n                self.disc_factor, global_step, threshold=self.discriminator_iter_start\n            )\n            loss = (\n                weighted_nll_loss\n                + self.kl_weight * kl_loss\n                + d_weight * disc_factor * g_loss\n                + self.wavelet_weight * wl_loss\n            )\n            log = {\n                \"{}/total_loss\".format(split): loss.clone().detach().mean(),\n                \"{}/logvar\".format(split): self.logvar.detach(),\n                \"{}/kl_loss\".format(split): kl_loss.detach().mean(),\n                \"{}/nll_loss\".format(split): nll_loss.detach().mean(),\n                \"{}/rec_loss\".format(split): weighted_nll_loss.detach().mean(),\n                \"{}/wl_loss\".format(split): wl_loss.detach().mean(),\n                \"{}/d_weight\".format(split): d_weight.detach(),\n                \"{}/disc_factor\".format(split): torch.tensor(disc_factor),\n                \"{}/g_loss\".format(split): g_loss.detach().mean(),\n            }\n            return loss, log\n        elif optimizer_idx == 1:\n            logits_real = self.discriminator(inputs.contiguous().detach())\n            logits_fake = self.discriminator(reconstructions.contiguous().detach())\n\n            disc_factor = adopt_weight(\n                self.disc_factor, global_step, threshold=self.discriminator_iter_start\n            )\n\n            d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)\n\n            log = {\n                \"{}/disc_loss\".format(split): d_loss.clone().detach().mean(),\n                \"{}/logits_real\".format(split): logits_real.detach().mean(),\n                \"{}/logits_fake\".format(split): logits_fake.detach().mean(),\n            }\n            return d_loss, log\n"
  },
  {
    "path": "opensora/models/causalvideovae/model/modeling_videobase.py",
    "content": "import torch\nfrom diffusers import ModelMixin, ConfigMixin\nfrom torch import nn\nimport os\nimport json\nfrom diffusers.configuration_utils import ConfigMixin\nfrom diffusers.models.modeling_utils import ModelMixin\nfrom typing import Optional, Union\nimport glob\n\n\nclass VideoBaseAE(ModelMixin, ConfigMixin):\n    config_name = \"config.json\"\n    \n    def __init__(self, *args, **kwargs) -> None:\n        super().__init__(*args, **kwargs)\n    \n    def encode(self, x: torch.Tensor, *args, **kwargs):\n        pass\n\n    def decode(self, encoding: torch.Tensor, *args, **kwargs):\n        pass\n    \n    @property\n    def num_training_steps(self) -> int:\n        \"\"\"Total training steps inferred from datamodule and devices.\"\"\"\n        if self.trainer.max_steps:\n            return self.trainer.max_steps\n    \n        limit_batches = self.trainer.limit_train_batches\n        batches = len(self.train_dataloader())\n        batches = min(batches, limit_batches) if isinstance(limit_batches, int) else int(limit_batches * batches)     \n    \n        num_devices = max(1, self.trainer.num_gpus, self.trainer.num_processes)\n        if self.trainer.tpu_cores:\n            num_devices = max(num_devices, self.trainer.tpu_cores)\n    \n        effective_accum = self.trainer.accumulate_grad_batches * num_devices\n        return (batches // effective_accum) * self.trainer.max_epochs\n    \n    @classmethod\n    def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):\n        ckpt_files = glob.glob(os.path.join(pretrained_model_name_or_path, '*.ckpt'))\n        if ckpt_files:\n            # Adapt to checkpoint\n            last_ckpt_file = ckpt_files[-1]\n            config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)\n            model = cls.from_config(config_file)\n            model.init_from_ckpt(last_ckpt_file)\n            return model\n        else:\n            return super().from_pretrained(pretrained_model_name_or_path, **kwargs)"
  },
  {
    "path": "opensora/models/causalvideovae/model/modules/__init__.py",
    "content": "from .block import Block\nfrom .attention import *\nfrom .conv import *\nfrom .normalize import *\nfrom .resnet_block import *\nfrom .updownsample import *\nfrom .wavelet import *"
  },
  {
    "path": "opensora/models/causalvideovae/model/modules/attention.py",
    "content": "import torch.nn as nn\nimport torch.nn.functional as F\nfrom .normalize import Normalize\nfrom .conv import CausalConv3d\nimport torch\nfrom .block import Block\n\ntry:\n    import torch_npu\n    from opensora.npu_config import npu_config, set_run_dtype\nexcept:\n    torch_npu = None\n    npu_config = None\n    # from xformers import ops as xops\n\nclass AttnBlock3D(Block):\n    \"\"\"Compatible with old versions, there are issues, use with caution.\"\"\"\n    def __init__(self, in_channels):\n        super().__init__()\n        self.in_channels = in_channels\n\n        self.norm = Normalize(in_channels)\n        self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)\n        self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)\n        self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)\n        self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)\n\n    def forward(self, x):\n        h_ = x\n        h_ = self.norm(h_)\n        q = self.q(h_)\n        k = self.k(h_)\n        v = self.v(h_)\n\n        # compute attention\n        b, c, t, h, w = q.shape\n        q = q.reshape(b * t, c, h * w)\n        q = q.permute(0, 2, 1)  # b,hw,c\n        k = k.reshape(b * t, c, h * w)  # b,c,hw\n        w_ = torch.bmm(q, k)  # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]\n        w_ = w_ * (int(c) ** (-0.5))\n        w_ = torch.nn.functional.softmax(w_, dim=2)\n\n        # attend to values\n        v = v.reshape(b * t, c, h * w)\n        w_ = w_.permute(0, 2, 1)  # b,hw,hw (first hw of k, second of q)\n        h_ = torch.bmm(v, w_)  # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]\n        h_ = h_.reshape(b, c, t, h, w)\n\n        h_ = self.proj_out(h_)\n\n        return x + h_\n\nclass AttnBlock3DFix(nn.Module):\n    \"\"\"\n    Thanks to https://github.com/PKU-YuanGroup/Open-Sora-Plan/pull/172.\n    \"\"\"\n    def __init__(self, in_channels, norm_type=\"groupnorm\"):\n        super().__init__()\n        self.in_channels = in_channels\n\n        self.norm = Normalize(in_channels, norm_type=norm_type)\n        self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)\n        self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)\n        self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)\n        self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)\n\n    def forward(self, x):\n        h_ = x\n        h_ = self.norm(h_)\n        q = self.q(h_)\n        k = self.k(h_)\n        v = self.v(h_)\n\n        b, c, t, h, w = q.shape\n        q = q.permute(0, 2, 3, 4, 1).reshape(b * t, h * w, c).contiguous()\n        k = k.permute(0, 2, 3, 4, 1).reshape(b * t, h * w, c).contiguous()\n        v = v.permute(0, 2, 3, 4, 1).reshape(b * t, h * w, c).contiguous()\n        \n        if torch_npu is None:\n            # attn_output = xops.memory_efficient_attention(\n            #     q, k, v,\n            #     scale=c ** -0.5\n            # )\n            q = q.view(b * t, -1, 1, c).transpose(1, 2)\n            k = k.view(b * t, -1, 1, c).transpose(1, 2)\n            v = v.view(b * t, -1, 1, c).transpose(1, 2)\n\n            attn_output = F.scaled_dot_product_attention(\n                q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False\n            )\n            attn_output = attn_output.transpose(1, 2).reshape(b * t, -1, 1 * c)\n\n        else:\n            # print('npu_config.enable_FA, q.dtype == torch.float32', npu_config.enable_FA, q.dtype == torch.float32)\n            if npu_config.enable_FA and q.dtype == torch.float32:\n                dtype = torch.bfloat16\n            else:\n                dtype = None\n            with set_run_dtype(q, dtype):\n                query, key, value = npu_config.set_current_run_dtype([q, k, v])\n                hidden_states = npu_config.run_attention(query, key, value, atten_mask=None, input_layout=\"BSH\",\n                                                            head_dim=c, head_num=1)\n\n                attn_output = npu_config.restore_dtype(hidden_states)\n\n        attn_output = attn_output.reshape(b, t, h, w, c).permute(0, 4, 1, 2, 3)\n        h_ = self.proj_out(attn_output)\n\n        return x + h_\n"
  },
  {
    "path": "opensora/models/causalvideovae/model/modules/block.py",
    "content": "import torch.nn as nn\n\nclass Block(nn.Module):\n    def __init__(self, *args, **kwargs) -> None:\n        super().__init__(*args, **kwargs)"
  },
  {
    "path": "opensora/models/causalvideovae/model/modules/conv.py",
    "content": "try:\n    import torch_npu\n    from opensora.npu_config import npu_config\nexcept:\n    torch_npu = None\n    npu_config = None\n    \nimport torch.nn as nn\nfrom typing import Union, Tuple\nimport torch\nfrom .block import Block\nfrom .ops import cast_tuple\nfrom .ops import video_to_image\nfrom torch.utils.checkpoint import checkpoint\nimport torch.nn.functional as F\nfrom collections import deque\n\nclass Conv2d(nn.Conv2d):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        kernel_size: Union[int, Tuple[int]] = 3,\n        stride: Union[int, Tuple[int]] = 1,\n        padding: Union[str, int, Tuple[int]] = 0,\n        dilation: Union[int, Tuple[int]] = 1,\n        groups: int = 1,\n        bias: bool = True,\n        padding_mode: str = \"zeros\",\n        device=None,\n        dtype=None,\n    ) -> None:\n        super().__init__(\n            in_channels,\n            out_channels,\n            kernel_size,\n            stride,\n            padding,\n            dilation,\n            groups,\n            bias,\n            padding_mode,\n            device,\n            dtype,\n        )\n\n    @video_to_image\n    def forward(self, x):\n        return super().forward(x)\n\n\n\nclass CausalConv3d(Block):\n    def __init__(\n        self,\n        chan_in,\n        chan_out,\n        kernel_size: Union[int, Tuple[int, int, int]],\n        enable_cached=False,\n        bias=True,\n        **kwargs\n    ):\n        super().__init__()\n        self.kernel_size = cast_tuple(kernel_size, 3)\n        self.time_kernel_size = self.kernel_size[0]\n        self.chan_in = chan_in\n        self.chan_out = chan_out\n        self.stride = kwargs.pop(\"stride\", 1)\n        self.padding = kwargs.pop(\"padding\", 0)\n        self.padding = list(cast_tuple(self.padding, 3))\n        self.padding[0] = 0\n        self.stride = cast_tuple(self.stride, 3)\n        self.conv = nn.Conv3d(\n            chan_in,\n            chan_out,\n            self.kernel_size,\n            stride=self.stride,\n            padding=self.padding,\n            bias=bias\n        )\n        self.enable_cached = enable_cached\n        \n        self.is_first_chunk = True\n        \n        self.causal_cached = deque()\n        self.cache_offset = 0\n\n    def forward(self, x):\n        if self.is_first_chunk:\n            first_frame_pad = x[:, :, :1, :, :].repeat(\n                (1, 1, self.time_kernel_size - 1, 1, 1)\n            )\n        else:\n            first_frame_pad = self.causal_cached.popleft()\n            \n        x = torch.concatenate((first_frame_pad, x), dim=2)\n\n        if self.enable_cached and self.time_kernel_size != 1:\n            if (self.time_kernel_size - 1) // self.stride[0] != 0:\n                if self.cache_offset == 0:\n                    self.causal_cached.append(x[:, :, -(self.time_kernel_size - 1) // self.stride[0]:].clone())\n                else:\n                    self.causal_cached.append(x[:, :, :-self.cache_offset][:, :, -(self.time_kernel_size - 1) // self.stride[0]:].clone())\n            else:\n                self.causal_cached.append(x[:, :, 0:0, :, :].clone())\n        elif self.enable_cached:\n            self.causal_cached.append(x[:, :, 0:0, :, :].clone())\n            \n        x = self.conv(x)\n        return x\n\n\nclass CausalConv3d_GC(CausalConv3d):\n    def __init__(\n        self,\n        chan_in,\n        chan_out,\n        kernel_size: Union[int, Tuple[int]],\n        init_method=\"random\",\n        **kwargs\n    ):\n        super().__init__(chan_in, chan_out, kernel_size, init_method, **kwargs)\n\n    def forward(self, x):\n        # 1 + 16   16 as video, 1 as image\n        first_frame_pad = x[:, :, :1, :, :].repeat(\n            (1, 1, self.time_kernel_size - 1, 1, 1)\n        )  # b c t h w\n        x = torch.concatenate((first_frame_pad, x), dim=2)  # 3 + 16\n        return checkpoint(self.conv, x)\n"
  },
  {
    "path": "opensora/models/causalvideovae/model/modules/normalize.py",
    "content": "import torch\nimport torch.nn as nn\nfrom .block import Block\nfrom einops import rearrange\n\nclass GroupNorm(Block):\n    def __init__(self, num_channels, num_groups=32, eps=1e-6, *args, **kwargs) -> None:\n        super().__init__(*args, **kwargs)\n        self.norm = torch.nn.GroupNorm(\n            num_groups=num_groups, num_channels=num_channels, eps=eps, affine=True\n        )\n    def forward(self, x):\n        return self.norm(x)\n\nclass LayerNorm(Block):\n    def __init__(self, num_channels, eps=1e-6, *args, **kwargs) -> None:\n        super().__init__(*args, **kwargs)\n        self.norm = torch.nn.LayerNorm(num_channels, eps=eps, elementwise_affine=True)\n    def forward(self, x):\n        if x.dim() == 5:\n            x = rearrange(x, \"b c t h w -> b t h w c\")\n            x = self.norm(x)\n            x = rearrange(x, \"b t h w c -> b c t h w\")\n        else:\n            x = rearrange(x, \"b c h w -> b h w c\")\n            x = self.norm(x)\n            x = rearrange(x, \"b h w c -> b c h w\")\n        return x\n\ndef Normalize(in_channels, num_groups=32, norm_type=\"groupnorm\"):\n    if norm_type == \"groupnorm\":\n        return torch.nn.GroupNorm(\n            num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True\n        )\n    elif norm_type == \"layernorm\":\n        return LayerNorm(num_channels=in_channels, eps=1e-6)"
  },
  {
    "path": "opensora/models/causalvideovae/model/modules/ops.py",
    "content": "import torch\nfrom einops import rearrange\n\ndef video_to_image(func):\n    def wrapper(self, x, *args, **kwargs):\n        if x.dim() == 5:\n            t = x.shape[2]\n            if True:\n                x = rearrange(x, \"b c t h w -> (b t) c h w\")\n                x = func(self, x, *args, **kwargs)\n                x = rearrange(x, \"(b t) c h w -> b c t h w\", t=t)\n            else:\n                # Conv 2d slice infer\n                result = []\n                for i in range(t):\n                    frame = x[:, :, i, :, :]\n                    frame = func(self, frame, *args, **kwargs)\n                    result.append(frame.unsqueeze(2))\n                x = torch.concatenate(result, dim=2)\n        return x\n    return wrapper\n\ndef nonlinearity(x):\n    return x * torch.sigmoid(x)\n\ndef cast_tuple(t, length=1):\n    return t if isinstance(t, tuple) or isinstance(t, list) else ((t,) * length)\n\ndef shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True):\n    n_dims = len(x.shape)\n    if src_dim < 0:\n        src_dim = n_dims + src_dim\n    if dest_dim < 0:\n        dest_dim = n_dims + dest_dim\n    assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims\n    dims = list(range(n_dims))\n    del dims[src_dim]\n    permutation = []\n    ctr = 0\n    for i in range(n_dims):\n        if i == dest_dim:\n            permutation.append(src_dim)\n        else:\n            permutation.append(dims[ctr])\n            ctr += 1\n    x = x.permute(permutation)\n    if make_contiguous:\n        x = x.contiguous()\n    return x"
  },
  {
    "path": "opensora/models/causalvideovae/model/modules/quant.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.distributed as dist\nimport numpy as np\nimport torch.nn.functional as F\nfrom .ops import shift_dim\n\nclass Codebook(nn.Module):\n    def __init__(self, n_codes, embedding_dim):\n        super().__init__()\n        self.register_buffer(\"embeddings\", torch.randn(n_codes, embedding_dim))\n        self.register_buffer(\"N\", torch.zeros(n_codes))\n        self.register_buffer(\"z_avg\", self.embeddings.data.clone())\n\n        self.n_codes = n_codes\n        self.embedding_dim = embedding_dim\n        self._need_init = True\n\n    def _tile(self, x):\n        d, ew = x.shape\n        if d < self.n_codes:\n            n_repeats = (self.n_codes + d - 1) // d\n            std = 0.01 / np.sqrt(ew)\n            x = x.repeat(n_repeats, 1)\n            x = x + torch.randn_like(x) * std\n        return x\n\n    def _init_embeddings(self, z):\n        # z: [b, c, t, h, w]\n        self._need_init = False\n        flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2)\n        y = self._tile(flat_inputs)\n\n        d = y.shape[0]\n        _k_rand = y[torch.randperm(y.shape[0])][: self.n_codes]\n        if dist.is_initialized():\n            dist.broadcast(_k_rand, 0)\n        self.embeddings.data.copy_(_k_rand)\n        self.z_avg.data.copy_(_k_rand)\n        self.N.data.copy_(torch.ones(self.n_codes))\n\n    def forward(self, z):\n        # z: [b, c, t, h, w]\n        if self._need_init and self.training:\n            self._init_embeddings(z)\n        flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2)\n        distances = (\n            (flat_inputs**2).sum(dim=1, keepdim=True)\n            - 2 * flat_inputs @ self.embeddings.t()\n            + (self.embeddings.t() ** 2).sum(dim=0, keepdim=True)\n        )\n\n        encoding_indices = torch.argmin(distances, dim=1)\n        encode_onehot = F.one_hot(encoding_indices, self.n_codes).type_as(flat_inputs)\n        encoding_indices = encoding_indices.view(z.shape[0], *z.shape[2:])\n\n        embeddings = F.embedding(encoding_indices, self.embeddings)\n        embeddings = shift_dim(embeddings, -1, 1)\n\n        commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach())\n\n        # EMA codebook update\n        if self.training:\n            n_total = encode_onehot.sum(dim=0)\n            encode_sum = flat_inputs.t() @ encode_onehot\n            if dist.is_initialized():\n                dist.all_reduce(n_total)\n                dist.all_reduce(encode_sum)\n\n            self.N.data.mul_(0.99).add_(n_total, alpha=0.01)\n            self.z_avg.data.mul_(0.99).add_(encode_sum.t(), alpha=0.01)\n\n            n = self.N.sum()\n            weights = (self.N + 1e-7) / (n + self.n_codes * 1e-7) * n\n            encode_normalized = self.z_avg / weights.unsqueeze(1)\n            self.embeddings.data.copy_(encode_normalized)\n\n            y = self._tile(flat_inputs)\n            _k_rand = y[torch.randperm(y.shape[0])][: self.n_codes]\n            if dist.is_initialized():\n                dist.broadcast(_k_rand, 0)\n\n            usage = (self.N.view(self.n_codes, 1) >= 1).float()\n            self.embeddings.data.mul_(usage).add_(_k_rand * (1 - usage))\n\n        embeddings_st = (embeddings - z).detach() + z\n\n        avg_probs = torch.mean(encode_onehot, dim=0)\n        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))\n\n        return dict(\n            embeddings=embeddings_st,\n            encodings=encoding_indices,\n            commitment_loss=commitment_loss,\n            perplexity=perplexity,\n        )\n\n    def dictionary_lookup(self, encodings):\n        embeddings = F.embedding(encodings, self.embeddings)\n        return embeddings"
  },
  {
    "path": "opensora/models/causalvideovae/model/modules/resnet_block.py",
    "content": "try:\n    import torch_npu\n    from opensora.npu_config import npu_config\nexcept:\n    torch_npu = None\n    npu_config = None\n\nimport torch\nfrom .normalize import Normalize\nfrom .ops import nonlinearity, video_to_image\nfrom .conv import CausalConv3d\nfrom .block import Block\nfrom torch.utils.checkpoint import checkpoint\n\n\nclass ResnetBlock2D(Block):\n    def __init__(\n        self,\n        *,\n        in_channels,\n        out_channels=None,\n        conv_shortcut=False,\n        norm_type,\n        dropout,\n    ):\n        super().__init__()\n        self.in_channels = in_channels\n        self.out_channels = in_channels if out_channels is None else out_channels\n        self.use_conv_shortcut = conv_shortcut\n\n        self.norm1 = Normalize(in_channels, norm_type=norm_type)\n        self.conv1 = torch.nn.Conv2d(\n            in_channels, out_channels, kernel_size=3, stride=1, padding=1\n        )\n        self.norm2 = Normalize(out_channels, norm_type=norm_type)\n        self.dropout = torch.nn.Dropout(dropout)\n        self.conv2 = torch.nn.Conv2d(\n            out_channels, out_channels, kernel_size=3, stride=1, padding=1\n        )\n        if self.in_channels != self.out_channels:\n            if self.use_conv_shortcut:\n                self.conv_shortcut = torch.nn.Conv2d(\n                    in_channels, out_channels, kernel_size=3, stride=1, padding=1\n                )\n            else:\n                self.nin_shortcut = torch.nn.Conv2d(\n                    in_channels, out_channels, kernel_size=1, stride=1, padding=0\n                )\n\n    @video_to_image\n    def forward(self, x):\n        h = x\n        if npu_config is None:\n            h = self.norm1(h)\n        else:\n            h = npu_config.run_group_norm(self.norm1, h)\n        h = nonlinearity(h)\n        h = self.conv1(h)\n        if npu_config is None:\n            h = self.norm2(h)\n        else:\n            h = npu_config.run_group_norm(self.norm2, h)\n        h = nonlinearity(h)\n        h = self.dropout(h)\n        h = self.conv2(h)\n        if self.in_channels != self.out_channels:\n            if self.use_conv_shortcut:\n                x = self.conv_shortcut(x)\n            else:\n                x = self.nin_shortcut(x)\n        x = x + h\n        return x\n\n\nclass ResnetBlock3D(Block):\n    def __init__(\n        self,\n        *,\n        in_channels,\n        out_channels=None,\n        conv_shortcut=False,\n        dropout,\n        norm_type,\n    ):\n        super().__init__()\n        self.in_channels = in_channels\n        self.out_channels = in_channels if out_channels is None else out_channels\n        self.use_conv_shortcut = conv_shortcut\n\n        self.norm1 = Normalize(in_channels, norm_type=norm_type)\n        self.conv1 = CausalConv3d(in_channels, out_channels, 3, padding=1)\n        self.norm2 = Normalize(out_channels, norm_type=norm_type)\n        self.dropout = torch.nn.Dropout(dropout)\n        self.conv2 = CausalConv3d(out_channels, out_channels, 3, padding=1)\n        if self.in_channels != self.out_channels:\n            if self.use_conv_shortcut:\n                self.conv_shortcut = CausalConv3d(\n                    in_channels, out_channels, 3, padding=1\n                )\n            else:\n                self.nin_shortcut = CausalConv3d(\n                    in_channels, out_channels, 1, padding=0\n                )\n\n    def forward(self, x):\n        h = x\n        if npu_config is None:\n            h = self.norm1(h)\n        else:\n            h = npu_config.run_group_norm(self.norm1, h)\n        h = nonlinearity(h)\n        h = self.conv1(h)\n        if npu_config is None:\n            h = self.norm2(h)\n        else:\n            h = npu_config.run_group_norm(self.norm2, h)\n        h = nonlinearity(h)\n        h = self.dropout(h)\n        h = self.conv2(h)\n        if self.in_channels != self.out_channels:\n            if self.use_conv_shortcut:\n                x = self.conv_shortcut(x)\n            else:\n                x = self.nin_shortcut(x)\n        return x + h\n\n\nclass ResnetBlock3D_GC(Block):\n    def __init__(\n        self,\n        *,\n        in_channels,\n        out_channels=None,\n        conv_shortcut=False,\n        norm_type,\n        dropout,\n    ):\n        super().__init__()\n        self.in_channels = in_channels\n        self.out_channels = in_channels if out_channels is None else out_channels\n        self.use_conv_shortcut = conv_shortcut\n\n        self.norm1 = Normalize(in_channels, norm_type=norm_type)\n        self.conv1 = CausalConv3d(in_channels, out_channels, 3, padding=1)\n        self.norm2 = Normalize(out_channels, norm_type=norm_type)\n        self.dropout = torch.nn.Dropout(dropout)\n        self.conv2 = CausalConv3d(out_channels, out_channels, 3, padding=1)\n        if self.in_channels != self.out_channels:\n            if self.use_conv_shortcut:\n                self.conv_shortcut = CausalConv3d(\n                    in_channels, out_channels, 3, padding=1\n                )\n            else:\n                self.nin_shortcut = CausalConv3d(\n                    in_channels, out_channels, 1, padding=0\n                )\n\n    def forward(self, x):\n        return checkpoint(self._forward, x, use_reentrant=True)\n\n    def _forward(self, x):\n        h = x\n        h = self.norm1(h)\n        h = nonlinearity(h)\n        h = self.conv1(h)\n        h = self.norm2(h)\n        h = nonlinearity(h)\n        h = self.dropout(h)\n        h = self.conv2(h)\n        if self.in_channels != self.out_channels:\n            if self.use_conv_shortcut:\n                x = self.conv_shortcut(x)\n            else:\n                x = self.nin_shortcut(x)\n        return x + h\n"
  },
  {
    "path": "opensora/models/causalvideovae/model/modules/updownsample.py",
    "content": "from typing import Union, Tuple\nfrom collections import deque\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom .ops import cast_tuple, video_to_image\nfrom .conv import CausalConv3d, CausalConv3d_GC\nfrom einops import rearrange\nfrom .block import Block\ntry:\n    import torch_npu\n    from opensora.npu_config import npu_config\nexcept:\n    torch_npu = None\n    npu_config = None\n\n\nclass Upsample(Block):\n    def __init__(self, in_channels, out_channels):\n        super().__init__()\n        self.with_conv = True\n        if self.with_conv:\n            self.conv = torch.nn.Conv2d(in_channels,\n                                        out_channels,\n                                        kernel_size=3,\n                                        stride=1,\n                                        padding=1)\n            \n    @video_to_image\n    def forward(self, x):\n        x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode=\"nearest\")\n        if self.with_conv:\n            x = self.conv(x)\n        return x\n\nclass Downsample(Block):\n    def __init__(self, in_channels, out_channels, undown=False):\n        super().__init__()\n        self.with_conv = True\n        self.undown = undown\n        if self.with_conv:\n            # no asymmetric padding in torch conv, must do it ourselves\n            if self.undown:\n                self.conv = torch.nn.Conv2d(in_channels,\n                                        out_channels,\n                                        kernel_size=3,\n                                        stride=1,\n                                        padding=1)\n            else:\n                self.conv = torch.nn.Conv2d(in_channels,\n                                        out_channels,\n                                        kernel_size=3,\n                                        stride=2,\n                                        padding=0)\n    @video_to_image\n    def forward(self, x):\n        if self.with_conv:\n            if self.undown:\n                if npu_config is not None and npu_config.on_npu:\n                    x_dtype = x.dtype\n                    x = x.to(npu_config.replaced_type)\n                    x = npu_config.run_conv3d(self.conv, x, x_dtype)\n                else:\n                    x = self.conv(x)\n            else:\n                pad = (0, 1, 0, 1)\n                if npu_config is not None and npu_config.on_npu:\n                    x_dtype = x.dtype\n                    x = x.to(npu_config.replaced_type)\n                    x = torch.nn.functional.pad(x, pad, mode=\"constant\", value=0)\n                    x = npu_config.run_conv3d(self.conv, x, x_dtype)\n                else:\n                    x = torch.nn.functional.pad(x, pad, mode=\"constant\", value=0)\n                    x = self.conv(x)\n        else:\n            x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)\n        return x\n\nclass SpatialDownsample2x(Block):\n    def __init__(\n        self,\n        chan_in,\n        chan_out,\n        kernel_size: Union[int, Tuple[int]] = (3, 3),\n        stride: Union[int, Tuple[int]] = (2, 2),\n        **kwargs\n    ):\n        super().__init__()\n        kernel_size = cast_tuple(kernel_size, 2)\n        stride = cast_tuple(stride, 2)\n        self.chan_in = chan_in\n        self.chan_out = chan_out\n        self.kernel_size = kernel_size\n        self.conv = CausalConv3d(\n            self.chan_in,\n            self.chan_out,\n            (1,) + self.kernel_size,\n            stride=(1, ) + stride,\n            padding=0\n        )\n\n    def forward(self, x):\n        pad = (0,1,0,1,0,0)\n        x = torch.nn.functional.pad(x, pad, mode=\"constant\", value=0)\n        x = self.conv(x)\n        return x\n\nclass SpatialUpsample2x_GC(Block):\n    def __init__(\n        self,\n        chan_in,\n        chan_out,\n        kernel_size: Union[int, Tuple[int]] = (3, 3),\n        stride: Union[int, Tuple[int]] = (1, 1),\n        unup=False,\n    ):\n        super().__init__()\n        self.chan_in = chan_in\n        self.chan_out = chan_out\n        self.kernel_size = kernel_size\n        self.unup = unup\n        self.conv = CausalConv3d_GC(\n            self.chan_in,\n            self.chan_out,\n            (1,) + self.kernel_size,\n            stride=(1, ) + stride,\n            padding=1\n        )\n\n    def forward(self, x):\n        if not self.unup:\n            t = x.shape[2]\n            x = rearrange(x, \"b c t h w -> b (c t) h w\")\n            x = F.interpolate(x, scale_factor=(2,2), mode=\"nearest\")\n            x = rearrange(x, \"b (c t) h w -> b c t h w\", t=t)\n        x = self.conv(x)\n        return x\n    \n\nclass SpatialUpsample2x(Block):\n    def __init__(\n        self,\n        chan_in,\n        chan_out,\n        kernel_size: Union[int, Tuple[int]] = (3, 3),\n        stride: Union[int, Tuple[int]] = (1, 1),\n        unup=False,\n    ):\n        super().__init__()\n        self.chan_in = chan_in\n        self.chan_out = chan_out\n        self.kernel_size = kernel_size\n        self.unup = unup\n        self.conv = CausalConv3d(\n            self.chan_in,\n            self.chan_out,\n            (1,) + self.kernel_size,\n            stride=(1, ) + stride,\n            padding=1\n        )\n\n    def forward(self, x):\n        if not self.unup:\n            t = x.shape[2]\n            x = rearrange(x, \"b c t h w -> b (c t) h w\")\n            x = F.interpolate(x, scale_factor=(2,2), mode=\"nearest\")\n            x = rearrange(x, \"b (c t) h w -> b c t h w\", t=t)\n        x = self.conv(x)\n        return x\n    \nclass TimeDownsample2x(Block):\n    def __init__(\n        self,\n        chan_in,\n        chan_out,\n        kernel_size: int = 3\n    ):\n        super().__init__()\n        self.kernel_size = kernel_size\n        if npu_config is not None and npu_config.on_npu:\n            self.avg_pool = nn.AvgPool2d((kernel_size, 1), stride=(2, 1))\n            self.pad = nn.ReplicationPad3d((0, 0, 0, 0, self.kernel_size - 1, 0))\n        else:\n            self.conv = nn.AvgPool3d((kernel_size, 1, 1), stride=(2, 1, 1))\n        \n    def forward(self, x):\n        if npu_config is not None and npu_config.on_npu:\n            n, c, d, h, w = x.shape\n            x = self.pad(x)\n            x = x.view(n * c, -1, h * w)\n            pooled = self.avg_pool(x)\n            output = pooled.view(n, c, -1, h, w)\n            return output\n        else:\n            first_frame_pad = x[:, :, :1, :, :].repeat(\n                (1, 1, self.kernel_size - 1, 1, 1)\n            )\n            x = torch.concatenate((first_frame_pad, x), dim=2)\n            return self.conv(x)\n\nclass TimeUpsample2x(Block):\n    def __init__(\n        self,\n        chan_in,\n        chan_out\n    ):\n        super().__init__()\n    def forward(self, x):\n        if x.size(2) > 1:\n            x,x_= x[:,:,:1],x[:,:,1:]\n            x_= F.interpolate(x_, scale_factor=(2,1,1), mode='trilinear')\n            x = torch.concat([x, x_], dim=2)\n        return x\n    \nclass TimeDownsampleRes2x(Block):\n    def __init__(\n        self,\n        in_channels,\n        out_channels,\n        kernel_size: int = 3,\n        mix_factor: float = 2.0,\n    ):\n        super().__init__()\n        self.kernel_size = cast_tuple(kernel_size, 3)\n        if npu_config is not None and npu_config.on_npu:\n            self.avg_pool = nn.AvgPool2d((kernel_size, 1), stride=(2, 1))\n            self.pad = nn.ReplicationPad3d((0, 0, 0, 0, kernel_size - 1, 0))\n        else:\n            self.avg_pool = nn.AvgPool3d((kernel_size, 1, 1), stride=(2, 1, 1))\n        self.conv = nn.Conv3d(\n            in_channels, out_channels, self.kernel_size, stride=(2,1,1), padding=(0,1,1)\n        )\n        self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor]))\n    \n    def forward(self, x):\n        alpha = torch.sigmoid(self.mix_factor)\n        if npu_config is not None and npu_config.on_npu:\n            n, c, d, h, w = x.shape\n            x_dtype = x.dtype\n            x = x.to(npu_config.replaced_type)\n            x = self.pad(x)\n            pad_x = x.view(n, c, -1, h, w)\n            avg_x = self.avg_pool(x.view(n * c, -1, h * w)).view(n, c, -1, h, w).to(x_dtype)\n            conv_x = npu_config.run_conv3d(self.conv, pad_x, x_dtype)\n            return alpha * avg_x + (1 - alpha) * conv_x\n        else:\n            first_frame_pad = x[:, :, :1, :, :].repeat(\n                (1, 1, self.kernel_size[0] - 1, 1, 1)\n            )\n            x = torch.concatenate((first_frame_pad, x), dim=2)\n            return alpha * self.avg_pool(x) + (1 - alpha) * self.conv(x)\n\nclass TimeUpsampleRes2x(Block):\n    def __init__(\n        self,\n        in_channels,\n        out_channels,\n        kernel_size: int = 3,\n        mix_factor: float = 2.0,\n    ):\n        super().__init__()\n        self.conv = CausalConv3d(\n            in_channels, out_channels, kernel_size, padding=1\n        )\n        self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor]))\n        \n    def forward(self, x):\n        alpha = torch.sigmoid(self.mix_factor)\n        if x.size(2) > 1:\n            x,x_= x[:,:,:1],x[:,:,1:]\n            if npu_config is not None and npu_config.on_npu:\n                x_dtype = x_.dtype\n                x_ = x_.to(npu_config.replaced_type)\n                x_ = F.interpolate(x_, scale_factor=(2, 1, 1), mode='trilinear')\n                x_ = x_.to(x_dtype)\n            else:\n                x_= F.interpolate(x_, scale_factor=(2,1,1), mode='trilinear')\n            x = torch.concat([x, x_], dim=2)\n        return alpha * x + (1-alpha) * self.conv(x)\n\nclass Spatial2xTime2x3DDownsample(Block):\n    def __init__(self, in_channels, out_channels):\n        super().__init__()\n        self.conv = CausalConv3d(in_channels, out_channels, kernel_size=3, padding=0, stride=2)\n\n    def forward(self, x):\n        pad = (0,1,0,1,0,0)\n        x = torch.nn.functional.pad(x, pad, mode=\"constant\", value=0)\n        x = self.conv(x)\n        return x\n\nclass Spatial2x3DDownsample(Block):\n    def __init__(self, in_channels, out_channels):\n        super().__init__()\n        self.conv = CausalConv3d(in_channels, out_channels, kernel_size=3, padding=0, stride=(1,2,2))\n\n    def forward(self, x):\n        pad = (0,1,0,1,0,0)\n        x = torch.nn.functional.pad(x, pad, mode=\"constant\", value=0)\n        x = self.conv(x)\n        return x\n    \n\nclass Spatial2x3DUpsample(Block):\n    def __init__(self, in_channels, out_channels):\n        super().__init__()\n        self.conv = CausalConv3d(in_channels, out_channels, kernel_size=3, padding=1)\n\n    def forward(self, x):\n        x = F.interpolate(x, scale_factor=(1,2,2), mode='trilinear')\n        return self.conv(x)\n\nclass Spatial2xTime2x3DUpsample(Block):\n    def __init__(\n        self,\n        in_channels,\n        out_channels,\n        t_interpolation=\"trilinear\",\n        enable_cached=False,\n    ):\n        super().__init__()\n        self.t_interpolation = t_interpolation\n        self.conv = CausalConv3d(in_channels, out_channels, kernel_size=3, padding=1)\n        self.enable_cached = enable_cached\n        self.causal_cached = deque()\n\n    def forward(self, x):\n        if x.size(2) > 1 or len(self.causal_cached) > 0:\n            if self.enable_cached and len(self.causal_cached) > 0:\n                x = torch.cat([self.causal_cached.popleft(), x], dim=2)\n                self.causal_cached.append(x[:, :, -2:-1].clone())\n                x = F.interpolate(x, scale_factor=(2, 1, 1), mode=self.t_interpolation)\n                x = x[:, :, 2:]\n                x = F.interpolate(x, scale_factor=(1, 2, 2), mode=\"trilinear\")\n            else:\n                if self.enable_cached:\n                    self.causal_cached.append(x[:, :, -1:].clone())\n                x, x_ = x[:, :, :1], x[:, :, 1:]\n                x_ = F.interpolate(\n                    x_, scale_factor=(2, 1, 1), mode=self.t_interpolation\n                )\n                x_ = F.interpolate(x_, scale_factor=(1, 2, 2), mode=\"trilinear\")\n                x = F.interpolate(x, scale_factor=(1, 2, 2), mode=\"trilinear\")\n                x = torch.concat([x, x_], dim=2)\n        else:\n            if self.enable_cached:\n                self.causal_cached.append(x[:, :, -1:].clone())\n            x = F.interpolate(x, scale_factor=(1, 2, 2), mode=\"trilinear\")\n        return self.conv(x)\n    "
  },
  {
    "path": "opensora/models/causalvideovae/model/modules/wavelet.py",
    "content": "import torch\nimport torch.nn.functional as F\nimport torch.nn as nn\nfrom ..modules import CausalConv3d\nfrom ..modules.ops import video_to_image\n\nfrom einops import rearrange\ntry:\n    import torch_npu\n    from opensora.npu_config import npu_config, set_run_dtype\nexcept Exception as e:\n    torch_npu = None\n    npu_config = None\n    \nclass HaarWaveletTransform3D(nn.Module):\n    def __init__(self, *args, **kwargs) -> None:\n        super().__init__(*args, **kwargs)\n        h = torch.tensor([[[1, 1], [1, 1]], [[1, 1], [1, 1]]]) * 0.3536\n        g = torch.tensor([[[1, -1], [1, -1]], [[1, -1], [1, -1]]]) * 0.3536\n        hh = torch.tensor([[[1, 1], [-1, -1]], [[1, 1], [-1, -1]]]) * 0.3536\n        gh = torch.tensor([[[1, -1], [-1, 1]], [[1, -1], [-1, 1]]]) * 0.3536\n        h_v = torch.tensor([[[1, 1], [1, 1]], [[-1, -1], [-1, -1]]]) * 0.3536\n        g_v = torch.tensor([[[1, -1], [1, -1]], [[-1, 1], [-1, 1]]]) * 0.3536\n        hh_v = torch.tensor([[[1, 1], [-1, -1]], [[-1, -1], [1, 1]]]) * 0.3536\n        gh_v = torch.tensor([[[1, -1], [-1, 1]], [[-1, 1], [1, -1]]]) * 0.3536\n        h = h.view(1, 1, 2, 2, 2)\n        g = g.view(1, 1, 2, 2, 2)\n        hh = hh.view(1, 1, 2, 2, 2)\n        gh = gh.view(1, 1, 2, 2, 2)\n        h_v = h_v.view(1, 1, 2, 2, 2)\n        g_v = g_v.view(1, 1, 2, 2, 2)\n        hh_v = hh_v.view(1, 1, 2, 2, 2)\n        gh_v = gh_v.view(1, 1, 2, 2, 2)\n\n        self.h_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False)\n        self.g_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False)\n        self.hh_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False)\n        self.gh_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False)\n        self.h_v_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False)\n        self.g_v_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False)\n        self.hh_v_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False)\n        self.gh_v_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False)\n\n        self.h_conv.conv.weight.data = h\n        self.g_conv.conv.weight.data = g\n        self.hh_conv.conv.weight.data = hh\n        self.gh_conv.conv.weight.data = gh\n        self.h_v_conv.conv.weight.data = h_v\n        self.g_v_conv.conv.weight.data = g_v\n        self.hh_v_conv.conv.weight.data = hh_v\n        self.gh_v_conv.conv.weight.data = gh_v\n        self.h_conv.requires_grad_(False)\n        self.g_conv.requires_grad_(False)\n        self.hh_conv.requires_grad_(False)\n        self.gh_conv.requires_grad_(False)\n        self.h_v_conv.requires_grad_(False)\n        self.g_v_conv.requires_grad_(False)\n        self.hh_v_conv.requires_grad_(False)\n        self.gh_v_conv.requires_grad_(False)\n\n    def forward(self, x):\n        assert x.dim() == 5\n        \n        if torch_npu is not None:\n            dtype = x.dtype\n            x = x.to(npu_config.conv_dtype)\n            self.to(npu_config.conv_dtype)\n        \n        b = x.shape[0]\n        x = rearrange(x, \"b c t h w -> (b c) 1 t h w\")\n        low_low_low = self.h_conv(x)\n        low_low_low = rearrange(low_low_low, \"(b c) 1 t h w -> b c t h w\", b=b)\n        low_low_high = self.g_conv(x)\n        low_low_high = rearrange(low_low_high, \"(b c) 1 t h w -> b c t h w\", b=b)\n        low_high_low = self.hh_conv(x)\n        low_high_low = rearrange(low_high_low, \"(b c) 1 t h w -> b c t h w\", b=b)\n        low_high_high = self.gh_conv(x)\n        low_high_high = rearrange(low_high_high, \"(b c) 1 t h w -> b c t h w\", b=b)\n        high_low_low = self.h_v_conv(x)\n        high_low_low = rearrange(high_low_low, \"(b c) 1 t h w -> b c t h w\", b=b)\n        high_low_high = self.g_v_conv(x)\n        high_low_high = rearrange(high_low_high, \"(b c) 1 t h w -> b c t h w\", b=b)\n        high_high_low = self.hh_v_conv(x)\n        high_high_low = rearrange(high_high_low, \"(b c) 1 t h w -> b c t h w\", b=b)\n        high_high_high = self.gh_v_conv(x)\n        high_high_high = rearrange(high_high_high, \"(b c) 1 t h w -> b c t h w\", b=b)\n        \n        output = torch.cat(\n            [\n                low_low_low,\n                low_low_high,\n                low_high_low,\n                low_high_high,\n                high_low_low,\n                high_low_high,\n                high_high_low,\n                high_high_high,\n            ],\n            dim=1,\n        )\n        \n        if torch_npu is not None:\n            x = x.to(dtype)\n            output = output.to(dtype)\n            self.to(dtype)\n        \n        return output\n\nclass InverseHaarWaveletTransform3D(nn.Module):\n    def __init__(self, enable_cached=False, *args, **kwargs) -> None:\n        super().__init__(*args, **kwargs)\n\n        self.register_buffer('h', \n            torch.tensor([[[1, 1], [1, 1]], [[1, 1], [1, 1]]]).view(1, 1, 2, 2, 2) * 0.3536\n        )\n        self.register_buffer('g', \n            torch.tensor([[[1, -1], [1, -1]], [[1, -1], [1, -1]]]).view(1, 1, 2, 2, 2) * 0.3536\n        )\n        self.register_buffer('hh', \n            torch.tensor([[[1, 1], [-1, -1]], [[1, 1], [-1, -1]]]).view(1, 1, 2, 2, 2) * 0.3536\n        )\n        self.register_buffer('gh', \n            torch.tensor([[[1, -1], [-1, 1]], [[1, -1], [-1, 1]]]).view(1, 1, 2, 2, 2) * 0.3536\n        )\n        self.register_buffer('h_v', \n            torch.tensor([[[1, 1], [1, 1]], [[-1, -1], [-1, -1]]]).view(1, 1, 2, 2, 2) * 0.3536\n        )\n        self.register_buffer('g_v', \n            torch.tensor([[[1, -1], [1, -1]], [[-1, 1], [-1, 1]]]).view(1, 1, 2, 2, 2) * 0.3536\n        )\n        self.register_buffer('hh_v', \n            torch.tensor([[[1, 1], [-1, -1]], [[-1, -1], [1, 1]]]).view(1, 1, 2, 2, 2) * 0.3536\n        )\n        self.register_buffer('gh_v', \n            torch.tensor([[[1, -1], [-1, 1]], [[-1, 1], [1, -1]]]).view(1, 1, 2, 2, 2) * 0.3536\n        )\n        self.enable_cached = enable_cached\n        self.is_first_chunk = True\n\n    def forward(self, coeffs):\n        assert coeffs.dim() == 5\n        \n        if torch_npu is not None:\n            dtype = coeffs.dtype\n            coeffs = coeffs.to(npu_config.conv_dtype)\n            self.h = self.h.to(npu_config.conv_dtype)\n            self.g = self.g.to(npu_config.conv_dtype)\n            self.hh = self.hh.to(npu_config.conv_dtype)\n            self.gh = self.gh.to(npu_config.conv_dtype)\n            self.h_v = self.h_v.to(npu_config.conv_dtype)\n            self.g_v = self.g_v.to(npu_config.conv_dtype)\n            self.hh_v = self.hh_v.to(npu_config.conv_dtype)\n            self.gh_v = self.gh_v.to(npu_config.conv_dtype)\n            \n        \n        b = coeffs.shape[0]\n\n        (\n            low_low_low,\n            low_low_high,\n            low_high_low,\n            low_high_high,\n            high_low_low,\n            high_low_high,\n            high_high_low,\n            high_high_high,\n        ) = coeffs.chunk(8, dim=1)\n\n        low_low_low = rearrange(low_low_low, \"b c t h w -> (b c) 1 t h w\")\n        low_low_high = rearrange(low_low_high, \"b c t h w -> (b c) 1 t h w\")\n        low_high_low = rearrange(low_high_low, \"b c t h w -> (b c) 1 t h w\")\n        low_high_high = rearrange(low_high_high, \"b c t h w -> (b c) 1 t h w\")\n        high_low_low = rearrange(high_low_low, \"b c t h w -> (b c) 1 t h w\")\n        high_low_high = rearrange(high_low_high, \"b c t h w -> (b c) 1 t h w\")\n        high_high_low = rearrange(high_high_low, \"b c t h w -> (b c) 1 t h w\")\n        high_high_high = rearrange(high_high_high, \"b c t h w -> (b c) 1 t h w\")\n\n        low_low_low = F.conv_transpose3d(low_low_low, self.h, stride=2)\n        low_low_high = F.conv_transpose3d(low_low_high, self.g, stride=2)\n        low_high_low = F.conv_transpose3d(low_high_low, self.hh, stride=2)\n        low_high_high = F.conv_transpose3d(low_high_high, self.gh, stride=2)\n        high_low_low = F.conv_transpose3d(high_low_low, self.h_v, stride=2)\n        high_low_high = F.conv_transpose3d(high_low_high, self.g_v, stride=2)\n        high_high_low = F.conv_transpose3d(high_high_low, self.hh_v, stride=2)\n        high_high_high = F.conv_transpose3d(high_high_high, self.gh_v, stride=2)\n        \n        if self.enable_cached and not self.is_first_chunk:\n            reconstructed = (\n                low_low_low\n                + low_low_high\n                + low_high_low\n                + low_high_high\n                + high_low_low\n                + high_low_high\n                + high_high_low\n                + high_high_high\n            )\n        else:\n            reconstructed = (\n                low_low_low[:, :, 1:]\n                + low_low_high[:, :, 1:]\n                + low_high_low[:, :, 1:]\n                + low_high_high[:, :, 1:]\n                + high_low_low[:, :, 1:]\n                + high_low_high[:, :, 1:]\n                + high_high_low[:, :, 1:]\n                + high_high_high[:, :, 1:]\n            )\n            \n        reconstructed = rearrange(reconstructed, \"(b c) 1 t h w -> b c t h w\", b=b)\n        \n        if torch_npu is not None:\n            coeffs = coeffs.to(dtype)\n            reconstructed = reconstructed.to(dtype)\n            self.h = self.h.to(dtype)\n            self.g = self.g.to(dtype)\n            self.hh = self.hh.to(dtype)\n            self.gh = self.gh.to(dtype)\n            self.h_v = self.h_v.to(dtype)\n            self.g_v = self.g_v.to(dtype)\n            self.hh_v = self.hh_v.to(dtype)\n            self.gh_v = self.gh_v.to(dtype)\n        \n        return reconstructed\n\n\nclass HaarWaveletTransform2D(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.register_buffer('aa', torch.tensor([[1, 1], [1, 1]]).view(1, 1, 2, 2) / 2)\n        self.register_buffer('ad', torch.tensor([[1, 1], [-1, -1]]).view(1, 1, 2, 2) / 2)\n        self.register_buffer('da', torch.tensor([[1, -1], [1, -1]]).view(1, 1, 2, 2) / 2)\n        self.register_buffer('dd', torch.tensor([[1, -1], [-1, 1]]).view(1, 1, 2, 2) / 2)\n\n    @video_to_image\n    def forward(self, x):\n        b, c, h, w = x.shape\n        x = x.reshape(b * c, 1, h, w)\n        low_low = F.conv2d(x, self.aa, stride=2).reshape(b, c, h // 2, w // 2)\n        low_high = F.conv2d(x, self.ad, stride=2).reshape(b, c, h // 2, w // 2)\n        high_low = F.conv2d(x, self.da, stride=2).reshape(b, c, h // 2, w // 2)\n        high_high = F.conv2d(x, self.dd, stride=2).reshape(b, c, h // 2, w // 2)\n        coeffs = torch.cat([low_low, low_high, high_low, high_high], dim=1)\n        return coeffs\n\nclass InverseHaarWaveletTransform2D(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.register_buffer('aa', torch.tensor([[1, 1], [1, 1]]).view(1, 1, 2, 2) / 2)\n        self.register_buffer('ad', torch.tensor([[1, 1], [-1, -1]]).view(1, 1, 2, 2) / 2)\n        self.register_buffer('da', torch.tensor([[1, -1], [1, -1]]).view(1, 1, 2, 2) / 2)\n        self.register_buffer('dd', torch.tensor([[1, -1], [-1, 1]]).view(1, 1, 2, 2) / 2)\n\n    @video_to_image\n    def forward(self, coeffs):\n        low_low, low_high, high_low, high_high = coeffs.chunk(4, dim=1)\n        b, c, height_half, width_half = low_low.shape\n        height = height_half * 2\n        width = width_half * 2\n\n        low_low = F.conv_transpose2d(\n            low_low.reshape(b * c, 1, height_half, width_half), self.aa, stride=2\n        )\n        low_high = F.conv_transpose2d(\n            low_high.reshape(b * c, 1, height_half, width_half), self.ad, stride=2\n        )\n        high_low = F.conv_transpose2d(\n            high_low.reshape(b * c, 1, height_half, width_half), self.da, stride=2\n        )\n        high_high = F.conv_transpose2d(\n            high_high.reshape(b * c, 1, height_half, width_half), self.dd, stride=2\n        )\n\n        return (low_low + low_high + high_low + high_high).reshape(b, c, height, width)"
  },
  {
    "path": "opensora/models/causalvideovae/model/registry.py",
    "content": "class ModelRegistry:\n    _models = {}\n\n    @classmethod\n    def register(cls, model_name):\n        def decorator(model_class):\n            cls._models[model_name] = model_class\n            return model_class\n        return decorator\n\n    @classmethod\n    def get_model(cls, model_name):\n        return cls._models.get(model_name)"
  },
  {
    "path": "opensora/models/causalvideovae/model/trainer_videobase.py",
    "content": "from transformers import Trainer\nimport torch.nn.functional as F\nfrom typing import Optional\nimport os\nimport torch\nfrom transformers.utils import WEIGHTS_NAME\nimport json\n\nclass VideoBaseTrainer(Trainer):\n\n    def _save(self, output_dir: Optional[str] = None, state_dict=None):\n        output_dir = output_dir if output_dir is not None else self.args.output_dir\n        os.makedirs(output_dir, exist_ok=True)\n        if state_dict is None:\n            state_dict = self.model.state_dict()\n        \n        # get model config\n        model_config = self.model.config.to_dict()\n        \n        # add more information\n        model_config['model'] = self.model.__class__.__name__\n        \n        with open(os.path.join(output_dir, \"config.json\"), \"w\") as file:\n            json.dump(self.model.config.to_dict(), file)\n        torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))\n        torch.save(self.args, os.path.join(output_dir, \"training_args.bin\"))\n"
  },
  {
    "path": "opensora/models/causalvideovae/model/utils/__init__.py",
    "content": ""
  },
  {
    "path": "opensora/models/causalvideovae/model/utils/distrib_utils.py",
    "content": "import torch\nimport numpy as np\n\nclass DiagonalGaussianDistribution(object):\n    def __init__(self, parameters, deterministic=False):\n        self.parameters = parameters\n        self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)\n        self.logvar = torch.clamp(self.logvar, -30.0, 20.0)\n        self.deterministic = deterministic\n        self.std = torch.exp(0.5 * self.logvar)\n        self.var = torch.exp(self.logvar)\n        if self.deterministic:\n            self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device, dtype=self.parameters.dtype)\n\n    def sample(self):\n        x = self.mean + self.std * torch.randn(self.mean.shape, device=self.parameters.device, dtype=self.parameters.dtype)\n        return x\n\n    def kl(self, other=None):\n        if self.deterministic:\n            return torch.Tensor([0.])\n        else:\n            if other is None:\n                return 0.5 * torch.sum(torch.pow(self.mean, 2)\n                                       + self.var - 1.0 - self.logvar,\n                                       dim=[1, 2, 3])\n            else:\n                return 0.5 * torch.sum(\n                    torch.pow(self.mean - other.mean, 2) / other.var\n                    + self.var / other.var - 1.0 - self.logvar + other.logvar,\n                    dim=[1, 2, 3])\n\n    def nll(self, sample, dims=[1,2,3]):\n        if self.deterministic:\n            return torch.Tensor([0.])\n        logtwopi = np.log(2.0 * np.pi)\n        return 0.5 * torch.sum(\n            logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,\n            dim=dims)\n\n    def mode(self):\n        return self.mean\n"
  },
  {
    "path": "opensora/models/causalvideovae/model/utils/module_utils.py",
    "content": "import importlib\n\nModule = str\nMODULES_BASE = \"opensora.models.causalvideovae.model.modules.\"\n\ndef resolve_str_to_obj(str_val, append=True):\n    if append:\n        str_val = MODULES_BASE + str_val\n    module_name, class_name = str_val.rsplit('.', 1)\n    module = importlib.import_module(module_name)\n    return getattr(module, class_name)\n\ndef create_instance(module_class_str: str, **kwargs):\n    module_name, class_name = module_class_str.rsplit('.', 1)\n    module = importlib.import_module(module_name)\n    class_ = getattr(module, class_name)\n    return class_(**kwargs)"
  },
  {
    "path": "opensora/models/causalvideovae/model/utils/scheduler_utils.py",
    "content": "import torch\n\ndef cosine_scheduler(step, max_steps, value_base=1, value_end=0):\n    step = torch.tensor(step)\n    cosine_value = 0.5 * (1 + torch.cos(torch.pi * step / max_steps))\n    value = value_end + (value_base - value_end) * cosine_value\n    return value"
  },
  {
    "path": "opensora/models/causalvideovae/model/utils/video_utils.py",
    "content": "import torch\nimport numpy as np\n\ndef tensor_to_video(x):\n    x = (x * 2 - 1).detach().cpu()\n    x = torch.clamp(x, -1, 1)\n    x = (x + 1) / 2\n    x = x.permute(1, 0, 2, 3).float().numpy() # c t h w -> t c h w\n    x = (255 * x).astype(np.uint8)\n    return x"
  },
  {
    "path": "opensora/models/causalvideovae/model/utils/wavelet_utils.py",
    "content": "import torch\nimport torch.nn.functional as F\nimport torch.nn as nn\nfrom ..modules import CausalConv3d\nfrom einops import rearrange\n\nclass HaarWaveletTransform3D(nn.Module):\n    def __init__(self, *args, **kwargs) -> None:\n        super().__init__(*args, **kwargs)\n        h = torch.tensor([[[1, 1], [1, 1]], [[1, 1], [1, 1]]]) * 0.3536\n        g = torch.tensor([[[1, -1], [1, -1]], [[1, -1], [1, -1]]]) * 0.3536\n        hh = torch.tensor([[[1, 1], [-1, -1]], [[1, 1], [-1, -1]]]) * 0.3536\n        gh = torch.tensor([[[1, -1], [-1, 1]], [[1, -1], [-1, 1]]]) * 0.3536\n        h_v = torch.tensor([[[1, 1], [1, 1]], [[-1, -1], [-1, -1]]]) * 0.3536\n        g_v = torch.tensor([[[1, -1], [1, -1]], [[-1, 1], [-1, 1]]]) * 0.3536\n        hh_v = torch.tensor([[[1, 1], [-1, -1]], [[-1, -1], [1, 1]]]) * 0.3536\n        gh_v = torch.tensor([[[1, -1], [-1, 1]], [[-1, 1], [1, -1]]]) * 0.3536\n        h = h.view(1, 1, 2, 2, 2)\n        g = g.view(1, 1, 2, 2, 2)\n        hh = hh.view(1, 1, 2, 2, 2)\n        gh = gh.view(1, 1, 2, 2, 2)\n        h_v = h_v.view(1, 1, 2, 2, 2)\n        g_v = g_v.view(1, 1, 2, 2, 2)\n        hh_v = hh_v.view(1, 1, 2, 2, 2)\n        gh_v = gh_v.view(1, 1, 2, 2, 2)\n\n        self.h_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False)\n        self.g_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False)\n        self.hh_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False)\n        self.gh_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False)\n        self.h_v_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False)\n        self.g_v_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False)\n        self.hh_v_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False)\n        self.gh_v_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False)\n\n        self.h_conv.conv.weight.data = h\n        self.g_conv.conv.weight.data = g\n        self.hh_conv.conv.weight.data = hh\n        self.gh_conv.conv.weight.data = gh\n        self.h_v_conv.conv.weight.data = h_v\n        self.g_v_conv.conv.weight.data = g_v\n        self.hh_v_conv.conv.weight.data = hh_v\n        self.gh_v_conv.conv.weight.data = gh_v\n        self.h_conv.requires_grad_(False)\n        self.g_conv.requires_grad_(False)\n        self.hh_conv.requires_grad_(False)\n        self.gh_conv.requires_grad_(False)\n        self.h_v_conv.requires_grad_(False)\n        self.g_v_conv.requires_grad_(False)\n        self.hh_v_conv.requires_grad_(False)\n        self.gh_v_conv.requires_grad_(False)\n\n    def forward(self, x):\n        assert x.dim() == 5\n        b = x.shape[0]\n        x = rearrange(x, \"b c t h w -> (b c) 1 t h w\")\n        low_low_low = self.h_conv(x)\n        low_low_low = rearrange(low_low_low, \"(b c) 1 t h w -> b c t h w\", b=b)\n        low_low_high = self.g_conv(x)\n        low_low_high = rearrange(low_low_high, \"(b c) 1 t h w -> b c t h w\", b=b)\n        low_high_low = self.hh_conv(x)\n        low_high_low = rearrange(low_high_low, \"(b c) 1 t h w -> b c t h w\", b=b)\n        low_high_high = self.gh_conv(x)\n        low_high_high = rearrange(low_high_high, \"(b c) 1 t h w -> b c t h w\", b=b)\n        high_low_low = self.h_v_conv(x)\n        high_low_low = rearrange(high_low_low, \"(b c) 1 t h w -> b c t h w\", b=b)\n        high_low_high = self.g_v_conv(x)\n        high_low_high = rearrange(high_low_high, \"(b c) 1 t h w -> b c t h w\", b=b)\n        high_high_low = self.hh_v_conv(x)\n        high_high_low = rearrange(high_high_low, \"(b c) 1 t h w -> b c t h w\", b=b)\n        high_high_high = self.gh_v_conv(x)\n        high_high_high = rearrange(high_high_high, \"(b c) 1 t h w -> b c t h w\", b=b)\n        \n        output = torch.cat(\n            [\n                low_low_low,\n                low_low_high,\n                low_high_low,\n                low_high_high,\n                high_low_low,\n                high_low_high,\n                high_high_low,\n                high_high_high,\n            ],\n            dim=1,\n        )\n        return output\n\nclass InverseHaarWaveletTransform3D(nn.Module):\n    def __init__(self, enable_cached=False, *args, **kwargs) -> None:\n        super().__init__(*args, **kwargs)\n\n        self.register_buffer('h', \n            torch.tensor([[[1, 1], [1, 1]], [[1, 1], [1, 1]]]).view(1, 1, 2, 2, 2) * 0.3536\n        )\n        self.register_buffer('g', \n            torch.tensor([[[1, -1], [1, -1]], [[1, -1], [1, -1]]]).view(1, 1, 2, 2, 2) * 0.3536\n        )\n        self.register_buffer('hh', \n            torch.tensor([[[1, 1], [-1, -1]], [[1, 1], [-1, -1]]]).view(1, 1, 2, 2, 2) * 0.3536\n        )\n        self.register_buffer('gh', \n            torch.tensor([[[1, -1], [-1, 1]], [[1, -1], [-1, 1]]]).view(1, 1, 2, 2, 2) * 0.3536\n        )\n        self.register_buffer('h_v', \n            torch.tensor([[[1, 1], [1, 1]], [[-1, -1], [-1, -1]]]).view(1, 1, 2, 2, 2) * 0.3536\n        )\n        self.register_buffer('g_v', \n            torch.tensor([[[1, -1], [1, -1]], [[-1, 1], [-1, 1]]]).view(1, 1, 2, 2, 2) * 0.3536\n        )\n        self.register_buffer('hh_v', \n            torch.tensor([[[1, 1], [-1, -1]], [[-1, -1], [1, 1]]]).view(1, 1, 2, 2, 2) * 0.3536\n        )\n        self.register_buffer('gh_v', \n            torch.tensor([[[1, -1], [-1, 1]], [[-1, 1], [1, -1]]]).view(1, 1, 2, 2, 2) * 0.3536\n        )\n        self.enable_cached = enable_cached\n        self.causal_cached = None\n\n    def forward(self, coeffs):\n        assert coeffs.dim() == 5\n        b = coeffs.shape[0]\n\n        (\n            low_low_low,\n            low_low_high,\n            low_high_low,\n            low_high_high,\n            high_low_low,\n            high_low_high,\n            high_high_low,\n            high_high_high,\n        ) = coeffs.chunk(8, dim=1)\n\n        low_low_low = rearrange(low_low_low, \"b c t h w -> (b c) 1 t h w\")\n        low_low_high = rearrange(low_low_high, \"b c t h w -> (b c) 1 t h w\")\n        low_high_low = rearrange(low_high_low, \"b c t h w -> (b c) 1 t h w\")\n        low_high_high = rearrange(low_high_high, \"b c t h w -> (b c) 1 t h w\")\n        high_low_low = rearrange(high_low_low, \"b c t h w -> (b c) 1 t h w\")\n        high_low_high = rearrange(high_low_high, \"b c t h w -> (b c) 1 t h w\")\n        high_high_low = rearrange(high_high_low, \"b c t h w -> (b c) 1 t h w\")\n        high_high_high = rearrange(high_high_high, \"b c t h w -> (b c) 1 t h w\")\n\n        low_low_low = F.conv_transpose3d(low_low_low, self.h, stride=2)\n        low_low_high = F.conv_transpose3d(low_low_high, self.g, stride=2)\n        low_high_low = F.conv_transpose3d(low_high_low, self.hh, stride=2)\n        low_high_high = F.conv_transpose3d(low_high_high, self.gh, stride=2)\n        high_low_low = F.conv_transpose3d(high_low_low, self.h_v, stride=2)\n        high_low_high = F.conv_transpose3d(high_low_high, self.g_v, stride=2)\n        high_high_low = F.conv_transpose3d(high_high_low, self.hh_v, stride=2)\n        high_high_high = F.conv_transpose3d(high_high_high, self.gh_v, stride=2)\n        if self.enable_cached and self.causal_cached:\n            reconstructed = (\n                low_low_low\n                + low_low_high\n                + low_high_low\n                + low_high_high\n                + high_low_low\n                + high_low_high\n                + high_high_low\n                + high_high_high\n            )\n        else:\n            reconstructed = (\n                low_low_low[:, :, 1:]\n                + low_low_high[:, :, 1:]\n                + low_high_low[:, :, 1:]\n                + low_high_high[:, :, 1:]\n                + high_low_low[:, :, 1:]\n                + high_low_high[:, :, 1:]\n                + high_high_low[:, :, 1:]\n                + high_high_high[:, :, 1:]\n            )\n            self.causal_cached = True\n        reconstructed = rearrange(reconstructed, \"(b c) 1 t h w -> b c t h w\", b=b)\n        return reconstructed\n\n\nclass HaarWaveletTransform2D(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.register_buffer('aa', torch.tensor([[1, 1], [1, 1]]).view(1, 1, 2, 2) / 2)\n        self.register_buffer('ad', torch.tensor([[1, 1], [-1, -1]]).view(1, 1, 2, 2) / 2)\n        self.register_buffer('da', torch.tensor([[1, -1], [1, -1]]).view(1, 1, 2, 2) / 2)\n        self.register_buffer('dd', torch.tensor([[1, -1], [-1, 1]]).view(1, 1, 2, 2) / 2)\n\n    def forward(self, x):\n        b, c, h, w = x.shape\n        x = x.reshape(b * c, 1, h, w)\n        low_low = F.conv2d(x, self.aa, stride=2).reshape(b, c, h // 2, w // 2)\n        low_high = F.conv2d(x, self.ad, stride=2).reshape(b, c, h // 2, w // 2)\n        high_low = F.conv2d(x, self.da, stride=2).reshape(b, c, h // 2, w // 2)\n        high_high = F.conv2d(x, self.dd, stride=2).reshape(b, c, h // 2, w // 2)\n        coeffs = torch.cat([low_low, low_high, high_low, high_high], dim=1)\n        return coeffs\n\nclass InverseHaarWaveletTransform2D(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.register_buffer('aa', torch.tensor([[1, 1], [1, 1]]).view(1, 1, 2, 2) / 2)\n        self.register_buffer('ad', torch.tensor([[1, 1], [-1, -1]]).view(1, 1, 2, 2) / 2)\n        self.register_buffer('da', torch.tensor([[1, -1], [1, -1]]).view(1, 1, 2, 2) / 2)\n        self.register_buffer('dd', torch.tensor([[1, -1], [-1, 1]]).view(1, 1, 2, 2) / 2)\n\n    def forward(self, coeffs):\n        low_low, low_high, high_low, high_high = coeffs.chunk(4, dim=1)\n        b, c, height_half, width_half = low_low.shape\n        height = height_half * 2\n        width = width_half * 2\n\n        low_low = F.conv_transpose2d(\n            low_low.reshape(b * c, 1, height_half, width_half), self.aa, stride=2\n        )\n        low_high = F.conv_transpose2d(\n            low_high.reshape(b * c, 1, height_half, width_half), self.ad, stride=2\n        )\n        high_low = F.conv_transpose2d(\n            high_low.reshape(b * c, 1, height_half, width_half), self.da, stride=2\n        )\n        high_high = F.conv_transpose2d(\n            high_high.reshape(b * c, 1, height_half, width_half), self.dd, stride=2\n        )\n\n        return (low_low + low_high + high_low + high_high).reshape(b, c, height, width)"
  },
  {
    "path": "opensora/models/causalvideovae/model/vae/__init__.py",
    "content": "from .modeling_causalvae import CausalVAEModel\nfrom .modeling_wfvae import WFVAEModel\nfrom einops import rearrange\nfrom torch import nn"
  },
  {
    "path": "opensora/models/causalvideovae/model/vae/modeling_causalvae.py",
    "content": "\ntry:\n    import torch_npu\n    from opensora.npu_config import npu_config\nexcept:\n    torch_npu = None\n    npu_config = None\nfrom ..modeling_videobase import VideoBaseAE\nfrom ..modules import Normalize\nfrom ..modules.ops import nonlinearity\nfrom typing import Tuple\nimport torch.nn as nn\nfrom ..utils.module_utils import resolve_str_to_obj, Module\nfrom ..utils.distrib_utils import DiagonalGaussianDistribution\nfrom ..registry import ModelRegistry\nimport torch\nfrom diffusers.configuration_utils import register_to_config\nfrom copy import deepcopy\nimport os\n\n\nclass Encoder(nn.Module):\n    def __init__(\n        self,\n        z_channels: int,\n        hidden_size: int,\n        hidden_size_mult: Tuple[int] = (1, 2, 4, 4),\n        attn_resolutions: Tuple[int] = (16,),\n        conv_in: Module = \"Conv2d\",\n        conv_out: Module = \"CasualConv3d\",\n        attention: Module = \"AttnBlock\",\n        resnet_blocks: Tuple[Module] = (\n            \"ResnetBlock2D\",\n            \"ResnetBlock2D\",\n            \"ResnetBlock2D\",\n            \"ResnetBlock3D\",\n        ),\n        spatial_downsample: Tuple[Module] = (\n            \"Downsample\",\n            \"Downsample\",\n            \"Downsample\",\n            \"\",\n        ),\n        temporal_downsample: Tuple[Module] = (\"\", \"\", \"TimeDownsampleRes2x\", \"\"),\n        mid_resnet: Module = \"ResnetBlock3D\",\n        dropout: float = 0.0,\n        resolution: int = 256,\n        num_res_blocks: int = 2,\n        double_z: bool = True,\n        norm_type: str = \"groupnorm\",\n    ) -> None:\n        super().__init__()\n        assert len(resnet_blocks) == len(hidden_size_mult), print(\n            hidden_size_mult, resnet_blocks\n        )\n        # ---- Config ----\n        self.num_resolutions = len(hidden_size_mult)\n        self.resolution = resolution\n        self.num_res_blocks = num_res_blocks\n\n        # ---- In ----\n        self.conv_in = resolve_str_to_obj(conv_in)(\n            3, hidden_size, kernel_size=3, stride=1, padding=1\n        )\n\n        # ---- Downsample ----\n        curr_res = resolution\n        in_ch_mult = (1,) + tuple(hidden_size_mult)\n        self.in_ch_mult = in_ch_mult\n        self.down = nn.ModuleList()\n        for i_level in range(self.num_resolutions):\n            block = nn.ModuleList()\n            attn = nn.ModuleList()\n            block_in = hidden_size * in_ch_mult[i_level]\n            block_out = hidden_size * hidden_size_mult[i_level]\n            for i_block in range(self.num_res_blocks):\n                block.append(\n                    resolve_str_to_obj(resnet_blocks[i_level])(\n                        in_channels=block_in,\n                        out_channels=block_out,\n                        dropout=dropout,\n                        norm_type=norm_type\n                    )\n                )\n                block_in = block_out\n                if curr_res in attn_resolutions:\n                    attn.append(resolve_str_to_obj(attention)(block_in))\n            down = nn.Module()\n            down.block = block\n            down.attn = attn\n            if spatial_downsample[i_level]:\n                down.downsample = resolve_str_to_obj(spatial_downsample[i_level])(\n                    block_in, block_in\n                )\n                curr_res = curr_res // 2\n            if temporal_downsample[i_level]:\n                down.time_downsample = resolve_str_to_obj(temporal_downsample[i_level])(\n                    block_in, block_in\n                )\n            self.down.append(down)\n\n        # ---- Mid ----\n        self.mid = nn.Module()\n        self.mid.block_1 = resolve_str_to_obj(mid_resnet)(\n            in_channels=block_in,\n            out_channels=block_in,\n            dropout=dropout,\n            norm_type=norm_type\n        )\n        self.mid.attn_1 = resolve_str_to_obj(attention)(block_in)\n        self.mid.block_2 = resolve_str_to_obj(mid_resnet)(\n            in_channels=block_in,\n            out_channels=block_in,\n            dropout=dropout,\n            norm_type=norm_type\n        )\n        # ---- Out ----\n        self.norm_out = Normalize(block_in, norm_type=norm_type)\n        self.conv_out = resolve_str_to_obj(conv_out)(\n            block_in,\n            2 * z_channels if double_z else z_channels,\n            kernel_size=3,\n            stride=1,\n            padding=1,\n        )\n\n    def forward(self, x):\n        h = self.conv_in(x)\n        for i_level in range(self.num_resolutions):\n            for i_block in range(self.num_res_blocks):\n                h = self.down[i_level].block[i_block](h)\n                if len(self.down[i_level].attn) > 0:\n                    h = self.down[i_level].attn[i_block](h)\n            if hasattr(self.down[i_level], \"downsample\"):\n                h = self.down[i_level].downsample(h)\n            if hasattr(self.down[i_level], \"time_downsample\"):\n                h = self.down[i_level].time_downsample(h)\n\n        h = self.mid.block_1(h)\n        h = self.mid.attn_1(h)\n        h = self.mid.block_2(h)\n        if npu_config is None:\n            h = self.norm_out(h)\n        else:\n            h = npu_config.run_group_norm(self.norm_out, h)\n        h = nonlinearity(h)\n        h = self.conv_out(h)\n        return h\n\n\nclass Decoder(nn.Module):\n    def __init__(\n        self,\n        z_channels: int,\n        hidden_size: int,\n        hidden_size_mult: Tuple[int] = (1, 2, 4, 4),\n        attn_resolutions: Tuple[int] = (16,),\n        conv_in: Module = \"Conv2d\",\n        conv_out: Module = \"CasualConv3d\",\n        attention: Module = \"AttnBlock\",\n        resnet_blocks: Tuple[Module] = (\n            \"ResnetBlock3D\",\n            \"ResnetBlock3D\",\n            \"ResnetBlock3D\",\n            \"ResnetBlock3D\",\n        ),\n        spatial_upsample: Tuple[Module] = (\n            \"\",\n            \"SpatialUpsample2x\",\n            \"SpatialUpsample2x\",\n            \"SpatialUpsample2x\",\n        ),\n        temporal_upsample: Tuple[Module] = (\"\", \"\", \"\", \"TimeUpsampleRes2x\"),\n        mid_resnet: Module = \"ResnetBlock3D\",\n        dropout: float = 0.0,\n        resolution: int = 256,\n        num_res_blocks: int = 2,\n        norm_type: str = \"groupnorm\",\n    ):\n        super().__init__()\n        # ---- Config ----\n        self.num_resolutions = len(hidden_size_mult)\n        self.resolution = resolution\n        self.num_res_blocks = num_res_blocks\n\n        # ---- In ----\n        block_in = hidden_size * hidden_size_mult[self.num_resolutions - 1]\n        curr_res = resolution // 2 ** (self.num_resolutions - 1)\n        self.conv_in = resolve_str_to_obj(conv_in)(\n            z_channels, block_in, kernel_size=3, padding=1\n        )\n\n        # ---- Mid ----\n        self.mid = nn.Module()\n        self.mid.block_1 = resolve_str_to_obj(mid_resnet)(\n            in_channels=block_in,\n            out_channels=block_in,\n            dropout=dropout,\n            norm_type=norm_type\n        )\n        self.mid.attn_1 = resolve_str_to_obj(attention)(block_in, norm_type=norm_type)\n        self.mid.block_2 = resolve_str_to_obj(mid_resnet)(\n            in_channels=block_in,\n            out_channels=block_in,\n            dropout=dropout,\n            norm_type=norm_type\n        )\n\n        # ---- Upsample ----\n        self.up = nn.ModuleList()\n        for i_level in reversed(range(self.num_resolutions)):\n            block = nn.ModuleList()\n            attn = nn.ModuleList()\n            block_out = hidden_size * hidden_size_mult[i_level]\n            for i_block in range(self.num_res_blocks + 1):\n                block.append(\n                    resolve_str_to_obj(resnet_blocks[i_level])(\n                        in_channels=block_in,\n                        out_channels=block_out,\n                        dropout=dropout,\n                        norm_type=norm_type\n                    )\n                )\n                block_in = block_out\n                if curr_res in attn_resolutions:\n                    attn.append(resolve_str_to_obj(attention)(block_in, norm_type=norm_type))\n            up = nn.Module()\n            up.block = block\n            up.attn = attn\n            if spatial_upsample[i_level]:\n                up.upsample = resolve_str_to_obj(spatial_upsample[i_level])(\n                    block_in, block_in\n                )\n                curr_res = curr_res * 2\n            if temporal_upsample[i_level]:\n                up.time_upsample = resolve_str_to_obj(temporal_upsample[i_level])(\n                    block_in, block_in\n                )\n            self.up.insert(0, up)\n\n        # ---- Out ----\n        self.norm_out = Normalize(block_in, norm_type=norm_type)\n        self.conv_out = resolve_str_to_obj(conv_out)(\n            block_in, 3, kernel_size=3, padding=1\n        )\n\n    def forward(self, z):\n        h = self.conv_in(z)\n        h = self.mid.block_1(h)\n        h = self.mid.attn_1(h)\n        h = self.mid.block_2(h)\n\n        for i_level in reversed(range(self.num_resolutions)):\n            for i_block in range(self.num_res_blocks + 1):\n                h = self.up[i_level].block[i_block](h)\n                if len(self.up[i_level].attn) > 0:\n                    h = self.up[i_level].attn[i_block](h)\n            if hasattr(self.up[i_level], \"upsample\"):\n                h = self.up[i_level].upsample(h)\n            if hasattr(self.up[i_level], \"time_upsample\"):\n                h = self.up[i_level].time_upsample(h)\n\n        if npu_config is None:\n            h = self.norm_out(h)\n        else:\n            h = npu_config.run_group_norm(self.norm_out, h)\n        h = nonlinearity(h)\n        h = self.conv_out(h)\n        return h\n\n@ModelRegistry.register(\"CausalVAE\")\nclass CausalVAEModel(VideoBaseAE):\n    @register_to_config\n    def __init__(\n        self,\n        hidden_size: int = 128,\n        z_channels: int = 4,\n        hidden_size_mult: Tuple[int] = (1, 2, 4, 4),\n        attn_resolutions: Tuple[int] = [],\n        dropout: float = 0.0,\n        resolution: int = 256,\n        double_z: bool = True,\n        embed_dim: int = 4,\n        num_res_blocks: int = 2,\n        q_conv: str = \"CausalConv3d\",\n        encoder_conv_in: Module = \"CausalConv3d\",\n        encoder_conv_out: Module = \"CausalConv3d\",\n        encoder_attention: Module = \"AttnBlock3D\",\n        encoder_resnet_blocks: Tuple[Module] = (\n            \"ResnetBlock3D\",\n            \"ResnetBlock3D\",\n            \"ResnetBlock3D\",\n            \"ResnetBlock3D\",\n        ),\n        encoder_spatial_downsample: Tuple[Module] = (\n            \"SpatialDownsample2x\",\n            \"SpatialDownsample2x\",\n            \"SpatialDownsample2x\",\n            \"\",\n        ),\n        encoder_temporal_downsample: Tuple[Module] = (\n            \"\",\n            \"TimeDownsample2x\",\n            \"TimeDownsample2x\",\n            \"\",\n        ),\n        encoder_mid_resnet: Module = \"ResnetBlock3D\",\n        decoder_conv_in: Module = \"CausalConv3d\",\n        decoder_conv_out: Module = \"CausalConv3d\",\n        decoder_attention: Module = \"AttnBlock3D\",\n        decoder_resnet_blocks: Tuple[Module] = (\n            \"ResnetBlock3D\",\n            \"ResnetBlock3D\",\n            \"ResnetBlock3D\",\n            \"ResnetBlock3D\",\n        ),\n        decoder_spatial_upsample: Tuple[Module] = (\n            \"\",\n            \"SpatialUpsample2x\",\n            \"SpatialUpsample2x\",\n            \"SpatialUpsample2x\",\n        ),\n        decoder_temporal_upsample: Tuple[Module] = (\n            \"\",\n            \"\",\n            \"TimeUpsample2x\",\n            \"TimeUpsample2x\",\n        ),\n        decoder_mid_resnet: Module = \"ResnetBlock3D\",\n        use_quant_layer: bool = True,\n        norm_type: str = \"groupnorm\",\n    ) -> None:\n        super().__init__()\n\n        self.tile_sample_min_size = 512000\n        self.tile_sample_min_size_t = 33\n\n        self.tile_sample_min_size_dec = 512\n        self.tile_sample_min_size_t_dec = 17\n        self.tile_latent_min_size = int(self.tile_sample_min_size_dec / (2 ** (len(hidden_size_mult) - 1)))\n        self.tile_latent_min_size_t = int((self.tile_sample_min_size_t_dec-1) / 4) + 1\n\n        self.tile_overlap_t = 2\n\n        self.tile_overlap_factor = 0.125\n        self.use_tiling = False\n        \n        self.use_quant_layer = use_quant_layer\n\n        self.encoder = Encoder(\n            z_channels=z_channels,\n            hidden_size=hidden_size,\n            hidden_size_mult=hidden_size_mult,\n            attn_resolutions=attn_resolutions,\n            conv_in=encoder_conv_in,\n            conv_out=encoder_conv_out,\n            attention=encoder_attention,\n            resnet_blocks=encoder_resnet_blocks,\n            spatial_downsample=encoder_spatial_downsample,\n            temporal_downsample=encoder_temporal_downsample,\n            mid_resnet=encoder_mid_resnet,\n            dropout=dropout,\n            resolution=resolution,\n            num_res_blocks=num_res_blocks,\n            double_z=double_z,\n            norm_type=norm_type\n        )\n\n        self.decoder = Decoder(\n            z_channels=z_channels,\n            hidden_size=hidden_size,\n            hidden_size_mult=hidden_size_mult,\n            attn_resolutions=attn_resolutions,\n            conv_in=decoder_conv_in,\n            conv_out=decoder_conv_out,\n            attention=decoder_attention,\n            resnet_blocks=decoder_resnet_blocks,\n            spatial_upsample=decoder_spatial_upsample,\n            temporal_upsample=decoder_temporal_upsample,\n            mid_resnet=decoder_mid_resnet,\n            dropout=dropout,\n            resolution=resolution,\n            num_res_blocks=num_res_blocks,\n            norm_type=norm_type\n        )\n        if self.use_quant_layer:\n            quant_conv_cls = resolve_str_to_obj(q_conv)\n            self.quant_conv = quant_conv_cls(2 * z_channels, 2 * embed_dim, 1)\n            self.post_quant_conv = quant_conv_cls(embed_dim, z_channels, 1)\n\n    def get_encoder(self):\n        if self.use_quant_layer:\n            return [self.quant_conv, self.encoder]\n        return [self.encoder]\n\n    def get_decoder(self):\n        if self.use_quant_layer:\n            return [self.post_quant_conv, self.decoder]\n        return [self.decoder]\n\n    def encode(self, x):\n        if self.use_tiling and (\n            x.shape[-1] > self.tile_sample_min_size\n            or x.shape[-2] > self.tile_sample_min_size\n            or x.shape[-3] > self.tile_sample_min_size_t\n        ):\n            # import ipdb;ipdb.set_trace()\n            return self.tiled_encode(x)\n        h = self.encoder(x)\n        if self.use_quant_layer:\n            h = self.quant_conv(h)\n        posterior = DiagonalGaussianDistribution(h)\n        return posterior\n\n    def decode(self, z):\n        if self.use_tiling and (\n            z.shape[-1] > self.tile_latent_min_size\n            or z.shape[-2] > self.tile_latent_min_size\n            or z.shape[-3] > self.tile_latent_min_size_t\n        ):\n            return self.tiled_decode(z)\n        if self.use_quant_layer:\n            z = self.post_quant_conv(z)\n        dec = self.decoder(z)\n        return dec\n\n    def forward(self, input, sample_posterior=True):\n        posterior = self.encode(input)\n        if sample_posterior:\n            z = posterior.sample()\n        else:\n            z = posterior.mode()\n        dec = self.decode(z)\n        return dec, posterior\n\n    def on_train_start(self):\n        self.ema = deepcopy(self) if self.save_ema == True else None\n\n    def get_last_layer(self):\n        if hasattr(self.decoder.conv_out, \"conv\"):\n            return self.decoder.conv_out.conv.weight\n        else:\n            return self.decoder.conv_out.weight\n\n    def blend_v(\n        self, a: torch.Tensor, b: torch.Tensor, blend_extent: int\n    ) -> torch.Tensor:\n        blend_extent = min(a.shape[3], b.shape[3], blend_extent)\n        for y in range(blend_extent):\n            b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (\n                1 - y / blend_extent\n            ) + b[:, :, :, y, :] * (y / blend_extent)\n        return b\n\n    def blend_h(\n        self, a: torch.Tensor, b: torch.Tensor, blend_extent: int\n    ) -> torch.Tensor:\n        blend_extent = min(a.shape[4], b.shape[4], blend_extent)\n        for x in range(blend_extent):\n            b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (\n                1 - x / blend_extent\n            ) + b[:, :, :, :, x] * (x / blend_extent)\n        return b\n\n    def tiled_encode(self, x):\n        t = x.shape[2]\n        t_chunk_idx = [i for i in range(0, t, self.tile_sample_min_size_t-1)]\n        # print('tiled_encode', t_chunk_idx)\n        if len(t_chunk_idx) == 1 and t_chunk_idx[0] == 0:\n            t_chunk_start_end = [[0, t]]\n        else:\n            t_chunk_start_end = [[t_chunk_idx[i], t_chunk_idx[i+1]+1+(self.tile_overlap_t-1)*4] for i in range(len(t_chunk_idx)-1)]\n            if t_chunk_start_end[-1][-1] > t:\n                t_chunk_start_end[-1][-1] = t\n            elif t_chunk_start_end[-1][-1] < t:\n                last_start_end = [t_chunk_idx[-1], t]\n                t_chunk_start_end.append(last_start_end)\n        moments = []\n        # print('tiled_encode t_chunk_start_end', t_chunk_start_end)\n        for idx, (start, end) in enumerate(t_chunk_start_end):\n            chunk_x = x[:, :, start: end]\n            if idx != 0:\n                moment = self.tiled_encode2d(chunk_x, return_moments=True)[:, :, 1+(self.tile_overlap_t-1):]\n            else:\n                moment = self.tiled_encode2d(chunk_x, return_moments=True)\n            moments.append(moment)\n        moments = torch.cat(moments, dim=2)\n        posterior = DiagonalGaussianDistribution(moments)\n        return posterior\n    \n    def tiled_decode(self, x):\n        t = x.shape[2]\n        t_chunk_idx = [i for i in range(0, t, self.tile_latent_min_size_t-1)]\n        # print('tiled_decode', t_chunk_idx)\n        if len(t_chunk_idx) == 1 and t_chunk_idx[0] == 0:\n            t_chunk_start_end = [[0, t]]\n        else:\n            t_chunk_start_end = [[t_chunk_idx[i], t_chunk_idx[i+1]+1+(self.tile_overlap_t-1)] for i in range(len(t_chunk_idx)-1)]\n            if t_chunk_start_end[-1][-1] > t:\n                t_chunk_start_end[-1][-1] = t\n            elif t_chunk_start_end[-1][-1] < t:\n                last_start_end = [t_chunk_idx[-1], t]\n                t_chunk_start_end.append(last_start_end)\n        dec_ = []\n        # print('tiled_decode t_chunk_start_end', t_chunk_start_end)\n        for idx, (start, end) in enumerate(t_chunk_start_end):\n            # import ipdb;ipdb.set_trace()\n            chunk_x = x[:, :, start: end]\n            if idx != 0:\n                dec = self.tiled_decode2d(chunk_x)[:, :, 1+(self.tile_overlap_t-1)*4:]\n            else:\n                dec = self.tiled_decode2d(chunk_x)\n            # print(chunk_x.shape, dec.shape)\n            dec_.append(dec)\n        dec_ = torch.cat(dec_, dim=2)\n        return dec_\n\n    def tiled_encode2d(self, x, return_moments=False):\n        overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))\n        blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)\n        row_limit = self.tile_latent_min_size - blend_extent\n        # print('overlap_size, blend_extent, row_limit', overlap_size, blend_extent, row_limit)\n        # Split the image into 512x512 tiles and encode them separately.\n        rows = []\n        for i in range(0, x.shape[3], overlap_size):\n            row = []\n            for j in range(0, x.shape[4], overlap_size):\n                # print(i, j)\n                tile = x[\n                    :,\n                    :,\n                    :,\n                    i : i + self.tile_sample_min_size,\n                    j : j + self.tile_sample_min_size,\n                ]\n                tile = self.encoder(tile)\n                if self.use_quant_layer:\n                    tile = self.quant_conv(tile)\n                row.append(tile)\n            rows.append(row)\n        result_rows = []\n        for i, row in enumerate(rows):\n            result_row = []\n            for j, tile in enumerate(row):\n                # blend the above tile and the left tile\n                # to the current tile and add the current tile to the result row\n                if i > 0:\n                    tile = self.blend_v(rows[i - 1][j], tile, blend_extent)\n                if j > 0:\n                    tile = self.blend_h(row[j - 1], tile, blend_extent)\n                result_row.append(tile[:, :, :, :row_limit, :row_limit])\n            result_rows.append(torch.cat(result_row, dim=4))\n\n        moments = torch.cat(result_rows, dim=3)\n        posterior = DiagonalGaussianDistribution(moments)\n        if return_moments:\n            return moments\n        return posterior\n\n    def tiled_decode2d(self, z):\n\n        overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))\n        blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)\n        row_limit = self.tile_sample_min_size - blend_extent\n\n        # Split z into overlapping 64x64 tiles and decode them separately.\n        # The tiles have an overlap to avoid seams between tiles.\n        # print('tiled_decode2d', list(range(0, z.shape[3], overlap_size)), list(range(0, z.shape[4], overlap_size)))\n        # import ipdb;ipdb.set_trace()\n        rows = []\n        for i in range(0, z.shape[3], overlap_size):\n            row = []\n            for j in range(0, z.shape[4], overlap_size):\n                tile = z[\n                    :,\n                    :,\n                    :,\n                    i : i + self.tile_latent_min_size,\n                    j : j + self.tile_latent_min_size,\n                ]\n                if self.use_quant_layer:\n                    tile = self.post_quant_conv(tile)\n                decoded = self.decoder(tile)\n                row.append(decoded)\n            rows.append(row)\n        result_rows = []\n        for i, row in enumerate(rows):\n            result_row = []\n            for j, tile in enumerate(row):\n                # blend the above tile and the left tile\n                # to the current tile and add the current tile to the result row\n                if i > 0:\n                    tile = self.blend_v(rows[i - 1][j], tile, blend_extent)\n                if j > 0:\n                    tile = self.blend_h(row[j - 1], tile, blend_extent)\n                result_row.append(tile[:, :, :, :row_limit, :row_limit])\n            result_rows.append(torch.cat(result_row, dim=4))\n\n        dec = torch.cat(result_rows, dim=3)\n        return dec\n\n    def enable_tiling(self, use_tiling: bool = True):\n        self.use_tiling = use_tiling\n\n    def disable_tiling(self):\n        self.enable_tiling(False)\n\n    def init_from_ckpt(self, path, ignore_keys=list()):\n        sd = torch.load(path, map_location=\"cpu\")\n        print(\"init from \" + path)\n\n        if (\n            \"ema_state_dict\" in sd\n            and len(sd[\"ema_state_dict\"]) > 0\n            and os.environ.get(\"NOT_USE_EMA_MODEL\", 0) == 0\n        ):\n            print(\"Load from ema model!\")\n            sd = sd[\"ema_state_dict\"]\n            sd = {key.replace(\"module.\", \"\"): value for key, value in sd.items()}\n        elif \"state_dict\" in sd:\n            print(\"Load from normal model!\")\n            if \"gen_model\" in sd[\"state_dict\"]:\n                sd = sd[\"state_dict\"][\"gen_model\"]\n            else:\n                sd = sd[\"state_dict\"]\n\n        keys = list(sd.keys())\n\n        for k in keys:\n            for ik in ignore_keys:\n                if k.startswith(ik):\n                    print(\"Deleting key {} from state_dict.\".format(k))\n                    del sd[k]\n\n        missing_keys, unexpected_keys = self.load_state_dict(sd, strict=False)\n"
  },
  {
    "path": "opensora/models/causalvideovae/model/vae/modeling_wfvae.py",
    "content": "try:\n    import torch_npu\n    from opensora.npu_config import npu_config\nexcept:\n    torch_npu = None\n    npu_config = None\n    \nfrom ..modeling_videobase import VideoBaseAE\nfrom diffusers.configuration_utils import register_to_config\nimport torch\nimport torch.nn as nn\nfrom ..modules import (\n    ResnetBlock2D,\n    ResnetBlock3D,\n    Conv2d,\n    HaarWaveletTransform3D,\n    InverseHaarWaveletTransform3D,\n    CausalConv3d,\n    Normalize,\n    AttnBlock3DFix,\n    nonlinearity,\n)\nimport torch.nn as nn\nfrom ..utils.distrib_utils import DiagonalGaussianDistribution\nimport torch\nfrom copy import deepcopy\nimport os\nfrom ..registry import ModelRegistry\nfrom einops import rearrange\nfrom collections import deque\nfrom ..utils.module_utils import resolve_str_to_obj, Module\nfrom typing import List\n\nclass Encoder(VideoBaseAE):\n\n    @register_to_config\n    def __init__(\n        self,\n        latent_dim: int = 8,\n        base_channels: int = 128,\n        num_resblocks: int = 2,\n        energy_flow_hidden_size: int = 64,\n        dropout: float = 0.0,\n        attention_type: str = \"AttnBlock3DFix\",\n        use_attention: bool = True,\n        norm_type: str = \"groupnorm\",\n        l1_dowmsample_block: str = \"Downsample\",\n        l1_downsample_wavelet: str = \"HaarWaveletTransform2D\",\n        l2_dowmsample_block: str = \"Spatial2xTime2x3DDownsample\",\n        l2_downsample_wavelet: str = \"HaarWaveletTransform3D\",\n    ) -> None:\n        super().__init__()\n        self.down1 = nn.Sequential(\n            Conv2d(24, base_channels, kernel_size=3, stride=1, padding=1),\n            *[\n                ResnetBlock2D(\n                    in_channels=base_channels,\n                    out_channels=base_channels,\n                    dropout=dropout,\n                    norm_type=norm_type,\n                )\n                for _ in range(num_resblocks)\n            ],\n            resolve_str_to_obj(l1_dowmsample_block)(in_channels=base_channels, out_channels=base_channels),\n        )\n        self.down2 = nn.Sequential(\n            Conv2d(\n                base_channels + energy_flow_hidden_size,\n                base_channels * 2,\n                kernel_size=3,\n                stride=1,\n                padding=1,\n            ),\n            *[\n                ResnetBlock3D(\n                    in_channels=base_channels * 2,\n                    out_channels=base_channels * 2,\n                    dropout=dropout,\n                    norm_type=norm_type,\n                )\n                for _ in range(num_resblocks)\n            ],\n            resolve_str_to_obj(l2_dowmsample_block)(base_channels * 2, base_channels * 2),\n        )\n        # Connection\n        if l1_dowmsample_block == \"Downsample\": # Bad code. For temporal usage.\n            l1_channels = 12\n        else:\n            l1_channels = 24\n            \n        self.connect_l1 = Conv2d(\n            l1_channels, energy_flow_hidden_size, kernel_size=3, stride=1, padding=1\n        )\n        self.connect_l2 = Conv2d(\n            24, energy_flow_hidden_size, kernel_size=3, stride=1, padding=1\n        )\n        # Mid\n        mid_layers = [\n            ResnetBlock3D(\n                in_channels=base_channels * 2 + energy_flow_hidden_size,\n                out_channels=base_channels * 4,\n                dropout=dropout,\n                norm_type=norm_type,\n            ),\n            ResnetBlock3D(\n                in_channels=base_channels * 4,\n                out_channels=base_channels * 4,\n                dropout=dropout,\n                norm_type=norm_type,\n            ),\n        ]\n        if use_attention:\n            mid_layers.insert(\n                1, resolve_str_to_obj(attention_type)(in_channels=base_channels * 4, norm_type=norm_type)\n            )\n        self.mid = nn.Sequential(*mid_layers)\n\n        self.norm_out = Normalize(base_channels * 4, norm_type=norm_type)\n        self.conv_out = CausalConv3d(\n            base_channels * 4, latent_dim * 2, kernel_size=3, stride=1, padding=1\n        )\n        \n        self.wavelet_transform_in = HaarWaveletTransform3D()\n        self.wavelet_transform_l1 = resolve_str_to_obj(l1_downsample_wavelet)()\n        self.wavelet_transform_l2 = resolve_str_to_obj(l2_downsample_wavelet)()\n        \n        \n    def forward(self, x):\n        coeffs = self.wavelet_transform_in(x)\n        \n        l1_coeffs = coeffs[:, :3]\n        l1_coeffs = self.wavelet_transform_l1(l1_coeffs)\n        l1 = self.connect_l1(l1_coeffs)\n        l2_coeffs = self.wavelet_transform_l2(l1_coeffs[:, :3])\n        l2 = self.connect_l2(l2_coeffs)\n        \n        h = self.down1(coeffs)\n        h = torch.concat([h, l1], dim=1)\n        h = self.down2(h)\n        h = torch.concat([h, l2], dim=1)\n        h = self.mid(h)\n        \n        if npu_config is None:\n            h = self.norm_out(h)\n        else:\n            h = npu_config.run_group_norm(self.norm_out, h)\n            \n        h = nonlinearity(h)\n        h = self.conv_out(h)\n        \n        return h, (l1_coeffs, l2_coeffs)\n\nclass Decoder(VideoBaseAE):\n\n    @register_to_config\n    def __init__(\n        self,\n        latent_dim: int = 8,\n        base_channels: int = 128,\n        num_resblocks: int = 2,\n        dropout: float = 0.0,\n        energy_flow_hidden_size: int = 128,\n        attention_type: str = \"AttnBlock3DFix\",\n        use_attention: bool = True,\n        norm_type: str = \"groupnorm\",\n        t_interpolation: str = \"nearest\",\n        connect_res_layer_num: int = 1,\n        l1_upsample_block: str = \"Upsample\",\n        l1_upsample_wavelet: str = \"InverseHaarWaveletTransform2D\",\n        l2_upsample_block: str = \"Spatial2xTime2x3DUpsample\",\n        l2_upsample_wavelet: str = \"InverseHaarWaveletTransform3D\",\n    ) -> None:\n        super().__init__()\n        self.energy_flow_hidden_size = energy_flow_hidden_size\n    \n        self.conv_in = CausalConv3d(\n            latent_dim, base_channels * 4, kernel_size=3, stride=1, padding=1\n        )\n        mid_layers = [\n            ResnetBlock3D(\n                in_channels=base_channels * 4,\n                out_channels=base_channels * 4,\n                dropout=dropout,\n                norm_type=norm_type,\n            ),\n            ResnetBlock3D(\n                in_channels=base_channels * 4,\n                out_channels=base_channels * 4 + energy_flow_hidden_size,\n                dropout=dropout,\n                norm_type=norm_type,\n            ),\n        ]\n        if use_attention:\n            mid_layers.insert(\n                1, resolve_str_to_obj(attention_type)(in_channels=base_channels * 4, norm_type=norm_type)\n            )\n            \n        self.mid = nn.Sequential(*mid_layers)\n\n        self.up2 = nn.Sequential(\n            *[\n                ResnetBlock3D(\n                    in_channels=base_channels * 4,\n                    out_channels=base_channels * 4,\n                    dropout=dropout,\n                    norm_type=norm_type,\n                )\n                for _ in range(num_resblocks)\n            ],\n            resolve_str_to_obj(l2_upsample_block)(\n                base_channels * 4, base_channels * 4, t_interpolation=t_interpolation\n            ),\n            ResnetBlock3D(\n                in_channels=base_channels * 4,\n                out_channels=base_channels * 4 + energy_flow_hidden_size,\n                dropout=dropout,\n                norm_type=norm_type,\n            ),\n        )\n        self.up1 = nn.Sequential(\n            *[\n                ResnetBlock3D(\n                    in_channels=base_channels * (4 if i == 0 else 2),\n                    out_channels=base_channels * 2,\n                    dropout=dropout,\n                    norm_type=norm_type,\n                )\n                for i in range(num_resblocks)\n            ],\n            resolve_str_to_obj(l1_upsample_block)(in_channels=base_channels * 2, out_channels=base_channels * 2),\n            ResnetBlock3D(\n                in_channels=base_channels * 2,\n                out_channels=base_channels * 2,\n                dropout=dropout,\n                norm_type=norm_type,\n            ),\n        )\n        self.layer = nn.Sequential(\n            *[\n                ResnetBlock3D(\n                    in_channels=base_channels * (2 if i == 0 else 1),\n                    out_channels=base_channels,\n                    dropout=dropout,\n                    norm_type=norm_type,\n                )\n                for i in range(2)\n            ],\n        )\n        # Connection\n        if l1_upsample_block == \"Upsample\": # Bad code. For temporal usage.\n            l1_channels = 12\n        else:\n            l1_channels = 24\n        self.connect_l1 = nn.Sequential(\n            *[\n                ResnetBlock3D(\n                    in_channels=energy_flow_hidden_size,\n                    out_channels=energy_flow_hidden_size,\n                    dropout=dropout,\n                    norm_type=norm_type,\n                )\n                for _ in range(connect_res_layer_num)\n            ],\n            Conv2d(energy_flow_hidden_size, l1_channels, kernel_size=3, stride=1, padding=1),\n        )\n        self.connect_l2 = nn.Sequential(\n            *[\n                ResnetBlock3D(\n                    in_channels=energy_flow_hidden_size,\n                    out_channels=energy_flow_hidden_size,\n                    dropout=dropout,\n                    norm_type=norm_type,\n                )\n                for _ in range(connect_res_layer_num)\n            ],\n            Conv2d(energy_flow_hidden_size, 24, kernel_size=3, stride=1, padding=1),\n        )\n        # Out\n        self.norm_out = Normalize(base_channels, norm_type=norm_type)\n        self.conv_out = Conv2d(base_channels, 24, kernel_size=3, stride=1, padding=1)\n        \n        self.inverse_wavelet_transform_out = InverseHaarWaveletTransform3D()\n        self.inverse_wavelet_transform_l1 = resolve_str_to_obj(l1_upsample_wavelet)()\n        self.inverse_wavelet_transform_l2 = resolve_str_to_obj(l2_upsample_wavelet)()\n        \n    def forward(self, z):\n        h = self.conv_in(z)\n        h = self.mid(h)\n        \n        l2_coeffs = self.connect_l2(h[:, -self.energy_flow_hidden_size :])\n        l2 = self.inverse_wavelet_transform_l2(l2_coeffs)\n        \n        h = self.up2(h[:, : -self.energy_flow_hidden_size])\n        \n        l1_coeffs = h[:, -self.energy_flow_hidden_size :]\n        l1_coeffs = self.connect_l1(l1_coeffs)\n        l1_coeffs[:, :3] = l1_coeffs[:, :3] + l2\n        l1 = self.inverse_wavelet_transform_l1(l1_coeffs)\n\n        h = self.up1(h[:, : -self.energy_flow_hidden_size])\n        \n        h = self.layer(h)\n        if npu_config is None:\n            h = self.norm_out(h)\n        else:\n            h = npu_config.run_group_norm(self.norm_out, h)\n        h = nonlinearity(h)\n        h = self.conv_out(h)\n        h[:, :3] = h[:, :3] + l1\n        \n        dec = self.inverse_wavelet_transform_out(h)\n        return dec, (l1_coeffs, l2_coeffs)\n\n\n@ModelRegistry.register(\"WFVAE\")\nclass WFVAEModel(VideoBaseAE):\n\n    @register_to_config\n    def __init__(\n        self,\n        latent_dim: int = 8,\n        base_channels: int = 128,\n        encoder_num_resblocks: int = 2,\n        encoder_energy_flow_hidden_size: int = 64,\n        decoder_num_resblocks: int = 2,\n        decoder_energy_flow_hidden_size: int = 128,\n        attention_type: str = \"AttnBlock3DFix\",\n        use_attention: bool = True,\n        dropout: float = 0.0,\n        norm_type: str = \"groupnorm\",\n        t_interpolation: str = \"nearest\",\n        connect_res_layer_num: int = 1,\n        scale: List[float] = [0.18215, 0.18215, 0.18215, 0.18215, 0.18215, 0.18215, 0.18215, 0.18215],\n        shift: List[float] = [0, 0, 0, 0, 0, 0, 0, 0],\n        # Module config\n        l1_dowmsample_block: str = \"Downsample\",\n        l1_downsample_wavelet: str = \"HaarWaveletTransform2D\",\n        l2_dowmsample_block: str = \"Spatial2xTime2x3DDownsample\",\n        l2_downsample_wavelet: str = \"HaarWaveletTransform3D\",\n        l1_upsample_block: str = \"Upsample\",\n        l1_upsample_wavelet: str = \"InverseHaarWaveletTransform2D\",\n        l2_upsample_block: str = \"Spatial2xTime2x3DUpsample\",\n        l2_upsample_wavelet: str = \"InverseHaarWaveletTransform3D\",\n    ) -> None:\n        super().__init__()\n        self.use_tiling = False\n        # Hardcode for now\n        self.t_chunk_enc = 8\n        self.t_chunk_dec = 2\n        \n        self.t_upsample_times = 4 // 2\n        self.use_quant_layer = False\n\n        self.encoder = Encoder(\n            latent_dim=latent_dim,\n            base_channels=base_channels,\n            num_resblocks=encoder_num_resblocks,\n            energy_flow_hidden_size=encoder_energy_flow_hidden_size,\n            dropout=dropout,\n            use_attention=use_attention,\n            norm_type=norm_type,\n            l1_dowmsample_block=l1_dowmsample_block,\n            l1_downsample_wavelet=l1_downsample_wavelet,\n            l2_dowmsample_block=l2_dowmsample_block,\n            l2_downsample_wavelet=l2_downsample_wavelet,\n            attention_type=attention_type\n        )\n        self.decoder = Decoder(\n            latent_dim=latent_dim,\n            base_channels=base_channels,\n            num_resblocks=decoder_num_resblocks,\n            energy_flow_hidden_size=decoder_energy_flow_hidden_size,\n            dropout=dropout,\n            use_attention=use_attention,\n            norm_type=norm_type,\n            t_interpolation=t_interpolation,\n            connect_res_layer_num=connect_res_layer_num,\n            l1_upsample_block=l1_upsample_block,\n            l1_upsample_wavelet=l1_upsample_wavelet,\n            l2_upsample_block=l2_upsample_block,\n            l2_upsample_wavelet=l2_upsample_wavelet,\n            attention_type=attention_type\n        )\n\n        # Set cache offset for trilinear lossless upsample.\n        self._set_cache_offset([self.decoder.up2, self.decoder.connect_l2, self.decoder.conv_in, self.decoder.mid], 1)\n        self._set_cache_offset([self.decoder.up2[-2:], self.decoder.up1, self.decoder.connect_l1, self.decoder.layer], self.t_upsample_times)\n        \n    def get_encoder(self):\n        if self.use_quant_layer:\n            return [self.quant_conv, self.encoder]\n        return [self.encoder]\n\n    def get_decoder(self):\n        if self.use_quant_layer:\n            return [self.post_quant_conv, self.decoder]\n        return [self.decoder]\n\n    def _empty_causal_cached(self, parent):\n        for name, module in parent.named_modules():\n            if hasattr(module, 'causal_cached'):\n                module.causal_cached = deque()\n                \n    def _set_causal_cached(self, enable_cached=True):\n        for name, module in self.named_modules():\n            if hasattr(module, 'enable_cached'):\n                module.enable_cached = enable_cached\n    \n    def _set_cache_offset(self, modules, cache_offset=0):\n        for module in modules:\n            for submodule in module.modules():\n                if hasattr(submodule, 'cache_offset'):\n                    submodule.cache_offset = cache_offset\n    \n    def _set_first_chunk(self, is_first_chunk=True):\n        for module in self.modules():\n            if hasattr(module, 'is_first_chunk'):\n                module.is_first_chunk = is_first_chunk\n    \n    def build_chunk_start_end(self, t, decoder_mode=False):\n        start_end = [[0, 1]]\n        start = 1\n        end = start\n        while True:\n            if start >= t:\n                break\n            end = min(t, end + (self.t_chunk_dec if decoder_mode else self.t_chunk_enc) )\n            start_end.append([start, end])\n            start = end\n        return start_end\n    \n    def encode(self, x):\n        self._empty_causal_cached(self.encoder)\n        self._set_first_chunk(True)\n        \n        if self.use_tiling:\n            h = self.tile_encode(x)\n            l1, l2 = None, None\n        else:\n            h, (l1, l2) = self.encoder(x)\n            if self.use_quant_layer:\n                h = self.quant_conv(h)\n            \n        posterior = DiagonalGaussianDistribution(h)\n        return posterior\n    \n    \n    def tile_encode(self, x):\n        b, c, t, h, w = x.shape\n        \n        start_end = self.build_chunk_start_end(t)\n        result = []\n        for idx, (start, end) in enumerate(start_end):\n            self._set_first_chunk(idx == 0)\n            chunk = x[:, :, start:end, :, :]\n            chunk = self.encoder(chunk)[0]\n            if self.use_quant_layer:\n                chunk = self.quant_conv(chunk)\n            result.append(chunk)\n            \n        return torch.cat(result, dim=2)\n\n\n    def decode(self, z):\n        self._empty_causal_cached(self.decoder)\n        self._set_first_chunk(True)\n        \n        if self.use_tiling:\n            dec = self.tile_decode(z)\n            l1, l2 = None, None\n        else:\n            if self.use_quant_layer:\n                z = self.post_quant_conv(z)\n            dec, (l1, l2) = self.decoder(z)\n            \n        return dec\n    \n    def tile_decode(self, x):\n        b, c, t, h, w = x.shape\n        \n        start_end = self.build_chunk_start_end(t, decoder_mode=True)\n        \n        result = []\n        for idx, (start, end) in enumerate(start_end):\n            self._set_first_chunk(idx==0)\n            \n            if end + 1 < t:\n                chunk = x[:, :, start:end+1, :, :]\n            else:\n                chunk = x[:, :, start:end, :, :]\n                \n            if self.use_quant_layer:\n                chunk = self.post_quant_conv(chunk)\n            chunk = self.decoder(chunk)[0]\n            \n            if end + 1 < t:\n                chunk = chunk[:, :, :-4]\n                result.append(chunk.clone())\n            else:\n                result.append(chunk.clone())\n            \n        return torch.cat(result, dim=2)\n\n    def forward(self, input, sample_posterior=True):\n        posterior = self.encode(input)\n        if sample_posterior:\n            z = posterior.sample()\n        else:\n            z = posterior.mode()\n        dec = self.decode(z)\n        return dec, posterior\n\n    def get_last_layer(self):\n        if hasattr(self.decoder.conv_out, \"conv\"):\n            return self.decoder.conv_out.conv.weight\n        else:\n            return self.decoder.conv_out.weight\n\n    def enable_tiling(self, use_tiling: bool = True):\n        self.use_tiling = use_tiling\n        self._set_causal_cached(use_tiling)\n        \n    def disable_tiling(self):\n        self.enable_tiling(False)\n\n    def init_from_ckpt(self, path, ignore_keys=list()):\n        sd = torch.load(path, map_location=\"cpu\")\n        print(\"init from \" + path)\n\n        if (\n            \"ema_state_dict\" in sd\n            and len(sd[\"ema_state_dict\"]) > 0\n            and os.environ.get(\"NOT_USE_EMA_MODEL\", 0) == 0\n        ):\n            print(\"Load from ema model!\")\n            sd = sd[\"ema_state_dict\"]\n            sd = {key.replace(\"module.\", \"\"): value for key, value in sd.items()}\n        elif \"state_dict\" in sd:\n            print(\"Load from normal model!\")\n            if \"gen_model\" in sd[\"state_dict\"]:\n                sd = sd[\"state_dict\"][\"gen_model\"]\n            else:\n                sd = sd[\"state_dict\"]\n\n        keys = list(sd.keys())\n\n        for k in keys:\n            for ik in ignore_keys:\n                if k.startswith(ik):\n                    print(\"Deleting key {} from state_dict.\".format(k))\n                    del sd[k]\n\n        missing_keys, unexpected_keys = self.load_state_dict(sd, strict=False)\n        print(missing_keys, unexpected_keys)"
  },
  {
    "path": "opensora/models/causalvideovae/sample/rec_video_vae.py",
    "content": "import argparse\nfrom tqdm import tqdm\nimport torch\nimport sys\nfrom torch.utils.data import DataLoader, Subset\nimport os\nfrom accelerate import Accelerator\n\nsys.path.append(\".\")\nfrom opensora.models.causalvideovae.model import *\nfrom opensora.models.causalvideovae.dataset.video_dataset import ValidVideoDataset\nfrom opensora.models.causalvideovae.utils.video_utils import custom_to_video\n\n@torch.no_grad()\ndef main(args: argparse.Namespace):\n    accelerator = Accelerator()\n    device = accelerator.device\n    \n    real_video_dir = args.real_video_dir\n    generated_video_dir = args.generated_video_dir\n    sample_rate = args.sample_rate\n    resolution = args.resolution\n    crop_size = args.crop_size\n    num_frames = args.num_frames\n    sample_rate = args.sample_rate\n    device = args.device\n    sample_fps = args.sample_fps\n    batch_size = args.batch_size\n    num_workers = args.num_workers\n    subset_size = args.subset_size\n    \n    if not os.path.exists(args.generated_video_dir):\n        os.makedirs(args.generated_video_dir, exist_ok=True)\n    \n    data_type = torch.bfloat16\n    \n    # ---- Load Model ----\n    device = args.device\n    model_cls = ModelRegistry.get_model(args.model_name)\n    vae = model_cls.from_pretrained(args.from_pretrained)\n    vae = vae.to(device).to(data_type)\n    if args.enable_tiling:\n        vae.enable_tiling()\n        vae.tile_overlap_factor = args.tile_overlap_factor\n\n    # ---- Prepare Dataset ----\n    dataset = ValidVideoDataset(\n        real_video_dir=real_video_dir,\n        num_frames=num_frames,\n        sample_rate=sample_rate,\n        crop_size=crop_size,\n        resolution=resolution,\n    )\n    if subset_size:\n        indices = range(subset_size)\n        dataset = Subset(dataset, indices=indices)\n        \n    dataloader = DataLoader(\n        dataset, batch_size=batch_size, pin_memory=False, num_workers=num_workers\n    )\n    dataloader = accelerator.prepare(dataloader)\n\n    # ---- Inference ----\n    for batch in tqdm(dataloader, disable=not accelerator.is_local_main_process):\n        x, file_names = batch['video'], batch['file_name']\n        x = x.to(device=device, dtype=data_type)  # b c t h w\n        x = x * 2 - 1\n        encode_result = vae.encode(x)\n        if isinstance(encode_result, tuple):\n            encode_result = encode_result[0]\n        latents = encode_result.sample().to(data_type)\n        video_recon = vae.decode(latents)\n        if isinstance(video_recon, tuple):\n            video_recon = video_recon[0]\n        for idx, video in enumerate(video_recon):\n            output_path = os.path.join(generated_video_dir, file_names[idx])\n            if args.output_origin:\n                os.makedirs(os.path.join(generated_video_dir, \"origin/\"), exist_ok=True)\n                origin_output_path = os.path.join(generated_video_dir, \"origin/\", file_names[idx])\n                custom_to_video(\n                    x[idx], fps=sample_fps / sample_rate, output_file=origin_output_path\n                )\n            custom_to_video(\n                video, fps=sample_fps / sample_rate, output_file=output_path\n            )\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--real_video_dir\", type=str, default=\"\")\n    parser.add_argument(\"--generated_video_dir\", type=str, default=\"\")\n    parser.add_argument(\"--from_pretrained\", type=str, default=\"\")\n    parser.add_argument(\"--sample_fps\", type=int, default=30)\n    parser.add_argument(\"--resolution\", type=int, default=336)\n    parser.add_argument(\"--crop_size\", type=int, default=None)\n    parser.add_argument(\"--num_frames\", type=int, default=17)\n    parser.add_argument(\"--sample_rate\", type=int, default=1)\n    parser.add_argument(\"--batch_size\", type=int, default=1)\n    parser.add_argument(\"--num_workers\", type=int, default=8)\n    parser.add_argument(\"--subset_size\", type=int, default=None)\n    parser.add_argument(\"--tile_overlap_factor\", type=float, default=0.25)\n    parser.add_argument('--enable_tiling', action='store_true')\n    parser.add_argument('--output_origin', action='store_true')\n    parser.add_argument(\"--model_name\", type=str, default=None, help=\"\")\n    parser.add_argument(\"--device\", type=str, default=\"cuda\")\n\n    args = parser.parse_args()\n    main(args)\n    \n"
  },
  {
    "path": "opensora/models/causalvideovae/utils/__init__.py",
    "content": ""
  },
  {
    "path": "opensora/models/causalvideovae/utils/dataset_utils.py",
    "content": "import math\nfrom einops import rearrange\nimport decord\nfrom torch.nn import functional as F\nimport torch\n\n\nIMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG']\n\ndef is_image_file(filename):\n    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)\n\nclass DecordInit(object):\n    \"\"\"Using Decord(https://github.com/dmlc/decord) to initialize the video_reader.\"\"\"\n\n    def __init__(self, num_threads=1):\n        self.num_threads = num_threads\n        self.ctx = decord.cpu(0)\n\n    def __call__(self, filename):\n        \"\"\"Perform the Decord initialization.\n        Args:\n            results (dict): The resulting dict to be modified and passed\n                to the next transform in pipeline.\n        \"\"\"\n        reader = decord.VideoReader(filename,\n                                    ctx=self.ctx,\n                                    num_threads=self.num_threads)\n        return reader\n\n    def __repr__(self):\n        repr_str = (f'{self.__class__.__name__}('\n                    f'sr={self.sr},'\n                    f'num_threads={self.num_threads})')\n        return repr_str\n\ndef pad_to_multiple(number, ds_stride):\n    remainder = number % ds_stride\n    if remainder == 0:\n        return number\n    else:\n        padding = ds_stride - remainder\n        return number + padding\n"
  },
  {
    "path": "opensora/models/causalvideovae/utils/downloader.py",
    "content": "import gdown\nimport os\n\nopensora_cache_home = os.path.expanduser(\n    os.getenv(\"OPENSORA_HOME\", os.path.join(\"~/.cache\", \"opensora\"))\n)\n\n\ndef gdown_download(id, fname, cache_dir=None):\n    cache_dir = opensora_cache_home if not cache_dir else cache_dir\n\n    os.makedirs(cache_dir, exist_ok=True)\n    destination = os.path.join(cache_dir, fname)\n    if os.path.exists(destination):\n        return destination\n\n    gdown.download(id=id, output=destination, quiet=False)\n    return destination\n"
  },
  {
    "path": "opensora/models/causalvideovae/utils/video_utils.py",
    "content": "import torch\nimport numpy as np\nimport numpy.typing as npt\nimport cv2\nfrom decord import VideoReader, cpu\n\ndef array_to_video(\n    image_array: npt.NDArray, fps: float = 30.0, output_file: str = \"output_video.mp4\"\n) -> None:\n    \"\"\"b h w c\"\"\"\n    height, width, channels = image_array[0].shape\n    fourcc = cv2.VideoWriter_fourcc(*\"mp4v\")\n    video_writer = cv2.VideoWriter(output_file, fourcc, float(fps), (width, height))\n\n    for image in image_array:\n        image_rgb = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)\n        video_writer.write(image_rgb)\n\n    video_writer.release()\n\ndef custom_to_video(\n    x: torch.Tensor, fps: float = 2.0, output_file: str = \"output_video.mp4\"\n) -> None:\n    x = x.detach().cpu()\n    x = torch.clamp(x, -1, 1)\n    x = (x + 1) / 2\n    x = x.permute(1, 2, 3, 0).float().numpy()\n    x = (255 * x).astype(np.uint8)\n    array_to_video(x, fps=fps, output_file=output_file)\n    return\n\ndef read_video(video_path: str, num_frames: int, sample_rate: int) -> torch.Tensor:\n    decord_vr = VideoReader(video_path, ctx=cpu(0), num_threads=8)\n    total_frames = len(decord_vr)\n    sample_frames_len = sample_rate * num_frames\n\n    if total_frames > sample_frames_len:\n        s = 0\n        e = s + sample_frames_len\n        num_frames = num_frames\n    else:\n        s = 0\n        e = total_frames\n        num_frames = int(total_frames / sample_frames_len * num_frames)\n        print(\n            f\"sample_frames_len {sample_frames_len}, only can sample {num_frames * sample_rate}\",\n            video_path,\n            total_frames,\n        )\n\n    frame_id_list = np.linspace(s, e - 1, num_frames, dtype=int)\n    video_data = decord_vr.get_batch(frame_id_list).asnumpy()\n    video_data = torch.from_numpy(video_data)\n    video_data = video_data.permute(3, 0, 1, 2)  # (T, H, W, C) -> (C, T, H, W)\n    return video_data\n\ndef tensor_to_video(x):\n    \"\"\"[0-1] tensor to video\"\"\"\n    x = (x * 2 - 1).detach().cpu()\n    x = torch.clamp(x, -1, 1)\n    x = (x + 1) / 2\n    x = x.permute(1, 0, 2, 3).float().numpy() # c t h w -> t c h w\n    x = (255 * x).astype(np.uint8)\n    return x"
  },
  {
    "path": "opensora/models/diffusion/__init__.py",
    "content": "from .opensora_v1_3.modeling_opensora import OpenSora_v1_3_models\nfrom .opensora_v1_3.modeling_inpaint import OpenSoraInpaint_v1_3_models\n\n\nDiffusion_models = {}\nDiffusion_models.update(OpenSora_v1_3_models)\nDiffusion_models.update(OpenSoraInpaint_v1_3_models)\n\nfrom .opensora_v1_3.modeling_opensora import OpenSora_v1_3_models_class\nfrom .opensora_v1_3.modeling_inpaint import OpenSoraInpaint_v1_3_models_class\n\nDiffusion_models_class = {}\nDiffusion_models_class.update(OpenSora_v1_3_models_class)\nDiffusion_models_class.update(OpenSoraInpaint_v1_3_models_class)\n    "
  },
  {
    "path": "opensora/models/diffusion/common.py",
    "content": "import torch\nfrom einops import rearrange, repeat\nfrom typing import Any, Dict, Optional, Tuple\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\nfrom diffusers.models.attention_processor import Attention as Attention_\ntry:\n    import torch_npu\n    from opensora.npu_config import npu_config, set_run_dtype\n    from opensora.acceleration.parallel_states import get_sequence_parallel_state, hccl_info as xccl_info\n    from opensora.acceleration.communications import all_to_all_SBH\nexcept:\n    torch_npu = None\n    npu_config = None\n    set_run_dtype = None\n    from opensora.utils.parallel_states import get_sequence_parallel_state, nccl_info as xccl_info\n    from opensora.utils.communications import all_to_all_SBH\n\nclass PatchEmbed2D(nn.Module):\n    \"\"\"2D Image to Patch Embedding but with video\"\"\"\n\n    def __init__(\n        self,\n        patch_size=16,\n        in_channels=3,\n        embed_dim=768,\n        bias=True,\n    ):\n        super().__init__()\n        self.proj = nn.Conv2d(\n            in_channels, embed_dim, \n            kernel_size=(patch_size, patch_size), stride=(patch_size, patch_size), bias=bias\n        )\n\n    def forward(self, latent):\n        b, _, _, _, _ = latent.shape\n        latent = rearrange(latent, 'b c t h w -> (b t) c h w')\n        latent = self.proj(latent)\n        latent = rearrange(latent, '(b t) c h w -> b (t h w) c', b=b)\n        return latent\n\n\nclass PositionGetter3D(object):\n    \"\"\" return positions of patches \"\"\"\n\n    def __init__(self, ):\n        self.cache_positions = {}\n        \n    def __call__(self, b, t, h, w, device):\n        if not (b,t,h,w) in self.cache_positions:\n            x = torch.arange(w, device=device)\n            y = torch.arange(h, device=device)\n            z = torch.arange(t, device=device)\n            pos = torch.cartesian_prod(z, y, x)\n            # print('PositionGetter3D', PositionGetter3D)\n            pos = pos.reshape(t * h * w, 3).transpose(0, 1).reshape(3, -1, 1).contiguous().expand(3, -1, b).clone()\n            poses = (pos[0].contiguous(), pos[1].contiguous(), pos[2].contiguous())\n            max_poses = (int(poses[0].max()), int(poses[1].max()), int(poses[2].max()))\n\n            self.cache_positions[b, t, h, w] = (poses, max_poses)\n        pos = self.cache_positions[b, t, h, w]\n\n        return pos\n    \nclass RoPE3D(torch.nn.Module):\n\n    def __init__(self, freq=10000.0, F0=1.0, interpolation_scale_thw=(1, 1, 1)):\n        super().__init__()\n        self.base = freq\n        self.F0 = F0\n        self.interpolation_scale_t = interpolation_scale_thw[0]\n        self.interpolation_scale_h = interpolation_scale_thw[1]\n        self.interpolation_scale_w = interpolation_scale_thw[2]\n        self.cache = {}\n\n    def get_cos_sin(self, D, seq_len, device, dtype, interpolation_scale=1):\n        if (D, seq_len, device, dtype) not in self.cache:\n            inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D))\n            t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) / interpolation_scale\n            freqs = torch.einsum(\"i,j->ij\", t, inv_freq).to(dtype)\n            freqs = torch.cat((freqs, freqs), dim=-1)\n            cos = freqs.cos()  # (Seq, Dim)\n            sin = freqs.sin()\n            self.cache[D, seq_len, device, dtype] = (cos, sin)\n        return self.cache[D, seq_len, device, dtype]\n\n    @staticmethod\n    def rotate_half(x):\n        x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:]\n        return torch.cat((-x2, x1), dim=-1)\n\n    def apply_rope1d(self, tokens, pos1d, cos, sin):\n        assert pos1d.ndim == 2\n        # for (ntokens x batch_size x nheads x dim)\n        cos = torch.nn.functional.embedding(pos1d, cos)[:, :, None, :]\n        sin = torch.nn.functional.embedding(pos1d, sin)[:, :, None, :]\n\n        return (tokens * cos) + (self.rotate_half(tokens) * sin)\n\n    def forward(self, tokens, positions):\n        \"\"\"\n        input:\n            * tokens: ntokens x batch_size x nheads x dim\n            * positions: batch_size x ntokens x 3 (t, y and x position of each token)\n        output:\n            * tokens after appplying RoPE3D (ntokens x batch_size x nheads x dim)\n        \"\"\"\n        assert tokens.size(3) % 3 == 0, \"number of dimensions should be a multiple of three\"\n        D = tokens.size(3) // 3\n        poses, max_poses = positions\n        assert len(poses) == 3 and poses[0].ndim == 2# Batch, Seq, 3\n        cos_t, sin_t = self.get_cos_sin(D, max_poses[0] + 1, tokens.device, tokens.dtype, self.interpolation_scale_t)\n        cos_y, sin_y = self.get_cos_sin(D, max_poses[1] + 1, tokens.device, tokens.dtype, self.interpolation_scale_h)\n        cos_x, sin_x = self.get_cos_sin(D, max_poses[2] + 1, tokens.device, tokens.dtype, self.interpolation_scale_w)\n        # split features into three along the feature dimension, and apply rope1d on each half\n        t, y, x = tokens.chunk(3, dim=-1)\n        t = self.apply_rope1d(t, poses[0], cos_t, sin_t)\n        y = self.apply_rope1d(y, poses[1], cos_y, sin_y)\n        x = self.apply_rope1d(x, poses[2], cos_x, sin_x)\n        tokens = torch.cat((t, y, x), dim=-1)\n        return tokens"
  },
  {
    "path": "opensora/models/diffusion/opensora_v1_3/__init__.py",
    "content": ""
  },
  {
    "path": "opensora/models/diffusion/opensora_v1_3/modeling_inpaint.py",
    "content": "import os\nimport numpy as np\nfrom torch import nn\nimport torch\nfrom einops import rearrange, repeat\nfrom typing import Any, Dict, Optional, Tuple\nfrom diffusers.configuration_utils import register_to_config\nfrom opensora.models.diffusion.common import PatchEmbed2D\nfrom opensora.utils.utils import to_2tuple\n\n\nfrom opensora.models.diffusion.opensora_v1_3.modeling_opensora import OpenSoraT2V_v1_3 as OpenSoraT2V\n\nimport glob\n\ndef zero_module(module):\n    for p in module.parameters():\n        nn.init.zeros_(p)\n    return module\n\n\nclass OpenSoraInpaint_v1_3(OpenSoraT2V):\n    _supports_gradient_checkpointing = True\n\n    @register_to_config\n    def __init__(\n        self,\n        num_attention_heads: int = 16,\n        attention_head_dim: int = 88,\n        in_channels: Optional[int] = None,\n        out_channels: Optional[int] = None,\n        num_layers: int = 1,\n        dropout: float = 0.0,\n        cross_attention_dim: Optional[int] = None,\n        attention_bias: bool = True,\n        sample_size_h: Optional[int] = None,\n        sample_size_w: Optional[int] = None,\n        sample_size_t: Optional[int] = None,\n        patch_size: Optional[int] = None,\n        patch_size_t: Optional[int] = None,\n        activation_fn: str = \"geglu\",\n        only_cross_attention: bool = False,\n        double_self_attention: bool = False,\n        upcast_attention: bool = False,\n        norm_elementwise_affine: bool = False,\n        norm_eps: float = 1e-6,\n        caption_channels: int = None,\n        interpolation_scale_h: float = 1.0,\n        interpolation_scale_w: float = 1.0,\n        interpolation_scale_t: float = 1.0,\n        sparse1d: bool = False,\n        sparse_n: int = 2,\n        # inpaint\n        vae_scale_factor_t: int = 4,\n    ):\n        super().__init__(\n            num_attention_heads=num_attention_heads,\n            attention_head_dim=attention_head_dim,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            num_layers=num_layers,\n            dropout=dropout,\n            cross_attention_dim=cross_attention_dim,\n            attention_bias=attention_bias,\n            sample_size_h=sample_size_h,\n            sample_size_w=sample_size_w,\n            sample_size_t=sample_size_t,\n            patch_size=patch_size,\n            patch_size_t=patch_size_t,\n            activation_fn=activation_fn,\n            only_cross_attention=only_cross_attention,\n            double_self_attention=double_self_attention,\n            upcast_attention=upcast_attention,\n            norm_elementwise_affine=norm_elementwise_affine,\n            norm_eps=norm_eps,\n            caption_channels=caption_channels,\n            interpolation_scale_h=interpolation_scale_h,\n            interpolation_scale_w=interpolation_scale_w,\n            interpolation_scale_t=interpolation_scale_t,\n            sparse1d=sparse1d,\n            sparse_n=sparse_n,\n        )\n\n        self.vae_scale_factor_t = vae_scale_factor_t\n        # init masked_pixel_values and mask conv_in\n        self._init_patched_inputs_for_inpainting()\n\n    def _init_patched_inputs_for_inpainting(self):\n\n        self.config.sample_size = to_2tuple(self.config.sample_size)\n\n        self.pos_embed_masked_hidden_states = nn.ModuleList(\n            [\n                PatchEmbed2D(\n                    patch_size=self.config.patch_size,\n                    in_channels=self.config.in_channels,\n                    embed_dim=self.config.hidden_size,\n                ),\n                zero_module(nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False)),\n            ]\n        )\n\n        self.pos_embed_mask = nn.ModuleList(\n            [\n                PatchEmbed2D(\n                    patch_size=self.config.patch_size,\n                    in_channels=self.vae_scale_factor_t,\n                    embed_dim=self.config.hidden_size,\n                ),\n                zero_module(nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False)),\n            ]\n        )\n\n    def _operate_on_patched_inputs(self, hidden_states, encoder_hidden_states, timestep, batch_size, frame):\n        # inpaint\n        assert hidden_states.shape[1] == 2 * self.config.in_channels + self.vae_scale_factor_t\n        in_channels = self.config.in_channels\n\n        input_hidden_states, input_masked_hidden_states, input_mask = hidden_states[:, :in_channels], hidden_states[:, in_channels: 2 * in_channels], hidden_states[:, 2 * in_channels:]\n\n        input_hidden_states = self.pos_embed(input_hidden_states.to(self.dtype))\n\n        input_masked_hidden_states = self.pos_embed_masked_hidden_states[0](input_masked_hidden_states.to(self.dtype))\n        input_masked_hidden_states = self.pos_embed_masked_hidden_states[1](input_masked_hidden_states)\n\n        input_mask = self.pos_embed_mask[0](input_mask.to(self.dtype))\n        input_mask = self.pos_embed_mask[1](input_mask)\n\n        hidden_states = input_hidden_states + input_masked_hidden_states + input_mask\n\n        added_cond_kwargs = {\"resolution\": None, \"aspect_ratio\": None}\n        timestep, embedded_timestep = self.adaln_single(\n            timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=self.dtype\n        )  # b 6d, b d\n\n        encoder_hidden_states = self.caption_projection(encoder_hidden_states)  # b, 1, l, d or b, 1, l, d\n        assert encoder_hidden_states.shape[1] == 1\n        encoder_hidden_states = rearrange(encoder_hidden_states, 'b 1 l d -> (b 1) l d')\n\n        return hidden_states, encoder_hidden_states, timestep, embedded_timestep\n\ndef OpenSoraInpaint_v1_3_2B_122(**kwargs):\n   return OpenSoraInpaint_v1_3(\n        num_layers=32, attention_head_dim=96, num_attention_heads=24, patch_size_t=1, patch_size=2,\n        caption_channels=4096, cross_attention_dim=2304, activation_fn=\"gelu-approximate\", **kwargs\n    )\n\nOpenSoraInpaint_v1_3_models = {\n    \"OpenSoraInpaint_v1_3-2B/122\": OpenSoraInpaint_v1_3_2B_122,  # 2.7B\n}\n\nOpenSoraInpaint_v1_3_models_class = {\n    \"OpenSoraInpaint_v1_3-2B/122\": OpenSoraInpaint_v1_3,\n    \"OpenSoraInpaint_v1_3\": OpenSoraInpaint_v1_3,\n}\n"
  },
  {
    "path": "opensora/models/diffusion/opensora_v1_3/modeling_opensora.py",
    "content": "import os\nimport numpy as np\nfrom torch import nn\nimport torch\nfrom einops import rearrange, repeat\nfrom typing import Any, Dict, Optional, Tuple\nfrom torch.nn import functional as F\nfrom diffusers.models.modeling_outputs import Transformer2DModelOutput\nfrom diffusers.utils import is_torch_version, deprecate\nfrom diffusers.configuration_utils import ConfigMixin, register_to_config\nfrom diffusers.models.modeling_utils import ModelMixin\nfrom diffusers.models.normalization import AdaLayerNormSingle\nfrom diffusers.models.embeddings import PixArtAlphaTextProjection\nfrom opensora.models.diffusion.opensora_v1_3.modules import BasicTransformerBlock, Attention\nfrom opensora.models.diffusion.common import PatchEmbed2D\nfrom opensora.utils.utils import to_2tuple\ntry:\n    import torch_npu\n    from opensora.npu_config import npu_config\n    from opensora.acceleration.parallel_states import get_sequence_parallel_state, hccl_info\nexcept:\n    torch_npu = None\n    npu_config = None\n    from opensora.utils.parallel_states import get_sequence_parallel_state, nccl_info\n\nclass OpenSoraT2V_v1_3(ModelMixin, ConfigMixin):\n    _supports_gradient_checkpointing = True\n\n    @register_to_config\n    def __init__(\n        self,\n        num_attention_heads: int = 16,\n        attention_head_dim: int = 88,\n        in_channels: Optional[int] = None,\n        out_channels: Optional[int] = None,\n        num_layers: int = 1,\n        dropout: float = 0.0,\n        cross_attention_dim: Optional[int] = None,\n        attention_bias: bool = True,\n        sample_size_h: Optional[int] = None,\n        sample_size_w: Optional[int] = None,\n        sample_size_t: Optional[int] = None,\n        patch_size: Optional[int] = None,\n        patch_size_t: Optional[int] = None,\n        activation_fn: str = \"geglu\",\n        only_cross_attention: bool = False,\n        double_self_attention: bool = False,\n        upcast_attention: bool = False,\n        norm_elementwise_affine: bool = False,\n        norm_eps: float = 1e-6,\n        caption_channels: int = None,\n        interpolation_scale_h: float = 1.0,\n        interpolation_scale_w: float = 1.0,\n        interpolation_scale_t: float = 1.0,\n        sparse1d: bool = False,\n        sparse_n: int = 2,\n    ):\n        super().__init__()\n        # Set some common variables used across the board.\n        self.out_channels = in_channels if out_channels is None else out_channels\n        self.config.hidden_size = self.config.num_attention_heads * self.config.attention_head_dim\n        self.gradient_checkpointing = False\n        self._init_patched_inputs()\n\n    def _init_patched_inputs(self):\n\n        self.config.sample_size = (self.config.sample_size_h, self.config.sample_size_w)\n        interpolation_scale_thw = (\n            self.config.interpolation_scale_t, \n            self.config.interpolation_scale_h, \n            self.config.interpolation_scale_w\n            )\n        \n        self.caption_projection = PixArtAlphaTextProjection(\n            in_features=self.config.caption_channels, hidden_size=self.config.hidden_size\n        )\n\n        self.pos_embed = PatchEmbed2D(\n            patch_size=self.config.patch_size,\n            in_channels=self.config.in_channels,\n            embed_dim=self.config.hidden_size,\n        )\n        \n        self.transformer_blocks = nn.ModuleList(\n            [\n                BasicTransformerBlock(\n                    self.config.hidden_size,\n                    self.config.num_attention_heads,\n                    self.config.attention_head_dim,\n                    dropout=self.config.dropout,\n                    cross_attention_dim=self.config.cross_attention_dim,\n                    activation_fn=self.config.activation_fn,\n                    attention_bias=self.config.attention_bias,\n                    only_cross_attention=self.config.only_cross_attention,\n                    double_self_attention=self.config.double_self_attention,\n                    upcast_attention=self.config.upcast_attention,\n                    norm_elementwise_affine=self.config.norm_elementwise_affine,\n                    norm_eps=self.config.norm_eps,\n                    interpolation_scale_thw=interpolation_scale_thw, \n                    sparse1d=self.config.sparse1d if i > 1 and i < 30 else False, \n                    sparse_n=self.config.sparse_n, \n                    sparse_group=i % 2 == 1, \n                )\n                for i in range(self.config.num_layers)\n            ]\n        )\n        self.norm_out = nn.LayerNorm(self.config.hidden_size, elementwise_affine=False, eps=1e-6)\n        self.scale_shift_table = nn.Parameter(torch.randn(2, self.config.hidden_size) / self.config.hidden_size**0.5)\n        self.proj_out = nn.Linear(\n            self.config.hidden_size, self.config.patch_size_t * self.config.patch_size * self.config.patch_size * self.out_channels\n        )\n        self.adaln_single = AdaLayerNormSingle(self.config.hidden_size)\n\n    def _set_gradient_checkpointing(self, module, value=False):\n        if hasattr(module, \"gradient_checkpointing\"):\n            module.gradient_checkpointing = value\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        timestep: Optional[torch.LongTensor] = None,\n        encoder_hidden_states: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        encoder_attention_mask: Optional[torch.Tensor] = None,\n        return_dict: bool = True,\n        **kwargs, \n    ):\n        \n        batch_size, c, frame, h, w = hidden_states.shape\n        # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.\n        #   we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.\n        #   we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.\n        # expects mask of shape:\n        #   [batch, key_tokens]\n        # adds singleton query_tokens dimension:\n        #   [batch,                    1, key_tokens]\n        # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:\n        #   [batch,  heads, query_tokens, key_tokens] (e.g. torch sdp attn)\n        #   [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)\n        if attention_mask is not None and attention_mask.ndim == 4:\n            # assume that mask is expressed as:\n            #   (1 = keep,      0 = discard)\n            # convert mask into a bias that can be added to attention scores:\n            #   (keep = +0,     discard = -10000.0)\n            # b, frame, h, w -> a video\n            # b, 1, h, w -> only images\n            attention_mask = attention_mask.to(self.dtype)\n\n            attention_mask = attention_mask.unsqueeze(1)  # b 1 t h w\n            attention_mask = F.max_pool3d(\n                attention_mask, \n                kernel_size=(self.config.patch_size_t, self.config.patch_size, self.config.patch_size), \n                stride=(self.config.patch_size_t, self.config.patch_size, self.config.patch_size)\n                )\n            attention_mask = rearrange(attention_mask, 'b 1 t h w -> (b 1) 1 (t h w)') \n            attention_mask = (1 - attention_mask.bool().to(self.dtype)) * -10000.0\n\n\n        # convert encoder_attention_mask to a bias the same way we do for attention_mask\n        if encoder_attention_mask is not None and encoder_attention_mask.ndim == 3:  \n            # b, 1, l\n            encoder_attention_mask = (1 - encoder_attention_mask.to(self.dtype)) * -10000.0\n\n\n        # 1. Input\n        frame = ((frame - 1) // self.config.patch_size_t + 1) if frame % 2 == 1 else frame // self.config.patch_size_t  # patchfy\n        height, width = hidden_states.shape[-2] // self.config.patch_size, hidden_states.shape[-1] // self.config.patch_size\n\n\n        hidden_states, encoder_hidden_states, timestep, embedded_timestep = self._operate_on_patched_inputs(\n            hidden_states, encoder_hidden_states, timestep, batch_size, frame\n        )\n\n        # To\n        # x            (t*h*w b d) or (t//sp*h*w b d)\n        # cond_1       (l b d) or (l//sp b d)\n        hidden_states = rearrange(hidden_states, 'b s h -> s b h', b=batch_size).contiguous()\n        encoder_hidden_states = rearrange(encoder_hidden_states, 'b s h -> s b h', b=batch_size).contiguous()\n        timestep = timestep.view(batch_size, 6, -1).transpose(0, 1).contiguous()\n\n        sparse_mask = {}\n        if npu_config is None:\n            if get_sequence_parallel_state():\n                head_num = self.config.num_attention_heads // nccl_info.world_size\n            else:\n                head_num = self.config.num_attention_heads\n        else:\n            head_num = None\n        for sparse_n in [1, 4]:\n            sparse_mask[sparse_n] = Attention.prepare_sparse_mask(attention_mask, encoder_attention_mask, sparse_n, head_num)\n        # 2. Blocks\n        for i, block in enumerate(self.transformer_blocks):\n            if i > 1 and i < 30:\n                attention_mask, encoder_attention_mask = sparse_mask[block.attn1.processor.sparse_n][block.attn1.processor.sparse_group]\n            else:\n                attention_mask, encoder_attention_mask = sparse_mask[1][block.attn1.processor.sparse_group]\n            if self.training and self.gradient_checkpointing:\n\n                def create_custom_forward(module, return_dict=None):\n                    def custom_forward(*inputs):\n                        if return_dict is not None:\n                            return module(*inputs, return_dict=return_dict)\n                        else:\n                            return module(*inputs)\n\n                    return custom_forward\n\n                ckpt_kwargs: Dict[str, Any] = {\"use_reentrant\": False} if is_torch_version(\">=\", \"1.11.0\") else {}\n                \n                hidden_states = torch.utils.checkpoint.checkpoint(\n                    create_custom_forward(block),\n                    hidden_states,\n                    attention_mask,\n                    encoder_hidden_states,\n                    encoder_attention_mask,\n                    timestep,\n                    frame, \n                    height, \n                    width, \n                    **ckpt_kwargs,\n                )\n            else:\n                hidden_states = block(\n                    hidden_states,\n                    attention_mask=attention_mask,\n                    encoder_hidden_states=encoder_hidden_states,\n                    encoder_attention_mask=encoder_attention_mask,\n                    timestep=timestep,\n                    frame=frame, \n                    height=height, \n                    width=width, \n                )\n\n        # To (b, t*h*w, h) or (b, t//sp*h*w, h)\n        hidden_states = rearrange(hidden_states, 's b h -> b s h', b=batch_size).contiguous()\n\n        # 3. Output\n        output = self._get_output_for_patched_inputs(\n            hidden_states=hidden_states,\n            timestep=timestep,\n            embedded_timestep=embedded_timestep,\n            num_frames=frame, \n            height=height,\n            width=width,\n        )  # b c t h w\n\n        if not return_dict:\n            return (output,)\n\n        return Transformer2DModelOutput(sample=output)\n\n\n    def _operate_on_patched_inputs(self, hidden_states, encoder_hidden_states, timestep, batch_size, frame):\n        \n        hidden_states = self.pos_embed(hidden_states.to(self.dtype))\n\n        added_cond_kwargs = {\"resolution\": None, \"aspect_ratio\": None}\n        timestep, embedded_timestep = self.adaln_single(\n            timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=self.dtype\n        )  # b 6d, b d\n\n        encoder_hidden_states = self.caption_projection(encoder_hidden_states)  # b, 1, l, d or b, 1, l, d\n        assert encoder_hidden_states.shape[1] == 1\n        encoder_hidden_states = rearrange(encoder_hidden_states, 'b 1 l d -> (b 1) l d')\n\n        return hidden_states, encoder_hidden_states, timestep, embedded_timestep\n\n    \n    \n    def _get_output_for_patched_inputs(\n        self, hidden_states, timestep, embedded_timestep, num_frames, height, width\n    ):  \n        shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)\n        hidden_states = self.norm_out(hidden_states)\n        # Modulation\n        hidden_states = hidden_states * (1 + scale) + shift\n        hidden_states = self.proj_out(hidden_states)\n        hidden_states = hidden_states.squeeze(1)\n\n        # unpatchify\n        hidden_states = hidden_states.reshape(\n            shape=(-1, num_frames, height, width, self.config.patch_size_t, self.config.patch_size, self.config.patch_size, self.out_channels)\n        )\n        hidden_states = torch.einsum(\"nthwopqc->nctohpwq\", hidden_states)\n        output = hidden_states.reshape(\n            shape=(-1, self.out_channels, \n                   num_frames * self.config.patch_size_t, height * self.config.patch_size, width * self.config.patch_size)\n        )\n        return output\n\ndef OpenSoraT2V_v1_3_2B_122(**kwargs):\n    kwargs.pop('skip_connection', None)\n    return OpenSoraT2V_v1_3(\n        num_layers=32, attention_head_dim=96, num_attention_heads=24, patch_size_t=1, patch_size=2,\n        caption_channels=4096, cross_attention_dim=2304, activation_fn=\"gelu-approximate\", **kwargs\n        )\n\nOpenSora_v1_3_models = {\n    \"OpenSoraT2V_v1_3-2B/122\": OpenSoraT2V_v1_3_2B_122,  # 2.7B\n}\n\nOpenSora_v1_3_models_class = {\n    \"OpenSoraT2V_v1_3-2B/122\": OpenSoraT2V_v1_3,\n    \"OpenSoraT2V_v1_3\": OpenSoraT2V_v1_3,\n}\n\nif __name__ == '__main__':\n    from opensora.models.causalvideovae import ae_stride_config, ae_channel_config\n    from opensora.models.causalvideovae import ae_norm, ae_denorm\n    from opensora.models import CausalVAEModelWrapper\n\n    args = type('args', (), \n    {\n        'ae': 'WFVAEModel_D8_4x8x8', \n        'model_max_length': 300, \n        'max_height': 176,\n        'max_width': 176,\n        'num_frames': 33,\n        'compress_kv_factor': 1, \n        'interpolation_scale_t': 1,\n        'interpolation_scale_h': 1,\n        'interpolation_scale_w': 1,\n        \"sparse1d\": True, \n        \"sparse_n\": 4, \n        \"rank\": 64, \n    }\n    )\n    b = 2\n    c = 8\n    cond_c = 4096\n    num_timesteps = 1000\n    ae_stride_t, ae_stride_h, ae_stride_w = ae_stride_config[args.ae]\n    latent_size = (args.max_height // ae_stride_h, args.max_width // ae_stride_w)\n    num_frames = (args.num_frames - 1) // ae_stride_t + 1\n\n    device = torch.device('cuda:0')\n    model = OpenSoraT2V_v1_3_2B_122(\n        in_channels=c, \n        out_channels=c, \n        sample_size_h=latent_size, \n        sample_size_w=latent_size, \n        sample_size_t=num_frames, \n        activation_fn=\"gelu-approximate\",\n        attention_bias=True,\n        double_self_attention=False,\n        norm_elementwise_affine=False,\n        norm_eps=1e-06,\n        only_cross_attention=False,\n        upcast_attention=False,\n        interpolation_scale_t=args.interpolation_scale_t, \n        interpolation_scale_h=args.interpolation_scale_h, \n        interpolation_scale_w=args.interpolation_scale_w, \n        sparse1d=args.sparse1d, \n        sparse_n=args.sparse_n\n    )\n    \n    try:\n        path = \"/storage/ongoing/new/7.19anyres/Open-Sora-Plan/bs32x8x1_anyx93x640x640_fps16_lr1e-5_snr5_ema9999_sparse1d4_dit_l_mt5xxl_vpred_zerosnr/checkpoint-43000/model_ema/diffusion_pytorch_model.safetensors\"\n        # ckpt = torch.load(path, map_location=\"cpu\")\n        from safetensors.torch import load_file as safe_load\n        ckpt = safe_load(path, device=\"cpu\")\n        msg = model.load_state_dict(ckpt, strict=True)\n        print(msg)\n    except Exception as e:\n        print(e)\n    print(model)\n    print(f'{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e9} B')\n    # import sys;sys.exit()\n    model = model.to(device)\n    x = torch.randn(b, c,  1+(args.num_frames-1)//ae_stride_t, args.max_height//ae_stride_h, args.max_width//ae_stride_w).to(device)\n    cond = torch.randn(b, 1, args.model_max_length, cond_c).to(device)\n    attn_mask = torch.randint(0, 2, (b, 1+(args.num_frames-1)//ae_stride_t, args.max_height//ae_stride_h, args.max_width//ae_stride_w)).to(device)  # B L or B 1+num_images L\n    cond_mask = torch.randint(0, 2, (b, 1, args.model_max_length)).to(device)  # B L or B 1+num_images L\n    timestep = torch.randint(0, 1000, (b,), device=device)\n    model_kwargs = dict(\n        hidden_states=x, encoder_hidden_states=cond, attention_mask=attn_mask, \n        encoder_attention_mask=cond_mask, timestep=timestep\n        )\n    with torch.no_grad():\n        output = model(**model_kwargs)\n    print(output[0].shape)\n\n"
  },
  {
    "path": "opensora/models/diffusion/opensora_v1_3/modules.py",
    "content": "import torch\nfrom einops import rearrange, repeat\nfrom typing import Any, Dict, Optional, Tuple\nimport torch.nn.functional as F\nfrom torch import nn\nfrom typing import Optional, Tuple\nfrom diffusers.utils import logging\nfrom diffusers.utils.torch_utils import maybe_allow_in_graph\nfrom diffusers.models.attention import FeedForward\n\nfrom diffusers.models.attention_processor import Attention as Attention_\nfrom diffusers.models.embeddings import Timesteps, TimestepEmbedding\ntry:\n    import torch_npu\n    from opensora.npu_config import npu_config, set_run_dtype\n    from opensora.acceleration.parallel_states import get_sequence_parallel_state, hccl_info as xccl_info\n    from opensora.acceleration.communications import all_to_all_SBH\nexcept:\n    torch_npu = None\n    npu_config = None\n    set_run_dtype = None\n    from opensora.utils.parallel_states import get_sequence_parallel_state, nccl_info as xccl_info\n    from opensora.utils.communications import all_to_all_SBH\n\nfrom ..common import RoPE3D, PositionGetter3D\n\nlogger = logging.get_logger(__name__)\n\n\nclass Attention(Attention_):\n    def __init__(\n            self, interpolation_scale_thw, sparse1d, sparse_n, \n            sparse_group, is_cross_attn, **kwags\n            ):\n        processor = OpenSoraAttnProcessor2_0(\n            interpolation_scale_thw=interpolation_scale_thw, sparse1d=sparse1d, sparse_n=sparse_n, \n            sparse_group=sparse_group, is_cross_attn=is_cross_attn\n            )\n        super().__init__(processor=processor, **kwags)\n\n    @staticmethod\n    def prepare_sparse_mask(attention_mask, encoder_attention_mask, sparse_n, head_num):\n        attention_mask = attention_mask.unsqueeze(1)\n        encoder_attention_mask = encoder_attention_mask.unsqueeze(1)\n        l = attention_mask.shape[-1]\n        if l % (sparse_n * sparse_n) == 0:\n            pad_len = 0\n        else:\n            pad_len = sparse_n * sparse_n - l % (sparse_n * sparse_n)\n\n        attention_mask_sparse = F.pad(attention_mask, (0, pad_len, 0, 0), value=-9980.0)\n        attention_mask_sparse_1d = rearrange(\n            attention_mask_sparse, \n            'b 1 1 (g k) -> (k b) 1 1 g', \n            k=sparse_n\n            )\n        attention_mask_sparse_1d_group = rearrange(\n            attention_mask_sparse, \n            'b 1 1 (n m k) -> (m b) 1 1 (n k)',\n            m=sparse_n, \n            k=sparse_n\n            )\n        encoder_attention_mask_sparse = encoder_attention_mask.repeat(sparse_n, 1, 1, 1)\n        if npu_config is not None:\n            attention_mask_sparse_1d = npu_config.get_attention_mask(\n                attention_mask_sparse_1d, attention_mask_sparse_1d.shape[-1]\n                )\n            attention_mask_sparse_1d_group = npu_config.get_attention_mask(\n                attention_mask_sparse_1d_group, attention_mask_sparse_1d_group.shape[-1]\n                )\n            \n            encoder_attention_mask_sparse_1d = npu_config.get_attention_mask(\n                encoder_attention_mask_sparse, attention_mask_sparse_1d.shape[-1]\n                )\n            encoder_attention_mask_sparse_1d_group = encoder_attention_mask_sparse_1d\n        else:\n            attention_mask_sparse_1d = attention_mask_sparse_1d.repeat_interleave(head_num, dim=1)\n            attention_mask_sparse_1d_group = attention_mask_sparse_1d_group.repeat_interleave(head_num, dim=1)\n\n            encoder_attention_mask_sparse_1d = encoder_attention_mask_sparse.repeat_interleave(head_num, dim=1)\n            encoder_attention_mask_sparse_1d_group = encoder_attention_mask_sparse_1d\n\n        return {\n                    False: (attention_mask_sparse_1d, encoder_attention_mask_sparse_1d),\n                    True: (attention_mask_sparse_1d_group, encoder_attention_mask_sparse_1d_group)\n                }\n\n    def prepare_attention_mask(\n        self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3\n    ) -> torch.Tensor:\n        r\"\"\"\n        Prepare the attention mask for the attention computation.\n\n        Args:\n            attention_mask (`torch.Tensor`):\n                The attention mask to prepare.\n            target_length (`int`):\n                The target length of the attention mask. This is the length of the attention mask after padding.\n            batch_size (`int`):\n                The batch size, which is used to repeat the attention mask.\n            out_dim (`int`, *optional*, defaults to `3`):\n                The output dimension of the attention mask. Can be either `3` or `4`.\n\n        Returns:\n            `torch.Tensor`: The prepared attention mask.\n        \"\"\"\n        head_size = self.heads\n        if get_sequence_parallel_state():\n            head_size = head_size // xccl_info.world_size  # e.g, 24 // 8\n        \n        if attention_mask is None:  # b 1 t*h*w in sa, b 1 l in ca\n            return attention_mask\n\n        current_length: int = attention_mask.shape[-1]\n        if current_length != target_length:\n            print(f'attention_mask.shape, {attention_mask.shape}, current_length, {current_length}, target_length, {target_length}')\n            attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)\n\n        if out_dim == 3:\n            if attention_mask.shape[0] < batch_size * head_size:\n                attention_mask = attention_mask.repeat_interleave(head_size, dim=0)\n        elif out_dim == 4:\n            attention_mask = attention_mask.unsqueeze(1)\n            attention_mask = attention_mask.repeat_interleave(head_size, dim=1)\n\n        return attention_mask\n\nclass OpenSoraAttnProcessor2_0:\n    r\"\"\"\n    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).\n    \"\"\"\n\n    def __init__(self, interpolation_scale_thw=(1, 1, 1), \n                 sparse1d=False, sparse_n=2, sparse_group=False, is_cross_attn=True):\n        self.sparse1d = sparse1d\n        self.sparse_n = sparse_n\n        self.sparse_group = sparse_group\n        self.is_cross_attn = is_cross_attn\n        self.interpolation_scale_thw = interpolation_scale_thw\n        \n        self._init_rope(interpolation_scale_thw)\n        if not hasattr(F, \"scaled_dot_product_attention\"):\n            raise ImportError(\"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.\")\n\n    def _init_rope(self, interpolation_scale_thw):\n        self.rope = RoPE3D(interpolation_scale_thw=interpolation_scale_thw)\n        self.position_getter = PositionGetter3D()\n    \n    def _sparse_1d(self, x, frame, height, width):\n        \"\"\"\n        require the shape of (ntokens x batch_size x dim)\n        \"\"\"\n        l = x.shape[0]\n        assert l == frame*height*width\n        pad_len = 0\n        if l % (self.sparse_n * self.sparse_n) != 0:\n            pad_len = self.sparse_n * self.sparse_n - l % (self.sparse_n * self.sparse_n)\n        if pad_len != 0:\n            x = F.pad(x, (0, 0, 0, 0, 0, pad_len))\n        if not self.sparse_group:\n            x = rearrange(x, '(g k) b d -> g (k b) d', k=self.sparse_n)\n        else:\n            x = rearrange(x, '(n m k) b d -> (n k) (m b) d', m=self.sparse_n, k=self.sparse_n)\n        return x, pad_len\n    \n    def _reverse_sparse_1d(self, x, frame, height, width, pad_len):\n        \"\"\"\n        require the shape of (ntokens x batch_size x dim)\n        \"\"\"\n        assert x.shape[0] == (frame*height*width+pad_len) // self.sparse_n\n        if not self.sparse_group:\n            x = rearrange(x, 'g (k b) d -> (g k) b d', k=self.sparse_n)\n        else:\n            x = rearrange(x, '(n k) (m b) d -> (n m k) b d', m=self.sparse_n, k=self.sparse_n)\n        x = x[:frame*height*width, :, :]\n        return x\n    \n    def _sparse_1d_kv(self, x):\n        \"\"\"\n        require the shape of (ntokens x batch_size x dim)\n        \"\"\"\n        x = repeat(x, 's b d -> s (k b) d', k=self.sparse_n)\n        return x\n    \n    def __call__(\n        self,\n        attn: Attention,\n        hidden_states: torch.FloatTensor,\n        encoder_hidden_states: Optional[torch.FloatTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        temb: Optional[torch.FloatTensor] = None,\n        frame: int = 8, \n        height: int = 16, \n        width: int = 16, \n        *args,\n        **kwargs,\n    ) -> torch.FloatTensor:\n\n        residual = hidden_states\n\n        sequence_length, batch_size, _ = (\n            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n        )\n\n        # if attention_mask is not None:\n        #     if npu_config is None:\n        #         # scaled_dot_product_attention expects attention_mask shape to be\n        #         # (batch, heads, source_length, target_length)\n        #         if get_sequence_parallel_state():\n        #             # sequence_length has been split, so we need sequence_length * nccl_info.world_size\n        #             # (sp*b 1 s), where s has not been split\n        #             # (sp*b 1 s) -prepare-> (sp*b*head 1 s) -> (sp*b head 1 s), where head has been split (e.g, 24 // 8)\n        #             attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length * xccl_info.world_size, batch_size)\n        #             attention_mask = attention_mask.view(batch_size, attn.heads // xccl_info.world_size, -1, attention_mask.shape[-1])\n        #         else:\n        #             attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n        #             attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])\n\n        query = attn.to_q(hidden_states)\n\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n\n        key = attn.to_k(encoder_hidden_states)\n        value = attn.to_v(encoder_hidden_states)\n\n        inner_dim = key.shape[-1]\n        head_dim = inner_dim // attn.heads\n        FA_head_num = attn.heads\n        total_frame = frame\n\n        if get_sequence_parallel_state():\n            sp_size = xccl_info.world_size\n            FA_head_num = attn.heads // sp_size\n            total_frame = frame * sp_size\n            # apply all_to_all to gather sequence and split attention heads [s // sp * b, h, d] -> [s * b, h // sp, d]\n            query = all_to_all_SBH(query.view(-1, attn.heads, head_dim), scatter_dim=1, gather_dim=0)\n            key = all_to_all_SBH(key.view(-1, attn.heads, head_dim), scatter_dim=1, gather_dim=0)\n            value = all_to_all_SBH(value.view(-1, attn.heads, head_dim), scatter_dim=1, gather_dim=0)\n        query = query.view(-1, batch_size, FA_head_num, head_dim)\n        key = key.view(-1, batch_size, FA_head_num, head_dim)\n\n        if not self.is_cross_attn:\n            # require the shape of (ntokens x batch_size x nheads x dim)\n            pos_thw = self.position_getter(batch_size, t=total_frame, h=height, w=width, device=query.device)\n\n            query = self.rope(query, pos_thw)\n            key = self.rope(key, pos_thw)\n            \n            # query = rearrange(query, 's b h d -> b h s d')\n            # key = rearrange(key, 's b h d -> b h s d')\n            # dtype = query.dtype\n\n            # query = self.rope(query.to(torch.float16), pos_thw)\n            # key = self.rope(key.to(torch.float16), pos_thw)\n\n            # query = rearrange(query, 'b h s d -> s b h d').to(dtype)\n            # key = rearrange(key, 'b h s d -> s b h d').to(dtype)\n\n        query = query.view(-1, batch_size, FA_head_num * head_dim)\n        key = key.view(-1, batch_size, FA_head_num * head_dim)\n        value = value.view(-1, batch_size, FA_head_num * head_dim)\n        # print(f'q {query.shape}, k {key.shape}, v {value.shape}')\n        if self.sparse1d:\n            query, pad_len = self._sparse_1d(query, total_frame, height, width)\n            if self.is_cross_attn:\n                key = self._sparse_1d_kv(key)\n                value = self._sparse_1d_kv(value)\n            else:\n                key, pad_len = self._sparse_1d(key, total_frame, height, width)\n                value, pad_len = self._sparse_1d(value, total_frame, height, width)\n\n        # print(f'after sparse q {query.shape}, k {key.shape}, v {value.shape}')\n        if npu_config is not None:\n            hidden_states = npu_config.run_attention(query, key, value, attention_mask, \"SBH\", head_dim, FA_head_num)\n        else:\n            query = rearrange(query, 's b (h d) -> b h s d', h=FA_head_num)\n            key = rearrange(key, 's b (h d) -> b h s d', h=FA_head_num)\n            value = rearrange(value, 's b (h d) -> b h s d', h=FA_head_num)\n            # 0, -10000 ->(bool) False, True ->(any) True ->(not) False\n            # 0, 0 ->(bool) False, False ->(any) False ->(not) True\n            # if attention_mask is None or not torch.any(attention_mask.bool()):  # 0 mean visible\n            #     attention_mask = None\n            # the output of sdp = (batch, num_heads, seq_len, head_dim)\n            with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=False, enable_mem_efficient=True):\n                hidden_states = F.scaled_dot_product_attention(\n                    query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False\n                )\n            hidden_states = rearrange(hidden_states, 'b h s d -> s b (h d)', h=FA_head_num)\n\n        if self.sparse1d:\n            hidden_states = self._reverse_sparse_1d(hidden_states, total_frame, height, width, pad_len)\n\n        # [s, b, h // sp * d] -> [s // sp * b, h, d] -> [s // sp, b, h * d]\n        if get_sequence_parallel_state():\n            hidden_states = all_to_all_SBH(hidden_states.reshape(-1, FA_head_num, head_dim), scatter_dim=0, gather_dim=1)\n            hidden_states = hidden_states.view(-1, batch_size, inner_dim)\n\n        hidden_states = hidden_states.to(query.dtype)\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        # if attn.residual_connection:\n        #     print('attn.residual_connection')\n            # hidden_states = hidden_states + residual\n\n        hidden_states = hidden_states / attn.rescale_output_factor\n\n        return hidden_states\n\n\n@maybe_allow_in_graph\nclass BasicTransformerBlock(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        num_attention_heads: int,\n        attention_head_dim: int,\n        dropout=0.0,\n        cross_attention_dim: Optional[int] = None,\n        activation_fn: str = \"geglu\",\n        attention_bias: bool = False,\n        only_cross_attention: bool = False,\n        double_self_attention: bool = False,\n        upcast_attention: bool = False,\n        norm_elementwise_affine: bool = True,\n        norm_eps: float = 1e-5,\n        final_dropout: bool = False,\n        ff_inner_dim: Optional[int] = None,\n        ff_bias: bool = True,\n        attention_out_bias: bool = True,\n        interpolation_scale_thw: Tuple[int] = (1, 1, 1), \n        sparse1d: bool = False,\n        sparse_n: int = 2,\n        sparse_group: bool = False,\n    ):\n        super().__init__()\n\n        # Define 3 blocks. Each block has its own normalization layer.\n        # 1. Self-Attn\n        self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)\n\n        self.attn1 = Attention(\n            query_dim=dim,\n            heads=num_attention_heads,\n            dim_head=attention_head_dim,\n            dropout=dropout,\n            bias=attention_bias,\n            cross_attention_dim=cross_attention_dim if only_cross_attention else None,\n            upcast_attention=upcast_attention,\n            out_bias=attention_out_bias,\n            interpolation_scale_thw=interpolation_scale_thw, \n            sparse1d=sparse1d,\n            sparse_n=sparse_n,\n            sparse_group=sparse_group,\n            is_cross_attn=False,\n        )\n\n        # 2. Cross-Attn\n        self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)\n\n        self.attn2 = Attention(\n            query_dim=dim,\n            cross_attention_dim=cross_attention_dim if not double_self_attention else None,\n            heads=num_attention_heads,\n            dim_head=attention_head_dim,\n            dropout=dropout,\n            bias=attention_bias,\n            upcast_attention=upcast_attention,\n            out_bias=attention_out_bias,\n            interpolation_scale_thw=interpolation_scale_thw, \n            sparse1d=sparse1d,\n            sparse_n=sparse_n,\n            sparse_group=sparse_group,\n            is_cross_attn=True,\n        )  # is self-attn if encoder_hidden_states is none\n\n        # 3. Feed-forward\n        self.ff = FeedForward(\n            dim,\n            dropout=dropout,\n            activation_fn=activation_fn,\n            final_dropout=final_dropout,\n            inner_dim=ff_inner_dim,\n            bias=ff_bias,\n        )\n\n        # 4. Scale-shift.\n        self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)\n\n\n    def forward(\n        self,\n        hidden_states: torch.FloatTensor,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        encoder_hidden_states: Optional[torch.FloatTensor] = None,\n        encoder_attention_mask: Optional[torch.FloatTensor] = None,\n        timestep: Optional[torch.LongTensor] = None,\n        frame: int = None, \n        height: int = None, \n        width: int = None, \n    ) -> torch.FloatTensor:\n        \n        # 0. Self-Attention\n        batch_size = hidden_states.shape[1]\n        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (\n                self.scale_shift_table[:, None] + timestep.reshape(6, batch_size, -1)\n        ).chunk(6, dim=0)\n\n        norm_hidden_states = self.norm1(hidden_states)\n\n        norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa\n\n        attn_output = self.attn1(\n            norm_hidden_states,\n            encoder_hidden_states=None,\n            attention_mask=attention_mask, frame=frame, height=height, width=width, \n        )\n\n        attn_output = gate_msa * attn_output\n\n        hidden_states = attn_output + hidden_states\n\n        # 3. Cross-Attention\n        norm_hidden_states = hidden_states\n\n        attn_output = self.attn2(\n            norm_hidden_states,\n            encoder_hidden_states=encoder_hidden_states,\n            attention_mask=encoder_attention_mask, frame=frame, height=height, width=width,\n        )\n        hidden_states = attn_output + hidden_states\n\n        # 4. Feed-forward\n        norm_hidden_states = self.norm2(hidden_states)\n\n        norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp\n\n        ff_output = self.ff(norm_hidden_states)\n\n        ff_output = gate_mlp * ff_output\n\n        hidden_states = ff_output + hidden_states\n\n        return hidden_states\n"
  },
  {
    "path": "opensora/models/frame_interpolation/cfgs/AMT-G.yaml",
    "content": "\nseed: 2023\n\nnetwork:\n  name: networks.AMT-G.Model\n  params:\n    corr_radius: 3\n    corr_lvls: 4\n    num_flows: 5"
  },
  {
    "path": "opensora/models/frame_interpolation/interpolation.py",
    "content": "# this script is modified from https://github.com/MCG-NKU/AMT/blob/main/demos/demo_2x.py\nfrom json import load\nimport os\nimport cv2\nimport sys\nimport glob\nimport torch\nimport argparse\nimport numpy as np\nimport os.path as osp\nfrom warnings import warn\nfrom omegaconf import OmegaConf\nfrom torchvision.utils import make_grid\nsys.path.append('.')\nfrom utils.utils import (\n    read, write,\n    img2tensor, tensor2img,\n    check_dim_and_resize\n    )\nfrom utils.build_utils import build_from_cfg\nfrom utils.utils import InputPadder\n\n\nAMT_G = {\n    'name': 'networks.AMT-G.Model',\n    'params':{\n        'corr_radius': 3,\n        'corr_lvls': 4,\n        'num_flows': 5,\n    }\n}\n\n\n\ndef init(device=\"cuda\"):\n\n    '''\n        initialize the device and the anchor resolution.\n    '''\n\n    if device == 'cuda':\n        anchor_resolution = 1024 * 512\n        anchor_memory = 1500 * 1024**2\n        anchor_memory_bias = 2500 * 1024**2\n        vram_avail = torch.cuda.get_device_properties(device).total_memory\n        print(\"VRAM available: {:.1f} MB\".format(vram_avail / 1024 ** 2))\n    else:\n        # Do not resize in cpu mode\n        anchor_resolution = 8192*8192\n        anchor_memory = 1\n        anchor_memory_bias = 0\n        vram_avail = 1\n    \n    return anchor_resolution, anchor_memory, anchor_memory_bias, vram_avail\n\ndef get_input_video_from_path(input_path, device=\"cuda\"):\n\n    '''\n        Get the input video from the input_path.\n\n        params:\n            input_path: str, the path of the input video.\n            devices: str, the device to run the model.\n        returns:\n            inputs: list, the list of the input frames.\n            scale: float, the scale of the input frames.\n            padder: InputPadder, the padder to pad the input frames.\n    '''\n\n    anchor_resolution, anchor_memory, anchor_memory_bias, vram_avail = init(device)\n\n    if osp.splitext(input_path)[-1] in ['.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', \n                                        '.webm', '.MP4', '.AVI', '.MOV', '.MKV', '.FLV', \n                                        '.WMV', '.WEBM']:\n\n        vcap = cv2.VideoCapture(input_path)\n\n        inputs = []\n        w = int(vcap.get(cv2.CAP_PROP_FRAME_WIDTH))\n        h = int(vcap.get(cv2.CAP_PROP_FRAME_HEIGHT))\n        scale = anchor_resolution / (h * w) * np.sqrt((vram_avail - anchor_memory_bias) / anchor_memory)\n        scale = 1 if scale > 1 else scale\n        scale = 1 / np.floor(1 / np.sqrt(scale) * 16) * 16\n        if scale < 1:\n            print(f\"Due to the limited VRAM, the video will be scaled by {scale:.2f}\")\n        padding = int(16 / scale)\n        padder = InputPadder((h, w), padding)\n        while True:\n            ret, frame = vcap.read()\n            if ret is False:\n                break\n            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n            frame_t = img2tensor(frame).to(device)\n            frame_t = padder.pad(frame_t)\n            inputs.append(frame_t)\n        print(f'Loading the [video] from {input_path}, the number of frames [{len(inputs)}]')\n    else:\n        raise TypeError(\"Input should be a video.\")\n    \n    return inputs, scale, padder\n\n\ndef load_model(ckpt_path, device=\"cuda\"):\n\n    '''\n        load the frame interpolation model.\n    '''\n    network_cfg = AMT_G\n    network_name = network_cfg['name']\n    print(f'Loading [{network_name}] from [{ckpt_path}]...')\n    model = build_from_cfg(network_cfg)\n    ckpt = torch.load(ckpt_path)\n    model.load_state_dict(ckpt['state_dict'])\n    model = model.to(device)\n    model.eval()\n    return model\n\ndef interpolater(model, inputs, scale, padder, iters=1):\n\n    '''\n        interpolating with the interpolation model.\n\n        params:\n            model: nn.Module, the frame interpolation model.\n            inputs: list, the list of the input frames.\n            scale: float, the scale of the input frames.\n            iters: int, the number of iterations of interpolation. The final frames model generating is 2 ** iters * (m - 1) + 1 and m is input frames. \n        returns:\n            outputs: list, the list of the output frames.\n    '''\n\n    print(f'Start frame interpolation:')\n    embt = torch.tensor(1/2).float().view(1, 1, 1, 1).to(device)\n\n    for i in range(iters):\n        print(f'Iter {i+1}. input_frames={len(inputs)} output_frames={2*len(inputs)-1}')\n        outputs = [inputs[0]]\n        for in_0, in_1 in zip(inputs[:-1], inputs[1:]):\n            in_0 = in_0.to(device)\n            in_1 = in_1.to(device)\n            with torch.no_grad():\n                imgt_pred = model(in_0, in_1, embt, scale_factor=scale, eval=True)['imgt_pred']\n            outputs += [imgt_pred.cpu(), in_1.cpu()]\n        inputs = outputs\n\n    outputs = padder.unpad(*outputs)\n\n    return outputs\n\ndef write(outputs, input_path, output_path, frame_rate=30):\n    '''\n        write results to the output_path.\n    '''\n\n    if osp.exists(output_path) is False:\n        os.makedirs(output_path)\n\n    \n    size = outputs[0].shape[2:][::-1]\n\n    _, file_name_with_extension = os.path.split(input_path)\n    file_name, _ = os.path.splitext(file_name_with_extension)\n\n    save_video_path = f'{output_path}/output_{file_name}.mp4'\n    writer = cv2.VideoWriter(save_video_path, cv2.VideoWriter_fourcc(*\"mp4v\"), \n                        frame_rate, size)\n\n    for i, imgt_pred in enumerate(outputs):\n        imgt_pred = tensor2img(imgt_pred)\n        imgt_pred = cv2.cvtColor(imgt_pred, cv2.COLOR_RGB2BGR)\n        writer.write(imgt_pred)        \n    print(f\"Demo video is saved to [{save_video_path}]\")\n\n    writer.release()\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--ckpt', type=str, default='amt-g.pth', help=\"The pretrained model.\") \n    parser.add_argument('--niters', type=int, default=1, help=\"Iter of Interpolation. The number of frames will be double after per iter.\") \n    parser.add_argument('--input', default=\"test.mp4\", help=\"Input video.\") \n    parser.add_argument('--output_path', type=str, default='results', help=\"Output path.\") \n    parser.add_argument('--frame_rate', type=int, default=30, help=\"Frames rate of the output video.\")\n\n    args = parser.parse_args()\n\n    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n    ckpt_path = args.ckpt\n    input_path = args.input\n    output_path = args.output_path\n    iters = int(args.niters)\n    frame_rate = int(args.frame_rate)\n\n    inputs, scale, padder = get_input_video_from_path(input_path, device)\n    model = load_model(ckpt_path, device)\n    outputs = interpolater(model, inputs, scale, padder, iters)\n    write(outputs, input_path, output_path, frame_rate)\n"
  },
  {
    "path": "opensora/models/frame_interpolation/networks/AMT-G.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom networks.blocks.raft import (\n    coords_grid,\n    BasicUpdateBlock, BidirCorrBlock\n)\nfrom networks.blocks.feat_enc import (\n    LargeEncoder\n)\nfrom networks.blocks.ifrnet import (\n    resize,\n    Encoder,\n    InitDecoder,\n    IntermediateDecoder\n)\nfrom networks.blocks.multi_flow import (\n    multi_flow_combine,\n    MultiFlowDecoder\n)\n\n\nclass Model(nn.Module):\n    def __init__(self, \n                 corr_radius=3, \n                 corr_lvls=4, \n                 num_flows=5, \n                 channels=[84, 96, 112, 128], \n                 skip_channels=84):\n        super(Model, self).__init__()\n        self.radius = corr_radius\n        self.corr_levels = corr_lvls\n        self.num_flows = num_flows\n\n        self.feat_encoder = LargeEncoder(output_dim=128, norm_fn='instance', dropout=0.)\n        self.encoder = Encoder(channels, large=True)\n        self.decoder4 = InitDecoder(channels[3], channels[2], skip_channels)\n        self.decoder3 = IntermediateDecoder(channels[2], channels[1], skip_channels)\n        self.decoder2 = IntermediateDecoder(channels[1], channels[0], skip_channels)\n        self.decoder1 = MultiFlowDecoder(channels[0], skip_channels, num_flows)\n\n        self.update4 = self._get_updateblock(112, None)\n        self.update3_low = self._get_updateblock(96, 2.0)\n        self.update2_low = self._get_updateblock(84, 4.0)\n        \n        self.update3_high = self._get_updateblock(96, None)\n        self.update2_high = self._get_updateblock(84, None)\n        \n        self.comb_block = nn.Sequential(\n            nn.Conv2d(3*self.num_flows, 6*self.num_flows, 7, 1, 3),\n            nn.PReLU(6*self.num_flows),\n            nn.Conv2d(6*self.num_flows, 3, 7, 1, 3),\n        )\n\n    def _get_updateblock(self, cdim, scale_factor=None):\n        return BasicUpdateBlock(cdim=cdim, hidden_dim=192, flow_dim=64, \n                                corr_dim=256, corr_dim2=192, fc_dim=188, \n                                scale_factor=scale_factor, corr_levels=self.corr_levels, \n                                radius=self.radius)\n\n    def _corr_scale_lookup(self, corr_fn, coord, flow0, flow1, embt, downsample=1):\n        # convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0\n        # based on linear assumption\n        t1_scale = 1. / embt\n        t0_scale = 1. / (1. - embt)\n        if downsample != 1:\n            inv = 1 / downsample\n            flow0 = inv * resize(flow0, scale_factor=inv)\n            flow1 = inv * resize(flow1, scale_factor=inv)\n            \n        corr0, corr1 = corr_fn(coord + flow1 * t1_scale, coord + flow0 * t0_scale) \n        corr = torch.cat([corr0, corr1], dim=1)\n        flow = torch.cat([flow0, flow1], dim=1)\n        return corr, flow\n    \n    def forward(self, img0, img1, embt, scale_factor=1.0, eval=False, **kwargs):\n        mean_ = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True)\n        img0 = img0 - mean_\n        img1 = img1 - mean_\n        img0_ = resize(img0, scale_factor) if scale_factor != 1.0 else img0\n        img1_ = resize(img1, scale_factor) if scale_factor != 1.0 else img1\n        b, _, h, w = img0_.shape\n        coord = coords_grid(b, h // 8, w // 8, img0.device)\n        \n        fmap0, fmap1 = self.feat_encoder([img0_, img1_]) # [1, 128, H//8, W//8]\n        corr_fn = BidirCorrBlock(fmap0, fmap1, radius=self.radius, num_levels=self.corr_levels)\n\n        # f0_1: [1, c0, H//2, W//2] | f0_2: [1, c1, H//4, W//4]\n        # f0_3: [1, c2, H//8, W//8] | f0_4: [1, c3, H//16, W//16]\n        f0_1, f0_2, f0_3, f0_4 = self.encoder(img0_)\n        f1_1, f1_2, f1_3, f1_4 = self.encoder(img1_)\n\n        ######################################### the 4th decoder #########################################\n        up_flow0_4, up_flow1_4, ft_3_ = self.decoder4(f0_4, f1_4, embt)\n        corr_4, flow_4 = self._corr_scale_lookup(corr_fn, coord, \n                                                 up_flow0_4, up_flow1_4, \n                                                 embt, downsample=1)\n\n        # residue update with lookup corr\n        delta_ft_3_, delta_flow_4 = self.update4(ft_3_, flow_4, corr_4)\n        delta_flow0_4, delta_flow1_4 = torch.chunk(delta_flow_4, 2, 1)\n        up_flow0_4 = up_flow0_4 + delta_flow0_4\n        up_flow1_4 = up_flow1_4 + delta_flow1_4\n        ft_3_ = ft_3_ + delta_ft_3_\n\n        ######################################### the 3rd decoder #########################################\n        up_flow0_3, up_flow1_3, ft_2_ = self.decoder3(ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4)\n        corr_3, flow_3 = self._corr_scale_lookup(corr_fn, \n                                                 coord, up_flow0_3, up_flow1_3, \n                                                 embt, downsample=2)\n\n        # residue update with lookup corr\n        delta_ft_2_, delta_flow_3 = self.update3_low(ft_2_, flow_3, corr_3)\n        delta_flow0_3, delta_flow1_3 = torch.chunk(delta_flow_3, 2, 1)\n        up_flow0_3 = up_flow0_3 + delta_flow0_3\n        up_flow1_3 = up_flow1_3 + delta_flow1_3\n        ft_2_ = ft_2_ + delta_ft_2_\n        \n        # residue update with lookup corr (hr)\n        corr_3 = resize(corr_3, scale_factor=2.0)\n        up_flow_3 = torch.cat([up_flow0_3, up_flow1_3], dim=1)\n        delta_ft_2_, delta_up_flow_3 = self.update3_high(ft_2_, up_flow_3, corr_3)\n        ft_2_ += delta_ft_2_\n        up_flow0_3 += delta_up_flow_3[:, 0:2]\n        up_flow1_3 += delta_up_flow_3[:, 2:4]\n        \n        ######################################### the 2nd decoder #########################################\n        up_flow0_2, up_flow1_2, ft_1_  = self.decoder2(ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3)\n        corr_2, flow_2 = self._corr_scale_lookup(corr_fn, \n                                                 coord, up_flow0_2, up_flow1_2, \n                                                 embt, downsample=4)\n        \n        # residue update with lookup corr\n        delta_ft_1_, delta_flow_2 = self.update2_low(ft_1_, flow_2, corr_2)\n        delta_flow0_2, delta_flow1_2 = torch.chunk(delta_flow_2, 2, 1)\n        up_flow0_2 = up_flow0_2 + delta_flow0_2\n        up_flow1_2 = up_flow1_2 + delta_flow1_2\n        ft_1_ = ft_1_ + delta_ft_1_\n        \n        # residue update with lookup corr (hr)\n        corr_2 = resize(corr_2, scale_factor=4.0)\n        up_flow_2 = torch.cat([up_flow0_2, up_flow1_2], dim=1)\n        delta_ft_1_, delta_up_flow_2 = self.update2_high(ft_1_, up_flow_2, corr_2)\n        ft_1_ += delta_ft_1_\n        up_flow0_2 += delta_up_flow_2[:, 0:2]\n        up_flow1_2 += delta_up_flow_2[:, 2:4]\n        \n        ######################################### the 1st decoder #########################################\n        up_flow0_1, up_flow1_1, mask, img_res = self.decoder1(ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2)\n        \n        if scale_factor != 1.0: \n            up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor)\n            up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor)\n            mask = resize(mask, scale_factor=(1.0/scale_factor))\n            img_res = resize(img_res, scale_factor=(1.0/scale_factor))\n\n        # Merge multiple predictions \n        imgt_pred = multi_flow_combine(self.comb_block, img0, img1, up_flow0_1, up_flow1_1, \n                                                                        mask, img_res, mean_)\n        imgt_pred = torch.clamp(imgt_pred, 0, 1)\n\n        if eval:\n            return  { 'imgt_pred': imgt_pred, }\n        else:\n            up_flow0_1 = up_flow0_1.reshape(b, self.num_flows, 2, h, w)\n            up_flow1_1 = up_flow1_1.reshape(b, self.num_flows, 2, h, w)\n            return {\n                'imgt_pred': imgt_pred,\n                'flow0_pred': [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4],\n                'flow1_pred': [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4],\n                'ft_pred': [ft_1_, ft_2_, ft_3_],\n            }\n"
  },
  {
    "path": "opensora/models/frame_interpolation/networks/__init__.py",
    "content": ""
  },
  {
    "path": "opensora/models/frame_interpolation/networks/blocks/__init__.py",
    "content": ""
  },
  {
    "path": "opensora/models/frame_interpolation/networks/blocks/feat_enc.py",
    "content": "import torch\nimport torch.nn as nn\n\n\nclass BottleneckBlock(nn.Module):\n    def __init__(self, in_planes, planes, norm_fn='group', stride=1):\n        super(BottleneckBlock, self).__init__()\n  \n        self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)\n        self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)\n        self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)\n        self.relu = nn.ReLU(inplace=True)\n\n        num_groups = planes // 8\n\n        if norm_fn == 'group':\n            self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)\n            self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)\n            self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)\n            if not stride == 1:\n                self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)\n        \n        elif norm_fn == 'batch':\n            self.norm1 = nn.BatchNorm2d(planes//4)\n            self.norm2 = nn.BatchNorm2d(planes//4)\n            self.norm3 = nn.BatchNorm2d(planes)\n            if not stride == 1:\n                self.norm4 = nn.BatchNorm2d(planes)\n        \n        elif norm_fn == 'instance':\n            self.norm1 = nn.InstanceNorm2d(planes//4)\n            self.norm2 = nn.InstanceNorm2d(planes//4)\n            self.norm3 = nn.InstanceNorm2d(planes)\n            if not stride == 1:\n                self.norm4 = nn.InstanceNorm2d(planes)\n\n        elif norm_fn == 'none':\n            self.norm1 = nn.Sequential()\n            self.norm2 = nn.Sequential()\n            self.norm3 = nn.Sequential()\n            if not stride == 1:\n                self.norm4 = nn.Sequential()\n\n        if stride == 1:\n            self.downsample = None\n        \n        else:    \n            self.downsample = nn.Sequential(\n                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)\n\n\n    def forward(self, x):\n        y = x\n        y = self.relu(self.norm1(self.conv1(y)))\n        y = self.relu(self.norm2(self.conv2(y)))\n        y = self.relu(self.norm3(self.conv3(y)))\n\n        if self.downsample is not None:\n            x = self.downsample(x)\n\n        return self.relu(x+y)\n\n\nclass ResidualBlock(nn.Module):\n    def __init__(self, in_planes, planes, norm_fn='group', stride=1):\n        super(ResidualBlock, self).__init__()\n  \n        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)\n        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)\n        self.relu = nn.ReLU(inplace=True)\n\n        num_groups = planes // 8\n\n        if norm_fn == 'group':\n            self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)\n            self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)\n            if not stride == 1:\n                self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)\n        \n        elif norm_fn == 'batch':\n            self.norm1 = nn.BatchNorm2d(planes)\n            self.norm2 = nn.BatchNorm2d(planes)\n            if not stride == 1:\n                self.norm3 = nn.BatchNorm2d(planes)\n        \n        elif norm_fn == 'instance':\n            self.norm1 = nn.InstanceNorm2d(planes)\n            self.norm2 = nn.InstanceNorm2d(planes)\n            if not stride == 1:\n                self.norm3 = nn.InstanceNorm2d(planes)\n\n        elif norm_fn == 'none':\n            self.norm1 = nn.Sequential()\n            self.norm2 = nn.Sequential()\n            if not stride == 1:\n                self.norm3 = nn.Sequential()\n\n        if stride == 1:\n            self.downsample = None\n        \n        else:    \n            self.downsample = nn.Sequential(\n                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)\n\n\n    def forward(self, x):\n        y = x\n        y = self.relu(self.norm1(self.conv1(y)))\n        y = self.relu(self.norm2(self.conv2(y)))\n\n        if self.downsample is not None:\n            x = self.downsample(x)\n\n        return self.relu(x+y)\n\n\nclass SmallEncoder(nn.Module):\n    def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):\n        super(SmallEncoder, self).__init__()\n        self.norm_fn = norm_fn\n\n        if self.norm_fn == 'group':\n            self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)\n            \n        elif self.norm_fn == 'batch':\n            self.norm1 = nn.BatchNorm2d(32)\n\n        elif self.norm_fn == 'instance':\n            self.norm1 = nn.InstanceNorm2d(32)\n\n        elif self.norm_fn == 'none':\n            self.norm1 = nn.Sequential()\n\n        self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)\n        self.relu1 = nn.ReLU(inplace=True)\n\n        self.in_planes = 32\n        self.layer1 = self._make_layer(32,  stride=1)\n        self.layer2 = self._make_layer(64, stride=2)\n        self.layer3 = self._make_layer(96, stride=2)\n\n        self.dropout = None\n        if dropout > 0:\n            self.dropout = nn.Dropout2d(p=dropout)\n        \n        self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n            elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):\n                if m.weight is not None:\n                    nn.init.constant_(m.weight, 1)\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n\n    def _make_layer(self, dim, stride=1):\n        layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)\n        layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)\n        layers = (layer1, layer2)\n    \n        self.in_planes = dim\n        return nn.Sequential(*layers)\n\n\n    def forward(self, x):\n\n        # if input is list, combine batch dimension\n        is_list = isinstance(x, tuple) or isinstance(x, list)\n        if is_list:\n            batch_dim = x[0].shape[0]\n            x = torch.cat(x, dim=0)\n\n        x = self.conv1(x)\n        x = self.norm1(x)\n        x = self.relu1(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.conv2(x)\n\n        if self.training and self.dropout is not None:\n            x = self.dropout(x)\n\n        if is_list:\n            x = torch.split(x, [batch_dim, batch_dim], dim=0)\n\n        return x\n\nclass BasicEncoder(nn.Module):\n    def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):\n        super(BasicEncoder, self).__init__()\n        self.norm_fn = norm_fn\n\n        if self.norm_fn == 'group':\n            self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)\n            \n        elif self.norm_fn == 'batch':\n            self.norm1 = nn.BatchNorm2d(64)\n\n        elif self.norm_fn == 'instance':\n            self.norm1 = nn.InstanceNorm2d(64)\n\n        elif self.norm_fn == 'none':\n            self.norm1 = nn.Sequential()\n\n        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)\n        self.relu1 = nn.ReLU(inplace=True)\n\n        self.in_planes = 64\n        self.layer1 = self._make_layer(64,  stride=1)\n        self.layer2 = self._make_layer(72, stride=2)\n        self.layer3 = self._make_layer(128, stride=2)\n\n        # output convolution\n        self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)\n\n        self.dropout = None\n        if dropout > 0:\n            self.dropout = nn.Dropout2d(p=dropout)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n            elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):\n                if m.weight is not None:\n                    nn.init.constant_(m.weight, 1)\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n\n    def _make_layer(self, dim, stride=1):\n        layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)\n        layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)\n        layers = (layer1, layer2)\n        \n        self.in_planes = dim\n        return nn.Sequential(*layers)\n\n\n    def forward(self, x):\n\n        # if input is list, combine batch dimension\n        is_list = isinstance(x, tuple) or isinstance(x, list)\n        if is_list:\n            batch_dim = x[0].shape[0]\n            x = torch.cat(x, dim=0)\n\n        x = self.conv1(x)\n        x = self.norm1(x)\n        x = self.relu1(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n\n        x = self.conv2(x)\n\n        if self.training and self.dropout is not None:\n            x = self.dropout(x)\n\n        if is_list:\n            x = torch.split(x, [batch_dim, batch_dim], dim=0)\n\n        return x\n\nclass LargeEncoder(nn.Module):\n    def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):\n        super(LargeEncoder, self).__init__()\n        self.norm_fn = norm_fn\n\n        if self.norm_fn == 'group':\n            self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)\n            \n        elif self.norm_fn == 'batch':\n            self.norm1 = nn.BatchNorm2d(64)\n\n        elif self.norm_fn == 'instance':\n            self.norm1 = nn.InstanceNorm2d(64)\n\n        elif self.norm_fn == 'none':\n            self.norm1 = nn.Sequential()\n\n        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)\n        self.relu1 = nn.ReLU(inplace=True)\n\n        self.in_planes = 64\n        self.layer1 = self._make_layer(64, stride=1)\n        self.layer2 = self._make_layer(112, stride=2)\n        self.layer3 = self._make_layer(160, stride=2)\n        self.layer3_2 = self._make_layer(160, stride=1)\n\n        # output convolution\n        self.conv2 = nn.Conv2d(self.in_planes, output_dim, kernel_size=1)\n\n        self.dropout = None\n        if dropout > 0:\n            self.dropout = nn.Dropout2d(p=dropout)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n            elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):\n                if m.weight is not None:\n                    nn.init.constant_(m.weight, 1)\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n\n    def _make_layer(self, dim, stride=1):\n        layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)\n        layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)\n        layers = (layer1, layer2)\n        \n        self.in_planes = dim\n        return nn.Sequential(*layers)\n\n\n    def forward(self, x):\n\n        # if input is list, combine batch dimension\n        is_list = isinstance(x, tuple) or isinstance(x, list)\n        if is_list:\n            batch_dim = x[0].shape[0]\n            x = torch.cat(x, dim=0)\n\n        x = self.conv1(x)\n        x = self.norm1(x)\n        x = self.relu1(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.layer3_2(x)\n\n        x = self.conv2(x)\n\n        if self.training and self.dropout is not None:\n            x = self.dropout(x)\n\n        if is_list:\n            x = torch.split(x, [batch_dim, batch_dim], dim=0)\n\n        return x\n"
  },
  {
    "path": "opensora/models/frame_interpolation/networks/blocks/ifrnet.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom utils.flow_utils import warp\n\n\ndef resize(x, scale_factor):\n    return F.interpolate(x, scale_factor=scale_factor, mode=\"bilinear\", align_corners=False)\n\ndef convrelu(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True):\n    return nn.Sequential(\n        nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=bias), \n        nn.PReLU(out_channels)\n    )\n\nclass ResBlock(nn.Module):\n    def __init__(self, in_channels, side_channels, bias=True):\n        super(ResBlock, self).__init__()\n        self.side_channels = side_channels\n        self.conv1 = nn.Sequential(\n            nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias), \n            nn.PReLU(in_channels)\n        )\n        self.conv2 = nn.Sequential(\n            nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias), \n            nn.PReLU(side_channels)\n        )\n        self.conv3 = nn.Sequential(\n            nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias), \n            nn.PReLU(in_channels)\n        )\n        self.conv4 = nn.Sequential(\n            nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias), \n            nn.PReLU(side_channels)\n        )\n        self.conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias)\n        self.prelu = nn.PReLU(in_channels)\n\n    def forward(self, x):\n        out = self.conv1(x)\n\n        res_feat = out[:, :-self.side_channels, ...]\n        side_feat = out[:, -self.side_channels:, :, :]\n        side_feat = self.conv2(side_feat)\n        out = self.conv3(torch.cat([res_feat, side_feat], 1))\n\n        res_feat = out[:, :-self.side_channels, ...]\n        side_feat = out[:, -self.side_channels:, :, :]\n        side_feat = self.conv4(side_feat)\n        out = self.conv5(torch.cat([res_feat, side_feat], 1))\n\n        out = self.prelu(x + out)\n        return out\n    \nclass Encoder(nn.Module):\n    def __init__(self, channels, large=False):\n        super(Encoder, self).__init__()\n        self.channels = channels        \n        prev_ch = 3\n        for idx, ch in enumerate(channels, 1):\n            k = 7 if large and idx == 1 else 3\n            p = 3 if k ==7 else 1\n            self.register_module(f'pyramid{idx}', \n            nn.Sequential(\n                convrelu(prev_ch, ch, k, 2, p),\n                convrelu(ch, ch, 3, 1, 1)\n            ))\n            prev_ch = ch\n                \n    def forward(self, in_x):\n        fs = []\n        for idx in range(len(self.channels)):\n            out_x = getattr(self, f'pyramid{idx+1}')(in_x)\n            fs.append(out_x)\n            in_x = out_x\n        return fs\n    \nclass InitDecoder(nn.Module):\n    def __init__(self, in_ch, out_ch, skip_ch) -> None:\n        super().__init__()\n        self.convblock = nn.Sequential(\n            convrelu(in_ch*2+1, in_ch*2), \n            ResBlock(in_ch*2, skip_ch), \n            nn.ConvTranspose2d(in_ch*2, out_ch+4, 4, 2, 1, bias=True)\n        )\n    def forward(self, f0, f1, embt):\n        h, w = f0.shape[2:]\n        embt = embt.repeat(1, 1, h, w)\n        out = self.convblock(torch.cat([f0, f1, embt], 1))\n        flow0, flow1 = torch.chunk(out[:, :4, ...], 2, 1)\n        ft_ = out[:, 4:, ...]\n        return flow0, flow1, ft_\n    \nclass IntermediateDecoder(nn.Module):\n    def __init__(self, in_ch, out_ch, skip_ch) -> None:\n        super().__init__()\n        self.convblock = nn.Sequential(\n            convrelu(in_ch*3+4, in_ch*3), \n            ResBlock(in_ch*3, skip_ch), \n            nn.ConvTranspose2d(in_ch*3, out_ch+4, 4, 2, 1, bias=True)\n        )\n    def forward(self, ft_, f0, f1, flow0_in, flow1_in):\n        f0_warp = warp(f0, flow0_in)\n        f1_warp = warp(f1, flow1_in)\n        f_in = torch.cat([ft_, f0_warp, f1_warp, flow0_in, flow1_in], 1)\n        out = self.convblock(f_in)\n        flow0, flow1 = torch.chunk(out[:, :4, ...], 2, 1)\n        ft_ = out[:, 4:, ...]\n        flow0 = flow0 + 2.0 * resize(flow0_in, scale_factor=2.0)\n        flow1 = flow1 + 2.0 * resize(flow1_in, scale_factor=2.0)\n        return flow0, flow1, ft_"
  },
  {
    "path": "opensora/models/frame_interpolation/networks/blocks/multi_flow.py",
    "content": "import torch\nimport torch.nn as nn\nfrom utils.flow_utils import warp\nfrom networks.blocks.ifrnet import (\n    convrelu, resize,\n    ResBlock,\n)\n\n\ndef multi_flow_combine(comb_block, img0, img1, flow0, flow1, \n                       mask=None, img_res=None, mean=None):\n        '''\n            A parallel implementation of multiple flow field warping \n            comb_block: An nn.Seqential object.\n            img shape: [b, c, h, w]\n            flow shape: [b, 2*num_flows, h, w]\n            mask (opt):\n                If 'mask' is None, the function conduct a simple average.\n            img_res (opt):\n                If 'img_res' is None, the function adds zero instead. \n            mean (opt):\n                If 'mean' is None, the function adds zero instead.       \n        '''\n        b, c, h, w = flow0.shape\n        num_flows = c // 2\n        flow0   =   flow0.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w)\n        flow1   =   flow1.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w)\n        \n        mask    =    mask.reshape(b, num_flows, 1, h, w\n                            ).reshape(-1, 1, h, w) if mask is not None else None\n        img_res = img_res.reshape(b, num_flows, 3, h, w\n                            ).reshape(-1, 3, h, w)  if img_res is not None else 0\n        img0 = torch.stack([img0] * num_flows, 1).reshape(-1, 3, h, w)\n        img1 = torch.stack([img1] * num_flows, 1).reshape(-1, 3, h, w)\n        mean = torch.stack([mean] * num_flows, 1).reshape(-1, 1, 1, 1\n                                                    ) if mean is not None else 0\n        \n        img0_warp = warp(img0, flow0)\n        img1_warp = warp(img1, flow1)\n        img_warps = mask * img0_warp + (1 - mask) * img1_warp + mean + img_res\n        img_warps = img_warps.reshape(b, num_flows, 3, h, w)\n        imgt_pred = img_warps.mean(1) + comb_block(img_warps.view(b, -1, h, w))\n        return imgt_pred\n\n\nclass MultiFlowDecoder(nn.Module):\n    def __init__(self, in_ch, skip_ch, num_flows=3):\n        super(MultiFlowDecoder, self).__init__()\n        self.num_flows = num_flows\n        self.convblock = nn.Sequential(\n            convrelu(in_ch*3+4, in_ch*3), \n            ResBlock(in_ch*3, skip_ch), \n            nn.ConvTranspose2d(in_ch*3, 8*num_flows, 4, 2, 1, bias=True)\n        )\n        \n    def forward(self, ft_, f0, f1, flow0, flow1):\n        n = self.num_flows\n        f0_warp = warp(f0, flow0)\n        f1_warp = warp(f1, flow1)\n        out = self.convblock(torch.cat([ft_, f0_warp, f1_warp, flow0, flow1], 1))\n        delta_flow0, delta_flow1, mask, img_res = torch.split(out, [2*n, 2*n, n, 3*n], 1)\n        mask = torch.sigmoid(mask)\n        \n        flow0 = delta_flow0 + 2.0 * resize(flow0, scale_factor=2.0\n                                           ).repeat(1, self.num_flows, 1, 1)\n        flow1 = delta_flow1 + 2.0 * resize(flow1, scale_factor=2.0\n                                           ).repeat(1, self.num_flows, 1, 1)\n        \n        return flow0, flow1, mask, img_res"
  },
  {
    "path": "opensora/models/frame_interpolation/networks/blocks/raft.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\ndef resize(x, scale_factor):\n    return F.interpolate(x, scale_factor=scale_factor, mode=\"bilinear\", align_corners=False)\n\n\ndef bilinear_sampler(img, coords, mask=False):\n    \"\"\" Wrapper for grid_sample, uses pixel coordinates \"\"\"\n    H, W = img.shape[-2:]\n    xgrid, ygrid = coords.split([1,1], dim=-1)\n    xgrid = 2*xgrid/(W-1) - 1\n    ygrid = 2*ygrid/(H-1) - 1\n\n    grid = torch.cat([xgrid, ygrid], dim=-1)\n    img = F.grid_sample(img, grid, align_corners=True)\n\n    if mask:\n        mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)\n        return img, mask.float()\n\n    return img\n\n\ndef coords_grid(batch, ht, wd, device):\n    coords = torch.meshgrid(torch.arange(ht, device=device), \n                            torch.arange(wd, device=device), \n                            indexing='ij')\n    coords = torch.stack(coords[::-1], dim=0).float()\n    return coords[None].repeat(batch, 1, 1, 1)\n\n\nclass SmallUpdateBlock(nn.Module):\n    def __init__(self, cdim, hidden_dim, flow_dim, corr_dim, fc_dim,\n                 corr_levels=4, radius=3, scale_factor=None):\n        super(SmallUpdateBlock, self).__init__()\n        cor_planes = corr_levels * (2 * radius + 1) **2\n        self.scale_factor = scale_factor\n\n        self.convc1 = nn.Conv2d(2 * cor_planes, corr_dim, 1, padding=0)\n        self.convf1 = nn.Conv2d(4, flow_dim*2, 7, padding=3)\n        self.convf2 = nn.Conv2d(flow_dim*2, flow_dim, 3, padding=1)\n        self.conv = nn.Conv2d(corr_dim+flow_dim, fc_dim, 3, padding=1)\n\n        self.gru = nn.Sequential(\n            nn.Conv2d(fc_dim+4+cdim, hidden_dim, 3, padding=1),\n            nn.LeakyReLU(negative_slope=0.1, inplace=True),\n            nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),\n        )\n\n        self.feat_head = nn.Sequential(\n            nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),\n            nn.LeakyReLU(negative_slope=0.1, inplace=True),\n            nn.Conv2d(hidden_dim, cdim, 3, padding=1),\n        )\n\n        self.flow_head = nn.Sequential(\n            nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),\n            nn.LeakyReLU(negative_slope=0.1, inplace=True),\n            nn.Conv2d(hidden_dim, 4, 3, padding=1),\n        )\n\n        self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)\n            \n    def forward(self, net, flow, corr):\n        net = resize(net, 1 / self.scale_factor\n                      ) if self.scale_factor is not None else net\n        cor = self.lrelu(self.convc1(corr))\n        flo = self.lrelu(self.convf1(flow))\n        flo = self.lrelu(self.convf2(flo))\n        cor_flo = torch.cat([cor, flo], dim=1)\n        inp = self.lrelu(self.conv(cor_flo))\n        inp = torch.cat([inp, flow, net], dim=1)\n\n        out = self.gru(inp)\n        delta_net = self.feat_head(out)\n        delta_flow = self.flow_head(out)\n        \n        if self.scale_factor is not None:\n            delta_net = resize(delta_net, scale_factor=self.scale_factor)\n            delta_flow = self.scale_factor * resize(delta_flow, scale_factor=self.scale_factor)\n        \n        return delta_net, delta_flow\n\n\nclass BasicUpdateBlock(nn.Module):\n    def __init__(self, cdim, hidden_dim, flow_dim, corr_dim, corr_dim2, \n                 fc_dim, corr_levels=4, radius=3, scale_factor=None, out_num=1):\n        super(BasicUpdateBlock, self).__init__()\n        cor_planes = corr_levels * (2 * radius + 1) **2\n\n        self.scale_factor = scale_factor\n        self.convc1 = nn.Conv2d(2 * cor_planes, corr_dim, 1, padding=0)\n        self.convc2 = nn.Conv2d(corr_dim, corr_dim2, 3, padding=1)\n        self.convf1 = nn.Conv2d(4, flow_dim*2, 7, padding=3)\n        self.convf2 = nn.Conv2d(flow_dim*2, flow_dim, 3, padding=1)\n        self.conv = nn.Conv2d(flow_dim+corr_dim2, fc_dim, 3, padding=1)\n\n        self.gru = nn.Sequential(\n            nn.Conv2d(fc_dim+4+cdim, hidden_dim, 3, padding=1),\n            nn.LeakyReLU(negative_slope=0.1, inplace=True),\n            nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),\n        )\n\n        self.feat_head = nn.Sequential(\n            nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),\n            nn.LeakyReLU(negative_slope=0.1, inplace=True),\n            nn.Conv2d(hidden_dim, cdim, 3, padding=1),\n        )\n\n        self.flow_head = nn.Sequential(\n            nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),\n            nn.LeakyReLU(negative_slope=0.1, inplace=True),\n            nn.Conv2d(hidden_dim, 4*out_num, 3, padding=1),\n        )\n\n        self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)\n            \n    def forward(self, net, flow, corr):\n        net = resize(net, 1 / self.scale_factor\n                      ) if self.scale_factor is not None else net\n        cor = self.lrelu(self.convc1(corr))\n        cor = self.lrelu(self.convc2(cor))\n        flo = self.lrelu(self.convf1(flow))\n        flo = self.lrelu(self.convf2(flo))\n        cor_flo = torch.cat([cor, flo], dim=1)\n        inp = self.lrelu(self.conv(cor_flo))\n        inp = torch.cat([inp, flow, net], dim=1)\n\n        out = self.gru(inp)\n        delta_net = self.feat_head(out)\n        delta_flow = self.flow_head(out)\n        \n        if self.scale_factor is not None:\n            delta_net = resize(delta_net, scale_factor=self.scale_factor)\n            delta_flow = self.scale_factor * resize(delta_flow, scale_factor=self.scale_factor)\n        return delta_net, delta_flow\n\n\nclass BidirCorrBlock:\n    def __init__(self, fmap1, fmap2, num_levels=4, radius=4):\n        self.num_levels = num_levels\n        self.radius = radius\n        self.corr_pyramid = []\n        self.corr_pyramid_T = []\n\n        corr = BidirCorrBlock.corr(fmap1, fmap2)\n        batch, h1, w1, dim, h2, w2 = corr.shape\n        corr_T = corr.clone().permute(0, 4, 5, 3, 1, 2)\n\n        corr = corr.reshape(batch*h1*w1, dim, h2, w2)\n        corr_T = corr_T.reshape(batch*h2*w2, dim, h1, w1)\n        \n        self.corr_pyramid.append(corr)\n        self.corr_pyramid_T.append(corr_T)\n\n        for _ in range(self.num_levels-1):\n            corr = F.avg_pool2d(corr, 2, stride=2)\n            corr_T = F.avg_pool2d(corr_T, 2, stride=2)\n            self.corr_pyramid.append(corr)\n            self.corr_pyramid_T.append(corr_T)\n\n    def __call__(self, coords0, coords1):\n        r = self.radius\n        coords0 = coords0.permute(0, 2, 3, 1)\n        coords1 = coords1.permute(0, 2, 3, 1)\n        assert coords0.shape == coords1.shape, f\"coords0 shape: [{coords0.shape}] is not equal to [{coords1.shape}]\"\n        batch, h1, w1, _ = coords0.shape\n\n        out_pyramid = []\n        out_pyramid_T = []\n        for i in range(self.num_levels):\n            corr = self.corr_pyramid[i]\n            corr_T = self.corr_pyramid_T[i]\n\n            dx = torch.linspace(-r, r, 2*r+1, device=coords0.device)\n            dy = torch.linspace(-r, r, 2*r+1, device=coords0.device)\n            delta = torch.stack(torch.meshgrid(dy, dx, indexing='ij'), axis=-1)\n            delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)\n\n            centroid_lvl_0 = coords0.reshape(batch*h1*w1, 1, 1, 2) / 2**i\n            centroid_lvl_1 = coords1.reshape(batch*h1*w1, 1, 1, 2) / 2**i\n            coords_lvl_0 = centroid_lvl_0 + delta_lvl\n            coords_lvl_1 = centroid_lvl_1 + delta_lvl\n\n            corr = bilinear_sampler(corr, coords_lvl_0)\n            corr_T = bilinear_sampler(corr_T, coords_lvl_1)\n            corr = corr.view(batch, h1, w1, -1)\n            corr_T = corr_T.view(batch, h1, w1, -1)\n            out_pyramid.append(corr)\n            out_pyramid_T.append(corr_T)\n\n        out = torch.cat(out_pyramid, dim=-1)\n        out_T = torch.cat(out_pyramid_T, dim=-1)\n        return out.permute(0, 3, 1, 2).contiguous().float(), out_T.permute(0, 3, 1, 2).contiguous().float()\n\n    @staticmethod\n    def corr(fmap1, fmap2):\n        batch, dim, ht, wd = fmap1.shape\n        fmap1 = fmap1.view(batch, dim, ht*wd)\n        fmap2 = fmap2.view(batch, dim, ht*wd) \n        \n        corr = torch.matmul(fmap1.transpose(1,2), fmap2)\n        corr = corr.view(batch, ht, wd, 1, ht, wd)\n        return corr  / torch.sqrt(torch.tensor(dim).float())"
  },
  {
    "path": "opensora/models/frame_interpolation/readme.md",
    "content": "#### Frame Interpolation\r\n\r\nWe use AMT as our frame interpolation model. (Thanks [AMT](https://github.com/MCG-NKU/AMT)) After sampling, you can use frame interpolation model to interpolate your video smoothly.\r\n\r\n1. Download the pretrained weights from [AMT](https://github.com/MCG-NKU/AMT), we recommend using the largest model AMT-G to achieve the best performance. \r\n2. Run the script of frame interpolation.\r\n```\r\npython opensora/models/frame_interpolation/interpolation.py --ckpt /path/to/ckpt --niters 1 --input /path/to/input/video.mp4 --output_path /path/to/output/floder --frame_rate 30\r\n```\r\n3. The output video will be stored at output_path and its duration time is equal `the total number of frames after frame interpolation / the frame rate`\r\n##### Frame Interpolation Specific Settings\r\n\r\n* `--ckpt`: Pretrained model of [AMT](https://github.com/MCG-NKU/AMT). We use AMT-G as our frame interpolation model. \r\n* `--niter`: Iterations of interpolation. With $m$ input frames, `[N_ITER]` $=n$ corresponds to $2^n\\times (m-1)+1$ output frames.\r\n* `--input`: Path of the input video.\r\n* `--output_path`: Folder Path of the output video.\r\n* `--frame_rate\"`: Frame rate of the output video. \r\n"
  },
  {
    "path": "opensora/models/frame_interpolation/utils/__init__.py",
    "content": ""
  },
  {
    "path": "opensora/models/frame_interpolation/utils/build_utils.py",
    "content": "import importlib\n\n\ndef base_build_fn(module, cls, params):\n    return getattr(importlib.import_module(\n                    module, package=None), cls)(**params)\n\n\ndef build_from_cfg(config):\n    module, cls = config['name'].rsplit(\".\", 1)\n    params = config.get('params', {})\n    return base_build_fn(module, cls, params)\n"
  },
  {
    "path": "opensora/models/frame_interpolation/utils/dist_utils.py",
    "content": "import os\nimport torch\n\n\ndef get_world_size():\n    \"\"\"Find OMPI world size without calling mpi functions\n    :rtype: int\n    \"\"\"\n    if os.environ.get('PMI_SIZE') is not None:\n        return int(os.environ.get('PMI_SIZE') or 1)\n    elif os.environ.get('OMPI_COMM_WORLD_SIZE') is not None:\n        return int(os.environ.get('OMPI_COMM_WORLD_SIZE') or 1)\n    else:\n        return torch.cuda.device_count()\n\n\ndef get_global_rank():\n    \"\"\"Find OMPI world rank without calling mpi functions\n    :rtype: int\n    \"\"\"\n    if os.environ.get('PMI_RANK') is not None:\n        return int(os.environ.get('PMI_RANK') or 0)\n    elif os.environ.get('OMPI_COMM_WORLD_RANK') is not None:\n        return int(os.environ.get('OMPI_COMM_WORLD_RANK') or 0)\n    else:\n        return 0\n\n\ndef get_local_rank():\n    \"\"\"Find OMPI local rank without calling mpi functions\n    :rtype: int\n    \"\"\"\n    if os.environ.get('MPI_LOCALRANKID') is not None:\n        return int(os.environ.get('MPI_LOCALRANKID') or 0)\n    elif os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') is not None:\n        return int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') or 0)\n    else:\n        return 0\n\n\ndef get_master_ip():\n    if os.environ.get('AZ_BATCH_MASTER_NODE') is not None:\n        return os.environ.get('AZ_BATCH_MASTER_NODE').split(':')[0]\n    elif os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') is not None:\n        return os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE')\n    else:\n        return \"127.0.0.1\"\n\n"
  },
  {
    "path": "opensora/models/frame_interpolation/utils/flow_utils.py",
    "content": "import numpy as np\nimport torch\nfrom PIL import ImageFile\nimport torch.nn.functional as F\nImageFile.LOAD_TRUNCATED_IMAGES = True\n\n\ndef warp(img, flow):\n    B, _, H, W = flow.shape\n    xx = torch.linspace(-1.0, 1.0, W).view(1, 1, 1, W).expand(B, -1, H, -1)\n    yy = torch.linspace(-1.0, 1.0, H).view(1, 1, H, 1).expand(B, -1, -1, W)\n    grid = torch.cat([xx, yy], 1).to(img)\n    flow_ = torch.cat([flow[:, 0:1, :, :] / ((W - 1.0) / 2.0), flow[:, 1:2, :, :] / ((H - 1.0) / 2.0)], 1)\n    grid_ = (grid + flow_).permute(0, 2, 3, 1)\n    output = F.grid_sample(input=img, grid=grid_, mode='bilinear', padding_mode='border', align_corners=True)\n    return output\n\n\ndef make_colorwheel():\n    \"\"\"\n    Generates a color wheel for optical flow visualization as presented in:\n        Baker et al. \"A Database and Evaluation Methodology for Optical Flow\" (ICCV, 2007)\n        URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf\n    Code follows the original C++ source code of Daniel Scharstein.\n    Code follows the the Matlab source code of Deqing Sun.\n    Returns:\n        np.ndarray: Color wheel\n    \"\"\"\n\n    RY = 15\n    YG = 6\n    GC = 4\n    CB = 11\n    BM = 13\n    MR = 6\n\n    ncols = RY + YG + GC + CB + BM + MR\n    colorwheel = np.zeros((ncols, 3))\n    col = 0\n\n    # RY\n    colorwheel[0:RY, 0] = 255\n    colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)\n    col = col+RY\n    # YG\n    colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)\n    colorwheel[col:col+YG, 1] = 255\n    col = col+YG\n    # GC\n    colorwheel[col:col+GC, 1] = 255\n    colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)\n    col = col+GC\n    # CB\n    colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)\n    colorwheel[col:col+CB, 2] = 255\n    col = col+CB\n    # BM\n    colorwheel[col:col+BM, 2] = 255\n    colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)\n    col = col+BM\n    # MR\n    colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)\n    colorwheel[col:col+MR, 0] = 255\n    return colorwheel\n\ndef flow_uv_to_colors(u, v, convert_to_bgr=False):\n    \"\"\"\n    Applies the flow color wheel to (possibly clipped) flow components u and v.\n    According to the C++ source code of Daniel Scharstein\n    According to the Matlab source code of Deqing Sun\n    Args:\n        u (np.ndarray): Input horizontal flow of shape [H,W]\n        v (np.ndarray): Input vertical flow of shape [H,W]\n        convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.\n    Returns:\n        np.ndarray: Flow visualization image of shape [H,W,3]\n    \"\"\"\n    flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)\n    colorwheel = make_colorwheel()  # shape [55x3]\n    ncols = colorwheel.shape[0]\n    rad = np.sqrt(np.square(u) + np.square(v))\n    a = np.arctan2(-v, -u)/np.pi\n    fk = (a+1) / 2*(ncols-1)\n    k0 = np.floor(fk).astype(np.int32)\n    k1 = k0 + 1\n    k1[k1 == ncols] = 0\n    f = fk - k0\n    for i in range(colorwheel.shape[1]):\n        tmp = colorwheel[:,i]\n        col0 = tmp[k0] / 255.0\n        col1 = tmp[k1] / 255.0\n        col = (1-f)*col0 + f*col1\n        idx = (rad <= 1)\n        col[idx]  = 1 - rad[idx] * (1-col[idx])\n        col[~idx] = col[~idx] * 0.75   # out of range\n        # Note the 2-i => BGR instead of RGB\n        ch_idx = 2-i if convert_to_bgr else i\n        flow_image[:,:,ch_idx] = np.floor(255 * col)\n    return flow_image\n\ndef flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):\n    \"\"\"\n    Expects a two dimensional flow image of shape.\n    Args:\n        flow_uv (np.ndarray): Flow UV image of shape [H,W,2]\n        clip_flow (float, optional): Clip maximum of flow values. Defaults to None.\n        convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.\n    Returns:\n        np.ndarray: Flow visualization image of shape [H,W,3]\n    \"\"\"\n    assert flow_uv.ndim == 3, 'input flow must have three dimensions'\n    assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'\n    if clip_flow is not None:\n        flow_uv = np.clip(flow_uv, 0, clip_flow)\n    u = flow_uv[:,:,0]\n    v = flow_uv[:,:,1]\n    rad = np.sqrt(np.square(u) + np.square(v))\n    rad_max = np.max(rad)\n    epsilon = 1e-5\n    u = u / (rad_max + epsilon)\n    v = v / (rad_max + epsilon)\n    return flow_uv_to_colors(u, v, convert_to_bgr)"
  },
  {
    "path": "opensora/models/frame_interpolation/utils/utils.py",
    "content": "import re\nimport sys\nimport torch\nimport random\nimport numpy as np\nfrom PIL import ImageFile\nimport torch.nn.functional as F\nfrom imageio import imread, imwrite\nImageFile.LOAD_TRUNCATED_IMAGES = True\n\n\nclass AverageMeter():\n    def __init__(self):\n        self.reset()\n\n    def reset(self):\n        self.val = 0.\n        self.avg = 0.\n        self.sum = 0.\n        self.count = 0\n\n    def update(self, val, n=1):\n        self.val = val\n        self.sum += val * n\n        self.count += n\n        self.avg = self.sum / self.count\n\n\nclass AverageMeterGroups:\n    def __init__(self) -> None:\n        self.meter_dict = dict()\n    \n    def update(self, dict, n=1):\n        for name, val in dict.items():\n            if self.meter_dict.get(name) is None:\n                self.meter_dict[name] = AverageMeter()\n            self.meter_dict[name].update(val, n)\n    \n    def reset(self, name=None):\n        if name is None:\n            for v in self.meter_dict.values():\n                v.reset()\n        else:\n            meter = self.meter_dict.get(name)\n            if meter is not None:\n                meter.reset()\n    \n    def avg(self, name):\n        meter = self.meter_dict.get(name)\n        if meter is not None:\n            return meter.avg\n\n\nclass InputPadder:\n    \"\"\" Pads images such that dimensions are divisible by divisor \"\"\"\n    def __init__(self, dims, divisor=16):\n        self.ht, self.wd = dims[-2:]\n        pad_ht = (((self.ht // divisor) + 1) * divisor - self.ht) % divisor\n        pad_wd = (((self.wd // divisor) + 1) * divisor - self.wd) % divisor\n        self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]\n\n    def pad(self, *inputs):\n        if len(inputs) == 1:\n            return F.pad(inputs[0], self._pad, mode='replicate')\n        else:\n            return [F.pad(x, self._pad, mode='replicate') for x in inputs]\n\n    def unpad(self, *inputs):\n        if len(inputs) == 1:\n            return self._unpad(inputs[0])\n        else:\n            return [self._unpad(x) for x in inputs]\n    \n    def _unpad(self, x):\n        ht, wd = x.shape[-2:]\n        c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]\n        return x[..., c[0]:c[1], c[2]:c[3]]\n\n\ndef img2tensor(img):\n    if img.shape[-1] > 3:\n        img = img[:,:,:3]\n    return torch.tensor(img).permute(2, 0, 1).unsqueeze(0) / 255.0\n\n\ndef tensor2img(img_t):\n    return (img_t * 255.).detach(\n                        ).squeeze(0).permute(1, 2, 0).cpu().numpy(\n                        ).clip(0, 255).astype(np.uint8)\n\ndef seed_all(seed):\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n\n\ndef read(file):\n    if file.endswith('.float3'): return readFloat(file)\n    elif file.endswith('.flo'): return readFlow(file)\n    elif file.endswith('.ppm'): return readImage(file)\n    elif file.endswith('.pgm'): return readImage(file)\n    elif file.endswith('.png'): return readImage(file)\n    elif file.endswith('.jpg'): return readImage(file)\n    elif file.endswith('.pfm'): return readPFM(file)[0]\n    else: raise Exception('don\\'t know how to read %s' % file)\n\n\ndef write(file, data):\n    if file.endswith('.float3'): return writeFloat(file, data)\n    elif file.endswith('.flo'): return writeFlow(file, data)\n    elif file.endswith('.ppm'): return writeImage(file, data)\n    elif file.endswith('.pgm'): return writeImage(file, data)\n    elif file.endswith('.png'): return writeImage(file, data)\n    elif file.endswith('.jpg'): return writeImage(file, data)\n    elif file.endswith('.pfm'): return writePFM(file, data)\n    else: raise Exception('don\\'t know how to write %s' % file)\n\n\ndef readPFM(file):\n    file = open(file, 'rb')\n\n    color = None\n    width = None\n    height = None\n    scale = None\n    endian = None\n\n    header = file.readline().rstrip()\n    if header.decode(\"ascii\") == 'PF':\n        color = True\n    elif header.decode(\"ascii\") == 'Pf':\n        color = False\n    else:\n        raise Exception('Not a PFM file.')\n\n    dim_match = re.match(r'^(\\d+)\\s(\\d+)\\s$', file.readline().decode(\"ascii\"))\n    if dim_match:\n        width, height = list(map(int, dim_match.groups()))\n    else:\n        raise Exception('Malformed PFM header.')\n\n    scale = float(file.readline().decode(\"ascii\").rstrip())\n    if scale < 0:\n        endian = '<'\n        scale = -scale\n    else:\n        endian = '>'\n\n    data = np.fromfile(file, endian + 'f')\n    shape = (height, width, 3) if color else (height, width)\n\n    data = np.reshape(data, shape)\n    data = np.flipud(data)\n    return data, scale\n\n\ndef writePFM(file, image, scale=1):\n    file = open(file, 'wb')\n\n    color = None\n\n    if image.dtype.name != 'float32':\n        raise Exception('Image dtype must be float32.')\n\n    image = np.flipud(image)\n\n    if len(image.shape) == 3 and image.shape[2] == 3:\n        color = True\n    elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1:\n        color = False\n    else:\n        raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.')\n\n    file.write('PF\\n' if color else 'Pf\\n'.encode())\n    file.write('%d %d\\n'.encode() % (image.shape[1], image.shape[0]))\n\n    endian = image.dtype.byteorder\n\n    if endian == '<' or endian == '=' and sys.byteorder == 'little':\n        scale = -scale\n\n    file.write('%f\\n'.encode() % scale)\n\n    image.tofile(file)\n\n\ndef readFlow(name):\n    if name.endswith('.pfm') or name.endswith('.PFM'):\n        return readPFM(name)[0][:,:,0:2]\n\n    f = open(name, 'rb')\n\n    header = f.read(4)\n    if header.decode(\"utf-8\") != 'PIEH':\n        raise Exception('Flow file header does not contain PIEH')\n\n    width = np.fromfile(f, np.int32, 1).squeeze()\n    height = np.fromfile(f, np.int32, 1).squeeze()\n\n    flow = np.fromfile(f, np.float32, width * height * 2).reshape((height, width, 2))\n\n    return flow.astype(np.float32)\n\n\ndef readImage(name):\n    if name.endswith('.pfm') or name.endswith('.PFM'):\n        data = readPFM(name)[0]\n        if len(data.shape)==3:\n            return data[:,:,0:3]\n        else:\n            return data\n    return imread(name)\n\n\ndef writeImage(name, data):\n    if name.endswith('.pfm') or name.endswith('.PFM'):\n        return writePFM(name, data, 1)\n    return imwrite(name, data)\n\n\ndef writeFlow(name, flow):\n    f = open(name, 'wb')\n    f.write('PIEH'.encode('utf-8'))\n    np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)\n    flow = flow.astype(np.float32)\n    flow.tofile(f)\n\n\ndef readFloat(name):\n    f = open(name, 'rb')\n\n    if(f.readline().decode(\"utf-8\"))  != 'float\\n':\n        raise Exception('float file %s did not contain <float> keyword' % name)\n\n    dim = int(f.readline())\n\n    dims = []\n    count = 1\n    for i in range(0, dim):\n        d = int(f.readline())\n        dims.append(d)\n        count *= d\n\n    dims = list(reversed(dims))\n\n    data = np.fromfile(f, np.float32, count).reshape(dims)\n    if dim > 2:\n        data = np.transpose(data, (2, 1, 0))\n        data = np.transpose(data, (1, 0, 2))\n\n    return data\n\n\ndef writeFloat(name, data):\n    f = open(name, 'wb')\n\n    dim=len(data.shape)\n    if dim>3:\n        raise Exception('bad float file dimension: %d' % dim)\n\n    f.write(('float\\n').encode('ascii'))\n    f.write(('%d\\n' % dim).encode('ascii'))\n\n    if dim == 1:\n        f.write(('%d\\n' % data.shape[0]).encode('ascii'))\n    else:\n        f.write(('%d\\n' % data.shape[1]).encode('ascii'))\n        f.write(('%d\\n' % data.shape[0]).encode('ascii'))\n        for i in range(2, dim):\n            f.write(('%d\\n' % data.shape[i]).encode('ascii'))\n\n    data = data.astype(np.float32)\n    if dim==2:\n        data.tofile(f)\n\n    else:\n        np.transpose(data, (2, 0, 1)).tofile(f)\n\n\ndef check_dim_and_resize(tensor_list):\n    shape_list = []\n    for t in tensor_list:\n        shape_list.append(t.shape[2:])\n\n    if len(set(shape_list)) > 1:\n        desired_shape = shape_list[0]\n        print(f'Inconsistent size of input video frames. All frames will be resized to {desired_shape}')\n        \n        resize_tensor_list = []\n        for t in tensor_list:\n            resize_tensor_list.append(torch.nn.functional.interpolate(t, size=tuple(desired_shape), mode='bilinear'))\n\n        tensor_list = resize_tensor_list\n\n    return tensor_list\n\n"
  },
  {
    "path": "opensora/models/prompt_refiner/inference.py",
    "content": "from transformers import AutoModelForCausalLM, AutoTokenizer\nimport torch\nfrom tqdm import tqdm\nimport argparse\n\ndef get_output(prompt):\n    template = \"Refine the sentence: \\\"{}\\\" to contain subject description, action, scene description. \" \\\n            \"(Optional: camera language, light and shadow, atmosphere) and conceive some additional actions to make the sentence more dynamic. \" \\\n            \"Make sure it is a fluent sentence, not nonsense.\"\n    prompt = template.format(prompt)\n    messages = [\n            {\"role\": \"system\", \"content\": \"You are a caption refiner.\"},\n            {\"role\": \"user\", \"content\": prompt}\n    ]\n\n    input_ids = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n    model_inputs = tokenizer([input_ids], return_tensors=\"pt\").to(device)\n    generated_ids = model.generate(model_inputs.input_ids, max_new_tokens=512)\n    generated_ids = [\n        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)\n    ]\n    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]\n    print('\\nInput\\n:', prompt)\n    print('\\nOutput\\n:', response)\n    return response\n\ndef parse_args():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--mode_path\", type=str, default=\"llama3_8B_lora_merged_cn\")\n    parser.add_argument(\"--prompt\", type=str, default='a dog is running.')\n    args = parser.parse_args()\n    return args\n\nif __name__ == '__main__':\n    args = parse_args()\n    device = torch.device('cuda')\n    tokenizer = AutoTokenizer.from_pretrained(args.mode_path, trust_remote_code=True)\n    model = AutoModelForCausalLM.from_pretrained(args.mode_path,torch_dtype=torch.bfloat16, trust_remote_code=True).to(device).eval()\n\n    response = get_output(args.prompt)"
  },
  {
    "path": "opensora/models/prompt_refiner/merge.py",
    "content": "import os\nfrom transformers import AutoModelForCausalLM, AutoTokenizer\nfrom peft import PeftModel\nimport torch\nimport argparse\n\n\ndef get_lora_model(base_model_path, lora_model_input_path, lora_model_output_path):\n    model = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, device_map=\"auto\",trust_remote_code=True)\n    model = PeftModel.from_pretrained(model, lora_model_input_path)\n    merged_model = model.merge_and_unload()\n    merged_model.save_pretrained(lora_model_output_path, safe_serialization=True)\n    print(\"Merge lora to base model\")\n\n    tokenizer = AutoTokenizer.from_pretrained(base_model_path, trust_remote_code=True)\n    tokenizer.save_pretrained(lora_model_output_path)\n    print(\"Save tokenizer\")\n\ndef get_model_result(base_model_path, fintune_model_path):\n    tokenizer = AutoTokenizer.from_pretrained(base_model_path)\n    device = \"cuda\"\n\n    fintune_model = AutoModelForCausalLM.from_pretrained(\n        fintune_model_path,\n        device_map=\"auto\",\n        torch_dtype=torch.bfloat16,\n    ).eval()\n\n    base_model = AutoModelForCausalLM.from_pretrained(\n        base_model_path,\n        device_map=\"auto\",\n        torch_dtype=torch.bfloat16,\n    ).eval()\n\n    template = \"Refine the sentence: \\\"{}\\\" to contain subject description, action, scene description. \" \\\n        \"(Optional: camera language, light and shadow, atmosphere) and conceive some additional actions to make the sentence more dynamic. \" \\\n        \"Make sure it is a fluent sentence, not nonsense.\"\n\n    prompt = \"a dog和一只猫\"\n    prompt = template.format(prompt)\n    messages = [\n        {\"role\": \"system\", \"content\": \"You are a caption refiner.\"},\n        {\"role\": \"user\", \"content\": prompt}\n    ]\n    text = tokenizer.apply_chat_template(\n        messages,\n        tokenize=False,\n        add_generation_prompt=True\n    )\n\n    model_inputs = tokenizer([text], return_tensors=\"pt\").to(device)\n\n    def get_result(model_inputs, model):\n        generated_ids = model.generate(\n            model_inputs.input_ids,\n            max_new_tokens=512,\n            eos_token_id=tokenizer.get_vocab()[\"<|eot_id|>\"]\n        )\n        generated_ids = [\n            output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)\n        ]\n\n        response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]\n        return response\n\n    base_model_response = get_result(model_inputs, base_model)\n    fintune_model_response = get_result(model_inputs, fintune_model)\n    print(\"\\nInput\\n\", prompt)\n    print(\"\\nResult before fine-tune:\\n\", base_model_response)\n    print(\"\\nResult after fine-tune:\\n\", fintune_model_response)\n\ndef parse_args():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--base_path\", type=str, default=\"Meta-Llama-3___1-8B-Instruct\")\n    parser.add_argument(\"--lora_in_path\", type=str, default=\"llama3_1_instruct_lora/checkpoint-1008\")\n    parser.add_argument(\"--lora_out_path\", type=str, default=\"llama3_1_instruct_lora/llama3_8B_lora_merged_cn\")\n    args = parser.parse_args()\n    return args\n\nif __name__ == '__main__':\n    args = parse_args()\n    get_lora_model(args.base_path, args.lora_in_path, args.lora_out_path)\n    get_model_result(args.base_path, args.lora_out_path)"
  },
  {
    "path": "opensora/models/prompt_refiner/train.py",
    "content": "from datasets import Dataset\nimport pandas as pd\nfrom transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer, GenerationConfig\nfrom peft import LoraConfig, TaskType, get_peft_model\nimport torch\nimport argparse\n\nins = \"Refine the sentence to contain subject description, action, scene description. \" \\\n        \"(Optional: camera language, light and shadow, atmosphere) and conceive some additional actions to make the sentence more dynamic. \" \\\n        \"Make sure it is a fluent sentence, not nonsense.\"\n\ndef parse_args():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--data_path\", type=str, default='refine_32255.json')\n    parser.add_argument(\"--model_path\", type=str, default='Meta-Llama-3___1-8B-Instruct')\n    parser.add_argument(\"--lora_out_path\", type=str, default=\"llama3_1_instruct_lora\")\n    args = parser.parse_args()\n    return args\n\nargs = parse_args()\n\n\ndf = pd.read_json(args.data_path)\nds = Dataset.from_pandas(df)\ntokenizer = AutoTokenizer.from_pretrained(args.model_path, use_fast=False, trust_remote_code=True)\ntokenizer.pad_token = tokenizer.eos_token\n\ndef process_func(example):\n    MAX_LENGTH = 2048   \n    input_ids, attention_mask, labels = [], [], []\n    instruction = tokenizer(f\"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\\n\\nYou are a caption refiner.<|eot_id|><|start_header_id|>user<|end_header_id|>\\n\\n{example['instruction'] + example['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n\", add_special_tokens=False)  # add_special_tokens 不在开头加 special_tokens\n    response = tokenizer(f\"{example['output']}<|eot_id|>\", add_special_tokens=False)\n    input_ids = instruction[\"input_ids\"] + response[\"input_ids\"] + [tokenizer.pad_token_id]\n    attention_mask = instruction[\"attention_mask\"] + response[\"attention_mask\"] + [1] \n    labels = [-100] * len(instruction[\"input_ids\"]) + response[\"input_ids\"] + [tokenizer.pad_token_id]  \n    if len(input_ids) > MAX_LENGTH: \n        input_ids = input_ids[:MAX_LENGTH]\n        attention_mask = attention_mask[:MAX_LENGTH]\n        labels = labels[:MAX_LENGTH]\n    return {\n        \"input_ids\": input_ids,\n        \"attention_mask\": attention_mask,\n        \"labels\": labels\n    }\n\ntokenized_id = ds.map(process_func, remove_columns=ds.column_names)\n\n\nmodel = AutoModelForCausalLM.from_pretrained(args.model_path, device_map=\"auto\",torch_dtype=torch.bfloat16)\nprint(model)\nmodel.enable_input_require_grads()\n\nconfig = LoraConfig(\n    task_type=TaskType.CAUSAL_LM, \n    target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"],\n    inference_mode=False,\n    r=64,\n    lora_alpha=64,\n    lora_dropout=0.1\n)\nprint(config)\n\nmodel = get_peft_model(model, config)\nmodel.print_trainable_parameters()\n\nargs = TrainingArguments(\n    output_dir=args.lora_out_path,\n    per_device_train_batch_size=32,\n    gradient_accumulation_steps=1,\n    logging_steps=1,\n    num_train_epochs=1,\n    save_steps=20, \n    dataloader_num_workers=4, \n    learning_rate=1.5e-4,\n    warmup_ratio=0.1, \n    save_on_each_node=True,\n    gradient_checkpointing=True, \n    report_to='wandb', \n)\n\ntrainer = Trainer(\n    model=model,\n    args=args,\n    train_dataset=tokenized_id,\n    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),\n)\n\ntrainer.train()"
  },
  {
    "path": "opensora/models/text_encoder/__init__.py",
    "content": "from opensora.models.text_encoder.clip import CLIPWrapper\nfrom opensora.models.text_encoder.t5 import T5Wrapper\n\ntext_encoder = {\n    'google/mt5-xl': T5Wrapper,\n    'google/mt5-xxl': T5Wrapper,\n    'google/umt5-xl': T5Wrapper,\n    'google/umt5-xxl': T5Wrapper,\n    'google/t5-v1_1-xl': T5Wrapper,\n    'DeepFloyd/t5-v1_1-xxl': T5Wrapper,\n    'openai/clip-vit-large-patch14': CLIPWrapper, \n    'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k': CLIPWrapper\n}\n\ndef get_text_warpper(text_encoder_name):\n    \"\"\"deprecation\"\"\"\n    encoder_key = None\n    for key in text_encoder.keys():\n        if key in text_encoder_name:\n            encoder_key = key\n            break\n    text_enc = text_encoder.get(encoder_key, None)\n    assert text_enc is not None\n    return text_enc\n"
  },
  {
    "path": "opensora/models/text_encoder/clip.py",
    "content": "import torch\nfrom torch import nn\nfrom transformers import CLIPTextModelWithProjection\n\ntry:\n    import torch_npu\nexcept:\n    torch_npu = None\n\nclass CLIPWrapper(nn.Module):\n    def __init__(self, args, **kwargs):\n        super(CLIPWrapper, self).__init__()\n        self.model_name = args.text_encoder_name_2\n        if torch_npu is not None:\n            self.model_name = '/home/save_dir/pretrained/clip/models--laion--CLIP-ViT-bigG-14-laion2B-39B-b160k/snapshots/bc7788f151930d91b58474715fdce5524ad9a189'\n        else:\n            self.model_name = '/storage/cache_dir/CLIP-ViT-bigG-14-laion2B-39B-b160k'\n        print(f'Loading CLIP model from {self.model_name}...')\n        self.text_enc = CLIPTextModelWithProjection.from_pretrained(self.model_name, cache_dir=args.cache_dir, **kwargs).eval()\n\n    def forward(self, input_ids, attention_mask): \n        text_encoder_embs = self.text_enc(input_ids=input_ids, output_hidden_states=True)[0]\n        return text_encoder_embs.detach()\n"
  },
  {
    "path": "opensora/models/text_encoder/t5.py",
    "content": "import torch\nfrom torch import nn\nfrom transformers import T5EncoderModel\n\ntry:\n    import torch_npu\nexcept:\n    torch_npu = None\n\nclass T5Wrapper(nn.Module):\n    def __init__(self, args, **kwargs):\n        super(T5Wrapper, self).__init__()\n        self.model_name = args.text_encoder_name_1\n        print(f'Loading T5 model from {self.model_name}...')\n        self.text_enc = T5EncoderModel.from_pretrained(self.model_name, cache_dir=args.cache_dir, **kwargs).eval()\n\n    def forward(self, input_ids, attention_mask):\n        text_encoder_embs = self.text_enc(input_ids=input_ids, attention_mask=attention_mask)['last_hidden_state']\n        return text_encoder_embs.detach()\n"
  },
  {
    "path": "opensora/npu_config.py",
    "content": "import math\nimport mmap\nimport os\nimport pickle\nimport random\nimport numpy as np\nimport torch\nimport subprocess\nimport sys\nimport threading\nimport gc\nimport torch.distributed as dist\n\nfrom opensora.adaptor.zp_manager import zp_manager\n\ntry:\n    import torch_npu\n\n    npu_is_available = True\n    from torch_npu.contrib import transfer_to_npu\nexcept:\n    npu_is_available = False\n\nfrom contextlib import contextmanager\nimport types\n\n\ndef compress_video(input_file, output_file, out_size):\n    \"\"\"使用 ffmpeg 压缩视频文件。\"\"\"\n    command = [\n        'ffmpeg',\n        '-i', input_file,\n        '-vf', f\"scale='min({out_size},iw)':'min({out_size},ih)':force_original_aspect_ratio=decrease\",\n        '-c:v', 'libx264',\n        '-crf', '18',\n        '-preset', 'slow',\n        '-c:a', 'copy',\n        output_file\n    ]\n    subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)\n\n\n@contextmanager\ndef set_run_dtype(x, dtype=None):\n    # 保存原始环境变量的值（如果存在）\n    npu_config.original_run_dtype = x.dtype\n    # 设置环境变量为指定的值\n    npu_config.current_run_dtype = dtype\n    try:\n        # Yield control back to the body of the `with` statement\n        yield\n    finally:\n        # 恢复原始的环境变量值\n        npu_config.current_run_dtype = None\n        npu_config.original_run_dtype = None\n\n\nclass NPUConfig:\n    N_NPU_PER_NODE = 8\n\n    def __init__(self):\n        self.on_npu = npu_is_available\n        self.node_world_size = self.N_NPU_PER_NODE\n        self.profiling = False\n        self.profiling_step = 5\n        self.enable_FA = True\n        self.enable_FP32 = False\n        self.load_pickle = True\n        self.use_small_dataset = False\n        self.current_run_dtype = None\n        self.original_run_dtype = None\n        self.zp_manager = zp_manager\n        self.replaced_type = torch.float32\n        self.conv_dtype = torch.float16\n        if self.enable_FA and self.enable_FP32:\n            self.inf_float = -10000.0\n        else:\n            self.inf_float = -10000.0\n\n        if self.use_small_dataset:\n            self.load_pickle = False\n\n        self._loss = []\n        self.work_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))\n        self.pickle_save_path = f\"{self.work_path}/pickles\"\n        self.mm = dict()\n\n        if self.on_npu:\n            import deepspeed\n            import sys\n            torch_npu.npu.set_compile_mode(jit_compile=False)\n\n            import deepspeed.runtime.utils as utils\n            from opensora.adaptor.utils import all_gather_dp_groups, all_gather_into_tensor_dp_groups\n            utils.all_gather_dp_groups = all_gather_dp_groups\n\n            import deepspeed.runtime.bf16_optimizer as bf16_optimizer\n            from opensora.adaptor.bf16_optimizer import BF16_Optimizer\n            self.replace_methods(bf16_optimizer.BF16_Optimizer, BF16_Optimizer)\n\n            from opensora.adaptor.stage_1_and_2 import DeepSpeedZeroOptimizer\n            import deepspeed.runtime.zero.stage_1_and_2 as stage_1_and_2\n            self.replace_methods(stage_1_and_2.DeepSpeedZeroOptimizer, DeepSpeedZeroOptimizer, ['_has_inf_or_nan'])\n\n            import deepspeed.runtime.engine as engine\n            from opensora.adaptor.engine import DeepSpeedEngine\n            self.replace_methods(engine.DeepSpeedEngine, DeepSpeedEngine, skip_fcns=['__init__', '_copy_recovery_script', '_change_recovery_script_permissions'])\n\n        if \"RANK\" in os.environ:\n            self.rank = int(os.environ[\"RANK\"])\n            self.world_size = int(os.environ[\"WORLD_SIZE\"])\n            torch_npu.npu.set_device(self.get_local_rank())\n        else:\n            self.rank = torch.cuda.current_device()\n            self.world_size = self.N_NPU_PER_NODE\n        self.print_with_rank(f\"The npu_config.on_npu is {self.on_npu}\")\n        self.bind_thread_to_cpu()\n        gc.set_threshold(700, 10, 10000)\n\n    def get_total_cores(self):\n        try:\n            total_cores = os.sysconf('SC_NPROCESSORS_ONLN')\n        except (AttributeError, ValueError):\n            total_cores = os.cpu_count()\n        return total_cores\n\n\n    def bind_thread_to_cpu(self):\n        total_cores = self.get_total_cores()\n        # 每个卡的核心数量\n        cores_per_rank = total_cores // 8\n        # 计算本地rank\n        local_rank = self.rank % 8\n        # 计算当前 rank 的 CPU 核范围\n        start_core = local_rank * cores_per_rank\n        end_core = start_core + cores_per_rank - 1\n        # 构建 CPU 核范围字符串\n        cpu_cores_range = f\"{start_core}-{end_core}\"\n        pid = os.getpid()\n        command = f\"taskset -cp {cpu_cores_range} {pid}\"\n\n        subprocess.run(command, shell=True, check=True)\n        return f\"Binding Cores:{self.rank}:{pid}:{cpu_cores_range}\"\n\n    def replace_methods(self, target_class, source_class, skip_fcns=[], only_include_fcns=None):\n        for attr_name in dir(source_class):\n            attr_value = getattr(source_class, attr_name)\n            if attr_name in source_class.__dict__:\n                attr_class_value = source_class.__dict__[attr_name]\n            else:\n                attr_class_value = attr_value\n            if (isinstance(attr_class_value, staticmethod) or isinstance(attr_class_value, classmethod)\n                    or attr_name in skip_fcns):\n                print(f\"skip replace {attr_name}\")\n                continue\n\n            if only_include_fcns is not None and attr_name not in only_include_fcns:\n                continue\n\n            elif isinstance(attr_value, types.FunctionType):\n                setattr(target_class, attr_name, attr_value)\n\n    def get_attention_mask(self, attention_mask, repeat_num):\n        if self.on_npu and attention_mask is not None:\n            if npu_config.enable_FA:\n                attention_mask = attention_mask.to(torch.bool)\n            attention_mask = attention_mask.repeat_interleave(repeat_num, dim=-2)\n        return attention_mask\n    def set_current_run_dtype(self, variables):\n        if variables[0].dtype != self.current_run_dtype and self.current_run_dtype is not None:\n            for index, var in enumerate(variables):\n                variables[index] = var.to(self.current_run_dtype)\n        return tuple(variables)\n\n    def restore_dtype(self, x):\n        if x.dtype != self.original_run_dtype and self.original_run_dtype is not None:\n            x = x.to(self.original_run_dtype)\n        return x\n\n    def get_output_video_path(self, name):\n        os.makedirs(f\"{self.work_path}/output_videos\", exist_ok=True)\n        return f\"{self.work_path}/output_videos/{name}\"\n\n    def get_node_id(self):\n        return self.rank // self.node_world_size\n\n    def get_node_size(self):\n        return self.world_size // self.node_world_size\n\n    def get_local_rank(self):\n        return self.rank % self.N_NPU_PER_NODE\n\n    def get_pickle_path(self, file_name):\n        return f\"{self.pickle_save_path}/{file_name}_local_n63\"\n\n    def free_mm(self):\n        for key, value in self.mm.items():\n            value.close()\n        self.mm.clear()\n\n    def __del__(self):\n        self.free_mm()\n\n    def try_load_pickle(self, file_name, function):\n        file_name = self.get_pickle_path(file_name)\n        if os.path.exists(file_name) and self.load_pickle:\n            with open(file_name, 'rb') as file:\n                # self.mm[file_name] = mmap.mmap(file.fileno(), 0, access=mmap.ACCESS_READ)\n                # # 使用 mmap 进行数据读取\n                # loaded_data = pickle.loads(self.mm[file_name][:])\n                loaded_data = pickle.load(file)\n                return loaded_data\n        else:\n            data = function()\n            if not self.use_small_dataset:\n                if self.rank % self.N_NPU_PER_NODE == 0:\n                    # 只需要rank0保存文件\n                    os.makedirs(self.pickle_save_path, exist_ok=True)\n                    with open(file_name, 'wb') as file:\n                        pickle.dump(data, file, pickle.HIGHEST_PROTOCOL)\n            return data\n\n    def try_get_vid_path(self, file, out_size=1024):\n        output_file = file.rsplit(\".\", 1)[0] + f\"_resize{out_size}.mp4\"\n        if not os.path.exists(output_file):\n            return file\n        #     compress_video(file, output_file, out_size)\n        return output_file\n\n    def npu_format_cast(self, x):\n        return torch_npu.npu_format_cast(x, 2)\n\n    def calc_grad_norm(self, model):\n        # 计算并打印梯度范数\n        # model_engine = accelerator.deepspeed_engine_wrapped.engine\n        # gradients = model_engine.get_gradients()\n        # grad_norm = get_grad_norm(gradients)\n        # 计算并打印梯度范数\n        grad_norm = 0\n        n_grad = 0\n        # for name, param in model.named_parameters():\n        #     grad_data = deepspeed.utils.safe_get_full_grad(param)\n        #     # self.print_tensor_stats(grad_data, name=name)\n        #\n        #     if grad_data is not None:\n        #         param_norm = grad_data.norm(2)\n        #         grad_norm += param_norm.item() ** 2\n        #         n_grad += 1\n        # grad_norm = (grad_norm / n_grad) ** (1. / 2)\n\n        return grad_norm\n\n    def _run(self, operator, x, tmp_dtype, out_dtype=None, out_nd_format=False):\n        if self.on_npu:\n            if out_dtype is None:\n                out_dtype = x.dtype\n\n            with torch.cuda.amp.autocast(enabled=False):\n                x = operator.to(device=x.device, dtype=tmp_dtype)(x.to(tmp_dtype))\n                x = x.to(out_dtype)\n                if out_nd_format:\n                    return self.npu_format_cast(x)\n                else:\n                    return x\n        else:\n            return operator(x)\n\n    def run_group_norm(self, operator, x):\n        return self._run(operator, x, torch.float32)\n\n    def run_layer_norm(self, operator, x):\n        return self._run(operator, x, torch.float32)\n\n    def print_tensor_stats(self, tensor, name=\"Tensor\", rank=None):\n        if rank and rank != self.rank:\n            return\n\n        if tensor is None:\n            self.print_msg(f\"Tensor {name} is None.\")\n            return\n\n        x_dtype = tensor.dtype\n        tensor = tensor.to(torch.bfloat16)\n        max_val = tensor.max().item()\n        min_val = tensor.min().item()\n        abs_max_val = min(abs(max_val), abs(min_val))\n        mean_val = tensor.mean().item()\n        median_val = tensor.median().item()\n        std_val = tensor.std().item()\n        shape = tensor.shape\n        self.print_msg(\n            f\"{name} - Max: {max_val}, Min: {min_val}, Mean: {mean_val}, AbsMax: {abs_max_val},\"\n            f\"Median: {median_val}, Std: {std_val}, Shape: {shape}, Type: {x_dtype}\")\n\n    def run_conv3d(self, operator, x, out_dtype):\n        return self._run(operator, x, self.conv_dtype, out_dtype, out_nd_format=True)\n\n    def run_pool_2d(self, operator, x):\n        return self._run(operator, x, self.replaced_type)\n\n    def run_pad_2d(self, operator, x, pad, mode=\"constant\"):\n        if self.on_npu:\n            x_dtype = x.dtype\n            x = x.to(self.replaced_type)\n            x = operator(x, pad, mode)\n            x = x.to(x_dtype)\n        else:\n            x = operator(x, pad, mode)\n        return x\n\n    def seed_everything(self, seed=100):\n        seed += self.rank\n        random.seed(seed)\n        np.random.seed(seed)\n        torch.manual_seed(seed)\n        torch.cuda.manual_seed(seed)\n        torch.cuda.manual_seed_all(seed)\n        torch.backends.cudnn.deterministic = True\n        torch.backends.cudnn.benchmark = False\n\n    def print_with_rank(self, msg, rank=0, save=False):\n        if self.rank == rank:\n            print(f\"{msg}\", flush=True)\n            if save:\n                self._loss.append(msg)\n\n    def print_msg(self, msg, on=True, rank=None):\n        if on:\n            if self.rank == rank or rank is None:\n                print(f\"[RANK-{self.rank}]: {msg}\", flush=True)\n\n    def save_loss(self, filename, rank=0):\n        if self.rank == rank:\n            import json\n            with open(filename, 'w') as file:\n                json.dump(self._loss, file, indent=4)\n\n    def run_attention(self, query, key, value, atten_mask, input_layout, head_dim, head_num):\n        if self.enable_FA:\n            hidden_states = torch_npu.npu_fusion_attention(query, key, value,\n                                                           atten_mask=atten_mask,\n                                                           input_layout=input_layout,\n                                                           scale=1 / math.sqrt(head_dim),\n                                                           head_num=head_num)[0]\n        else:\n            hidden_states = self.scaled_dot_product_attention(query, key, value,\n                                                              atten_mask=atten_mask,\n                                                              input_layout=input_layout,\n                                                              scale=1 / math.sqrt(head_dim),\n                                                              head_num=head_num)\n        return hidden_states\n\n    def scaled_dot_product_attention(self, query, key, value, input_layout, head_num=None,\n                                     atten_mask=None, scale=None, dropout_p=0.0, is_causal=False) -> torch.Tensor:\n        # L, S = query.size(-2), key.size(-2)\n        def trans_tensor_shape(x, layout, head_num):\n            if layout == \"BSH\":\n                batch = x.shape[0]\n                x = x.view(batch, -1, head_num, x.shape[-1] // head_num).transpose(1, 2).contiguous()\n            elif layout == \"SBH\":\n                batch = x.shape[1]\n                x = x.view(-1, batch * head_num, x.shape[-1] // head_num).transpose(0, 1).contiguous()\n                x = x.view(batch, head_num, -1, x.shape[-1])\n            return x\n\n        query = trans_tensor_shape(query, input_layout, head_num)\n        key = trans_tensor_shape(key, input_layout, head_num)\n        value = trans_tensor_shape(value, input_layout, head_num)\n\n        attn_weight = query @ key.transpose(-2, -1) * scale\n        attn_bias = torch.zeros_like(attn_weight, dtype=query.dtype, device=query.device)\n        if is_causal:\n            assert atten_mask is None\n            temp_mask = torch.zeros_like(attn_weight, dtype=torch.bool, device=query.device).tril(diagonal=0)\n            attn_bias.masked_fill_(temp_mask.logical_not(), npu_config.inf_float)\n            attn_bias.to(query.dtype)\n\n        if atten_mask is not None:\n            assert (not self.enable_FA) and atten_mask.dtype != torch.bool, \\\n                \"attention_mask must not be bool type when use this function\"\n\n        attn_weight += attn_bias\n        attn_weight = torch.softmax(attn_weight, dim=-1)\n        attn_weight = torch.dropout(attn_weight, dropout_p, train=True)\n        output = attn_weight @ value\n        if input_layout == \"BSH\":\n            output = output.transpose(1, 2).contiguous().view(output.shape[0], -1, head_num * output.shape[-1])\n        else:\n            output = output.view(output.shape[0] * head_num, -1, output.shape[-1]).transpose(0, 1).contiguous()\n            output = output.view(output.shape[0], -1, head_num * output.shape[-1])\n        return output\n\n    def print_tensor_with_rank(self, name, tensor, rank=[0], dim_print_cnt=[]):\n        if type(rank) is not list:\n            rank = [rank]\n        if self.rank in rank:\n            def print_dim(tensor_, indices):\n                if tensor_.dim() == len(indices):\n                    return '{0:10.5f} '.format(tensor[tuple(indices)].detach().item())\n                else:\n                    cur_dim = len(indices)\n                    ret = ''\n                    for x in range(0, tensor_.size(cur_dim), tensor_.size(cur_dim) // dim_print_cnt[cur_dim]):\n                        ret += print_dim(tensor_, indices + [x])\n                    return ret + '\\n'\n\n            print(name, tensor.size(), self.rank, '\\n', print_dim(tensor, []))\n\n\nnpu_config = NPUConfig()\n"
  },
  {
    "path": "opensora/sample/caption_refiner.py",
    "content": "import torch\nfrom torch import nn\nfrom transformers import AutoTokenizer, AutoModelForCausalLM\n\n\n\nTEMPLATE = \"\"\"\nRefine the sentence: \\\"{}\\\" to contain subject description, action, scene description. \" \\\n\"(Optional: camera language, light and shadow, atmosphere) and conceive some additional actions to make the sentence more dynamic. \" \\\n\"Make sure it is a fluent sentence, not nonsense.\n\"\"\"\n\nclass OpenSoraCaptionRefiner(nn.Module):\n    def __init__(self, args, dtype, device):\n        super().__init__()\n        self.tokenizer = AutoTokenizer.from_pretrained(\n            args.caption_refiner, trust_remote_code=True\n            )\n        self.model = AutoModelForCausalLM.from_pretrained(\n            args.caption_refiner, torch_dtype=dtype, trust_remote_code=True\n            ).to(device).eval()\n        self.device = device\n        \n    def get_refiner_output(self, prompt):\n        prompt = TEMPLATE.format(prompt)\n        messages = [\n                {\"role\": \"system\", \"content\": \"You are a caption refiner.\"},\n                {\"role\": \"user\", \"content\": prompt}\n        ]\n        input_ids = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n        model_inputs = self.tokenizer([input_ids], return_tensors=\"pt\").to(self.device)\n        generated_ids = self.model.generate(model_inputs.input_ids, max_new_tokens=512)\n        generated_ids = [\n            output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)\n        ]\n        response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]\n        return response"
  },
  {
    "path": "opensora/sample/pipeline_inpaint.py",
    "content": "\nimport inspect\nimport os\nfrom typing import Callable, Dict, List, Optional, Tuple, Union\nfrom dataclasses import dataclass\nfrom altair import condition\nimport numpy as np\nimport torch\nfrom einops import rearrange\nfrom PIL import Image\nimport decord\n\nfrom transformers import CLIPTextModelWithProjection, CLIPTokenizer, CLIPImageProcessor, MT5Tokenizer, T5EncoderModel\nimport torch.nn.functional as F\nfrom torchvision.transforms import Compose, Lambda, Resize\n\nfrom diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput\nfrom diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback\nfrom diffusers.image_processor import VaeImageProcessor\nfrom diffusers.models import AutoencoderKL, HunyuanDiT2DModel\nfrom diffusers.models.embeddings import get_2d_rotary_pos_embed\nfrom diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker\nfrom diffusers.schedulers import DDPMScheduler\nfrom diffusers.utils import logging, BaseOutput\nfrom diffusers.utils.torch_utils import randn_tensor\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline\n\nfrom opensora.models.diffusion.opensora_v1_3.modeling_inpaint import OpenSoraInpaint_v1_3\nfrom opensora.sample.pipeline_opensora import OpenSoraPipeline, OpenSoraPipelineOutput, rescale_noise_cfg\nfrom opensora.dataset.transform import CenterCropResizeVideo, SpatialStrideCropVideo,ToTensorAfterResize, maxhwresize\nfrom opensora.utils.mask_utils import MaskProcessor, MaskCompressor, GaussianNoiseAdder, MaskType, STR_TO_TYPE, TYPE_TO_STR\n\ntry:\n    import torch_npu\n    from opensora.npu_config import npu_config\n    from opensora.acceleration.parallel_states import get_sequence_parallel_state, hccl_info\nexcept:\n    torch_npu = None\n    npu_config = None\n    from opensora.utils.parallel_states import get_sequence_parallel_state, nccl_info\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\ndef is_video_file(file_path):\n    video_extensions = {'.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm', '.mpeg', '.mpg', '.3gp'}\n    file_extension = os.path.splitext(file_path)[1].lower()\n    return file_extension in video_extensions\n\ndef is_image_file(file_path):\n    image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff', '.webp'}\n    file_extension = os.path.splitext(file_path)[1].lower()\n    return file_extension in image_extensions\n\ndef open_image(file_path):\n    image = Image.open(file_path).convert(\"RGB\")\n    return image\n\ndef open_video(file_path, start_frame_idx, num_frames, frame_interval=1):\n\n    decord_vr = decord.VideoReader(file_path, ctx=decord.cpu(0), num_threads=1)\n\n    total_frames = len(decord_vr)\n    frame_indices = list(range(start_frame_idx, min(start_frame_idx + num_frames * frame_interval, total_frames), frame_interval))\n\n    if len(frame_indices) == 0:\n        raise ValueError(\"No frames selected. Check your start_frame_idx and num_frames.\")\n    \n    if len(frame_indices) < num_frames:\n        raise ValueError(f\"Requested {num_frames} frames but only {len(frame_indices)} frames are available, please adjust the start_frame_idx and num_frames or decrease the frame_interval.\")\n        \n    video_data = decord_vr.get_batch(frame_indices).asnumpy()\n    video_data = torch.from_numpy(video_data)\n    video_data = video_data.permute(0, 3, 1, 2)  # (T, H, W, C) -> (T C H W)\n    return video_data\n\n\ndef get_pixel_values(file_path, num_frames):\n    if is_image_file(file_path[0]):\n        pixel_values = [open_image(path) for path in file_path]\n        pixel_values = [torch.from_numpy(np.array(image)) for image in pixel_values]\n        pixel_values = [rearrange(image, 'h w c -> c h w').unsqueeze(0) for image in pixel_values]\n    elif is_video_file(file_path[0]):\n        pixel_values = [open_video(video_path, 0, num_frames) for video_path in file_path]\n    return pixel_values\n\n\nclass OpenSoraInpaintPipeline(OpenSoraPipeline):\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: T5EncoderModel,\n        tokenizer: MT5Tokenizer,\n        transformer: OpenSoraInpaint_v1_3,\n        scheduler: DDPMScheduler,\n        text_encoder_2: CLIPTextModelWithProjection = None,\n        tokenizer_2: CLIPTokenizer = None,\n    ):\n        super().__init__(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            transformer=transformer,\n            scheduler=scheduler,\n            text_encoder_2=text_encoder_2,\n            tokenizer_2=tokenizer_2,\n        )\n        \n        # If performing continuation or random, the default mask is half of the frame, which can be modified\n        self.mask_processor = MaskProcessor(min_clear_ratio=0.5, max_clear_ratio=0.5) \n\n        self.mask_compressor = MaskCompressor(ae_stride_t=self.vae.vae_scale_factor[0], ae_stride_h=self.vae.vae_scale_factor[1], ae_stride_w=self.vae.vae_scale_factor[2])\n        \n        self.noise_adder = None\n\n    def check_inputs(\n        self,\n        conditional_pixel_values_path,\n        conditional_pixel_values_indices,\n        mask_type,\n        max_hxw,\n        noise_strength,\n        prompt,\n        num_frames,\n        height,\n        width,\n        negative_prompt=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n        prompt_attention_mask=None,\n        negative_prompt_attention_mask=None,\n        prompt_embeds_2=None,\n        negative_prompt_embeds_2=None,\n        prompt_attention_mask_2=None,\n        negative_prompt_attention_mask_2=None,\n        callback_on_step_end_tensor_inputs=None,\n    ):\n        if conditional_pixel_values_path is None:\n            raise ValueError(\"conditional_pixel_values_path should be provided\")\n        else:\n            if not isinstance(conditional_pixel_values_path, list) or not isinstance(conditional_pixel_values_path[0], str):\n                raise ValueError(\"conditional_pixel_values_path should be a list of strings\")\n    \n        if not is_image_file(conditional_pixel_values_path[0]) and not is_video_file(conditional_pixel_values_path[0]):\n            raise ValueError(\"conditional_pixel_values_path should be an image or video file path\")  \n        \n        if is_video_file(conditional_pixel_values_path[0]) and len(conditional_pixel_values_path) > 1:\n            raise ValueError(\"conditional_pixel_values_path should be a list of image file paths or a single video file path\")\n        \n        if conditional_pixel_values_indices is not None \\\n            and (not isinstance(conditional_pixel_values_indices, list) or not isinstance(conditional_pixel_values_indices[0], int) \\\n                 or len(conditional_pixel_values_indices) != len(conditional_pixel_values_path)):\n            raise ValueError(\"conditional_pixel_values_indices should be a list of integers with the same length as conditional_pixel_values_path\")\n        \n        if mask_type is not None and not mask_type in STR_TO_TYPE.keys() and not mask_type in STR_TO_TYPE.values():\n            raise ValueError(f\"Invalid mask type: {mask_type}\")\n        \n        if not isinstance(max_hxw, int) or not (max_hxw >= 102400 and max_hxw <= 236544):\n            raise  ValueError(\"max_hxw should be an integer between 102400 and 236544\")\n        \n        if not isinstance(noise_strength, float) or not (noise_strength >= 0 and noise_strength <= 1):\n            raise ValueError(\"noise_strength should be a non-negative float\")\n        \n        super().check_inputs(\n            prompt,\n            num_frames,\n            height,\n            width,\n            negative_prompt,\n            prompt_embeds,\n            negative_prompt_embeds,\n            prompt_attention_mask,\n            negative_prompt_attention_mask,\n            prompt_embeds_2,\n            negative_prompt_embeds_2,\n            prompt_attention_mask_2,\n            negative_prompt_attention_mask_2,\n            callback_on_step_end_tensor_inputs,\n        )\n\n    def get_resize_transform(\n        self, \n        ori_height,\n        ori_width,\n        height=None, \n        width=None, \n        crop_for_hw=False, \n        hw_stride=32, \n        max_hxw=236544, # 480 x 480\n    ):\n        if crop_for_hw:\n            assert height is not None and width is not None\n            transform = CenterCropResizeVideo((height, width))\n        else:\n            new_height, new_width = maxhwresize(ori_height, ori_width, max_hxw)\n            transform = Compose(\n                [\n                    CenterCropResizeVideo((new_height, new_width)), # We use CenterCropResizeVideo to share the same height and width, ensuring that the shape of the crop remains consistent when multiple images are captured\n                    SpatialStrideCropVideo(stride=hw_stride), \n                ]\n            )\n        return transform\n        \n        \n    def get_video_transform(self):\n        norm_fun = Lambda(lambda x: 2. * x - 1.)\n        transform = Compose([\n            ToTensorAfterResize(),\n            norm_fun\n        ])\n        return transform\n\n    def get_mask_type_cond_indices(self, mask_type, conditional_pixel_values_path, conditional_pixel_values_indices, num_frames):\n        if mask_type is not None and mask_type in STR_TO_TYPE.keys():\n            mask_type = STR_TO_TYPE[mask_type]\n        if is_image_file(conditional_pixel_values_path[0]):\n            if len(conditional_pixel_values_path) == 1:\n                mask_type = MaskType.i2v if mask_type is None else mask_type\n                if num_frames > 1:\n                    conditional_pixel_values_indices = [0] if conditional_pixel_values_indices is None else conditional_pixel_values_indices\n                    assert len(conditional_pixel_values_indices) == 1, \"conditional_pixel_values_indices should be a list of integers with the same length as conditional_pixel_values_path\"\n            elif len(conditional_pixel_values_path) == 2:\n                mask_type = MaskType.transition if mask_type is None else mask_type\n                if num_frames > 1:\n                    conditional_pixel_values_indices = [0, -1] if conditional_pixel_values_indices is None else conditional_pixel_values_indices\n                    assert len(conditional_pixel_values_indices) == 2, \"conditional_pixel_values_indices should be a list of integers with the same length as conditional_pixel_values_path\"\n            else:\n                if num_frames > 1:\n                    assert conditional_pixel_values_indices is not None and len(conditional_pixel_values_path) == len(conditional_pixel_values_indices), \"conditional_pixel_values_indices should be a list of integers with the same length as conditional_pixel_values_path\"\n                    mask_type = MaskType.random_temporal if mask_type is None else mask_type\n        elif is_video_file(conditional_pixel_values_path[0]):\n            # When the input is a video, video continuation is executed by default, with a continuation rate of double\n            mask_type = MaskType.continuation if mask_type is None else mask_type\n        return mask_type, conditional_pixel_values_indices\n\n\n    def get_masked_pixel_values_mask(\n        self, \n        conditional_pixel_values,\n        conditional_pixel_values_indices,\n        mask_type, \n        batch_size, \n        num_samples_per_prompt, \n        num_frames, \n        height, \n        width, \n        video_transform,\n        weight_dtype,\n        device\n    ):\n        if device is None:\n            device = getattr(self, '_execution_device', None) or getattr(self, 'device', None) or torch.device('cuda')\n\n        conditional_pixel_values = conditional_pixel_values.to(device=device, dtype=weight_dtype)\n\n        if conditional_pixel_values.shape[0] == num_frames:\n            inpaint_cond_data = self.mask_processor(conditional_pixel_values, mask_type=mask_type)\n            masked_pixel_values, mask = inpaint_cond_data['masked_pixel_values'], inpaint_cond_data['mask']\n        else:\n            input_pixel_values = torch.zeros([num_frames, 3, height, width], device=device, dtype=weight_dtype)\n            input_mask = torch.ones([num_frames, 1, height, width], device=device, dtype=weight_dtype)\n            input_pixel_values[conditional_pixel_values_indices] = conditional_pixel_values\n            input_mask[conditional_pixel_values_indices] = 0\n            masked_pixel_values = input_pixel_values * (input_mask < 0.5)\n            mask = input_mask\n\n        print('conditional_pixel_values_indices', conditional_pixel_values_indices)\n        print('mask_type', TYPE_TO_STR[mask_type])\n\n        masked_pixel_values = video_transform(masked_pixel_values)\n\n        masked_pixel_values = masked_pixel_values.unsqueeze(0).repeat(batch_size * num_samples_per_prompt, 1, 1, 1, 1).transpose(1, 2).contiguous() # b c t h w\n        mask = mask.unsqueeze(0).repeat(batch_size * num_samples_per_prompt, 1, 1, 1, 1).transpose(1, 2).contiguous() # b c t h w\n        \n        if self.noise_adder is not None:\n            # add some noise to improve motion strength\n            masked_pixel_values = self.noise_adder(masked_pixel_values, mask)\n        \n        masked_pixel_values = masked_pixel_values.to(self.vae.vae.dtype)\n        masked_pixel_values = self.vae.encode(masked_pixel_values)\n\n        mask = self.mask_compressor(mask)\n    \n        masked_pixel_values = torch.cat([masked_pixel_values] * 2) if self.do_classifier_free_guidance else masked_pixel_values\n        mask = torch.cat([mask] * 2) if self.do_classifier_free_guidance else mask\n\n        masked_pixel_values = masked_pixel_values.to(weight_dtype)\n        mask = mask.to(weight_dtype)\n\n        return masked_pixel_values, mask\n    \n    @torch.no_grad()\n    def __call__(\n        self,\n        conditional_pixel_values_path: Union[str, List[str]] = None,\n        conditional_pixel_values_indices: Union[int, List[int]] = None,\n        mask_type: Union[str, MaskType] = None,\n        crop_for_hw: bool = False,\n        max_hxw: int = 236544,\n        noise_strength: Optional[float] = 0.0,\n        prompt: Union[str, List[str]] = None,\n        num_frames: Optional[int] = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_inference_steps: Optional[int] = 50,\n        guidance_scale: Optional[float] = 5.0,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_samples_per_prompt: Optional[int] = 1,\n        eta: Optional[float] = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        prompt_embeds_2: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds_2: Optional[torch.Tensor] = None,\n        prompt_attention_mask: Optional[torch.Tensor] = None,\n        prompt_attention_mask_2: Optional[torch.Tensor] = None,\n        negative_prompt_attention_mask: Optional[torch.Tensor] = None,\n        negative_prompt_attention_mask_2: Optional[torch.Tensor] = None,\n        output_type: Optional[str] = \"pil\",\n        return_dict: bool = True,\n        callback_on_step_end: Optional[\n            Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]\n        ] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        guidance_rescale: float = 0.0,\n        max_sequence_length: int = 512,\n        device = None, \n    ):\n        \n        if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):\n            callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs\n\n        # 0. default height and width\n        num_frames = num_frames or (self.transformer.config.sample_size_t - 1) * self.vae.vae_scale_factor[0] + 1\n        height = height or self.transformer.config.sample_size[0] * self.vae.vae_scale_factor[1]\n        width = width or self.transformer.config.sample_size[1] * self.vae.vae_scale_factor[2]\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            conditional_pixel_values_path,\n            conditional_pixel_values_indices,\n            mask_type,\n            max_hxw,\n            noise_strength,\n            prompt,\n            num_frames, \n            height,\n            width,\n            negative_prompt,\n            prompt_embeds,\n            negative_prompt_embeds,\n            prompt_attention_mask,\n            negative_prompt_attention_mask,\n            prompt_embeds_2,\n            negative_prompt_embeds_2,\n            prompt_attention_mask_2,\n            negative_prompt_attention_mask_2,\n            callback_on_step_end_tensor_inputs,\n        )\n        self._guidance_scale = guidance_scale\n        self._guidance_rescale = guidance_rescale\n        self._interrupt = False\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = device or getattr(self, '_execution_device', None) or getattr(self, 'device', None) or torch.device('cuda')\n\n\n        # 3. Encode input prompt\n\n        (\n            prompt_embeds,\n            negative_prompt_embeds,\n            prompt_attention_mask,\n            negative_prompt_attention_mask,\n        ) = self.encode_prompt(\n            prompt=prompt,\n            device=device,\n            dtype=self.transformer.dtype,\n            num_samples_per_prompt=num_samples_per_prompt,\n            do_classifier_free_guidance=self.do_classifier_free_guidance,\n            negative_prompt=negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            prompt_attention_mask=prompt_attention_mask,\n            negative_prompt_attention_mask=negative_prompt_attention_mask,\n            max_sequence_length=max_sequence_length,\n            text_encoder_index=0,\n        )\n        if self.tokenizer_2 is not None:\n            (\n                prompt_embeds_2,\n                negative_prompt_embeds_2,\n                prompt_attention_mask_2,\n                negative_prompt_attention_mask_2,\n            ) = self.encode_prompt(\n                prompt=prompt,\n                device=device,\n                dtype=self.transformer.dtype,\n                num_samples_per_prompt=num_samples_per_prompt,\n                do_classifier_free_guidance=self.do_classifier_free_guidance,\n                negative_prompt=negative_prompt,\n                prompt_embeds=prompt_embeds_2,\n                negative_prompt_embeds=negative_prompt_embeds_2,\n                prompt_attention_mask=prompt_attention_mask_2,\n                negative_prompt_attention_mask=negative_prompt_attention_mask_2,\n                max_sequence_length=77,\n                text_encoder_index=1,\n            )\n        else:\n            prompt_embeds_2 = None\n            negative_prompt_embeds_2 = None\n            prompt_attention_mask_2 = None\n            negative_prompt_attention_mask_2 = None\n\n        # ==================prepare inpaint data=====================================\n        if noise_strength != 0:\n            self.noise_adder = GaussianNoiseAdder(mean=np.log(noise_strength), std=0.01, clear_ratio=0)\n\n        mask_type, conditional_pixel_values_indices = self.get_mask_type_cond_indices(mask_type, conditional_pixel_values_path, conditional_pixel_values_indices, num_frames)\n\n        conditional_pixel_values = get_pixel_values(conditional_pixel_values_path, num_frames)\n\n        min_height = min([pixels.shape[2] for pixels in conditional_pixel_values])\n        min_width = min([pixels.shape[3] for pixels in conditional_pixel_values])\n\n        resize_transform = self.get_resize_transform(\n            ori_height=min_height, \n            ori_width=min_width, \n            height=height, \n            width=width, \n            crop_for_hw=crop_for_hw,\n            max_hxw=max_hxw,\n        )\n\n        video_transform = self.get_video_transform()\n        conditional_pixel_values = torch.cat([resize_transform(pixels) for pixels in conditional_pixel_values])\n        real_height, real_width = conditional_pixel_values.shape[-2], conditional_pixel_values.shape[-1]\n        # ==================prepare inpaint data=====================================\n        \n        # 4. Prepare timesteps\n        self.scheduler.set_timesteps(num_inference_steps, device=device)\n        timesteps = self.scheduler.timesteps\n\n        # 5. Prepare latent variables\n        if get_sequence_parallel_state():\n            world_size = hccl_info.world_size if torch_npu is not None else nccl_info.world_size\n        num_channels_latents = self.transformer.config.in_channels\n        latents = self.prepare_latents(\n            batch_size * num_samples_per_prompt,\n            num_channels_latents,\n            (num_frames + world_size - 1) // world_size if get_sequence_parallel_state() else num_frames, \n            real_height,\n            real_width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            latents,\n        )\n\n        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # ==============================create mask=====================================\n        masked_pixel_values, mask = self.get_masked_pixel_values_mask(\n            conditional_pixel_values,\n            conditional_pixel_values_indices,\n            mask_type, \n            batch_size, \n            num_samples_per_prompt, \n            num_frames, \n            real_height,\n            real_width,\n            video_transform,\n            prompt_embeds.dtype,\n            device\n        )\n        # ==============================create mask=====================================\n\n        # 7 create image_rotary_emb, style embedding & time ids\n        if self.do_classifier_free_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])\n            prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])\n            if self.tokenizer_2 is not None:\n                prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])\n                prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2])\n\n        prompt_embeds = prompt_embeds.to(device=device)\n        prompt_attention_mask = prompt_attention_mask.to(device=device)\n        if self.tokenizer_2 is not None:\n            prompt_embeds_2 = prompt_embeds_2.to(device=device)\n            prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device)\n\n        # ==================make sp=====================================\n        if get_sequence_parallel_state():\n            prompt_embeds = rearrange(\n                prompt_embeds, \n                'b (n x) h -> b n x h', \n                n=world_size,\n                x=prompt_embeds.shape[1] // world_size\n                ).contiguous()\n            rank = hccl_info.rank if torch_npu is not None else nccl_info.rank\n            prompt_embeds = prompt_embeds[:, rank, :, :]\n\n            latents_num_frames = latents.shape[2]\n            masked_pixel_values = masked_pixel_values[:, :, latents_num_frames * rank: latents_num_frames * (rank + 1)]\n            mask = mask[:, :, latents_num_frames * rank: latents_num_frames * (rank + 1)]\n        # ==================make sp=====================================\n\n        # 8. Denoising loop\n        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n        self._num_timesteps = len(timesteps)\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                if self.interrupt:\n                    continue\n\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                # inpaint\n                latent_model_input = torch.cat([latent_model_input, masked_pixel_values, mask], dim=1)\n\n                # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input\n                t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to(\n                    dtype=latent_model_input.dtype\n                )\n\n                # ==================prepare my shape=====================================\n                # predict the noise residual\n                if prompt_embeds.ndim == 3:\n                    prompt_embeds = prompt_embeds.unsqueeze(1)  # b l d -> b 1 l d\n                if prompt_attention_mask.ndim == 2:\n                    prompt_attention_mask = prompt_attention_mask.unsqueeze(1)  # b l -> b 1 l\n                if prompt_embeds_2 is not None and prompt_embeds_2.ndim == 2:\n                    prompt_embeds = prompt_embeds.unsqueeze(1)  # b d -> b 1 d\n                \n                attention_mask = torch.ones_like(latent_model_input)[:, 0].to(device=device)\n                # ==================prepare my shape=====================================\n\n                # ==================make sp=====================================\n                if get_sequence_parallel_state():\n                    attention_mask = attention_mask.repeat(1, world_size, 1, 1)\n                # ==================make sp=====================================\n\n                noise_pred = self.transformer(\n                    latent_model_input,\n                    attention_mask=attention_mask, \n                    encoder_hidden_states=prompt_embeds,\n                    encoder_attention_mask=prompt_attention_mask,\n                    timestep=t_expand,\n                    pooled_projections=prompt_embeds_2,\n                    return_dict=False,\n                )[0]\n\n                # perform guidance\n                if self.do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                if self.do_classifier_free_guidance and guidance_rescale > 0.0:\n                    # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf\n                    noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]\n\n                if callback_on_step_end is not None:\n                    callback_kwargs = {}\n                    for k in callback_on_step_end_tensor_inputs:\n                        callback_kwargs[k] = locals()[k]\n                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)\n\n                    latents = callback_outputs.pop(\"latents\", latents)\n                    prompt_embeds = callback_outputs.pop(\"prompt_embeds\", prompt_embeds)\n                    negative_prompt_embeds = callback_outputs.pop(\"negative_prompt_embeds\", negative_prompt_embeds)\n                    prompt_embeds_2 = callback_outputs.pop(\"prompt_embeds_2\", prompt_embeds_2)\n                    negative_prompt_embeds_2 = callback_outputs.pop(\n                        \"negative_prompt_embeds_2\", negative_prompt_embeds_2\n                    )\n\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n\n        # ==================make sp=====================================\n        if get_sequence_parallel_state():\n            latents_shape = list(latents.shape)  # b c t//sp h w\n            full_shape = [latents_shape[0] * world_size] + latents_shape[1:]  # # b*sp c t//sp h w\n            all_latents = torch.zeros(full_shape, dtype=latents.dtype, device=latents.device)\n            torch.distributed.all_gather_into_tensor(all_latents, latents)\n            latents_list = list(all_latents.chunk(world_size, dim=0))\n            latents = torch.cat(latents_list, dim=2)\n        # ==================make sp=====================================\n\n        if not output_type == \"latent\":\n            videos = self.decode_latents(latents)\n            videos = videos[:, :num_frames, :real_height, :real_width]\n        else:\n            videos = latents\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return (videos, )\n\n        return OpenSoraPipelineOutput(videos=videos)\n"
  },
  {
    "path": "opensora/sample/pipeline_opensora.py",
    "content": "\nimport inspect\nfrom typing import Callable, Dict, List, Optional, Tuple, Union\nfrom dataclasses import dataclass\nimport numpy as np\nimport torch\nfrom einops import rearrange\nfrom transformers import CLIPTextModelWithProjection, CLIPTokenizer, CLIPImageProcessor, MT5Tokenizer, T5EncoderModel\n\nfrom diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput\nfrom diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback\nfrom diffusers.image_processor import VaeImageProcessor\nfrom diffusers.models import AutoencoderKL, HunyuanDiT2DModel\nfrom diffusers.models.embeddings import get_2d_rotary_pos_embed\nfrom diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker\nfrom diffusers.schedulers import DDPMScheduler, FlowMatchEulerDiscreteScheduler\nfrom diffusers.utils import logging, BaseOutput\nfrom diffusers.utils.torch_utils import randn_tensor\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline\n\nfrom opensora.models.diffusion.opensora_v1_3.modeling_opensora import OpenSoraT2V_v1_3\ntry:\n    import torch_npu\n    from opensora.acceleration.parallel_states import get_sequence_parallel_state, hccl_info\nexcept:\n    torch_npu = None\n    from opensora.utils.parallel_states import get_sequence_parallel_state, nccl_info\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n@dataclass\nclass OpenSoraPipelineOutput(BaseOutput):\n    videos: Union[List[torch.FloatTensor], np.ndarray]\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg\ndef rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):\n    \"\"\"\n    Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and\n    Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4\n    \"\"\"\n    std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)\n    std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)\n    # rescale the results from guidance (fixes overexposure)\n    noise_pred_rescaled = noise_cfg * (std_text / std_cfg)\n    # mix with the original results from guidance by factor guidance_rescale to avoid \"plain looking\" images\n    noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg\n    return noise_cfg\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps\ndef retrieve_timesteps(\n    scheduler,\n    num_inference_steps: Optional[int] = None,\n    device: Optional[Union[str, torch.device]] = None,\n    timesteps: Optional[List[int]] = None,\n    sigmas: Optional[List[float]] = None,\n    **kwargs,\n):\n    \"\"\"\n    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles\n    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.\n\n    Args:\n        scheduler (`SchedulerMixin`):\n            The scheduler to get timesteps from.\n        num_inference_steps (`int`):\n            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`\n            must be `None`.\n        device (`str` or `torch.device`, *optional*):\n            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n        timesteps (`List[int]`, *optional*):\n            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,\n            `num_inference_steps` and `sigmas` must be `None`.\n        sigmas (`List[float]`, *optional*):\n            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,\n            `num_inference_steps` and `timesteps` must be `None`.\n\n    Returns:\n        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the\n        second element is the number of inference steps.\n    \"\"\"\n    if timesteps is not None and sigmas is not None:\n        raise ValueError(\"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values\")\n    if timesteps is not None:\n        accepts_timesteps = \"timesteps\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accepts_timesteps:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" timestep schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    elif sigmas is not None:\n        accept_sigmas = \"sigmas\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accept_sigmas:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" sigmas schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    else:\n        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n    return timesteps, num_inference_steps\n\n\nclass OpenSoraPipeline(DiffusionPipeline):\n\n    model_cpu_offload_seq = \"text_encoder->text_encoder_2->transformer->vae\"\n    _optional_components = [\n        \"text_encoder_2\",\n        \"tokenizer_2\",\n        \"text_encoder\",\n        \"tokenizer\",\n    ]\n    _callback_tensor_inputs = [\n        \"latents\",\n        \"prompt_embeds\",\n        \"negative_prompt_embeds\",\n        \"prompt_embeds_2\",\n        \"negative_prompt_embeds_2\",\n    ]\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: T5EncoderModel,\n        tokenizer: MT5Tokenizer,\n        transformer: OpenSoraT2V_v1_3,\n        scheduler: DDPMScheduler,\n        text_encoder_2: CLIPTextModelWithProjection = None,\n        tokenizer_2: CLIPTokenizer = None,\n    ):\n        super().__init__()\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            tokenizer_2=tokenizer_2,\n            transformer=transformer,\n            scheduler=scheduler,\n            text_encoder_2=text_encoder_2,\n        )\n\n    def encode_prompt(\n        self,\n        prompt: str,\n        device: torch.device = None,\n        dtype: torch.dtype = None,\n        num_samples_per_prompt: int = 1,\n        do_classifier_free_guidance: bool = True,\n        negative_prompt: Optional[str] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        prompt_attention_mask: Optional[torch.Tensor] = None,\n        negative_prompt_attention_mask: Optional[torch.Tensor] = None,\n        max_sequence_length: Optional[int] = None,\n        text_encoder_index: int = 0,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            device: (`torch.device`):\n                torch device\n            dtype (`torch.dtype`):\n                torch dtype\n            num_samples_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            prompt_attention_mask (`torch.Tensor`, *optional*):\n                Attention mask for the prompt. Required when `prompt_embeds` is passed directly.\n            negative_prompt_attention_mask (`torch.Tensor`, *optional*):\n                Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly.\n            max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt.\n            text_encoder_index (`int`, *optional*):\n                Index of the text encoder to use. `0` for T5 and `1` for clip.\n        \"\"\"\n        if dtype is None:\n            if self.text_encoder_2 is not None:\n                dtype = self.text_encoder_2.dtype\n            elif self.transformer is not None:\n                dtype = self.transformer.dtype\n            else:\n                dtype = None\n\n        if device is None:\n            device = getattr(self, '_execution_device', None) or getattr(self, 'device', None) or torch.device('cuda')\n\n\n        tokenizers = [self.tokenizer, self.tokenizer_2]\n        text_encoders = [self.text_encoder, self.text_encoder_2]\n\n        tokenizer = tokenizers[text_encoder_index]\n        text_encoder = text_encoders[text_encoder_index]\n\n        if max_sequence_length is None:\n            if text_encoder_index == 0:\n                max_length = 512\n            if text_encoder_index == 1:\n                max_length = 77\n        else:\n            max_length = max_sequence_length\n\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        if prompt_embeds is None:\n            text_inputs = tokenizer(\n                prompt,\n                padding=\"max_length\",\n                max_length=max_length,\n                truncation=True,\n                return_attention_mask=True,\n                return_tensors=\"pt\",\n            )\n            text_input_ids = text_inputs.input_ids\n            untruncated_ids = tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\n                text_input_ids, untruncated_ids\n            ):\n                removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])\n                logger.warning(\n                    \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                    f\" {tokenizer.model_max_length} tokens: {removed_text}\"\n                )\n\n            prompt_attention_mask = text_inputs.attention_mask.to(device)\n            prompt_embeds = text_encoder(\n                text_input_ids.to(device),\n                attention_mask=prompt_attention_mask,\n            )\n            prompt_embeds = prompt_embeds[0]\n\n            if text_encoder_index == 1:\n                prompt_embeds = prompt_embeds.unsqueeze(1)  # b d -> b 1 d for clip\n\n            prompt_attention_mask = prompt_attention_mask.repeat(num_samples_per_prompt, 1)\n\n        prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)\n\n\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_samples_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(bs_embed * num_samples_per_prompt, seq_len, -1)\n\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance and negative_prompt_embeds is None:\n            uncond_tokens: List[str]\n            if negative_prompt is None:\n                uncond_tokens = [\"\"] * batch_size\n            elif prompt is not None and type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif isinstance(negative_prompt, str):\n                uncond_tokens = [negative_prompt]\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = negative_prompt\n\n            # max_length = prompt_embeds.shape[1]\n            uncond_input = tokenizer(\n                uncond_tokens,\n                padding=\"max_length\",\n                max_length=max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n\n            negative_prompt_attention_mask = uncond_input.attention_mask.to(device)\n            negative_prompt_embeds = text_encoder(\n                uncond_input.input_ids.to(device),\n                attention_mask=negative_prompt_attention_mask,\n            )\n            negative_prompt_embeds = negative_prompt_embeds[0]\n            if text_encoder_index == 1:\n                negative_prompt_embeds = negative_prompt_embeds.unsqueeze(1)  # b d -> b 1 d for clip\n            negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_samples_per_prompt, 1)\n\n        if do_classifier_free_guidance:\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = negative_prompt_embeds.shape[1]\n\n            negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)\n\n            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_samples_per_prompt, 1)\n            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_samples_per_prompt, seq_len, -1)\n\n        return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    def check_inputs(\n        self,\n        prompt,\n        num_frames,\n        height,\n        width,\n        negative_prompt=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n        prompt_attention_mask=None,\n        negative_prompt_attention_mask=None,\n        prompt_embeds_2=None,\n        negative_prompt_embeds_2=None,\n        prompt_attention_mask_2=None,\n        negative_prompt_attention_mask_2=None,\n        callback_on_step_end_tensor_inputs=None,\n    ):\n        if (num_frames - 1) % 4 != 0:\n            raise ValueError(f\"`num_frames - 1` have to be divisible by 4 but is {num_frames}.\")\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if callback_on_step_end_tensor_inputs is not None and not all(\n            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs\n        ):\n            raise ValueError(\n                f\"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is None and prompt_embeds_2 is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if prompt_embeds is not None and prompt_attention_mask is None:\n            raise ValueError(\"Must provide `prompt_attention_mask` when specifying `prompt_embeds`.\")\n\n        if prompt_embeds_2 is not None and prompt_attention_mask_2 is None:\n            raise ValueError(\"Must provide `prompt_attention_mask_2` when specifying `prompt_embeds_2`.\")\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:\n            raise ValueError(\"Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.\")\n\n        if negative_prompt_embeds_2 is not None and negative_prompt_attention_mask_2 is None:\n            raise ValueError(\n                \"Must provide `negative_prompt_attention_mask_2` when specifying `negative_prompt_embeds_2`.\"\n            )\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n        if prompt_embeds_2 is not None and negative_prompt_embeds_2 is not None:\n            if prompt_embeds_2.shape != negative_prompt_embeds_2.shape:\n                raise ValueError(\n                    \"`prompt_embeds_2` and `negative_prompt_embeds_2` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds_2` {prompt_embeds_2.shape} != `negative_prompt_embeds_2`\"\n                    f\" {negative_prompt_embeds_2.shape}.\"\n                )\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents\n    def prepare_latents(self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None):\n        shape = (\n            batch_size,\n            num_channels_latents,\n            (int(num_frames) - 1) // self.vae.vae_scale_factor[0] + 1, \n            int(height) // self.vae.vae_scale_factor[1],\n            int(width) // self.vae.vae_scale_factor[2],\n        )\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        if latents is None:\n            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n        else:\n            latents = latents.to(device)\n\n        if not isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):\n            # scale the initial noise by the standard deviation required by the scheduler\n            latents = latents * self.scheduler.init_noise_sigma\n        return latents\n\n    @property\n    def guidance_scale(self):\n        return self._guidance_scale\n\n    @property\n    def guidance_rescale(self):\n        return self._guidance_rescale\n\n    @property\n    def do_classifier_free_guidance(self):\n        return self._guidance_scale > 1\n\n    @property\n    def num_timesteps(self):\n        return self._num_timesteps\n\n    @property\n    def interrupt(self):\n        return self._interrupt\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        num_frames: Optional[int] = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_inference_steps: Optional[int] = 50,\n        timesteps: List[int] = None,\n        guidance_scale: Optional[float] = 5.0,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_samples_per_prompt: Optional[int] = 1,\n        eta: Optional[float] = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        prompt_embeds_2: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds_2: Optional[torch.Tensor] = None,\n        prompt_attention_mask: Optional[torch.Tensor] = None,\n        prompt_attention_mask_2: Optional[torch.Tensor] = None,\n        negative_prompt_attention_mask: Optional[torch.Tensor] = None,\n        negative_prompt_attention_mask_2: Optional[torch.Tensor] = None,\n        output_type: Optional[str] = \"pil\",\n        return_dict: bool = True,\n        callback_on_step_end: Optional[\n            Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]\n        ] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        guidance_rescale: float = 0.0,\n        max_sequence_length: int = 512,\n        device = None, \n    ):\n        \n        if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):\n            callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs\n\n        # 0. default height and width\n        num_frames = num_frames or (self.transformer.config.sample_size_t - 1) * self.vae.vae_scale_factor[0] + 1\n        height = height or self.transformer.config.sample_size[0] * self.vae.vae_scale_factor[1]\n        width = width or self.transformer.config.sample_size[1] * self.vae.vae_scale_factor[2]\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt,\n            num_frames, \n            height,\n            width,\n            negative_prompt,\n            prompt_embeds,\n            negative_prompt_embeds,\n            prompt_attention_mask,\n            negative_prompt_attention_mask,\n            prompt_embeds_2,\n            negative_prompt_embeds_2,\n            prompt_attention_mask_2,\n            negative_prompt_attention_mask_2,\n            callback_on_step_end_tensor_inputs,\n        )\n        self._guidance_scale = guidance_scale\n        self._guidance_rescale = guidance_rescale\n        self._interrupt = False\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = device or getattr(self, '_execution_device', None) or getattr(self, 'device', None) or torch.device('cuda')\n\n\n        # 3. Encode input prompt\n\n        (\n            prompt_embeds,\n            negative_prompt_embeds,\n            prompt_attention_mask,\n            negative_prompt_attention_mask,\n        ) = self.encode_prompt(\n            prompt=prompt,\n            device=device,\n            dtype=self.transformer.dtype,\n            num_samples_per_prompt=num_samples_per_prompt,\n            do_classifier_free_guidance=self.do_classifier_free_guidance,\n            negative_prompt=negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            prompt_attention_mask=prompt_attention_mask,\n            negative_prompt_attention_mask=negative_prompt_attention_mask,\n            max_sequence_length=max_sequence_length,\n            text_encoder_index=0,\n        )\n        if self.tokenizer_2 is not None:\n            (\n                prompt_embeds_2,\n                negative_prompt_embeds_2,\n                prompt_attention_mask_2,\n                negative_prompt_attention_mask_2,\n            ) = self.encode_prompt(\n                prompt=prompt,\n                device=device,\n                dtype=self.transformer.dtype,\n                num_samples_per_prompt=num_samples_per_prompt,\n                do_classifier_free_guidance=self.do_classifier_free_guidance,\n                negative_prompt=negative_prompt,\n                prompt_embeds=prompt_embeds_2,\n                negative_prompt_embeds=negative_prompt_embeds_2,\n                prompt_attention_mask=prompt_attention_mask_2,\n                negative_prompt_attention_mask=negative_prompt_attention_mask_2,\n                max_sequence_length=77,\n                text_encoder_index=1,\n            )\n        else:\n            prompt_embeds_2 = None\n            negative_prompt_embeds_2 = None\n            prompt_attention_mask_2 = None\n            negative_prompt_attention_mask_2 = None\n\n        # 4. Prepare timesteps\n        if not isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):\n            self.scheduler.set_timesteps(num_inference_steps, device=device)\n            timesteps = self.scheduler.timesteps\n            num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n            self._num_timesteps = len(timesteps)\n        else:\n            timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)\n            num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)\n            self._num_timesteps = len(timesteps)\n        # 5. Prepare latent variables\n        if get_sequence_parallel_state():\n            world_size = hccl_info.world_size if torch_npu is not None else nccl_info.world_size\n        num_channels_latents = self.transformer.config.in_channels\n        latents = self.prepare_latents(\n            batch_size * num_samples_per_prompt,\n            num_channels_latents,\n            (num_frames + world_size - 1) // world_size if get_sequence_parallel_state() else num_frames, \n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            latents,\n        )\n\n        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        if not isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):\n            extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n        else:\n            extra_step_kwargs = {}\n        # 7 create image_rotary_emb, style embedding & time ids\n        if self.do_classifier_free_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])\n            prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])\n            if self.tokenizer_2 is not None:\n                prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])\n                prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2])\n\n        prompt_embeds = prompt_embeds.to(device=device)\n        prompt_attention_mask = prompt_attention_mask.to(device=device)\n        if self.tokenizer_2 is not None:\n            prompt_embeds_2 = prompt_embeds_2.to(device=device)\n            prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device)\n\n\n        # ==================make sp=====================================\n        if get_sequence_parallel_state():\n            prompt_embeds = rearrange(\n                prompt_embeds, \n                'b (n x) h -> b n x h', \n                n=world_size,\n                x=prompt_embeds.shape[1] // world_size\n                ).contiguous()\n            rank = hccl_info.rank if torch_npu is not None else nccl_info.rank\n            prompt_embeds = prompt_embeds[:, rank, :, :]\n        # ==================make sp=====================================\n\n        # 8. Denoising loop\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                if self.interrupt:\n                    continue\n\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents\n                if not isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):\n                    latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input\n                if not isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):\n                    timestep = torch.tensor([t] * latent_model_input.shape[0], device=device).to(\n                        dtype=latent_model_input.dtype\n                    )\n                else:\n                    timestep = t.expand(latent_model_input.shape[0])\n\n                # ==================prepare my shape=====================================\n                # predict the noise residual\n                if prompt_embeds.ndim == 3:\n                    prompt_embeds = prompt_embeds.unsqueeze(1)  # b l d -> b 1 l d\n                if prompt_attention_mask.ndim == 2:\n                    prompt_attention_mask = prompt_attention_mask.unsqueeze(1)  # b l -> b 1 l\n                if prompt_embeds_2 is not None and prompt_embeds_2.ndim == 2:\n                    prompt_embeds = prompt_embeds.unsqueeze(1)  # b d -> b 1 d\n                \n                attention_mask = torch.ones_like(latent_model_input)[:, 0].to(device=device)\n                # ==================prepare my shape=====================================\n\n                # ==================make sp=====================================\n                if get_sequence_parallel_state():\n                    attention_mask = attention_mask.repeat(1, world_size, 1, 1)\n                # ==================make sp=====================================\n\n                noise_pred = self.transformer(\n                    latent_model_input,\n                    attention_mask=attention_mask, \n                    encoder_hidden_states=prompt_embeds,\n                    encoder_attention_mask=prompt_attention_mask,\n                    timestep=timestep,\n                    pooled_projections=prompt_embeds_2,\n                    return_dict=False,\n                )[0]\n                assert not torch.any(torch.isnan(noise_pred))\n                # perform guidance\n                if self.do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                if self.do_classifier_free_guidance and guidance_rescale > 0.0 and not isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):\n                    # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf\n                    noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]\n\n                if callback_on_step_end is not None:\n                    callback_kwargs = {}\n                    for k in callback_on_step_end_tensor_inputs:\n                        callback_kwargs[k] = locals()[k]\n                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)\n\n                    latents = callback_outputs.pop(\"latents\", latents)\n                    prompt_embeds = callback_outputs.pop(\"prompt_embeds\", prompt_embeds)\n                    negative_prompt_embeds = callback_outputs.pop(\"negative_prompt_embeds\", negative_prompt_embeds)\n                    prompt_embeds_2 = callback_outputs.pop(\"prompt_embeds_2\", prompt_embeds_2)\n                    negative_prompt_embeds_2 = callback_outputs.pop(\n                        \"negative_prompt_embeds_2\", negative_prompt_embeds_2\n                    )\n\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n\n        # ==================make sp=====================================\n        if get_sequence_parallel_state():\n            latents_shape = list(latents.shape)  # b c t//sp h w\n            full_shape = [latents_shape[0] * world_size] + latents_shape[1:]  # # b*sp c t//sp h w\n            all_latents = torch.zeros(full_shape, dtype=latents.dtype, device=latents.device)\n            torch.distributed.all_gather_into_tensor(all_latents, latents)\n            latents_list = list(all_latents.chunk(world_size, dim=0))\n            latents = torch.cat(latents_list, dim=2)\n        # ==================make sp=====================================\n\n        if not output_type == \"latent\":\n            videos = self.decode_latents(latents)\n            videos = videos[:, :num_frames, :height, :width]\n        else:\n            videos = latents\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return (videos, )\n\n        return OpenSoraPipelineOutput(videos=videos)\n\n    \n    def decode_latents(self, latents):\n        print(f'before vae decode {latents.shape}', torch.max(latents).item(), torch.min(latents).item(), torch.mean(latents).item(), torch.std(latents).item())\n        video = self.vae.decode(latents.to(self.vae.vae.dtype))\n        print(f'after vae decode {latents.shape}', torch.max(video).item(), torch.min(video).item(), torch.mean(video).item(), torch.std(video).item())\n        video = ((video / 2.0 + 0.5).clamp(0, 1) * 255).to(dtype=torch.uint8).cpu().permute(0, 1, 3, 4, 2).contiguous() # b t h w c\n        return video"
  },
  {
    "path": "opensora/sample/rec_image.py",
    "content": "import sys\nsys.path.append(\".\")\nfrom PIL import Image\nimport torch\nfrom torchvision.transforms import ToTensor, Compose, Resize, Normalize, Lambda\nfrom torch.nn import functional as F\nimport argparse\nimport numpy as np\nfrom opensora.models.causalvideovae import ae_wrapper\n\ndef preprocess(video_data: torch.Tensor, short_size: int = 128) -> torch.Tensor:\n    transform = Compose(\n        [\n            ToTensor(),\n            Lambda(lambda x: 2. * x - 1.), \n            Resize(size=short_size),\n        ]\n    )\n    outputs = transform(video_data)\n    outputs = outputs.unsqueeze(0).unsqueeze(2)\n    return outputs\n\ndef main(args: argparse.Namespace):\n    image_path = args.image_path\n    short_size = args.short_size\n    device = args.device\n    kwarg = {}\n    \n    # vae = getae_wrapper(args.ae)(args.model_path, subfolder=\"vae\", cache_dir='cache_dir', **kwarg).to(device)\n    vae = ae_wrapper[args.ae](args.ae_path, **kwarg).eval().to(device)\n    if args.enable_tiling:\n        vae.vae.enable_tiling()\n        vae.vae.tile_overlap_factor = args.tile_overlap_factor\n    vae.eval()\n    vae = vae.to(device)\n    vae = vae.half()\n    \n    with torch.no_grad():\n        x_vae = preprocess(Image.open(image_path), short_size)\n        x_vae = x_vae.to(device, dtype=torch.float16)  # b c t h w\n        latents = vae.encode(x_vae)\n        latents = latents.to(torch.float16)\n        image_recon = vae.decode(latents)  # b t c h w\n    x = image_recon[0, 0, :, :, :]\n    x = x.squeeze()\n    x = x.detach().cpu().numpy()\n    x = np.clip(x, -1, 1)\n    x = (x + 1) / 2\n    x = (255*x).astype(np.uint8)\n    x = x.transpose(1,2,0)\n    image = Image.fromarray(x)\n    image.save(args.rec_path)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--image_path', type=str, default='')\n    parser.add_argument('--rec_path', type=str, default='')\n    parser.add_argument('--ae', type=str, default='')\n    parser.add_argument('--ae_path', type=str, default='')\n    parser.add_argument('--model_path', type=str, default='results/pretrained')\n    parser.add_argument('--short_size', type=int, default=336)\n    parser.add_argument('--device', type=str, default='cuda')\n    parser.add_argument('--tile_overlap_factor', type=float, default=0.25)\n    parser.add_argument('--enable_tiling', action='store_true')\n    \n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "opensora/sample/rec_video.py",
    "content": "import math\nimport random\nimport argparse\nfrom typing import Optional\n\nimport cv2\nimport numpy as np\nimport numpy.typing as npt\nimport torch\nfrom PIL import Image\nfrom decord import VideoReader, cpu\nfrom torch.nn import functional as F\nfrom pytorchvideo.transforms import ShortSideScale\nfrom torchvision.transforms import Lambda, Compose\nimport sys\nfrom opensora.models.causalvideovae import ae_wrapper\nfrom opensora.dataset.transform import ToTensorVideo, CenterCropResizeVideo\n\n\ndef array_to_video(image_array: npt.NDArray, fps: float = 30.0, output_file: str = 'output_video.mp4') -> None:\n    height, width, channels = image_array[0].shape\n    fourcc = cv2.VideoWriter_fourcc(*'mp4v')\n    video_writer = cv2.VideoWriter(output_file, fourcc, float(fps), (width, height))\n\n    for image in image_array:\n        image_rgb = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)\n        video_writer.write(image_rgb)\n\n    video_writer.release()\n\n\ndef custom_to_video(x: torch.Tensor, fps: float = 2.0, output_file: str = 'output_video.mp4') -> None:\n    x = x.detach().cpu()\n    x = torch.clamp(x, -1, 1)\n    x = (x + 1) / 2\n    x = x.permute(0, 2, 3, 1).numpy()\n    x = (255 * x).astype(np.uint8)\n    array_to_video(x, fps=fps, output_file=output_file)\n    return\n\n\ndef read_video(video_path: str, num_frames: int, sample_rate: int) -> torch.Tensor:\n    decord_vr = VideoReader(video_path, ctx=cpu(0))\n    total_frames = len(decord_vr)\n    sample_frames_len = sample_rate * num_frames\n\n    # if total_frames > sample_frames_len:\n    #     s = random.randint(0, total_frames - sample_frames_len - 1)\n    #     s = 0\n    #     e = s + sample_frames_len\n    #     num_frames = num_frames\n    # else:\n    # s = 0\n    # e = total_frames\n    # num_frames = int(total_frames / sample_frames_len * num_frames)\n    s = 0\n    e = sample_frames_len\n    print(f'sample_frames_len {sample_frames_len}, only can sample {num_frames * sample_rate}', video_path,\n            total_frames)\n\n    frame_id_list = np.linspace(s, e - 1, num_frames, dtype=int)\n    video_data = decord_vr.get_batch(frame_id_list).asnumpy()\n    video_data = torch.from_numpy(video_data)\n    video_data = video_data.permute(3, 0, 1, 2)  # (T, H, W, C) -> (C, T, H, W)\n    return video_data\n\n\ndef preprocess(video_data: torch.Tensor, height: int = 128, width: int = 128) -> torch.Tensor:\n    transform = Compose(\n        [\n            ToTensorVideo(),\n            CenterCropResizeVideo((height, width)),\n            Lambda(lambda x: 2. * x - 1.)\n        ]\n    )\n\n    video_outputs = transform(video_data)\n    video_outputs = torch.unsqueeze(video_outputs, 0)\n\n    return video_outputs\n\n\ndef main(args: argparse.Namespace):\n    device = args.device\n    kwarg = {}\n    # vae = getae_wrapper(args.ae)(args.model_path, subfolder=\"vae\", cache_dir='cache_dir', **kwarg).to(device)\n    # vae = CausalVAEModelWrapper(args.ae_path, **kwarg).to(device)\n    vae = ae_wrapper[args.ae](args.ae_path, **kwarg).eval().to(device)\n    if args.enable_tiling:\n        vae.vae.enable_tiling()\n        vae.vae.tile_overlap_factor = args.tile_overlap_factor\n        # vae.vae.tile_sample_min_size = 512\n        # vae.vae.tile_latent_min_size = 64\n        # vae.vae.tile_sample_min_size_t = 29\n        # vae.vae.tile_latent_min_size_t = 8\n        # if args.save_memory:\n        #     vae.vae.tile_sample_min_size = 256\n        #     vae.vae.tile_latent_min_size = 32\n        #     vae.vae.tile_sample_min_size_t = 9\n        #     vae.vae.tile_latent_min_size_t = 3\n    dtype = torch.float32\n    vae.eval()\n    vae = vae.to(device, dtype=dtype)\n    \n    with torch.no_grad():\n        x_vae = preprocess(read_video(args.video_path, args.num_frames, args.sample_rate), args.height,\n                           args.width)\n        print(x_vae.shape)\n        x_vae = x_vae.to(device, dtype=dtype)  # b c t h w\n        # for i in range(10000):\n        latents = vae.encode(x_vae)\n        print(latents.shape)\n        latents = latents.to(dtype)\n        video_recon = vae.decode(latents)  # b t c h w\n        print(video_recon.shape)\n\n\n    \n    # vae = vae.half()\n    # from tqdm import tqdm\n    # with torch.no_grad():\n    #     x_vae = torch.rand(1, 3, 93, 720, 1280)\n    #     print(x_vae.shape)\n    #     x_vae = x_vae.to(device, dtype=torch.float16)  # b c t h w\n    #     # x_vae = x_vae.to(device)  # b c t h w\n    #     for i in tqdm(range(100000)):\n    #         latents = vae.encode(x_vae)\n    #     print(latents.shape)\n    #     latents = latents.to(torch.float16)\n    #     video_recon = vae.decode(latents)  # b t c h w\n    #     print(video_recon.shape)\n\n\n    custom_to_video(video_recon[0], fps=args.fps, output_file=args.rec_path)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--video_path', type=str, default='')\n    parser.add_argument('--rec_path', type=str, default='')\n    parser.add_argument('--ae', type=str, default='')\n    parser.add_argument('--ae_path', type=str, default='')\n    parser.add_argument('--model_path', type=str, default='results/pretrained')\n    parser.add_argument('--fps', type=int, default=30)\n    parser.add_argument('--height', type=int, default=336)\n    parser.add_argument('--width', type=int, default=336)\n    parser.add_argument('--num_frames', type=int, default=100)\n    parser.add_argument('--sample_rate', type=int, default=1)\n    parser.add_argument('--device', type=str, default=\"cuda\")\n    parser.add_argument('--tile_overlap_factor', type=float, default=0.25)\n    parser.add_argument('--tile_sample_min_size', type=int, default=512)\n    parser.add_argument('--tile_sample_min_size_t', type=int, default=33)\n    parser.add_argument('--tile_sample_min_size_dec', type=int, default=256)\n    parser.add_argument('--tile_sample_min_size_dec_t', type=int, default=33)\n    parser.add_argument('--enable_tiling', action='store_true')\n    parser.add_argument('--save_memory', action='store_true')\n\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "opensora/sample/sample.py",
    "content": "import os\nimport torch\ntry:\n    import torch_npu\n    from opensora.npu_config import npu_config\nexcept:\n    torch_npu = None\n    npu_config = None\n    pass\nfrom opensora.utils.sample_utils import (\n    init_gpu_env, init_npu_env, prepare_pipeline, get_args, \n    run_model_and_save_samples, run_model_and_save_samples_npu\n)\nfrom opensora.sample.caption_refiner import OpenSoraCaptionRefiner\n\nif __name__ == \"__main__\":\n    args = get_args()\n    dtype = torch.float16\n\n    if torch_npu is not None:\n        npu_config.print_msg(args)\n        npu_config.conv_dtype = dtype\n        init_npu_env(args)\n    else:\n        args = init_gpu_env(args)\n\n    device = torch.cuda.current_device()\n    if args.num_frames != 1 and args.enhance_video is not None:\n        from opensora.sample.VEnhancer.enhance_a_video import VEnhancer\n        enhance_video_model = VEnhancer(model_path=args.enhance_video, version='v2', device=device)\n    else:\n        enhance_video_model = None\n    pipeline = prepare_pipeline(args, dtype, device)\n    if args.caption_refiner is not None:\n        caption_refiner_model = OpenSoraCaptionRefiner(args, dtype, device)\n    else:\n        caption_refiner_model = None\n\n    if npu_config is not None and npu_config.on_npu and npu_config.profiling:\n        run_model_and_save_samples_npu(args, pipeline, caption_refiner_model, enhance_video_model)\n    else:\n        run_model_and_save_samples(args, pipeline, caption_refiner_model, enhance_video_model)\n"
  },
  {
    "path": "opensora/serve/gradio_utils.py",
    "content": "import random\n\nimport imageio\nimport uuid\nimport torch\n\nimport numpy as np\n\n\nPOS_PROMPT = \"\"\"\n    high quality, high aesthetic, {}\n    \"\"\"\n\nNEG_PROMPT = \"\"\"\n    nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, \n    low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry.\n    \"\"\"\n\nNUM_IMAGES_PER_PROMPT = 1\nMAX_SEED = np.iinfo(np.int32).max\n\ndef randomize_seed_fn(seed: int, randomize_seed: bool) -> int:\n    if randomize_seed:\n        seed = random.randint(0, MAX_SEED)\n    return seed\n\n\nLOGO = \"\"\"\n    <center><img src='https://s21.ax1x.com/2024/07/14/pk5pLBF.jpg' alt='Open-Sora Plan logo' style=\"width:220px; margin-bottom:1px\"></center>\n\"\"\"\nTITLE = \"\"\"\n    <div style=\"text-align: center; font-size: 45px; font-weight: bold; margin-bottom: 5px;\">\n        Open-Sora Plan🤗\n    </div>\n\"\"\"\nDESCRIPTION = \"\"\"\n    <div style=\"text-align: center; font-size: 16px; font-weight: bold; margin-bottom: 5px;\">\n        Support Chinese and English; 支持中英双语\n    </div>\n    <div style=\"text-align: center; font-size: 16px; font-weight: bold; margin-bottom: 5px;\">\n        Welcome to Star🌟 our <a href='https://github.com/PKU-YuanGroup/Open-Sora-Plan' target='_blank'><b>GitHub</b></a>\n    </div>\n\"\"\"\n\nt2v_prompt_examples = [\n    \"动画场景特写中，一个矮小、毛茸茸的怪物跪在一根融化的红蜡烛旁。三维写实的艺术风格注重光照和纹理的相互作用，在整个场景中投射出引人入胜的阴影。怪物睁着好奇的大眼睛注视着火焰，它的皮毛在温暖闪烁的光芒中轻轻拂动。镜头慢慢拉近，捕捉到怪物皮毛的复杂细节和精致的熔蜡液滴。怪物试探性地伸出一只爪子，似乎想要触碰火焰，而烛光则在它周围闪烁舞动，气氛充满了惊奇和好奇。\", \n    \"An animated scene features a close-up of a short, fluffy monster kneeling beside a melting red candle. The 3D, realistic art style focuses on the interplay of lighting and texture, casting intriguing shadows across the scene. The monster gazes at the flame with wide, curious eyes, its fur gently ruffling in the warm, flickering glow. The camera slowly zooms in, capturing the intricate details of the monster's fur and the delicate, molten wax droplets. The atmosphere is filled with a sense of wonder and curiosity, as the monster tentatively reaches out a paw, as if to touch the flame, while the candlelight dances and flickers around it.\", \n    \"特写镜头捕捉到一只维多利亚皇冠鸽，其醒目的蓝色羽毛和鲜艳的红色胸部格外显眼。这只鸽子精致的花边鸽冠和醒目的红眼更增添了它的威严。鸽子的头部略微偏向一侧，给人一种威严的感觉。背景被模糊处理，使人们的注意力集中在鸽子引人注目的特征上。柔和的光线洒在画面上，投下柔和的阴影，增强了鸽子羽毛的质感。鸽子微微扇动翅膀，嘴角向上翘起，似乎在好奇地观察周围的环境，营造出一种动感迷人的氛围。\", \n    \"A close-up shot captures a Victoria crowned pigeon, its striking blue plumage and vibrant red chest standing out prominently. The bird's delicate, lacy crest and striking red eye add to its regal appearance. The pigeon's head is tilted slightly to the side, giving it a majestic look. The background is blurred, drawing attention to the bird's striking features. Soft light bathes the scene, casting gentle shadows that enhance the texture of its feathers. The pigeon flutters its wings slightly, and its beak tilts upwards, as if curiously observing the surroundings, creating a dynamic and captivating atmosphere.\", \n    \"一架无人机捕捉到了大苏尔加雷角海滩上海浪拍打着崎岖悬崖的壮丽景色。湛蓝的海水拍打出白色的浪花，夕阳的金光照亮了岩石海岸，投下长长的阴影，营造出温暖宁静的氛围。远处矗立着一座小岛，岛上有一座灯塔，更增添了画面的魅力。海鸥在头顶上滑翔，海风吹过附近的植被，沙沙作响，给宁静的海岸景观带来了勃勃生机。\", \n    \"A drone captures a breathtaking view of waves crashing against the rugged cliffs along Big Sur's Garay Point Beach. The crashing blue waters create white-tipped waves, while the golden light of the setting sun illuminates the rocky shore, casting long shadows and creating a warm, serene atmosphere. A small island with a lighthouse stands in the distance, adding to the scene's charm. Seagulls glide overhead as the ocean breeze rustles through the nearby vegetation, bringing life to the tranquil coastal landscape.\", \n    \"一个二十岁出头的年轻人，头发蓬松，鼻梁上架着一副眼镜，安详地坐在高高飘扬的蓬松白云上。他全神贯注地读着一本书，偶尔抬起头看一眼周围翱翔的鸟儿。阳光透过飘渺的云层，在这幅画面上洒下柔和的金色光芒，并在他的脸上投下俏皮的影子。当他翻开书页时，一阵微风吹过，书页沙沙作响，他微笑着，感受着失重和自由的快感。\", \n    \"A young man in his early twenties, with tousled hair and a pair of glasses perched on the end of his nose, sits serenely on a fluffy, white cloud floating high in the sky. He is engrossed in a book, occasionally glancing up to watch the birds soar around him. The sunlight filters through the wispy clouds, casting a soft, golden glow over the scene and creating playful shadows that dance on his face. As he turns a page, a gentle breeze rustles the pages, and he smiles, feeling the thrill of weightlessness and freedom.\", \n    \"三维动画描绘了一只圆滚滚、毛茸茸的小动物，它有一双富于表情的大眼睛，正在探索一片生机勃勃的魔法森林。这个异想天开的生物是兔子和松鼠的混合体，长着柔软的蓝色皮毛和浓密的条纹尾巴。它沿着波光粼粼的溪流蹦蹦跳跳，眼睛睁得大大的，充满了好奇。森林里充满了神奇的元素：会发光和变色的花朵、长着紫色和银色树叶的树木，还有像萤火虫一样的小浮光。它跳着跳着，停了下来，与一群围着蘑菇圈跳舞的小精灵嬉戏互动。然后，它抬头敬畏地看着一棵发光的大树，这棵树似乎是森林的核心。摄像机平稳地摇镜头，捕捉到这只小动物好奇地伸手触摸一朵发光的花朵，花朵随之变色。整个场景沐浴在柔和、空灵的光线中，背景中的阴影轻轻舞动，营造出一种令人陶醉和惊奇的氛围。小动物的嬉戏打闹和神奇的氛围让森林变得生机勃勃，仿佛每一刻都是一次发现和喜悦。\", \n    \"A 3D animation depicts a small, round, fluffy creature with big, expressive eyes exploring a vibrant, enchanted forest. This whimsical creature, a blend of a rabbit and a squirrel, has soft blue fur and a bushy, striped tail. It hops along a sparkling stream, its eyes wide with wonder. The forest is alive with magical elements: flowers that glow and change colors, trees with leaves in shades of purple and silver, and small floating lights that resemble fireflies. As the creature hops, it pauses to interact playfully with a group of tiny, fairy-like beings dancing around a mushroom ring. It then looks up in awe at a large, glowing tree that seems to be the heart of the forest. The camera pans smoothly to capture the creature's curiosity as it reaches out to touch a glowing flower, causing it to change colors. The scene is bathed in a soft, ethereal light, with shadows dancing gently in the background, creating an atmosphere of enchantment and wonder. The creature's playful antics and the magical ambiance make the forest come alive, as if every moment is a discovery and a delight.\", \n    \"一架无人机优雅地环绕着阿马尔菲海岸崎岖不平的山顶上一座历史悠久的教堂，拍摄其宏伟的建筑细节以及层层叠叠的小径和天井。下方，海浪拍打着岩石，地平线延伸至意大利的沿海水域和丘陵地貌。远处的身影在天井中漫步，欣赏着壮丽的海景，营造出一幅动感十足的画面。午后和煦的阳光让整个场景沐浴在神奇而浪漫的光影中，投下长长的阴影，为迷人的景色增添了深度。镜头不时拉近以突出教堂错综复杂的细节，然后拉远以展示广阔的海岸线，营造出引人入胜的视觉叙事效果。\", \n    \"A drone camera gracefully circles a historic church perched on a rugged outcropping along the Amalfi Coast, capturing its magnificent architectural details and tiered pathways and patios. Below, waves crash against the rocks, while the horizon stretches out over the coastal waters and hilly landscapes of Italy. Distant figures stroll and enjoy the breathtaking ocean views from the patios, creating a dynamic scene. The warm glow of the afternoon sun bathes the scene in a magical and romantic light, casting long shadows and adding depth to the stunning vista. The camera occasionally zooms in to highlight the intricate details of the church, then pans out to showcase the expansive coastline, creating a captivating visual narrative.\", \n    \"一个特写镜头捕捉到一位 60 多岁、留着胡子的白发老人，他坐在巴黎的一家咖啡馆里陷入沉思，思考着宇宙的历史。他的眼睛紧紧盯着屏幕外走动的人们，而自己却一动不动。他身着羊毛大衣、纽扣衬衫、棕色贝雷帽，戴着一副眼镜，散发着教授的风范。他偶尔瞥一眼四周，目光停留在背景中熙熙攘攘的巴黎街道和城市景观上。场景沐浴在金色的光线中，让人联想到 35 毫米电影胶片。当他微微前倾时，眼睛睁大，露出顿悟的瞬间，并微微闭口微笑，暗示他已经找到了生命奥秘的答案。景深营造出光影交错的动态效果，烘托出智慧沉思的氛围。\", \n    \"An extreme close-up captures a gray-haired man with a beard in his 60s, deep in thought as he sits at a Parisian cafe, contemplating the history of the universe. His eyes focus intently on people walking offscreen, while he remains mostly motionless. Dressed in a wool coat, a button-down shirt, a brown beret, and glasses, he exudes a professorial demeanor. The man occasionally glances around, his gaze lingering on the bustling Parisian streets and cityscape in the background. The scene is bathed in golden light, reminiscent of a cinematic 35mm film. As he leans forward slightly, his eyes widen in a moment of epiphany, and he offers a subtle, closed-mouth smile, suggesting he has found the answer to the mystery of life. The depth of field creates a dynamic interplay of light and shadow, enhancing the atmosphere of intellectual contemplation.\", \n    \"一只欢快的水獭穿着明黄色的救生衣，自信地在冲浪板上保持平衡，在郁郁葱葱的热带岛屿附近波光粼粼的绿松石水域中滑行。该场景采用三维数字艺术风格渲染，阳光在水面上投下俏皮的阴影。水獭不时将爪子伸入水中，溅起的水珠捕捉到光线，为宁静的氛围增添了动感和刺激。\", \n    \"A cheerful otter confidently balances on a surfboard, donning a bright yellow lifejacket, as it glides through the shimmering turquoise waters near lush tropical islands. The scene is rendered in a 3D digital art style, with the sunlight casting playful shadows on the water's surface. The otter occasionally dips its paws into the water, sending up sprays of droplets that catch the light, adding a sense of motion and excitement to the tranquil atmosphere.\", \n    \"在这幅迷人的特写镜头中，一只变色龙展示了它非凡的变色能力，在柔和的散射光中，它鲜艳的色调微妙地变换着。模糊的背景凸显了变色龙醒目的外表，而光影的交错则突出了变色龙皮肤的复杂细节。\", \n    \"In this captivating close-up shot, a chameleon displays its remarkable color-changing abilities, its vibrant hues shifting subtly in the soft, diffused light. The blurred background highlights the animal's striking appearance, while the interplay of light and shadow accentuates the intricate details of its skin.\", \n    \"圣托里尼在蓝色时刻的壮丽鸟瞰图捕捉到了白色基克拉迪建筑与蓝色圆顶的迷人建筑，在黄昏的天空中投射出长长的阴影。火山口的景色令人惊叹，光与影的交织营造出宁静的氛围。当太阳落到地平线以下时，夕阳的余晖将整个场景笼罩在温暖的金色中，海鸥在空中优雅地翱翔，几艘帆船在下方的火山口悠闲地漂流。\", \n    \"A breathtaking aerial view of Santorini during the blue hour captures the stunning architecture of white Cycladic buildings with blue domes, casting long shadows against the twilight sky. The caldera views are awe-inspiring, with the interplay of light and shadow creating a serene atmosphere. As the sun dips below the horizon, the gentle glow of the setting sun bathes the scene in a warm, golden hue, while seagulls soar gracefully through the air and a few sailboats drift lazily in the caldera below.\", \n    \"一群羊驼在鲜艳的涂鸦墙前自信地摆着姿势，每只羊驼都穿着五颜六色的羊毛针织衫，戴着时尚的太阳镜。在正午明媚的阳光下，它们嬉戏互动，有的好奇地东张西望，有的则亲昵地偎依在一起。光与影的鲜明对比增强了这一场景的动感活力，营造出一种融合了都市前卫与奇异魅力的氛围。\", \n    \"A group of alpacas, each donning colorful knit wool sweaters and stylish sunglasses, pose confidently against a vibrant graffiti-covered wall. Under the bright midday sun, they interact playfully with one another, some glancing around curiously while others nuzzle affectionately. The scene's dynamic energy is heightened by the stark interplay of light and shadow, creating an atmosphere that blends urban edginess with whimsical charm.\", \n    \"一只充满活力的动画兔子，身穿俏皮的粉色滑雪服，在湛蓝的天空下，熟练地从积雪的山坡上滑下。兔子充满活力地跳跃和旋转，在闪闪发光的雪地上投下动态阴影，而阳光的明亮光线则凸显了闪闪发光的景观，营造出一种欢快的氛围。当兔子下降时，它的流畅动作被广角镜头捕捉到，增加了速度感和刺激感。\", \n    \"A vibrant animated rabbit, dressed in a playful pink snowboarding outfit, expertly carves its way down a snowy mountain slope under a clear blue sky. The rabbit performs energetic jumps and spins, casting dynamic shadows on the glistening snow, while the sun's bright rays highlight the sparkling landscape, creating an atmosphere of joyful exhilaration. As the rabbit descends, its fluid motions are captured in a sweeping camera angle, adding to the sense of speed and excitement.\", \n    \"食物镜头，完美的汉堡，配上奶酪和生菜，微距拍摄，旋转拍摄，推拉镜头\", \n    \"food shot, a perfect burger in a bun with cheese and lettuce, macro shot, rotating shot, dolly in\",  \n    \"这幅肖像画描绘了一只长着蓝眼睛的橘色猫，缓缓旋转，灵感来自维米尔的《戴珍珠耳环的少女》。这只猫戴着珍珠耳环，棕色的皮毛像荷兰帽一样，背景为黑色，在工作室灯光的映衬下显得格外明亮。\", \n    \"This portrait depicts an orange cat with blue eyes, slowly rotating, inspired by Vermeer ’s ’Girl with a Pearl Earring’. The cat is adorned with pearl earrings and has brown fur styled like a Dutch cap against a black background, illuminated by studio lighting.\", \n    \"一只熊猫在竹林下弹奏吉他，它的爪子轻轻拨动琴弦，一群着迷的兔子观看着，音乐与竹叶的沙沙声融为一体。高清。\",  \n    \"A panda strumming a guitar under a bamboo grove, its paws gently plucking the strings as a group of mesmerized rabbits watch, the music blending with the rustle of bamboo leaves. HD.\", \n    \"雪花玻璃球摇晃后，会呈现出一座微型城市，雪花实际上是闪闪发光的星星。建筑物亮起，反射着天上的雪花，微小的人影在街道上移动，他们的路径被柔和的星光照亮，营造出神奇、宁静的城市景观。高清。\", \n    \"A snow globe, when shaken, reveals a miniature city where the snowflakes are actually glowing stars. The buildings light up, reflecting the celestial snowfall, and tiny figures move through the streets, their paths illuminated by the gentle starlight, creating a magical, peaceful urban landscape. HD.\",  \n    \"魔术师水晶球的特写，展现了水晶球内部的未来城市景观。摩天大楼的光影直冲云霄，飞行汽车在空中飞驰，在水晶球表面投射出霓虹灯的反光。8K。\", \n    \"A close-up of a magician’s crystal ball that reveals a futuristic cityscape within. Skyscrapers of light stretch towards the heavens, and flying cars zip through the air, casting neon reflections across the ball’s surface. 8K.\", \n]\n\n\nstyle_list = [\n    {\n        \"name\": \"(Default)\",\n        \"prompt\": \"(masterpiece), (best quality), (ultra-detailed), (unwatermarked), {prompt}\",\n        \"negative_prompt\": NEG_PROMPT,\n    },\n    {\n        \"name\": \"Cinematic\",\n        \"prompt\": \"cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy\",\n        \"negative_prompt\": \"anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured. \",\n    },\n    {\n        \"name\": \"Photographic\",\n        \"prompt\": \"cinematic photo, a close-up of  {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed\",\n        \"negative_prompt\": \"drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly. \",\n    },\n    {\n        \"name\": \"Anime\",\n        \"prompt\": \"anime artwork {prompt} . anime style, key visual, vibrant, studio anime,  highly detailed\",\n        \"negative_prompt\": \"photo, deformed, black and white, realism, disfigured, low contrast. \",\n    },\n    {\n        \"name\": \"Manga\",\n        \"prompt\": \"manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style\",\n        \"negative_prompt\": \"ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style. \",\n    },\n    {\n        \"name\": \"Digital Art\",\n        \"prompt\": \"concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed\",\n        \"negative_prompt\": \"photo, photorealistic, realism, ugly. \",\n    },\n    {\n        \"name\": \"Pixel art\",\n        \"prompt\": \"pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics\",\n        \"negative_prompt\": \"sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic. \",\n    },\n    {\n        \"name\": \"Fantasy art\",\n        \"prompt\": \"ethereal fantasy concept art of  {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy\",\n        \"negative_prompt\": \"photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white. \",\n    },\n    {\n        \"name\": \"Neonpunk\",\n        \"prompt\": \"neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional\",\n        \"negative_prompt\": \"painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured. \",\n    },\n    {\n        \"name\": \"3D Model\",\n        \"prompt\": \"professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting\",\n        \"negative_prompt\": \"ugly, deformed, noisy, low poly, blurry, painting. \",\n    },\n]\n"
  },
  {
    "path": "opensora/serve/gradio_web_server.py",
    "content": "import gradio as gr\nimport os\nimport torch\nfrom einops import rearrange\nimport torch.distributed as dist\nfrom torchvision.utils import save_image\nimport imageio\nimport math\nimport argparse\nimport random\nimport numpy as np\nimport string\n\nfrom opensora.sample.caption_refiner import OpenSoraCaptionRefiner\nfrom opensora.utils.sample_utils import (\n    prepare_pipeline, save_video_grid, init_gpu_env\n)\nfrom .gradio_utils import *\n\n\n\n@torch.no_grad()\n@torch.inference_mode()\ndef generate(\n        prompt: str,\n        seed: int = 0,\n        num_frames: int = 29, \n        num_samples: int = 1, \n        guidance_scale: float = 4.5,\n        num_inference_steps: int = 25,\n        randomize_seed: bool = False,\n        progress=gr.Progress(track_tqdm=False),\n):\n    seed = int(randomize_seed_fn(seed, randomize_seed))\n    if seed is not None:\n        torch.manual_seed(seed)\n    if not os.path.exists(args.save_img_path):\n        os.makedirs(args.save_img_path, exist_ok=True)\n\n    video_grids = []\n    text_prompt = [prompt]\n\n    \n\n    for index, prompt in enumerate(text_prompt):\n        if caption_refiner_model is not None:\n            refine_prompt = caption_refiner_model.get_refiner_output(prompt)\n            print(f'\\nOrigin prompt: {prompt}\\n->\\nRefine prompt: {refine_prompt}')\n            prompt = refine_prompt\n        input_prompt = POS_PROMPT.format(prompt)\n        videos = pipeline(\n            input_prompt, \n            negative_prompt=NEG_PROMPT, \n            num_frames=num_frames,\n            height=352,\n            width=640,\n            num_inference_steps=num_inference_steps,\n            guidance_scale=guidance_scale,\n            num_samples_per_prompt=num_samples,\n            max_sequence_length=512,\n            device=device, \n            ).videos\n        if num_frames != 1 and enhance_video_model is not None:\n            # b t h w c\n            videos = enhance_video_model.enhance_a_video(videos, input_prompt, 2.0, args.fps, 250)\n        if num_frames == 1:\n            videos = rearrange(videos, 'b t h w c -> (b t) c h w')\n            if num_samples != 1:\n                for i, image in enumerate(videos):\n                    save_image(\n                        image / 255.0, \n                        os.path.join(\n                            args.save_img_path, \n                            f'{args.sample_method}_{index}_gs{guidance_scale}_s{num_inference_steps}_i{i}.jpg'\n                            ),\n                        nrow=math.ceil(math.sqrt(videos.shape[0])), \n                        normalize=True, \n                        value_range=(0, 1)\n                        )  # b c h w\n            save_image(\n                videos / 255.0, \n                os.path.join(\n                    args.save_img_path, \n                    f'{args.sample_method}_{index}_gs{guidance_scale}_s{num_inference_steps}.jpg'\n                    ),\n                nrow=math.ceil(math.sqrt(videos.shape[0])), \n                normalize=True, \n                value_range=(0, 1)\n                )  # b c h w\n        else:\n            if num_samples == 1:\n                imageio.mimwrite(\n                    os.path.join(\n                        args.save_img_path,\n                        f'{args.sample_method}_{index}_gs{guidance_scale}_s{num_inference_steps}.mp4'\n                    ), \n                    videos[0],\n                    fps=args.fps, \n                    quality=6\n                    )  # highest quality is 10, lowest is 0\n            else:\n                for i in range(num_samples):\n                    imageio.mimwrite(\n                        os.path.join(\n                            args.save_img_path,\n                            f'{args.sample_method}_{index}_gs{guidance_scale}_s{num_inference_steps}_i{i}.mp4'\n                        ), videos[i],\n                        fps=args.fps, \n                        quality=6\n                        )  # highest quality is 10, lowest is 0\n                    \n                videos = save_video_grid(videos)\n                imageio.mimwrite(\n                    os.path.join(\n                        args.save_img_path,\n                        f'{args.sample_method}_{index}_gs{guidance_scale}_s{num_inference_steps}.mp4'\n                    ), \n                    videos,\n                    fps=args.fps, \n                    quality=6\n                    )  # highest quality is 10, lowest is 0)\n                videos = videos.unsqueeze(0) # 1 t h w c\n        video_grids.append(videos)\n\n    video_grids = torch.cat(video_grids, dim=0)\n    \n    final_path = os.path.join(\n                    args.save_img_path,\n                    f'{args.sample_method}_gs{guidance_scale}_s{num_inference_steps}'\n                    )\n\n    random_string = ''.join(random.choices(string.ascii_letters, k=4))\n    if num_frames == 1:\n        final_path = final_path + f'_{random_string}.jpg'\n        save_image(\n            video_grids / 255.0, \n            final_path, \n            nrow=math.ceil(math.sqrt(len(video_grids))), \n            normalize=True, \n            value_range=(0, 1)\n            )\n    else:\n        video_grids = save_video_grid(video_grids)\n        final_path = final_path + f'_{random_string}.mp4'\n        imageio.mimwrite(\n            final_path, \n            video_grids, \n            fps=args.fps, \n            quality=6\n            )\n    print('save path {}'.format(args.save_img_path))\n    return final_path, seed\n\n\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\"--model_path\", type=str, default='LanguageBind/Open-Sora-Plan-v1.0.0')\nparser.add_argument(\"--version\", type=str, default='v1_3', choices=['v1_3', 'v1_5'])\nparser.add_argument(\"--caption_refiner\", type=str, default=None)\nparser.add_argument(\"--ae\", type=str, default='CausalVAEModel_4x8x8')\nparser.add_argument(\"--ae_path\", type=str, default='CausalVAEModel_4x8x8')\nparser.add_argument(\"--text_encoder_name_1\", type=str, default='DeepFloyd/t5-v1_1-xxl')\nparser.add_argument(\"--text_encoder_name_2\", type=str, default=None)\nparser.add_argument(\"--save_img_path\", type=str, default=\"./test_gradio\")\nparser.add_argument(\"--fps\", type=int, default=18)\nparser.add_argument('--enable_tiling', action='store_true')\nparser.add_argument('--save_memory', action='store_true')\nparser.add_argument('--compile', action='store_true') \nparser.add_argument(\"--gradio_port\", type=int, default=11900)\nparser.add_argument(\"--local_rank\", type=int, default=0)\nparser.add_argument(\"--enhance_video\", type=str, default=None)\nparser.add_argument(\"--model_type\", type=str, default='t2v')\n\nparser.add_argument(\"--cache_dir\", type=str, default=\"cache_dir\")\n\nparser.add_argument(\"--prediction_type\", type=str, default=\"v_prediction\")\nparser.add_argument('--v1_5_scheduler', action='store_true') \n\nparser.add_argument('--sample_method', type=str, default='EulerAncestralDiscrete') \nargs = parser.parse_args()\n\n\nargs.sp = False\nargs.rescale_betas_zero_snr = True\n\ndtype = torch.bfloat16\n# args = init_gpu_env(args)\ndevice = torch.cuda.current_device()\n\nif args.enhance_video is not None:\n    from opensora.sample.VEnhancer.enhance_a_video import VEnhancer\n    enhance_video_model = VEnhancer(model_path=args.enhance_video, version='v2', device=device)\nelse:\n    enhance_video_model = None\n\npipeline = prepare_pipeline(args, dtype, device)\nif args.caption_refiner is not None:\n    caption_refiner_model = OpenSoraCaptionRefiner(args, dtype, device)\nelse:\n    caption_refiner_model = None\n\nwith gr.Blocks(css=\"style.css\") as demo:\n    gr.Markdown(LOGO)\n    gr.Markdown(TITLE)\n    gr.Markdown(DESCRIPTION)\n\n    with gr.Row(equal_height=False):\n        with gr.Group():\n            with gr.Row():\n                seed = gr.Slider(\n                    label=\"Seed\",\n                    minimum=0,\n                    maximum=MAX_SEED,\n                    step=1,\n                    value=0,\n                )\n            randomize_seed = gr.Checkbox(label=\"Randomize seed\", value=True)\n            with gr.Row():\n                num_frames = gr.Slider(\n                        label=\"Num Frames\",\n                        minimum=1,\n                        maximum=93,\n                        step=16,\n                        value=29,\n                    )\n                num_samples = gr.Slider(\n                        label=\"Num Samples\",\n                        minimum=1,\n                        maximum=4,\n                        step=1,\n                        value=1,\n                    )\n            with gr.Row():\n                guidance_scale = gr.Slider(\n                    label=\"Guidance scale\",\n                    minimum=1,\n                    maximum=10,\n                    step=0.1,\n                    value=7.5,\n                )\n                inference_steps = gr.Slider(\n                    label=\"Inference steps\",\n                    minimum=10,\n                    maximum=200,\n                    step=1,\n                    value=50,\n                )\n        with gr.Group():\n            with gr.Row():\n                prompt = gr.Text(\n                    label=\"Prompt\",\n                    show_label=False,\n                    max_lines=1,\n                    placeholder=\"Enter your prompt\",\n                    container=False,\n                )\n                run_button = gr.Button(\"Run\", scale=0)\n            result = gr.Video(autoplay=True, label=\"Result\")\n            # result = gr.Gallery(label=\"Result\", columns=NUM_IMAGES_PER_PROMPT,  show_label=False)\n\n\n    \n\n    with gr.Row(), gr.Column():\n        gr.Markdown(\"## Examples (Text-to-Video)\")\n        examples = [[i, 42, 93, 1, 7.5, 100, True] for i in t2v_prompt_examples]\n        gr.Examples(\n            examples=examples, \n            inputs=[\n                prompt, seed, num_frames, num_samples, \n                guidance_scale, inference_steps, randomize_seed\n                ],\n            label='Text-to-Video', \n            cache_examples=False, \n            outputs=[result, seed],\n            fn=generate\n            )\n\n\n    gr.on(\n        triggers=[\n            prompt.submit,\n            run_button.click,\n        ],\n        fn=generate,\n        inputs=[\n            prompt,\n            seed,\n            num_frames, \n            num_samples, \n            guidance_scale,\n            inference_steps,\n            randomize_seed,\n        ],\n        outputs=[result, seed],\n        api_name=\"run\",\n    )\n\n\n\n# if __name__ == \"__main__\":\ndemo.queue(max_size=20).launch(\n    server_name=\"0.0.0.0\", \n    server_port=args.gradio_port+args.local_rank, \n    debug=True\n    )\n"
  },
  {
    "path": "opensora/serve/gradio_web_server_i2v.py",
    "content": "import gradio as gr\nimport os\nimport torch\nfrom einops import rearrange\nimport torch.distributed as dist\nfrom torchvision.utils import save_image\nimport imageio\nimport math\nimport argparse\nimport random\nimport numpy as np\nimport string\n\nfrom opensora.sample.caption_refiner import OpenSoraCaptionRefiner\nfrom opensora.utils.sample_utils import (\n    prepare_pipeline, save_video_grid, init_gpu_env\n)\nfrom .gradio_utils import *\n\n\n\n@torch.no_grad()\n@torch.inference_mode()\ndef generate(\n        prompt: str,\n        image_1: str, \n        image_2: str = None, \n        seed: int = 0,\n        num_frames: int = 29, \n        num_samples: int = 1, \n        guidance_scale: float = 4.5,\n        num_inference_steps: int = 25,\n        randomize_seed: bool = False,\n        progress=gr.Progress(track_tqdm=True),\n):\n    seed = int(randomize_seed_fn(seed, randomize_seed))\n    if seed is not None:\n        torch.manual_seed(seed)\n    if not os.path.exists(args.save_img_path):\n        os.makedirs(args.save_img_path, exist_ok=True)\n\n    video_grids = []\n    text_prompt = [prompt]\n    images = [[image_1] if image_2 is None else [image_1, image_2]]\n    \n\n    for index, (image, prompt) in enumerate(zip(images, text_prompt)):\n        if caption_refiner_model is not None:\n            refine_prompt = caption_refiner_model.get_refiner_output(prompt)\n            print(f'\\nOrigin prompt: {prompt}\\n->\\nRefine prompt: {refine_prompt}')\n            prompt = refine_prompt\n        input_prompt = POS_PROMPT.format(prompt)\n        print(image)\n        videos = pipeline(\n            conditional_images=image, \n            prompt=input_prompt, \n            negative_prompt=NEG_PROMPT, \n            num_frames=num_frames,\n            num_inference_steps=num_inference_steps,\n            guidance_scale=guidance_scale,\n            num_samples_per_prompt=num_samples,\n            max_sequence_length=512,\n            device=device, \n            ).videos\n        if num_frames != 1 and enhance_video_model is not None:\n            # b t h w c\n            videos = enhance_video_model.enhance_a_video(videos, input_prompt, 2.0, args.fps, 250)\n        if num_frames == 1:\n            videos = rearrange(videos, 'b t h w c -> (b t) c h w')\n            if num_samples != 1:\n                for i, image in enumerate(videos):\n                    save_image(\n                        image / 255.0, \n                        os.path.join(\n                            args.save_img_path, \n                            f'{args.sample_method}_{index}_gs{guidance_scale}_s{num_inference_steps}_i{i}.jpg'\n                            ),\n                        nrow=math.ceil(math.sqrt(videos.shape[0])), \n                        normalize=True, \n                        value_range=(0, 1)\n                        )  # b c h w\n            save_image(\n                videos / 255.0, \n                os.path.join(\n                    args.save_img_path, \n                    f'{args.sample_method}_{index}_gs{guidance_scale}_s{num_inference_steps}.jpg'\n                    ),\n                nrow=math.ceil(math.sqrt(videos.shape[0])), \n                normalize=True, \n                value_range=(0, 1)\n                )  # b c h w\n        else:\n            if num_samples == 1:\n                imageio.mimwrite(\n                    os.path.join(\n                        args.save_img_path,\n                        f'{args.sample_method}_{index}_gs{guidance_scale}_s{num_inference_steps}.mp4'\n                    ), \n                    videos[0],\n                    fps=args.fps, \n                    quality=6\n                    )  # highest quality is 10, lowest is 0\n            else:\n                for i in range(num_samples):\n                    imageio.mimwrite(\n                        os.path.join(\n                            args.save_img_path,\n                            f'{args.sample_method}_{index}_gs{guidance_scale}_s{num_inference_steps}_i{i}.mp4'\n                        ), videos[i],\n                        fps=args.fps, \n                        quality=6\n                        )  # highest quality is 10, lowest is 0\n                    \n                videos = save_video_grid(videos)\n                imageio.mimwrite(\n                    os.path.join(\n                        args.save_img_path,\n                        f'{args.sample_method}_{index}_gs{guidance_scale}_s{num_inference_steps}.mp4'\n                    ), \n                    videos,\n                    fps=args.fps, \n                    quality=6\n                    )  # highest quality is 10, lowest is 0)\n                videos = videos.unsqueeze(0) # 1 t h w c\n        video_grids.append(videos)\n\n    video_grids = torch.cat(video_grids, dim=0)\n    \n    final_path = os.path.join(\n                    args.save_img_path,\n                    f'{args.sample_method}_gs{guidance_scale}_s{num_inference_steps}'\n                    )\n\n    random_string = ''.join(random.choices(string.ascii_letters, k=4))\n    if num_frames == 1:\n        final_path = final_path + f'_{random_string}.jpg'\n        save_image(\n            video_grids / 255.0, \n            final_path, \n            nrow=math.ceil(math.sqrt(len(video_grids))), \n            normalize=True, \n            value_range=(0, 1)\n            )\n    else:\n        video_grids = save_video_grid(video_grids)\n        final_path = final_path + f'_{random_string}.mp4'\n        imageio.mimwrite(\n            final_path, \n            video_grids, \n            fps=args.fps, \n            quality=6\n            )\n    print('save path {}'.format(args.save_img_path))\n    return final_path, seed\n\n\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\"--model_path\", type=str, default='LanguageBind/Open-Sora-Plan-v1.0.0')\nparser.add_argument(\"--version\", type=str, default='v1_3', choices=['v1_3', 'v1_5'])\nparser.add_argument(\"--caption_refiner\", type=str, default=None)\nparser.add_argument(\"--ae\", type=str, default='CausalVAEModel_4x8x8')\nparser.add_argument(\"--ae_path\", type=str, default='CausalVAEModel_4x8x8')\nparser.add_argument(\"--text_encoder_name_1\", type=str, default='DeepFloyd/t5-v1_1-xxl')\nparser.add_argument(\"--text_encoder_name_2\", type=str, default=None)\nparser.add_argument(\"--save_img_path\", type=str, default=\"./sample_videos/t2v\")\nparser.add_argument(\"--fps\", type=int, default=24)\nparser.add_argument('--enable_tiling', action='store_true')\nparser.add_argument('--save_memory', action='store_true')\nparser.add_argument('--compile', action='store_true') \nparser.add_argument(\"--gradio_port\", type=int, default=11900)\nparser.add_argument(\"--enhance_video\", type=str, default=None)\nparser.add_argument(\"--model_type\", type=str, default='i2v')\nargs = parser.parse_args()\n\nargs.model_path = \"/storage/gyy/hw/Open-Sora-Plan/runs/inpaint_93x1280x1280_stage3_gpu/checkpoint-1692/model_ema\"\nargs.version = \"v1_3\"\nargs.caption_refiner = \"/storage/ongoing/refine_model/llama3_1_instruct_lora/llama3_8B_lora_merged_cn\"\nargs.ae = \"WFVAEModel_D8_4x8x8\"\nargs.ae_path = \"/storage/lcm/wf-vae_trilinear\"\nargs.text_encoder_name_1 = \"/storage/ongoing/new/Open-Sora-Plan/cache_dir/mt5-xxl\"\nargs.text_encoder_name_2 = None\nargs.save_img_path = \"./test_gradio\"\nargs.fps = 18\n\nargs.prediction_type = \"v_prediction\"\nargs.rescale_betas_zero_snr = True\nargs.cache_dir = \"./cache_dir\"\nargs.sample_method = 'EulerAncestralDiscrete'\nargs.sp = False\nargs.crop_for_hw = False\nargs.max_hw_square = 1048576\nargs.enable_tiling = True\n\ndtype = torch.bfloat16\nargs = init_gpu_env(args)\ndevice = torch.cuda.current_device()\n\nif args.enhance_video is not None:\n    from opensora.sample.VEnhancer.enhance_a_video import VEnhancer\n    enhance_video_model = VEnhancer(model_path=args.enhance_video, version='v2', device=device)\nelse:\n    enhance_video_model = None\n\npipeline = prepare_pipeline(args, dtype, device)\nif args.caption_refiner is not None:\n    caption_refiner_model = OpenSoraCaptionRefiner(args, dtype, device)\nelse:\n    caption_refiner_model = None\n\nwith gr.Blocks(css=\"style.css\") as demo:\n    gr.Markdown(LOGO)\n    gr.Markdown(TITLE)\n    gr.Markdown(DESCRIPTION)\n\n    with gr.Row(equal_height=False):\n        with gr.Group():\n            with gr.Row():\n                image_1 = gr.Image(type=\"filepath\", label='Image 1')\n                image_2 = gr.Image(type=\"filepath\", label='Image 2')\n            with gr.Row():\n                seed = gr.Slider(\n                    label=\"Seed\",\n                    minimum=0,\n                    maximum=MAX_SEED,\n                    step=1,\n                    value=0,\n                )\n            randomize_seed = gr.Checkbox(label=\"Randomize seed\", value=True)\n            with gr.Row():\n                num_frames = gr.Slider(\n                        label=\"Num Frames\",\n                        minimum=29,\n                        maximum=93,\n                        step=16,\n                        value=29,\n                    )\n                num_samples = gr.Slider(\n                        label=\"Num Samples\",\n                        minimum=1,\n                        maximum=4,\n                        step=1,\n                        value=1,\n                    )\n            with gr.Row():\n                guidance_scale = gr.Slider(\n                    label=\"Guidance scale\",\n                    minimum=1,\n                    maximum=10,\n                    step=0.1,\n                    value=7.5,\n                )\n                inference_steps = gr.Slider(\n                    label=\"Inference steps\",\n                    minimum=10,\n                    maximum=200,\n                    step=1,\n                    value=50,\n                )\n        with gr.Group():\n            with gr.Row():\n                prompt = gr.Text(\n                    label=\"Prompt\",\n                    show_label=False,\n                    max_lines=1,\n                    placeholder=\"Enter your prompt\",\n                    container=False,\n                )\n                run_button = gr.Button(\"Run\", scale=0)\n            result = gr.Video(autoplay=True, label=\"Result\")\n            # result = gr.Gallery(label=\"Result\", columns=NUM_IMAGES_PER_PROMPT,  show_label=False)\n\n\n    \n\n    # with gr.Row(), gr.Column():\n    #     gr.Markdown(\"## Examples (Text-to-Video)\")\n    #     examples = [[i, 42, 93, 1, 7.5, 100, False] for i in t2v_prompt_examples]\n    #     gr.Examples(\n    #         examples=examples, \n    #         inputs=[\n    #             prompt, seed, num_frames, num_samples, \n    #             guidance_scale, inference_steps, randomize_seed\n    #             ],\n    #         label='Text-to-Video', \n    #         cache_examples=False, \n    #         outputs=[result, seed],\n    #         fn=generate\n    #         )\n\n\n    gr.on(\n        triggers=[\n            prompt.submit,\n            run_button.click,\n        ],\n        fn=generate,\n        inputs=[\n            prompt,\n            image_1, \n            image_2, \n            seed,\n            num_frames, \n            num_samples, \n            guidance_scale,\n            inference_steps,\n            randomize_seed,\n        ],\n        outputs=[result, seed],\n        api_name=\"run\",\n    )\n\n\n\n# if __name__ == \"__main__\":\ndemo.queue(max_size=20).launch(\n    server_name=\"0.0.0.0\", \n    server_port=args.gradio_port+args.local_rank, \n    debug=True\n    )"
  },
  {
    "path": "opensora/serve/style.css",
    "content": ".gradio-container{width:1280px!important}"
  },
  {
    "path": "opensora/train/train_causalvae.py",
    "content": "import os\nimport torch\nimport torch.distributed as dist\nfrom torch.nn.parallel import DistributedDataParallel as DDP\n\nfrom torch.utils.data import DataLoader, DistributedSampler, Subset\nimport argparse\nimport logging\nfrom colorlog import ColoredFormatter\nimport tqdm\nfrom itertools import chain\nimport wandb\nimport random\nimport numpy as np\nfrom pathlib import Path\nfrom einops import rearrange\nimport time\n\ntry:\n    import lpips\nexcept:\n    raise Exception(\"Need lpips to valid.\")\n\nimport sys\nsys.path.append(\".\")\nfrom opensora.models.causalvideovae.model import *\nfrom opensora.models.causalvideovae.model.ema_model import EMA\nfrom opensora.models.causalvideovae.dataset.ddp_sampler import CustomDistributedSampler\nfrom opensora.models.causalvideovae.dataset.video_dataset import TrainVideoDataset, ValidVideoDataset\nfrom opensora.models.causalvideovae.model.utils.module_utils import resolve_str_to_obj\nfrom opensora.models.causalvideovae.utils.video_utils import tensor_to_video\n\n\ndef set_random_seed(seed):\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n\ndef ddp_setup():\n    dist.init_process_group(backend=\"nccl\")\n    torch.cuda.set_device(int(os.environ[\"LOCAL_RANK\"]))\n\ndef setup_logger(rank):\n    logger = logging.getLogger()\n    logger.setLevel(logging.INFO)\n    formatter = ColoredFormatter(\n        f\"[rank{rank}] %(log_color)s%(asctime)s - %(levelname)s - %(message)s\",\n        datefmt=\"%Y-%m-%d %H:%M:%S\",\n        log_colors={\n            \"DEBUG\": \"cyan\",\n            \"INFO\": \"green\",\n            \"WARNING\": \"yellow\",\n            \"ERROR\": \"red\",\n            \"CRITICAL\": \"bold_red\",\n        },\n        reset=True,\n        style=\"%\",\n    )\n    stream_handler = logging.StreamHandler()\n    stream_handler.setLevel(logging.DEBUG)\n    stream_handler.setFormatter(formatter)\n\n    if not logger.handlers:\n        logger.addHandler(stream_handler)\n\n    return logger\n\ndef check_unused_params(model):\n    unused_params = []\n    for name, param in model.named_parameters():\n        if param.grad is None:\n            unused_params.append(name)\n    return unused_params\n\ndef set_requires_grad_optimizer(optimizer, requires_grad):\n    for param_group in optimizer.param_groups:\n        for param in param_group[\"params\"]:\n            param.requires_grad = requires_grad\n\ndef total_params(model):\n    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n    total_params_in_millions = total_params / 1e6\n    return int(total_params_in_millions)\n\n\ndef get_exp_name(args):\n    return f\"{args.exp_name}-lr{args.lr:.2e}-bs{args.batch_size}-rs{args.resolution}-sr{args.sample_rate}-fr{args.num_frames}\"\n\n\ndef set_train(modules):\n    for module in modules:\n        module.train()\n\n\ndef set_eval(modules):\n    for module in modules:\n        module.eval()\n\n\ndef set_modules_requires_grad(modules, requires_grad):\n    for module in modules:\n        module.requires_grad_(requires_grad)\n\n\ndef save_checkpoint(\n    epoch,\n    current_step,\n    optimizer_state,\n    state_dict,\n    scaler_state,\n    sampler_state,\n    checkpoint_dir,\n    filename=\"checkpoint.ckpt\",\n    ema_state_dict={},\n):\n    filepath = checkpoint_dir / Path(filename)\n    torch.save(\n        {\n            \"epoch\": epoch,\n            \"current_step\": current_step,\n            \"optimizer_state\": optimizer_state,\n            \"state_dict\": state_dict,\n            \"ema_state_dict\": ema_state_dict,\n            \"scaler_state\": scaler_state,\n            \"sampler_state\": sampler_state,\n        },\n        filepath,\n    )\n    return filepath\n\n\ndef valid(global_rank, rank, model, val_dataloader, precision, args):\n    if args.eval_lpips:\n        lpips_model = lpips.LPIPS(net=\"alex\", spatial=True)\n        lpips_model.to(rank)\n        lpips_model = DDP(lpips_model, device_ids=[rank])\n        lpips_model.requires_grad_(False)\n        lpips_model.eval()\n\n    bar = None\n    if global_rank == 0:\n        bar = tqdm.tqdm(total=len(val_dataloader), desc=\"Validation...\")\n\n    psnr_list = []\n    lpips_list = []\n    video_log = []\n    num_video_log = args.eval_num_video_log\n\n    with torch.no_grad():\n        for batch_idx, batch in enumerate(val_dataloader):\n            inputs = batch[\"video\"].to(rank)\n            with torch.cuda.amp.autocast(dtype=precision):\n                outputs = model(inputs)\n                video_recon = outputs[0]\n\n            # Upload videos\n            if global_rank == 0:\n                for i in range(len(video_recon)):\n                    if num_video_log <= 0:\n                        break\n                    video = tensor_to_video(video_recon[i])\n                    video_log.append(video)\n                    num_video_log -= 1\n            inputs = rearrange(inputs, \"b c t h w -> (b t) c h w\").contiguous()\n            video_recon = rearrange(\n                video_recon, \"b c t h w -> (b t) c h w\"\n            ).contiguous()\n\n            # Calculate PSNR\n            mse = torch.mean(torch.square(inputs - video_recon), dim=(1, 2, 3))\n            psnr = 20 * torch.log10(1 / torch.sqrt(mse))\n            psnr = psnr.mean().detach().cpu().item()\n\n            # Calculate LPIPS\n            if args.eval_lpips:\n                lpips_score = (\n                    lpips_model.forward(inputs, video_recon)\n                    .mean()\n                    .detach()\n                    .cpu()\n                    .item()\n                )\n                lpips_list.append(lpips_score)\n\n            psnr_list.append(psnr)\n            if global_rank == 0:\n                bar.update()\n            # Release gpus memory\n            torch.cuda.empty_cache()\n    return psnr_list, lpips_list, video_log\n\n\ndef gather_valid_result(psnr_list, lpips_list, video_log_list, rank, world_size):\n    gathered_psnr_list = [None for _ in range(world_size)]\n    gathered_lpips_list = [None for _ in range(world_size)]\n    gathered_video_logs = [None for _ in range(world_size)]\n\n    dist.all_gather_object(gathered_psnr_list, psnr_list)\n    dist.all_gather_object(gathered_lpips_list, lpips_list)\n    dist.all_gather_object(gathered_video_logs, video_log_list)\n    return (\n        np.array(gathered_psnr_list).mean(),\n        np.array(gathered_lpips_list).mean(),\n        list(chain(*gathered_video_logs)),\n    )\n\n\ndef train(args):\n    # Setup logger\n    ddp_setup()\n    rank = int(os.environ[\"LOCAL_RANK\"])\n    global_rank = dist.get_rank()\n    logger = setup_logger(rank)\n\n    # Init\n    ckpt_dir = Path(args.ckpt_dir) / Path(get_exp_name(args))\n    if global_rank == 0:\n        try:\n            ckpt_dir.mkdir(exist_ok=False, parents=True)\n        except:\n            logger.warning(f\"`{ckpt_dir}` exists!\")\n            time.sleep(5)\n    dist.barrier()\n\n    # Load generator model\n    model_cls = ModelRegistry.get_model(args.model_name)\n\n    if not model_cls:\n        raise ModuleNotFoundError(\n            f\"`{args.model_name}` not in {str(ModelRegistry._models.keys())}.\"\n        )\n\n    if args.pretrained_model_name_or_path is not None:\n        if global_rank == 0:\n            logger.warning(\n                f\"You are loading a checkpoint from `{args.pretrained_model_name_or_path}`.\"\n            )\n        model = model_cls.from_pretrained(\n            args.pretrained_model_name_or_path,\n            ignore_mismatched_sizes=args.ignore_mismatched_sizes,\n            low_cpu_mem_usage=False,\n            device_map=None,\n        )\n    else:\n        if global_rank == 0:\n            logger.warning(f\"Model will be inited randomly.\")\n        model = model_cls.from_config(args.model_config)\n    \n    if global_rank == 0:\n        logger.warning(\"Connecting to WANDB...\")\n        model_config = dict(**model.config)\n        args_config = dict(**vars(args))\n        if 'resolution' in model_config:\n            del model_config['resolution']\n        \n        wandb.init(\n            project=os.environ.get(\"WANDB_PROJECT\", \"causalvideovae\"),\n            config=dict(**model_config, **args_config),\n            name=get_exp_name(args),\n        )\n    \n    dist.barrier()\n    \n    # Load discriminator model\n    disc_cls = resolve_str_to_obj(args.disc_cls, append=False)\n    logger.warning(\n        f\"disc_class: {args.disc_cls} perceptual_weight: {args.perceptual_weight}  loss_type: {args.loss_type}\"\n    )\n    disc = disc_cls(\n        disc_start=args.disc_start,\n        disc_weight=args.disc_weight,\n        kl_weight=args.kl_weight,\n        logvar_init=args.logvar_init,\n        perceptual_weight=args.perceptual_weight,\n        loss_type=args.loss_type,\n        wavelet_weight=args.wavelet_weight\n    )\n\n    # DDP\n    model = model.to(rank, )\n    model = DDP(\n        model, device_ids=[rank], find_unused_parameters=args.find_unused_parameters\n    )\n    disc = disc.to(rank)\n    disc = DDP(\n        disc, device_ids=[rank], find_unused_parameters=args.find_unused_parameters\n    )\n\n    # Load dataset\n    dataset = TrainVideoDataset(\n        args.video_path,\n        sequence_length=args.num_frames,\n        resolution=args.resolution,\n        sample_rate=args.sample_rate,\n        dynamic_sample=args.dynamic_sample,\n        cache_file=\"idx.pkl\",\n        is_main_process=global_rank == 0,\n    )\n    ddp_sampler = CustomDistributedSampler(dataset)\n    dataloader = DataLoader(\n        dataset,\n        batch_size=args.batch_size,\n        sampler=ddp_sampler,\n        pin_memory=True,\n        num_workers=args.dataset_num_worker,\n    )\n    val_dataset = ValidVideoDataset(\n        real_video_dir=args.eval_video_path,\n        num_frames=args.eval_num_frames,\n        sample_rate=args.eval_sample_rate,\n        crop_size=args.eval_resolution,\n        resolution=args.eval_resolution,\n    )\n    indices = range(args.eval_subset_size)\n    val_dataset = Subset(val_dataset, indices=indices)\n    val_sampler = CustomDistributedSampler(val_dataset)\n    val_dataloader = DataLoader(\n        val_dataset,\n        batch_size=args.eval_batch_size,\n        sampler=val_sampler,\n        pin_memory=True,\n    )\n\n    # Optimizer\n    modules_to_train = [module for module in model.module.get_decoder()]\n    if not args.freeze_encoder:\n        modules_to_train += [module for module in model.module.get_encoder()]\n    else:\n        for module in model.module.get_encoder():\n            module.eval()\n            module.requires_grad_(False)\n        logger.warning(\"Encoder is freezed!\")\n\n    parameters_to_train = []\n    for module in modules_to_train:\n        parameters_to_train += list(filter(lambda p: p.requires_grad, module.parameters()))\n\n    gen_optimizer = torch.optim.AdamW(parameters_to_train, lr=args.lr, weight_decay=1e-4)\n    disc_optimizer = torch.optim.AdamW(\n        filter(lambda p: p.requires_grad, disc.module.discriminator.parameters()), lr=args.lr, weight_decay=0.01\n    )\n\n    # AMP scaler\n    scaler = torch.cuda.amp.GradScaler()\n    precision = torch.bfloat16\n    if args.mix_precision == \"fp16\":\n        precision = torch.float16\n    elif args.mix_precision == \"fp32\":\n        precision = torch.float32\n    print(precision)\n    \n    # Load from checkpoint\n    start_epoch = 0\n    current_step = 0\n    if args.resume_from_checkpoint:\n        if not os.path.isfile(args.resume_from_checkpoint):\n            raise Exception(\n                f\"Make sure `{args.resume_from_checkpoint}` is a ckpt file.\"\n            )\n        checkpoint = torch.load(args.resume_from_checkpoint, map_location=\"cpu\")\n        model.module.load_state_dict(checkpoint[\"state_dict\"][\"gen_model\"], strict=False)\n        disc.module.load_state_dict(checkpoint[\"state_dict\"][\"dics_model\"])\n        scaler.load_state_dict(checkpoint[\"scaler_state\"])\n        gen_optimizer.load_state_dict(checkpoint[\"optimizer_state\"][\"gen_optimizer\"])\n        disc_optimizer.load_state_dict(checkpoint[\"optimizer_state\"][\"disc_optimizer\"])\n        ddp_sampler.load_state_dict(checkpoint[\"sampler_state\"])\n        start_epoch = checkpoint[\"sampler_state\"][\"epoch\"]\n        current_step = checkpoint[\"current_step\"]\n        logger.info(\n            f\"Checkpoint loaded from {args.resume_from_checkpoint}, starting from epoch {start_epoch} step {current_step}\"\n        )\n\n    if args.ema:\n        logger.warning(f\"Start with EMA. EMA decay = {args.ema_decay}.\")\n        ema = EMA(model, args.ema_decay)\n        ema.register()\n\n    # Training loop\n    logger.info(\"Prepared!\")\n    dist.barrier()\n    if global_rank == 0:\n        logger.info(f\"=== Model Params ===\")\n        logger.info(f\"Generator:\\t\\t{total_params(model.module)}M\")\n        logger.info(f\"\\t- Encoder:\\t{total_params(model.module.encoder):d}M\")\n        logger.info(f\"\\t- Decoder:\\t{total_params(model.module.decoder):d}M\")\n        logger.info(f\"Discriminator:\\t{total_params(disc.module):d}M\")\n        logger.info(f\"===========\")\n        logger.info(f\"Precision is set to: {args.mix_precision}!\")\n        logger.info(\"Start training!\")\n\n    # Training Bar\n    bar_desc = \"\"\n    bar = None\n    if global_rank == 0:\n        max_steps = (\n            args.epochs * len(dataloader) if args.max_steps is None else args.max_steps\n        )\n        bar = tqdm.tqdm(total=max_steps, desc=bar_desc.format(current_epoch=0, loss=0))\n        bar.update(current_step)\n        bar_desc = \"Epoch: {current_epoch}, Loss: {loss}\"\n        logger.warning(\"Training Details: \")\n        logger.warning(f\" Max steps: {max_steps}\")\n        logger.warning(f\" Dataset Samples: {len(dataloader)}\")\n        logger.warning(\n            f\" Total Batch Size: {args.batch_size} * {os.environ['WORLD_SIZE']}\"\n        )\n    dist.barrier()\n\n    # Training Loop\n    num_epochs = args.epochs\n\n    def update_bar(bar):\n        if global_rank == 0:\n            bar.desc = bar_desc.format(current_epoch=epoch, loss=f\"-\")\n            bar.update()\n\n    for epoch in range(num_epochs):\n        set_train(modules_to_train)\n        ddp_sampler.set_epoch(epoch)  # Shuffle data at every epoch\n        for batch_idx, batch in enumerate(dataloader):\n            inputs = batch[\"video\"].to(rank)\n\n            if (\n                current_step % 2 == 1\n                and current_step >= disc.module.discriminator_iter_start\n            ):\n                set_modules_requires_grad(modules_to_train, False)\n                step_gen = False\n                step_dis = True\n            else:\n                set_modules_requires_grad(modules_to_train, True)\n                step_gen = True\n                step_dis = False\n\n            assert (\n                step_gen or step_dis\n            ), \"You should backward either Gen or Dis in a step.\"\n\n            with torch.cuda.amp.autocast(dtype=precision):\n                outputs = model(inputs)\n                recon = outputs[0]\n                posterior = outputs[1]\n                if len(outputs) == 3: # which means there is wavelet output\n                    wavelet_coeffs = outputs[2] if args.wavelet_loss else None\n                else:\n                    wavelet_coeffs = None\n\n            # Generator Step\n            if step_gen:\n                with torch.cuda.amp.autocast(dtype=precision):\n                    g_loss, g_log = disc(\n                        inputs,\n                        recon,\n                        posterior,\n                        optimizer_idx=0,\n                        global_step=current_step,\n                        last_layer=model.module.get_last_layer(),\n                        wavelet_coeffs=wavelet_coeffs,\n                        split=\"train\",\n                    )\n                gen_optimizer.zero_grad()\n                scaler.scale(g_loss).backward()\n                # scaler.unscale_(gen_optimizer)\n                # torch.nn.utils.clip_grad_norm_(parameters_to_train, 5e6)\n                scaler.step(gen_optimizer)\n                scaler.update()\n                if args.ema:\n                    ema.update()\n                if global_rank == 0 and current_step % args.log_steps == 0:\n                    wandb.log(\n                        {\"train/generator_loss\": g_loss.item()}, step=current_step\n                    )\n                    wandb.log(\n                        {\"train/rec_loss\": g_log['train/rec_loss']}, step=current_step\n                    )\n                    wandb.log(\n                        {\"train/latents_std\": posterior.sample().std().item()}, step=current_step\n                    )\n                    if 'train/sb_loss' in g_log:\n                        wandb.log(\n                            {\"train/sb_loss\": g_log['train/sb_loss']}, step=current_step\n                        )\n                    if 'train/wl_loss' in g_log:\n                        wandb.log(\n                            {\"train/wl_loss\": g_log['train/wl_loss']}, step=current_step\n                        )\n\n            # Discriminator Step\n            if step_dis:\n                with torch.cuda.amp.autocast(dtype=precision):\n                    d_loss, d_log = disc(\n                        inputs,\n                        recon,\n                        posterior,\n                        optimizer_idx=1,\n                        global_step=current_step,\n                        last_layer=None,\n                        split=\"train\",\n                    )\n                disc_optimizer.zero_grad()\n                scaler.scale(d_loss).backward()\n                scaler.unscale_(disc_optimizer)\n                torch.nn.utils.clip_grad_norm_(disc.module.discriminator.parameters(), 1.0)\n                scaler.step(disc_optimizer)\n                scaler.update()\n                if global_rank == 0 and current_step % args.log_steps == 0:\n                    wandb.log(\n                        {\"train/discriminator_loss\": d_loss.item()}, step=current_step\n                    )\n\n            update_bar(bar)\n            current_step += 1\n\n            def valid_model(model, name=\"\"):\n                set_eval(modules_to_train)\n                psnr_list, lpips_list, video_log = valid(\n                    global_rank, rank, model, val_dataloader, precision, args\n                )\n                valid_psnr, valid_lpips, valid_video_log = gather_valid_result(\n                    psnr_list, lpips_list, video_log, rank, dist.get_world_size()\n                )\n                if global_rank == 0:\n                    name = \"_\" + name if name != \"\" else name\n                    wandb.log(\n                        {\n                            f\"val{name}/recon\": wandb.Video(\n                                np.array(valid_video_log), fps=10\n                            )\n                        },\n                        step=current_step,\n                    )\n                    wandb.log({f\"val{name}/psnr\": valid_psnr}, step=current_step)\n                    wandb.log({f\"val{name}/lpips\": valid_lpips}, step=current_step)\n                    logger.info(f\"{name} Validation done.\")\n\n            if current_step % args.eval_steps == 0 or current_step == 1:\n                if global_rank == 0:\n                    logger.info(\"Starting validation...\")\n                valid_model(model)\n                if args.ema:\n                    ema.apply_shadow()\n                    valid_model(model, \"ema\")\n                    ema.restore()\n\n            # Checkpoint\n            if current_step % args.save_ckpt_step == 0 and global_rank == 0:\n                file_path = save_checkpoint(\n                    epoch,\n                    current_step,\n                    {\n                        \"gen_optimizer\": gen_optimizer.state_dict(),\n                        \"disc_optimizer\": disc_optimizer.state_dict(),\n                    },\n                    {\n                        \"gen_model\": model.module.state_dict(),\n                        \"dics_model\": disc.module.state_dict(),\n                    },\n                    scaler.state_dict(),\n                    ddp_sampler.state_dict(),\n                    ckpt_dir,\n                    f\"checkpoint-{current_step}.ckpt\",\n                    ema_state_dict=ema.shadow if args.ema else {},\n                )\n                logger.info(f\"Checkpoint has been saved to `{file_path}`.\")\n\n    dist.destroy_process_group()\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"Distributed Training\")\n    # Exp setting\n    parser.add_argument(\n        \"--exp_name\", type=str, default=\"test\", help=\"number of epochs to train\"\n    )\n    parser.add_argument(\"--seed\", type=int, default=1234, help=\"seed\")\n    # Training setting\n    parser.add_argument(\n        \"--epochs\", type=int, default=10, help=\"number of epochs to train\"\n    )\n    parser.add_argument(\n        \"--max_steps\", type=int, default=None, help=\"number of epochs to train\"\n    )\n    parser.add_argument(\"--save_ckpt_step\", type=int, default=1000, help=\"\")\n    parser.add_argument(\"--ckpt_dir\", type=str, default=\"./results/\", help=\"\")\n    parser.add_argument(\n        \"--batch_size\", type=int, default=1, help=\"batch size for training\"\n    )\n    parser.add_argument(\"--lr\", type=float, default=1e-5, help=\"learning rate\")\n    parser.add_argument(\"--log_steps\", type=int, default=5, help=\"log steps\")\n    parser.add_argument(\"--freeze_encoder\", action=\"store_true\", help=\"\")\n    parser.add_argument(\"--clip_grad_norm\", type=float, default=1e5, help=\"\")\n\n    # Data\n    parser.add_argument(\"--video_path\", type=str, default=None, help=\"\")\n    parser.add_argument(\"--num_frames\", type=int, default=17, help=\"\")\n    parser.add_argument(\"--resolution\", type=int, default=256, help=\"\")\n    parser.add_argument(\"--sample_rate\", type=int, default=2, help=\"\")\n    parser.add_argument(\"--dynamic_sample\", action=\"store_true\", help=\"\")\n    # Generator model\n    parser.add_argument(\"--ignore_mismatched_sizes\", action=\"store_true\", help=\"\")\n    parser.add_argument(\"--find_unused_parameters\", action=\"store_true\", help=\"\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\", type=str, default=None, help=\"\"\n    )\n    parser.add_argument(\"--model_name\", type=str, default=None, help=\"\")\n    parser.add_argument(\"--resume_from_checkpoint\", type=str, default=None, help=\"\")\n    parser.add_argument(\"--not_resume_training_process\", action=\"store_true\", help=\"\")\n    parser.add_argument(\"--model_config\", type=str, default=None, help=\"\")\n    parser.add_argument(\n        \"--mix_precision\",\n        type=str,\n        default=\"bf16\",\n        choices=[\"fp16\", \"bf16\", \"fp32\"],\n        help=\"precision for training\",\n    )\n    parser.add_argument(\"--wavelet_loss\", action=\"store_true\", help=\"\")\n    parser.add_argument(\"--wavelet_weight\", type=float, default=0.1, help=\"\")\n    # Discriminator Model\n    parser.add_argument(\"--load_disc_from_checkpoint\", type=str, default=None, help=\"\")\n    parser.add_argument(\n        \"--disc_cls\",\n        type=str,\n        default=\"opensora.models.causalvideovae.model.losses.LPIPSWithDiscriminator3D\",\n        help=\"\",\n    )\n    parser.add_argument(\"--disc_start\", type=int, default=5, help=\"\")\n    parser.add_argument(\"--disc_weight\", type=float, default=0.5, help=\"\")\n    parser.add_argument(\"--kl_weight\", type=float, default=1e-06, help=\"\")\n    parser.add_argument(\"--perceptual_weight\", type=float, default=1.0, help=\"\")\n    parser.add_argument(\"--loss_type\", type=str, default=\"l1\", help=\"\")\n    parser.add_argument(\"--logvar_init\", type=float, default=0.0, help=\"\")\n\n    # Validation\n    parser.add_argument(\"--eval_steps\", type=int, default=1000, help=\"\")\n    parser.add_argument(\"--eval_video_path\", type=str, default=None, help=\"\")\n    parser.add_argument(\"--eval_num_frames\", type=int, default=17, help=\"\")\n    parser.add_argument(\"--eval_resolution\", type=int, default=256, help=\"\")\n    parser.add_argument(\"--eval_sample_rate\", type=int, default=1, help=\"\")\n    parser.add_argument(\"--eval_batch_size\", type=int, default=8, help=\"\")\n    parser.add_argument(\"--eval_subset_size\", type=int, default=100, help=\"\")\n    parser.add_argument(\"--eval_num_video_log\", type=int, default=2, help=\"\")\n    parser.add_argument(\"--eval_lpips\", action=\"store_true\", help=\"\")\n\n    # Dataset\n    parser.add_argument(\"--dataset_num_worker\", type=int, default=4, help=\"\")\n\n    # EMA\n    parser.add_argument(\"--ema\", action=\"store_true\", help=\"\")\n    parser.add_argument(\"--ema_decay\", type=float, default=0.999, help=\"\")\n\n    args = parser.parse_args()\n\n    set_random_seed(args.seed)\n    train(args)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "opensora/train/train_inpaint.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\n\"\"\"\nA minimal training script for DiT using PyTorch DDP.\n\"\"\"\nimport argparse\nimport logging\nimport math\nimport os\nimport shutil\nfrom pathlib import Path\nfrom typing import Optional\nimport gc\nimport numpy as np\nfrom einops import rearrange\nimport torch.utils\nimport torch.utils.data\nfrom tqdm import tqdm\nimport yaml\n\nfrom opensora.adaptor.modules import replace_with_fp32_forwards\n\ntry:\n    import torch_npu\n    from opensora.npu_config import npu_config\n    from opensora.acceleration.parallel_states import initialize_sequence_parallel_state, \\\n        destroy_sequence_parallel_group, get_sequence_parallel_state, set_sequence_parallel_state\n    from opensora.acceleration.communications import prepare_parallel_data, broadcast\nexcept:\n    torch_npu = None\n    npu_config = None\n    from opensora.utils.parallel_states import initialize_sequence_parallel_state, \\\n        destroy_sequence_parallel_group, get_sequence_parallel_state, set_sequence_parallel_state\n    from opensora.utils.communications import prepare_parallel_data, broadcast\n    pass\nimport time\nfrom dataclasses import field, dataclass\nfrom torch.utils.data import DataLoader\nfrom copy import deepcopy\nimport accelerate\nimport torch\nimport json\nfrom torch.nn import functional as F\nimport transformers\nfrom transformers.utils import ContextManagers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import DistributedType, ProjectConfiguration, set_seed\nfrom accelerate.state import AcceleratorState\nfrom packaging import version\nfrom tqdm.auto import tqdm\n\nimport diffusers\nfrom diffusers import DDPMScheduler, PNDMScheduler, DPMSolverMultistepScheduler, CogVideoXDDIMScheduler\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import EMAModel, compute_snr\nfrom diffusers.utils import check_min_version, is_wandb_available\n\nfrom opensora.models.causalvideovae import ae_stride_config, ae_channel_config\nfrom opensora.models.causalvideovae import ae_norm, ae_denorm\nfrom opensora.models import CausalVAEModelWrapper\nfrom opensora.models.text_encoder import get_text_warpper\nfrom opensora.dataset import getdataset\nfrom opensora.models import CausalVAEModelWrapper\nfrom opensora.models.diffusion import Diffusion_models, Diffusion_models_class\nfrom opensora.utils.dataset_utils import Collate, LengthGroupedSampler\nfrom opensora.utils.utils import explicit_uniform_sampling\nfrom opensora.sample.pipeline_opensora import OpenSoraPipeline\nfrom opensora.models.causalvideovae import ae_stride_config, ae_wrapper\nfrom opensora.utils.mask_utils import MaskCompressor, GaussianNoiseAdder\n\n# from opensora.utils.utils import monitor_npu_power\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.24.0\")\nlogger = get_logger(__name__)\nfrom torch.utils.data import _utils\n_utils.MP_STATUS_CHECK_INTERVAL = 1800.0  # dataloader timeout (default is 5.0s), we increase it to 1800s.\n\n\nclass ProgressInfo:\n    def __init__(self, global_step, train_loss=0.0):\n        self.global_step = global_step\n        self.train_loss = train_loss\n\n\n#################################################################################\n#                                  Training Loop                                #\n#################################################################################\n\ndef main(args):\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    if torch_npu is not None and npu_config is not None:\n        npu_config.print_msg(args)\n        npu_config.seed_everything(args.seed)\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n\n    if args.num_frames != 1:\n        initialize_sequence_parallel_state(args.sp_size)\n\n    if args.report_to == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed, device_specific=True)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n            # backup the config file\n            shutil.copy(args.mask_config, os.path.join(args.output_dir, \"mask_config.yaml\"))\n\n    # For mixed precision training we cast all non-trainable weigths to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Create model:\n    \n    # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3.\n    # For this to work properly all models must be run through `accelerate.prepare`. But accelerate\n    # will try to assign the same optimizer with the same weights to all models during\n    # `deepspeed.initialize`, which of course doesn't work.\n    #\n    # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2\n    # frozen models from being partitioned during `zero.Init` which gets called during\n    # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding\n    # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded.\n    def deepspeed_zero_init_disabled_context_manager():\n        \"\"\"\n        returns either a context list that includes one that will disable zero.Init or an empty context list\n        \"\"\"\n        deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None\n        if deepspeed_plugin is None:\n            return []\n\n        return [deepspeed_plugin.zero3_init_context_manager(enable=False)]\n\n    with ContextManagers(deepspeed_zero_init_disabled_context_manager()):\n        kwargs = {}\n        ae = ae_wrapper[args.ae](args.ae_path, cache_dir=args.cache_dir, **kwargs).eval()\n        \n        if args.enable_tiling:\n            ae.vae.enable_tiling()\n\n        kwargs = {\n            'torch_dtype': weight_dtype, \n            'low_cpu_mem_usage': False\n            }\n        text_enc_1 = get_text_warpper(args.text_encoder_name_1)(args, **kwargs).eval()\n\n        text_enc_2 = None\n        if args.text_encoder_name_2 is not None:\n            text_enc_2 = get_text_warpper(args.text_encoder_name_2)(args, **kwargs).eval()\n    \n    ae_stride_t, ae_stride_h, ae_stride_w = ae_stride_config[args.ae]\n    ae.vae_scale_factor = (ae_stride_t, ae_stride_h, ae_stride_w)\n    assert ae_stride_h == ae_stride_w, f\"Support only ae_stride_h == ae_stride_w now, but found ae_stride_h ({ae_stride_h}), ae_stride_w ({ae_stride_w})\"\n    args.ae_stride_t, args.ae_stride_h, args.ae_stride_w = ae_stride_t, ae_stride_h, ae_stride_w\n    args.ae_stride = args.ae_stride_h\n    patch_size = args.model[-3:]\n    patch_size_t, patch_size_h, patch_size_w = int(patch_size[0]), int(patch_size[1]), int(patch_size[2])\n    args.patch_size = patch_size_h\n    args.patch_size_t, args.patch_size_h, args.patch_size_w = patch_size_t, patch_size_h, patch_size_w\n    assert patch_size_h == patch_size_w, f\"Support only patch_size_h == patch_size_w now, but found patch_size_h ({patch_size_h}), patch_size_w ({patch_size_w})\"\n    assert (args.num_frames - 1) % ae_stride_t == 0, f\"(Frames - 1) must be divisible by ae_stride_t, but found num_frames ({args.num_frames}), ae_stride_t ({ae_stride_t}).\"\n    assert args.max_height % ae_stride_h == 0, f\"Height must be divisible by ae_stride_h, but found Height ({args.max_height}), ae_stride_h ({ae_stride_h}).\"\n    assert args.max_width % ae_stride_h == 0, f\"Width size must be divisible by ae_stride_h, but found Width ({args.max_width}), ae_stride_h ({ae_stride_h}).\"\n\n    args.stride_t = ae_stride_t * patch_size_t\n    args.stride = ae_stride_h * patch_size_h\n    ae.latent_size = latent_size = (args.max_height // ae_stride_h, args.max_width // ae_stride_w)\n    args.latent_size_t = latent_size_t = (args.num_frames - 1) // ae_stride_t + 1\n\n    mask_compressor = MaskCompressor(ae_stride_h=ae_stride_h, ae_stride_w=ae_stride_w, ae_stride_t=ae_stride_t)\n    noise_adder = None\n    if args.add_noise_to_condition:\n        noise_adder = GaussianNoiseAdder(mean=-3.0, std=0.5, clear_ratio=0.05)\n\n    model_kwargs = {'vae_scale_factor_t': ae_stride_t}\n\n    model = Diffusion_models[args.model](\n        in_channels=ae_channel_config[args.ae],\n        out_channels=ae_channel_config[args.ae],\n        sample_size_h=latent_size,\n        sample_size_w=latent_size,\n        sample_size_t=latent_size_t,\n        interpolation_scale_h=args.interpolation_scale_h,\n        interpolation_scale_w=args.interpolation_scale_w,\n        interpolation_scale_t=args.interpolation_scale_t,\n        sparse1d=args.sparse1d, \n        sparse_n=args.sparse_n,\n        **model_kwargs,\n    )\n\n    # # use pretrained model?\n    if args.pretrained:\n        model_state_dict = model.state_dict()\n        print(f'Load from {args.pretrained}')\n        if args.pretrained.endswith('.safetensors'):  \n            from safetensors.torch import load_file as safe_load\n            pretrained_checkpoint = safe_load(args.pretrained, device=\"cpu\")\n            pretrained_keys = set(list(pretrained_checkpoint.keys()))\n            model_keys = set(list(model_state_dict.keys()))\n            common_keys = list(pretrained_keys & model_keys)\n            checkpoint = {k: pretrained_checkpoint[k] for k in common_keys if model_state_dict[k].numel() == pretrained_checkpoint[k].numel()}\n            if not 'pos_embed_masked_hidden_states.0.proj.weight' in checkpoint:\n                checkpoint['pos_embed_masked_hidden_states.0.proj.weight'] = checkpoint['pos_embed.proj.weight']\n                checkpoint['pos_embed_masked_hidden_states.0.proj.bias'] = checkpoint['pos_embed.proj.bias']\n            missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False)\n        elif os.path.isdir(args.pretrained):\n            if os.path.exists(os.path.join(args.pretrained, 'config.json')):\n                with open(os.path.join(args.pretrained, 'config.json')) as f:\n                    config = json.load(f)\n                class_name = config['_class_name']\n                print(f'Load from {args.pretrained} with class_name {class_name}')\n                load_model = Diffusion_models_class[class_name].from_pretrained(args.pretrained)\n                missing_keys, unexpected_keys = model.load_state_dict(load_model.state_dict(), strict=False)\n                if 'pos_embed_masked_hidden_states.0.proj.weight' in missing_keys:\n                    model.pos_embed_masked_hidden_states[0].proj.weight.data = deepcopy(load_model.pos_embed.proj.weight.data)\n                    model.pos_embed_masked_hidden_states[0].proj.bias.data = deepcopy(load_model.pos_embed.proj.bias.data)\n                    assert torch.equal(model.pos_embed_masked_hidden_states[0].proj.weight.data, load_model.pos_embed.proj.weight.data)\n                    assert torch.equal(model.pos_embed_masked_hidden_states[0].proj.bias.data, load_model.pos_embed.proj.bias.data)\n                    missing_keys.remove('pos_embed_masked_hidden_states.0.proj.weight')\n                    missing_keys.remove('pos_embed_masked_hidden_states.0.proj.bias')\n                del load_model\n            else:\n                raise ValueError(f'Invalid pretrained model path: {args.pretrained}, you should provide a valid pretrained model path within a valid config.json file!')\n        else:\n            pretrained_checkpoint = torch.load(args.pretrained, map_location='cpu')\n            if 'model' in checkpoint:\n                pretrained_checkpoint = pretrained_checkpoint['model']\n                pretrained_keys = set(list(pretrained_checkpoint.keys()))\n            model_keys = set(list(model_state_dict.keys()))\n            common_keys = list(pretrained_keys & model_keys)\n            checkpoint = {k: pretrained_checkpoint[k] for k in common_keys if model_state_dict[k].numel() == pretrained_checkpoint[k].numel()}\n            if not 'pos_embed_masked_hidden_states.0.proj.weight' in checkpoint:\n                checkpoint['pos_embed_masked_hidden_states.0.proj.weight'] = checkpoint['pos_embed.proj.weight']\n                checkpoint['pos_embed_masked_hidden_states.0.proj.bias'] = checkpoint['pos_embed.proj.bias']\n            missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False)\n\n        print(f'missing_keys {len(missing_keys)} {missing_keys}, unexpected_keys {len(unexpected_keys)}')\n        print(f'Successfully load {len(model_state_dict) - len(missing_keys)}/{len(model_state_dict)} keys from {args.pretrained}!')\n\n    model.gradient_checkpointing = args.gradient_checkpointing\n    # Freeze vae and text encoders.\n    ae.vae.requires_grad_(False)\n    text_enc_1.requires_grad_(False)\n    if text_enc_2 is not None:\n        text_enc_2.requires_grad_(False)\n    # Set model as trainable.\n    model.train()\n\n    kwargs = dict(\n        prediction_type=args.prediction_type, \n        rescale_betas_zero_snr=args.rescale_betas_zero_snr\n    )\n    if args.cogvideox_scheduler:\n        noise_scheduler = CogVideoXDDIMScheduler(**kwargs)\n    elif args.v1_5_scheduler:\n        kwargs['beta_start'] = 0.00085\n        kwargs['beta_end'] = 0.0120\n        kwargs['beta_schedule'] = \"scaled_linear\"\n        noise_scheduler = DDPMScheduler(**kwargs)\n    else:\n        noise_scheduler = DDPMScheduler(**kwargs)\n    # Move unet, vae and text_encoder to device and cast to weight_dtype\n    # The VAE is in float32 to avoid NaN losses.\n    if not args.extra_save_mem:\n        ae.vae.to(accelerator.device, dtype=torch.float32 if args.vae_fp32 else weight_dtype)\n        text_enc_1.to(accelerator.device, dtype=weight_dtype)\n        if text_enc_2 is not None:\n            text_enc_2.to(accelerator.device, dtype=weight_dtype)\n\n    # Create EMA for the unet.\n    if args.use_ema:\n        ema_model = deepcopy(model)\n        ema_model = EMAModel(ema_model.parameters(), decay=args.ema_decay, update_after_step=args.ema_start_step,\n                             model_cls=Diffusion_models_class[args.model], model_config=ema_model.config)\n\n    # `accelerate` 0.16.0 will have better support for customized saving\n    if version.parse(accelerate.__version__) >= version.parse(\"0.16.0\"):\n        # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n        def save_model_hook(models, weights, output_dir):\n            if accelerator.is_main_process:\n                if args.use_ema:\n                    ema_model.save_pretrained(os.path.join(output_dir, \"model_ema\"))\n\n                for i, model in enumerate(models):\n                    model.save_pretrained(os.path.join(output_dir, \"model\"))\n                    if weights:  # Don't pop if empty\n                        # make sure to pop weight so that corresponding model is not saved again\n                        weights.pop()\n\n        def load_model_hook(models, input_dir):\n            if args.use_ema:\n                load_model = EMAModel.from_pretrained(os.path.join(input_dir, \"model_ema\"), Diffusion_models_class[args.model])\n                ema_model.load_state_dict(load_model.state_dict())\n                ema_model.to(accelerator.device)\n                del load_model\n\n            for i in range(len(models)):\n                # pop models so that they are not loaded again\n                model = models.pop()\n\n                # load diffusers style into model\n                load_model = Diffusion_models_class[args.model].from_pretrained(input_dir, subfolder=\"model\")\n                model.register_to_config(**load_model.config)\n\n                model.load_state_dict(load_model.state_dict())\n                del load_model\n\n        accelerator.register_save_state_pre_hook(save_model_hook)\n        accelerator.register_load_state_pre_hook(load_model_hook)\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    params_to_optimize = model.parameters()\n    # Optimizer creation\n    if not (args.optimizer.lower() == \"prodigy\" or args.optimizer.lower() == \"adamw\"):\n        logger.warning(\n            f\"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy].\"\n            \"Defaulting to adamW\"\n        )\n        args.optimizer = \"adamw\"\n\n    if args.use_8bit_adam and not args.optimizer.lower() == \"adamw\":\n        logger.warning(\n            f\"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was \"\n            f\"set to {args.optimizer.lower()}\"\n        )\n\n    if args.optimizer.lower() == \"adamw\":\n        if args.use_8bit_adam:\n            try:\n                import bitsandbytes as bnb\n            except ImportError:\n                raise ImportError(\n                    \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n                )\n\n            optimizer_class = bnb.optim.AdamW8bit\n        else:\n            optimizer_class = torch.optim.AdamW\n\n        optimizer = optimizer_class(\n            params_to_optimize,\n            lr=args.learning_rate,\n            betas=(args.adam_beta1, args.adam_beta2),\n            weight_decay=args.adam_weight_decay,\n            eps=args.adam_epsilon,\n        )\n\n    if args.optimizer.lower() == \"prodigy\":\n        try:\n            import prodigyopt\n        except ImportError:\n            raise ImportError(\"To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`\")\n\n        optimizer_class = prodigyopt.Prodigy\n\n        if args.learning_rate <= 0.1:\n            logger.warning(\n                \"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0\"\n            )\n\n        optimizer = optimizer_class(\n            params_to_optimize,\n            lr=args.learning_rate,\n            betas=(args.adam_beta1, args.adam_beta2),\n            beta3=args.prodigy_beta3,\n            weight_decay=args.adam_weight_decay,\n            eps=args.adam_epsilon,\n            decouple=args.prodigy_decouple,\n            use_bias_correction=args.prodigy_use_bias_correction,\n            safeguard_warmup=args.prodigy_safeguard_warmup,\n        )\n    logger.info(f\"optimizer: {optimizer}\")\n\n    # Setup data:\n    if args.trained_data_global_step is not None:\n        initial_global_step_for_sampler = args.trained_data_global_step\n    else:\n        initial_global_step_for_sampler = 0\n    \n\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n    total_batch_size = total_batch_size // args.sp_size * args.train_sp_batch_size\n    args.total_batch_size = total_batch_size\n    if args.min_hxw is None:\n        args.min_hxw = args.max_hxw // 4\n    train_dataset = getdataset(args)\n    sampler = LengthGroupedSampler(\n                args.train_batch_size,\n                world_size=accelerator.num_processes, \n                gradient_accumulation_size=args.gradient_accumulation_steps, \n                initial_global_step=initial_global_step_for_sampler, \n                lengths=train_dataset.lengths, \n                group_data=args.group_data, \n            )\n    train_dataloader = DataLoader(\n        train_dataset,\n        shuffle=False,\n        # pin_memory=True,\n        collate_fn=Collate(args),\n        batch_size=args.train_batch_size,\n        num_workers=args.dataloader_num_workers,\n        sampler=sampler, \n        drop_last=True, \n        # prefetch_factor=4\n    )\n    logger.info(f'after train_dataloader')\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,\n        num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,\n    )\n\n    # Prepare everything with our `accelerator`.\n    # model.requires_grad_(False)\n    # model.pos_embed.requires_grad_(True)\n    # model.patch_embed.requires_grad_(True)\n\n    logger.info(f'before accelerator.prepare')\n    model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        model, optimizer, train_dataloader, lr_scheduler\n    )\n    logger.info(f'after accelerator.prepare')\n    \n    if args.use_ema:\n        ema_model.to(accelerator.device)\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    # NOTE wandb\n    if accelerator.is_main_process:\n        logger.info(\"init trackers...\")\n        project_name = os.getenv('PROJECT', os.path.basename(args.output_dir))\n        entity = os.getenv('ENTITY', None)\n        run_name = os.getenv('WANDB_NAME', None)\n        init_kwargs = {\n            \"entity\": entity,\n            \"run_name\": run_name,\n        }\n        accelerator.init_trackers(project_name=project_name, config=vars(args), init_kwargs=init_kwargs)\n\n    # Train!\n    print(f\"  Args = {args}\")\n    logger.info(f\"  Args = {args}\")\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Model = {model}\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    logger.info(f\"  Total optimization steps (num_update_steps_per_epoch) = {num_update_steps_per_epoch}\")\n    logger.info(f\"  Total training parameters = {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e9} B\")\n    \n    logger.info(f\"  AutoEncoder = {args.ae}; Dtype = {ae.vae.dtype}; Parameters = {sum(p.numel() for p in ae.parameters()) / 1e9} B\")\n    logger.info(f\"  Text_enc_1 = {args.text_encoder_name_1}; Dtype = {weight_dtype}; Parameters = {sum(p.numel() for p in text_enc_1.parameters()) / 1e9} B\")\n    if args.text_encoder_name_2 is not None:\n        logger.info(f\"  Text_enc_2 = {args.text_encoder_name_2}; Dtype = {weight_dtype}; Parameters = {sum(p.numel() for p in text_enc_2.parameters()) / 1e9} B\")\n\n    global_step = 0\n    first_epoch = 0\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n    progress_info = ProgressInfo(global_step, train_loss=0.0)\n\n    def sync_gradients_info(loss):\n        # Checks if the accelerator has performed an optimization step behind the scenes\n        if args.use_ema:\n            ema_model.step(model.parameters())\n        progress_bar.update(1)\n        progress_info.global_step += 1\n        end_time = time.time()\n        one_step_duration = end_time - start_time\n        accelerator.log({\"train_loss\": progress_info.train_loss}, step=progress_info.global_step)\n        if torch_npu is not None and npu_config is not None:\n            npu_config.print_msg(f\"Step: [{progress_info.global_step}], local_loss={loss.detach().item()}, \"\n                                 f\"train_loss={progress_info.train_loss}, time_cost={one_step_duration}\",\n                                 rank=0)\n        progress_info.train_loss = 0.0\n\n        # DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues.\n        if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:\n            if progress_info.global_step % args.checkpointing_steps == 0 or progress_info.global_step == args.after_one_epoch_global_step:\n                # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                if accelerator.is_main_process and args.checkpoints_total_limit is not None:\n                    checkpoints = os.listdir(args.output_dir)\n                    checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                    checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                    # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                    if len(checkpoints) >= args.checkpoints_total_limit:\n                        num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                        removing_checkpoints = checkpoints[0:num_to_remove]\n\n                        logger.info(\n                            f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                        )\n                        logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                        for removing_checkpoint in removing_checkpoints:\n                            removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                            shutil.rmtree(removing_checkpoint)\n\n                save_path = os.path.join(args.output_dir, f\"checkpoint-{progress_info.global_step}\")\n                accelerator.save_state(save_path)\n                logger.info(f\"Saved state to {save_path}\")\n\n        logs = {\"step_loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n        progress_bar.set_postfix(**logs)\n        # Regularly releasing memory\n        # if progress_info.global_step % 100 == 0:\n        #     torch.cuda.empty_cache()\n        #     gc.collect()\n\n    def run(model_input, model_kwargs, prof):\n        global start_time\n        start_time = time.time()\n\n        try:\n            in_channels = ae_channel_config[args.ae]\n            model_input, masked_input, video_mask = model_input[:, 0:in_channels], model_input[:, in_channels:2 * in_channels], model_input[:, 2 * in_channels:]\n        except:\n            raise ValueError(\"masked_x and video_mask is None!\")\n\n\n        noise = torch.randn_like(model_input)\n        if args.noise_offset:\n            # https://www.crosslabs.org//blog/diffusion-with-offset-noise\n            noise += args.noise_offset * torch.randn((model_input.shape[0], model_input.shape[1], 1, 1, 1),\n                                                     device=model_input.device)\n\n        bsz = model_input.shape[0]\n        # Sample a random timestep for each image without bias.\n        if accelerator.num_processes > noise_scheduler.config.num_train_timesteps: \n            timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device)\n        else:\n            timesteps = explicit_uniform_sampling(\n                T=noise_scheduler.config.num_train_timesteps, \n                n=accelerator.num_processes, \n                rank=accelerator.process_index, \n                bsz=bsz, device=model_input.device, \n                )\n        # print(f'rank: {accelerator.process_index}, timesteps: {timesteps}')\n        if get_sequence_parallel_state():  # image do not need sp, disable when image batch\n            broadcast(timesteps)\n\n        # Add noise to the model input according to the noise magnitude at each timestep\n        # (this is the forward diffusion process)\n\n        noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)\n        model_pred = model(\n            torch.cat([noisy_model_input, masked_input, video_mask], dim=1),\n            timesteps,\n            **model_kwargs,\n        )[0]\n        # Get the target for loss depending on the prediction type\n        if noise_scheduler.config.prediction_type == \"epsilon\":\n            target = noise\n        elif noise_scheduler.config.prediction_type == \"v_prediction\":\n            target = noise_scheduler.get_velocity(model_input, noise, timesteps)\n        elif noise_scheduler.config.prediction_type == \"sample\":\n            # We set the target to latents here, but the model_pred will return the noise sample prediction.\n            target = model_input\n            # We will have to subtract the noise residual from the prediction to get the target sample.\n            model_pred = model_pred - noise\n        else:\n            raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n        mask = model_kwargs.get('attention_mask', None)\n        if get_sequence_parallel_state():\n            if torch.all(mask.bool()):\n                mask = None\n            # mask    (sp_bs*b t h w)\n            assert mask is None\n        b, c, _, _, _ = model_pred.shape\n        if mask is not None:\n            mask = mask.unsqueeze(1).repeat(1, c, 1, 1, 1).float()  # b t h w -> b c t h w\n            mask = mask.reshape(b, -1)\n        if args.snr_gamma is None:\n            # model_pred: b c t h w, attention_mask: b t h w\n            loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"none\")\n            loss = loss.reshape(b, -1)\n            if mask is not None:\n                loss = (loss * mask).sum() / mask.sum()  # mean loss on unpad patches\n            else:\n                loss = loss.mean()\n        else:\n            # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.\n            # Since we predict the noise instead of x_0, the original formulation is slightly changed.\n            # This is discussed in Section 4.2 of the same paper.\n            snr = compute_snr(noise_scheduler, timesteps)\n            mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(\n                dim=1\n            )[0]\n            if noise_scheduler.config.prediction_type == \"epsilon\":\n                mse_loss_weights = mse_loss_weights / snr\n            elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                mse_loss_weights = mse_loss_weights / (snr + 1)\n            else:\n                raise NameError(f'{noise_scheduler.config.prediction_type}')\n            loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"none\")\n            loss = loss.reshape(b, -1)\n            mse_loss_weights = mse_loss_weights.reshape(b, 1)\n            if mask is not None:\n                loss = (loss * mask * mse_loss_weights).sum() / mask.sum()  # mean loss on unpad patches\n            else:\n                loss = (loss * mse_loss_weights).mean()\n        # Gather the losses across all processes for logging (if we use distributed training).\n        avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()\n        progress_info.train_loss += avg_loss.detach().item() / args.gradient_accumulation_steps\n\n        # Backpropagate\n        accelerator.backward(loss)\n        if accelerator.sync_gradients:\n            params_to_clip = model.parameters()\n            accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n        optimizer.step()\n        lr_scheduler.step()\n        optimizer.zero_grad()\n\n        if accelerator.sync_gradients:\n            sync_gradients_info(loss)\n\n        if prof is not None:\n            prof.step()\n\n\n        return loss\n\n    def train_one_step(step_, data_item_, prof_=None):\n        train_loss = 0.0\n        x, attn_mask, input_ids_1, cond_mask_1, input_ids_2, cond_mask_2 = data_item_\n        if accelerator.is_main_process:\n            print(f'\\nstep: {step_}, x: {x.shape}, dtype: {x.dtype}')\n        # assert not torch.any(torch.isnan(x)), 'torch.any(torch.isnan(x))'\n        # print('after data collate')\n        # print(f'x: {x.shape}, attn_mask: {attn_mask.shape}, input_ids_1: {input_ids_1.shape}, cond_mask_1: {cond_mask_1.shape}, input_ids_2: {input_ids_2.shape}, cond_mask_2: {cond_mask_2.shape}')\n\n        if args.extra_save_mem:\n            torch.cuda.empty_cache()\n            ae.vae.to(accelerator.device, dtype=torch.float32 if args.vae_fp32 else weight_dtype)\n            text_enc_1.to(accelerator.device, dtype=weight_dtype)\n            if text_enc_2 is not None:\n                text_enc_2.to(accelerator.device, dtype=weight_dtype)\n\n        x = x.to(accelerator.device, dtype=ae.vae.dtype)  # B C T H W\n        # x = x.to(accelerator.device, dtype=torch.float32)  # B C T H W\n        attn_mask = attn_mask.to(accelerator.device)  # B T H W\n        input_ids_1 = input_ids_1.to(accelerator.device)  # B 1 L\n        cond_mask_1 = cond_mask_1.to(accelerator.device)  # B 1 L\n        input_ids_2 = input_ids_2.to(accelerator.device) if input_ids_2 is not None else input_ids_2 # B 1 L\n        cond_mask_2 = cond_mask_2.to(accelerator.device) if cond_mask_2 is not None else cond_mask_2 # B 1 L\n\n        with torch.no_grad():\n            B, N, L = input_ids_1.shape  # B 1 L\n            # use batch inference\n            input_ids_1 = input_ids_1.reshape(-1, L)\n            cond_mask_1 = cond_mask_1.reshape(-1, L)\n            cond_1 = text_enc_1(input_ids_1, cond_mask_1)  # B L D\n            cond_1 = cond_1.reshape(B, N, L, -1)\n            cond_mask_1 = cond_mask_1.reshape(B, N, L)\n            if text_enc_2 is not None:\n                B_, N_, L_ = input_ids_2.shape  # B 1 L\n                input_ids_2 = input_ids_2.reshape(-1, L_)\n                cond_2 = text_enc_2(input_ids_2, cond_mask_2)  # B D\n                cond_2 = cond_2.reshape(B_, 1, -1)  # B 1 D\n            else:\n                cond_2 = None\n\n            # Map input images to latent space + normalize latents\n            x, masked_x, mask = x[:, :3], x[:, 3:6], x[:, 6:7]\n            # Adding noise to control frames enhances generalization ability.\n            if noise_adder is not None:\n                masked_x = noise_adder(masked_x, mask)\n            x, masked_x = ae.encode(x), ae.encode(masked_x)\n            mask = mask_compressor(mask)\n            x = torch.cat([x, masked_x, mask], dim=1) \n        \n        if args.extra_save_mem:\n            ae.vae.to('cpu')\n            text_enc_1.to('cpu')\n            if text_enc_2 is not None:\n                text_enc_2.to('cpu')\n            torch.cuda.empty_cache()\n\n        current_step_frame = x.shape[2]\n        current_step_sp_state = get_sequence_parallel_state()\n        if args.sp_size != 1:  # enable sp\n            if current_step_frame == 1:  # but image do not need sp\n                set_sequence_parallel_state(False)\n            else:\n                set_sequence_parallel_state(True)\n        if get_sequence_parallel_state():\n            x, cond_1, attn_mask, cond_mask_1, cond_2 = prepare_parallel_data(\n                x, cond_1, attn_mask, cond_mask_1, cond_2\n                )        \n            # x            (b c t h w)   -gather0-> (sp*b c t h w)   -scatter2-> (sp*b c t//sp h w)\n            # cond_1       (b sp l/sp d) -gather0-> (sp*b sp l/sp d) -scatter1-> (sp*b 1 l/sp d)\n            # attn_mask    (b t*sp h w)  -gather0-> (sp*b t*sp h w)  -scatter1-> (sp*b t h w)\n            # cond_mask_1  (b sp l)      -gather0-> (sp*b sp l)      -scatter1-> (sp*b 1 l)\n            # cond_2       (b sp d)      -gather0-> (sp*b sp d)      -scatter1-> (sp*b 1 d)\n            for iter in range(args.train_batch_size * args.sp_size // args.train_sp_batch_size):\n                with accelerator.accumulate(model):\n                    # x            (sp_bs*b c t//sp h w)\n                    # cond_1       (sp_bs*b 1 l/sp d)\n                    # attn_mask    (sp_bs*b t h w)\n                    # cond_mask_1  (sp_bs*b 1 l)\n                    # cond_2       (sp_bs*b 1 d)\n                    st_idx = iter * args.train_sp_batch_size\n                    ed_idx = (iter + 1) * args.train_sp_batch_size\n                    model_kwargs = dict(\n                        encoder_hidden_states=cond_1[st_idx: ed_idx],\n                        attention_mask=attn_mask[st_idx: ed_idx],\n                        encoder_attention_mask=cond_mask_1[st_idx: ed_idx], \n                        pooled_projections=cond_2[st_idx: ed_idx] if cond_2 is not None else None, \n                        )\n                    run(x[st_idx: ed_idx], model_kwargs, prof_)\n\n        else:\n            with accelerator.accumulate(model):\n                # assert not torch.any(torch.isnan(x)), 'after vae'\n                x = x.to(weight_dtype)\n                model_kwargs = dict(\n                    encoder_hidden_states=cond_1, attention_mask=attn_mask, \n                    encoder_attention_mask=cond_mask_1, \n                    pooled_projections=cond_2\n                    )\n                run(x, model_kwargs, prof_)\n\n        set_sequence_parallel_state(current_step_sp_state)  # in case the next step use sp, which need broadcast(timesteps)\n\n        if progress_info.global_step >= args.max_train_steps:\n            return True\n\n        return False\n\n    def train_one_epoch(prof_=None):\n        # for epoch in range(first_epoch, args.num_train_epochs):\n        progress_info.train_loss = 0.0\n        if progress_info.global_step >= args.max_train_steps:\n            return True\n        \n        args.after_one_epoch_global_step = progress_info.global_step + len(train_dataloader) // args.gradient_accumulation_steps - 1\n\n        for step, data_item in enumerate(train_dataloader):\n            if train_one_step(step, data_item, prof_):\n                break\n\n            if step >= 2 and torch_npu is not None and npu_config is not None:\n                npu_config.free_mm()\n\n    if npu_config is not None and npu_config.on_npu and npu_config.profiling:\n        experimental_config = torch_npu.profiler._ExperimentalConfig(\n            profiler_level=torch_npu.profiler.ProfilerLevel.Level1,\n            aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization\n        )\n        profile_output_path = f\"/home/image_data/npu_profiling_t2v/{os.getenv('PROJECT_NAME', 'local')}\"\n        os.makedirs(profile_output_path, exist_ok=True)\n\n        with torch_npu.profiler.profile(\n                activities=[\n                    torch_npu.profiler.ProfilerActivity.CPU, \n                    torch_npu.profiler.ProfilerActivity.NPU, \n                    ],\n                with_stack=True,\n                record_shapes=True,\n                profile_memory=True,\n                experimental_config=experimental_config,\n                schedule=torch_npu.profiler.schedule(\n                    wait=npu_config.profiling_step, warmup=0, active=1, repeat=1, skip_first=0\n                    ),\n                on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(f\"{profile_output_path}/\")\n        ) as prof:\n            train_one_epoch(prof)\n    else:\n        if args.enable_profiling:\n            with torch.profiler.profile(\n                activities=[\n                    # torch.profiler.ProfilerActivity.CPU, \n                    torch.profiler.ProfilerActivity.CUDA, \n                    ], \n                schedule=torch.profiler.schedule(wait=5, warmup=1, active=1, repeat=1, skip_first=0),\n                on_trace_ready=torch.profiler.tensorboard_trace_handler('./gpu_profiling_active_1_delmask_delbkmask_andvaemask_curope_gpu'),\n                record_shapes=True,\n                profile_memory=True,\n                with_stack=True\n            ) as prof:\n                train_one_epoch(prof)\n        else:\n            train_one_epoch()\n    accelerator.wait_for_everyone()\n    accelerator.end_training()\n    if get_sequence_parallel_state():\n        destroy_sequence_parallel_group()\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n\n    # dataset & dataloader\n    parser.add_argument(\"--dataset\", type=str, required=True)\n    parser.add_argument(\"--data\", type=str, required='')\n    parser.add_argument(\"--sample_rate\", type=int, default=1)\n    parser.add_argument(\"--train_fps\", type=int, default=24)\n    parser.add_argument(\"--drop_short_ratio\", type=float, default=1.0)\n    parser.add_argument(\"--speed_factor\", type=float, default=1.0)\n    parser.add_argument(\"--num_frames\", type=int, default=65)\n    parser.add_argument(\"--max_height\", type=int, default=320)\n    parser.add_argument(\"--max_width\", type=int, default=240)\n    parser.add_argument(\"--max_hxw\", type=int, default=None)\n    parser.add_argument(\"--min_hxw\", type=int, default=None)\n    parser.add_argument(\"--ood_img_ratio\", type=float, default=0.0)\n    parser.add_argument(\"--use_img_from_vid\", action=\"store_true\")\n    parser.add_argument(\"--model_max_length\", type=int, default=512)\n    parser.add_argument('--cfg', type=float, default=0.1)\n    parser.add_argument(\"--dataloader_num_workers\", type=int, default=10, help=\"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\")\n    parser.add_argument(\"--train_batch_size\", type=int, default=16, help=\"Batch size (per device) for the training dataloader.\")\n    parser.add_argument(\"--group_data\", action=\"store_true\")\n    parser.add_argument(\"--hw_stride\", type=int, default=32)\n    parser.add_argument(\"--force_resolution\", action=\"store_true\")\n    parser.add_argument(\"--trained_data_global_step\", type=int, default=None)\n    parser.add_argument(\"--use_decord\", action=\"store_true\")\n\n    # text encoder & vae & diffusion model\n    parser.add_argument('--vae_fp32', action='store_true')\n    parser.add_argument('--extra_save_mem', action='store_true')\n    parser.add_argument(\"--model\", type=str, choices=list(Diffusion_models.keys()), default=\"Latte-XL/122\")\n    parser.add_argument('--enable_tiling', action='store_true')\n    parser.add_argument('--interpolation_scale_h', type=float, default=1.0)\n    parser.add_argument('--interpolation_scale_w', type=float, default=1.0)\n    parser.add_argument('--interpolation_scale_t', type=float, default=1.0)\n    parser.add_argument(\"--ae\", type=str, default=\"stabilityai/sd-vae-ft-mse\")\n    parser.add_argument(\"--ae_path\", type=str, default=\"stabilityai/sd-vae-ft-mse\")\n    parser.add_argument(\"--text_encoder_name_1\", type=str, default='DeepFloyd/t5-v1_1-xxl')\n    parser.add_argument(\"--text_encoder_name_2\", type=str, default=None)\n    parser.add_argument(\"--cache_dir\", type=str, default='./cache_dir')\n    parser.add_argument(\"--pretrained\", type=str, default=None)\n    parser.add_argument('--sparse1d', action='store_true')\n    parser.add_argument('--sparse_n', type=int, default=2)\n    parser.add_argument('--cogvideox_scheduler', action='store_true')\n    parser.add_argument('--v1_5_scheduler', action='store_true')\n    parser.add_argument(\"--gradient_checkpointing\", action=\"store_true\", help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\")\n\n    # diffusion setting\n    parser.add_argument(\"--snr_gamma\", type=float, default=None, help=\"SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. More details here: https://arxiv.org/abs/2303.09556.\")\n    parser.add_argument(\"--use_ema\", action=\"store_true\", help=\"Whether to use EMA model.\")\n    parser.add_argument(\"--ema_decay\", type=float, default=0.9999)\n    parser.add_argument(\"--ema_start_step\", type=int, default=0)\n    parser.add_argument(\"--noise_offset\", type=float, default=0.0, help=\"The scale of noise offset.\")\n    parser.add_argument(\"--prediction_type\", type=str, default='epsilon', help=\"The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.\")\n    parser.add_argument('--rescale_betas_zero_snr', action='store_true')\n\n    # validation & logs\n    parser.add_argument(\"--enable_profiling\", action=\"store_true\")\n    parser.add_argument(\"--num_sampling_steps\", type=int, default=20)\n    parser.add_argument('--guidance_scale', type=float, default=4.5)\n    parser.add_argument(\"--enable_tracker\", action=\"store_true\")\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\"--output_dir\", type=str, default=None, help=\"The output directory where the model predictions and checkpoints will be written.\")\n    parser.add_argument(\"--checkpoints_total_limit\", type=int, default=None, help=(\"Max number of checkpoints to store.\"))\n    parser.add_argument(\"--checkpointing_steps\", type=int, default=500,\n                        help=(\n                            \"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final\"\n                            \" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming\"\n                            \" training using `--resume_from_checkpoint`.\"\n                        ),\n                        )\n    parser.add_argument(\"--resume_from_checkpoint\", type=str, default=None,\n                        help=(\n                            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n                            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n                        ),\n                        )\n    parser.add_argument(\"--logging_dir\", type=str, default=\"logs\",\n                        help=(\n                            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n                            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n                        ),\n                        )\n    parser.add_argument(\"--report_to\", type=str, default=\"tensorboard\",\n                        help=(\n                            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n                            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n                        ),\n                        )\n    \n    # optimizer & scheduler\n    parser.add_argument(\"--num_train_epochs\", type=int, default=100)\n    parser.add_argument(\"--max_train_steps\", type=int, default=1000000, help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\")\n    parser.add_argument(\"--gradient_accumulation_steps\", type=int, default=1, help=\"Number of updates steps to accumulate before performing a backward/update pass.\")\n    parser.add_argument(\"--optimizer\", type=str, default=\"adamW\", help='The optimizer type to use. Choose between [\"AdamW\", \"prodigy\"]')\n    parser.add_argument(\"--learning_rate\", type=float, default=1e-4, help=\"Initial learning rate (after the potential warmup period) to use.\")\n    parser.add_argument(\"--lr_warmup_steps\", type=int, default=0, help=\"Number of steps for the warmup in the lr scheduler.\")\n    parser.add_argument(\"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW\")\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam and Prodigy optimizers.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam and Prodigy optimizers.\")\n    parser.add_argument(\"--prodigy_decouple\", type=bool, default=True, help=\"Use AdamW style decoupled weight decay\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-02, help=\"Weight decay to use for unet params\")\n    parser.add_argument(\"--adam_weight_decay_text_encoder\", type=float, default=None, help=\"Weight decay to use for text_encoder\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-15, help=\"Epsilon value for the Adam optimizer and Prodigy optimizers.\")\n    parser.add_argument(\"--prodigy_use_bias_correction\", type=bool, default=True, help=\"Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW\")\n    parser.add_argument(\"--prodigy_safeguard_warmup\", type=bool, default=True, help=\"Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. Ignored if optimizer is adamW\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--prodigy_beta3\", type=float, default=None,\n                        help=\"coefficients for computing the Prodidy stepsize using running averages. If set to None, \"\n                             \"uses the value of square root of beta2. Ignored if optimizer is adamW\",\n                        )\n    parser.add_argument(\"--lr_scheduler\", type=str, default=\"constant\",\n                        help=(\n                            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n                            ' \"constant\", \"constant_with_warmup\"]'\n                        ),\n                        )\n    parser.add_argument(\"--allow_tf32\", action=\"store_true\",\n                        help=(\n                            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n                            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n                        ),\n                        )\n    parser.add_argument(\"--mixed_precision\", type=str, default=None, choices=[\"no\", \"fp16\", \"bf16\"],\n                        help=(\n                            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n                            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n                            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n                        ),\n                        )\n\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\"--sp_size\", type=int, default=1, help=\"For sequence parallel\")\n    parser.add_argument(\"--train_sp_batch_size\", type=int, default=1, help=\"Batch size for sequence parallel training\")\n\n    # inpaint\n    parser.add_argument(\"--mask_config\", type=str, default=None)\n    parser.add_argument(\"--add_noise_to_condition\", action='store_true')\n    parser.add_argument(\"--default_text_ratio\", type=float, default=0.5) # for inpainting mode\n\n    args = parser.parse_args()\n\n    assert args.mask_config is not None, 'mask_config is required!'\n    with open(args.mask_config, 'r') as f:\n        yaml_config = yaml.safe_load(f)\n    \n    for key, value in yaml_config.items():\n        if not hasattr(args, key):\n            setattr(args, key, value)\n\n    main(args)\n"
  },
  {
    "path": "opensora/train/train_t2v_diffusers.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\n\"\"\"\nA minimal training script for DiT using PyTorch DDP.\n\"\"\"\nimport argparse\nimport logging\nimport math\nimport os\nimport shutil\nfrom pathlib import Path\nfrom typing import Optional\nimport gc\nimport numpy as np\nfrom einops import rearrange\nimport torch.utils\nimport torch.utils.data\nfrom tqdm import tqdm\nimport time\n\nfrom opensora.adaptor.modules import replace_with_fp32_forwards\ntry:\n    import torch_npu\n    from opensora.npu_config import npu_config\n    from opensora.acceleration.parallel_states import initialize_sequence_parallel_state, \\\n        destroy_sequence_parallel_group, get_sequence_parallel_state, set_sequence_parallel_state\n    from opensora.acceleration.communications import prepare_parallel_data, broadcast\nexcept:\n    torch_npu = None\n    npu_config = None\n    from opensora.utils.parallel_states import initialize_sequence_parallel_state, \\\n        destroy_sequence_parallel_group, get_sequence_parallel_state, set_sequence_parallel_state\n    from opensora.utils.communications import prepare_parallel_data, broadcast\n    pass\n\nfrom dataclasses import field, dataclass\nfrom torch.utils.data import DataLoader\nfrom copy import deepcopy\nimport accelerate\nimport torch\nfrom torch.nn import functional as F\nimport transformers\nfrom transformers.utils import ContextManagers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import DistributedType, ProjectConfiguration, set_seed\nfrom accelerate.state import AcceleratorState\nfrom packaging import version\nfrom tqdm.auto import tqdm\n\nimport copy\nimport diffusers\nfrom diffusers import DDPMScheduler, PNDMScheduler, DPMSolverMultistepScheduler, CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import EMAModel, compute_snr\nfrom diffusers.utils import check_min_version, is_wandb_available\nfrom diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3\n\nfrom opensora.models.causalvideovae import ae_stride_config, ae_channel_config\nfrom opensora.models.causalvideovae import ae_norm, ae_denorm\nfrom opensora.models import CausalVAEModelWrapper\nfrom opensora.models.text_encoder import get_text_warpper\nfrom opensora.dataset import getdataset\nfrom opensora.models import CausalVAEModelWrapper\nfrom opensora.models.diffusion import Diffusion_models, Diffusion_models_class\nfrom opensora.utils.dataset_utils import Collate, LengthGroupedSampler\nfrom opensora.utils.utils import explicit_uniform_sampling\nfrom opensora.sample.pipeline_opensora import OpenSoraPipeline\nfrom opensora.models.causalvideovae import ae_stride_config, ae_wrapper\n\n# from opensora.utils.utils import monitor_npu_power\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.24.0\")\nlogger = get_logger(__name__)\n\n@torch.inference_mode()\ndef log_validation(args, model, vae, text_encoder, tokenizer, accelerator, weight_dtype, global_step, ema=False):\n    positive_prompt = \"(masterpiece), (best quality), (ultra-detailed), {}. emotional, harmonious, vignette, 4k epic detailed, shot on kodak, 35mm photo, sharp focus, high budget, cinemascope, moody, epic, gorgeous\"\n    negative_prompt = \"\"\"nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, \n                        \"\"\"\n    validation_prompt = [\n        \"a cat wearing sunglasses and working as a lifeguard at pool.\",\n        \"A serene underwater scene featuring a sea turtle swimming through a coral reef. The turtle, with its greenish-brown shell, is the main focus of the video, swimming gracefully towards the right side of the frame. The coral reef, teeming with life, is visible in the background, providing a vibrant and colorful backdrop to the turtle's journey. Several small fish, darting around the turtle, add a sense of movement and dynamism to the scene.\"\n        ]\n    logger.info(f\"Running validation....\\n\")\n    model = accelerator.unwrap_model(model)\n    scheduler = DPMSolverMultistepScheduler()\n    opensora_pipeline = OpenSoraPipeline(vae=vae,\n                                         text_encoder_1=text_encoder[0],\n                                         text_encoder_2=text_encoder[1],\n                                         tokenizer=tokenizer,\n                                         scheduler=scheduler,\n                                         transformer=model).to(device=accelerator.device)\n    videos = []\n    for prompt in validation_prompt:\n        logger.info('Processing the ({}) prompt'.format(prompt))\n        video = opensora_pipeline(\n                                positive_prompt.format(prompt),\n                                negative_prompt=negative_prompt, \n                                num_frames=args.num_frames,\n                                height=args.max_height,\n                                width=args.max_width,\n                                num_inference_steps=args.num_sampling_steps,\n                                guidance_scale=args.guidance_scale,\n                                enable_temporal_attentions=True,\n                                num_images_per_prompt=1,\n                                mask_feature=True,\n                                max_sequence_length=args.model_max_length,\n                                ).images\n        videos.append(video[0])\n    # import ipdb;ipdb.set_trace()\n    gc.collect()\n    torch.cuda.empty_cache()\n    videos = torch.stack(videos).numpy()\n    videos = rearrange(videos, 'b t h w c -> b t c h w')\n    for tracker in accelerator.trackers:\n        if tracker.name == \"tensorboard\":\n            if videos.shape[1] == 1:\n                assert args.num_frames == 1\n                images = rearrange(videos, 'b 1 c h w -> (b 1) h w c')\n                np_images = np.stack([np.asarray(img) for img in images])\n                tracker.writer.add_images(f\"{'ema_' if ema else ''}validation\", np_images, global_step, dataformats=\"NHWC\")\n            else:\n                np_videos = np.stack([np.asarray(vid) for vid in videos])\n                tracker.writer.add_video(f\"{'ema_' if ema else ''}validation\", np_videos, global_step, fps=24)\n        if tracker.name == \"wandb\":\n            import wandb\n            if videos.shape[1] == 1:\n                images = rearrange(videos, 'b 1 c h w -> (b 1) h w c')\n                logs = {\n                    f\"{'ema_' if ema else ''}validation\": [\n                        wandb.Image(image, caption=f\"{i}: {prompt}\")\n                        for i, (image, prompt) in enumerate(zip(images, validation_prompt))\n                    ]\n                }\n            else:\n                logs = {\n                    f\"{'ema_' if ema else ''}validation\": [\n                        wandb.Video(video, caption=f\"{i}: {prompt}\", fps=24)\n                        for i, (video, prompt) in enumerate(zip(videos, validation_prompt))\n                    ]\n                }\n            tracker.log(logs, step=global_step)\n\n    del opensora_pipeline\n    gc.collect()\n    torch.cuda.empty_cache()\n\n\nclass ProgressInfo:\n    def __init__(self, global_step, train_loss=0.0):\n        self.global_step = global_step\n        self.train_loss = train_loss\n\n\n#################################################################################\n#                                  Training Loop                                #\n#################################################################################\n\ndef main(args):\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    if torch_npu is not None and npu_config is not None:\n        npu_config.print_msg(args)\n        npu_config.seed_everything(args.seed)\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n\n    if args.num_frames != 1:\n        initialize_sequence_parallel_state(args.sp_size)\n\n    if args.report_to == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n\n        # if accelerator.is_main_process:\n        #     from threading import Thread\n        #     Thread(target=monitor_npu_power, daemon=True).start()\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed, device_specific=True)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n    # For mixed precision training we cast all non-trainable weigths to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Create model:\n    \n    # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3.\n    # For this to work properly all models must be run through `accelerate.prepare`. But accelerate\n    # will try to assign the same optimizer with the same weights to all models during\n    # `deepspeed.initialize`, which of course doesn't work.\n    #\n    # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2\n    # frozen models from being partitioned during `zero.Init` which gets called during\n    # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding\n    # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded.\n    def deepspeed_zero_init_disabled_context_manager():\n        \"\"\"\n        returns either a context list that includes one that will disable zero.Init or an empty context list\n        \"\"\"\n        deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None\n        if deepspeed_plugin is None:\n            return []\n\n        return [deepspeed_plugin.zero3_init_context_manager(enable=False)]\n\n    with ContextManagers(deepspeed_zero_init_disabled_context_manager()):\n        kwargs = {}\n        ae = ae_wrapper[args.ae](args.ae_path, cache_dir=args.cache_dir, **kwargs).eval()\n        \n        if args.enable_tiling:\n            ae.vae.enable_tiling()\n\n        kwargs = {\n            'torch_dtype': weight_dtype, \n            'low_cpu_mem_usage': False\n            }\n        text_enc_1 = get_text_warpper(args.text_encoder_name_1)(args, **kwargs).eval()\n\n        text_enc_2 = None\n        if args.text_encoder_name_2 is not None:\n            text_enc_2 = get_text_warpper(args.text_encoder_name_2)(args, **kwargs).eval()\n\n    ae_stride_t, ae_stride_h, ae_stride_w = ae_stride_config[args.ae]\n    ae.vae_scale_factor = (ae_stride_t, ae_stride_h, ae_stride_w)\n    assert ae_stride_h == ae_stride_w, f\"Support only ae_stride_h == ae_stride_w now, but found ae_stride_h ({ae_stride_h}), ae_stride_w ({ae_stride_w})\"\n    args.ae_stride_t, args.ae_stride_h, args.ae_stride_w = ae_stride_t, ae_stride_h, ae_stride_w\n    args.ae_stride = args.ae_stride_h\n    patch_size = args.model[-3:]\n    patch_size_t, patch_size_h, patch_size_w = int(patch_size[0]), int(patch_size[1]), int(patch_size[2])\n    args.patch_size = patch_size_h\n    args.patch_size_t, args.patch_size_h, args.patch_size_w = patch_size_t, patch_size_h, patch_size_w\n    assert patch_size_h == patch_size_w, f\"Support only patch_size_h == patch_size_w now, but found patch_size_h ({patch_size_h}), patch_size_w ({patch_size_w})\"\n    assert (args.num_frames - 1) % ae_stride_t == 0, f\"(Frames - 1) must be divisible by ae_stride_t, but found num_frames ({args.num_frames}), ae_stride_t ({ae_stride_t}).\"\n    assert args.max_height % ae_stride_h == 0, f\"Height must be divisible by ae_stride_h, but found Height ({args.max_height}), ae_stride_h ({ae_stride_h}).\"\n    assert args.max_width % ae_stride_h == 0, f\"Width size must be divisible by ae_stride_h, but found Width ({args.max_width}), ae_stride_h ({ae_stride_h}).\"\n\n    args.stride_t = ae_stride_t * patch_size_t\n    args.stride = ae_stride_h * patch_size_h\n    ae.latent_size = latent_size = (args.max_height // ae_stride_h, args.max_width // ae_stride_w)\n    args.latent_size_t = latent_size_t = (args.num_frames - 1) // ae_stride_t + 1\n    model = Diffusion_models[args.model](\n        in_channels=ae_channel_config[args.ae],\n        out_channels=ae_channel_config[args.ae],\n        sample_size_h=latent_size,\n        sample_size_w=latent_size,\n        sample_size_t=latent_size_t,\n        interpolation_scale_h=args.interpolation_scale_h,\n        interpolation_scale_w=args.interpolation_scale_w,\n        interpolation_scale_t=args.interpolation_scale_t,\n        sparse1d=args.sparse1d, \n        sparse_n=args.sparse_n, \n        skip_connection=args.skip_connection, \n    )\n\n    # # use pretrained model?\n    if args.pretrained:\n        model_state_dict = model.state_dict()\n        print(f'Load from {args.pretrained}')\n        if args.pretrained.endswith('.safetensors'):  \n            from safetensors.torch import load_file as safe_load\n            pretrained_checkpoint = safe_load(args.pretrained, device=\"cpu\")\n            pretrained_keys = set(list(pretrained_checkpoint.keys()))\n            model_keys = set(list(model_state_dict.keys()))\n            common_keys = list(pretrained_keys & model_keys)\n            checkpoint = {k: pretrained_checkpoint[k] for k in common_keys if model_state_dict[k].numel() == pretrained_checkpoint[k].numel()}\n            missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False)\n        elif os.path.isdir(args.pretrained):\n            model = Diffusion_models_class[args.model].from_pretrained(args.pretrained)\n            missing_keys, unexpected_keys = [], []\n        else:\n            pretrained_checkpoint = torch.load(args.pretrained, map_location='cpu')\n            if 'model' in checkpoint:\n                pretrained_checkpoint = pretrained_checkpoint['model']\n                pretrained_keys = set(list(pretrained_checkpoint.keys()))\n            model_keys = set(list(model_state_dict.keys()))\n            common_keys = list(pretrained_keys & model_keys)\n            checkpoint = {k: pretrained_checkpoint[k] for k in common_keys if model_state_dict[k].numel() == pretrained_checkpoint[k].numel()}\n            missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False)\n        print(f'missing_keys {len(missing_keys)} {missing_keys}, unexpected_keys {len(unexpected_keys)}')\n        print(f'Successfully load {len(model_state_dict) - len(missing_keys)}/{len(model_state_dict)} keys from {args.pretrained}!')\n\n    model.gradient_checkpointing = args.gradient_checkpointing\n    # Freeze vae and text encoders.\n    ae.vae.requires_grad_(False)\n    text_enc_1.requires_grad_(False)\n    if text_enc_2 is not None:\n        text_enc_2.requires_grad_(False)\n    # Set model as trainable.\n    model.train()\n\n    kwargs = dict(\n        prediction_type=args.prediction_type, \n        rescale_betas_zero_snr=args.rescale_betas_zero_snr\n    )\n    if args.cogvideox_scheduler:\n        noise_scheduler = CogVideoXDDIMScheduler(**kwargs)\n    elif args.v1_5_scheduler:\n        kwargs['beta_start'] = 0.00085\n        kwargs['beta_end'] = 0.0120\n        kwargs['beta_schedule'] = \"scaled_linear\"\n        noise_scheduler = DDPMScheduler(**kwargs)\n    elif args.rf_scheduler:\n        noise_scheduler = FlowMatchEulerDiscreteScheduler()\n        noise_scheduler_copy = copy.deepcopy(noise_scheduler)\n    else:\n        noise_scheduler = DDPMScheduler(**kwargs)\n    # Move unet, vae and text_encoder to device and cast to weight_dtype\n    # The VAE is in float32 to avoid NaN losses.\n    if not args.extra_save_mem:\n        ae.vae.to(accelerator.device, dtype=torch.float32 if args.vae_fp32 else weight_dtype)\n        text_enc_1.to(accelerator.device, dtype=weight_dtype)\n        if text_enc_2 is not None:\n            text_enc_2.to(accelerator.device, dtype=weight_dtype)\n\n    # Create EMA for the unet.\n    if args.use_ema:\n        ema_model = deepcopy(model)\n        ema_model = EMAModel(ema_model.parameters(), decay=args.ema_decay, update_after_step=args.ema_start_step,\n                             model_cls=Diffusion_models_class[args.model], model_config=ema_model.config, \n                             foreach=args.foreach_ema)\n\n    # `accelerate` 0.16.0 will have better support for customized saving\n    if version.parse(accelerate.__version__) >= version.parse(\"0.16.0\"):\n        # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n        def save_model_hook(models, weights, output_dir):\n            if accelerator.is_main_process:\n                if args.use_ema:\n                    ema_model.save_pretrained(os.path.join(output_dir, \"model_ema\"))\n\n                for i, model in enumerate(models):\n                    model.save_pretrained(os.path.join(output_dir, \"model\"))\n                    if weights:  # Don't pop if empty\n                        # make sure to pop weight so that corresponding model is not saved again\n                        weights.pop()\n\n        def load_model_hook(models, input_dir):\n            if args.use_ema:\n                load_model = EMAModel.from_pretrained(\n                    os.path.join(input_dir, \"model_ema\"), \n                    Diffusion_models_class[args.model], \n                    foreach=args.foreach_ema, \n                    )\n                ema_model.load_state_dict(load_model.state_dict())\n                if args.offload_ema:\n                    ema_model.pin_memory()\n                else:\n                    ema_model.to(accelerator.device)\n                del load_model\n\n            for i in range(len(models)):\n                # pop models so that they are not loaded again\n                model = models.pop()\n\n                # load diffusers style into model\n                load_model = Diffusion_models_class[args.model].from_pretrained(input_dir, subfolder=\"model\")\n                model.register_to_config(**load_model.config)\n\n                model.load_state_dict(load_model.state_dict())\n                del load_model\n\n        accelerator.register_save_state_pre_hook(save_model_hook)\n        accelerator.register_load_state_pre_hook(load_model_hook)\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    params_to_optimize = model.parameters()\n    # Optimizer creation\n    if not (args.optimizer.lower() == \"prodigy\" or args.optimizer.lower() == \"adamw\"):\n        logger.warning(\n            f\"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy].\"\n            \"Defaulting to adamW\"\n        )\n        args.optimizer = \"adamw\"\n\n    if args.use_8bit_adam and not args.optimizer.lower() == \"adamw\":\n        logger.warning(\n            f\"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was \"\n            f\"set to {args.optimizer.lower()}\"\n        )\n\n    if args.optimizer.lower() == \"adamw\":\n        if args.use_8bit_adam:\n            try:\n                import bitsandbytes as bnb\n            except ImportError:\n                raise ImportError(\n                    \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n                )\n\n            optimizer_class = bnb.optim.AdamW8bit\n        else:\n            optimizer_class = torch.optim.AdamW\n\n        optimizer = optimizer_class(\n            params_to_optimize,\n            lr=args.learning_rate,\n            betas=(args.adam_beta1, args.adam_beta2),\n            weight_decay=args.adam_weight_decay,\n            eps=args.adam_epsilon,\n        )\n\n    if args.optimizer.lower() == \"prodigy\":\n        try:\n            import prodigyopt\n        except ImportError:\n            raise ImportError(\"To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`\")\n\n        optimizer_class = prodigyopt.Prodigy\n\n        if args.learning_rate <= 0.1:\n            logger.warning(\n                \"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0\"\n            )\n\n        optimizer = optimizer_class(\n            params_to_optimize,\n            lr=args.learning_rate,\n            betas=(args.adam_beta1, args.adam_beta2),\n            beta3=args.prodigy_beta3,\n            weight_decay=args.adam_weight_decay,\n            eps=args.adam_epsilon,\n            decouple=args.prodigy_decouple,\n            use_bias_correction=args.prodigy_use_bias_correction,\n            safeguard_warmup=args.prodigy_safeguard_warmup,\n        )\n    logger.info(f\"optimizer: {optimizer}\")\n\n    # Setup data:\n    if args.trained_data_global_step is not None:\n        initial_global_step_for_sampler = args.trained_data_global_step\n    else:\n        initial_global_step_for_sampler = 0\n    \n\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n    total_batch_size = total_batch_size // args.sp_size * args.train_sp_batch_size\n    args.total_batch_size = total_batch_size\n    if args.max_hxw is not None and args.min_hxw is None:\n        args.min_hxw = args.max_hxw // 4\n    train_dataset = getdataset(args)\n    sampler = LengthGroupedSampler(\n                args.train_batch_size,\n                world_size=accelerator.num_processes, \n                gradient_accumulation_size=args.gradient_accumulation_steps, \n                initial_global_step=initial_global_step_for_sampler, \n                lengths=train_dataset.lengths, \n                group_data=args.group_data, \n            )\n    train_dataloader = DataLoader(\n        train_dataset,\n        shuffle=False,\n        pin_memory=True,\n        collate_fn=Collate(args),\n        batch_size=args.train_batch_size,\n        num_workers=args.dataloader_num_workers,\n        sampler=sampler, \n        drop_last=True, \n        # prefetch_factor=4\n    )\n    logger.info(f'after train_dataloader')\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,\n        num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,\n    )\n\n    # Prepare everything with our `accelerator`.\n    # model.requires_grad_(False)\n    # model.pos_embed.requires_grad_(True)\n    # model.patch_embed.requires_grad_(True)\n\n    logger.info(f'before accelerator.prepare')\n    model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        model, optimizer, train_dataloader, lr_scheduler\n    )\n    logger.info(f'after accelerator.prepare')\n    \n    if args.use_ema:\n        if args.offload_ema:\n            ema_model.pin_memory()\n        else:\n            ema_model.to(accelerator.device)\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        accelerator.init_trackers(os.path.basename(args.output_dir), config=vars(args))\n\n    # Train!\n    print(f\"  Args = {args}\")\n    print(f\"  noise_scheduler = {noise_scheduler}\")\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Model = {model}\")\n    logger.info(f\"  Args = {args}\")\n    logger.info(f\"  Noise_scheduler = {noise_scheduler}\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    logger.info(f\"  Total optimization steps (num_update_steps_per_epoch) = {num_update_steps_per_epoch}\")\n    logger.info(f\"  Total training parameters = {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e9} B\")\n    \n    logger.info(f\"  AutoEncoder = {args.ae}; Dtype = {ae.vae.dtype}; Parameters = {sum(p.numel() for p in ae.parameters()) / 1e9} B\")\n    logger.info(f\"  Text_enc_1 = {args.text_encoder_name_1}; Dtype = {weight_dtype}; Parameters = {sum(p.numel() for p in text_enc_1.parameters()) / 1e9} B\")\n    if args.text_encoder_name_2 is not None:\n        logger.info(f\"  Text_enc_2 = {args.text_encoder_name_2}; Dtype = {weight_dtype}; Parameters = {sum(p.numel() for p in text_enc_2.parameters()) / 1e9} B\")\n\n    global_step = 0\n    first_epoch = 0\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n    progress_info = ProgressInfo(global_step, train_loss=0.0)\n\n    \n    def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):\n        sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)\n        schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)\n        timesteps = timesteps.to(accelerator.device)\n        step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]\n\n        sigma = sigmas[step_indices].flatten()\n        while len(sigma.shape) < n_dim:\n            sigma = sigma.unsqueeze(-1)\n        return sigma\n\n    def sync_gradients_info(loss):\n        # Checks if the accelerator has performed an optimization step behind the scenes\n        if args.use_ema:\n            if args.offload_ema:\n                ema_model.to(device=\"cuda\", non_blocking=True)\n            ema_model.step(model.parameters())\n            if args.offload_ema:\n                ema_model.to(device=\"cpu\", non_blocking=True)\n        progress_bar.update(1)\n        progress_info.global_step += 1\n        end_time = time.time()\n        one_step_duration = end_time - start_time\n        if progress_info.global_step % args.log_interval == 0:\n            train_loss = progress_info.train_loss.item() / args.log_interval\n            accelerator.log({\"train_loss\": train_loss, \"lr\": lr_scheduler.get_last_lr()[0]}, step=progress_info.global_step)\n            if torch_npu is not None and npu_config is not None:\n                npu_config.print_msg(f\"Step: [{progress_info.global_step}], local_loss={loss.detach().item()}, \"\n                                    f\"train_loss={train_loss}, time_cost={one_step_duration}\",\n                                    rank=0)\n            progress_info.train_loss = torch.tensor(0.0, device=loss.device)\n\n        # DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues.\n        if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:\n            if progress_info.global_step % args.checkpointing_steps == 0:\n                # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                if accelerator.is_main_process and args.checkpoints_total_limit is not None:\n                    checkpoints = os.listdir(args.output_dir)\n                    checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                    checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                    # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                    if len(checkpoints) >= args.checkpoints_total_limit:\n                        num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                        removing_checkpoints = checkpoints[0:num_to_remove]\n\n                        logger.info(\n                            f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                        )\n                        logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                        for removing_checkpoint in removing_checkpoints:\n                            removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                            shutil.rmtree(removing_checkpoint)\n\n                save_path = os.path.join(args.output_dir, f\"checkpoint-{progress_info.global_step}\")\n                accelerator.save_state(save_path)\n                logger.info(f\"Saved state to {save_path}\")\n\n        logs = {\"step_loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n        progress_bar.set_postfix(**logs)\n\n    def run(step_, model_input, model_kwargs, prof):\n        # print(\"rank {} | step {} | cd run fun\".format(accelerator.process_index, step_))\n        global start_time\n        start_time = time.time()\n\n        noise = torch.randn_like(model_input)\n        bsz = model_input.shape[0]\n        if not args.rf_scheduler:\n            if args.noise_offset:\n                # https://www.crosslabs.org//blog/diffusion-with-offset-noise\n                noise += args.noise_offset * torch.randn((model_input.shape[0], model_input.shape[1], 1, 1, 1),\n                                                        device=model_input.device)\n\n            # Sample a random timestep for each image without bias.\n            timesteps = explicit_uniform_sampling(\n                T=noise_scheduler.config.num_train_timesteps, \n                n=accelerator.num_processes, \n                rank=accelerator.process_index, \n                bsz=bsz, device=model_input.device, \n                )\n            if get_sequence_parallel_state():  # image do not need sp, disable when image batch\n                broadcast(timesteps)\n\n            # Add noise to the model input according to the noise magnitude at each timestep\n            # (this is the forward diffusion process)\n\n            noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)\n        else:\n            # Sample a random timestep for each image\n            # for weighting schemes where we sample timesteps non-uniformly\n            u = compute_density_for_timestep_sampling(\n                weighting_scheme=args.weighting_scheme,\n                batch_size=bsz,\n                logit_mean=args.logit_mean,\n                logit_std=args.logit_std,\n                mode_scale=args.mode_scale,\n            )\n            indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()\n            timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)\n\n            # Add noise according to flow matching.\n            # zt = (1 - texp) * x + texp * z1\n            sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)\n            noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise\n\n        model_pred = model(\n            noisy_model_input,\n            timesteps,\n            **model_kwargs\n        )[0]\n        mask = model_kwargs.get('attention_mask', None)\n        if not args.rf_scheduler:\n            # Get the target for loss depending on the prediction type\n            if noise_scheduler.config.prediction_type == \"epsilon\":\n                target = noise\n            elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                target = noise_scheduler.get_velocity(model_input, noise, timesteps)\n            elif noise_scheduler.config.prediction_type == \"sample\":\n                # We set the target to latents here, but the model_pred will return the noise sample prediction.\n                target = model_input\n                # We will have to subtract the noise residual from the prediction to get the target sample.\n                model_pred = model_pred - noise\n            else:\n                raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n            if get_sequence_parallel_state():\n                if torch.all(mask.bool()):\n                    mask = None\n                # mask    (sp_bs*b t h w)\n                assert mask is None\n            b, c, _, _, _ = model_pred.shape\n            if mask is not None:\n                mask = mask.unsqueeze(1).repeat(1, c, 1, 1, 1).float()  # b t h w -> b c t h w\n                mask = mask.reshape(b, -1)\n            if args.snr_gamma is None:\n                # model_pred: b c t h w, attention_mask: b t h w\n                loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"none\")\n                loss = loss.reshape(b, -1)\n                if mask is not None:\n                    loss = (loss * mask).sum() / mask.sum()  # mean loss on unpad patches\n                else:\n                    loss = loss.mean()\n            else:\n                # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.\n                # Since we predict the noise instead of x_0, the original formulation is slightly changed.\n                # This is discussed in Section 4.2 of the same paper.\n                snr = compute_snr(noise_scheduler, timesteps)\n                mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(\n                    dim=1\n                )[0]\n                if noise_scheduler.config.prediction_type == \"epsilon\":\n                    mse_loss_weights = mse_loss_weights / snr\n                elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                    mse_loss_weights = mse_loss_weights / (snr + 1)\n                else:\n                    raise NameError(f'{noise_scheduler.config.prediction_type}')\n                loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"none\")\n                loss = loss.reshape(b, -1)\n                mse_loss_weights = mse_loss_weights.reshape(b, 1)\n                if mask is not None:\n                    loss = (loss * mask * mse_loss_weights).sum() / mask.sum()  # mean loss on unpad patches\n                else:\n                    loss = (loss * mse_loss_weights).mean()\n        else:\n            if torch.all(mask.bool()):\n                mask = None\n\n            b, c, _, _, _ = model_pred.shape\n            if mask is not None:\n                mask = mask.unsqueeze(1).repeat(1, c, 1, 1, 1).float()  # b t h w -> b c t h w\n                mask = mask.reshape(b, -1)\n\n            # these weighting schemes use a uniform timestep sampling\n            # and instead post-weight the loss\n            weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)\n\n            # flow matching loss\n            target = noise - model_input\n\n            # Compute regular loss.\n            loss_mse = (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1)\n            if mask is not None:\n                loss = (loss_mse * mask).sum() / mask.sum()\n            else:\n                loss = loss_mse.mean()\n\n\n\n        # Gather the losses across all processes for logging (if we use distributed training).\n        avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()\n        # avg_loss = accelerator.reduce(loss, reduction=\"mean\")\n        # progress_info.train_loss += avg_loss.detach().item() / args.gradient_accumulation_steps\n        progress_info.train_loss += avg_loss.detach() / args.gradient_accumulation_steps\n        # Backpropagate\n        accelerator.backward(loss)\n        if accelerator.sync_gradients:\n            params_to_clip = model.parameters()\n            accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n        optimizer.step()\n        lr_scheduler.step()\n        optimizer.zero_grad()\n        if accelerator.sync_gradients:\n            sync_gradients_info(loss)\n\n        if accelerator.is_main_process:\n\n            if progress_info.global_step % args.checkpointing_steps == 0:\n\n                if args.enable_tracker:\n                    log_validation(\n                        args, model, ae, [text_enc_1.text_enc, getattr(text_enc_2, 'text_enc', None)], \n                        train_dataset.tokenizer, accelerator, weight_dtype, progress_info.global_step\n                    )\n\n                    if args.use_ema and npu_config is None:\n                        # Store the UNet parameters temporarily and load the EMA parameters to perform inference.\n                        ema_model.store(model.parameters())\n                        ema_model.copy_to(model.parameters())\n                        log_validation(\n                            args, model, ae, [text_enc_1.text_enc, getattr(text_enc_2, 'text_enc', None)], \n                            train_dataset.tokenizer, accelerator, weight_dtype, progress_info.global_step, ema=True\n                        )\n                        # Switch back to the original UNet parameters.\n                        ema_model.restore(model.parameters())\n\n        if prof is not None:\n            prof.step()\n\n        return loss\n\n    def train_one_step(step_, data_item_, prof_=None):\n        train_loss = 0.0\n        # print(\"rank {} | step {} | unzip data\".format(accelerator.process_index, step_))\n        x, attn_mask, input_ids_1, cond_mask_1, input_ids_2, cond_mask_2 = data_item_\n        # print(f'step: {step_}, rank: {accelerator.process_index}, x: {x.shape}, dtype: {x.dtype}')\n        # assert not torch.any(torch.isnan(x)), 'torch.any(torch.isnan(x))'\n        if args.extra_save_mem:\n            torch.cuda.empty_cache()\n            ae.vae.to(accelerator.device, dtype=torch.float32 if args.vae_fp32 else weight_dtype)\n            text_enc_1.to(accelerator.device, dtype=weight_dtype)\n            if text_enc_2 is not None:\n                text_enc_2.to(accelerator.device, dtype=weight_dtype)\n\n        x = x.to(accelerator.device, dtype=ae.vae.dtype)  # B C T H W\n        # x = x.to(accelerator.device, dtype=torch.float32)  # B C T H W\n        attn_mask = attn_mask.to(accelerator.device)  # B T H W\n        input_ids_1 = input_ids_1.to(accelerator.device)  # B 1 L\n        cond_mask_1 = cond_mask_1.to(accelerator.device)  # B 1 L\n        input_ids_2 = input_ids_2.to(accelerator.device) if input_ids_2 is not None else input_ids_2 # B 1 L\n        cond_mask_2 = cond_mask_2.to(accelerator.device) if cond_mask_2 is not None else cond_mask_2 # B 1 L\n        \n        with torch.no_grad():\n            B, N, L = input_ids_1.shape  # B 1 L\n            # use batch inference\n            input_ids_1 = input_ids_1.reshape(-1, L)\n            cond_mask_1 = cond_mask_1.reshape(-1, L)\n            cond_1 = text_enc_1(input_ids_1, cond_mask_1)  # B L D\n            cond_1 = cond_1.reshape(B, N, L, -1)\n            cond_mask_1 = cond_mask_1.reshape(B, N, L)\n            if text_enc_2 is not None:\n                B_, N_, L_ = input_ids_2.shape  # B 1 L\n                input_ids_2 = input_ids_2.reshape(-1, L_)\n                cond_2 = text_enc_2(input_ids_2, cond_mask_2)  # B D\n                cond_2 = cond_2.reshape(B_, 1, -1)  # B 1 D\n            else:\n                cond_2 = None\n\n            # Map input images to latent space + normalize latents\n            x = ae.encode(x)  # B C T H W\n            # print(f'step: {step_}, rank: {accelerator.process_index}, after vae.encode, x: {x.shape}, dtype: {x.dtype}, mean: {x.mean()}, std: {x.std()}')\n            # x = torch.rand(1, 32, 14, 80, 80).to(x.device, dtype=x.dtype)\n            # def custom_to_video(x: torch.Tensor, fps: float = 2.0, output_file: str = 'output_video.mp4') -> None:\n            #     from examples.rec_video import array_to_video\n            #     x = x.detach().cpu()\n            #     x = torch.clamp(x, -1, 1)\n            #     x = (x + 1) / 2\n            #     x = x.permute(1, 2, 3, 0).numpy()\n            #     x = (255*x).astype(np.uint8)\n            #     array_to_video(x, fps=fps, output_file=output_file)\n            #     return\n            # videos = ae.decode(x)[0]\n            # videos = videos.transpose(0, 1)\n            # custom_to_video(videos.to(torch.float32), fps=24, output_file='tmp.mp4')\n            # import sys;sys.exit()\n            \n        # print(\"rank {} | step {} | after encode\".format(accelerator.process_index, step_))\n        if args.extra_save_mem:\n            ae.vae.to('cpu')\n            text_enc_1.to('cpu')\n            if text_enc_2 is not None:\n                text_enc_2.to('cpu')\n            torch.cuda.empty_cache()\n\n        current_step_frame = x.shape[2]\n        current_step_sp_state = get_sequence_parallel_state()\n        if args.sp_size != 1:  # enable sp\n            if current_step_frame == 1:  # but image do not need sp\n                set_sequence_parallel_state(False)\n            else:\n                set_sequence_parallel_state(True)\n        if get_sequence_parallel_state():\n            x, cond_1, attn_mask, cond_mask_1, cond_2 = prepare_parallel_data(\n                x, cond_1, attn_mask, cond_mask_1, cond_2\n                )        \n            # x            (b c t h w)   -gather0-> (sp*b c t h w)   -scatter2-> (sp*b c t//sp h w)\n            # cond_1       (b sp l/sp d) -gather0-> (sp*b sp l/sp d) -scatter1-> (sp*b 1 l/sp d)\n            # attn_mask    (b t*sp h w)  -gather0-> (sp*b t*sp h w)  -scatter1-> (sp*b t h w)\n            # cond_mask_1  (b sp l)      -gather0-> (sp*b sp l)      -scatter1-> (sp*b 1 l)\n            # cond_2       (b sp d)      -gather0-> (sp*b sp d)      -scatter1-> (sp*b 1 d)\n            for iter in range(args.train_batch_size * args.sp_size // args.train_sp_batch_size):\n                with accelerator.accumulate(model):\n                    # x            (sp_bs*b c t//sp h w)\n                    # cond_1       (sp_bs*b 1 l/sp d)\n                    # attn_mask    (sp_bs*b t h w)\n                    # cond_mask_1  (sp_bs*b 1 l)\n                    # cond_2       (sp_bs*b 1 d)\n                    st_idx = iter * args.train_sp_batch_size\n                    ed_idx = (iter + 1) * args.train_sp_batch_size\n                    model_kwargs = dict(\n                        encoder_hidden_states=cond_1[st_idx: ed_idx],\n                        attention_mask=attn_mask[st_idx: ed_idx],\n                        encoder_attention_mask=cond_mask_1[st_idx: ed_idx], \n                        pooled_projections=cond_2[st_idx: ed_idx] if cond_2 is not None else None, \n                        )\n                    run(step_, x[st_idx: ed_idx], model_kwargs, prof_)\n        else:\n            with accelerator.accumulate(model):\n                # assert not torch.any(torch.isnan(x)), 'after vae'\n                x = x.to(weight_dtype)\n                model_kwargs = dict(\n                    encoder_hidden_states=cond_1, attention_mask=attn_mask, \n                    encoder_attention_mask=cond_mask_1, \n                    pooled_projections=cond_2\n                    )\n                run(step_, x, model_kwargs, prof_)\n\n        set_sequence_parallel_state(current_step_sp_state)  # in case the next step use sp, which need broadcast(timesteps)\n\n        if progress_info.global_step >= args.max_train_steps:\n            return True\n\n        return False\n\n    def train_one_epoch(prof_=None):\n        # for epoch in range(first_epoch, args.num_train_epochs):\n        progress_info.train_loss = 0.0\n        if progress_info.global_step >= args.max_train_steps:\n            return True\n        for step, data_item in enumerate(train_dataloader):\n            # print(\"rank {} | step {} | get data\".format(accelerator.process_index, step))\n            if train_one_step(step, data_item, prof_):\n                break\n\n            if step >= 2 and torch_npu is not None and npu_config is not None:\n                npu_config.free_mm()\n\n    if npu_config is not None and npu_config.on_npu and npu_config.profiling:\n        experimental_config = torch_npu.profiler._ExperimentalConfig(\n            profiler_level=torch_npu.profiler.ProfilerLevel.Level1,\n            aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization\n        )\n        profile_output_path = f\"/home/image_data/npu_profiling_t2v/{os.getenv('PROJECT_NAME', 'local')}\"\n        os.makedirs(profile_output_path, exist_ok=True)\n\n        with torch_npu.profiler.profile(\n                activities=[\n                    torch_npu.profiler.ProfilerActivity.CPU, \n                    torch_npu.profiler.ProfilerActivity.NPU, \n                    ],\n                with_stack=True,\n                record_shapes=True,\n                profile_memory=True,\n                experimental_config=experimental_config,\n                schedule=torch_npu.profiler.schedule(\n                    wait=npu_config.profiling_step, warmup=0, active=1, repeat=1, skip_first=0\n                    ),\n                on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(f\"{profile_output_path}/\")\n        ) as prof:\n            train_one_epoch(prof)\n    else:\n        if args.enable_profiling:\n            with torch.profiler.profile(\n                activities=[\n                    torch.profiler.ProfilerActivity.CPU, \n                    torch.profiler.ProfilerActivity.CUDA, \n                    ], \n                schedule=torch.profiler.schedule(wait=5, warmup=1, active=1, repeat=1, skip_first=0),\n                on_trace_ready=torch.profiler.tensorboard_trace_handler('./gpu_profiling_active_1_delmask_delbkmask_andvaemask_curope_gpu'),\n                record_shapes=True,\n                profile_memory=True,\n                with_stack=True\n            ) as prof:\n                train_one_epoch(prof)\n        else:\n            train_one_epoch()\n    accelerator.wait_for_everyone()\n    accelerator.end_training()\n    if get_sequence_parallel_state():\n        destroy_sequence_parallel_group()\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n\n    # dataset & dataloader\n    parser.add_argument(\"--dataset\", type=str, required=True)\n    parser.add_argument(\"--data\", type=str, required='')\n    parser.add_argument(\"--sample_rate\", type=int, default=1)\n    parser.add_argument(\"--train_fps\", type=int, default=24)\n    parser.add_argument(\"--drop_short_ratio\", type=float, default=1.0)\n    parser.add_argument(\"--speed_factor\", type=float, default=1.0)\n    parser.add_argument(\"--num_frames\", type=int, default=65)\n    parser.add_argument(\"--max_height\", type=int, default=320)\n    parser.add_argument(\"--max_width\", type=int, default=240)\n    parser.add_argument(\"--max_hxw\", type=int, default=None)\n    parser.add_argument(\"--min_hxw\", type=int, default=None)\n    parser.add_argument(\"--ood_img_ratio\", type=float, default=0.0)\n    parser.add_argument(\"--use_img_from_vid\", action=\"store_true\")\n    parser.add_argument(\"--model_max_length\", type=int, default=512)\n    parser.add_argument('--cfg', type=float, default=0.1)\n    parser.add_argument(\"--dataloader_num_workers\", type=int, default=10, help=\"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\")\n    parser.add_argument(\"--train_batch_size\", type=int, default=16, help=\"Batch size (per device) for the training dataloader.\")\n    parser.add_argument(\"--group_data\", action=\"store_true\")\n    parser.add_argument(\"--hw_stride\", type=int, default=32)\n    parser.add_argument(\"--force_resolution\", action=\"store_true\")\n    parser.add_argument(\"--trained_data_global_step\", type=int, default=None)\n    parser.add_argument(\"--use_decord\", action=\"store_true\")\n\n    # text encoder & vae & diffusion model\n    parser.add_argument('--vae_fp32', action='store_true')\n    parser.add_argument('--extra_save_mem', action='store_true')\n    parser.add_argument(\"--model\", type=str, choices=list(Diffusion_models.keys()), default=\"Latte-XL/122\")\n    parser.add_argument('--enable_tiling', action='store_true')\n    parser.add_argument('--interpolation_scale_h', type=float, default=1.0)\n    parser.add_argument('--interpolation_scale_w', type=float, default=1.0)\n    parser.add_argument('--interpolation_scale_t', type=float, default=1.0)\n    parser.add_argument(\"--ae\", type=str, default=\"stabilityai/sd-vae-ft-mse\")\n    parser.add_argument(\"--ae_path\", type=str, default=\"stabilityai/sd-vae-ft-mse\")\n    parser.add_argument(\"--text_encoder_name_1\", type=str, default='DeepFloyd/t5-v1_1-xxl')\n    parser.add_argument(\"--text_encoder_name_2\", type=str, default=None)\n    parser.add_argument(\"--cache_dir\", type=str, default='./cache_dir')\n    parser.add_argument(\"--pretrained\", type=str, default=None)\n    parser.add_argument('--sparse1d', action='store_true')\n    parser.add_argument('--sparse_n', type=int, default=2)\n    parser.add_argument('--skip_connection', action='store_true')\n    parser.add_argument('--cogvideox_scheduler', action='store_true')\n    parser.add_argument('--v1_5_scheduler', action='store_true')\n    parser.add_argument('--rf_scheduler', action='store_true')\n    parser.add_argument(\"--weighting_scheme\", type=str, default=\"logit_normal\", choices=[\"sigma_sqrt\", \"logit_normal\", \"mode\", \"cosmap\"])\n    parser.add_argument(\"--logit_mean\", type=float, default=0.0, help=\"mean to use when using the `'logit_normal'` weighting scheme.\")\n    parser.add_argument(\"--logit_std\", type=float, default=1.0, help=\"std to use when using the `'logit_normal'` weighting scheme.\")\n    parser.add_argument(\"--mode_scale\", type=float, default=1.29, help=\"Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.\")\n    parser.add_argument(\"--gradient_checkpointing\", action=\"store_true\", help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\")\n\n    # diffusion setting\n    parser.add_argument(\"--snr_gamma\", type=float, default=None, help=\"SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. More details here: https://arxiv.org/abs/2303.09556.\")\n    parser.add_argument(\"--use_ema\", action=\"store_true\", help=\"Whether to use EMA model.\")\n    parser.add_argument(\"--ema_decay\", type=float, default=0.9999)\n    parser.add_argument(\"--ema_start_step\", type=int, default=0)\n    parser.add_argument(\"--offload_ema\", action=\"store_true\", help=\"Offload EMA model to CPU during training step.\")\n    parser.add_argument(\"--foreach_ema\", action=\"store_true\", help=\"Use faster foreach implementation of EMAModel.\")\n    parser.add_argument(\"--noise_offset\", type=float, default=0.0, help=\"The scale of noise offset.\")\n    parser.add_argument(\"--prediction_type\", type=str, default='epsilon', help=\"The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.\")\n    parser.add_argument('--rescale_betas_zero_snr', action='store_true')\n\n    # validation & logs\n    parser.add_argument(\"--log_interval\", type=int, default=10)\n    parser.add_argument(\"--enable_profiling\", action=\"store_true\")\n    parser.add_argument(\"--num_sampling_steps\", type=int, default=20)\n    parser.add_argument('--guidance_scale', type=float, default=4.5)\n    parser.add_argument(\"--enable_tracker\", action=\"store_true\")\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\"--output_dir\", type=str, default=None, help=\"The output directory where the model predictions and checkpoints will be written.\")\n    parser.add_argument(\"--checkpoints_total_limit\", type=int, default=None, help=(\"Max number of checkpoints to store.\"))\n    parser.add_argument(\"--checkpointing_steps\", type=int, default=500,\n                        help=(\n                            \"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final\"\n                            \" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming\"\n                            \" training using `--resume_from_checkpoint`.\"\n                        ),\n                        )\n    parser.add_argument(\"--resume_from_checkpoint\", type=str, default=None,\n                        help=(\n                            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n                            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n                        ),\n                        )\n    parser.add_argument(\"--logging_dir\", type=str, default=\"logs\",\n                        help=(\n                            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n                            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n                        ),\n                        )\n    parser.add_argument(\"--report_to\", type=str, default=\"tensorboard\",\n                        help=(\n                            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n                            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n                        ),\n                        )\n    \n    # optimizer & scheduler\n    parser.add_argument(\"--num_train_epochs\", type=int, default=100)\n    parser.add_argument(\"--max_train_steps\", type=int, default=None, help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\")\n    parser.add_argument(\"--gradient_accumulation_steps\", type=int, default=1, help=\"Number of updates steps to accumulate before performing a backward/update pass.\")\n    parser.add_argument(\"--optimizer\", type=str, default=\"adamW\", help='The optimizer type to use. Choose between [\"AdamW\", \"prodigy\"]')\n    parser.add_argument(\"--learning_rate\", type=float, default=1e-4, help=\"Initial learning rate (after the potential warmup period) to use.\")\n    parser.add_argument(\"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\")\n    parser.add_argument(\"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW\")\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam and Prodigy optimizers.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam and Prodigy optimizers.\")\n    parser.add_argument(\"--prodigy_decouple\", type=bool, default=True, help=\"Use AdamW style decoupled weight decay\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-02, help=\"Weight decay to use for unet params\")\n    parser.add_argument(\"--adam_weight_decay_text_encoder\", type=float, default=None, help=\"Weight decay to use for text_encoder\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-15, help=\"Epsilon value for the Adam optimizer and Prodigy optimizers.\")\n    parser.add_argument(\"--prodigy_use_bias_correction\", type=bool, default=True, help=\"Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW\")\n    parser.add_argument(\"--prodigy_safeguard_warmup\", type=bool, default=True, help=\"Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. Ignored if optimizer is adamW\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--prodigy_beta3\", type=float, default=None,\n                        help=\"coefficients for computing the Prodidy stepsize using running averages. If set to None, \"\n                             \"uses the value of square root of beta2. Ignored if optimizer is adamW\",\n                        )\n    parser.add_argument(\"--lr_scheduler\", type=str, default=\"constant\",\n                        help=(\n                            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n                            ' \"constant\", \"constant_with_warmup\"]'\n                        ),\n                        )\n    parser.add_argument(\"--allow_tf32\", action=\"store_true\",\n                        help=(\n                            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n                            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n                        ),\n                        )\n    parser.add_argument(\"--mixed_precision\", type=str, default=None, choices=[\"no\", \"fp16\", \"bf16\"],\n                        help=(\n                            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n                            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n                            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n                        ),\n                        )\n\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\"--sp_size\", type=int, default=1, help=\"For sequence parallel\")\n    parser.add_argument(\"--train_sp_batch_size\", type=int, default=1, help=\"Batch size for sequence parallel training\")\n\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "opensora/utils/communications.py",
    "content": "import torch\nimport torch.distributed as dist\nfrom einops import rearrange\nfrom opensora.utils.parallel_states import nccl_info\n\ndef broadcast(input_: torch.Tensor):\n    sp_size = nccl_info.world_size\n    src = nccl_info.rank // sp_size * sp_size\n    dist.broadcast(input_, src=src, group=nccl_info.group)\n\n_COUNT = 0\ndef _all_to_all(\n    input_: torch.Tensor,\n    scatter_dim: int,\n    gather_dim: int,\n):\n    group = nccl_info.group\n    sp_size = nccl_info.world_size\n    input_list = [t.contiguous() for t in torch.tensor_split(input_, sp_size, scatter_dim)]\n    output_list = [torch.empty_like(input_list[0]) for _ in range(sp_size)]\n    dist.all_to_all(output_list, input_list, group=group)\n    return torch.cat(output_list, dim=gather_dim).contiguous()\n\ndef _single_all_to_all(\n    input_: torch.Tensor,\n    scatter_dim: int,\n    gather_dim: int,\n    enable_HCCL=False,\n):\n\n    sp_size = nccl_info.world_size\n    inp_shape = list(input_.shape)\n    inp_shape[scatter_dim] = inp_shape[scatter_dim] // sp_size\n    if scatter_dim < 1:\n        input_t = input_.reshape(\n            [sp_size, inp_shape[scatter_dim]] + \\\n            inp_shape[scatter_dim + 1:]\n        )\n    else:\n        # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them!\n        input_t = input_.reshape(\n            [-1, sp_size, inp_shape[scatter_dim]] + \\\n            inp_shape[scatter_dim + 1:]\n        ).transpose(0, 1).contiguous()\n\n    output = torch.empty_like(input_t)\n    dist.all_to_all_single(output, input_t, group=nccl_info.group)\n    # if scattering the seq-dim, transpose the heads back to the original dimension\n    if scatter_dim < 1:\n        output = output.transpose(0, 1).contiguous()\n\n    return output.reshape(\n        inp_shape[: gather_dim] + [inp_shape[gather_dim] * sp_size, ] + inp_shape[gather_dim + 1:])\n\n\nclass _AllToAll(torch.autograd.Function):\n    \"\"\"All-to-all communication.\n\n    Args:\n        input_: input matrix\n        process_group: communication group\n        scatter_dim: scatter dimension\n        gather_dim: gather dimension\n    \"\"\"\n\n    @staticmethod\n    def forward(ctx, input_, scatter_dim, gather_dim, all_to_all_func):\n        ctx.scatter_dim = scatter_dim\n        ctx.gather_dim = gather_dim\n        ctx.all_to_all = all_to_all_func\n        output = ctx.all_to_all(input_, scatter_dim, gather_dim)\n        return output\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        grad_output = ctx.all_to_all(\n            grad_output,\n            ctx.gather_dim,\n            ctx.scatter_dim,\n        )\n        return (\n            grad_output,\n            None,\n            None,\n            None,\n        )\n\ndef all_to_all_SBH(\n    input_: torch.Tensor,\n    scatter_dim: int = 1,\n    gather_dim: int = 0,\n):\n    return _AllToAll.apply(input_, scatter_dim, gather_dim, _single_all_to_all)\n\ndef all_to_all_BSND(\n    input_: torch.Tensor,\n    scatter_dim: int = 2,\n    gather_dim: int = 1,\n):\n    return _AllToAll.apply(input_, scatter_dim, gather_dim, _all_to_all)\n\n\ndef prepare_parallel_data(\n        hidden_states, \n        encoder_hidden_states, \n        attention_mask, \n        encoder_attention_mask, \n        pooled_projections, \n        ):\n    def all_to_all(\n            hidden_states, \n            encoder_hidden_states, \n            attention_mask, \n            encoder_attention_mask, \n            pooled_projections, \n            ):\n        # hidden_states          (b c t h w)   -gather0-> (sp*b c t h w)   -scatter2-> (sp*b c t//sp h w)\n        # encoder_hidden_states  (b sp l/sp d) -gather0-> (sp*b sp l/sp d) -scatter1-> (sp*b 1 l/sp d)\n        # attention_mask         (b t*sp h w)  -gather0-> (sp*b t*sp h w)  -scatter1-> (sp*b t h w)\n        # encoder_attention_mask (b sp l)      -gather0-> (sp*b sp l)      -scatter1-> (sp*b 1 l)\n        # pooled_projections     (b sp d)      -gather0-> (sp*b sp d)      -scatter1-> (sp*b 1 d)\n        hidden_states = _single_all_to_all(hidden_states, scatter_dim=2, gather_dim=0, enable_HCCL=True)\n        encoder_hidden_states = _single_all_to_all(encoder_hidden_states, scatter_dim=1, gather_dim=0, enable_HCCL=True)\n        attention_mask = _single_all_to_all(attention_mask, scatter_dim=1, gather_dim=0, enable_HCCL=True)\n        encoder_attention_mask = _single_all_to_all(encoder_attention_mask, scatter_dim=1, gather_dim=0, enable_HCCL=True)\n        if pooled_projections is not None:\n            pooled_projections = _single_all_to_all(pooled_projections, scatter_dim=1, gather_dim=0, enable_HCCL=True)\n\n        return hidden_states, encoder_hidden_states, attention_mask, encoder_attention_mask, pooled_projections\n\n    sp_size = nccl_info.world_size\n    frame = hidden_states.shape[2]\n    assert frame % sp_size == 0, \"frame should be a multiple of sp_size\"\n\n    encoder_hidden_states = rearrange(\n        encoder_hidden_states, 'b 1 (n x) h -> b n x h', n=sp_size, x=encoder_hidden_states.shape[2]//sp_size\n        ).contiguous()\n    hidden_states, encoder_hidden_states, attention_mask, encoder_attention_mask, pooled_projections = all_to_all(\n        hidden_states, \n        encoder_hidden_states, \n        attention_mask.repeat(1, sp_size, 1, 1), \n        encoder_attention_mask.repeat(1, sp_size, 1), \n        pooled_projections.repeat(1, sp_size, 1) if pooled_projections is not None else None, \n        )\n\n    return hidden_states, encoder_hidden_states, attention_mask, encoder_attention_mask, pooled_projections"
  },
  {
    "path": "opensora/utils/dataset_utils.py",
    "content": "import math\nfrom einops import rearrange\nimport decord\nfrom torch.nn import functional as F\nimport torch\nfrom typing import Optional\nimport torch.utils\nimport torch.utils.data\nimport torch\nfrom torch.utils.data import Sampler\nfrom typing import List\nfrom collections import Counter, defaultdict\nimport random\n\n\nIMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG']\n\ndef is_image_file(filename):\n    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)\n\nclass DecordInit(object):\n    \"\"\"Using Decord(https://github.com/dmlc/decord) to initialize the video_reader.\"\"\"\n\n    def __init__(self, num_threads=1):\n        self.num_threads = num_threads\n        self.ctx = decord.cpu(0)\n\n    def __call__(self, filename):\n        \"\"\"Perform the Decord initialization.\n        Args:\n            results (dict): The resulting dict to be modified and passed\n                to the next transform in pipeline.\n        \"\"\"\n        reader = decord.VideoReader(filename,\n                                    ctx=self.ctx,\n                                    num_threads=self.num_threads)\n        return reader\n\n    def __repr__(self):\n        repr_str = (f'{self.__class__.__name__}('\n                    f'sr={self.sr},'\n                    f'num_threads={self.num_threads})')\n        return repr_str\n\ndef pad_to_multiple(number, ds_stride):\n    remainder = number % ds_stride\n    if remainder == 0:\n        return number\n    else:\n        padding = ds_stride - remainder\n        return number + padding\n\nclass Collate:\n    def __init__(self, args):\n        self.batch_size = args.train_batch_size\n        self.group_data = args.group_data\n        self.force_resolution = args.force_resolution\n\n        self.max_height = args.max_height\n        self.max_width = args.max_width\n        self.ae_stride = args.ae_stride\n\n        self.ae_stride_t = args.ae_stride_t\n        self.ae_stride_thw = (self.ae_stride_t, self.ae_stride, self.ae_stride)\n\n        self.patch_size = args.patch_size\n        self.patch_size_t = args.patch_size_t\n\n        self.num_frames = args.num_frames\n        self.max_thw = (self.num_frames, self.max_height, self.max_width)\n\n    def package(self, batch):\n        batch_tubes = [i['pixel_values'] for i in batch]  # b [c t h w]\n        input_ids_1 = [i['input_ids_1'] for i in batch]  # b [1 l]\n        cond_mask_1 = [i['cond_mask_1'] for i in batch]  # b [1 l]\n        input_ids_2 = [i['input_ids_2'] for i in batch]  # b [1 l]\n        cond_mask_2 = [i['cond_mask_2'] for i in batch]  # b [1 l]\n        assert all([i is None for i in input_ids_2]) or all([i is not None for i in input_ids_2])\n        assert all([i is None for i in cond_mask_2]) or all([i is not None for i in cond_mask_2])\n        if all([i is None for i in input_ids_2]):\n            input_ids_2 = None\n        if all([i is None for i in cond_mask_2]):\n            cond_mask_2 = None\n        return batch_tubes, input_ids_1, cond_mask_1, input_ids_2, cond_mask_2\n\n    def __call__(self, batch):\n        batch_tubes, input_ids_1, cond_mask_1, input_ids_2, cond_mask_2 = self.package(batch)\n\n        ds_stride = self.ae_stride * self.patch_size\n        t_ds_stride = self.ae_stride_t * self.patch_size_t\n        \n        pad_batch_tubes, attention_mask, input_ids_1, cond_mask_1, input_ids_2, cond_mask_2 = self.process(\n            batch_tubes, input_ids_1, cond_mask_1, input_ids_2, cond_mask_2, \n            t_ds_stride, ds_stride, self.max_thw, self.ae_stride_thw\n        )\n        assert not torch.any(torch.isnan(pad_batch_tubes)), 'after pad_batch_tubes'\n        return pad_batch_tubes, attention_mask, input_ids_1, cond_mask_1, input_ids_2, cond_mask_2\n\n    def process(self, batch_tubes, input_ids_1, cond_mask_1, input_ids_2, cond_mask_2, t_ds_stride, ds_stride, max_thw, ae_stride_thw):\n        # pad to max multiple of ds_stride\n        batch_input_size = [i.shape for i in batch_tubes]  # [(c t h w), (c t h w)]\n        assert len(batch_input_size) == self.batch_size\n        if self.group_data or self.batch_size == 1:  #\n            len_each_batch = batch_input_size\n            idx_length_dict = dict([*zip(list(range(self.batch_size)), len_each_batch)])\n            count_dict = Counter(len_each_batch)\n            if len(count_dict) != 1:\n                sorted_by_value = sorted(count_dict.items(), key=lambda item: item[1])\n                # import ipdb;ipdb.set_trace()\n                # print(batch, idx_length_dict, count_dict, sorted_by_value)\n                pick_length = sorted_by_value[-1][0]  # the highest frequency\n                candidate_batch = [idx for idx, length in idx_length_dict.items() if length == pick_length]\n                random_select_batch = [random.choice(candidate_batch) for _ in range(len(len_each_batch) - len(candidate_batch))]\n                print(batch_input_size, idx_length_dict, count_dict, sorted_by_value, pick_length, candidate_batch, random_select_batch)\n                pick_idx = candidate_batch + random_select_batch\n\n                batch_tubes = [batch_tubes[i] for i in pick_idx]\n                batch_input_size = [i.shape for i in batch_tubes]  # [(c t h w), (c t h w)]\n                input_ids_1 = [input_ids_1[i] for i in pick_idx]  # b [1, l]\n                cond_mask_1 = [cond_mask_1[i] for i in pick_idx]  # b [1, l]\n                if input_ids_2 is not None:\n                    input_ids_2 = [input_ids_2[i] for i in pick_idx]  # b [1, l]\n                if cond_mask_2 is not None:\n                    cond_mask_2 = [cond_mask_2[i] for i in pick_idx]  # b [1, l]\n\n            for i in range(1, self.batch_size):\n                assert batch_input_size[0] == batch_input_size[i]\n            max_t = max([i[1] for i in batch_input_size])\n            max_h = max([i[2] for i in batch_input_size])\n            max_w = max([i[3] for i in batch_input_size])\n        else:\n            max_t, max_h, max_w = max_thw\n        pad_max_t, pad_max_h, pad_max_w = pad_to_multiple(max_t-1+self.ae_stride_t, t_ds_stride), \\\n                                          pad_to_multiple(max_h, ds_stride), \\\n                                          pad_to_multiple(max_w, ds_stride)\n        pad_max_t = pad_max_t + 1 - self.ae_stride_t\n        each_pad_t_h_w = [\n            [\n                pad_max_t - i.shape[1],\n                pad_max_h - i.shape[2],\n                pad_max_w - i.shape[3]\n                ] for i in batch_tubes\n                ]\n        pad_batch_tubes = [\n            F.pad(im, (0, pad_w, 0, pad_h, 0, pad_t), value=0) \n            for (pad_t, pad_h, pad_w), im in zip(each_pad_t_h_w, batch_tubes)\n            ]\n        pad_batch_tubes = torch.stack(pad_batch_tubes, dim=0)\n\n\n        max_tube_size = [pad_max_t, pad_max_h, pad_max_w]\n        max_latent_size = [\n            ((max_tube_size[0]-1) // ae_stride_thw[0] + 1),\n            max_tube_size[1] // ae_stride_thw[1],\n            max_tube_size[2] // ae_stride_thw[2]\n            ]\n        valid_latent_size = [\n            [\n                int(math.ceil((i[1]-1) / ae_stride_thw[0])) + 1,\n                int(math.ceil(i[2] / ae_stride_thw[1])),\n                int(math.ceil(i[3] / ae_stride_thw[2]))\n                ] for i in batch_input_size]\n        attention_mask = [\n            F.pad(torch.ones(i, dtype=pad_batch_tubes.dtype), (0, max_latent_size[2] - i[2], \n                                                               0, max_latent_size[1] - i[1],\n                                                               0, max_latent_size[0] - i[0]), value=0) for i in valid_latent_size]\n        attention_mask = torch.stack(attention_mask)  # b t h w\n        if self.batch_size == 1 or self.group_data:\n            if not torch.all(attention_mask.bool()):\n                print(batch_input_size, (max_t, max_h, max_w), (pad_max_t, pad_max_h, pad_max_w), each_pad_t_h_w, max_latent_size, valid_latent_size)\n            assert torch.all(attention_mask.bool())\n\n        input_ids_1 = torch.stack(input_ids_1)  # b 1 l\n        cond_mask_1 = torch.stack(cond_mask_1)  # b 1 l\n        input_ids_2 = torch.stack(input_ids_2) if input_ids_2 is not None else input_ids_2  # b 1 l\n        cond_mask_2 = torch.stack(cond_mask_2) if cond_mask_2 is not None else cond_mask_2  # b 1 l\n\n        return pad_batch_tubes, attention_mask, input_ids_1, cond_mask_1, input_ids_2, cond_mask_2\n\n\ndef group_data_fun(lengths, generator=None):\n    # counter is decrease order\n    counter = Counter(lengths)  # counter {'1x256x256': 3, ''}   lengths ['1x256x256', '1x256x256', '1x256x256', ...]\n    grouped_indices = defaultdict(list)\n    for idx, item in enumerate(lengths):  # group idx to a list\n        grouped_indices[item].append(idx)\n\n    grouped_indices = dict(grouped_indices)  # {'1x256x256': [0, 1, 2], ...}\n    sorted_indices = [grouped_indices[item] for (item, _) in sorted(counter.items(), key=lambda x: x[1], reverse=True)]\n    \n    # shuffle in each group\n    shuffle_sorted_indices = []\n    for indice in sorted_indices:\n        shuffle_idx = torch.randperm(len(indice), generator=generator).tolist()\n        shuffle_sorted_indices.extend([indice[idx] for idx in shuffle_idx])\n    return shuffle_sorted_indices\n\ndef last_group_data_fun(shuffled_megabatches, lengths):\n    # lengths ['1x256x256', '1x256x256', '1x256x256' ...]\n    re_shuffled_megabatches = []\n    # print('shuffled_megabatches', len(shuffled_megabatches))\n    for i_megabatch, megabatch in enumerate(shuffled_megabatches):\n        re_megabatch = []\n        for i_batch, batch in enumerate(megabatch):\n            assert len(batch) != 0\n                \n            len_each_batch = [lengths[i] for i in batch]  # ['1x256x256', '1x256x256']\n            idx_length_dict = dict([*zip(batch, len_each_batch)])  # {0: '1x256x256', 100: '1x256x256'}\n            count_dict = Counter(len_each_batch)  # {'1x256x256': 2} or {'1x256x256': 1, '1x768x256': 1}\n            if len(count_dict) != 1:\n                sorted_by_value = sorted(count_dict.items(), key=lambda item: item[1])  # {'1x256x256': 1, '1x768x256': 1}\n                # import ipdb;ipdb.set_trace()\n                # print(batch, idx_length_dict, count_dict, sorted_by_value)\n                pick_length = sorted_by_value[-1][0]  # the highest frequency\n                candidate_batch = [idx for idx, length in idx_length_dict.items() if length == pick_length]\n                random_select_batch = [random.choice(candidate_batch) for i in range(len(len_each_batch) - len(candidate_batch))]\n                # print(batch, idx_length_dict, count_dict, sorted_by_value, pick_length, candidate_batch, random_select_batch)\n                batch = candidate_batch + random_select_batch\n                # print(batch)\n\n            for i in range(1, len(batch)-1):\n                # if not lengths[batch[0]] == lengths[batch[i]]:\n                #     print(batch, [lengths[i] for i in batch])\n                #     import ipdb;ipdb.set_trace()\n                assert lengths[batch[0]] == lengths[batch[i]]\n            re_megabatch.append(batch)\n        re_shuffled_megabatches.append(re_megabatch)\n    \n    \n    # for megabatch, re_megabatch in zip(shuffled_megabatches, re_shuffled_megabatches):\n    #     for batch, re_batch in zip(megabatch, re_megabatch):\n    #         for i, re_i in zip(batch, re_batch):\n    #             if i != re_i:\n    #                 print(i, re_i)\n    return re_shuffled_megabatches\n                \ndef split_to_even_chunks(megabatch, lengths, world_size, batch_size):\n    \"\"\"\n    Split a list of indices into `chunks` chunks of roughly equal lengths.\n    \"\"\"\n    # batch_size=2, world_size=2\n    # [1, 2, 3, 4] -> [[1, 2], [3, 4]]\n    # [1, 2, 3] -> [[1, 2], [3]]\n    # [1, 2] -> [[1], [2]]\n    # [1] -> [[1], []]\n    chunks = [megabatch[i::world_size] for i in range(world_size)]\n\n    pad_chunks = []\n    for idx, chunk in enumerate(chunks):\n        if batch_size != len(chunk):  \n            assert batch_size > len(chunk)\n            if len(chunk) != 0:  # [[1, 2], [3]] -> [[1, 2], [3, 3]]\n                chunk = chunk + [random.choice(chunk) for _ in range(batch_size - len(chunk))]\n            else:\n                chunk = random.choice(pad_chunks)  # [[1], []] -> [[1], [1]]\n                print(chunks[idx], '->', chunk)\n        pad_chunks.append(chunk)\n    return pad_chunks\n\ndef get_length_grouped_indices(lengths, batch_size, world_size, gradient_accumulation_size, initial_global_step, generator=None, group_data=False, seed=42):\n    # We need to use torch for the random part as a distributed sampler will set the random seed for torch.\n    if generator is None:\n        generator = torch.Generator().manual_seed(seed)  # every rank will generate a fixed order but random index\n    # print('lengths', lengths)\n    \n    if group_data:\n        indices = group_data_fun(lengths, generator)\n    else:\n        indices = torch.randperm(len(lengths), generator=generator).tolist()\n    # print('indices', len(indices))\n\n    # print('sort indices', len(indices))\n    # print('sort indices', indices)\n    # print('sort lengths', [lengths[i] for i in indices])\n    \n    megabatch_size = world_size * batch_size\n    megabatches = [indices[i: i + megabatch_size] for i in range(0, len(lengths), megabatch_size)]\n    # import ipdb;ipdb.set_trace()\n    # print('megabatches', len(megabatches))\n    # print('\\nmegabatches', megabatches)\n    # megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]\n    # import ipdb;ipdb.set_trace()\n    # print('sort megabatches', len(megabatches))\n    megabatches_len = [[lengths[i] for i in megabatch] for megabatch in megabatches]\n    # print(f'\\nrank {accelerator.process_index} sorted megabatches_len', megabatches_len[0], megabatches_len[1], megabatches_len[-2], megabatches_len[-1])\n    # import ipdb;ipdb.set_trace()\n    megabatches = [split_to_even_chunks(megabatch, lengths, world_size, batch_size) for megabatch in megabatches]\n    # import ipdb;ipdb.set_trace()\n    # print('nsplit_to_even_chunks megabatches', len(megabatches))\n    # print('\\nsplit_to_even_chunks megabatches', megabatches)\n    split_to_even_chunks_len = [[[lengths[i] for i in batch] for batch in megabatch] for megabatch in megabatches]\n    # print(f'\\nrank {accelerator.process_index} split_to_even_chunks_len', split_to_even_chunks_len[0], split_to_even_chunks_len[1], split_to_even_chunks_len[-2], split_to_even_chunks_len[-1])\n    # print('\\nsplit_to_even_chunks len', split_to_even_chunks_len)\n    # return [i for megabatch in megabatches for batch in megabatch for i in batch]\n\n    indices_mega = torch.randperm(len(megabatches), generator=generator).tolist()\n    # print(f'rank {accelerator.process_index} seed {seed}, len(megabatches) {len(megabatches)}, indices_mega, {indices_mega[:50]}')\n    shuffled_megabatches = [megabatches[i] for i in indices_mega]\n    shuffled_megabatches_len = [[[lengths[i] for i in batch] for batch in megabatch] for megabatch in shuffled_megabatches]\n    # print(f'\\nrank {accelerator.process_index} sorted shuffled_megabatches_len', shuffled_megabatches_len[0], shuffled_megabatches_len[1], shuffled_megabatches_len[-2], shuffled_megabatches_len[-1])\n\n    # import ipdb;ipdb.set_trace()\n    # print('shuffled_megabatches', len(shuffled_megabatches))\n    if group_data:\n        shuffled_megabatches = last_group_data_fun(shuffled_megabatches, lengths)\n        group_shuffled_megabatches_len = [[[lengths[i] for i in batch] for batch in megabatch] for megabatch in shuffled_megabatches]\n        # print(f'\\nrank {accelerator.process_index} group_shuffled_megabatches_len', group_shuffled_megabatches_len[0], group_shuffled_megabatches_len[1], group_shuffled_megabatches_len[-2], group_shuffled_megabatches_len[-1])\n    \n    # import ipdb;ipdb.set_trace()\n    initial_global_step = initial_global_step * gradient_accumulation_size\n    # print('shuffled_megabatches', len(shuffled_megabatches))\n    # print('have been trained idx:', len(shuffled_megabatches[:initial_global_step]))\n    # print('shuffled_megabatches[:10]', shuffled_megabatches[:10])\n    # print('have been trained idx:', shuffled_megabatches[:initial_global_step])\n    shuffled_megabatches = shuffled_megabatches[initial_global_step:]\n    print(f'Skip the data of {initial_global_step} step!')\n    # print('after shuffled_megabatches', len(shuffled_megabatches))\n    # print('after shuffled_megabatches[:10]', shuffled_megabatches[:10])\n\n    # print('\\nshuffled_megabatches', shuffled_megabatches)\n    # import ipdb;ipdb.set_trace()\n    # print('\\nshuffled_megabatches len', [[i, lengths[i]] for megabatch in shuffled_megabatches for batch in megabatch for i in batch])\n\n    return [i for megabatch in shuffled_megabatches for batch in megabatch for i in batch]\n\n\nclass LengthGroupedSampler(Sampler):\n    r\"\"\"\n    Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while\n    keeping a bit of randomness.\n    \"\"\"\n\n    def __init__(\n        self,\n        batch_size: int,\n        world_size: int,\n        gradient_accumulation_size: int, \n        initial_global_step: int, \n        lengths: Optional[List[int]] = None, \n        group_data=False, \n        generator=None,\n    ):\n        if lengths is None:\n            raise ValueError(\"Lengths must be provided.\")\n\n        self.batch_size = batch_size\n        self.world_size = world_size\n        self.initial_global_step = initial_global_step\n        self.gradient_accumulation_size = gradient_accumulation_size\n        self.lengths = lengths\n        self.group_data = group_data\n        self.generator = generator\n        # print('self.lengths, self.initial_global_step, self.batch_size, self.world_size, self.gradient_accumulation_size', \n        #       len(self.lengths), self.initial_global_step, self.batch_size, self.world_size, self.gradient_accumulation_size)\n\n    def __len__(self):\n        return len(self.lengths) - self.initial_global_step * self.batch_size * self.world_size * self.gradient_accumulation_size\n\n    def __iter__(self):\n        indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, \n                                             self.gradient_accumulation_size, self.initial_global_step, \n                                             group_data=self.group_data, generator=self.generator)\n        # print(len(indices), indices[23640:23690])\n        # import sys;sys.exit()\n        return iter(indices)\n"
  },
  {
    "path": "opensora/utils/downloader.py",
    "content": "import gdown\nimport os\n\nopensora_cache_home = os.path.expanduser(\n    os.getenv(\"OPENSORA_HOME\", os.path.join(\"~/.cache\", \"opensora\"))\n)\n\n\ndef gdown_download(id, fname, cache_dir=None):\n    cache_dir = opensora_cache_home if not cache_dir else cache_dir\n\n    os.makedirs(cache_dir, exist_ok=True)\n    destination = os.path.join(cache_dir, fname)\n    if os.path.exists(destination):\n        return destination\n\n    gdown.download(id=id, output=destination, quiet=False)\n    return destination\n"
  },
  {
    "path": "opensora/utils/ema.py",
    "content": "import contextlib\nimport copy\nimport random\nfrom typing import Any, Dict, Iterable, List, Optional, Union\n\nfrom diffusers.utils import (\n    deprecate,\n    is_torchvision_available,\n    is_transformers_available,\n)\n\nif is_transformers_available():\n    import transformers\n\nif is_torchvision_available():\n    from torchvision import transforms\n\nimport numpy as np\nimport torch\n\n\n# Adapted from diffusers-style ema https://github.com/huggingface/diffusers/blob/main/src/diffusers/training_utils.py#L263\nclass EMAModel:\n    \"\"\"\n    Exponential Moving Average of models weights\n    \"\"\"\n\n    def __init__(\n        self,\n        parameters: Iterable[torch.nn.Parameter],\n        decay: float = 0.9999,\n        min_decay: float = 0.0,\n        update_after_step: int = 0,\n        use_ema_warmup: bool = False,\n        inv_gamma: Union[float, int] = 1.0,\n        power: Union[float, int] = 2 / 3,\n        model_cls: Optional[Any] = None,\n        model_config: Dict[str, Any] = None,\n        **kwargs,\n    ):\n        \"\"\"\n        Args:\n            parameters (Iterable[torch.nn.Parameter]): The parameters to track.\n            decay (float): The decay factor for the exponential moving average.\n            min_decay (float): The minimum decay factor for the exponential moving average.\n            update_after_step (int): The number of steps to wait before starting to update the EMA weights.\n            use_ema_warmup (bool): Whether to use EMA warmup.\n            inv_gamma (float):\n                Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True.\n            power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True.\n            device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA\n                        weights will be stored on CPU.\n\n        @crowsonkb's notes on EMA Warmup:\n            If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan\n            to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),\n            gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999\n            at 215.4k steps).\n        \"\"\"\n\n        if isinstance(parameters, torch.nn.Module):\n            deprecation_message = (\n                \"Passing a `torch.nn.Module` to `ExponentialMovingAverage` is deprecated. \"\n                \"Please pass the parameters of the module instead.\"\n            )\n            deprecate(\n                \"passing a `torch.nn.Module` to `ExponentialMovingAverage`\",\n                \"1.0.0\",\n                deprecation_message,\n                standard_warn=False,\n            )\n            parameters = parameters.parameters()\n\n            # set use_ema_warmup to True if a torch.nn.Module is passed for backwards compatibility\n            use_ema_warmup = True\n\n        if kwargs.get(\"max_value\", None) is not None:\n            deprecation_message = \"The `max_value` argument is deprecated. Please use `decay` instead.\"\n            deprecate(\"max_value\", \"1.0.0\", deprecation_message, standard_warn=False)\n            decay = kwargs[\"max_value\"]\n\n        if kwargs.get(\"min_value\", None) is not None:\n            deprecation_message = \"The `min_value` argument is deprecated. Please use `min_decay` instead.\"\n            deprecate(\"min_value\", \"1.0.0\", deprecation_message, standard_warn=False)\n            min_decay = kwargs[\"min_value\"]\n\n        parameters = list(parameters)\n        self.shadow_params = [p.clone().detach() for p in parameters]\n\n        if kwargs.get(\"device\", None) is not None:\n            deprecation_message = \"The `device` argument is deprecated. Please use `to` instead.\"\n            deprecate(\"device\", \"1.0.0\", deprecation_message, standard_warn=False)\n            self.to(device=kwargs[\"device\"])\n\n        self.temp_stored_params = None\n\n        self.decay = decay\n        self.min_decay = min_decay\n        self.update_after_step = update_after_step\n        self.use_ema_warmup = use_ema_warmup\n        self.inv_gamma = inv_gamma\n        self.power = power\n        self.optimization_step = 0\n        self.cur_decay_value = None  # set in `step()`\n\n        self.model_cls = model_cls\n        self.model_config = model_config\n\n    @classmethod\n    def extract_ema_kwargs(cls, kwargs):\n        \"\"\"\n        Extracts the EMA kwargs from the kwargs of a class method.\n        \"\"\"\n        ema_kwargs = {}\n        for key in [\n            \"decay\",\n            \"min_decay\",\n            \"optimization_step\",\n            \"update_after_step\",\n            \"use_ema_warmup\",\n            \"inv_gamma\",\n            \"power\",\n        ]:\n            if kwargs.get(key, None) is not None:\n                ema_kwargs[key] = kwargs.pop(key)\n        return ema_kwargs\n\n    @classmethod\n    def from_pretrained(cls, path, model_cls) -> \"EMAModel\":\n        config = model_cls.load_config(path)\n        ema_kwargs = cls.extract_ema_kwargs(config)\n        model = model_cls.from_pretrained(path)\n\n        ema_model = cls(model.parameters(), model_cls=model_cls, model_config=config)\n\n        ema_model.load_state_dict(ema_kwargs)\n        return ema_model\n\n    def save_pretrained(self, path):\n        if self.model_cls is None:\n            raise ValueError(\"`save_pretrained` can only be used if `model_cls` was defined at __init__.\")\n\n        if self.model_config is None:\n            raise ValueError(\"`save_pretrained` can only be used if `model_config` was defined at __init__.\")\n\n        model = self.model_cls.from_config(self.model_config)\n        state_dict = self.state_dict()\n        state_dict.pop(\"shadow_params\", None)\n\n        model.register_to_config(**state_dict)\n        self.copy_to(model.parameters())\n        model.save_pretrained(path)\n\n    def get_decay(self, optimization_step: int) -> float:\n        \"\"\"\n        Compute the decay factor for the exponential moving average.\n        \"\"\"\n        step = max(0, optimization_step - self.update_after_step - 1)\n\n        if step <= 0:\n            return 0.0\n\n        if self.use_ema_warmup:\n            cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power\n        else:\n            cur_decay_value = (1 + step) / (10 + step)\n\n        cur_decay_value = min(cur_decay_value, self.decay)\n        # make sure decay is not smaller than min_decay\n        cur_decay_value = max(cur_decay_value, self.min_decay)\n        return cur_decay_value\n\n    @torch.no_grad()\n    def step(self, parameters: Iterable[torch.nn.Parameter]):\n        if isinstance(parameters, torch.nn.Module):\n            deprecation_message = (\n                \"Passing a `torch.nn.Module` to `ExponentialMovingAverage.step` is deprecated. \"\n                \"Please pass the parameters of the module instead.\"\n            )\n            deprecate(\n                \"passing a `torch.nn.Module` to `ExponentialMovingAverage.step`\",\n                \"1.0.0\",\n                deprecation_message,\n                standard_warn=False,\n            )\n            parameters = parameters.parameters()\n\n        parameters = list(parameters)\n\n        self.optimization_step += 1\n\n        # Compute the decay factor for the exponential moving average.\n        decay = self.get_decay(self.optimization_step)\n        self.cur_decay_value = decay\n        one_minus_decay = 1 - decay\n\n        context_manager = contextlib.nullcontext\n        if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled():\n            import deepspeed\n\n        for s_param, param in zip(self.shadow_params, parameters):\n            if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled():\n                context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None)\n\n            with context_manager():\n                if param.requires_grad:\n                    s_param.sub_(one_minus_decay * (s_param - param))\n                else:\n                    s_param.copy_(param)\n\n    def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:\n        \"\"\"\n        Copy current averaged parameters into given collection of parameters.\n\n        Args:\n            parameters: Iterable of `torch.nn.Parameter`; the parameters to be\n                updated with the stored moving averages. If `None`, the parameters with which this\n                `ExponentialMovingAverage` was initialized will be used.\n        \"\"\"\n        parameters = list(parameters)\n        for s_param, param in zip(self.shadow_params, parameters):\n            param.data.copy_(s_param.to(param.device).data)\n\n\n    def to(self, device=None, dtype=None) -> None:\n        r\"\"\"Move internal buffers of the ExponentialMovingAverage to `device`.\n\n        Args:\n            device: like `device` argument to `torch.Tensor.to`\n        \"\"\"\n        # .to() on the tensors handles None correctly\n        self.shadow_params = [\n            p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)\n            for p in self.shadow_params\n        ]\n\n    def state_dict(self) -> dict:\n        r\"\"\"\n        Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during\n        checkpointing to save the ema state dict.\n        \"\"\"\n        # Following PyTorch conventions, references to tensors are returned:\n        # \"returns a reference to the state and not its copy!\" -\n        # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict\n        return {\n            \"decay\": self.decay,\n            \"min_decay\": self.min_decay,\n            \"optimization_step\": self.optimization_step,\n            \"update_after_step\": self.update_after_step,\n            \"use_ema_warmup\": self.use_ema_warmup,\n            \"inv_gamma\": self.inv_gamma,\n            \"power\": self.power,\n            \"shadow_params\": self.shadow_params,\n        }\n\n    def store(self, parameters: Iterable[torch.nn.Parameter]) -> None:\n        r\"\"\"\n        Args:\n        Save the current parameters for restoring later.\n            parameters: Iterable of `torch.nn.Parameter`; the parameters to be\n                temporarily stored.\n        \"\"\"\n        self.temp_stored_params = [param.detach().cpu().clone() for param in parameters]\n\n    def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None:\n        r\"\"\"\n        Args:\n        Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without:\n        affecting the original optimization process. Store the parameters before the `copy_to()` method. After\n        validation (or model saving), use this to restore the former parameters.\n            parameters: Iterable of `torch.nn.Parameter`; the parameters to be\n                updated with the stored parameters. If `None`, the parameters with which this\n                `ExponentialMovingAverage` was initialized will be used.\n        \"\"\"\n        if self.temp_stored_params is None:\n            raise RuntimeError(\"This ExponentialMovingAverage has no `store()`ed weights \" \"to `restore()`\")\n        for c_param, param in zip(self.temp_stored_params, parameters):\n            param.data.copy_(c_param.data)\n\n        # Better memory-wise.\n        self.temp_stored_params = None\n\n    def load_state_dict(self, state_dict: dict) -> None:\n        r\"\"\"\n        Args:\n        Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the\n        ema state dict.\n            state_dict (dict): EMA state. Should be an object returned\n                from a call to :meth:`state_dict`.\n        \"\"\"\n        # deepcopy, to be consistent with module API\n        state_dict = copy.deepcopy(state_dict)\n\n        self.decay = state_dict.get(\"decay\", self.decay)\n        if self.decay < 0.0 or self.decay > 1.0:\n            raise ValueError(\"Decay must be between 0 and 1\")\n\n        self.min_decay = state_dict.get(\"min_decay\", self.min_decay)\n        if not isinstance(self.min_decay, float):\n            raise ValueError(\"Invalid min_decay\")\n\n        self.optimization_step = state_dict.get(\"optimization_step\", self.optimization_step)\n        if not isinstance(self.optimization_step, int):\n            raise ValueError(\"Invalid optimization_step\")\n\n        self.update_after_step = state_dict.get(\"update_after_step\", self.update_after_step)\n        if not isinstance(self.update_after_step, int):\n            raise ValueError(\"Invalid update_after_step\")\n\n        self.use_ema_warmup = state_dict.get(\"use_ema_warmup\", self.use_ema_warmup)\n        if not isinstance(self.use_ema_warmup, bool):\n            raise ValueError(\"Invalid use_ema_warmup\")\n\n        self.inv_gamma = state_dict.get(\"inv_gamma\", self.inv_gamma)\n        if not isinstance(self.inv_gamma, (float, int)):\n            raise ValueError(\"Invalid inv_gamma\")\n\n        self.power = state_dict.get(\"power\", self.power)\n        if not isinstance(self.power, (float, int)):\n            raise ValueError(\"Invalid power\")\n\n        shadow_params = state_dict.get(\"shadow_params\", None)\n        if shadow_params is not None:\n            self.shadow_params = shadow_params\n            if not isinstance(self.shadow_params, list):\n                raise ValueError(\"shadow_params must be a list\")\n            if not all(isinstance(p, torch.Tensor) for p in self.shadow_params):\n                raise ValueError(\"shadow_params must all be Tensors\")\n"
  },
  {
    "path": "opensora/utils/ema_utils.py",
    "content": "\nfrom peft import get_peft_model, PeftModel\nimport os\nfrom copy import deepcopy\nimport torch\nimport json\nfrom diffusers.training_utils import EMAModel as diffuser_EMAModel\n\n\n\nclass EMAModel(diffuser_EMAModel):\n    def __init__(self, parameters, **kwargs):\n        self.lora_config = kwargs.pop('lora_config', None)\n        super().__init__(parameters, **kwargs)\n    \n    @classmethod\n    def from_pretrained(cls, path, model_cls, lora_config, model_base) -> \"EMAModel\":\n        # 1. load model\n        if lora_config is not None:\n            # 1.1 load origin model\n            model_base = model_cls.from_pretrained(model_base)  # model_base\n            config = model_base.config\n            # 1.2 convert to lora model automatically and load lora weight\n            model = PeftModel.from_pretrained(model_base, path)  # lora_origin_model\n        else:\n            model = model_cls.from_pretrained(path)\n            config = model.config\n        # 3. ema the whole model\n        ema_model = cls(model.parameters(), model_cls=model_cls, model_config=config, lora_config=lora_config)\n        # 4. load ema_config, e.g decay...\n        with open(os.path.join(path, 'ema_config.json'), 'r') as f:\n            state_dict = json.load(f)\n        ema_model.load_state_dict(state_dict)\n        return ema_model\n\n    def save_pretrained(self, path):\n        if self.model_cls is None:\n            raise ValueError(\"`save_pretrained` can only be used if `model_cls` was defined at __init__.\")\n\n        if self.model_config is None:\n            raise ValueError(\"`save_pretrained` can only be used if `model_config` was defined at __init__.\")\n        # 1. init a base model randomly\n        model = self.model_cls.from_config(self.model_config)\n        # 1.1 convert lora_model\n        if self.lora_config is not None:\n            model = get_peft_model(model, self.lora_config)\n        # 2. ema_model copy to model\n        self.copy_to(model.parameters())\n        # 3. save weight\n        if self.lora_config is not None:\n            model.save_pretrained(path)  # only lora weight\n            merge_model = model.merge_and_unload()\n            merge_model.save_pretrained(path) # merge_model weight\n        else:\n            merge_model.save_pretrained(path) # model weight\n        # 4. save ema_config, e.g decay...\n        state_dict = self.state_dict()  # lora_model weight\n        state_dict.pop(\"shadow_params\", None)\n        with open(os.path.join(path, 'ema_config.json'), 'w') as f:\n            json.dump(state_dict, f, indent=2)"
  },
  {
    "path": "opensora/utils/freeinit_utils.py",
    "content": "import torch\nimport torch.fft as fft\nimport math\n\n\ndef freq_mix_3d(x, noise, LPF):\n    \"\"\"\n    Noise reinitialization.\n\n    Args:\n        x: diffused latent\n        noise: randomly sampled noise\n        LPF: low pass filter\n    \"\"\"\n    # FFT\n    x_freq = fft.fftn(x, dim=(-3, -2, -1))\n    x_freq = fft.fftshift(x_freq, dim=(-3, -2, -1))\n    noise_freq = fft.fftn(noise, dim=(-3, -2, -1))\n    noise_freq = fft.fftshift(noise_freq, dim=(-3, -2, -1))\n\n    # frequency mix\n    HPF = 1 - LPF\n    x_freq_low = x_freq * LPF\n    noise_freq_high = noise_freq * HPF\n    x_freq_mixed = x_freq_low + noise_freq_high # mix in freq domain\n\n    # IFFT\n    x_freq_mixed = fft.ifftshift(x_freq_mixed, dim=(-3, -2, -1))\n    x_mixed = fft.ifftn(x_freq_mixed, dim=(-3, -2, -1)).real\n\n    return x_mixed\n\n\ndef get_freq_filter(shape, device, filter_type, n, d_s, d_t):\n    \"\"\"\n    Form the frequency filter for noise reinitialization.\n\n    Args:\n        shape: shape of latent (B, C, T, H, W)\n        filter_type: type of the freq filter\n        n: (only for butterworth) order of the filter, larger n ~ ideal, smaller n ~ gaussian\n        d_s: normalized stop frequency for spatial dimensions (0.0-1.0)\n        d_t: normalized stop frequency for temporal dimension (0.0-1.0)\n    \"\"\"\n    if filter_type == \"gaussian\":\n        return gaussian_low_pass_filter(shape=shape, d_s=d_s, d_t=d_t).to(device)\n    elif filter_type == \"ideal\":\n        return ideal_low_pass_filter(shape=shape, d_s=d_s, d_t=d_t).to(device)\n    elif filter_type == \"box\":\n        return box_low_pass_filter(shape=shape, d_s=d_s, d_t=d_t).to(device)\n    elif filter_type == \"butterworth\":\n        return butterworth_low_pass_filter(shape=shape, n=n, d_s=d_s, d_t=d_t).to(device)\n    else:\n        raise NotImplementedError\n\ndef gaussian_low_pass_filter(shape, d_s=0.25, d_t=0.25):\n    \"\"\"\n    Compute the gaussian low pass filter mask.\n\n    Args:\n        shape: shape of the filter (volume)\n        d_s: normalized stop frequency for spatial dimensions (0.0-1.0)\n        d_t: normalized stop frequency for temporal dimension (0.0-1.0)\n    \"\"\"\n    T, H, W = shape[-3], shape[-2], shape[-1]\n    mask = torch.zeros(shape)\n    if d_s==0 or d_t==0:\n        return mask\n    for t in range(T):\n        for h in range(H):\n            for w in range(W):\n                d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2)\n                mask[..., t,h,w] = math.exp(-1/(2*d_s**2) * d_square)\n    return mask\n\n\ndef butterworth_low_pass_filter(shape, n=4, d_s=0.25, d_t=0.25):\n    \"\"\"\n    Compute the butterworth low pass filter mask.\n\n    Args:\n        shape: shape of the filter (volume)\n        n: order of the filter, larger n ~ ideal, smaller n ~ gaussian\n        d_s: normalized stop frequency for spatial dimensions (0.0-1.0)\n        d_t: normalized stop frequency for temporal dimension (0.0-1.0)\n    \"\"\"\n    T, H, W = shape[-3], shape[-2], shape[-1]\n    mask = torch.zeros(shape)\n    if d_s==0 or d_t==0:\n        return mask\n    for t in range(T):\n        for h in range(H):\n            for w in range(W):\n                d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2)\n                mask[..., t,h,w] = 1 / (1 + (d_square / d_s**2)**n)\n    return mask\n\n\ndef ideal_low_pass_filter(shape, d_s=0.25, d_t=0.25):\n    \"\"\"\n    Compute the ideal low pass filter mask.\n\n    Args:\n        shape: shape of the filter (volume)\n        d_s: normalized stop frequency for spatial dimensions (0.0-1.0)\n        d_t: normalized stop frequency for temporal dimension (0.0-1.0)\n    \"\"\"\n    T, H, W = shape[-3], shape[-2], shape[-1]\n    mask = torch.zeros(shape)\n    if d_s==0 or d_t==0:\n        return mask\n    for t in range(T):\n        for h in range(H):\n            for w in range(W):\n                d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2)\n                mask[..., t,h,w] =  1 if d_square <= d_s*2 else 0\n    return mask\n\n\ndef box_low_pass_filter(shape, d_s=0.25, d_t=0.25):\n    \"\"\"\n    Compute the ideal low pass filter mask (approximated version).\n\n    Args:\n        shape: shape of the filter (volume)\n        d_s: normalized stop frequency for spatial dimensions (0.0-1.0)\n        d_t: normalized stop frequency for temporal dimension (0.0-1.0)\n    \"\"\"\n    T, H, W = shape[-3], shape[-2], shape[-1]\n    mask = torch.zeros(shape)\n    if d_s==0 or d_t==0:\n        return mask\n\n    threshold_s = round(int(H // 2) * d_s)\n    threshold_t = round(T // 2 * d_t)\n\n    cframe, crow, ccol = T // 2, H // 2, W //2\n    mask[..., cframe - threshold_t:cframe + threshold_t, crow - threshold_s:crow + threshold_s, ccol - threshold_s:ccol + threshold_s] = 1.0\n\n    return mask"
  },
  {
    "path": "opensora/utils/lora_utils.py",
    "content": "\nfrom peft import get_peft_model, PeftModel\nimport os\nfrom copy import deepcopy\nimport torch\nimport json\n\ndef maybe_zero_3(param, ignore_status=False, name=None):\n    from deepspeed import zero\n    from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus\n    if hasattr(param, \"ds_id\"):\n        if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:\n            if not ignore_status:\n                logging.warning(f\"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}\")\n        with zero.GatheredParameters([param]):\n            param = param.data.detach().cpu().clone()\n    else:\n        param = param.detach().cpu().clone()\n    return param\n\n# Borrowed from peft.utils.get_peft_model_state_dict\ndef get_peft_state_maybe_zero_3(named_params, bias):\n    if bias == \"none\":\n        to_return = {k: t for k, t in named_params if \"lora_\" in k}\n    elif bias == \"all\":\n        to_return = {k: t for k, t in named_params if \"lora_\" in k or \"bias\" in k}\n    elif bias == \"lora_only\":\n        to_return = {}\n        maybe_lora_bias = {}\n        lora_bias_names = set()\n        for k, t in named_params:\n            if \"lora_\" in k:\n                to_return[k] = t\n                bias_name = k.split(\"lora_\")[0] + \"bias\"\n                lora_bias_names.add(bias_name)\n            elif \"bias\" in k:\n                maybe_lora_bias[k] = t\n        for k, t in maybe_lora_bias:\n            if bias_name in lora_bias_names:\n                to_return[bias_name] = t\n    else:\n        raise NotImplementedError\n    to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}\n    return to_return\n"
  },
  {
    "path": "opensora/utils/mask_utils.py",
    "content": "\nfrom math import floor, ceil\nfrom abc import ABC, abstractmethod\nimport cv2\nimport torch\nimport torch.nn.functional as F\nimport imageio\nimport numpy as np\ntry:\n    import torch_npu\n    from opensora.npu_config import npu_config\nexcept:\n    torch_npu = None\n    npu_config = None\nimport random\nfrom enum import Enum, auto\n\nfrom einops import rearrange\n\nclass MaskType(Enum):\n    t2iv = auto() # For video, execute t2v (all frames are masked), for image, execute t2i (the image are masked)\n    i2v = auto() # Only for video, execute i2v (i.e. maintain the first frame and mask the rest)\n    transition = auto() # Only for video, execute transition (i.e. maintain the first and last frame and mask the rest)\n    continuation = auto() # Only for video, execute video continuation (i.e. maintain the starting k frames and mask the rest)\n    clear = auto() # For video and image, all frames are not masked\n    random_temporal = auto() # For video, randomly mask some frames\n\nTYPE_TO_STR = {mask_type: mask_type.name for mask_type in MaskType}\nSTR_TO_TYPE = {mask_type.name: mask_type for mask_type in MaskType}\n\ndef save_mask_to_video(mask, save_path='mask.mp4', fps=24):\n    T, _, H, W = mask.shape\n    writer = imageio.get_writer(save_path, fps=fps, codec='libx264', quality=6)\n    for t in range(T):\n        frame = mask[t, 0].cpu().numpy() * 255\n        frame = frame.astype(np.uint8)  # 确保数据类型是 uint8\n        writer.append_data(frame)\n    writer.close()\n\ndef read_video(video_path):\n    reader = imageio.get_reader(video_path)\n    frames = []\n    for frame in reader:\n        frame = np.transpose(frame, (2, 0, 1))\n        frames.append(frame)\n    video_array = np.stack(frames)\n    video_tensor = torch.from_numpy(video_array).float()\n    reader.close()\n    return video_tensor\n\nclass BaseMaskGenerator(ABC):\n\n    def create_system_mask(self, num_frames, height, width, device, dtype):\n        if num_frames is None or height is None or width is None:\n            raise ValueError('num_frames, height, and width should be provided.')\n        return torch.ones([num_frames, 1, height, width], device=device, dtype=dtype)\n\n    @abstractmethod\n    def process(self, mask):\n        # process self.mask to meet the specific task\n        pass\n\n    def __call__(self, num_frames=None, height=None, width=None, device='cuda', dtype=torch.float32):\n        mask = self.create_system_mask(num_frames, height, width, device, dtype)\n        return self.process(mask)\n\nclass T2IVMaskGenerator(BaseMaskGenerator):\n    def process(self, mask):\n        mask.fill_(1)\n        return mask\n\nclass I2VMaskGenerator(BaseMaskGenerator):\n    def process(self, mask):\n        mask[0] = 0\n        return mask\n\nclass TransitionMaskGenerator(BaseMaskGenerator):\n    def process(self, mask):\n        mask[0] = 0\n        mask[-1] = 0\n        return mask\n\nclass ContinuationMaskGenerator(BaseMaskGenerator):\n    \n    def __init__(self, min_clear_ratio=0.0, max_clear_ratio=1.0):\n        assert min_clear_ratio >= 0 and min_clear_ratio <= 1, 'min_clear_ratio should be in the range of [0, 1].'\n        assert max_clear_ratio >= 0 and max_clear_ratio <= 1, 'max_clear_ratio should be in the range of [0, 1].'\n        assert min_clear_ratio <= max_clear_ratio, 'min_clear_ratio should be less than max_clear_ratio.'\n        self.min_clear_ratio = min_clear_ratio\n        self.max_clear_ratio = max_clear_ratio\n\n    def process(self, mask):\n        num_frames = mask.shape[0]\n        end_idx = random.randint(floor(num_frames * self.min_clear_ratio), ceil(num_frames * self.max_clear_ratio))\n        mask[0:end_idx] = 0\n        return mask\n\nclass ClearMaskGenerator(BaseMaskGenerator):\n    def process(self, mask):\n        mask.zero_()\n        return mask\n\nclass RandomTemporalMaskGenerator(BaseMaskGenerator):\n\n    def __init__(self, min_clear_ratio=0.0, max_clear_ratio=1.0):\n        assert min_clear_ratio >= 0 and min_clear_ratio <= 1, 'min_clear_ratio should be in the range of [0, 1].'\n        assert max_clear_ratio >= 0 and max_clear_ratio <= 1, 'max_clear_ratio should be in the range of [0, 1].'\n        assert min_clear_ratio <= max_clear_ratio, 'min_clear_ratio should be less than max_clear_ratio.'\n        self.min_clear_ratio = min_clear_ratio\n        self.max_clear_ratio = max_clear_ratio\n\n    def process(self, mask):\n        num_frames = mask.shape[0]\n        num_to_select = random.randint(floor(num_frames * self.min_clear_ratio), ceil(num_frames * self.max_clear_ratio))\n        selected_indices = random.sample(range(num_frames), num_to_select)\n        mask[selected_indices] = 0\n        return mask\n\n\nclass MaskProcessor:\n    def __init__(\n        self, \n        max_height=640, \n        max_width=640, \n        min_clear_ratio=0.0, \n        max_clear_ratio=1.0, \n    ):\n        \n        self.max_height = max_height\n        self.max_width = max_width\n        self.min_clear_ratio = min_clear_ratio\n        self.max_clear_ratio = max_clear_ratio\n\n        self.init_mask_generators()\n\n    def init_mask_generators(self):\n        self.mask_generators = {\n            MaskType.t2iv: T2IVMaskGenerator(),\n            MaskType.i2v: I2VMaskGenerator(),\n            MaskType.transition: TransitionMaskGenerator(),\n            MaskType.continuation: ContinuationMaskGenerator(min_clear_ratio=self.min_clear_ratio, max_clear_ratio=self.max_clear_ratio),\n            MaskType.clear: ClearMaskGenerator(),\n            MaskType.random_temporal: RandomTemporalMaskGenerator(min_clear_ratio=self.min_clear_ratio, max_clear_ratio=self.max_clear_ratio),\n        }\n    \n    def get_mask(self, mask_generator_type, num_frames, height, width, device='cuda', dtype=torch.float32):\n        return self.mask_generators[mask_generator_type](num_frames, height, width, device=device, dtype=dtype)\n    \n    def __call__(self, pixel_values, mask_type=None, mask_type_ratio_dict=None):\n\n        num_frames, _, height, width = pixel_values.shape   \n\n        if mask_type_ratio_dict is not None:  \n            assert isinstance(mask_type_ratio_dict, dict), 'mask_type_ratio_dict should be a dict.'\n            assert mask_type_ratio_dict.keys() <= set(MaskType), f'Invalid mask type: {set(MaskType) - mask_type_ratio_dict.keys()}'\n            mask_generator_type = random.choices(list(mask_type_ratio_dict.keys()), list(mask_type_ratio_dict.values()))[0]\n        elif mask_type is not None:\n            assert mask_type in STR_TO_TYPE.keys() or mask_type in STR_TO_TYPE.values(), f'Invalid mask type: {mask_type}'\n            mask_generator_type = mask_type if mask_type in MaskType else STR_TO_TYPE[mask_type]\n        else:\n            raise ValueError('mask_type or mask_type_ratio_dict should be provided.')\n        \n        mask = self.get_mask(mask_generator_type, num_frames, height, width, device=pixel_values.device, dtype=pixel_values.dtype)\n\n        masked_pixel_values = pixel_values * (mask < 0.5)\n        return dict(mask=mask, masked_pixel_values=masked_pixel_values)\n    \nclass MaskCompressor:\n    def __init__(self, ae_stride_h=8, ae_stride_w=8, ae_stride_t=4, **kwargs):\n        self.ae_stride_h = ae_stride_h\n        self.ae_stride_w = ae_stride_w\n        self.ae_stride_t = ae_stride_t\n    \n    def __call__(self, mask):\n        B, C, T, H, W = mask.shape\n        new_H, new_W = H // self.ae_stride_h, W // self.ae_stride_w\n        mask = rearrange(mask, 'b c t h w -> (b c t) 1 h w')\n        if torch_npu is not None:\n            dtype = mask.dtype\n            mask = mask.to(dtype=torch.float32)\n            mask = F.interpolate(mask, size=(new_H, new_W), mode='bilinear')\n            mask = mask.to(dtype)\n        else:\n            mask = F.interpolate(mask, size=(new_H, new_W), mode='bilinear')\n        mask = rearrange(mask, '(b c t) 1 h w -> b c t h w', t=T, b=B)\n        if T % 2 == 1:\n            new_T = T // self.ae_stride_t + 1\n            mask_first_frame = mask[:, :, 0:1].repeat(1, 1, self.ae_stride_t, 1, 1).contiguous() \n            mask = torch.cat([mask_first_frame, mask[:, :, 1:]], dim=2)\n        else:\n            new_T = T // self.ae_stride_t\n        mask = mask.view(B, new_T, self.ae_stride_t, new_H, new_W)\n        mask = mask.transpose(1, 2).contiguous() # Transpose to allows the channel dimension to represent a portion of the region in the original mask\n        return mask\n    \nclass BaseNoiseAdder(ABC):\n    \n    @abstractmethod\n    def add_noise(self, mask_pixel_values, mask):\n        pass\n\n    def __call__(self, mask_pixel_values, mask):\n        return self.add_noise(mask_pixel_values, mask)\n    \nclass GaussianNoiseAdder(BaseNoiseAdder):\n    def __init__(self, mean=-3.0, std=0.5, clear_ratio=0.05):\n        self.mean = mean\n        self.std = std\n        self.clear_ratio = clear_ratio\n    # pixel_values: (B, C, T, H, W)\n    # mask: (B, 1, T, H, W)\n    def add_noise(self, masked_pixel_values, mask):\n        if random.random() < self.clear_ratio:\n            return masked_pixel_values\n        noise_sigma = torch.normal(mean=self.mean, std=self.std, size=(masked_pixel_values.shape[0],), device=masked_pixel_values.device)\n        noise_sigma = torch.exp(noise_sigma).to(dtype=masked_pixel_values.dtype)\n        noise = torch.randn_like(masked_pixel_values) * noise_sigma[:, None, None, None, None]\n        noise = torch.where(mask < 0.5, noise, torch.zeros_like(noise))\n        return masked_pixel_values + noise\n\n\nif __name__ == '__main__':\n    video_path = '/home/image_data/hxy/data/video/000184.mp4'\n    video = read_video(video_path)\n    processor = MaskProcessor()\n    ratio_dict = {\n        MaskType.t2iv: 0,\n        MaskType.i2v: 0,\n        MaskType.transition: 0,\n        MaskType.continuation: 0,\n        MaskType.clear: 0,\n        MaskType.random_temporal: 1,\n    }\n\n    mask = processor(video, mask_type_ratio_dict=ratio_dict)['mask']\n    print(mask.shape)\n    save_mask_to_video(mask, save_path='test_mask.mp4', fps=24)\n    "
  },
  {
    "path": "opensora/utils/parallel_states.py",
    "content": "import torch\nimport torch.distributed as dist\nimport os\n\nclass COMM_INFO:\n    def __init__(self):\n        self.group = None\n        self.world_size = 0\n        self.rank = -1\n\nnccl_info = COMM_INFO()\n_SEQUENCE_PARALLEL_STATE = False\ndef initialize_sequence_parallel_state(sequence_parallel_size):\n    global _SEQUENCE_PARALLEL_STATE\n    if sequence_parallel_size > 1:\n        _SEQUENCE_PARALLEL_STATE = True\n        initialize_sequence_parallel_group(sequence_parallel_size)\n\ndef set_sequence_parallel_state(state):\n    global _SEQUENCE_PARALLEL_STATE\n    _SEQUENCE_PARALLEL_STATE = state\n\ndef get_sequence_parallel_state():\n    return _SEQUENCE_PARALLEL_STATE\n\ndef initialize_sequence_parallel_group(sequence_parallel_size):\n    \"\"\"Initialize the sequence parallel group.\"\"\"\n    rank = int(os.getenv('RANK', '0'))\n    world_size = int(os.getenv(\"WORLD_SIZE\", '1'))\n    assert world_size % sequence_parallel_size == 0, \"world_size must be divisible by sequence_parallel_size\"\n    # hccl\n    nccl_info.world_size = sequence_parallel_size\n    nccl_info.rank = rank\n    num_sequence_parallel_groups: int = world_size // sequence_parallel_size\n    for i in range(num_sequence_parallel_groups):\n        ranks = range(i * sequence_parallel_size, (i + 1) * sequence_parallel_size)\n        group = dist.new_group(ranks)\n        if rank in ranks:\n            nccl_info.group = group\n\n\ndef destroy_sequence_parallel_group():\n    \"\"\"Destroy the sequence parallel group.\"\"\"\n    dist.destroy_process_group()\n"
  },
  {
    "path": "opensora/utils/sample_utils.py",
    "content": "from diffusers.schedulers import (\n    DDIMScheduler, DDPMScheduler, PNDMScheduler,\n    EulerDiscreteScheduler, DPMSolverMultistepScheduler,\n    HeunDiscreteScheduler, EulerAncestralDiscreteScheduler,\n    DEISMultistepScheduler, KDPM2AncestralDiscreteScheduler, \n    DPMSolverSinglestepScheduler, CogVideoXDDIMScheduler, \n    FlowMatchEulerDiscreteScheduler\n    )\nfrom einops import rearrange\nimport time\nimport torch\nimport os\nimport torch.distributed as dist\nfrom torchvision.utils import save_image\nimport imageio\nimport math\nimport argparse\nfrom transformers import AutoModelForCausalLM\n\ntry:\n    import torch_npu\n    from opensora.npu_config import npu_config\n    from opensora.acceleration.parallel_states import initialize_sequence_parallel_state, hccl_info\nexcept:\n    torch_npu = None\n    npu_config = None\n    from opensora.utils.parallel_states import initialize_sequence_parallel_state, nccl_info\n    pass\n\nfrom opensora.utils.utils import set_seed\nfrom opensora.models.causalvideovae import ae_stride_config, ae_wrapper\nfrom opensora.sample.pipeline_opensora import OpenSoraPipeline\nfrom opensora.sample.pipeline_inpaint import OpenSoraInpaintPipeline\nfrom opensora.models.diffusion.opensora_v1_3.modeling_opensora import OpenSoraT2V_v1_3\nfrom opensora.models.diffusion.opensora_v1_3.modeling_inpaint import OpenSoraInpaint_v1_3\nfrom transformers import T5EncoderModel, T5Tokenizer, AutoTokenizer, MT5EncoderModel, CLIPTextModelWithProjection\n\ndef get_scheduler(args):\n    kwargs = dict(\n        prediction_type=args.prediction_type, \n        rescale_betas_zero_snr=args.rescale_betas_zero_snr, \n        timestep_spacing=\"trailing\" if args.rescale_betas_zero_snr else 'leading', \n    )\n    if args.v1_5_scheduler:\n        kwargs['beta_start'] = 0.00085\n        kwargs['beta_end'] = 0.0120\n        kwargs['beta_schedule'] = \"scaled_linear\"\n    if args.sample_method == 'DDIM':  \n        scheduler_cls = DDIMScheduler\n        kwargs['clip_sample'] = False\n    elif args.sample_method == 'EulerDiscrete':\n        scheduler_cls = EulerDiscreteScheduler\n    elif args.sample_method == 'DDPM':  \n        scheduler_cls = DDPMScheduler\n        kwargs['clip_sample'] = False\n    elif args.sample_method == 'DPMSolverMultistep':\n        scheduler_cls = DPMSolverMultistepScheduler\n    elif args.sample_method == 'DPMSolverSinglestep':\n        scheduler_cls = DPMSolverSinglestepScheduler\n    elif args.sample_method == 'PNDM':\n        scheduler_cls = PNDMScheduler\n        kwargs.pop('rescale_betas_zero_snr', None)\n    elif args.sample_method == 'HeunDiscrete':  ########\n        scheduler_cls = HeunDiscreteScheduler\n    elif args.sample_method == 'EulerAncestralDiscrete':\n        scheduler_cls = EulerAncestralDiscreteScheduler\n    elif args.sample_method == 'DEISMultistep':\n        scheduler_cls = DEISMultistepScheduler\n        kwargs.pop('rescale_betas_zero_snr', None)\n    elif args.sample_method == 'KDPM2AncestralDiscrete':  #########\n        scheduler_cls = KDPM2AncestralDiscreteScheduler\n    elif args.sample_method == 'CogVideoX':\n        scheduler_cls = CogVideoXDDIMScheduler\n    elif args.sample_method == 'FlowMatchEulerDiscrete':\n        scheduler_cls = FlowMatchEulerDiscreteScheduler\n        kwargs = {}\n    else:\n        raise NameError(f'Unsupport sample_method {args.sample_method}')\n    scheduler = scheduler_cls(**kwargs)\n    return scheduler\n\ndef prepare_pipeline(args, dtype, device):\n    \n    weight_dtype = dtype\n\n    vae = ae_wrapper[args.ae](args.ae_path)\n    vae.vae = vae.vae.to(device=device, dtype=weight_dtype).eval()\n    vae.vae_scale_factor = ae_stride_config[args.ae]\n    if args.enable_tiling:\n        vae.vae.enable_tiling()\n\n    if 'mt5' in args.text_encoder_name_1:\n        text_encoder_1 = MT5EncoderModel.from_pretrained(\n            args.text_encoder_name_1, cache_dir=args.cache_dir, \n            torch_dtype=weight_dtype\n            ).eval()\n    else:\n        text_encoder_1 = T5EncoderModel.from_pretrained(\n            args.text_encoder_name_1, cache_dir=args.cache_dir, \n            torch_dtype=weight_dtype\n            ).eval()\n    tokenizer_1 = AutoTokenizer.from_pretrained(\n        args.text_encoder_name_1, cache_dir=args.cache_dir\n        )\n\n    if args.text_encoder_name_2 is not None:\n        text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(\n            args.text_encoder_name_2, cache_dir=args.cache_dir, \n            torch_dtype=weight_dtype\n            ).eval()\n        tokenizer_2 = AutoTokenizer.from_pretrained(\n            args.text_encoder_name_2, cache_dir=args.cache_dir\n            )\n    else:\n        text_encoder_2, tokenizer_2 = None, None\n\n    if args.version == 'v1_3':\n        if args.model_type == 'inpaint' or args.model_type == 'i2v':\n            transformer_model = OpenSoraInpaint_v1_3.from_pretrained(\n                args.model_path, cache_dir=args.cache_dir,\n                device_map=None, torch_dtype=weight_dtype\n                ).eval()\n        else:\n            transformer_model = OpenSoraT2V_v1_3.from_pretrained(\n                args.model_path, cache_dir=args.cache_dir,\n                device_map=None, torch_dtype=weight_dtype\n                ).eval()\n    elif args.version == 'v1_5':\n        if args.model_type == 'inpaint' or args.model_type == 'i2v':\n            raise NotImplementedError('Inpainting model is not available in v1_5')\n        else:\n            from opensora.models.diffusion.opensora_v1_5.modeling_opensora import OpenSoraT2V_v1_5\n            transformer_model = OpenSoraT2V_v1_5.from_pretrained(\n                args.model_path, cache_dir=args.cache_dir, \n                # device_map=None, \n                torch_dtype=weight_dtype\n                ).eval()\n    \n    scheduler = get_scheduler(args)\n    pipeline_class = OpenSoraInpaintPipeline if args.model_type == 'inpaint' or args.model_type == 'i2v' else OpenSoraPipeline\n\n    pipeline = pipeline_class(\n        vae=vae,\n        text_encoder=text_encoder_1,\n        tokenizer=tokenizer_1,\n        scheduler=scheduler,\n        transformer=transformer_model, \n        text_encoder_2=text_encoder_2,\n        tokenizer_2=tokenizer_2,\n    ).to(device)\n\n    if args.save_memory:\n        print('enable_model_cpu_offload AND enable_sequential_cpu_offload AND enable_tiling')\n        pipeline.enable_model_cpu_offload()\n        pipeline.enable_sequential_cpu_offload()\n        # torch.cuda.empty_cache()\n        vae.vae.enable_tiling()\n        vae.vae.t_chunk_enc = 8\n        vae.vae.t_chunk_dec = vae.vae.t_chunk_enc // 2\n        \n    if args.compile:\n        pipeline.transformer = torch.compile(pipeline.transformer)\n\n    return pipeline\n\ndef init_gpu_env(args):\n    local_rank = int(os.getenv('RANK', 0))\n    world_size = int(os.getenv('WORLD_SIZE', 1))\n    args.local_rank = local_rank\n    args.world_size = world_size\n    torch.cuda.set_device(local_rank)\n    dist.init_process_group(\n        backend='nccl', init_method='env://', \n        world_size=world_size, rank=local_rank\n        )\n    if args.sp:\n        initialize_sequence_parallel_state(world_size)\n    return args\n\ndef init_npu_env(args):\n    local_rank = int(os.getenv('RANK', 0))\n    world_size = int(os.getenv('WORLD_SIZE', 1))\n    args.local_rank = local_rank\n    args.world_size = world_size\n    torch_npu.npu.set_device(local_rank)\n    dist.init_process_group(\n        backend='hccl', init_method='env://', \n        world_size=world_size, rank=local_rank\n        )\n    if args.sp:\n        initialize_sequence_parallel_state(world_size)\n    return args\n\n\ndef save_video_grid(video, nrow=None):\n    b, t, h, w, c = video.shape\n\n    if nrow is None:\n        nrow = math.ceil(math.sqrt(b))\n    ncol = math.ceil(b / nrow)\n    padding = 1\n    video_grid = torch.zeros(\n        (\n            t, \n            (padding + h) * nrow + padding, \n            (padding + w) * ncol + padding, \n            c\n        ), \n        dtype=torch.uint8\n        )\n\n    for i in range(b):\n        r = i // ncol\n        c = i % ncol\n        start_r = (padding + h) * r\n        start_c = (padding + w) * c\n        video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i]\n\n    return video_grid\n\n\ndef run_model_and_save_samples(args, pipeline, caption_refiner_model=None, enhance_video_model=None):\n    if args.seed is not None:\n        set_seed(args.seed, rank=args.local_rank, device_specific=True)\n    if args.local_rank >= 0:\n        torch.manual_seed(args.seed + args.local_rank)\n    if not os.path.exists(args.save_img_path):\n        os.makedirs(args.save_img_path, exist_ok=True)\n\n    video_grids = []\n    if not isinstance(args.text_prompt, list):\n        args.text_prompt = [args.text_prompt]\n    if len(args.text_prompt) == 1 and args.text_prompt[0].endswith('txt'):\n        text_prompt = open(args.text_prompt[0], 'r').readlines()\n        args.text_prompt = [i.strip() for i in text_prompt]\n    \n    if args.model_type == 'inpaint' or args.model_type == 'i2v':\n        if not isinstance(args.conditional_pixel_values_path, list):\n            args.conditional_pixel_values_path = [args.conditional_pixel_values_path]\n        if len(args.conditional_pixel_values_path) == 1 and args.conditional_pixel_values_path[0].endswith('txt'):\n            temp = open(args.conditional_pixel_values_path[0], 'r').readlines()\n            conditional_pixel_values_path = [i.strip().split(',') for i in temp]\n        \n        mask_type = args.mask_type if args.mask_type is not None else None\n\n    positive_prompt = \"\"\"\n    high quality, high aesthetic, {}\n    \"\"\"\n\n    negative_prompt = \"\"\"\n    nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, \n    low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry.\n    \"\"\"\n    \n    def generate(prompt, conditional_pixel_values_path=None, mask_type=None):\n        \n        if args.caption_refiner is not None:\n            if args.model_type != 'inpaint' and args.model_type != 'i2v':\n                refine_prompt = caption_refiner_model.get_refiner_output(prompt)\n                print(f'\\nOrigin prompt: {prompt}\\n->\\nRefine prompt: {refine_prompt}')\n                prompt = refine_prompt\n            else:\n                # Due to the current use of LLM as the caption refiner, additional content that is not present in the control image will be added. Therefore, caption refiner is not used in this mode.\n                print('Caption refiner is not available for inpainting model, use the original prompt...')\n                time.sleep(3)\n        input_prompt = positive_prompt.format(prompt)\n        \n        if args.model_type == 'inpaint' or args.model_type == 'i2v':\n            print(f'\\nConditional pixel values path: {conditional_pixel_values_path}')\n            videos = pipeline(\n                conditional_pixel_values_path=conditional_pixel_values_path,\n                mask_type=mask_type,\n                crop_for_hw=args.crop_for_hw,\n                max_hxw=args.max_hxw,\n                noise_strength=args.noise_strength,\n                prompt=input_prompt, \n                negative_prompt=negative_prompt, \n                num_frames=args.num_frames,\n                height=args.height,\n                width=args.width,\n                num_inference_steps=args.num_sampling_steps,\n                guidance_scale=args.guidance_scale,\n                num_samples_per_prompt=args.num_samples_per_prompt,\n                max_sequence_length=args.max_sequence_length,\n            ).videos\n        else:\n            videos = pipeline(\n                input_prompt, \n                negative_prompt=negative_prompt, \n                num_frames=args.num_frames,\n                height=args.height,\n                width=args.width,\n                num_inference_steps=args.num_sampling_steps,\n                guidance_scale=args.guidance_scale,\n                num_samples_per_prompt=args.num_samples_per_prompt,\n                max_sequence_length=args.max_sequence_length,\n            ).videos\n        if enhance_video_model is not None:\n            # b t h w c\n            videos = enhance_video_model.enhance_a_video(videos, input_prompt, 2.0, args.fps, 250)\n        if (not args.sp) or (args.sp and args.local_rank <= 0):\n            if args.num_frames == 1:\n                videos = rearrange(videos, 'b t h w c -> (b t) c h w')\n                if args.num_samples_per_prompt != 1:\n                    for i, image in enumerate(videos):\n                        save_image(\n                            image / 255.0, \n                            os.path.join(\n                                args.save_img_path, \n                                f'{args.sample_method}_{index}_gs{args.guidance_scale}_s{args.num_sampling_steps}_i{i}.jpg'\n                                ),\n                            nrow=math.ceil(math.sqrt(videos.shape[0])), \n                            normalize=True, \n                            value_range=(0, 1)\n                            )  # b c h w\n                save_image(\n                    videos / 255.0, \n                    os.path.join(\n                        args.save_img_path, \n                        f'{args.sample_method}_{index}_gs{args.guidance_scale}_s{args.num_sampling_steps}.jpg'\n                        ),\n                    nrow=math.ceil(math.sqrt(videos.shape[0])), \n                    normalize=True, \n                    value_range=(0, 1)\n                    )  # b c h w\n            else:\n                if args.num_samples_per_prompt == 1:\n                    imageio.mimwrite(\n                        os.path.join(\n                            args.save_img_path,\n                            f'{args.sample_method}_{index}_gs{args.guidance_scale}_s{args.num_sampling_steps}.mp4'\n                        ), \n                        videos[0],\n                        fps=args.fps, \n                        quality=6\n                        )  # highest quality is 10, lowest is 0\n                else:\n                    for i in range(args.num_samples_per_prompt):\n                        imageio.mimwrite(\n                            os.path.join(\n                                args.save_img_path,\n                                f'{args.sample_method}_{index}_gs{args.guidance_scale}_s{args.num_sampling_steps}_i{i}.mp4'\n                            ), videos[i],\n                            fps=args.fps, \n                            quality=6\n                            )  # highest quality is 10, lowest is 0\n                        \n                    videos = save_video_grid(videos)\n                    imageio.mimwrite(\n                        os.path.join(\n                            args.save_img_path,\n                            f'{args.sample_method}_{index}_gs{args.guidance_scale}_s{args.num_sampling_steps}.mp4'\n                        ), \n                        videos,\n                        fps=args.fps, \n                        quality=6\n                        )  # highest quality is 10, lowest is 0)\n                    videos = videos.unsqueeze(0) # 1 t h w c\n            video_grids.append(videos)\n\n    if args.model_type == 'inpaint' or args.model_type == 'i2v':\n        for index, (prompt, cond_path) in enumerate(zip(args.text_prompt, conditional_pixel_values_path)):\n            if not args.sp and args.local_rank != -1 and index % args.world_size != args.local_rank:\n                continue\n            generate(prompt, cond_path, mask_type)\n    else:\n        for index, prompt in enumerate(args.text_prompt):\n            if not args.sp and args.local_rank != -1 and index % args.world_size != args.local_rank:\n                continue  # skip when ddp\n            generate(prompt)\n\n    if (args.model_type == \"inpaint\" or args.model_type == \"i2v\") and not args.crop_for_hw:\n        print('completed, please check the saved images and videos')\n    else:\n        if not args.sp:\n            if args.local_rank != -1:\n                dist.barrier()\n                video_grids = torch.cat(video_grids, dim=0).cuda()\n                shape = list(video_grids.shape)\n                shape[0] *= args.world_size\n                gathered_tensor = torch.zeros(shape, dtype=video_grids.dtype).cuda()\n                dist.all_gather_into_tensor(gathered_tensor, video_grids.contiguous())\n                video_grids = gathered_tensor.cpu()\n                dist.barrier()\n            else:\n                video_grids = torch.cat(video_grids, dim=0)\n        elif args.sp and args.local_rank <= 0:\n            video_grids = torch.cat(video_grids)\n        \n        if args.local_rank <= 0:\n            if args.num_frames == 1:\n                save_image(\n                    video_grids / 255.0, \n                    os.path.join(\n                        args.save_img_path,\n                        f'{args.sample_method}_gs{args.guidance_scale}_s{args.num_sampling_steps}.jpg'\n                        ), \n                    nrow=math.ceil(math.sqrt(len(video_grids))), \n                    normalize=True, \n                    value_range=(0, 1)\n                    )\n            else:\n                video_grids = save_video_grid(video_grids)\n                imageio.mimwrite(\n                    os.path.join(\n                        args.save_img_path,\n                        f'{args.sample_method}_gs{args.guidance_scale}_s{args.num_sampling_steps}.mp4'\n                    ), \n                    video_grids, \n                    fps=args.fps, \n                    quality=6\n                    )\n            print('save path {}'.format(args.save_img_path))\n\n\n\ndef run_model_and_save_samples_npu(args, pipeline, caption_refiner_model=None, enhance_video_model=None):\n    \n    # experimental_config = torch_npu.profiler._ExperimentalConfig(\n    #     profiler_level=torch_npu.profiler.ProfilerLevel.Level1,\n    #     aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization\n    # )\n    # profile_output_path = \"/home/image_data/npu_profiling_t2v\"\n    # os.makedirs(profile_output_path, exist_ok=True)\n    # with torch_npu.profiler.profile(\n    #         activities=[\n    #             torch_npu.profiler.ProfilerActivity.NPU, \n    #             torch_npu.profiler.ProfilerActivity.CPU\n    #             ],\n    #         with_stack=True,\n    #         record_shapes=True,\n    #         profile_memory=True,\n    #         experimental_config=experimental_config,\n    #         schedule=torch_npu.profiler.schedule(\n    #             wait=10000, warmup=0, active=1, repeat=1, skip_first=0\n    #             ),\n    #         on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(f\"{profile_output_path}/\")\n    # ) as prof:\n    run_model_and_save_samples(args, pipeline, caption_refiner_model, enhance_video_model)\n        # prof.step()\n\n\ndef get_args():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--model_path\", type=str, default='LanguageBind/Open-Sora-Plan-v1.0.0')\n    parser.add_argument(\"--version\", type=str, default='v1_3', choices=['v1_3', 'v1_5'])\n    parser.add_argument(\"--model_type\", type=str, default='t2v', choices=['t2v', 'inpaint', 'i2v'])\n    parser.add_argument(\"--num_frames\", type=int, default=1)\n    parser.add_argument(\"--height\", type=int, default=512)\n    parser.add_argument(\"--width\", type=int, default=512)\n    parser.add_argument(\"--device\", type=str, default='cuda:0')\n    parser.add_argument(\"--cache_dir\", type=str, default='./cache_dir')\n    parser.add_argument(\"--caption_refiner\", type=str, default=None)\n    parser.add_argument(\"--ae\", type=str, default='CausalVAEModel_4x8x8')\n    parser.add_argument(\"--ae_path\", type=str, default='CausalVAEModel_4x8x8')\n    parser.add_argument(\"--enhance_video\", type=str, default=None)\n    parser.add_argument(\"--text_encoder_name_1\", type=str, default='DeepFloyd/t5-v1_1-xxl')\n    parser.add_argument(\"--text_encoder_name_2\", type=str, default=None)\n    parser.add_argument(\"--save_img_path\", type=str, default=\"./sample_videos/t2v\")\n    parser.add_argument(\"--guidance_scale\", type=float, default=7.5)\n    parser.add_argument(\"--sample_method\", type=str, default=\"PNDM\")\n    parser.add_argument(\"--num_sampling_steps\", type=int, default=50)\n    parser.add_argument(\"--fps\", type=int, default=24)\n    parser.add_argument(\"--max_sequence_length\", type=int, default=512)\n    parser.add_argument(\"--text_prompt\", nargs='+')\n    parser.add_argument(\"--seed\", type=int, default=42)\n    parser.add_argument(\"--num_samples_per_prompt\", type=int, default=1)\n    parser.add_argument('--enable_tiling', action='store_true')\n    parser.add_argument('--refine_caption', action='store_true')\n    parser.add_argument('--compile', action='store_true')\n    parser.add_argument('--save_memory', action='store_true') \n    parser.add_argument(\"--prediction_type\", type=str, default='epsilon', help=\"The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.\")\n    parser.add_argument('--rescale_betas_zero_snr', action='store_true')\n    parser.add_argument('--local_rank', type=int, default=-1)    \n    parser.add_argument('--world_size', type=int, default=1)    \n    parser.add_argument('--sp', action='store_true')\n\n    parser.add_argument('--v1_5_scheduler', action='store_true')\n    parser.add_argument('--conditional_pixel_values_path', type=str, default=None)\n    parser.add_argument('--mask_type', type=str, default=None)\n    parser.add_argument('--crop_for_hw', action='store_true')\n    parser.add_argument('--max_hxw', type=int, default=236544) # 480*480\n    parser.add_argument('--noise_strength', type=float, default=0.0)\n    args = parser.parse_args()\n    assert not (args.sp and args.num_frames == 1)\n    return args"
  },
  {
    "path": "opensora/utils/utils.py",
    "content": "import os\n\nimport torch\n\nimport os\nimport math\nimport torch\nimport logging\nimport random\nimport subprocess\nimport numpy as np\nimport torch.distributed as dist\n\n# from torch._six import inf\nimport accelerate\nfrom torch import inf\nfrom PIL import Image\nfrom typing import Union, Iterable\nimport collections\nfrom collections import OrderedDict\nfrom torch.utils.tensorboard import SummaryWriter\nimport wandb\nimport time\n\nfrom diffusers.utils import is_bs4_available, is_ftfy_available\n\nimport html\nimport re\nimport urllib.parse as ul\n\nif is_bs4_available():\n    from bs4 import BeautifulSoup\n\nif is_ftfy_available():\n    import ftfy\n\n_tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]]\n\ndef to_2tuple(x):\n    if isinstance(x, collections.abc.Iterable):\n        return x\n    return (x, x)\n\n\n\n\ndef explicit_uniform_sampling(T, n, rank, bsz, device):\n    \"\"\"\n    Explicit Uniform Sampling with integer timesteps and PyTorch.\n\n    Args:\n        T (int): Maximum timestep value.\n        n (int): Number of ranks (data parallel processes).\n        rank (int): The rank of the current process (from 0 to n-1).\n        bsz (int): Batch size, number of timesteps to return.\n\n    Returns:\n        torch.Tensor: A tensor of shape (bsz,) containing uniformly sampled integer timesteps\n                      within the rank's interval.\n    \"\"\"\n    interval_size = T / n  # Integer division to ensure boundaries are integers\n    lower_bound = interval_size * rank - 0.5\n    upper_bound = interval_size * (rank + 1) - 0.5\n    sampled_timesteps = [round(random.uniform(lower_bound, upper_bound)) for _ in range(bsz)]\n\n    # Uniformly sample within the rank's interval, returning integers\n    sampled_timesteps = torch.tensor([round(random.uniform(lower_bound, upper_bound)) for _ in range(bsz)], device=device)\n    sampled_timesteps = sampled_timesteps.long()\n    return sampled_timesteps\n\n\n\n#################################################################################\n#                             Training Clip Gradients                           #\n#################################################################################\n\ndef get_grad_norm(\n        parameters: _tensor_or_tensors, norm_type: float = 2.0) -> torch.Tensor:\n    r\"\"\"\n    Copy from torch.nn.utils.clip_grad_norm_\n\n    Clips gradient norm of an iterable of parameters.\n\n    The norm is computed over all gradients together, as if they were\n    concatenated into a single vector. Gradients are modified in-place.\n\n    Args:\n        parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a\n            single Tensor that will have gradients normalized\n        max_norm (float or int): max norm of the gradients\n        norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for\n            infinity norm.\n        error_if_nonfinite (bool): if True, an error is thrown if the total\n            norm of the gradients from :attr:`parameters` is ``nan``,\n            ``inf``, or ``-inf``. Default: False (will switch to True in the future)\n\n    Returns:\n        Total norm of the parameter gradients (viewed as a single vector).\n    \"\"\"\n    if isinstance(parameters, torch.Tensor):\n        parameters = [parameters]\n    grads = [p.grad for p in parameters if p.grad is not None]\n    norm_type = float(norm_type)\n    if len(grads) == 0:\n        return torch.tensor(0.)\n    device = grads[0].device\n    if norm_type == inf:\n        norms = [g.detach().abs().max().to(device) for g in grads]\n        total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))\n    else:\n        total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type)\n    return total_norm\n\n\ndef clip_grad_norm_(\n        parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0,\n        error_if_nonfinite: bool = False, clip_grad=True) -> torch.Tensor:\n    r\"\"\"\n    Copy from torch.nn.utils.clip_grad_norm_\n\n    Clips gradient norm of an iterable of parameters.\n\n    The norm is computed over all gradients together, as if they were\n    concatenated into a single vector. Gradients are modified in-place.\n\n    Args:\n        parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a\n            single Tensor that will have gradients normalized\n        max_norm (float or int): max norm of the gradients\n        norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for\n            infinity norm.\n        error_if_nonfinite (bool): if True, an error is thrown if the total\n            norm of the gradients from :attr:`parameters` is ``nan``,\n            ``inf``, or ``-inf``. Default: False (will switch to True in the future)\n\n    Returns:\n        Total norm of the parameter gradients (viewed as a single vector).\n    \"\"\"\n    if isinstance(parameters, torch.Tensor):\n        parameters = [parameters]\n    grads = [p.grad for p in parameters if p.grad is not None]\n    max_norm = float(max_norm)\n    norm_type = float(norm_type)\n    if len(grads) == 0:\n        return torch.tensor(0.)\n    device = grads[0].device\n    if norm_type == inf:\n        norms = [g.detach().abs().max().to(device) for g in grads]\n        total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))\n    else:\n        total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type)\n\n    if clip_grad:\n        if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):\n            raise RuntimeError(\n                f'The total norm of order {norm_type} for gradients from '\n                '`parameters` is non-finite, so it cannot be clipped. To disable '\n                'this error and scale the gradients by the non-finite norm anyway, '\n                'set `error_if_nonfinite=False`')\n        clip_coef = max_norm / (total_norm + 1e-6)\n        # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so\n        # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization\n        # when the gradients do not reside in CPU memory.\n        clip_coef_clamped = torch.clamp(clip_coef, max=1.0)\n        for g in grads:\n            g.detach().mul_(clip_coef_clamped.to(g.device))\n        # gradient_cliped = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type)\n        # print(gradient_cliped)\n    return total_norm\n\n\ndef get_experiment_dir(root_dir, args):\n    # if args.pretrained is not None and 'Latte-XL-2-256x256.pt' not in args.pretrained:\n    #     root_dir += '-WOPRE'\n    if args.use_compile:\n        root_dir += '-Compile'  # speedup by torch compile\n    if args.attention_mode:\n        root_dir += f'-{args.attention_mode.upper()}'\n    # if args.enable_xformers_memory_efficient_attention:\n    #     root_dir += '-Xfor'\n    if args.gradient_checkpointing:\n        root_dir += '-Gc'\n    if args.mixed_precision:\n        root_dir += f'-{args.mixed_precision.upper()}'\n    root_dir += f'-{args.max_image_size}'\n    return root_dir\n\ndef get_precision(args):\n    if args.mixed_precision == \"bf16\":\n        dtype = torch.bfloat16\n    elif args.mixed_precision == \"fp16\":\n        dtype = torch.float16\n    else:\n        dtype = torch.float32\n    return dtype\n\n#################################################################################\n#                             Training Logger                                   #\n#################################################################################\n\ndef create_logger(logging_dir):\n    \"\"\"\n    Create a logger that writes to a log file and stdout.\n    \"\"\"\n    if dist.get_rank() == 0:  # real logger\n        logging.basicConfig(\n            level=logging.INFO,\n            # format='[\\033[34m%(asctime)s\\033[0m] %(message)s',\n            format='[%(asctime)s] %(message)s',\n            datefmt='%Y-%m-%d %H:%M:%S',\n            handlers=[logging.StreamHandler(), logging.FileHandler(f\"{logging_dir}/log.txt\")]\n        )\n        logger = logging.getLogger(__name__)\n\n    else:  # dummy logger (does nothing)\n        logger = logging.getLogger(__name__)\n        logger.addHandler(logging.NullHandler())\n    return logger\n\n\ndef create_tensorboard(tensorboard_dir):\n    \"\"\"\n    Create a tensorboard that saves losses.\n    \"\"\"\n    if dist.get_rank() == 0:  # real tensorboard\n        # tensorboard\n        writer = SummaryWriter(tensorboard_dir)\n\n    return writer\n\n\ndef write_tensorboard(writer, *args):\n    '''\n    write the loss information to a tensorboard file.\n    Only for pytorch DDP mode.\n    '''\n    if dist.get_rank() == 0:  # real tensorboard\n        writer.add_scalar(args[0], args[1], args[2])\n\ndef get_npu_power():\n    result = subprocess.run([\"npu-smi\", \"info\"], stdout=subprocess.PIPE, text=True)\n    power_data = {}\n    npu_id = None\n\n    # 解析npu-smi的输出\n    for line in result.stdout.splitlines():\n        if line.startswith(\"| NPU\"):\n            npu_id = 0  # 开始新NPU记录\n        elif line.startswith(\"|\") and npu_id is not None:\n            parts = line.split(\"|\")\n            if len(parts) > 4:\n                power = parts[4].strip().split()[0]  # 提取Power(W)\n                \n                # 记录每个NPU的功率信息\n                power_data[f\"NPU_{npu_id}_Power_W\"] = float(power)\n                \n                npu_id += 1\n\n    return power_data\n\ndef monitor_npu_power():\n    while wandb.run is not None:\n        power_data = get_npu_power()\n        wandb.log(power_data)  # 实时记录NPU功率信息到wandb\n        time.sleep(10)  # 每10秒采集一次数据\n\n#################################################################################\n#                      EMA Update/ DDP Training Utils                           #\n#################################################################################\n\n@torch.no_grad()\ndef update_ema(ema_model, model, decay=0.9999):\n    \"\"\"\n    Step the EMA model towards the current model.\n    \"\"\"\n    ema_params = OrderedDict(ema_model.named_parameters())\n    model_params = OrderedDict(model.named_parameters())\n\n    for name, param in model_params.items():\n        # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed\n        ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)\n\n\ndef requires_grad(model, flag=True):\n    \"\"\"\n    Set requires_grad flag for all parameters in a model.\n    \"\"\"\n    for p in model.parameters():\n        p.requires_grad = flag\n\n\ndef cleanup():\n    \"\"\"\n    End DDP training.\n    \"\"\"\n    dist.destroy_process_group()\n\n\n# adapted from https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/random.py#L31\ndef set_seed(seed, rank, device_specific=True):\n    if device_specific:\n        seed += rank\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    torch.backends.cudnn.deterministic = True\n    torch.backends.cudnn.benchmark = False\n\ndef setup_distributed(backend=\"nccl\", port=None):\n    \"\"\"Initialize distributed training environment.\n    support both slurm and torch.distributed.launch\n    see torch.distributed.init_process_group() for more details\n    \"\"\"\n    num_gpus = torch.cuda.device_count()\n\n    if \"SLURM_JOB_ID\" in os.environ:\n        rank = int(os.environ[\"SLURM_PROCID\"])\n        world_size = int(os.environ[\"SLURM_NTASKS\"])\n        node_list = os.environ[\"SLURM_NODELIST\"]\n        addr = subprocess.getoutput(f\"scontrol show hostname {node_list} | head -n1\")\n        # specify master port\n        if port is not None:\n            os.environ[\"MASTER_PORT\"] = str(port)\n        elif \"MASTER_PORT\" not in os.environ:\n            # os.environ[\"MASTER_PORT\"] = \"29566\"\n            os.environ[\"MASTER_PORT\"] = str(29567 + num_gpus)\n        if \"MASTER_ADDR\" not in os.environ:\n            os.environ[\"MASTER_ADDR\"] = addr\n        os.environ[\"WORLD_SIZE\"] = str(world_size)\n        os.environ[\"LOCAL_RANK\"] = str(rank % num_gpus)\n        os.environ[\"RANK\"] = str(rank)\n    else:\n        rank = int(os.environ[\"RANK\"])\n        world_size = int(os.environ[\"WORLD_SIZE\"])\n\n    # torch.cuda.set_device(rank % num_gpus)\n\n    dist.init_process_group(\n        backend=backend,\n        world_size=world_size,\n        rank=rank,\n    )\n\n\n#################################################################################\n#                             MMCV  Utils                                    #\n#################################################################################\n\n\ndef collect_env():\n    # Copyright (c) OpenMMLab. All rights reserved.\n    from mmcv.utils import collect_env as collect_base_env\n    from mmcv.utils import get_git_hash\n    \"\"\"Collect the information of the running environments.\"\"\"\n\n    env_info = collect_base_env()\n    env_info['MMClassification'] = get_git_hash()[:7]\n\n    for name, val in env_info.items():\n        print(f'{name}: {val}')\n\n    print(torch.cuda.get_arch_list())\n    print(torch.version.cuda)\n\n\n#################################################################################\n#                          Pixart-alpha  Utils                                  #\n#################################################################################\n\n\nbad_punct_regex = re.compile(r'['+'#®•©™&@·º½¾¿¡§~'+'\\)'+'\\('+'\\]'+'\\['+'\\}'+'\\{'+'\\|'+'\\\\'+'\\/'+'\\*' + r']{1,}')  # noqa\n\ndef text_preprocessing(text, support_Chinese=True):\n    # The exact text cleaning as was in the training stage:\n    text = clean_caption(text, support_Chinese=support_Chinese)\n    return text\n\ndef basic_clean(text):\n    text = ftfy.fix_text(text)\n    text = html.unescape(html.unescape(text))\n    return text.strip()\n\ndef clean_caption(caption, support_Chinese=True):\n    caption = str(caption)\n    caption = ul.unquote_plus(caption)\n    caption = caption.strip().lower()\n    caption = re.sub('<person>', 'person', caption)\n    # urls:\n    caption = re.sub(\n        r'\\b((?:https?:(?:\\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\\w/-]*\\b\\/?(?!@)))',  # noqa\n        '', caption)  # regex for urls\n    caption = re.sub(\n        r'\\b((?:www:(?:\\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\\w/-]*\\b\\/?(?!@)))',  # noqa\n        '', caption)  # regex for urls\n    # html:\n    caption = BeautifulSoup(caption, features='html.parser').text\n\n    # @<nickname>\n    caption = re.sub(r'@[\\w\\d]+\\b', '', caption)\n\n    # 31C0—31EF CJK Strokes\n    # 31F0—31FF Katakana Phonetic Extensions\n    # 3200—32FF Enclosed CJK Letters and Months\n    # 3300—33FF CJK Compatibility\n    # 3400—4DBF CJK Unified Ideographs Extension A\n    # 4DC0—4DFF Yijing Hexagram Symbols\n    # 4E00—9FFF CJK Unified Ideographs\n    caption = re.sub(r'[\\u31c0-\\u31ef]+', '', caption)\n    caption = re.sub(r'[\\u31f0-\\u31ff]+', '', caption)\n    caption = re.sub(r'[\\u3200-\\u32ff]+', '', caption)\n    caption = re.sub(r'[\\u3300-\\u33ff]+', '', caption)\n    caption = re.sub(r'[\\u3400-\\u4dbf]+', '', caption)\n    caption = re.sub(r'[\\u4dc0-\\u4dff]+', '', caption)\n    if not support_Chinese:\n        caption = re.sub(r'[\\u4e00-\\u9fff]+', '', caption)  # Chinese\n    #######################################################\n\n    # все виды тире / all types of dash --> \"-\"\n    caption = re.sub(\n        r'[\\u002D\\u058A\\u05BE\\u1400\\u1806\\u2010-\\u2015\\u2E17\\u2E1A\\u2E3A\\u2E3B\\u2E40\\u301C\\u3030\\u30A0\\uFE31\\uFE32\\uFE58\\uFE63\\uFF0D]+',  # noqa\n        '-', caption)\n\n    # кавычки к одному стандарту\n    caption = re.sub(r'[`´«»“”¨]', '\"', caption)\n    caption = re.sub(r'[‘’]', \"'\", caption)\n\n    # &quot;\n    caption = re.sub(r'&quot;?', '', caption)\n    # &amp\n    caption = re.sub(r'&amp', '', caption)\n\n    # ip adresses:\n    caption = re.sub(r'\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}', ' ', caption)\n\n    # article ids:\n    caption = re.sub(r'\\d:\\d\\d\\s+$', '', caption)\n\n    # \\n\n    caption = re.sub(r'\\\\n', ' ', caption)\n\n    # \"#123\"\n    caption = re.sub(r'#\\d{1,3}\\b', '', caption)\n    # \"#12345..\"\n    caption = re.sub(r'#\\d{5,}\\b', '', caption)\n    # \"123456..\"\n    caption = re.sub(r'\\b\\d{6,}\\b', '', caption)\n    # filenames:\n    caption = re.sub(r'[\\S]+\\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)', '', caption)\n\n    #\n    caption = re.sub(r'[\\\"\\']{2,}', r'\"', caption)  # \"\"\"AUSVERKAUFT\"\"\"\n    caption = re.sub(r'[\\.]{2,}', r' ', caption)  # \"\"\"AUSVERKAUFT\"\"\"\n\n    caption = re.sub(bad_punct_regex, r' ', caption)  # ***AUSVERKAUFT***, #AUSVERKAUFT\n    caption = re.sub(r'\\s+\\.\\s+', r' ', caption)  # \" . \"\n\n    # this-is-my-cute-cat / this_is_my_cute_cat\n    regex2 = re.compile(r'(?:\\-|\\_)')\n    if len(re.findall(regex2, caption)) > 3:\n        caption = re.sub(regex2, ' ', caption)\n\n    caption = basic_clean(caption)\n\n    caption = re.sub(r'\\b[a-zA-Z]{1,3}\\d{3,15}\\b', '', caption)  # jc6640\n    caption = re.sub(r'\\b[a-zA-Z]+\\d+[a-zA-Z]+\\b', '', caption)  # jc6640vc\n    caption = re.sub(r'\\b\\d+[a-zA-Z]+\\d+\\b', '', caption)  # 6640vc231\n\n    caption = re.sub(r'(worldwide\\s+)?(free\\s+)?shipping', '', caption)\n    caption = re.sub(r'(free\\s)?download(\\sfree)?', '', caption)\n    caption = re.sub(r'\\bclick\\b\\s(?:for|on)\\s\\w+', '', caption)\n    caption = re.sub(r'\\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\\simage[s]?)?', '', caption)\n    caption = re.sub(r'\\bpage\\s+\\d+\\b', '', caption)\n\n    caption = re.sub(r'\\b\\d*[a-zA-Z]+\\d+[a-zA-Z]+\\d+[a-zA-Z\\d]*\\b', r' ', caption)  # j2d1a2a...\n\n    caption = re.sub(r'\\b\\d+\\.?\\d*[xх×]\\d+\\.?\\d*\\b', '', caption)\n\n    caption = re.sub(r'\\b\\s+\\:\\s+', r': ', caption)\n    caption = re.sub(r'(\\D[,\\./])\\b', r'\\1 ', caption)\n    caption = re.sub(r'\\s+', ' ', caption)\n\n    caption.strip()\n\n    caption = re.sub(r'^[\\\"\\']([\\w\\W]+)[\\\"\\']$', r'\\1', caption)\n    caption = re.sub(r'^[\\'\\_,\\-\\:;]', r'', caption)\n    caption = re.sub(r'[\\'\\_,\\-\\:\\-\\+]$', r'', caption)\n    caption = re.sub(r'^\\.\\S+$', '', caption)\n\n    return caption.strip()\n\n\nif __name__ == '__main__':\n    \n    # caption = re.sub(r'[\\u4e00-\\u9fff]+', '', caption)\n    a = \"امرأة مسنة بشعر أبيض ووجه مليء بالتجاعيد تجلس داخل سيارة قديمة الطراز، تنظر من خلال النافذة الجانبية بتعبير تأملي أو حزين قليلاً.\"\n    print(a)\n    print(text_preprocessing(a))\n\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[build-system]\nrequires = [\"setuptools>=61.0\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"opensora\"\nversion = \"1.3.0\"\ndescription = \"Reproduce OpenAI's Sora.\"\nreadme = \"README.md\"\nrequires-python = \">=3.8\"\nclassifiers = [\n    \"Programming Language :: Python :: 3\",\n    \"License :: OSI Approved :: Apache Software License\",\n]\ndependencies = [\n    \"transformers==4.44.2\", \"tokenizers==0.19.1\", \n    \"albumentations==1.4.0\", \"av==11.0.0\", \"decord==0.6.0\", \"einops==0.7.0\", \"fastapi==0.110.0\",\n    \"gdown==5.1.0\", \"h5py==3.10.0\", \"idna==3.8\", 'imageio==2.34.0', \"matplotlib==3.7.5\", \"numpy==1.24.4\",\n    \"omegaconf==2.1.1\", \"opencv-python==4.9.0.80\", \"opencv-python-headless==4.9.0.80\", \"pandas==2.0.3\", \"pillow==10.2.0\",\n    \"pydub==0.25.1\", \"pytorchvideo==0.1.5\", \"PyYAML==6.0.2\", \"regex==2024.7.24\",\n    \"requests==2.32.3\", \"scikit-learn==1.3.2\", \"scipy==1.10.1\", \"six==1.16.0\", \"test-tube==0.7.5\",\n    \"timm==0.9.16\", \"torchdiffeq==0.2.3\", \"torchmetrics==1.3.2\", \"tqdm==4.66.5\", \"urllib3==2.2.2\", \"uvicorn==0.27.1\",\n    \"scikit-video==1.1.11\", \"imageio-ffmpeg==0.4.9\", \"sentencepiece==0.1.99\", \"beautifulsoup4==4.12.3\",\n    \"ftfy==6.1.3\", \"moviepy==1.0.3\", \"wandb==0.16.3\", \"tensorboard==2.14.0\", \"pydantic==2.6.4\", \"gradio==4.0.0\", \n    \"torch==2.1.0\", \"torchvision==0.16.0\", \"xformers==0.0.22.post7\", \"accelerate==0.34.0\", \"diffusers==0.30.2\", \"deepspeed==0.12.6\"\n]\n\n[project.optional-dependencies]\ndev = [\"mypy==1.8.0\"]\n\n\n[project.urls]\n\"Homepage\" = \"https://github.com/PKU-YuanGroup/Open-Sora-Plan\"\n\"Bug Tracker\" = \"https://github.com/PKU-YuanGroup/Open-Sora-Plan/issues\"\n\n[tool.setuptools.packages.find]\nexclude = [\"assets*\", \"docker*\", \"docs\", \"scripts*\"]\n\n[tool.wheel]\nexclude = [\"assets*\", \"docker*\", \"docs\", \"scripts*\"]\n\n[tool.mypy]\nwarn_return_any = true\nwarn_unused_configs = true\nignore_missing_imports = true\ndisallow_untyped_calls = true\ncheck_untyped_defs = true\nno_implicit_optional = true\n"
  },
  {
    "path": "scripts/accelerate_configs/ddp_config.yaml",
    "content": "compute_environment: LOCAL_MACHINE\ndistributed_type: MULTI_GPU\nfsdp_config: {}\nmachine_rank: 0\nmain_process_ip: null\nmain_process_port: 29501\nmain_training_function: main\nnum_machines: 1\nnum_processes: 1\ngpu_ids: 0,\nuse_cpu: false"
  },
  {
    "path": "scripts/accelerate_configs/deepspeed_zero2_config.yaml",
    "content": "compute_environment: LOCAL_MACHINE\ndistributed_type: DEEPSPEED\ndeepspeed_config:\n deepspeed_config_file: scripts/accelerate_configs/zero2.json\nfsdp_config: {}\nmachine_rank: 0\nmain_process_ip: null\nmain_process_port: 29513\nmain_training_function: main\nnum_machines: 1\nnum_processes: 8\ngpu_ids: 0,1,2,3,4,5,6,7\nuse_cpu: false\n"
  },
  {
    "path": "scripts/accelerate_configs/deepspeed_zero2_offload_config.yaml",
    "content": "compute_environment: LOCAL_MACHINE\ndistributed_type: DEEPSPEED\ndeepspeed_config:\n deepspeed_config_file: scripts/accelerate_configs/zero2_offload.json\nfsdp_config: {}\nmachine_rank: 0\nmain_process_ip: null\nmain_process_port: 29501\nmain_training_function: main\nnum_machines: 1\nnum_processes: 8\ngpu_ids: 0,1,2,3,4,5,6,7\nuse_cpu: false"
  },
  {
    "path": "scripts/accelerate_configs/deepspeed_zero3_config.yaml",
    "content": "compute_environment: LOCAL_MACHINE\ndistributed_type: DEEPSPEED\ndeepspeed_config:\n deepspeed_config_file: scripts/accelerate_configs/zero3.json\nfsdp_config: {}\nmachine_rank: 0\nmain_process_ip: null\nmain_process_port: 29501\nmain_training_function: main\nnum_machines: 1\nnum_processes: 8\ngpu_ids: 0,1,2,3,4,5,6,7\nuse_cpu: false"
  },
  {
    "path": "scripts/accelerate_configs/deepspeed_zero3_offload_config.yaml",
    "content": "compute_environment: LOCAL_MACHINE\ndistributed_type: DEEPSPEED\ndeepspeed_config:\n deepspeed_config_file: scripts/accelerate_configs/zero3_offload.json\nfsdp_config: {}\nmachine_rank: 0\nmain_process_ip: null\nmain_process_port: 29501\nmain_training_function: main\nnum_machines: 1\nnum_processes: 8\ngpu_ids: 0,1,2,3,4,5,6,7\nuse_cpu: false"
  },
  {
    "path": "scripts/accelerate_configs/default_config.yaml",
    "content": "compute_environment: LOCAL_MACHINE\ndistributed_type: MULTI_GPU\nfsdp_config: {}\nmachine_rank: 0\nmain_process_ip: null\nmain_process_port: 29501\nmain_training_function: main\nmixed_precision: bf16\nnum_machines: 1\nnum_processes: 8\ngpu_ids: 0,1,2,3,4,5,6,7\nuse_cpu: false"
  },
  {
    "path": "scripts/accelerate_configs/hostfile",
    "content": "100.64.24.30 slots=8\n100.64.24.6 slots=8\n100.64.24.7 slots=8\n100.64.24.8 slots=8\n100.64.24.10 slots=8\n100.64.24.11 slots=8\n100.64.24.13 slots=8\n100.64.24.14 slots=8\n100.64.24.17 slots=8\n100.64.24.19 slots=8\n100.64.24.26 slots=8\n100.64.24.27 slots=8\n100.64.24.28 slots=8\n100.64.24.29 slots=8\n100.64.24.31 slots=8\n100.64.24.32 slots=8"
  },
  {
    "path": "scripts/accelerate_configs/zero2.json",
    "content": "{\n    \"fp16\": {\n        \"enabled\": false,\n        \"loss_scale\": 0,\n        \"loss_scale_window\": 1000,\n        \"initial_scale_power\": 16,\n        \"hysteresis\": 2,\n        \"min_loss_scale\": 1\n    },\n    \"bf16\": {\n        \"enabled\": \"auto\"\n    },\n    \"communication_data_type\": \"fp32\",\n    \"gradient_clipping\": 1.0,\n    \"train_micro_batch_size_per_gpu\": \"auto\",\n    \"train_batch_size\": \"auto\",\n    \"gradient_accumulation_steps\": \"auto\",\n    \"zero_optimization\": {\n        \"stage\": 2,\n        \"overlap_comm\": true,\n        \"contiguous_gradients\": true,\n        \"sub_group_size\": 1e9,\n        \"reduce_bucket_size\": 5e8\n    }\n}"
  },
  {
    "path": "scripts/accelerate_configs/zero2_npu.json",
    "content": "{\n    \"fp16\": {\n        \"enabled\": false,\n        \"loss_scale\": 0,\n        \"loss_scale_window\": 1000,\n        \"initial_scale_power\": 16,\n        \"hysteresis\": 2,\n        \"min_loss_scale\": 1\n    },\n    \"bf16\": {\n        \"enabled\": \"auto\"\n    },\n    \"communication_data_type\": \"fp32\",\n    \"gradient_clipping\": 1.0,\n    \"train_micro_batch_size_per_gpu\": \"auto\",\n    \"train_batch_size\": \"auto\",\n    \"gradient_accumulation_steps\": \"auto\",\n    \"zero_optimization\": {\n        \"stage\": 2,\n        \"overlap_comm\": true,\n        \"allgather_bucket_size\": 536870912,\n        \"contiguous_gradients\": true,\n        \"reduce_bucket_size\": 536870912\n    }\n}"
  },
  {
    "path": "scripts/accelerate_configs/zero2_offload.json",
    "content": "{\r\n    \"fp16\": {\r\n        \"enabled\": \"auto\",\r\n        \"loss_scale\": 0,\r\n        \"loss_scale_window\": 1000,\r\n        \"initial_scale_power\": 16,\r\n        \"hysteresis\": 2,\r\n        \"min_loss_scale\": 1\r\n    },\r\n    \"bf16\": {\r\n        \"enabled\": \"auto\"\r\n    },\r\n    \"communication_data_type\": \"fp32\",\r\n    \"gradient_clipping\": 1.0,\r\n    \"train_micro_batch_size_per_gpu\": \"auto\",\r\n    \"train_batch_size\": \"auto\",\r\n    \"gradient_accumulation_steps\": \"auto\",\r\n    \"zero_optimization\": {\r\n        \"stage\": 2,\r\n        \"offload_optimizer\": {\r\n            \"device\": \"cpu\"\r\n        },\r\n        \"overlap_comm\": true,\r\n        \"contiguous_gradients\": true,\r\n        \"sub_group_size\": 1e9,\r\n        \"reduce_bucket_size\": 5e8, \r\n        \"round_robin_gradients\": true\r\n    }\r\n}"
  },
  {
    "path": "scripts/accelerate_configs/zero3.json",
    "content": "{\n    \"fp16\": {\n        \"enabled\": \"auto\",\n        \"loss_scale\": 0,\n        \"loss_scale_window\": 1000,\n        \"initial_scale_power\": 16,\n        \"hysteresis\": 2,\n        \"min_loss_scale\": 1\n    },\n    \"bf16\": {\n        \"enabled\": \"auto\"\n    },\n    \"communication_data_type\": \"fp32\",\n    \"gradient_clipping\": 1.0,\n    \"train_micro_batch_size_per_gpu\": \"auto\",\n    \"train_batch_size\": \"auto\",\n    \"gradient_accumulation_steps\": \"auto\",\n    \"zero_optimization\": {\n        \"stage\": 3,\n        \"overlap_comm\": true,\n        \"contiguous_gradients\": true,\n        \"sub_group_size\": 1e9,\n        \"reduce_bucket_size\": 5e8,\n        \"stage3_prefetch_bucket_size\": 5e8,\n        \"stage3_param_persistence_threshold\": \"auto\",\n        \"stage3_max_live_parameters\": 1e9,\n        \"stage3_max_reuse_distance\": 1e9,\n        \"stage3_gather_16bit_weights_on_model_save\": true\n    }\n}"
  },
  {
    "path": "scripts/accelerate_configs/zero3_offload.json",
    "content": "{\n  \"fp16\": {\n    \"enabled\": \"auto\",\n    \"loss_scale\": 0,\n    \"loss_scale_window\": 1000,\n    \"initial_scale_power\": 16,\n    \"hysteresis\": 2,\n    \"min_loss_scale\": 1\n  },\n  \"bf16\": {\n    \"enabled\": \"auto\"\n  },\n  \"zero_optimization\": {\n    \"stage\": 3,\n    \"offload_optimizer\": {\n      \"device\": \"cpu\",\n      \"pin_memory\": true\n    },\n    \"offload_param\": {\n      \"device\": \"cpu\",\n      \"pin_memory\": true\n    },\n    \"overlap_comm\": true,\n    \"contiguous_gradients\": true,\n    \"sub_group_size\": 1e9,\n    \"reduce_bucket_size\": 5e8,\n    \"stage3_prefetch_bucket_size\": \"auto\",\n    \"stage3_param_persistence_threshold\": \"auto\",\n    \"stage3_max_live_parameters\": 1e9,\n    \"stage3_max_reuse_distance\": 1e9,\n    \"gather_16bit_weights_on_model_save\": true\n  },\n  \"gradient_accumulation_steps\": \"auto\",\n  \"gradient_clipping\": \"auto\",\n  \"train_batch_size\": \"auto\",\n  \"train_micro_batch_size_per_gpu\": \"auto\",\n  \"steps_per_print\": 1e5,\n  \"wall_clock_breakdown\": false\n}"
  },
  {
    "path": "scripts/causalvae/eval.sh",
    "content": "EXP_NAME=wfvae-4dim\nSAMPLE_RATE=1\nNUM_FRAMES=33\nRESOLUTION=256\nMETRIC=lpips\nSUBSET_SIZE=0\nORIGIN_DIR=video_gen/${EXP_NAME}_sr${SAMPLE_RATE}_nf${NUM_FRAMES}_res${RESOLUTION}_subset${SUBSET_SIZE}/origin\nRECON_DIR=video_gen/${EXP_NAME}_sr${SAMPLE_RATE}_nf${NUM_FRAMES}_res${RESOLUTION}_subset${SUBSET_SIZE}\n\npython opensora/models/causalvideovae/eval/eval.py \\\n    --batch_size 8 \\\n    --real_video_dir ${ORIGIN_DIR} \\\n    --generated_video_dir ${RECON_DIR} \\\n    --device cuda:1 \\\n    --sample_fps 1 \\\n    --sample_rate ${SAMPLE_RATE} \\\n    --num_frames ${NUM_FRAMES} \\\n    --resolution ${RESOLUTION} \\\n    --crop_size ${RESOLUTION} \\\n    --subset_size ${SUBSET_SIZE} \\\n    --metric ${METRIC}"
  },
  {
    "path": "scripts/causalvae/prepare_eval.sh",
    "content": "export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7\nDATASET_DIR=test_video\nEXP_NAME=wfvae\nSAMPLE_RATE=1\nNUM_FRAMES=33\nRESOLUTION=256\nCKPT=ckpt\nSUBSET_SIZE=0\n\naccelerate launch \\\n    --config_file scripts/accelerate_configs/default_config.yaml \\\n    opensora/models/causalvideovae/sample/rec_video_vae.py \\\n    --batch_size 1 \\\n    --real_video_dir ${DATASET_DIR} \\\n    --generated_video_dir video_gen/${EXP_NAME}_sr${SAMPLE_RATE}_nf${NUM_FRAMES}_res${RESOLUTION}_subset${SUBSET_SIZE} \\\n    --device cuda \\\n    --sample_fps 24 \\\n    --sample_rate ${SAMPLE_RATE} \\\n    --num_frames ${NUM_FRAMES} \\\n    --resolution ${RESOLUTION} \\\n    --subset_size ${SUBSET_SIZE} \\\n    --num_workers 8 \\\n    --from_pretrained ${CKPT} \\\n    --model_name WFVAE \\\n    --output_origin \\\n    --crop_size ${RESOLUTION}\n"
  },
  {
    "path": "scripts/causalvae/rec_image.sh",
    "content": "CUDA_VISIBLE_DEVICES=0 python examples/rec_image.py \\\n    --ae WFVAEModel_D8_4x8x8 \\\n    --ae_path \"/storage/lcm/WF-VAE/results/latent8\" \\\n    --image_path /storage/dataset/image/anytext3m/ocr_data/Art/images/gt_5544.jpg \\\n    --rec_path rec_.jpg \\\n    --device cuda \\\n    --short_size 512 "
  },
  {
    "path": "scripts/causalvae/rec_video.sh",
    "content": "CUDA_VISIBLE_DEVICES=1 python examples/rec_video.py \\\n    --ae WFVAEModel_D8_4x8x8 \\\n    --ae_path \"/storage/lcm/WF-VAE/results/latent8\" \\\n    --video_path /storage/lcm/WF-VAE/testvideo/gm1190263332-337350271.mp4 \\\n    --rec_path rec_tile_.mp4 \\\n    --device cuda \\\n    --sample_rate 1 \\\n    --num_frames 65 \\\n    --height 512 \\\n    --width 512 \\\n    --fps 30 \\\n    --enable_tiling"
  },
  {
    "path": "scripts/causalvae/train.sh",
    "content": "export WANDB_PROJECT=WFVAE\nexport CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7\nexport GLOO_SOCKET_IFNAME=bond0\nexport NCCL_SOCKET_IFNAME=bond0\nexport NCCL_IB_HCA=mlx5_10:1,mlx5_11:1,mlx5_12:1,mlx5_13:1\nexport NCCL_IB_GID_INDEX=3\nexport NCCL_IB_TC=162\nexport NCCL_IB_TIMEOUT=22\nexport NCCL_PXN_DISABLE=0\nexport NCCL_IB_QPS_PER_CONNECTION=4\nexport NCCL_ALGO=Ring\nexport OMP_NUM_THREADS=1\nexport MKL_NUM_THREADS=1\n\nEXP_NAME=TRAIN\n\ntorchrun \\\n    --nnodes=1 --nproc_per_node=8 \\\n    --master_addr=localhost \\\n    --master_port=12133 \\\n    opensora/train/train_causalvae.py \\\n    --exp_name ${EXP_NAME} \\\n    --video_path /storage/dataset/vae_eval/OpenMMLab___Kinetics-400/raw/Kinetics-400/videos_train/  \\\n    --eval_video_path /storage/dataset/vae_eval/OpenMMLab___Kinetics-400/raw/Kinetics-400/videos_val/ \\\n    --model_name WFVAE \\\n    --model_config scripts/causalvae/wfvae_4dim.json \\\n    --resolution 256 \\\n    --num_frames 25 \\\n    --batch_size 1 \\\n    --lr 0.00001 \\\n    --epochs 4 \\\n    --disc_start 0 \\\n    --save_ckpt_step 5000 \\\n    --eval_steps 1000 \\\n    --eval_batch_size 1 \\\n    --eval_num_frames 33 \\\n    --eval_sample_rate 1 \\\n    --eval_subset_size 500 \\\n    --eval_lpips \\\n    --ema \\\n    --ema_decay 0.999 \\\n    --perceptual_weight 1.0 \\\n    --loss_type l1 \\\n    --sample_rate 1 \\\n    --disc_cls opensora.models.causalvideovae.model.losses.LPIPSWithDiscriminator3D \\\n    --wavelet_loss \\\n    --wavelet_weight 0.1"
  },
  {
    "path": "scripts/causalvae/wfvae_4dim.json",
    "content": "{\n    \"_class_name\": \"WFVAEModel\",\n    \"_diffusers_version\": \"0.30.2\",\n    \"base_channels\": 128,\n    \"connect_res_layer_num\": 1,\n    \"decoder_energy_flow_hidden_size\": 128,\n    \"decoder_num_resblocks\": 2,\n    \"dropout\": 0.0,\n    \"encoder_energy_flow_hidden_size\": 128,\n    \"encoder_num_resblocks\": 2,\n    \"l1_dowmsample_block\": \"Downsample\",\n    \"l1_downsample_wavelet\": \"HaarWaveletTransform2D\",\n    \"l1_upsample_block\": \"Upsample\",\n    \"l1_upsample_wavelet\": \"InverseHaarWaveletTransform2D\",\n    \"l2_dowmsample_block\": \"Spatial2xTime2x3DDownsample\",\n    \"l2_downsample_wavelet\": \"HaarWaveletTransform3D\",\n    \"l2_upsample_block\": \"Spatial2xTime2x3DUpsample\",\n    \"l2_upsample_wavelet\": \"InverseHaarWaveletTransform3D\",\n    \"latent_dim\": 4,\n    \"norm_type\": \"layernorm\",\n    \"t_interpolation\": \"trilinear\",\n    \"use_attention\": true\n }"
  },
  {
    "path": "scripts/slurm/placeholder",
    "content": ""
  },
  {
    "path": "scripts/text_condition/gpu/sample_inpaint_v1_3.sh",
    "content": "\nCUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nnodes=1 --nproc_per_node 8 --master_port 29513 \\\n    -m opensora.sample.sample \\\n    --model_type \"inpaint\" \\\n    --model_path model_path \\\n    --version v1_3 \\\n    --num_frames 93 \\\n    --height 352 \\\n    --width 640 \\\n    --max_hxw 236544 \\\n    --crop_for_hw \\\n    --cache_dir \"../cache_dir\" \\\n    --text_encoder_name_1 \"/storage/ongoing/new/Open-Sora-Plan/cache_dir/mt5-xxl\" \\\n    --text_prompt examples/cond_prompt.txt \\\n    --conditional_pixel_values_path examples/cond_pix_path.txt \\\n    --ae WFVAEModel_D8_4x8x8 \\\n    --ae_path \"/storage/lcm/WF-VAE/results/latent8\" \\\n    --save_img_path \"./save_path\" \\\n    --fps 18 \\\n    --guidance_scale 7.5 \\\n    --num_sampling_steps 100 \\\n    --max_sequence_length 512 \\\n    --sample_method EulerAncestralDiscrete \\\n    --seed 1234 \\\n    --num_samples_per_prompt 1 \\\n    --rescale_betas_zero_snr \\\n    --prediction_type \"v_prediction\" \\\n    --noise_strength 0.0 \\"
  },
  {
    "path": "scripts/text_condition/gpu/sample_t2v_v1_3.sh",
    "content": "\nCUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nnodes=1 --nproc_per_node 8 --master_port 29514 \\\n    -m opensora.sample.sample \\\n    --model_path /storage/ongoing/9.29/mmdit/Open-Sora-Plan/final_ft_any93x352x640_v1_3_bs512_lr1e-5_snr5.0_fps16_zsnr_nofix_16node/checkpoint-5500/model_ema \\\n    --version v1_3 \\\n    --num_frames 93 \\\n    --height 352 \\\n    --width 640 \\\n    --cache_dir \"../cache_dir\" \\\n    --text_encoder_name_1 \"/storage/ongoing/new/Open-Sora-Plan/cache_dir/mt5-xxl\" \\\n    --text_prompt \"examples/sora.txt\" \\\n    --ae WFVAEModel_D8_4x8x8 \\\n    --ae_path \"/storage/lcm/WF-VAE/results/latent8\" \\\n    --save_img_path \"./train_1_3_nomotion_fps18\" \\\n    --fps 18 \\\n    --guidance_scale 7.5 \\\n    --num_sampling_steps 100 \\\n    --max_sequence_length 512 \\\n    --sample_method EulerAncestralDiscrete \\\n    --seed 1234 \\\n    --num_samples_per_prompt 1 \\\n    --rescale_betas_zero_snr \\\n    --prediction_type \"v_prediction\" "
  },
  {
    "path": "scripts/text_condition/gpu/train_inpaint_v1_3.sh",
    "content": "\nexport HF_DATASETS_OFFLINE=1 \nexport TRANSFORMERS_OFFLINE=1\nexport PDSH_RCMD_TYPE=ssh\n# NCCL setting\nexport GLOO_SOCKET_IFNAME=bond0\nexport NCCL_SOCKET_IFNAME=bond0\nexport NCCL_IB_HCA=mlx5_10:1,mlx5_11:1,mlx5_12:1,mlx5_13:1\nexport NCCL_IB_GID_INDEX=3\nexport NCCL_IB_TC=162\nexport NCCL_IB_TIMEOUT=25\nexport NCCL_PXN_DISABLE=0\nexport NCCL_IB_QPS_PER_CONNECTION=4\nexport NCCL_ALGO=Ring\nexport OMP_NUM_THREADS=1\nexport MKL_NUM_THREADS=1\nexport NCCL_IB_RETRY_CNT=32\n# export NCCL_ALGO=Tree\n\naccelerate launch \\\n    --config_file scripts/accelerate_configs/deepspeed_zero2_config.yaml \\\n    opensora/train/train_inpaint.py \\\n    --model OpenSoraInpaint_v1_3-2B/122 \\\n    --text_encoder_name_1 google/mt5-xxl \\\n    --cache_dir \"../../cache_dir/\" \\\n    --dataset inpaint \\\n    --data \"scripts/train_data/video_data.txt\" \\\n    --ae WFVAEModel_D8_4x8x8 \\\n    --ae_path \"/storage/lcm/WF-VAE/results/latent8\" \\\n    --sample_rate 1 \\\n    --num_frames 93 \\\n    --max_hxw 236544 \\\n    --min_hxw 102400 \\\n    --interpolation_scale_t 1.0 \\\n    --interpolation_scale_h 1.0 \\\n    --interpolation_scale_w 1.0 \\\n    --gradient_checkpointing \\\n    --train_batch_size=1 \\\n    --dataloader_num_workers 8 \\\n    --gradient_accumulation_steps=1 \\\n    --max_train_steps=1000000 \\\n    --learning_rate=1e-5 \\\n    --lr_scheduler=\"constant\" \\\n    --lr_warmup_steps=0 \\\n    --mixed_precision=\"bf16\" \\\n    --report_to=\"wandb\" \\\n    --checkpointing_steps=1000 \\\n    --allow_tf32 \\\n    --model_max_length 512 \\\n    --use_ema \\\n    --ema_start_step 0 \\\n    --cfg 0.1 \\\n    --resume_from_checkpoint=\"latest\" \\\n    --speed_factor 1.0 \\\n    --ema_decay 0.9999 \\\n    --drop_short_ratio 0.0 \\\n    --hw_stride 32 \\\n    --sparse1d --sparse_n 4 \\\n    --train_fps 18 \\\n    --seed 1234 \\\n    --trained_data_global_step 0 \\\n    --group_data \\\n    --use_decord \\\n    --prediction_type \"v_prediction\" \\\n    --output_dir=\"debug\" \\\n    --rescale_betas_zero_snr \\\n    --mask_config scripts/train_configs/mask_config.yaml \\\n    --add_noise_to_condition \\\n    --default_text_ratio 0.5 \n    # --pretrained \"\" \n"
  },
  {
    "path": "scripts/text_condition/gpu/train_t2v_v1_3.sh",
    "content": "\nexport HF_DATASETS_OFFLINE=1 \nexport TRANSFORMERS_OFFLINE=1\nexport PDSH_RCMD_TYPE=ssh\n# NCCL setting\nexport GLOO_SOCKET_IFNAME=bond0\nexport NCCL_SOCKET_IFNAME=bond0\nexport NCCL_IB_HCA=mlx5_10:1,mlx5_11:1,mlx5_12:1,mlx5_13:1\nexport NCCL_IB_GID_INDEX=3\nexport NCCL_IB_TC=162\nexport NCCL_IB_TIMEOUT=25\nexport NCCL_PXN_DISABLE=0\nexport NCCL_IB_QPS_PER_CONNECTION=4\nexport NCCL_ALGO=Ring\nexport OMP_NUM_THREADS=1\nexport MKL_NUM_THREADS=1\nexport NCCL_IB_RETRY_CNT=32\n# export NCCL_ALGO=Tree\n\naccelerate launch \\\n    --config_file scripts/accelerate_configs/deepspeed_zero2_config.yaml \\\n    opensora/train/train_t2v_diffusers.py \\\n    --model OpenSoraT2V_v1_3-2B/122 \\\n    --text_encoder_name_1 google/mt5-xxl \\\n    --cache_dir \"../../cache_dir/\" \\\n    --dataset t2v \\\n    --data \"scripts/train_data/merge_data.txt\" \\\n    --ae WFVAEModel_D8_4x8x8 \\\n    --ae_path \"/storage/lcm/WF-VAE/results/latent8\" \\\n    --sample_rate 1 \\\n    --num_frames 1 \\\n    --max_height 352 \\\n    --max_width 640 \\\n    --interpolation_scale_t 1.0 \\\n    --interpolation_scale_h 1.0 \\\n    --interpolation_scale_w 1.0 \\\n    --gradient_checkpointing \\\n    --train_batch_size=4 \\\n    --dataloader_num_workers 16 \\\n    --gradient_accumulation_steps=1 \\\n    --max_train_steps=1000000 \\\n    --learning_rate=1e-5 \\\n    --lr_scheduler=\"constant\" \\\n    --lr_warmup_steps=0 \\\n    --mixed_precision=\"bf16\" \\\n    --report_to=\"wandb\" \\\n    --checkpointing_steps=500 \\\n    --allow_tf32 \\\n    --model_max_length 512 \\\n    --use_ema \\\n    --ema_start_step 0 \\\n    --cfg 0.1 \\\n    --resume_from_checkpoint=\"latest\" \\\n    --speed_factor 1.0 \\\n    --ema_decay 0.9999 \\\n    --drop_short_ratio 0.0 \\\n    --pretrained \"\" \\\n    --hw_stride 32 \\\n    --sparse1d --sparse_n 4 \\\n    --train_fps 16 \\\n    --seed 1234 \\\n    --trained_data_global_step 0 \\\n    --group_data \\\n    --use_decord \\\n    --prediction_type \"v_prediction\" \\\n    --snr_gamma 5.0 \\\n    --force_resolution \\\n    --rescale_betas_zero_snr \\\n    --output_dir=\"debug\""
  },
  {
    "path": "scripts/text_condition/npu/sample_inpaint_v1_3.sh",
    "content": "\nexport TASK_QUEUE_ENABLE=0\ntorchrun --nnodes=1 --nproc_per_node 8 --master_port 29522 \\\n    -m opensora.sample.sample \\\n    --model_type \"inpaint\" \\\n    --model_path model_path \\\n    --version v1_3 \\\n    --num_frames 93 \\\n    --crop_for_hw \\\n    --height 352 \\\n    --width 640 \\\n    --max_hxw 236544 \\\n    --cache_dir \"../cache_dir\" \\\n    --text_encoder_name_1 \"/home/save_dir/pretrained/mt5-xxl\" \\\n    --text_prompt /home/image_data/gyy/mmdit/Open-Sora-Plan/validation_dir/prompt.txt \\\n    --conditional_pixel_values_path /home/image_data/gyy/mmdit/Open-Sora-Plan/validation_dir/cond_imgs_path.txt \\\n    --ae WFVAEModel_D8_4x8x8 \\\n    --ae_path \"/home/save_dir/lzj/formal_8dim/latent8\" \\\n    --save_img_path \"./test\" \\\n    --fps 18 \\\n    --guidance_scale 7.5 \\\n    --num_sampling_steps 50 \\\n    --max_sequence_length 512 \\\n    --sample_method EulerAncestralDiscrete \\\n    --seed 2514 \\\n    --num_samples_per_prompt 1 \\\n    --prediction_type \"v_prediction\" \\\n    --rescale_betas_zero_snr \\\n    --noise_strength 0.0 \\\n    # --mask_type i2v \\\n    # --enable_tiling \n"
  },
  {
    "path": "scripts/text_condition/npu/sample_t2v_v1_3.sh",
    "content": "\nCUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nnodes=1 --nproc_per_node 8 --master_port 29513 \\\n    -m opensora.sample.sample \\\n    --model_path model_path \\\n    --version v1_3 \\\n    --num_frames 93 \\\n    --height 352 \\\n    --width 640 \\\n    --cache_dir \"../cache_dir\" \\\n    --text_encoder_name_1 \"/home/save_dir/pretrained/mt5-xxl\" \\\n    --text_prompt examples/sora_refine.txt \\\n    --ae WFVAEModel_D8_4x8x8 \\\n    --ae_path \"/home/save_dir/lzj/formal_8dim/latent8\" \\\n    --save_img_path \"./test\" \\\n    --fps 18 \\\n    --guidance_scale 7.5 \\\n    --num_sampling_steps 100 \\\n    --max_sequence_length 512 \\\n    --sample_method EulerAncestralDiscrete \\\n    --seed 1234 \\\n    --num_samples_per_prompt 1 \\\n    --rescale_betas_zero_snr \\\n    --prediction_type \"v_prediction\""
  },
  {
    "path": "scripts/text_condition/npu/train_inpaint_v1_3.sh",
    "content": "\nexport PROJECT=$PROJECT_NAME\n# export PROJECT='test'\nexport HF_DATASETS_OFFLINE=1 \nexport TRANSFORMERS_OFFLINE=1\n\nexport TASK_QUEUE_ENABLE=0\nexport HCCL_OP_BASE_FFTS_MODE_ENABLE=TRUE\nexport MULTI_STREAM_MEMORY_REUSE=1\nexport PYTORCH_NPU_ALLOC_CONF=expandable_segments:True\n# export HCCL_ALGO=\"level0:NA;level1:H-D_R\"\n# --machine_rank=${MACHINE_RANK} \\\n# --main_process_ip=${MAIN_PROCESS_IP_VALUE} \\\n# multi_node_example_by_deepspeed.yaml\n# deepspeed_zero2_config.yaml\n\naccelerate launch \\\n    --config_file scripts/accelerate_configs/deepspeed_zero2_config.yaml \\\n    opensora/train/train_inpaint.py \\\n    --model OpenSoraInpaint_v1_3-2B/122 \\\n    --text_encoder_name_1 google/mt5-xxl \\\n    --cache_dir \"../../cache_dir/\" \\\n    --dataset inpaint \\\n    --data \"scripts/train_data/video_data.txt\" \\\n    --ae WFVAEModel_D8_4x8x8 \\\n    --ae_path \"/home/save_dir/lzj/formal_8dim/latent8\" \\\n    --vae_fp32 \\\n    --sample_rate 1 \\\n    --num_frames 93 \\\n    --max_hxw 236544 \\\n    --min_hxw 102400 \\\n    --snr_gamma 5.0 \\\n    --interpolation_scale_t 1.0 \\\n    --interpolation_scale_h 1.0 \\\n    --interpolation_scale_w 1.0 \\\n    --gradient_checkpointing \\\n    --train_batch_size=1 \\\n    --dataloader_num_workers 8 \\\n    --gradient_accumulation_steps=1 \\\n    --max_train_steps=1000000 \\\n    --learning_rate=1e-5 \\\n    --lr_scheduler=\"constant\" \\\n    --lr_warmup_steps=0 \\\n    --mixed_precision=\"bf16\" \\\n    --report_to=\"wandb\" \\\n    --checkpointing_steps=500 \\\n    --allow_tf32 \\\n    --model_max_length 512 \\\n    --use_ema \\\n    --ema_start_step 0 \\\n    --cfg 0.1 \\\n    --speed_factor 1.0 \\\n    --ema_decay 0.9999 \\\n    --drop_short_ratio 0.0 \\\n    --hw_stride 32 \\\n    --sparse1d --sparse_n=4 \\\n    --train_fps 16 \\\n    --seed 1234 \\\n    --trained_data_global_step 0 \\\n    --group_data \\\n    --use_decord \\\n    --prediction_type \"v_prediction\" \\\n    --output_dir=\"/home/save_dir/runs/$PROJECT\" \\\n    --mask_config scripts/train_configs/mask_config.yaml \\\n    --add_noise_to_condition \\\n    --default_text_ratio 0.5 \\\n    --resume_from_checkpoint=\"latest\" \n    # --pretrained \"/home/save_dir/pretrained/93x640x640_144k_ema\" \n    # --force_resolution\n    # --force_resolution \\\n    # --max_height 352 \\\n    # --max_width 640 \\\n"
  },
  {
    "path": "scripts/text_condition/npu/train_t2v_v1_3.sh",
    "content": "\nexport PROJECT=$PROJECT_NAME\n# export PROJECT='test'\nexport HF_DATASETS_OFFLINE=1 \nexport TRANSFORMERS_OFFLINE=1\n\nexport TASK_QUEUE_ENABLE=0\nexport HCCL_OP_BASE_FFTS_MODE_ENABLE=TRUE\nexport MULTI_STREAM_MEMORY_REUSE=1\nexport PYTORCH_NPU_ALLOC_CONF=expandable_segments:True\n# export HCCL_ALGO=\"level0:NA;level1:H-D_R\"\n# --machine_rank=${MACHINE_RANK} \\\n# --main_process_ip=${MAIN_PROCESS_IP_VALUE} \\\n# multi_node_example_by_deepspeed.yaml\n# deepspeed_zero2_config.yaml\n\naccelerate launch \\\n    --config_file scripts/accelerate_configs/deepspeed_zero2_config.yaml \\\n    opensora/train/train_t2v_diffusers.py \\\n    --model OpenSoraT2V_v1_3-2B/122 \\\n    --text_encoder_name_1 google/mt5-xxl \\\n    --cache_dir \"../../cache_dir/\" \\\n    --dataset t2v \\\n    --data \"scripts/train_data/video_data_debug_on_npu.txt\" \\\n    --ae WFVAEModel_D8_4x8x8 \\\n    --ae_path \"/home/save_dir/lzj/formal_8dim/latent8\" \\\n    --sample_rate 1 \\\n    --num_frames 93 \\\n    --max_height 352 \\\n    --max_width 640 \\\n    --force_resolution \\\n    --interpolation_scale_t 1.0 \\\n    --interpolation_scale_h 1.0 \\\n    --interpolation_scale_w 1.0 \\\n    --gradient_checkpointing \\\n    --train_batch_size=1 \\\n    --dataloader_num_workers 8 \\\n    --gradient_accumulation_steps=1 \\\n    --max_train_steps=1000000 \\\n    --learning_rate=1e-5 \\\n    --lr_scheduler=\"constant\" \\\n    --lr_warmup_steps=0 \\\n    --mixed_precision=\"bf16\" \\\n    --report_to=\"wandb\" \\\n    --checkpointing_steps=500 \\\n    --allow_tf32 \\\n    --model_max_length 512 \\\n    --use_ema \\\n    --ema_start_step 0 \\\n    --cfg 0.1 \\\n    --resume_from_checkpoint=\"latest\" \\\n    --speed_factor 1.0 \\\n    --ema_decay 0.9999 \\\n    --drop_short_ratio 0.0 \\\n    --pretrained \"/home/save_dir/pretrained/93x640x640_144k_ema\" \\\n    --hw_stride 32 \\\n    --sparse1d --sparse_n 4 \\\n    --train_fps 16 \\\n    --seed 1234 \\\n    --trained_data_global_step 0 \\\n    --group_data \\\n    --use_decord \\\n    --prediction_type \"v_prediction\" \\\n    --snr_gamma 5.0 \\\n    --rescale_betas_zero_snr \\\n    --output_dir=\"debug\""
  },
  {
    "path": "scripts/train_configs/mask_config.yaml",
    "content": "# mask processor args\nmin_clear_ratio: 0.0\nmax_clear_ratio: 1.0 \n\n# mask_type_ratio_dict_video\nmask_type_ratio_dict_video:\n  t2iv: 1\n  i2v: 8\n  transition: 8\n  continuation: 2\n  clear: 0\n  random_temporal: 1\n\nmask_type_ratio_dict_image:\n  t2iv: 0\n  clear: 0"
  },
  {
    "path": "scripts/train_data/merge_data.txt",
    "content": "/storage/dataset/recap_datacomp_1b_data/output,/storage/anno_pkl/img_nocn_res160_pkl/recap_64part_filter_aes_res160_pkl/part0_7036495.pkl"
  }
]